diff --git a/examples/ET-QM9.yaml b/examples/ET-QM9.yaml index 24d4ba242..08b27a410 100644 --- a/examples/ET-QM9.yaml +++ b/examples/ET-QM9.yaml @@ -55,4 +55,4 @@ train_size: 110000 trainable_rbf: false val_size: 10000 weight_decay: 0.0 -dtype: float +precision: 32 diff --git a/tests/test_model.py b/tests/test_model.py index e2e5010ee..80e2461ef 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,6 +7,7 @@ from torchmdnet import models from torchmdnet.models.model import create_model from torchmdnet.models import output_modules +from torchmdnet.models.utils import dtype_mapping from utils import load_example_args, create_example_batch @@ -14,11 +15,11 @@ @mark.parametrize("model_name", models.__all__) @mark.parametrize("use_batch", [True, False]) @mark.parametrize("explicit_q_s", [True, False]) -@mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_forward(model_name, use_batch, explicit_q_s, dtype): +@mark.parametrize("precision", [32, 64]) +def test_forward(model_name, use_batch, explicit_q_s, precision): z, pos, batch = create_example_batch() - pos = pos.to(dtype=dtype) - model = create_model(load_example_args(model_name, prior_model=None, dtype=dtype)) + pos = pos.to(dtype=dtype_mapping[precision]) + model = create_model(load_example_args(model_name, prior_model=None, precision=precision)) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) @@ -28,10 +29,10 @@ def test_forward(model_name, use_batch, explicit_q_s, dtype): @mark.parametrize("model_name", models.__all__) @mark.parametrize("output_model", output_modules.__all__) -@mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_forward_output_modules(model_name, output_model, dtype): +@mark.parametrize("precision", [32,64]) +def test_forward_output_modules(model_name, output_model, precision): z, pos, batch = create_example_batch() - args = load_example_args(model_name, remove_prior=True, output_model=output_model, dtype=dtype) + args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision) model = create_model(args) model(z, pos, batch=batch) @@ -146,7 +147,7 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): @mark.parametrize("model_name", models.__all__) def test_gradients(model_name): pl.seed_everything(1234) - dtype = torch.float64 + precision = 64 output_model = "Scalar" # create model and sample batch derivative = output_model in ["Scalar", "EquivariantScalar"] @@ -155,12 +156,12 @@ def test_gradients(model_name): remove_prior=True, output_model=output_model, derivative=derivative, - dtype=dtype, + precision=precision ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) pos.requires_grad_(True) - pos = pos.to(dtype) + pos = pos.to(torch.float64) torch.autograd.gradcheck( model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) diff --git a/tests/test_module.py b/tests/test_module.py index 17002d7c8..9631a6996 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -24,7 +24,8 @@ def test_load_model(): @mark.parametrize("model_name", models.__all__) @mark.parametrize("use_atomref", [True, False]) -def test_train(model_name, use_atomref, tmpdir): +@mark.parametrize("precision", [32, 64]) +def test_train(model_name, use_atomref, precision, tmpdir): args = load_example_args( model_name, remove_prior=not use_atomref, @@ -37,6 +38,7 @@ def test_train(model_name, use_atomref, tmpdir): num_layers=2, num_rbf=16, batch_size=8, + precision=precision, ) datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref)) @@ -47,6 +49,6 @@ def test_train(model_name, use_atomref, tmpdir): module = LNNP(args, prior_model=prior) - trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir) + trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"]) trainer.fit(module, datamodule) trainer.test(module, datamodule) diff --git a/tests/test_optimize.py b/tests/test_optimize.py index aeb064b48..c99fe87e9 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -3,7 +3,7 @@ import torch as pt from torchmdnet.models.model import create_model from torchmdnet.optimize import optimize - +from torchmdnet.models.utils import dtype_mapping @mark.parametrize("device", ["cpu", "cuda"]) @mark.parametrize("num_atoms", [10, 100]) @@ -39,6 +39,7 @@ def test_gn(device, num_atoms): "prior_model": None, "output_model": "Scalar", "reduce_op": "add", + "precision": 32, } ref_model = create_model(args).to(device) @@ -47,7 +48,7 @@ def test_gn(device, num_atoms): # Optimize the model model = optimize(ref_model).to(device) - + positions.to(dtype_mapping[args["precision"]]) # Execute the optimize model energy, gradient = model(elements, positions) diff --git a/tests/utils.py b/tests/utils.py index e5fa4ee1e..aa7588d6a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,8 +12,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") with open(config_file, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) - if "dtype" not in args: - args["dtype"] = "float" + if "precision" not in args: + args["precision"] = 32 args["model"] = model_name args["seed"] = 1234 if remove_prior: diff --git a/torchmdnet/data.py b/torchmdnet/data.py index 516158faf..f150d5f8d 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -6,8 +6,37 @@ from pytorch_lightning import LightningDataModule from pytorch_lightning.utilities import rank_zero_warn from torchmdnet import datasets +from torch_geometric.data import Dataset from torchmdnet.utils import make_splits, MissingEnergyException from torch_scatter import scatter +from torchmdnet.models.utils import dtype_mapping + + +class FloatCastDatasetWrapper(Dataset): + def __init__(self, dataset, dtype=torch.float64): + super(FloatCastDatasetWrapper, self).__init__( + dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter + ) + self.dataset = dataset + self.dtype = dtype + + def len(self): + return len(self.dataset) + + def get(self, idx): + data = self.dataset.get(idx) + for key, value in data: + if torch.is_tensor(value) and torch.is_floating_point(value): + setattr(data, key, value.to(self.dtype)) + return data + + def __getattr__(self, name): + # Check if the attribute exists in the underlying dataset + if hasattr(self.dataset, name): + return getattr(self.dataset, name) + raise AttributeError( + f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'" + ) class DataModule(LightningDataModule): @@ -34,6 +63,9 @@ def setup(self, stage): self.dataset = getattr(datasets, self.hparams["dataset"])( self.hparams["dataset_root"], **dataset_arg ) + self.dataset = FloatCastDatasetWrapper( + self.dataset, dtype_mapping[self.hparams["precision"]] + ) self.idx_train, self.idx_val, self.idx_test = make_splits( len(self.dataset), @@ -62,7 +94,7 @@ def val_dataloader(self): loaders = [self._get_dataloader(self.val_dataset, "val")] if ( len(self.test_dataset) > 0 - and (self.trainer.current_epoch+1) % self.hparams["test_interval"] == 0 + and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0 ): loaders.append(self._get_dataloader(self.test_dataset, "test")) return loaders diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 27992e735..8a56dd39b 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -25,8 +25,7 @@ def create_model(args, prior_model=None, mean=None, std=None): ------- nn.Module: An instance of the TorchMD_Net model. """ - args["dtype"] = "float32" if "dtype" not in args else args["dtype"] - args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"] + dtype = dtype_mapping[args["precision"]] shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -38,7 +37,7 @@ def create_model(args, prior_model=None, mean=None, std=None): cutoff_upper=args["cutoff_upper"], max_z=args["max_z"], max_num_neighbors=args["max_num_neighbors"], - dtype=args["dtype"] + dtype=dtype ) # representation network @@ -102,7 +101,7 @@ def create_model(args, prior_model=None, mean=None, std=None): args["embedding_dimension"], activation=args["activation"], reduce_op=args["reduce_op"], - dtype=args["dtype"], + dtype=dtype, ) # combine representation and output network @@ -113,7 +112,7 @@ def create_model(args, prior_model=None, mean=None, std=None): mean=mean, std=std, derivative=args["derivative"], - dtype=args["dtype"], + dtype=dtype, ) return model diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index c0aa439b1..bdc806742 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -526,4 +526,4 @@ def forward(self, x, v): "sigmoid": nn.Sigmoid, } -dtype_mapping = {"float": torch.float, "double": torch.float64, "float32": torch.float32, "float64": torch.float64} +dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64} diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 6b120df86..e39c48dc8 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -37,7 +37,7 @@ def get_args(): parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy') parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus') parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes') - parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision') + parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision') parser.add_argument('--log-dir', '-l', default='/tmp/logs', help='log file') parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test') parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)') @@ -67,7 +67,6 @@ def get_args(): parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') # architectural args - parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float32 or float64') parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')