forked from clovaai/subword-qac
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
74 lines (57 loc) · 2.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""
import os
import time
import math
import torch
from model import LMConfig, LanguageModel
class TrainLogger(object):
def __init__(self):
self.init()
def init(self):
self.start = time.time()
self.cnt_add = 0
self.tot_loss = 0.
self.cnt_query = 0
self.cnt_token = 0
def add(self, loss, n_query, n_token):
self.cnt_add += 1
self.tot_loss += loss * n_query
self.cnt_query += n_query
self.cnt_token += n_token
def average(self):
loss_query = self.tot_loss / self.cnt_query if self.cnt_query != 0 else 0.
loss_token = self.tot_loss / self.cnt_token if self.cnt_token != 0 else 0.
return loss_query, loss_token
def elapsed_time(self):
return time.time() - self.start
def print_str(self, time_avg_=False):
loss_query, loss_token = self.average()
time_str = f"{self.elapsed_time() * 1000. / self.cnt_add:6.2f} ms/batch" if time_avg_ else \
f"{self.elapsed_time():6.2f} s"
return f"{time_str} | loss_query {loss_query:6.2f} | token_ppl {math.exp(loss_token):6.2f}"
def get_params(model):
return list(filter(lambda p: p.requires_grad, model.parameters()))
def get_model(model):
return model.module if hasattr(model, 'module') else model
def model_save(path, model, optimizer=None):
model_to_save = get_model(model)
open(os.path.join(path, 'config.json'), 'w').write(str(model_to_save.config))
torch.save(model_to_save.state_dict(), open(os.path.join(path, 'model.pt'), 'wb'))
if optimizer:
torch.save(optimizer.state_dict(), open(os.path.join(path, 'optimizer.pt'), 'wb'))
def model_load(path, model=None, optimizer=None):
config = LMConfig(os.path.join(path, 'config.json'))
if model is None:
model_to_load = LanguageModel(config)
else:
model_to_load = get_model(model)
model_to_load.__init__(config)
model_state_dict = torch.load(open(os.path.join(path, 'model.pt'), 'rb'), map_location=lambda s, l: s)
model_to_load.load_state_dict(model_state_dict)
if optimizer:
optimizer_state_dict = torch.load(open(os.path.join(path, 'optimizer.pt'), 'rb'), map_location=lambda s, l: s)
optimizer.load_state_dict(optimizer_state_dict)
return model_to_load