-
Notifications
You must be signed in to change notification settings - Fork 0
/
new_main.py
62 lines (52 loc) · 1.73 KB
/
new_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from MCTS import MCTS
from ARGS import ARGS
from data import Dataseto
from deep import ConnectNet
from deep import cross_entropy_loss_batch
from state import state
import torch
from deep import NetHandler
import numpy
from Trainer import Trainer
#from blackfire import probe
#probe.initialize()
#probe.enable()
torch.set_printoptions(linewidth=100, precision=2)
numpy.set_printoptions(linewidth=100, precision=2)
net = ConnectNet()
datasett = Dataseto()
testsett = Dataseto()
MC = MCTS(net)
NetHandler = NetHandler(net, ARGS)
NetHandler.train_init()
MC.self_play(datasett, root=state())
MC.self_play(testsett, root=state())
MC.self_play(datasett, root=state())
MC.self_play(testsett, root=state())
MC.self_play(datasett, root=state())
MC.self_play(testsett, root=state())
for i in range(1000):
print(i)
MC.self_play(datasett, root=state())
MC.self_play(testsett, root=state())
trainloader = torch.utils.data.DataLoader(datasett, batch_size=10, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testsett, batch_size=10, shuffle=True, num_workers=2)
NetHandler.train(trainloader)
NetHandler.test_error(testloader)
if (i % 7 == 0):
print("len",len(datasett))
datasett.reset()
testsett.reset()
print("RESET")
#trainloader = torch.utils.data.DataLoader(datasett, batch_size=10, shuffle=True, num_workers=2)
#for i, data in enumerate(trainloader, 0):
# print(data[1])
# print("\n\nnet out")
# print(net.forward(data[0]))
# print("\n\n Ploss")
# print(cross_entropy_loss_batch(net.forward(data[0])[0], data[1]))
# print("single")
# print(net.forward(data[0])[0][0])
# print(data[1][0])
# print(net.PLoss(net.forward(data[0])[0][0], data[1][0]))
#probe.end()