From 87ea233a41eeefbf7c52c4b8902959c3d4e2d2cb Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 7 Aug 2023 13:15:34 +0200 Subject: [PATCH] Do not store dtype in args when creating the model --- torchmdnet/models/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 3fc7fd3da..cb6475481 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -26,7 +26,7 @@ def create_model(args, prior_model=None, mean=None, std=None): nn.Module: An instance of the TorchMD_Net model. """ args["precision"] = 32 if "precision" not in args else args["precision"] - args["dtype"] = dtype_mapping[args["precision"]] + dtype = dtype_mapping[args["precision"]] shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -38,7 +38,7 @@ def create_model(args, prior_model=None, mean=None, std=None): cutoff_upper=args["cutoff_upper"], max_z=args["max_z"], max_num_neighbors=args["max_num_neighbors"], - dtype=args["dtype"] + dtype=dtype ) # representation network @@ -102,7 +102,7 @@ def create_model(args, prior_model=None, mean=None, std=None): args["embedding_dimension"], activation=args["activation"], reduce_op=args["reduce_op"], - dtype=args["dtype"], + dtype=dtype, ) # combine representation and output network @@ -113,7 +113,7 @@ def create_model(args, prior_model=None, mean=None, std=None): mean=mean, std=std, derivative=args["derivative"], - dtype=args["dtype"], + dtype=dtype, ) return model