BERT: Visualizing Attention#

BertViz is a powerful visualization tool designed to help users understand and interpret the inner workings of the BERT model and its variants. By providing insights into the attention mechanisms, multi-head attention, and self-attention layers, BertViz enables researchers, practitioners, and enthusiasts to better comprehend the complex relationships BERT captures within a given input text.


To install BertViz, run the following command:

pip install bertviz
%pip install bertviz


We first import the necessary libraries and modules from bertviz, transformers, and utils. We then specify the BERT model version (bert-base-uncased), and load the tokenizer and model using the AutoTokenizer and AutoModel classes from the transformers library.

%config InlineBackend.figure_format='retina'

from bertviz import model_view, head_view
from bertviz.neuron_view import show
from transformers import AutoTokenizer, AutoModel, utils

utils.logging.set_verbosity_error()  # Suppress standard warnings

model_version = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_version)
model = AutoModel.from_pretrained(model_version, output_attentions=True)
Next, we define two sentences, sentence_a and sentence_b, which will serve as the input text for the visualization. We tokenize the sentences using the tokenizer.encode() function and pass them as input to the BERT model. The model processes the input and returns the attention weights as part of the output.

Now that we have the attention weights and tokenized input, we can visualize the attention mechanisms using bertviz. There are three primary visualization options available in bertviz: model_view, head_view, and neuron_view. To use any of these visualizations, you can call the corresponding function and pass in the required parameters.

sentence_a = "I went to the store."
sentence_b = "At the store, I bought fresh strawberries."

inputs = tokenizer.encode(
    [sentence_a, sentence_b],
outputs = model(inputs)
attention = outputs[-1]
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

For example, to use the head_view visualization, you can call the head_view function as follows:

head_view(attention, tokens)
