diff --git a/spanet/evaluation.py b/spanet/evaluation.py index 3ce7468..385bbab 100644 --- a/spanet/evaluation.py +++ b/spanet/evaluation.py @@ -41,6 +41,7 @@ def load_model( event_info_file: Optional[str] = None, batch_size: Optional[int] = None, cuda: bool = False, + fp16: bool = False, checkpoint: Optional[str] = None ) -> JetReconstructionModel: # Load the best-performing checkpoint on validation data @@ -50,6 +51,8 @@ def load_model( checkpoint = torch.load(checkpoint, map_location='cpu') checkpoint = checkpoint["state_dict"] + if fp16: + checkpoint = tree_map(lambda x: x.half(), checkpoint) # Load the options that were used for this run and set the testing-dataset value options = Options.load(f"{log_directory}/options.json") @@ -80,7 +83,8 @@ def load_model( def evaluate_on_test_dataset( model: JetReconstructionModel, progress=progress, - return_full_output: bool = False + return_full_output: bool = False, + fp16: bool = False ) -> Union[Evaluation, Tuple[Evaluation, Outputs]]: full_assignments = defaultdict(list) full_assignment_probabilities = defaultdict(list) @@ -97,7 +101,9 @@ def evaluate_on_test_dataset( for batch in dataloader: sources = tuple(Source(x[0].to(model.device), x[1].to(model.device)) for x in batch.sources) - outputs = model.forward(sources) + + with torch.cuda.amp.autocast(enabled=fp16): + outputs = model.forward(sources) assignment_indices = extract_predictions([ np.nan_to_num(assignment.detach().cpu().numpy(), -np.inf) diff --git a/spanet/network/layers/stacked_encoder.py b/spanet/network/layers/stacked_encoder.py index 3b76ec2..8fbf9a9 100644 --- a/spanet/network/layers/stacked_encoder.py +++ b/spanet/network/layers/stacked_encoder.py @@ -74,7 +74,7 @@ def forward(self, encoded_vectors: Tensor, padding_mask: Tensor, sequence_mask: combined_sequence_mask = torch.cat((particle_sequence_mask, sequence_mask), dim=0) # ----------------------------------------------------------------------------- - # Run all of the vectors through transformer encoderx + # Run all of the vectors through transformer encoder # combined_vectors: [T + 1, B, D] # particle_vector: [B, D] # encoded_vectors: [T, B, D] diff --git a/spanet/predict.py b/spanet/predict.py index e8d6679..7d52b92 100644 --- a/spanet/predict.py +++ b/spanet/predict.py @@ -70,13 +70,14 @@ def main(log_directory: str, event_file: Optional[str], batch_size: Optional[int], output_vectors: bool, - gpu: bool): - model = load_model(log_directory, test_file, event_file, batch_size, gpu) + gpu: bool, + fp16: bool): + model = load_model(log_directory, test_file, event_file, batch_size, gpu, fp16=fp16) if output_vectors: - evaluation, full_outputs = evaluate_on_test_dataset(model, return_full_output=True) + evaluation, full_outputs = evaluate_on_test_dataset(model, return_full_output=True, fp16=fp16) else: - evaluation = evaluate_on_test_dataset(model, return_full_output=False) + evaluation = evaluate_on_test_dataset(model, return_full_output=False, fp16=fp16) full_outputs = None create_hdf5_output(output_file, model.testing_dataset, evaluation, full_outputs) @@ -102,6 +103,9 @@ def main(log_directory: str, parser.add_argument("-g", "--gpu", action="store_true", help="Evaluate network on the gpu.") + + parser.add_argument("-fp16", "--fp16", action="store_true", + help="Use Automatic Mixed Precision for inference.") parser.add_argument("-v", "--output_vectors", action="store_true", help="Include embedding vectors in output in an additional section of the HDF5.") diff --git a/spanet/test.py b/spanet/test.py index b767225..7f76d64 100644 --- a/spanet/test.py +++ b/spanet/test.py @@ -228,10 +228,11 @@ def main( batch_size: Optional[int], lines: int, gpu: bool, + fp16: bool, latex: bool ): - model = load_model(log_directory, test_file, event_file, batch_size, gpu) - evaluation = evaluate_on_test_dataset(model) + model = load_model(log_directory, test_file, event_file, batch_size, gpu, fp16=fp16) + evaluation = evaluate_on_test_dataset(model, fp16=fp16) # Flatten predictions predictions = list(evaluation.assignments.values()) @@ -269,6 +270,9 @@ def main( parser.add_argument("-g", "--gpu", action="store_true", help="Evaluate network on the gpu.") + parser.add_argument("-fp16", "--fp16", action="store_true", + help="Use Automatic Mixed Precision for inference.") + parser.add_argument("-tex", "--latex", action="store_true", help="Output a latex table.")