-
Notifications
You must be signed in to change notification settings - Fork 5
/
subgraph_sample.py
78 lines (66 loc) · 2.65 KB
/
subgraph_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
from util import load_data
def subgraph_sample(dataset, graph_list, border, nums = 500):
print("Sampling Data")
with open('dataset/%s/sampling.txt' % (dataset), 'w') as f:
f.write(str(len(graph_list))+'\n')
for graph in graph_list:
if graph.nodegroup == 1:
graph.sample_list = []
graph.unsample_list = []
graph.sample_x = []
n = graph.g.number_of_nodes()
K = int(min(border-1, n/2))
f.write(str(K)+'\n')
graph.K = K
for i in range(nums):
sample_idx = np.random.permutation(n)
j = 0
sample_set = set()
wait_set = []
cnt = 0
if (len(graph.neighbors[j]) == 0):
j += 1
wait_set.append(sample_idx[j])
while cnt < K:
if len(wait_set) != 0:
x = wait_set.pop()
else:
break
while x in sample_set:
if len(wait_set) != 0:
x = wait_set.pop()
else:
cnt = K
break
sample_set.add(x)
cnt += 1
wait_set.extend(graph.neighbors[x])
unsample_set = set(range(n)).difference(sample_set)
f.write(str(len(sample_set)) + ' ')
for x in list(sample_set):
f.write(str(x) + ' ')
for x in list(unsample_set):
f.write(str(x) + ' ')
f.write('\n')
else:
f.write('0\n')
print("%s Finished"%(dataset))
if __name__ == '__main__':
for dataset in ['PTC', "PROTEINS", "DD", 'FRANK', "IMDBBINARY"]:
np.random.seed(0)
graphs, num_classes = load_data(dataset, 0)
if dataset == 'PTC':
border = 35
elif dataset == "PROTEINS":
border = 50
elif dataset == "IMDBBINARY":
border = 25
elif dataset == "DD":
border = 400
elif dataset == "FRANK":
border = 22
for i in range(len(graphs)):
if graphs[i].g.number_of_nodes() >= border:
graphs[i].nodegroup += 1
subgraph_sample(dataset, graphs, border)