From aa8c593f73b8bfd16f27fa8073bdcf358a1bf8a6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 9 Aug 2023 12:41:40 +0200 Subject: [PATCH] Reproduce the trick used before to test during training Alas, it requires to reload the dataloaders every epoch when test_interval>0 --- torchmdnet/data.py | 21 ++++++++++++--------- torchmdnet/module.py | 24 ++++++++---------------- torchmdnet/scripts/train.py | 3 ++- 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/torchmdnet/data.py b/torchmdnet/data.py index a7a30d6e4..702df4f46 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -92,10 +92,10 @@ def train_dataloader(self): def val_dataloader(self): loaders = [self._get_dataloader(self.val_dataset, "val")] - if ( - len(self.test_dataset) > 0 - and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0 - ): + # To allow to report the performance on the testing dataset during training + # we send the trainer two dataloaders every few steps and modify the + # validation step to understand the second dataloader as test data. + if self._is_test_during_training_epoch(): loaders.append(self._get_dataloader(self.test_dataset, "test")) return loaders @@ -116,13 +116,16 @@ def mean(self): def std(self): return self._std - def _get_dataloader(self, dataset, stage, store_dataloader=True): - store_dataloader = ( - store_dataloader and self.trainer.reload_dataloaders_every_n_epochs <= 0 + def _is_test_during_training_epoch(self): + return ( + len(self.test_dataset) > 0 + and self.hparams["test_interval"] > 0 + and self.trainer.current_epoch > 0 + and self.trainer.current_epoch % self.hparams["test_interval"] == 0 ) + + def _get_dataloader(self, dataset, stage, store_dataloader=True): if stage in self._saved_dataloaders and store_dataloader: - # storing the dataloaders like this breaks calls to trainer.reload_train_val_dataloaders - # but makes it possible that the dataloaders are not recreated on every testing epoch return self._saved_dataloaders[stage] if stage == "train": diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 084e45bb2..74f6b1553 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -72,12 +72,16 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, *args): # If args is not empty the first (and only) element is the dataloader_idx - # We want to test every couple of epochs, but this is not supported by Lightning. + # We want to test every number of epochs just for reporting, but this is not supported by Lightning. # Instead, we trick it by providing two validation dataloaders and interpreting the second one as test. + # The dataloader takes care of sending the two dataloaders only when the second one is needed. is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0) - loss_fn = mse_loss if is_val else l1_loss - step_type = "val" if is_val else "test" - return self.step(batch, loss_fn, step_type) + step_type = ( + {"loss_fn": mse_loss, "stage": "val"} + if is_val + else {"loss_fn": l1_loss, "satage": "test"} + ) + return self.step(batch, **step_type) def test_step(self, batch, batch_idx): return self.step(batch, l1_loss, "test") @@ -165,18 +169,6 @@ def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) optimizer.zero_grad() - def on_train_epoch_end(self, training_step_outputs=None): - # Handle the resetting of validation dataloaders - dm = self.trainer.datamodule - if hasattr(dm, "test_dataset") and len(dm.test_dataset) > 0: - should_reset = ( - self.current_epoch % self.hparams.test_interval == 0 - or (self.current_epoch + 1) % self.hparams.test_interval == 0 - ) - if should_reset: - # Using the new way to reset dataloaders in PyTorch Lightning v2.0.0 - self.trainer.validate_loop.setup_data() - def on_validation_epoch_end(self): if not self.trainer.sanity_checking: # construct dict of logged metrics diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 760484047..097346cdf 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -174,7 +174,8 @@ def main(): logger=_logger, precision=args.precision, gradient_clip_val=args.gradient_clipping, - inference_mode=False + inference_mode=False, + reload_dataloaders_every_n_epochs= 1 if args.test_interval > 0 and args.test_interval < args.num_epochs else -1, ) trainer.fit(model, data, ckpt_path=None if args.reset_trainer else args.load_model)