-
Notifications
You must be signed in to change notification settings - Fork 1
/
cv_classifier.py
94 lines (78 loc) · 4.02 KB
/
cv_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
import torch
import numpy as np
from torch_geometric.loader import DataLoader
from experiments.utils import Logger
from experiments.rho_trainer import Trainer
from experiments.autoencoding.classification_metrics import *
from params import CLASSIFICATION_DATASET, CLASSIFIER_PARAMS, CLASSIFICATION_TRAINER_PARAMS
from utils import get_dataset, get_classification_model
RUNS = 10
from params import *
CROSSEVALUATION_DATASETS = [QEMU_FFMPEG_PARAMS]
# CROSSEVALUATION_DATASETS = []
def set_seed(seed):
CLASSIFICATION_TRAINER_PARAMS["seed"] = seed
np.random.seed(CLASSIFICATION_TRAINER_PARAMS["seed"])
torch.manual_seed(CLASSIFICATION_TRAINER_PARAMS["seed"])
EXPERIMENT_NAME = f'{CLASSIFIER_PARAMS["name"]}_{CLASSIFICATION_DATASET["name"]}_{CLASSIFICATION_TRAINER_PARAMS["name"]}'
def write_result(split, variant, classifier, dataset, trainer, result, params):
with open("results/results.csv", "a") as f:
f.write(f"{split};{variant};{classifier};{dataset};{trainer};{json.dumps(params)};{json.dumps(result)}\n")
def cross_eval(trainer, common_params, datasets, variant):
prev_loader = trainer.test_loader
for dataset_params in datasets:
dataset_params["overwrite_cache"] = False
dataset = get_dataset(dataset_params)
trainer.test_loader = DataLoader(dataset, batch_size=trainer.params["batch_size"],
shuffle=False, num_workers=4, prefetch_factor=1)
write_result(split=dataset_params["name"], variant=variant,
result=trainer.evaluate(train_set=False), **common_params)
trainer.test_loader = prev_loader
with Logger("results/" + EXPERIMENT_NAME + "/out.log"):
for i in range(RUNS):
print(f"Starting run {i} of {EXPERIMENT_NAME}")
set_seed(CLASSIFICATION_TRAINER_PARAMS["seed"] + 1)
CLASSIFICATION_DATASET["overwrite_cache"] = False
dataset = get_dataset(CLASSIFICATION_DATASET)
CLASSIFIER_PARAMS["features"] = dataset.get_input_size()
CLASSIFIER_PARAMS["edge_dim"] = dataset.get_edge_size()
CLASSIFIER_PARAMS["classes"] = len(dataset.get_classes())
if CLASSIFIER_PARAMS["classes"] == 2:
CLASSIFIER_PARAMS["classes"] = 1
model = get_classification_model(CLASSIFIER_PARAMS)
losses = [CrossEntropyLoss()]
metrics = [ClassAccuracy(), ClassPositivesNegatives(), ClassAUC(), ClassAP(), BestClassPositivesNegatives()]
CLASSIFICATION_TRAINER_PARAMS["experiment_name"] = EXPERIMENT_NAME
trainer = Trainer(model, losses, metrics, dataset, CLASSIFICATION_TRAINER_PARAMS)
trainer.train()
params = {**CLASSIFICATION_DATASET, **CLASSIFIER_PARAMS, **CLASSIFICATION_TRAINER_PARAMS}
common_params = {
"classifier": CLASSIFIER_PARAMS["name"],
"dataset": CLASSIFICATION_DATASET["name"],
"trainer": CLASSIFICATION_TRAINER_PARAMS["name"],
"params": params
}
write_result(
"Train", "End", CLASSIFIER_PARAMS["name"], CLASSIFICATION_DATASET["name"],
CLASSIFICATION_TRAINER_PARAMS["name"], trainer.evaluate(train_set=True), params
)
write_result(
"Test", "End", CLASSIFIER_PARAMS["name"], CLASSIFICATION_DATASET["name"],
CLASSIFICATION_TRAINER_PARAMS["name"], trainer.evaluate(train_set=False), params
)
cross_eval(trainer, common_params, CROSSEVALUATION_DATASETS, "End")
trainer.load_checkpoint()
write_result(
"Train", "Checkpoint", CLASSIFIER_PARAMS["name"], CLASSIFICATION_DATASET["name"],
CLASSIFICATION_TRAINER_PARAMS["name"], trainer.evaluate(train_set=True), params
)
write_result(
"Test", "Checkpoint", CLASSIFIER_PARAMS["name"], CLASSIFICATION_DATASET["name"],
CLASSIFICATION_TRAINER_PARAMS["name"], trainer.evaluate(train_set=False), params
)
cross_eval(trainer, common_params, CROSSEVALUATION_DATASETS, "Checkpoint")
trainer.cleanup()
del trainer
del model
del dataset