Skip to content

Commit

Permalink
Added fp16 to inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexanders101 committed Aug 23, 2023
1 parent 824c657 commit 3d65e92
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
10 changes: 8 additions & 2 deletions spanet/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def load_model(
event_info_file: Optional[str] = None,
batch_size: Optional[int] = None,
cuda: bool = False,
fp16: bool = False,
checkpoint: Optional[str] = None
) -> JetReconstructionModel:
# Load the best-performing checkpoint on validation data
Expand All @@ -50,6 +51,8 @@ def load_model(

checkpoint = torch.load(checkpoint, map_location='cpu')
checkpoint = checkpoint["state_dict"]
if fp16:
checkpoint = tree_map(lambda x: x.half(), checkpoint)

# Load the options that were used for this run and set the testing-dataset value
options = Options.load(f"{log_directory}/options.json")
Expand Down Expand Up @@ -80,7 +83,8 @@ def load_model(
def evaluate_on_test_dataset(
model: JetReconstructionModel,
progress=progress,
return_full_output: bool = False
return_full_output: bool = False,
fp16: bool = False
) -> Union[Evaluation, Tuple[Evaluation, Outputs]]:
full_assignments = defaultdict(list)
full_assignment_probabilities = defaultdict(list)
Expand All @@ -97,7 +101,9 @@ def evaluate_on_test_dataset(

for batch in dataloader:
sources = tuple(Source(x[0].to(model.device), x[1].to(model.device)) for x in batch.sources)
outputs = model.forward(sources)

with torch.cuda.amp.autocast(enabled=fp16):
outputs = model.forward(sources)

assignment_indices = extract_predictions([
np.nan_to_num(assignment.detach().cpu().numpy(), -np.inf)
Expand Down
2 changes: 1 addition & 1 deletion spanet/network/layers/stacked_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(self, encoded_vectors: Tensor, padding_mask: Tensor, sequence_mask:
combined_sequence_mask = torch.cat((particle_sequence_mask, sequence_mask), dim=0)

# -----------------------------------------------------------------------------
# Run all of the vectors through transformer encoderx
# Run all of the vectors through transformer encoder
# combined_vectors: [T + 1, B, D]
# particle_vector: [B, D]
# encoded_vectors: [T, B, D]
Expand Down
12 changes: 8 additions & 4 deletions spanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def main(log_directory: str,
event_file: Optional[str],
batch_size: Optional[int],
output_vectors: bool,
gpu: bool):
model = load_model(log_directory, test_file, event_file, batch_size, gpu)
gpu: bool,
fp16: bool):
model = load_model(log_directory, test_file, event_file, batch_size, gpu, fp16=fp16)

if output_vectors:
evaluation, full_outputs = evaluate_on_test_dataset(model, return_full_output=True)
evaluation, full_outputs = evaluate_on_test_dataset(model, return_full_output=True, fp16=fp16)
else:
evaluation = evaluate_on_test_dataset(model, return_full_output=False)
evaluation = evaluate_on_test_dataset(model, return_full_output=False, fp16=fp16)
full_outputs = None

create_hdf5_output(output_file, model.testing_dataset, evaluation, full_outputs)
Expand All @@ -102,6 +103,9 @@ def main(log_directory: str,

parser.add_argument("-g", "--gpu", action="store_true",
help="Evaluate network on the gpu.")

parser.add_argument("-fp16", "--fp16", action="store_true",
help="Use Automatic Mixed Precision for inference.")

parser.add_argument("-v", "--output_vectors", action="store_true",
help="Include embedding vectors in output in an additional section of the HDF5.")
Expand Down
8 changes: 6 additions & 2 deletions spanet/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,11 @@ def main(
batch_size: Optional[int],
lines: int,
gpu: bool,
fp16: bool,
latex: bool
):
model = load_model(log_directory, test_file, event_file, batch_size, gpu)
evaluation = evaluate_on_test_dataset(model)
model = load_model(log_directory, test_file, event_file, batch_size, gpu, fp16=fp16)
evaluation = evaluate_on_test_dataset(model, fp16=fp16)

# Flatten predictions
predictions = list(evaluation.assignments.values())
Expand Down Expand Up @@ -269,6 +270,9 @@ def main(
parser.add_argument("-g", "--gpu", action="store_true",
help="Evaluate network on the gpu.")

parser.add_argument("-fp16", "--fp16", action="store_true",
help="Use Automatic Mixed Precision for inference.")

parser.add_argument("-tex", "--latex", action="store_true",
help="Output a latex table.")

Expand Down

0 comments on commit 3d65e92

Please sign in to comment.