-
Notifications
You must be signed in to change notification settings - Fork 28
/
visualize_attention.py
137 lines (116 loc) · 5.61 KB
/
visualize_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Script for visualizing attention mechanisms within BERT.
"""
import logging
import random
import re
import numpy as np
from tqdm import trange
import torch
from torch.utils.data import DataLoader
from models.finetuned_models import FineTunedBert
from utils.data_utils import IMDBDataset
from utils.model_utils import get_normalized_attention
from utils.visualization_utils import visualize_attention
# Disable unwanted warning messages from pytorch_transformers
# NOTE: Run once without the line below to check if anything is wrong, here we target to eliminate
# the message "Token indices sequence length is longer than the specified maximum sequence length"
# since we already take care of it within the tokenize() function through fixing sequence length
logging.getLogger('pytorch_transformers').setLevel(logging.CRITICAL)
# Specify DEVICE
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("DEVICE FOUND: %s" % DEVICE)
# Set seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# Define hyperparameters
PRETRAINED_MODEL_NAME = 'bert-base-cased'
NUM_PRETRAINED_BERT_LAYERS = 12
NUM_ATTENTION_HEADS = 12
MAX_TOKENIZATION_LENGTH = 512
NUM_CLASSES = 2
TOP_DOWN = True
NUM_RECURRENT_LAYERS = 0
HIDDEN_SIZE = 128
REINITIALIZE_POOLER_PARAMETERS = False
USE_BIDIRECTIONAL = False
DROPOUT_RATE = 0.20
AGGREGATE_ON_CLS_TOKEN = True
CONCATENATE_HIDDEN_STATES = False
SAVED_MODEL_PATH = 'finetuned-bert-model-12VA.pt'
APPLY_CLEANING = False
TRUNCATION_METHOD = 'head-only'
NUM_WORKERS = 0
ATTENTION_VISUALIZATION_METHOD = 'custom' # specify which layer, head, and token yourself
LAYER_ID = 11
HEAD_ID = 11
TOKEN_ID = 0
EXCLUDE_SPECIAL_TOKENS = True # exclude [CLS] and [SEP] tokens
NUM_EXAMPLES = 5
# Initialize model
model = FineTunedBert(pretrained_model_name=PRETRAINED_MODEL_NAME,
num_pretrained_bert_layers=NUM_PRETRAINED_BERT_LAYERS,
max_tokenization_length=MAX_TOKENIZATION_LENGTH,
num_classes=NUM_CLASSES,
top_down=TOP_DOWN,
num_recurrent_layers=NUM_RECURRENT_LAYERS,
use_bidirectional=USE_BIDIRECTIONAL,
hidden_size=HIDDEN_SIZE,
reinitialize_pooler_parameters=REINITIALIZE_POOLER_PARAMETERS,
dropout_rate=DROPOUT_RATE,
aggregate_on_cls_token=AGGREGATE_ON_CLS_TOKEN,
concatenate_hidden_states=CONCATENATE_HIDDEN_STATES,
use_gpu=True if torch.cuda.is_available() else False)
# Load model weights & assign model to correct device
model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
# Get tokenizer
tokenizer = model.get_tokenizer()
# Acquire test iterator through data loader
test_dataset = IMDBDataset(input_directory='data/aclImdb/test',
tokenizer=tokenizer,
apply_cleaning=APPLY_CLEANING,
max_tokenization_length=MAX_TOKENIZATION_LENGTH,
truncation_method=TRUNCATION_METHOD,
device=DEVICE)
test_loader = DataLoader(dataset=test_dataset,
batch_size=len(test_dataset),
shuffle=True,
num_workers=NUM_WORKERS)
# Get all test movie reviews from test data for attention visualizations
test_input_ids = next(iter(test_loader))[0].tolist()
print("NUMBER OF TEST EXAMPLES: %d" % len(test_input_ids))
for i in trange(NUM_EXAMPLES, desc='Attending to Test Reviews', leave=True):
example_test_input_ids = test_input_ids[i]
example_test_sentence = tokenizer.decode(token_ids=example_test_input_ids)
# Extract the first component in case the tokenizer categorized the text in two >= 2 pieces
# NOTE: This usually happens when there are multiple padding ([PAD]) tokens in the text
if isinstance(example_test_sentence, list):
example_test_sentence = example_test_sentence[0]
# Remove all model-induced tags to visualize attention weights on only original tokens
example_test_sentence = example_test_sentence.replace('[CLS]', '')
example_test_sentence = example_test_sentence.replace('[SEP]', '')
example_test_sentence = example_test_sentence.replace('[UNK]', '')
example_test_sentence = example_test_sentence.replace('[PAD]', '')
example_test_sentence = example_test_sentence.lstrip().rstrip()
example_test_sentence = re.sub(' +', ' ', example_test_sentence)
tokens_and_weights = get_normalized_attention(model=model,
raw_sentence=example_test_sentence,
method=ATTENTION_VISUALIZATION_METHOD,
n=LAYER_ID,
m=HEAD_ID,
k=TOKEN_ID,
exclude_special_tokens=EXCLUDE_SPECIAL_TOKENS,
normalization_method='min-max',
device=DEVICE)
visualize_attention(window_name="Attention Visualization of " +
"LAYER ID.: %d, HEAD ID.: %d, TOKEN ID.: %d on EXAMPLE ID.: %d" %
(LAYER_ID, HEAD_ID, TOKEN_ID, i),
tokens_and_weights=tokens_and_weights)