-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_train.py
39 lines (35 loc) · 1.08 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import matplotlib
matplotlib.use('Agg')
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from args import get_parser
import torch
import random
import datetime
from utils.train_utils import trainIters
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
args.use_gpu = torch.cuda.is_available()
args.log_term = False
args.hoct_dir = r'\\nv-nas01\Data\DME_recurrent\Model\2023_05_17_17'
args.dataset = 'Hoct'
args.num_workers = 0
args.max_epoch = 20
args.length_clip = 5
args.batch_size = 2
args.print_every = 100
args.maxseqlen = 5 # As the number of labels
args.models_path = args.hoct_dir
if not os.path.isdir(args.models_path):
os.mkdir(args.models_path)
now = datetime.datetime.now()
current_time = now.strftime("%d_%m_%y-%H")
args.model_name = current_time
torch.manual_seed(args.seed)
random.seed(args.seed)
gpu_id = args.gpu_id
if args.use_gpu:
torch.cuda.set_device(device=gpu_id)
torch.cuda.manual_seed(args.seed)
trainIters(args)