-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
65 lines (52 loc) · 1.42 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from datetime import datetime
from typing import List
defaults = {
"avail_models": ["anc", "cryptonet", "cryptonet_anc"],
"model": "anc",
"debug": False,
"tensorboard": False,
# Default config options for Trainer
"dropout": 0.1,
"training": {
"run": str(datetime.now()),
"batches": 12800,
"epochs": 10
},
# Default config options for ANC
"anc": {
"blocksize": 16,
"batchlen": 64,
"alice_lr": 0.001,
"eve_lr": 0.001,
},
# Default config options for Cryptonet
"cryptonet": {
"blocksize": 16,
"batchlen": 64,
"alice_lr": 0.001,
"eve_lr": 0.001,
},
# Default config options for Cryptonet+ANC
"cryptonet_anc": {
"blocksize": 16,
"batchlen": 64,
"alice_lr": 0.001,
"eve_lr": 0.001,
},
"save_model": True,
}
def build_config(argv: List[str]):
global defaults
config = dict()
for arg in argv:
if arg.find("=") == -1:
continue
name, value = arg.split("=")
keys = name.split("-")
config = tree_traverse(defaults, keys, value)
defaults = config
def tree_traverse(tree: dict, keys: List[str], value):
key = keys.pop(0)
if key in tree.keys():
tree[key] = tree_traverse(tree[key], keys, value) if type(tree[key]) == dict else type(tree[key])(value)
return tree