visualkeras: autoencoder example#

An example showing the visualkeras function used by a tf.keras.Model model.

 9 # Authors: The scikit-plots developers
10 # SPDX-License-Identifier: BSD-3-Clause

pip install protobuf==5.29.4

14 import tensorflow as tf
15
16 # Clear any session to reset the state of TensorFlow/Keras
17 tf.keras.backend.clear_session()

encoder Model

21 encoder_input = tf.keras.Input(shape=(28, 28, 1), name="img")
22 x = tf.keras.layers.Conv2D(16, 3, activation="relu")(encoder_input)
23 x = tf.keras.layers.Conv2D(32, 3, activation="relu")(x)
24 x = tf.keras.layers.MaxPooling2D(3)(x)
25 x = tf.keras.layers.Conv2D(32, 3, activation="relu")(x)
26 x = tf.keras.layers.Conv2D(16, 3, activation="relu")(x)
27 encoder_output = tf.keras.layers.GlobalMaxPooling2D()(x)
28 encoder = tf.keras.Model(encoder_input, encoder_output, name="encoder")
29
30 # autoencoder Model
31 x = tf.keras.layers.Reshape((4, 4, 1))(encoder_output)
32 x = tf.keras.layers.Conv2DTranspose(16, 3, activation="relu")(x)
33 x = tf.keras.layers.Conv2DTranspose(32, 3, activation="relu")(x)
34 x = tf.keras.layers.UpSampling2D(3)(x)
35 x = tf.keras.layers.Conv2DTranspose(16, 3, activation="relu")(x)
36 decoder_output = tf.keras.layers.Conv2DTranspose(1, 3, activation="relu")(x)
37 autoencoder = tf.keras.Model(encoder_input, decoder_output, name="autoencoder")
38 autoencoder.summary()
Model: "autoencoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ img (InputLayer)                │ (None, 28, 28, 1)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d (Conv2D)                 │ (None, 26, 26, 16)     │           160 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 24, 24, 32)     │         4,640 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 8, 8, 32)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 6, 6, 32)       │         9,248 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_3 (Conv2D)               │ (None, 4, 4, 16)       │         4,624 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_max_pooling2d            │ (None, 16)             │             0 │
│ (GlobalMaxPooling2D)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ reshape (Reshape)               │ (None, 4, 4, 1)        │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose                │ (None, 6, 6, 16)       │           160 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_1              │ (None, 8, 8, 32)       │         4,640 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ up_sampling2d (UpSampling2D)    │ (None, 24, 24, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_2              │ (None, 26, 26, 16)     │         4,624 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_3              │ (None, 28, 28, 1)      │           145 │
│ (Conv2DTranspose)               │                        │               │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 28,241 (110.32 KB)
 Trainable params: 28,241 (110.32 KB)
 Non-trainable params: 0 (0.00 B)

Build the model with an explicit input shape

42 autoencoder.build(
43     input_shape=(None, 28, 28, 1)
44 )  # Batch size of None, shape (28, 28, 1)
45
46 # Create a dummy input tensor with a batch size of 1
47 dummy_input = tf.random.normal([1, 28, 28, 1])  # Batch size of 1, shape (28, 28, 1)
48 # Run the dummy input through the model to trigger shape calculation
49 encoder_output = autoencoder(dummy_input)
50 # Now check the output shape of the encoder
51 print("Output shape after running model with dummy input:", encoder_output.shape)
52
53 # Check each layer's output shape after building the model
54 for layer in encoder.layers:
55     if hasattr(layer, "output_shape"):
56         print(f"{layer.name} output shape: {layer.output_shape}")
57     if hasattr(layer, "output"):
58         print(f"{layer.name} shape: {layer.output.shape}")
Output shape after running model with dummy input: (1, 28, 28, 1)
img shape: (None, 28, 28, 1)
conv2d shape: (None, 26, 26, 16)
conv2d_1 shape: (None, 24, 24, 32)
max_pooling2d shape: (None, 8, 8, 32)
conv2d_2 shape: (None, 6, 6, 32)
conv2d_3 shape: (None, 4, 4, 16)
global_max_pooling2d shape: (None, 16)
61 from scikitplot import visualkeras
62
63 img_encoder = visualkeras.layered_view(
64     encoder,
65     text_callable="default",
66     # to_file="result_images/encoder.png",
67     save_fig=True,
68     save_fig_filename="encoder.png",
69 )
70 img_encoder
plot autoencoder
2025-06-27 09:09:35.492813: W scikitplot 140532842670976 utils_pil.py:203:load_font] Error loading system font: cannot open resource
2025-06-27 09:09:35.492891: W scikitplot 140532842670976 utils_pil.py:205:load_font] Falling back to PIL default font.
2025-06-27 09:09:35.493024: W scikitplot 140532842670976 layered.py:203:layered_view] The legend_text_spacing_offset parameter is deprecated andwill be removed in a future release.

<matplotlib.image.AxesImage object at 0x7fcf6c741c10>
73 img_autoencoder = visualkeras.layered_view(
74     autoencoder,
75     # to_file="result_images/autoencoder.png",
76     save_fig=True,
77     save_fig_filename="autoencoder.png",
78 )
79 img_autoencoder
plot autoencoder
2025-06-27 09:09:35.660838: W scikitplot 140532842670976 utils_pil.py:203:load_font] Error loading system font: cannot open resource
2025-06-27 09:09:35.660916: W scikitplot 140532842670976 utils_pil.py:205:load_font] Falling back to PIL default font.
2025-06-27 09:09:35.661049: W scikitplot 140532842670976 layered.py:203:layered_view] The legend_text_spacing_offset parameter is deprecated andwill be removed in a future release.

<matplotlib.image.AxesImage object at 0x7fcf5476ec90>
82 img_autoencoder_text = visualkeras.layered_view(
83     autoencoder,
84     min_z=1,
85     min_xy=1,
86     max_z=4096,
87     max_xy=4096,
88     scale_z=1,
89     scale_xy=1,
90     # font={"font_size": 14},
91     text_callable="default",
92     # to_file="result_images/autoencoder_text.png",
93     save_fig=True,
94     save_fig_filename="autoencoder_text.png",
95     overwrite=False,
96     add_timestamp=True,
97     verbose=True,
98 )
99 img_autoencoder_text
plot autoencoder
2025-06-27 09:09:35.798412: W scikitplot 140532842670976 utils_pil.py:203:load_font] Error loading system font: cannot open resource
2025-06-27 09:09:35.798497: W scikitplot 140532842670976 utils_pil.py:205:load_font] Falling back to PIL default font.
2025-06-27 09:09:35.798636: W scikitplot 140532842670976 layered.py:203:layered_view] The legend_text_spacing_offset parameter is deprecated andwill be removed in a future release.
[INFO] Saving path to: /home/circleci/repo/galleries/examples/visualkeras_CNN/result_images/autoencoder_text_20250627_090935Z.png

<matplotlib.image.AxesImage object at 0x7fcf543147d0>

Tags: model-type: classification model-workflow: model building plot-type: visualkeras domain: neural network level: beginner purpose: showcase

Total running time of the script: (0 minutes 3.840 seconds)

Related examples

visualkeras: transformers example

visualkeras: transformers example

Visualkeras: Spam Classification Conv1D Dense Example

Visualkeras: Spam Classification Conv1D Dense Example

visualkeras: Spam Dense example

visualkeras: Spam Dense example

visualkeras: custom VGG example

visualkeras: custom VGG example

Gallery generated by Sphinx-Gallery