visualkeras custom vgg16 show dimension example#
An example showing the visualkeras
function
used by a tf.keras.Model
model.
# Authors: The scikit-plots developers
# SPDX-License-Identifier: BSD-3-Clause
# Force garbage collection
import gc; gc.collect()
import tensorflow as tf
# Clear any session to reset the state of TensorFlow/Keras
tf.keras.backend.clear_session()
from scikitplot import visualkeras
# create VGG16
image_size = 224
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(shape=(image_size, image_size, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(64, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(64, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())
model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(128, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(128, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())
model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())
model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())
model.add(tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.ZeroPadding2D((1, 1)))
model.add(tf.keras.layers.Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(tf.keras.layers.MaxPooling2D())
model.add(visualkeras.SpacingDummyLayer())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(4096, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(4096, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(1000, activation='softmax'))
# Now visualize the model!
from collections import defaultdict
color_map = defaultdict(dict)
color_map[tf.keras.layers.Conv2D]['fill'] = 'orange'
color_map[tf.keras.layers.ZeroPadding2D]['fill'] = 'gray'
color_map[tf.keras.layers.Dropout]['fill'] = 'pink'
color_map[tf.keras.layers.MaxPooling2D]['fill'] = 'red'
color_map[tf.keras.layers.Dense]['fill'] = 'green'
color_map[tf.keras.layers.Flatten]['fill'] = 'teal'
from PIL import ImageFont
def get_font():
import platform
system_platform = platform.system().lower()
# Detect platform and select font accordingly
try:
if system_platform == 'windows':
return ImageFont.truetype("arial.ttf", 32)
elif system_platform == 'darwin': # macOS
return ImageFont.truetype("/Library/Fonts/Arial.ttf", 32) # or "/System/Library/Fonts/Helvetica.ttc"
elif system_platform == 'linux':
# Try a more common font path
return ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32)
else:
raise ValueError("Unsupported platform")
except OSError:
# Fallback font if the specified font is not found
print("Font not found, using default font.")
return ImageFont.load_default()
# Example usage
font = get_font()
img_vgg16_show_dimension = visualkeras.layered_view(
model, legend=True, show_dimension=True,
to_file='../result_images/vgg16_show_dimension.png',
type_ignore=[visualkeras.SpacingDummyLayer]
)
img_vgg16_legend_show_dimension = visualkeras.layered_view(
model, legend=True, show_dimension=True,
to_file='../result_images/vgg16_legend_show_dimension.png',
type_ignore=[visualkeras.SpacingDummyLayer], font=font
)
img_vgg16_spacing_layers_show_dimension = visualkeras.layered_view(
model, legend=True, show_dimension=True,
to_file='../result_images/vgg16_spacing_layers_show_dimension.png',
type_ignore=[], spacing=0
)
img_vgg16_type_ignore_show_dimension = visualkeras.layered_view(
model, legend=True, show_dimension=True,
to_file='../result_images/vgg16_type_ignore_show_dimension.png',
type_ignore=[tf.keras.layers.ZeroPadding2D, tf.keras.layers.Dropout, tf.keras.layers.Flatten, visualkeras.SpacingDummyLayer]
)
img_vgg16_color_map_show_dimension = visualkeras.layered_view(
model, legend=True, show_dimension=True,
to_file='../result_images/vgg16_color_map_show_dimension.png',
type_ignore=[visualkeras.SpacingDummyLayer], color_map=color_map
)
img_vgg16_flat_show_dimension = visualkeras.layered_view(
model, legend=True, show_dimension=True,
to_file='../result_images/vgg16_flat_show_dimension.png',
type_ignore=[visualkeras.SpacingDummyLayer], draw_volume=False
)
img_vgg16_scaling_show_dimension = visualkeras.layered_view(
model, legend=True, show_dimension=True,
to_file='../result_images/vgg16_scaling_show_dimension.png',
type_ignore=[visualkeras.SpacingDummyLayer], scale_xy=1, scale_z=1, max_z=1000
)
try:
import matplotlib.pyplot as plt
plt.imshow(img_vgg16_show_dimension)
plt.axis('off')
plt.show()
plt.imshow(img_vgg16_legend_show_dimension)
plt.axis('off')
plt.show()
plt.imshow(img_vgg16_spacing_layers_show_dimension)
plt.axis('off')
plt.show()
plt.imshow(img_vgg16_type_ignore_show_dimension)
plt.axis('off')
plt.show()
plt.imshow(img_vgg16_color_map_show_dimension)
plt.axis('off')
plt.show()
plt.imshow(img_vgg16_flat_show_dimension)
plt.axis('off')
plt.show()
plt.imshow(img_vgg16_scaling_show_dimension)
plt.axis('off')
plt.show()
except:
pass

Font not found, using default font.
Total running time of the script: (0 minutes 8.816 seconds)
Related examples