-
Notifications
You must be signed in to change notification settings - Fork 28
/
submission.py
107 lines (85 loc) · 3.7 KB
/
submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass
from zipfile import ZipFile
import torch
import numpy as np
from tqdm import tqdm
from config.default import cfg
from lib.datasets.datamodules import DataModule
from lib.models.builder import build_model
from lib.utils.data import data_to_model_device
from transforms3d.quaternions import mat2quat
@dataclass
class Pose:
image_name: str
q: np.ndarray
t: np.ndarray
inliers: float
def __str__(self) -> str:
formatter = {'float': lambda v: f'{v:.6f}'}
max_line_width = 1000
q_str = np.array2string(self.q, formatter=formatter, max_line_width=max_line_width)[1:-1]
t_str = np.array2string(self.t, formatter=formatter, max_line_width=max_line_width)[1:-1]
return f'{self.image_name} {q_str} {t_str} {self.inliers}'
def predict(loader, model):
results_dict = defaultdict(list)
for data in tqdm(loader):
# run inference
data = data_to_model_device(data, model)
with torch.no_grad():
R_batched, t_batched = model(data)
for i_batch in range(len(data['scene_id'])):
R = R_batched[i_batch].unsqueeze(0).detach().cpu().numpy()
t = t_batched[i_batch].reshape(-1).detach().cpu().numpy()
inliers = data['inliers'][i_batch].item()
scene = data['scene_id'][i_batch]
query_img = data['pair_names'][1][i_batch]
# ignore frames without poses (e.g. not enough feature matches)
if np.isnan(R).any() or np.isnan(t).any() or np.isinf(t).any():
continue
# populate results_dict
estimated_pose = Pose(image_name=query_img,
q=mat2quat(R).reshape(-1),
t=t.reshape(-1),
inliers=inliers)
results_dict[scene].append(estimated_pose)
return results_dict
def save_submission(results_dict: dict, output_path: Path):
with ZipFile(output_path, 'w') as zip:
for scene, poses in results_dict.items():
poses_str = '\n'.join((str(pose) for pose in poses))
zip.writestr(f'pose_{scene}.txt', poses_str.encode('utf-8'))
def eval(args):
# Load configs
cfg.merge_from_file('config/datasets/mapfree.yaml')
cfg.merge_from_file(args.config)
# Create dataloader
if args.split == 'test':
cfg.TRAINING.BATCH_SIZE = 8
cfg.TRAINING.NUM_WORKERS = 8
dataloader = DataModule(cfg, drop_last_val=False).test_dataloader()
elif args.split == 'val':
cfg.TRAINING.BATCH_SIZE = 12
cfg.TRAINING.NUM_WORKERS = 8
dataloader = DataModule(cfg, drop_last_val=False).val_dataloader()
else:
raise NotImplemented(f'Invalid split: {args.split}')
# Create model
model = build_model(cfg, args.checkpoint)
# Get predictions from model
results_dict = predict(dataloader, model)
# Save predictions to txt per scene within zip
args.output_root.mkdir(parents=True, exist_ok=True)
save_submission(results_dict, args.output_root / 'submission.zip')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='path to config file')
parser.add_argument('--checkpoint',
help='path to model checkpoint (models with learned parameters)', default='')
parser.add_argument('--output_root', '-o', type=Path, default=Path('results/'))
parser.add_argument('--split', choices=('val', 'test'), default='test',
help='Dataset split to use for evaluation. Choose from test or val. Default: test')
args = parser.parse_args()
eval(args)