visualkeras transformers example#

An example showing the visualkeras function used by a tf.keras.Model or Module or TFPreTrainedModel model.

plot nlp mpnet with tf layers
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFMPNetModel: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.bias']
- This IS expected if you are initializing TFMPNetModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFMPNetModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFMPNetModel were not initialized from the PyTorch model and are newly initialized: ['mpnet.pooler.dense.weight', 'mpnet.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

# Authors: The scikit-plots developers
# SPDX-License-Identifier: BSD-3-Clause

# Force garbage collection
import gc; gc.collect()
import tensorflow as tf
# Clear any session to reset the state of TensorFlow/Keras
tf.keras.backend.clear_session()

from transformers import TFAutoModel

from scikitplot import visualkeras

# Load the Hugging Face transformer model
transformer_model = TFAutoModel.from_pretrained("microsoft/mpnet-base")

# Define a Keras-compatible wrapper for the Hugging Face model
def wrap_transformer_model(inputs):
    input_ids, attention_mask = inputs
    outputs = transformer_model(input_ids=input_ids, attention_mask=attention_mask)
    return outputs.last_hidden_state  # Return the last hidden state for visualization

# Define Keras model inputs
input_ids = tf.keras.Input(shape=(128,), dtype=tf.int32, name="input_ids")
attention_mask = tf.keras.Input(shape=(128,), dtype=tf.int32, name="attention_mask")

# Pass inputs through the transformer model using a Lambda layer
last_hidden_state = tf.keras.layers.Lambda(
    wrap_transformer_model,
    output_shape=(128, 768),  # Explicitly specify the output shape
    name='microsoft_mpnet-base',
)([input_ids, attention_mask])

# Reshape the output to fit into Conv2D (adding extra channel dimension) inside a Lambda layer
# def reshape_last_hidden_state(x):
#     return tf.reshape(x, (-1, 1, 128, 768))
# reshaped_output = tf.keras.layers.Lambda(reshape_last_hidden_state)(last_hidden_state)
# Use Reshape layer to reshape the output to fit into Conv2D (adding extra channel dimension)
# Reshape to (batch_size, 128, 768, 1) for Conv2D input
reshaped_output = tf.keras.layers.Reshape((-1, 128, 768))(last_hidden_state)

# Add different layers to the model
x = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='conv2d_1')(reshaped_output)
x = tf.keras.layers.BatchNormalization(name='batchnorm_1')(x)
x = tf.keras.layers.Dropout(0.3, name='dropout_1')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name='maxpool_1')(x)

x = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='conv2d_2')(x)
x = tf.keras.layers.BatchNormalization(name='batchnorm_2')(x)
x = tf.keras.layers.Dropout(0.3, name='dropout_2')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name='maxpool_2')(x)

x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2d_3')(x)
x = tf.keras.layers.BatchNormalization(name='batchnorm_3')(x)
x = tf.keras.layers.Dropout(0.4, name='dropout_3')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name='maxpool_3')(x)

# Add GlobalAveragePooling2D before the Dense layers
x = tf.keras.layers.GlobalAveragePooling2D(name='globalaveragepool')(x)

# Add Dense layers
x = tf.keras.layers.Dense(512, activation='relu', name='dense_1')(x)
x = tf.keras.layers.Dropout(0.5, name='dropout_4')(x)
x = tf.keras.layers.Dense(128, activation='relu', name='dense_2')(x)

# Add output layer (classification head)
dummy_output = tf.keras.layers.Dense(2, activation='softmax', name="dummy_classification_head")(x)

# Wrap into a Keras model
wrapped_model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=dummy_output)

# Visualize the wrapped model
img_nlp_mpnet_with_tf_layers = visualkeras.layered_view(
    wrapped_model,
    legend=True,
    show_dimension=True,
    scale_xy=1, scale_z=1, max_z=250,
    to_file="../result_images/nlp_mpnet_with_tf_layers.png"
)
try:
    import matplotlib.pyplot as plt
    plt.imshow(img_nlp_mpnet_with_tf_layers)
    plt.axis('off')
    plt.show()
except:
    pass

Total running time of the script: (0 minutes 23.743 seconds)

Related examples

visualkeras custom vgg16 show dimension example

visualkeras custom vgg16 show dimension example

visualkeras custom vgg16 example

visualkeras custom vgg16 example

visualkeras autoencoder example

visualkeras autoencoder example

Visualkeras Spam Classification Conv1D Dense Example

Visualkeras Spam Classification Conv1D Dense Example

Gallery generated by Sphinx-Gallery