Skip to content

Commit

Permalink
Add energy and force target
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Nov 23, 2023
1 parent 65aaf50 commit b8bb7ad
Show file tree
Hide file tree
Showing 2 changed files with 299 additions and 0 deletions.
148 changes: 148 additions & 0 deletions descent/targets/energy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Train against relative energies and forces."""
import typing

import datasets
import datasets.table
import pyarrow
import smee
import torch

DATA_SCHEMA = pyarrow.schema(
[
("smiles", pyarrow.string()),
("coords", pyarrow.list_(pyarrow.float64())),
("energy", pyarrow.list_(pyarrow.float64())),
("forces", pyarrow.list_(pyarrow.float64())),
]
)


class Entry(typing.TypedDict):
"""Represents a set of reference energies and forces."""

smiles: str
"""The indexed SMILES description of the molecule the energies and forces were
computed for."""

coords: torch.Tensor
"""The coordinates [Å] the energies and forces were evaluated at with
```shape=(n_confs, n_particles, 3)```."""
energy: torch.Tensor
"""The reference energies [kcal/mol] with ```shape=(n_confs,)```."""
forces: torch.Tensor
"""The reference forces [kcal/mol/Å] with ```shape=(n_confs, n_particles, 3)```."""


def create_dataset(entries: list[Entry]) -> datasets.Dataset:
"""Create a dataset from a list of existing entries.
Args:
entries: The entries to create the dataset from.
Returns:
The created dataset.
"""

table = pyarrow.Table.from_pylist(
[
{
"smiles": entry["smiles"],
"coords": torch.tensor(entry["coords"]).flatten().tolist(),
"energy": torch.tensor(entry["energy"]).flatten().tolist(),
"forces": torch.tensor(entry["forces"]).flatten().tolist(),
}
for entry in entries
],
schema=DATA_SCHEMA,
)
# TODO: validate rows
dataset = datasets.Dataset(datasets.table.InMemoryTable(table))
dataset.set_format("torch")

return dataset


def extract_smiles(dataset: datasets.Dataset) -> list[str]:
"""Return a list of unique SMILES strings in the dataset.
Args:
dataset: The dataset to extract the SMILES strings from.
Returns:
The list of unique SMILES strings.
"""
return sorted({*dataset.unique("smiles")})


def predict(
dataset: datasets.Dataset,
force_field: smee.TensorForceField,
topologies: dict[str, smee.TensorTopology],
reference: typing.Literal["mean", "min"] = "mean",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Predict the relative energies [kcal/mol] and forces [kcal/mol/Å] of a dataset.
Args:
dataset: The dataset to predict the energies and forces of.
force_field: The force field to use to predict the energies and forces.
topologies: The topologies of the molecules in the dataset. Each key should be
a fully indexed SMILES string.
reference: The reference energy to compute the relative energies with respect
to. This should be either the "mean" energy of all conformers, or the
energy of the conformer with the lowest reference energy ("min").
Returns:
The predicted and reference relative energies [kcal/mol] with
```shape=(n_confs,)```, and predicted and reference forces [kcal/mol/Å] with
```shape=(n_confs * n_atoms_per_conf, 3)```.
"""
energy_ref_all, energy_pred_all = [], []
forces_ref_all, forces_pred_all = [], []

for entry in dataset:
smiles = entry["smiles"]

energy_ref = entry["energy"]
forces_ref = entry["forces"].reshape(len(energy_ref), -1, 3)

coords = (
entry["coords"]
.reshape(len(energy_ref), -1, 3)
.detach()
.requires_grad_(True)
)

topology = topologies[smiles]

energy_pred = smee.compute_energy(topology, force_field, coords)
forces_pred = torch.autograd.grad(
energy_pred.sum(),
coords,
create_graph=True,
retain_graph=True,
allow_unused=True,
)[0]

if reference.lower() == "mean":
energy_ref_0 = energy_ref.mean()
energy_pred_0 = energy_pred.mean()
elif reference.lower() == "min":
min_idx = energy_ref.argmin()

energy_ref_0 = energy_ref[min_idx]
energy_pred_0 = energy_pred[min_idx]
else:
raise NotImplementedError(f"invalid reference energy {reference}")

energy_ref_all.append(energy_ref - energy_ref_0)
forces_ref_all.append(forces_ref.reshape(-1, 3))

energy_pred_all.append(energy_pred - energy_pred_0)
forces_pred_all.append(forces_pred.reshape(-1, 3))

return (
torch.cat(energy_ref_all),
torch.cat(energy_pred_all),
torch.cat(forces_ref_all),
torch.cat(forces_pred_all),
)
151 changes: 151 additions & 0 deletions descent/tests/targets/test_energy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import openff.interchange
import openff.toolkit
import pytest
import smee.converters
import torch

import descent.utils.dataset
from descent.targets.energy import Entry, create_dataset, extract_smiles, predict


@pytest.fixture
def mock_meoh_entry() -> Entry:
return {
"smiles": "[C:1]([O:2][H:6])([H:3])([H:4])[H:5]",
"coords": torch.arange(36, dtype=torch.float32).reshape(2, 6, 3),
"energy": 3.0 * torch.arange(2, dtype=torch.float32),
"forces": torch.arange(36, dtype=torch.float32).reshape(2, 6, 3) + 36.0,
}


@pytest.fixture
def mock_hoh_entry() -> Entry:
return {
"smiles": "[H:2][O:1][H:3]",
"coords": torch.tensor(
[
[[0.0, 0.0, 0.0], [-1.0, -0.5, 0.0], [1.0, -0.5, 0.0]],
[[0.0, 0.0, 0.0], [-0.7, -0.5, 0.0], [0.7, -0.5, 0.0]],
]
),
"energy": torch.tensor([2.0, 3.0]),
"forces": torch.arange(18, dtype=torch.float32).reshape(2, 3, 3),
}


def test_create_dataset(mock_meoh_entry):
expected_entries = [
{
"smiles": mock_meoh_entry["smiles"],
"coords": pytest.approx(mock_meoh_entry["coords"].flatten()),
"energy": pytest.approx(mock_meoh_entry["energy"]),
"forces": pytest.approx(mock_meoh_entry["forces"].flatten()),
},
]

dataset = create_dataset([mock_meoh_entry])
assert len(dataset) == 1

entries = list(descent.utils.dataset.iter_dataset(dataset))
assert entries == expected_entries


def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
expected_smiles = ["[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", "[H:2][O:1][H:3]"]

dataset = create_dataset([mock_meoh_entry, mock_hoh_entry])
smiles = extract_smiles(dataset)

assert smiles == expected_smiles


@pytest.mark.parametrize(
"reference, "
"expected_energy_ref, expected_forces_ref, "
"expected_energy_pred, expected_forces_pred",
[
(
"mean",
torch.tensor([-0.5, 0.5]),
torch.tensor(
[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
]
),
torch.tensor([7.899425506591797, -7.89942741394043]),
torch.tensor(
[
[0.0, 83.55978393554688, 0.0],
[-161.40325927734375, -41.77988815307617, 0.0],
[161.40325927734375, -41.77988815307617, 0.0],
[0.0, -137.45770263671875, 0.0],
[102.62999725341797, 68.72884368896484, 0.0],
[-102.62999725341797, 68.72884368896484, 0.0],
]
),
),
(
"min",
torch.tensor([0.0, 1.0]),
torch.tensor(
[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
]
),
torch.tensor([0.0, -15.798852920532227]),
torch.tensor(
[
[0.0, 83.55978393554688, 0.0],
[-161.40325927734375, -41.77988815307617, 0.0],
[161.40325927734375, -41.77988815307617, 0.0],
[0.0, -137.45770263671875, 0.0],
[102.62999725341797, 68.72884368896484, 0.0],
[-102.62999725341797, 68.72884368896484, 0.0],
]
),
),
],
)
def test_predict(
reference,
expected_energy_ref,
expected_forces_ref,
expected_energy_pred,
expected_forces_pred,
mock_hoh_entry,
):
dataset = create_dataset([mock_hoh_entry])

force_field, [topology] = smee.converters.convert_interchange(
openff.interchange.Interchange.from_smirnoff(
openff.toolkit.ForceField("openff-1.3.0.offxml"),
openff.toolkit.Molecule.from_mapped_smiles(
mock_hoh_entry["smiles"]
).to_topology(),
)
)
topologies = {mock_hoh_entry["smiles"]: topology}

energy_ref, energy_pred, forces_ref, forces_pred = predict(
dataset, force_field, topologies, reference=reference
)

assert energy_pred.shape == expected_energy_pred.shape
assert torch.allclose(energy_pred, expected_energy_pred)
assert energy_ref.shape == expected_energy_ref.shape
assert torch.allclose(energy_ref, expected_energy_ref)

assert forces_pred.shape == expected_forces_pred.shape
assert torch.allclose(forces_pred, expected_forces_pred)
assert forces_ref.shape == expected_forces_ref.shape
assert torch.allclose(forces_ref, expected_forces_ref)

0 comments on commit b8bb7ad

Please sign in to comment.