Adapting Grad-CAM for Image-to-Text Models: A Step-by-Step Guide
Grad-CAM is a powerful visualization tool originally designed for CNN architectures to highlight what parts of an image influence neural network decisions. Today, I’ll show you how I’ve adapted Grad-CAM to work with an image-to-text transformer model, specifically using the TrOCR model from Hugging Face.
Step 1: Token Generation from the Model
The first step involves generating tokens from our TrOCR model. These tokens are essentially the model’s interpretation of the image in a textual format, which we’ll later use for gradient computation.
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
image_path = "your_image_path.jpg"
image = Image.open(image_path).convert("RGB")
def get_generated_tokens(image, model, processor):
pixel_values = processor(images=image, return_tensors="pt").pixel_values
generated_tokens = model.generate(pixel_values=pixel_values, max_length=50)
return generated_tokens, pixel_values
generated tokens will look something like tensor([[ 2, 14200, 2022, 2]]). where ‘2’ token is the special token, representing the start and end.
Step 2: Layer Selection for Grad-CAM
Choosing the right layer is crucial because the effectiveness of Grad-CAM depends on capturing relevant activations that correlate with output predictions. In transformers, this is typically one of the final layers.
If you are unsure about the layer and output shape simply print the model or use torchsummary library for detailed output shapes of each layer.
For the above model, I have selected the last layer of the ViT encoder.
layer_name = model.encoder.encoder.layer[-1].output
Note: Here I have used .output in the layer_name because Huggingface model can return a dictionary or tuple, if it is a torch model just the name of the layer is good enough.
Step 3: Attaching Hooks to Capture Outputs and Gradients
We attach a forward hook to the selected layer to capture the outputs during the forward pass and retain them for computing gradients during the backward pass.
last_layer_output = None
def save_output(module, input, output):
global last_layer_output
last_layer_output = output
output.retain_grad()
last_layer = layer_name
last_layer.register_forward_hook(save_output)
Step 4: Targeting Specific Tokens
We select specific tokens to compute how much each part of the input image contributed to predicting that token, providing insights into model decisions.
Step 5: Reshaping Layer Outputs
Transformers output activations in a different format compared to CNNs. We transform these to mimic CNN feature maps, enabling us to apply Grad-CAM effectively:
Understanding the output shape of the selected layer:
- (Batch_size, Tokens, Features or Channels) ->(1, 577,796)
- Remove the first token [CLS] if it is ViT ->(1, 576, 796)
- If the feature map is square, which is true in this case ->(1, 24, 24, 796)
- Apply transpose, so features becomes just like CNN ->(1, 796, 24, 24)
def reshape_transform_vit_huggingface(x):
activations = x[:, 1:, :] # Remove the first token which is used for classification in some architectures
side_length = int(np.sqrt(activations.shape[1])) # Assuming the feature map is square
activations = activations.view(activations.shape[0], side_length, side_length, activations.shape[2])
activations = activations.transpose(2, 3).transpose(1, 2)
return activations
Step 6: Applying Grad-CAM
Finally, we apply the Grad-CAM algorithm to highlight the important areas of the image for each token. The algorithm uses gradients of the target token wrt the activations from our selected layer, weighted and summed to create a heatmap.
transform_output = reshape_transform_vit_huggingface(layer_output)
transform_grad = reshape_transform_vit_huggingface(grad)
weights = torch.mean(transform_grad, dim=(2, 3), keepdim=True) # Average across the spatial dimensions
# Step 2: Weighted combination of activation maps
grad_cam = torch.sum(weights * transform_output, dim=1, keepdim=True) # Sum over the feature maps
# Step 3: Apply ReLU
grad_cam = torch.relu(grad_cam) # Only take positive contributions
grad_cam = grad_cam.squeeze(0) # Remove batch dimension for visualization
# Step 4: Normalize (optional but helps in visualization)
grad_cam = grad_cam / grad_cam.max()
print("Grad-CAM shape:", grad_cam.shape)
heatmap = torch.nn.functional.interpolate(grad_cam.unsqueeze(0), size=(image.size[1], image.size[0]), mode='bilinear', align_corners=False)
heatmap = heatmap.squeeze().detach().numpy()
Implementation in Python
Here’s the complete Python code that accomplishes all the above steps using PyTorch, PIL for image handling, and matplotlib for visualization:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# Load the pre-trained processor and model
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# Load and process the image
image_path = "00809.jpg"
image = Image.open(image_path).convert("RGB")
def reshape_transform_vit_huggingface(x):
activations = x[:, 1:, :]
x = np.sqrt(activations.shape[1])
activations = activations.view(activations.shape[0], int(x), int(x), activations.shape[2])
activations = activations.transpose(2, 3).transpose(1, 2)
return activations
def get_generated_tokens(image, model, processor):
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# Forward pass
generated_tokens = model.generate(pixel_values=pixel_values, max_length=50)
return generated_tokens, pixel_values
last_layer_output = None
def get_activations_and_gradient(pixel_values, model, processor, generated_tokens, layer_name, token_index = 0):
text = processor.decode(generated_tokens[0, token_index], skip_special_tokens=False)
def save_output(module, input, output):
global last_layer_output
last_layer_output = output
output.retain_grad()
last_layer = layer_name
last_layer.register_forward_hook(save_output)
outputs = model(pixel_values=pixel_values, decoder_input_ids=generated_tokens[:, :-1], return_dict=True)
# Backward pass on a selected logit
selected_logit = outputs.logits[0, token_index, generated_tokens[0, token_index]]
selected_logit.backward()
return last_layer_output, last_layer_output.grad, text
def apply_gradcam(layer_output, grad, image, index, text):
from skimage import color
transform_output = reshape_transform_vit_huggingface(layer_output)
transform_grad = reshape_transform_vit_huggingface(grad)
weights = torch.mean(transform_grad, dim=(2, 3), keepdim=True) # Average across the spatial dimensions
# Step 2: Weighted combination of activation maps
grad_cam = torch.sum(weights * transform_output, dim=1, keepdim=True) # Sum over the feature maps
# Step 3: Apply ReLU
grad_cam = torch.relu(grad_cam) # Only take positive contributions
grad_cam = grad_cam.squeeze(0) # Remove batch dimension for visualization
# Step 4: Normalize (optional but helps in visualization)
grad_cam = grad_cam / grad_cam.max()
print("Grad-CAM shape:", grad_cam.shape)
heatmap = torch.nn.functional.interpolate(grad_cam.unsqueeze(0), size=(image.size[1], image.size[0]), mode='bilinear', align_corners=False)
heatmap = heatmap.squeeze().detach().numpy()
blended = Image.blend(image.convert('RGBA'), Image.fromarray((plt.cm.jet(heatmap)* 255).astype(np.uint8)).convert('RGBA'), alpha=0.5)
blended.save(f"blended_image_{index}.png", format='PNG')
return {f"{text}": f"blended_image_{index}.png"}
layer_name = model.encoder.encoder.layer[-1].output
generated_tokens, pixel_values = get_generated_tokens(image, model, processor)
print(generated_tokens)
for index, tokens in enumerate(generated_tokens[:, :-1].numpy().tolist()[0]):
layer_output, grad, text = get_activations_and_gradient(pixel_values, model, processor, generated_tokens, layer_name, token_index=index)
data = apply_gradcam(layer_output, grad, image, index, text)
print(data)
Conclusion
By adapting Grad-CAM for use with a transformer model, we can gain insights into which parts of the image the model focuses on when generating text. This technique can be incredibly useful for debugging and improving model performance, particularly in applications like automated content description and OCR.
I hope you found this guide helpful. For more insights and discussions on technology and innovation, feel free to follow me on Linkedin: www.linkedin.com/in/meetvpatel. I look forward to connecting with you!