Understanding Transformer Models in Deep Learning, Explaining with Salience Maps

Deep learning models, especially transformer models, have achieved remarkable success in various natural language processing tasks. However, understanding how these models make predictions can be challenging. In this blog post, we will delve into methods for explaining transformer models and focus on one specific technique: generating salience maps.

1. Explaining Transformer Models:

Transformer models revolutionized the field of natural language processing (NLP) with their attention mechanisms. Before explaining the details of salience maps, let’s explore a few general methods for understanding transformer models.

a) Attention Visualization:
Visualization of attention weights allows us to identify the input regions that receive higher importance. By observing these attention weights, we gain insights into what the model focuses on during processing.

b) Gradient-based Attribution Methods:
These methods calculate the gradients of model predictions with respect to input tokens. The gradient information helps attribute importance scores to individual tokens, indicating their contribution to the model’s output.

c) Influence Functions:
The influence function approach analyzes how model predictions change when inputs are modified. By identifying influential inputs, we gain a better understanding of the model’s behavior and can debug or optimize it accordingly.

2. Salience Maps: Capturing Token Importance:

One method for explaining transformer models is generating salience maps. Salience maps highlight the most important tokens in the input sequence that contribute significantly to the model’s predictions. Let’s explore the steps involved in obtaining salience maps.

a) Loading the Model and Tokenizer:
To generate salience maps, we first load the pre-trained transformer model and its corresponding tokenizer. These components are essential for preprocessing the input text and executing the model.

b) Preprocessing the Input:
Text data is tokenized and encoded as numerical sequences suitable for the transformer model. Proper preprocessing ensures compatibility between the input and the transformer’s expectations.

c) Using Integrated Gradients:
Integrated Gradients is a popular gradient-based attribution method. By computing the gradients of the model’s predictions with respect to input tokens, Integrated Gradients assigns importance scores to each token. These scores indicate how much a token contributes to the model’s output.

d) Normalizing and Visualizing the Salience Weights:
The importance scores obtained from Integrated Gradients are normalized to ensure they sum up to 1. By multiplying these scores with the corresponding token embeddings, we obtain the salience weights. These weights are then visualized as a heatmap overlaid on the input sequence, highlighting the most important regions.

Example Code:

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients
import seaborn as sns
import matplotlib.pyplot as plt

# Load the pre-trained tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Define a sample input sequence
input_text = "The movie was great, I really enjoyed it."

# Tokenize the input text and encode it as a tensor
input_ids = torch.tensor(tokenizer.encode(input_text, add_special_tokens=True)).unsqueeze(0)

# Define the baseline input tensor for the Integrated Gradients method
baseline_ids = torch.zeros_like(input_ids)

# Instantiate the Integrated Gradients method
integrated_grads = IntegratedGradients(model)

# Calculate the importance scores for each token in the input sequence
attributions, _ = integrated_grads.attribute(inputs=input_ids, baselines=baseline_ids,
target=0, return_convergence_delta=True)

# Normalize the attributions to sum to 1 and convert them to numpy arrays
salience_weights = attributions[0].detach().numpy() / sum(attributions[0]).detach().numpy()

# Extract the list of tokens from the input sequence
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

# Plot the salience weights as a heatmap overlaid on the input sequence
plt.figure(figsize=(10, 5))
sns.heatmap([salience_weights], annot=[tokens], cmap="Blues", fmt="")
plt.title("Salience Map for Input Text")
plt.ylabel("Salience Weights")


Understanding the inner workings of transformer models is crucial for their interpretability and trustworthiness. Explaining transformer models using methods like attention visualization, gradient-based attribution, influence functions, and salience maps helps shed light on their decision-making processes. Salience maps, in particular, offer intuitive insights by highlighting the important tokens in the input. By leveraging techniques like Integrated Gradients, we can generate salience maps and gain a better understanding of how transformer models process and interpret language.

In conclusion, generating salience maps and employing other methods to explain transformer models enhance their interpretability and enable more robust analysis and improvement. These techniques contribute to the broader field of explainable AI, promoting transparency and trust in deep learning models.


Author: robot learner
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source robot learner !