-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
190 lines (151 loc) · 6.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Import torch
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
# Import huggingface pre-trained models
from transformers import MarianTokenizer, MarianMTModel
# Import additional libraries
import argparse
import numpy as np
import os
import pickle
from tqdm import tqdm
import warnings
from dataset import CardTextDataset
# Training device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Define and set random states
RANDOM_STATE_DEFAULT = 42
# Default hyperparameters
LR_DEFAULT = 1e-4
EPOCH_DEFAULT = 5
# Pre-trained model name
HF_PRETRAINED_NAME = "Helsinki-NLP/opus-mt-ja-en"
# Additional tokens
TOKENIZER_EXTRA_TOKENS = [
# Substitutes for character name and trait references
"<TRAIT>", "<NAME>",
# Substitutes for trigger icons
"<SOUL>", "<CHOICE>", "<TREASURE>", "<SALVAGE>", "<STANDBY>",
"<GATE>", "<BOUNCE>", "<STOCK>", "<SHOT>", "<DRAW>",
# Tokens for keywords
"【", "】", "AUTO", "ACT", "CONT", "COUNTER", "CLOCK",
"トリガー",
]
def main():
# Parse the input arguments
parser = argparse.ArgumentParser()
# Dataset CSV file paths
parser.add_argument("--train_csv", required=True)
parser.add_argument("--val_csv", required=True)
# Japanese text column name
parser.add_argument("--ja", required=True)
# English text column name
parser.add_argument("--en", required=True)
# Model and tokenizer path
parser.add_argument("--model", default=None)
# Model export path
parser.add_argument("--export", required=True)
# Training hyperparameters
# Training epochs
parser.add_argument("--epochs", default=10, type=int)
# Batch size
parser.add_argument("--batch_size", default=1, type=int)
# Initial learning rate
parser.add_argument("--lr", default=LR_DEFAULT, type=float)
# Optional checkpoint every n epochs
parser.add_argument("--checkpoint", default=-1, type=int)
# Optional random seed
parser.add_argument("--seed", default=RANDOM_STATE_DEFAULT, type=int)
args = parser.parse_args()
# Seeding
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(args.seed)
# Load pre-trained model from huggingface
model_name = HF_PRETRAINED_NAME if args.model is None else args.model
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name, truncation=True)
# Add tokens and resize embeddings accordingly
if args.model is None:
tokenizer.add_tokens(TOKENIZER_EXTRA_TOKENS)
model.resize_token_embeddings(len(tokenizer))
# Load optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.5)
# Load train set
train_set = CardTextDataset(args.train_csv, args.ja, args.en)
train_dataloader = DataLoader(train_set, args.batch_size, shuffle=True)
# Load test set
val_set = CardTextDataset(args.val_csv, args.ja, args.en)
val_dataloader = DataLoader(val_set, args.batch_size)
# Epoch and epoch losses
training_epochs = range(1, args.epochs + 1)
train_losses, val_losses = [], []
# Begin training
if args.checkpoint <= 0 or args.checkpoint > args.epochs:
warnings.warn("Checkpoint saving disabled")
print(f"Traingin for {args.epochs} epochs on device: {DEVICE}")
model.to(DEVICE)
try:
for epoch in training_epochs:
# Enter training mode
model.train()
train_loss = []
tqdm_train_loader = tqdm(train_dataloader, leave=False)
for src_str, tgt_str in tqdm_train_loader:
src_tokenized = tokenizer(src_str, max_length=512, padding="max_length")
tgt_tokenized = tokenizer(tgt_str, max_length=512, padding="max_length")
kwargs = {
"input_ids": torch.tensor(src_tokenized["input_ids"], device=DEVICE),
"attention_mask": torch.tensor(src_tokenized["attention_mask"], device=DEVICE),
"labels": torch.tensor(tgt_tokenized["input_ids"], device=DEVICE),
}
optimizer.zero_grad()
# Forward pass
output = model(**kwargs)
# Backward pass
output.loss.backward()
optimizer.step()
# Save training batch loss and average
train_loss.append(output.loss.item())
tqdm_train_loader.set_description(desc=f"Epoch {epoch}: {np.average(train_loss):5f}")
# Save train loss
avg_train_loss = np.average(train_loss)
# Update scheduler once per epoch
scheduler.step()
# Export if required
if args.checkpoint > 0 and epoch % args.checkpoint == 0:
model.save_pretrained(os.path.join(args.export, "checkpoints", f"model_checkpoint_{epoch}"))
tokenizer.save_pretrained(os.path.join(args.export, "checkpoints", f"model_checkpoint_{epoch}"))
# Enter evaluation mode
model.eval()
val_loss = []
with torch.no_grad():
for src_str, tgt_str in val_dataloader:
src_tokenized = tokenizer(src_str, max_length=512, padding="max_length")
tgt_tokenized = tokenizer(tgt_str, max_length=512, padding="max_length")
kwargs = {
"input_ids": torch.tensor(src_tokenized["input_ids"], device=DEVICE),
"attention_mask": torch.tensor(src_tokenized["attention_mask"], device=DEVICE),
"labels": torch.tensor(tgt_tokenized["input_ids"], device=DEVICE),
}
# Forward pass
output = model(**kwargs)
# Save validation loss and batch average
val_loss.append(output.loss.item())
# Save validation loss
avg_val_loss = np.average(val_loss)
train_losses.append(avg_train_loss)
val_losses.append(avg_val_loss)
# Report epoch loss
print(f"Epoch {epoch}: train_loss={avg_train_loss: 5f}, val_loss={avg_val_loss: 5f}")
except KeyboardInterrupt:
print("Keyboard interrupt detected, aborting training")
# Export last model and tokenizer
model.save_pretrained(os.path.join(args.export, "checkpoints", "model_last"))
tokenizer.save_pretrained(os.path.join(args.export, "checkpoints", "model_last"))
if __name__ == "__main__":
main()