-
Notifications
You must be signed in to change notification settings - Fork 6
/
SMC.py
75 lines (60 loc) · 2.85 KB
/
SMC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from programGraph import *
from API import *
from pointerNetwork import *
import time
class SMC(Solver):
def __init__(self, model, _=None,
maximumLength=8,
initialParticles=100, exponentialGrowthFactor=2,
fitnessWeight=2.):
self.maximumLength = maximumLength
self.initialParticles = initialParticles
self.exponentialGrowthFactor = exponentialGrowthFactor
self.fitnessWeight = fitnessWeight
self.model = model
def _infer(self, spec, loss, timeout):
startTime = time.time()
numberOfParticles = self.initialParticles
specEncoding = self.model.specEncoder(spec)
# Maps from an object to its embedding
objectEncodings = ScopeEncoding(self.model, spec)
# Maps from a graph to its distance
_distance = {}
def distance(g):
if g in _distance: return _distance[g]
se = objectEncodings.encoding(list(g.objects()))
d = self.model.distance(se, specEncoding)
_distance[g] = d
return d
class Particle():
def __init__(self, graph, frequency):
self.frequency = frequency
self.graph = graph
self.distance = distance(graph)
while True:
population = [Particle(ProgramGraph([]), numberOfParticles)]
for _ in range(self.maximumLength):
sampleFrequency = {}
for p in population:
for newObject in self.model.repeatedlySample(specEncoding, p.graph,
objectEncodings, p.frequency):
if newObject is None: newGraph = p.graph
else: newGraph = p.graph.extend(newObject)
sampleFrequency[newGraph] = sampleFrequency.get(newGraph, 0) + 1
if time.time() - startTime >= timeout: return
for g in sampleFrequency: self._report(g)
# Convert graphs to particles
samples = [Particle(g, f)
for g, f in sampleFrequency.items() ]
# Resample
logWeights = [math.log(p.frequency) - p.distance
for p in samples]
ps = [ math.exp(lw - max(logWeights)) for lw in logWeights ]
ps = [p/sum(ps) for p in ps]
sampleFrequencies = np.random.multinomial(numberOfParticles, ps)
population = []
for particle, frequency in zip(samples, sampleFrequencies):
if frequency > 0:
particle.frequency = frequency
population.append(particle)
numberOfParticles *= self.exponentialGrowthFactor