-
Notifications
You must be signed in to change notification settings - Fork 0
/
blocklm_utils.py
474 lines (458 loc) · 24.9 KB
/
blocklm_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
import torch
import torch.utils.data
import mpu
import random
import copy
import numpy as np
import math
from utils import print_rank_0
from scipy.stats import poisson
def rindex(lst, val, start=None):
if start is None:
start = len(lst) - 1
for i in range(start, -1, -1):
if lst[i] == val:
return i
return -1
def index_in_list(lst, val, start=None):
if start is None:
start = 0
for i in range(start, len(lst)):
if lst[i] == val:
return i
return -1
class ConstructBlockStrategy:
def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5,
gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3,
max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3,
short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False,
shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False):
self.eod_token = args.eod_token
self.tokenizer = tokenizer
self.count = 0
self.max_seq_length = max_seq_length
self.rank = mpu.get_data_parallel_rank()
self.world_size = mpu.get_data_parallel_world_size()
# self.rank = 0
# self.world_size = 1
assert 0.0 <= bert_prob <= 1.0
self.bert_prob = bert_prob
self.gap_sentence_prob = gap_sentence_prob
self.gpt_prob = 1 - bert_prob - gap_sentence_prob
assert self.gpt_prob >= -1e-10
self.infill_prob = gpt_infill_prob
self.gpt_min_ratio = gpt_min_ratio
self.bert_ratio = bert_ratio
self.gap_sentence_ratio = gap_sentence_ratio
self.block_length_distribution = [poisson.pmf(i, average_block_length) for i in range(1, max_block_length)]
self.block_mask_prob = block_mask_prob
self.context_mask_ratio = context_mask_ratio
self.context_mask_range = context_mask_range
self.short_seq_prob = short_seq_prob
self.single_span_prob = single_span_prob
self.block_position_encoding = block_position_encoding
self.encoder_decoder = encoder_decoder
self.shuffle_blocks = shuffle_blocks
self.sentinel_token = sentinel_token
self.generation_mask = 'gMASK' if task_mask else 'MASK'
self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id
self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK'
self.gap_sentence_mask = self.tokenizer.get_command(self.gap_sentence_mask).Id
self.random_position = random_position
self.masked_lm = masked_lm
print_rank_0(
f"BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}")
print_rank_0(
f"generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}")
print_rank_0(f"block length distribution {self.block_length_distribution}")
print_rank_0(f"block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}")
def contains_sentence_end(self, tok):
tok = self.tokenizer.IdToToken(tok)
if '.' in tok:
return True
if '?' in tok:
return True
if '!' in tok:
return True
if ';' in tok:
return True
if ':' in tok:
return True
if '。' in tok:
return True
if '?' in tok:
return True
if '!' in tok:
return True
if ';' in tok:
return True
if '…' in tok:
return True
if '\n' in tok:
return True
return False
@staticmethod
def sample_spans(span_lengths, total_length, rng, offset=0):
blank_length = total_length - sum(span_lengths)
m = blank_length - len(span_lengths) + 1
places = [rng.randrange(m + 1) for _ in range(len(span_lengths))]
places.sort()
spans = []
for place, span_length in zip(places, span_lengths):
start = offset + place
end = offset + place + span_length
spans.append((start, end))
offset += span_length + 1
return spans
def sample_span_in_document(self, tokens, masked_lengths, rng):
rng.shuffle(masked_lengths)
mask_spans = []
mask_index = 0
indices = [-1] + np.where(tokens == self.eod_token)[0].tolist()
last_index = len(tokens)
documents = []
for index in reversed(indices):
start_index = index
if start_index + 1 < len(tokens) and tokens[start_index + 1] == self.tokenizer.get_command('ENC').Id:
start_index += 1
length = last_index - start_index - 1
if last_index == len(tokens) and length > 0:
length -= 1
documents.append((start_index + 1, length))
last_index = index
documents.sort(key=lambda x: x[1])
for i, (offset, length) in enumerate(documents):
if i == len(documents) - 1:
current_masked_length, current_count = 0, 0
while mask_index + current_count < len(masked_lengths) and masked_lengths[
mask_index + current_count] + current_masked_length + current_count <= length:
current_masked_length += masked_lengths[mask_index + current_count]
current_count += 1
if current_count > 0:
spans = self.sample_spans(masked_lengths[mask_index: mask_index + current_count], length, rng,
offset=offset)
mask_spans += spans
if mask_index + current_count < len(masked_lengths) - 1:
print(length, masked_lengths[mask_index:], masked_lengths[:mask_index], indices)
else:
current_masked_total = int(length * self.bert_ratio)
current_masked_length, current_count = 0, 0
while mask_index + current_count < len(masked_lengths) and masked_lengths[
mask_index + current_count] + current_masked_length <= current_masked_total:
current_masked_length += masked_lengths[mask_index + current_count]
current_count += 1
if current_count > 0:
spans = self.sample_spans(masked_lengths[mask_index:mask_index + current_count], length,
rng, offset=offset)
mask_spans += spans
mask_index += current_count
return mask_spans
def make_masked_data(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert'):
position_ids = np.arange(len(tokens), dtype=np.long)
targets = copy.deepcopy(tokens)
mask_id = self.tokenizer.get_command('MASK').Id
mlm_masks = np.zeros(len(tokens), dtype=np.long)
for start, end in block_spans:
for idx in range(start, end):
tokens[idx] = mask_id
mlm_masks[start: end] = 1
loss_masks = loss_masks * mlm_masks
return tokens, targets, loss_masks, position_ids
def make_block_data(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert'):
text_length = len(tokens)
position_ids = np.ones(len(tokens), dtype=np.long)
for start, end in block_spans:
position_ids[start + 1: end] = 0
position_ids = np.cumsum(position_ids) - 1
if self.random_position and position_ids[-1] < self.max_seq_length - 1:
position_bias = self.max_seq_length - position_ids[-1]
position_bias = rng.randrange(0, position_bias)
position_ids = position_ids + position_bias
if self.encoder_decoder or not self.shuffle_blocks:
block_spans.sort(key=lambda x: x[0])
else:
rng.shuffle(block_spans)
if self.sentinel_token:
block_spans = [(start, end, idx) for idx, (start, end) in enumerate(block_spans)]
else:
block_spans = [(start, end, 0) for start, end in block_spans]
target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], []
for start, end, idx in block_spans:
sop_token = 'sop' if idx == 0 else f"sop{idx}"
target_tokens.append([self.tokenizer.get_command(sop_token).Id])
span_tokens = copy.deepcopy(tokens[start: end])
if self.block_mask_prob > 0.0 and task == 'bert':
for sub_idx in range(len(span_tokens)):
if random.random() < self.block_mask_prob:
span_tokens[sub_idx] = self.tokenizer.get_command('dBLOCK').Id
target_tokens.append(span_tokens)
targets.append(tokens[start: end])
targets.append([self.tokenizer.get_command('eop').Id])
if not self.sentinel_token:
target_position_id = position_ids[start: end]
target_position_ids.append(target_position_id)
target_position_ids.append([target_position_id[0]])
else:
target_position_ids.append([self.max_seq_length] * (end - start + 1))
if self.block_position_encoding:
target_block_position_ids.append(np.arange(1, end - start + 2, dtype=np.long))
else:
target_block_position_ids.append([1] * (end - start + 1))
block_spans.sort(key=lambda x: x[0])
source_tokens, source_position_ids, local_spans = [], [], []
last, current_length = 0, 0
for start, end, idx in block_spans:
if task == 'generation':
mask_id = self.generation_mask
elif task == 'gap_sentence':
mask_id = self.gap_sentence_mask
else:
mask_token = 'MASK' if idx == 0 else f'MASK{idx}'
mask_id = self.tokenizer.get_command(mask_token).Id
local_spans.append((current_length, current_length + start - last))
source_tokens.append(tokens[last: start])
source_tokens.append([mask_id])
source_position_ids.append(position_ids[last: start])
source_position_ids.append([position_ids[start]])
current_length += start - last + 1
last = end
if last < len(tokens):
local_spans.append((current_length, current_length + len(tokens) - last))
source_tokens.append(tokens[last:])
source_position_ids.append(position_ids[last:])
source_length = sum(map(len, source_tokens))
if attention_mask is not None:
assert source_length == attention_mask
if target_tokens and self.eod_token in np.concatenate(target_tokens).tolist():
print("Found EOS in target", self.tokenizer.DecodeIds(tokens))
raise RuntimeError
if self.encoder_decoder:
target_tokens = target_tokens + [self.tokenizer.get_command('eop').Id]
loss_masks = np.ones(len(target_tokens), dtype=np.long)
return source_tokens, target_tokens, loss_masks
else:
tokens = np.concatenate(source_tokens + target_tokens)
if task == 'bert' and self.context_mask_ratio > 0:
mask_candidates = set()
for start, end in local_spans:
if start != 0:
local_end = min(end, start + self.context_mask_range)
mask_candidates.update(range(start, local_end))
if end != 0:
local_start = max(start, end - self.context_mask_range)
mask_candidates.update(range(local_start, end))
mask_pos = rng.sample(mask_candidates, int(self.context_mask_ratio * text_length))
for pos in mask_pos:
tokens[pos] = self.tokenizer.get_command('dBLOCK').Id
targets = np.concatenate(source_tokens + targets)
loss_masks = np.ones(len(tokens), dtype=np.long)
loss_masks[:source_length] = 0
position_ids = np.concatenate(source_position_ids + target_position_ids)
block_position_ids = np.concatenate(
[np.zeros(source_length, dtype=np.long)] + target_block_position_ids)
position_ids = np.stack([position_ids, block_position_ids], axis=0)
if attention_mask is not None:
return tokens, targets, loss_masks, position_ids
else:
return tokens, targets, loss_masks, position_ids, source_length
def generate_blank_data(self, sample, masked_lengths, attention_mask, rng, task='bert'):
rng.shuffle(masked_lengths)
tokens, loss_masks = sample['text'], sample['loss_mask']
assert tokens[0] == self.tokenizer.get_command('ENC').Id
block_spans = self.sample_span_in_document(tokens, masked_lengths, rng)
if len(block_spans) < len(masked_lengths):
return None
if self.masked_lm:
data = self.make_masked_data(tokens, loss_masks, attention_mask, block_spans, rng)
else:
data = self.make_block_data(tokens, loss_masks, attention_mask, block_spans, rng, task=task)
return data
def split_samples(self, samples, rng):
target_length = rng.randrange(32, self.max_seq_length - 1)
num_splits = (self.max_seq_length - 1) // target_length
new_samples = []
cls_id = self.tokenizer.get_command('ENC').Id
eos_id = self.tokenizer.get_command('eos').Id
for sample in samples:
tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:]
for _ in range(num_splits):
if target_length >= len(tokens):
new_tokens, new_loss_masks = tokens, loss_masks
else:
random_start = rng.randrange(0, len(tokens) - target_length)
while random_start > 0 and (tokens[random_start] == eos_id or not (
self.contains_sentence_end(tokens[random_start - 1]) or tokens[
random_start - 1] == eos_id)):
random_start -= 1
random_end = random_start + target_length
while random_end > random_start and not (
self.contains_sentence_end(tokens[random_end - 1]) or tokens[random_end - 1] == eos_id):
random_end -= 1
if random_end - random_start < target_length // 2:
random_end = random_start + target_length
new_tokens, new_loss_masks = tokens[random_start: random_end], loss_masks[random_start: random_end]
new_tokens = np.concatenate(([cls_id], new_tokens))
new_loss_masks = np.concatenate(([0], new_loss_masks))
new_samples.append({'text': new_tokens, 'loss_mask': new_loss_masks})
return new_samples
def construct_blocks(self, samples):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id, num_workers = worker_info.id, worker_info.num_workers
else:
worker_id, num_workers = 0, 1
rng = random.Random((self.count * num_workers + worker_id) * self.world_size + self.rank)
self.count += 1
token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], []
source_batch, target_batch = [], []
if rng.random() < self.short_seq_prob:
samples = self.split_samples(samples, rng)
rand = rng.random()
single_span = rand < self.single_span_prob
rand = 0.0 if single_span else rng.random()
attention_mask = []
if rand < self.bert_prob:
mode = 'bert'
for sample in samples:
if single_span:
masked_lengths = [rng.choices(range(1, len(self.block_length_distribution) + 1),
weights=self.block_length_distribution)[0]]
masked_count = masked_lengths[0]
else:
masked_lengths, masked_count = [], 0
while masked_count < int(self.bert_ratio * len(sample['text'])):
block_length = rng.choices(range(1, len(self.block_length_distribution) + 1),
weights=self.block_length_distribution)[0]
masked_lengths.append(block_length)
masked_count += block_length
if self.masked_lm:
sep = len(sample['text'])
else:
sep = len(sample['text']) - masked_count + len(masked_lengths)
data = self.generate_blank_data(sample, masked_lengths, sep, rng, task='bert')
if data is not None:
if self.encoder_decoder:
source_tokens, target_tokens, loss_masks = data
source_batch.append(source_tokens)
target_batch.append(target_tokens)
loss_mask_batch.append(loss_masks)
else:
tokens, targets, loss_masks, position_ids = data
token_batch.append(tokens)
target_batch.append(targets)
loss_mask_batch.append(loss_masks)
position_id_batch.append(position_ids)
attention_mask.append(sep)
elif rand < self.bert_prob + self.gap_sentence_prob:
mode = 'sentence'
for sample in samples:
tokens, loss_masks = sample['text'], sample['loss_mask']
sentence_spans = []
last_index = 1 if tokens[0] == self.tokenizer.get_command('ENC').Id else 0
for i in range(len(tokens)):
if self.contains_sentence_end(tokens[i]):
if last_index < i + 1:
sentence_spans.append((last_index, i + 1))
last_index = i + 1
elif tokens[i] == self.tokenizer.get_command('eos').Id:
last_index = i + 1
if last_index < len(tokens):
sentence_spans.append((last_index, len(tokens)))
if not sentence_spans and torch.distributed.get_rank() == 0:
try:
print(self.tokenizer.DecodeIds(tokens[1:]))
except IndexError:
print(tokens[1:])
rng.shuffle(sentence_spans)
block_spans, block_length = [], 0
for start, end in sentence_spans:
block_spans.append((start, end))
block_length += end - start
if block_length >= int(self.gap_sentence_ratio * len(tokens)):
break
data = self.make_block_data(tokens, loss_masks, None, block_spans, rng, task='gap_sentence')
tokens, targets, loss_masks, position_ids, sep = data
token_batch.append(tokens)
target_batch.append(targets)
loss_mask_batch.append(loss_masks)
position_id_batch.append(position_ids)
attention_mask.append(sep)
else:
# start_indices = [index_in_list(sample['loss_mask'], 1) for sample in samples]
# end_indices = [rindex(sample['loss_mask'], 1) for sample in samples]
# start_index, end_index = max(start_indices), min(end_indices) - self.min_generation_length
# if end_index < start_index + 1:
# end_index = start_index + 1
# division = rng.randrange(start_index, end_index)
mode = 'gpt'
max_generation_length = rng.randint(int(self.gpt_min_ratio * min(map(lambda x: len(x['text']), samples))),
max(map(lambda x: len(x['text']), samples)) - 2)
for sample in samples:
generation_length = min(max_generation_length, len(sample['text']) - 2)
attention_mask.append(len(sample['text']) - generation_length + 1)
multiple_doc = index_in_list(sample['text'], self.tokenizer.get_command('eos').Id) not in [-1, len(
sample['text']) - 1]
if multiple_doc or rng.random() < self.infill_prob:
division = len(sample['text']) - generation_length
tokens, loss_masks = sample['text'], sample['loss_mask']
source_tokens, target_tokens = tokens[:division], tokens[division:]
target_masks = loss_masks[division:]
tokens = np.concatenate((
source_tokens, [self.generation_mask, self.tokenizer.get_command('sop').Id],
target_tokens[:-1]))
targets = np.concatenate((source_tokens, [self.generation_mask], target_tokens))
loss_masks = np.concatenate((np.zeros(len(source_tokens) + 1, dtype=np.long), target_masks))
token_batch.append(tokens)
target_batch.append(targets)
loss_mask_batch.append(loss_masks)
position_ids = np.arange(len(source_tokens) + len(target_tokens) + 1, dtype=np.long)
position_ids[len(source_tokens) + 1:] = len(source_tokens)
if self.block_position_encoding:
block_position_ids = np.concatenate(
(np.zeros(len(source_tokens), dtype=np.long),
np.arange(len(target_tokens) + 1, dtype=np.long)))
else:
block_position_ids = np.concatenate((np.zeros(len(source_tokens) + 1, dtype=np.long),
np.ones(len(target_tokens) + 1, dtype=np.long)))
position_id_batch.append(np.stack([position_ids, block_position_ids], axis=0))
else:
tokens, targets, loss_masks, position_ids = self.generate_blank_data(sample, [generation_length],
attention_mask[-1], rng,
task='generation')
token_batch.append(tokens)
target_batch.append(targets)
loss_mask_batch.append(loss_masks)
position_id_batch.append(position_ids)
if tokens is None:
print(sample, generation_length, multiple_doc)
if self.encoder_decoder:
return {
'text': torch.tensor(source_batch, dtype=torch.long),
'target': torch.tensor(target_batch, dtype=torch.long),
'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long)}
else:
token_batch, target_batch, loss_mask_batch, position_id_batch = self.pad_batch(token_batch, target_batch,
loss_mask_batch,
position_id_batch)
return {'text': torch.tensor(token_batch, dtype=torch.long),
'target': torch.tensor(target_batch, dtype=torch.long),
'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long),
'position_id': torch.tensor(position_id_batch, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'mode': mode}
@staticmethod
def pad_batch(token_batch, target_batch, loss_mask_batch, position_id_batch):
seq_lengths = list(map(len, token_batch))
if seq_lengths.count(seq_lengths[0]) != len(seq_lengths):
max_length = max(seq_lengths)
token_batch = [np.concatenate((tokens, np.zeros(max_length - len(tokens), dtype=np.long))) for tokens in
token_batch]
target_batch = [np.concatenate((targets, np.zeros(max_length - len(targets), dtype=np.long))) for
targets in
target_batch]
loss_mask_batch = [np.concatenate((loss_masks, np.zeros(max_length - len(loss_masks), dtype=np.long)))
for loss_masks in loss_mask_batch]
position_id_batch = [
np.concatenate((position_ids, np.zeros((2, max_length - position_ids.shape[1]), dtype=np.long)),
axis=1) for position_ids in position_id_batch]
return token_batch, target_batch, loss_mask_batch, position_id_batch