-
Notifications
You must be signed in to change notification settings - Fork 1
/
tree.py
396 lines (273 loc) · 10.7 KB
/
tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
""" This module contains the data representation for the structure of trees.
Note that we separate the structure of the branched tree from its
realization, that would contain things such as charges, positions, etc.
that evolve even when the structure is fixed.
"""
from random import random as uniform
from scipy.sparse import lil_matrix, csr_matrix
from numpy import *
class Tree(object):
""" Instances of the Tree class contain the topological information
of a tree discharge: i.e. they encapsulate the relations between different
segments in a tree but not about locations, conductivities etc.."""
def __init__(self):
# We must carry a global (tree-level) index to access the parameter
# data arrays
self.n = 0
self.segments = []
self.root = None
def add_segment(self, segment):
""" Adds a *segment* to this tree. Returns the index of the segment
inside the tree. """
if self.n == 0:
self.root = segment
self.segments.append(segment)
index = self.n
self.n += 1
return index
def parents(self, root_index=0):
""" Builds an array with the indices to each segment's parent.
The root segment gets an index *root_index*. """
p = zeros((self.n,), dtype='i')
for i, segment in enumerate(self.segments):
try:
p[i] = segment.parent.index
except AttributeError:
p[i] = root_index
return p
def make_root(self):
""" Creates a segment node to be root of this tree. """
root = Segment()
root.set_tree(self)
self.root = root
return root
def terminals(self):
""" Finds all segments contained in the tree that do not
have any children. Returns an array with segment indices. """
l = []
for i, segment in enumerate(self.segments):
if not segment.children:
l.append(i)
return array(l)
def branches(self):
""" Finds all indices of segments that branch in the tree"""
l = []
for i, segment in enumerate(self.segments):
if len(segment.children) > 1:
l.append(i)
return array(l)
def extend(self, indices):
""" Extends the tree adding one children to each segment indexed
by *indices*, in that order. This is used to extend a propagating tree.
"""
for i in indices:
new_segment = Segment()
self.segments[i].add_child(new_segment)
def zeros(self, dim=None):
""" Returns an array that can hold all the data needed for a variable
in this tree's segments. For multi-dimension data, use *dim*. """
if dim is None:
return zeros((self.n,))
else:
return zeros((self.n, dim))
def lengths(self, endpoints):
""" Returns an array with the segment lengths of the tree, given
an array with the *endpoints*. """
parents = self.parents()
l = sqrt(sum((endpoints - endpoints[parents, :])**2, axis=1))
return l
def midpoints(self, endpoints):
""" Returns an array with the segment midpoints of the tree, given
an array with the *endpoints*. """
parents = self.parents()
return 0.5 * (endpoints + endpoints[parents, :])
def ohm_matrix(self, endpoints, fix=[]):
""" Builds a matrix M that will provide the evolution of charges
in every segment of the tree as dq/dt = M . phi, where phi is
the potential at the center of each segment and '.' is the dot product.
This function builds the matrix from scratch. Usually it is much
better to keep updating the matrix as the tree grows.
* *endpoints* must contain an array with the endpoints.
* *fix* contains an array with indices of nodes with a fixed charge.
usually that means the root node.
"""
l = self.lengths(endpoints)
linv = 1.0 / l
# We build the matrix in LIL format first, later we convert to a
# format more efficient for matrix-vector multiplications
M = lil_matrix((self.n, self.n))
for segment in self:
i = segment.index
m = 0.0
for other in segment.children:
j = other.index
M[i, j] = linv[j]
m -= linv[j]
if segment.parent is not None:
j = segment.parent.index
M[i, j] = linv[i]
m -= linv[i]
M[i, i] = m
for f in fix:
M[f, :] = 0
return csr_matrix(M)
def branch_label(self, labels=None, label=1, segment=None):
""" Returns an array with an integer for each node that is unique
for the branch where it sits. """
if labels is None:
labels = zeros((self.n,), dtype='i')
if segment is None:
segment = self.root
while True:
labels[segment.index] = label
if len(segment.children) != 1:
break
segment = segment.children[0]
for i, c in enumerate(segment.children):
self.branch_label(labels, label=2*label + i, segment=c)
return labels
def branch_distance(self, endpoints, dist=None, segment=None,
lengths=None):
""" Returns an array with the distance of each node from the
branching immediately above it. The distance is calculated along the
branch. """
if dist is None:
dist = zeros((self.n,), dtype='d')
if segment is None:
segment = self.root
if lengths is None:
lengths = self.lengths(endpoints)
l = 0
while True:
dist[segment.index] = l
l += lengths[segment.index]
if len(segment.children) != 1:
break
segment = segment.children[0]
for i, c in enumerate(segment.children):
self.branch_distance(endpoints, dist, segment=c, lengths=lengths)
return dist
def reconnects(self, endpoints, rmin=5e-4, dmin=1e-3):
""" Finds reconnections in a tree. """
term = array(self.terminals())
rterm = endpoints[term, :]
labels = self.branch_label()
lterm = labels[term]
dist = self.branch_distance(endpoints)
dterm = dist[term]
# We look only at node pairs where one of the node is a terminal.
r2 = sum((rterm[newaxis, :, :] - endpoints[:, newaxis, :])**2, axis=2)
dlabel = lterm[newaxis, :] - labels[:, newaxis]
# These still include branching events, which are very close but
# close to the branching points
s = logical_and(dlabel != 0, r2 <= rmin**2)
t = logical_and(dterm[newaxis, :] > dmin, dist[:, newaxis] > dmin)
u = logical_and(s, t)
i, j = nonzero(u)
return len(i) > 0
def save(self, fname):
""" Saves the tree structure into file fname. """
parents = self.parents()
i = arange(self.n)
savetxt(fname, c_[i, parents])
@staticmethod
def loadtxt(fname):
""" Loads a tree structure from a txt file [DEBUG]. """
indices, parents = loadtxt(fname, unpack=True)
return Tree.from_parents(parents)
@staticmethod
def from_parents(parents):
""" Builds a tree from a list of the parent indices. """
t = Tree()
indices = arange(parents.shape[0])
for i in indices:
if i == 0:
t.make_root()
else:
seg = Segment()
t.segments[parents[i]].add_child(seg)
return t
def __iter__(self):
return iter(self.segments)
class Segment(object):
""" This is class of the segments composing a :class:`Tree`. """
def __init__(self):
self.children = []
self.parent = None
self.tree = None
def set_tree(self, tree):
self.tree = tree
self.index = tree.add_segment(self)
def set_parent(self, parent):
self.parent = parent
self.set_tree(parent.tree)
def get(self, a):
""" Gets the value in array a corresponding to this segment. """
return a[self.index]
def set(self, a, value):
""" Sets the value in array a corresponding to this segment. """
a[self.index] = value
def iter_adjacent(self):
""" Iterates over all adjacent segments, including parent and
children (if any). """
if self.parent is not None:
yield self.parent
for child in self.children:
yield child
def add_child(self, other):
""" Adds the :class:`Segment` *other* as a child of this segment. """
other.set_parent(self)
self.children.append(other)
def random_branching_tree(n, p):
""" Builds a branched tree of n segments where every segment has a
probability p of having two descendants. This produces nice pictures
and can be useful for testing. """
tree = Tree()
root = tree.make_root()
leafs = [root]
for i in xrange(n):
l = leafs.pop(0)
# Every leaf has at least one descendant
s = Segment()
l.add_child(s)
leafs.append(s)
# With probability p it has two children
if random.uniform() < p:
s = Segment()
l.add_child(s)
leafs.append(s)
return tree
def sample_endpoints(tree):
""" Gives endpoints to a tree structure. Useful for plotting sample
trees [DEBUG]. """
r = tree.zeros(dim=3)
deltav = {1: array([[0, 0, -1]]),
2: array([[-1, 0, -1], [1, 0, -1]])}
def recurse(leaf, v):
if leaf.parent is None:
leaf.set(r, (0, 0, 0))
else:
leaf.set(r, leaf.parent.get(r) + v)
n = len(leaf.children)
lr = leaf.get(r)
for i, child in enumerate(leaf.children):
vnew = (v * array([0.9, 1.0, 0.95]) +
(deltav[n][i] + random.uniform(-0.1, 0.1, size=3))
* exp(lr[1] / 100))
recurse(child, vnew)
recurse(tree.root, array([0, 0, 0]))
return r
def test():
import pylab
tree = random_branching_tree(1000, 0.05)
r = sample_endpoints(tree)
for segment in tree:
ep = segment.get(r)
try:
ip = segment.parent.get(r)
except AttributeError:
ip = array([0, 0, 0])
pylab.plot([ip[0], ep[0]], [ip[2], ep[2]], lw=0.8, c='k')
pylab.show()
if __name__ == '__main__':
test()