visualkeras: custom vgg16 show dimension example#

An example showing the visualkeras function used by a tf.keras.Model model.

 9 # Authors: The scikit-plots developers
10 # SPDX-License-Identifier: BSD-3-Clause

pip install protobuf==5.29.4

14 import tensorflow as tf
15
16 # Clear any session to reset the state of TensorFlow/Keras
17 tf.keras.backend.clear_session()
18
19 from scikitplot import visualkeras

create VGG16

23 image_size = 224
24 model = tf.keras.models.Sequential()
25 model.add(tf.keras.layers.InputLayer(shape=(image_size, image_size, 3)))
26
27 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
28 model.add(tf.keras.layers.Conv2D(64, activation="relu", kernel_size=(3, 3)))
29 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
30 model.add(tf.keras.layers.Conv2D(64, activation="relu", kernel_size=(3, 3)))
31 model.add(visualkeras.SpacingDummyLayer())
32
33 model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
34 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
35 model.add(tf.keras.layers.Conv2D(128, activation="relu", kernel_size=(3, 3)))
36 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
37 model.add(tf.keras.layers.Conv2D(128, activation="relu", kernel_size=(3, 3)))
38 model.add(visualkeras.SpacingDummyLayer())
39
40 model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
41 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
42 model.add(tf.keras.layers.Conv2D(256, activation="relu", kernel_size=(3, 3)))
43 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
44 model.add(tf.keras.layers.Conv2D(256, activation="relu", kernel_size=(3, 3)))
45 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
46 model.add(tf.keras.layers.Conv2D(256, activation="relu", kernel_size=(3, 3)))
47 model.add(visualkeras.SpacingDummyLayer())
48
49 model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
50 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
51 model.add(tf.keras.layers.Conv2D(512, activation="relu", kernel_size=(3, 3)))
52 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
53 model.add(tf.keras.layers.Conv2D(512, activation="relu", kernel_size=(3, 3)))
54 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
55 model.add(tf.keras.layers.Conv2D(512, activation="relu", kernel_size=(3, 3)))
56 model.add(visualkeras.SpacingDummyLayer())
57
58 model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
59 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
60 model.add(tf.keras.layers.Conv2D(512, activation="relu", kernel_size=(3, 3)))
61 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
62 model.add(tf.keras.layers.Conv2D(512, activation="relu", kernel_size=(3, 3)))
63 model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
64 model.add(tf.keras.layers.Conv2D(512, activation="relu", kernel_size=(3, 3)))
65 model.add(tf.keras.layers.MaxPooling2D())
66 model.add(visualkeras.SpacingDummyLayer())
67
68 model.add(tf.keras.layers.Flatten())
69
70 model.add(tf.keras.layers.Dense(4096, activation="relu"))
71 model.add(tf.keras.layers.Dropout(0.5))
72 model.add(tf.keras.layers.Dense(4096, activation="relu"))
73 model.add(tf.keras.layers.Dropout(0.5))
74 model.add(tf.keras.layers.Dense(1000, activation="softmax"))
75 # model.summary()

Now visualize the model!

80 from collections import defaultdict
81
82 color_map = defaultdict(dict)
83 color_map[tf.keras.layers.Conv2D]["fill"] = "orange"
84 color_map[tf.keras.layers.ZeroPadding2D]["fill"] = "gray"
85 color_map[tf.keras.layers.Dropout]["fill"] = "pink"
86 color_map[tf.keras.layers.MaxPooling2D]["fill"] = "red"
87 color_map[tf.keras.layers.Dense]["fill"] = "green"
88 color_map[tf.keras.layers.Flatten]["fill"] = "teal"
91 from PIL import ImageFont
92
93 ImageFont.load_default()
<PIL.ImageFont.FreeTypeFont object at 0x7f9a3806ccd0>
 97 img_vgg16_show_dimension = visualkeras.layered_view(
 98     model,
 99     legend=True,
100     show_dimension=True,
101     type_ignore=[visualkeras.SpacingDummyLayer],
102     font={
103         "font_size": 61,
104         # 'use_default_font': False,
105         # 'font_path': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
106     },
107     # to_file="result_images/vgg16_show_dimension.png",
108     save_fig=True,
109     save_fig_filename="vgg16_show_dimension.png",
110 )
111 img_vgg16_show_dimension
plot custom vgg16 show dimension
<matplotlib.image.AxesImage object at 0x7f99e87bab90>
114 img_vgg16_legend_show_dimension = visualkeras.layered_view(
115     model,
116     legend=True,
117     show_dimension=True,
118     type_ignore=[visualkeras.SpacingDummyLayer],
119     font={
120         "font_size": 61,
121         # 'use_default_font': False,
122         # 'font_path': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
123     },
124     # to_file="result_images/vgg16_legend_show_dimension.png",
125     save_fig=True,
126     save_fig_filename="vgg16_legend_show_dimension.png",
127 )
128 img_vgg16_legend_show_dimension
plot custom vgg16 show dimension
<matplotlib.image.AxesImage object at 0x7f99e814d710>
131 img_vgg16_spacing_layers_show_dimension = visualkeras.layered_view(
132     model,
133     legend=True,
134     show_dimension=True,
135     type_ignore=[],
136     spacing=0,
137     font={
138         "font_size": 61,
139         # 'use_default_font': False,
140         # 'font_path': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
141     },
142     # to_file="result_images/vgg16_spacing_layers_show_dimension.png",
143     save_fig=True,
144     save_fig_filename="vgg16_spacing_layers_show_dimension.png",
145 )
146 img_vgg16_spacing_layers_show_dimension
plot custom vgg16 show dimension
<matplotlib.image.AxesImage object at 0x7f99e819d710>
149 img_vgg16_type_ignore_show_dimension = visualkeras.layered_view(
150     model,
151     legend=True,
152     show_dimension=True,
153     type_ignore=[
154         tf.keras.layers.ZeroPadding2D,
155         tf.keras.layers.Dropout,
156         tf.keras.layers.Flatten,
157         visualkeras.SpacingDummyLayer,
158     ],
159     font={
160         "font_size": 61,
161         # 'use_default_font': False,
162         # 'font_path': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
163     },
164     # to_file="result_images/vgg16_type_ignore_show_dimension.png",
165     save_fig=True,
166     save_fig_filename="vgg16_type_ignore_show_dimension.png",
167 )
168 img_vgg16_type_ignore_show_dimension
plot custom vgg16 show dimension
<matplotlib.image.AxesImage object at 0x7f99e800d710>
171 img_vgg16_color_map_show_dimension = visualkeras.layered_view(
172     model,
173     legend=True,
174     show_dimension=True,
175     type_ignore=[visualkeras.SpacingDummyLayer],
176     color_map=color_map,
177     font={
178         "font_size": 61,
179         # 'use_default_font': False,
180         # 'font_path': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
181     },
182     # to_file="result_images/vgg16_color_map_show_dimension.png",
183     save_fig=True,
184     save_fig_filename="vgg16_color_map_show_dimension.png",
185 )
186 img_vgg16_color_map_show_dimension
plot custom vgg16 show dimension
<matplotlib.image.AxesImage object at 0x7f99e8083a90>
189 img_vgg16_flat_show_dimension = visualkeras.layered_view(
190     model,
191     legend=True,
192     show_dimension=True,
193     type_ignore=[visualkeras.SpacingDummyLayer],
194     draw_volume=False,
195     font={
196         "font_size": 61,
197         # 'use_default_font': False,
198         # 'font_path': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
199     },
200     # to_file="result_images/vgg16_flat_show_dimension.png",
201     save_fig=True,
202     save_fig_filename="vgg16_flat_show_dimension.png",
203 )
204 img_vgg16_flat_show_dimension
plot custom vgg16 show dimension
<matplotlib.image.AxesImage object at 0x7f99c4716690>
207 img_vgg16_scaling_show_dimension = visualkeras.layered_view(
208     model,
209     legend=True,
210     show_dimension=True,
211     type_ignore=[visualkeras.SpacingDummyLayer],
212     # min_z = 1,
213     # min_xy = 1,
214     # max_z = 4096,
215     # max_xy = 4096,
216     # scale_z = 0.25,
217     # scale_xy = 5,
218     font={"font_size": 61},
219     # to_file="result_images/vgg16_scaling_show_dimension.png",
220     save_fig=True,
221     save_fig_filename="vgg16_scaling_show_dimension.png",
222 )
223 img_vgg16_scaling_show_dimension
plot custom vgg16 show dimension
<matplotlib.image.AxesImage object at 0x7f99c4789690>

Tags: model-type: classification model-workflow: model building plot-type: visualkeras domain: neural network level: intermediate purpose: showcase

Total running time of the script: (0 minutes 15.663 seconds)

Related examples

visualkeras: custom vgg16 example

visualkeras: custom vgg16 example

visualkeras: transformers example

visualkeras: transformers example

visualkeras: custom VGG example

visualkeras: custom VGG example

Visualkeras: Spam Classification Conv1D Dense Example

Visualkeras: Spam Classification Conv1D Dense Example

Gallery generated by Sphinx-Gallery