Skip to content

Commit

Permalink
Reproduce the trick used before to test during training
Browse files Browse the repository at this point in the history
Alas, it requires to reload the dataloaders every epoch when test_interval>0
  • Loading branch information
RaulPPelaez committed Aug 9, 2023
1 parent 625a75a commit aa8c593
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 26 deletions.
21 changes: 12 additions & 9 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def train_dataloader(self):

def val_dataloader(self):
loaders = [self._get_dataloader(self.val_dataset, "val")]
if (
len(self.test_dataset) > 0
and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0
):
# To allow to report the performance on the testing dataset during training
# we send the trainer two dataloaders every few steps and modify the
# validation step to understand the second dataloader as test data.
if self._is_test_during_training_epoch():
loaders.append(self._get_dataloader(self.test_dataset, "test"))
return loaders

Expand All @@ -116,13 +116,16 @@ def mean(self):
def std(self):
return self._std

def _get_dataloader(self, dataset, stage, store_dataloader=True):
store_dataloader = (
store_dataloader and self.trainer.reload_dataloaders_every_n_epochs <= 0
def _is_test_during_training_epoch(self):
return (
len(self.test_dataset) > 0
and self.hparams["test_interval"] > 0
and self.trainer.current_epoch > 0
and self.trainer.current_epoch % self.hparams["test_interval"] == 0
)

def _get_dataloader(self, dataset, stage, store_dataloader=True):
if stage in self._saved_dataloaders and store_dataloader:
# storing the dataloaders like this breaks calls to trainer.reload_train_val_dataloaders
# but makes it possible that the dataloaders are not recreated on every testing epoch
return self._saved_dataloaders[stage]

if stage == "train":
Expand Down
24 changes: 8 additions & 16 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,16 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch, batch_idx, *args):
# If args is not empty the first (and only) element is the dataloader_idx
# We want to test every couple of epochs, but this is not supported by Lightning.
# We want to test every number of epochs just for reporting, but this is not supported by Lightning.
# Instead, we trick it by providing two validation dataloaders and interpreting the second one as test.
# The dataloader takes care of sending the two dataloaders only when the second one is needed.
is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0)
loss_fn = mse_loss if is_val else l1_loss
step_type = "val" if is_val else "test"
return self.step(batch, loss_fn, step_type)
step_type = (
{"loss_fn": mse_loss, "stage": "val"}
if is_val
else {"loss_fn": l1_loss, "satage": "test"}
)
return self.step(batch, **step_type)

def test_step(self, batch, batch_idx):
return self.step(batch, l1_loss, "test")
Expand Down Expand Up @@ -165,18 +169,6 @@ def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
optimizer.zero_grad()

def on_train_epoch_end(self, training_step_outputs=None):
# Handle the resetting of validation dataloaders
dm = self.trainer.datamodule
if hasattr(dm, "test_dataset") and len(dm.test_dataset) > 0:
should_reset = (
self.current_epoch % self.hparams.test_interval == 0
or (self.current_epoch + 1) % self.hparams.test_interval == 0
)
if should_reset:
# Using the new way to reset dataloaders in PyTorch Lightning v2.0.0
self.trainer.validate_loop.setup_data()

def on_validation_epoch_end(self):
if not self.trainer.sanity_checking:
# construct dict of logged metrics
Expand Down
3 changes: 2 additions & 1 deletion torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def main():
logger=_logger,
precision=args.precision,
gradient_clip_val=args.gradient_clipping,
inference_mode=False
inference_mode=False,
reload_dataloaders_every_n_epochs= 1 if args.test_interval > 0 and args.test_interval < args.num_epochs else -1,
)

trainer.fit(model, data, ckpt_path=None if args.reset_trainer else args.load_model)
Expand Down

0 comments on commit aa8c593

Please sign in to comment.