Skip to content

Commit

Permalink
Blacken
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Aug 9, 2023
1 parent aa8c593 commit 07cab4e
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class LNNP(LightningModule):
"""
Lightning wrapper for the Neural Network Potentials in TorchMD-Net.
"""

def __init__(self, hparams, prior_model=None, mean=None, std=None):
super(LNNP, self).__init__()

if "charge" not in hparams:
hparams["charge"] = False
if "spin" not in hparams:
Expand Down Expand Up @@ -57,14 +57,15 @@ def configure_optimizers(self):
}
return [optimizer], [lr_scheduler]

def forward(self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor]]:
def forward(
self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
return self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args)

def training_step(self, batch, batch_idx):
Expand All @@ -89,7 +90,7 @@ def test_step(self, batch, batch_idx):
def step(self, batch, loss_fn, stage):
with torch.set_grad_enabled(stage == "train" or self.hparams.derivative):
extra_args = batch.to_dict()
for a in ('y', 'neg_dy', 'z', 'pos', 'batch', 'q', 's'):
for a in ("y", "neg_dy", "z", "pos", "batch", "q", "s"):
if a in extra_args:
del extra_args[a]
# TODO: the model doesn't necessarily need to return a derivative once
Expand All @@ -100,10 +101,9 @@ def step(self, batch, loss_fn, stage):
batch=batch.batch,
q=batch.q if self.hparams.charge else None,
s=batch.s if self.hparams.spin else None,
extra_args=extra_args
extra_args=extra_args,
)


loss_y, loss_neg_dy = 0, 0
if self.hparams.derivative:
if "y" not in batch:
Expand Down Expand Up @@ -186,13 +186,21 @@ def on_validation_epoch_end(self):
# if prediction and derivative are present, also log them separately
if len(self.losses["train_y"]) > 0 and len(self.losses["train_neg_dy"]) > 0:
result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean()
result_dict["train_loss_neg_dy"] = torch.stack(self.losses["train_neg_dy"]).mean()
result_dict["train_loss_neg_dy"] = torch.stack(
self.losses["train_neg_dy"]
).mean()
result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean()
result_dict["val_loss_neg_dy"] = torch.stack(self.losses["val_neg_dy"]).mean()
result_dict["val_loss_neg_dy"] = torch.stack(
self.losses["val_neg_dy"]
).mean()

if len(self.losses["test"]) > 0:
result_dict["test_loss_y"] = torch.stack(self.losses["test_y"]).mean()
result_dict["test_loss_neg_dy"] = torch.stack(self.losses["test_neg_dy"]).mean()
result_dict["test_loss_y"] = torch.stack(
self.losses["test_y"]
).mean()
result_dict["test_loss_neg_dy"] = torch.stack(
self.losses["test_neg_dy"]
).mean()

self.log_dict(result_dict, sync_dist=True)

Expand All @@ -212,4 +220,9 @@ def _reset_losses_dict(self):
}

def _reset_ema_dict(self):
self.ema = {"train_y": None, "val_y": None, "train_neg_dy": None, "val_neg_dy": None}
self.ema = {
"train_y": None,
"val_y": None,
"train_neg_dy": None,
"val_neg_dy": None,
}

0 comments on commit 07cab4e

Please sign in to comment.