-
Notifications
You must be signed in to change notification settings - Fork 0
/
NLIandDataset.py
102 lines (88 loc) · 3.78 KB
/
NLIandDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import re
import torch
from torch import nn
from d2l import torch as d2l
#@save
d2l.DATA_HUB['SNLI'] = (
'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
'9fcde07509c7e87ec61c640c1b2753d9041758e4')
# data_dir = d2l.download_extract('SNLI')
data_dir = 'C:\\Users\\RJZhang\\Desktop\\data\\snli_1.0\\snli_1.0'
#@save
def read_snli(data_dir, is_train):
"""将SNLI数据集解析为前提、假设和标签"""
def extract_text(s):
# 删除我们不会使⽤的信息
s = re.sub('\\(', '', s)
s = re.sub('\\)', '', s)
# ⽤⼀个空格替换两个或多个连续的空格
s = re.sub('\\s{2,}', ' ', s)
return s.strip()
label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
if is_train else 'snli_1.0_test.txt')
with open(file_name, 'r') as f:
rows = [row.split('\t') for row in f.readlines()[1:]]
premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
hypotheses = [extract_text(row[2]) for row in rows if row[0] \
in label_set]
labels = [label_set[row[0]] for row in rows if row[0] in label_set]
return premises, hypotheses, labels
train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
print('前提:', x0)
print('假设:', x1)
print('标签:', y)
test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
print([[row for row in data[2]].count(i) for i in range(3)])
#@save
class SNLIDataset(torch.utils.data.Dataset):
"""⽤于加载SNLI数据集的⾃定义数据集"""
def __init__(self, dataset, num_steps, vocab=None):
self.num_steps = num_steps
all_premise_tokens = d2l.tokenize(dataset[0])
all_hypothesis_tokens = d2l.tokenize(dataset[1])
if vocab is None:
self.vocab = d2l.Vocab(all_premise_tokens + \
all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
else:
self.vocab = vocab
self.premises = self._pad(all_premise_tokens)
self.hypotheses = self._pad(all_hypothesis_tokens)
self.labels = torch.tensor(dataset[2])
print('read ' + str(len(self.premises)) + ' examples')
def _pad(self, lines):
return torch.tensor([d2l.truncate_pad(
self.vocab[line], self.num_steps, self.vocab['<pad>'])
for line in lines])
def __getitem__(self, idx):
return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]
def __len__(self):
return len(self.premises)
#@save
def load_data_snli(batch_size, num_steps=50):
"""下载SNLI数据集并返回数据迭代器和词表"""
num_workers = d2l.get_dataloader_workers()
# data_dir = d2l.download_extract('SNLI')
data_dir ='C:\\Users\\RJZhang\\Desktop\\data\\snli_1.0\\snli_1.0'
train_data = read_snli(data_dir, True)
test_data = read_snli(data_dir, False)
train_set = SNLIDataset(train_data, num_steps)
test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size,
shuffle=True,
num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
shuffle=False,
num_workers=num_workers)
return train_iter, test_iter, train_set.vocab
if __name__ == '__main__':
train_iter, test_iter, vocab = load_data_snli(128, 50)
print(len(vocab))
for X, Y in train_iter:
print(X[0].shape)
print(X[1].shape)
print(Y.shape)
break