Skip to content

Commit

Permalink
Do not store dtype in args when creating the model
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Aug 7, 2023
1 parent c155f18 commit 87ea233
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
nn.Module: An instance of the TorchMD_Net model.
"""
args["precision"] = 32 if "precision" not in args else args["precision"]
args["dtype"] = dtype_mapping[args["precision"]]
dtype = dtype_mapping[args["precision"]]
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
Expand All @@ -38,7 +38,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
max_num_neighbors=args["max_num_neighbors"],
dtype=args["dtype"]
dtype=dtype
)

# representation network
Expand Down Expand Up @@ -102,7 +102,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
args["embedding_dimension"],
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=args["dtype"],
dtype=dtype,
)

# combine representation and output network
Expand All @@ -113,7 +113,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
mean=mean,
std=std,
derivative=args["derivative"],
dtype=args["dtype"],
dtype=dtype,
)
return model

Expand Down

0 comments on commit 87ea233

Please sign in to comment.