visualkeras: transformers example#

An example showing the visualkeras function used by a tf.keras.Model or Module or TFPreTrainedModel model.

# Authors: The scikit-plots developers
# SPDX-License-Identifier: BSD-3-Clause

Force garbage collection

import gc

gc.collect()
3
# pip install protobuf==5.29.4
import tensorflow as tf

# Clear any session to reset the state of TensorFlow/Keras
tf.keras.backend.clear_session()

from transformers import TFAutoModel

from scikitplot import visualkeras
# Load the Hugging Face transformer model
transformer_model = TFAutoModel.from_pretrained("microsoft/mpnet-base")


# Define a Keras-compatible wrapper for the Hugging Face model
def wrap_transformer_model(inputs):
    input_ids, attention_mask = inputs
    outputs = transformer_model(input_ids=input_ids, attention_mask=attention_mask)
    return outputs.last_hidden_state  # Return the last hidden state for visualization
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFMPNetModel: ['lm_head.bias', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing TFMPNetModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFMPNetModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFMPNetModel were not initialized from the PyTorch model and are newly initialized: ['mpnet.pooler.dense.weight', 'mpnet.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# Define Keras model inputs
input_ids = tf.keras.Input(shape=(128,), dtype=tf.int32, name="input_ids")
attention_mask = tf.keras.Input(shape=(128,), dtype=tf.int32, name="attention_mask")

# Pass inputs through the transformer model using a Lambda layer
last_hidden_state = tf.keras.layers.Lambda(
    wrap_transformer_model,
    output_shape=(128, 768),  # Explicitly specify the output shape
    name="microsoft_mpnet-base",
)([input_ids, attention_mask])

# Reshape the output to fit into Conv2D (adding extra channel dimension) inside a Lambda layer
# def reshape_last_hidden_state(x):
#     return tf.reshape(x, (-1, 1, 128, 768))
# reshaped_output = tf.keras.layers.Lambda(reshape_last_hidden_state)(last_hidden_state)
# Use Reshape layer to reshape the output to fit into Conv2D (adding extra channel dimension)
# Reshape to (batch_size, 128, 768, 1) for Conv2D input
reshaped_output = tf.keras.layers.Reshape((-1, 128, 768))(last_hidden_state)

# Add different layers to the model
x = tf.keras.layers.Conv2D(
    512, (3, 3), activation="relu", padding="same", name="conv2d_1"
)(reshaped_output)
x = tf.keras.layers.BatchNormalization(name="batchnorm_1")(x)
x = tf.keras.layers.Dropout(0.3, name="dropout_1")(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name="maxpool_1")(x)

x = tf.keras.layers.Conv2D(
    256, (3, 3), activation="relu", padding="same", name="conv2d_2"
)(x)
x = tf.keras.layers.BatchNormalization(name="batchnorm_2")(x)
x = tf.keras.layers.Dropout(0.3, name="dropout_2")(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name="maxpool_2")(x)

x = tf.keras.layers.Conv2D(
    128, (3, 3), activation="relu", padding="same", name="conv2d_3"
)(x)
x = tf.keras.layers.BatchNormalization(name="batchnorm_3")(x)
x = tf.keras.layers.Dropout(0.4, name="dropout_3")(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name="maxpool_3")(x)

# Add GlobalAveragePooling2D before the Dense layers
x = tf.keras.layers.GlobalAveragePooling2D(name="globalaveragepool")(x)

# Add Dense layers
x = tf.keras.layers.Dense(512, activation="relu", name="dense_1")(x)
x = tf.keras.layers.Dropout(0.5, name="dropout_4")(x)
x = tf.keras.layers.Dense(128, activation="relu", name="dense_2")(x)

# Add output layer (classification head)
dummy_output = tf.keras.layers.Dense(
    2, activation="softmax", name="dummy_classification_head"
)(x)

# Wrap into a Keras model
wrapped_model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=dummy_output)

# https://github.com/keras-team/keras/blob/v3.3.3/keras/src/models/model.py#L217
# https://github.com/keras-team/keras/blob/master/keras/src/utils/summary_utils.py#L121
wrapped_model.summary(
    line_length=None,
    positions=None,
    print_fn=None,
    expand_nested=False,
    show_trainable=True,
    layer_range=None,
)
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃ Layer (type)      ┃ Output Shape    ┃   Param # ┃ Connected to   ┃ Trai… ┃
┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━┩
│ input_ids         │ (None, 128)     │         0 │ -              │   -   │
│ (InputLayer)      │                 │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ attention_mask    │ (None, 128)     │         0 │ -              │   -   │
│ (InputLayer)      │                 │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ microsoft_mpnet-… │ (None, 128,     │         0 │ input_ids[0][… │   -   │
│ (Lambda)          │ 768)            │           │ attention_mas… │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ reshape (Reshape) │ (None, 1, 128,  │         0 │ microsoft_mpn… │   -   │
│                   │ 768)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ conv2d_1 (Conv2D) │ (None, 1, 128,  │ 3,539,456 │ reshape[0][0]  │   Y   │
│                   │ 512)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ batchnorm_1       │ (None, 1, 128,  │     2,048 │ conv2d_1[0][0] │   Y   │
│ (BatchNormalizat… │ 512)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ dropout_1         │ (None, 1, 128,  │         0 │ batchnorm_1[0… │   -   │
│ (Dropout)         │ 512)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ maxpool_1         │ (None, 0, 64,   │         0 │ dropout_1[0][… │   -   │
│ (MaxPooling2D)    │ 512)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ conv2d_2 (Conv2D) │ (None, 0, 64,   │ 1,179,904 │ maxpool_1[0][… │   Y   │
│                   │ 256)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ batchnorm_2       │ (None, 0, 64,   │     1,024 │ conv2d_2[0][0] │   Y   │
│ (BatchNormalizat… │ 256)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ dropout_2         │ (None, 0, 64,   │         0 │ batchnorm_2[0… │   -   │
│ (Dropout)         │ 256)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ maxpool_2         │ (None, 0, 32,   │         0 │ dropout_2[0][… │   -   │
│ (MaxPooling2D)    │ 256)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ conv2d_3 (Conv2D) │ (None, 0, 32,   │   295,040 │ maxpool_2[0][… │   Y   │
│                   │ 128)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ batchnorm_3       │ (None, 0, 32,   │       512 │ conv2d_3[0][0] │   Y   │
│ (BatchNormalizat… │ 128)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ dropout_3         │ (None, 0, 32,   │         0 │ batchnorm_3[0… │   -   │
│ (Dropout)         │ 128)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ maxpool_3         │ (None, 0, 16,   │         0 │ dropout_3[0][… │   -   │
│ (MaxPooling2D)    │ 128)            │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ globalaveragepool │ (None, 128)     │         0 │ maxpool_3[0][… │   -   │
│ (GlobalAveragePo… │                 │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ dense_1 (Dense)   │ (None, 512)     │    66,048 │ globalaverage… │   Y   │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ dropout_4         │ (None, 512)     │         0 │ dense_1[0][0]  │   -   │
│ (Dropout)         │                 │           │                │       │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ dense_2 (Dense)   │ (None, 128)     │    65,664 │ dropout_4[0][… │   Y   │
├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
│ dummy_classifica… │ (None, 2)       │       258 │ dense_2[0][0]  │   Y   │
│ (Dense)           │                 │           │                │       │
└───────────────────┴─────────────────┴───────────┴────────────────┴───────┘
 Total params: 5,149,954 (19.65 MB)
 Trainable params: 5,148,162 (19.64 MB)
 Non-trainable params: 1,792 (7.00 KB)
# Visualize the wrapped model
img_nlp_mpnet_with_tf_layers = visualkeras.layered_view(
    wrapped_model,
    legend=True,
    show_dimension=True,
    min_z=1,
    min_xy=1,
    max_z=4096,
    max_xy=4096,
    scale_z=1,
    scale_xy=1,
    font={"font_size": 99},
    text_callable="default",
    # to_file="result_images/nlp_mpnet_with_tf_layers.png",
    save_fig=True,
    save_fig_filename="nlp_mpnet_with_tf_layers.png",
    overwrite=False,
    add_timestamp=True,
    verbose=True,
)
plot nlp mpnet with tf layers
[INFO] Saving path to: /home/circleci/repo/galleries/examples/visualkeras_NLP/result_images/nlp_mpnet_with_tf_layers_20250422_154025Z.png
[INFO] Image saved using Matplotlib: /home/circleci/repo/galleries/examples/visualkeras_NLP/result_images/nlp_mpnet_with_tf_layers_20250422_154025Z.png

Tags: model-type: classification model-workflow: model building plot-type: visualkeras domain: neural network level: advanced purpose: showcase

Total running time of the script: (0 minutes 9.808 seconds)

Related examples

visualkeras: autoencoder example

visualkeras: autoencoder example

Visualkeras: Spam Classification Conv1D Dense Example

Visualkeras: Spam Classification Conv1D Dense Example

visualkeras: custom vgg16 show dimension example

visualkeras: custom vgg16 show dimension example

visualkeras: custom vgg16 example

visualkeras: custom vgg16 example

Gallery generated by Sphinx-Gallery