-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_net.py
66 lines (49 loc) · 1.93 KB
/
plot_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python
import numpy as n
import pylab as pl
import math as m
import os
import sys
def getErrorRates(fileName):
errRates = {'train': [], 'test': []}
for l in open(fileName, 'r').readlines():
if 's' not in l: continue
if 'TEST' not in l:
errRates['train'] += [1. - float(l.strip().split()[2])]
else:
errRates['test'] += [1. - float(l.strip().split()[2])]
return errRates
def showCost(fileName, testFreq):
errRates = getErrorRates(fileName)
numCycles = len(errRates['train'])
testErrors = n.row_stack(errRates['test'])
testErrors = n.tile(testErrors, (1, testFreq))
testErrors = list(testErrors.flatten())
testErrors += [testErrors[-1]] * max(0, len(errRates['train']) - len(errRates['test']))
testErrors = testErrors[:len(errRates['train'])]
# TODO: Use the actual training set size here, not the MNIST size
numEpochs = numCycles / (50000 / 128)
pl.figure(1)
x = range(0, numCycles)
print "Plotting Range:", x[0],"to",x[-1]
pl.plot(x, errRates['train'], 'k-', label='Training')
pl.plot(x, testErrors, 'r-', label='Held-Out')
pl.legend()
tickLocations = (numCycles, (len(errRates['train']) - len(errRates['test'])) % numCycles + 1, numCycles)
epochGranularity = max(1, int(m.ceil(numEpochs / 20.)))
epochGranularity = int(m.ceil(float(epochGranularity)/10) * 10)
tickLabels = map(lambda x: str((x[1] / numCycles)) if x[0] % epochGranularity == epochGranularity - 1 else '', enumerate(tickLocations))
pl.xticks(tickLocations, tickLabels)
pl.xlabel('Batches')
pl.ylabel('Error Rate')
print len(errRates['train'])
print len(errRates['test'])
if __name__ == "__main__":
testFreq = 463;
if len(sys.argv) == 3:
testFreq = int(sys.argv[2])
if (os.path.exists(sys.argv[1])):
showCost(sys.argv[1], testFreq)
pl.show()
else:
print 'File Not Found'