-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
65aaf50
commit b8bb7ad
Showing
2 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
"""Train against relative energies and forces.""" | ||
import typing | ||
|
||
import datasets | ||
import datasets.table | ||
import pyarrow | ||
import smee | ||
import torch | ||
|
||
DATA_SCHEMA = pyarrow.schema( | ||
[ | ||
("smiles", pyarrow.string()), | ||
("coords", pyarrow.list_(pyarrow.float64())), | ||
("energy", pyarrow.list_(pyarrow.float64())), | ||
("forces", pyarrow.list_(pyarrow.float64())), | ||
] | ||
) | ||
|
||
|
||
class Entry(typing.TypedDict): | ||
"""Represents a set of reference energies and forces.""" | ||
|
||
smiles: str | ||
"""The indexed SMILES description of the molecule the energies and forces were | ||
computed for.""" | ||
|
||
coords: torch.Tensor | ||
"""The coordinates [Å] the energies and forces were evaluated at with | ||
```shape=(n_confs, n_particles, 3)```.""" | ||
energy: torch.Tensor | ||
"""The reference energies [kcal/mol] with ```shape=(n_confs,)```.""" | ||
forces: torch.Tensor | ||
"""The reference forces [kcal/mol/Å] with ```shape=(n_confs, n_particles, 3)```.""" | ||
|
||
|
||
def create_dataset(entries: list[Entry]) -> datasets.Dataset: | ||
"""Create a dataset from a list of existing entries. | ||
Args: | ||
entries: The entries to create the dataset from. | ||
Returns: | ||
The created dataset. | ||
""" | ||
|
||
table = pyarrow.Table.from_pylist( | ||
[ | ||
{ | ||
"smiles": entry["smiles"], | ||
"coords": torch.tensor(entry["coords"]).flatten().tolist(), | ||
"energy": torch.tensor(entry["energy"]).flatten().tolist(), | ||
"forces": torch.tensor(entry["forces"]).flatten().tolist(), | ||
} | ||
for entry in entries | ||
], | ||
schema=DATA_SCHEMA, | ||
) | ||
# TODO: validate rows | ||
dataset = datasets.Dataset(datasets.table.InMemoryTable(table)) | ||
dataset.set_format("torch") | ||
|
||
return dataset | ||
|
||
|
||
def extract_smiles(dataset: datasets.Dataset) -> list[str]: | ||
"""Return a list of unique SMILES strings in the dataset. | ||
Args: | ||
dataset: The dataset to extract the SMILES strings from. | ||
Returns: | ||
The list of unique SMILES strings. | ||
""" | ||
return sorted({*dataset.unique("smiles")}) | ||
|
||
|
||
def predict( | ||
dataset: datasets.Dataset, | ||
force_field: smee.TensorForceField, | ||
topologies: dict[str, smee.TensorTopology], | ||
reference: typing.Literal["mean", "min"] = "mean", | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
"""Predict the relative energies [kcal/mol] and forces [kcal/mol/Å] of a dataset. | ||
Args: | ||
dataset: The dataset to predict the energies and forces of. | ||
force_field: The force field to use to predict the energies and forces. | ||
topologies: The topologies of the molecules in the dataset. Each key should be | ||
a fully indexed SMILES string. | ||
reference: The reference energy to compute the relative energies with respect | ||
to. This should be either the "mean" energy of all conformers, or the | ||
energy of the conformer with the lowest reference energy ("min"). | ||
Returns: | ||
The predicted and reference relative energies [kcal/mol] with | ||
```shape=(n_confs,)```, and predicted and reference forces [kcal/mol/Å] with | ||
```shape=(n_confs * n_atoms_per_conf, 3)```. | ||
""" | ||
energy_ref_all, energy_pred_all = [], [] | ||
forces_ref_all, forces_pred_all = [], [] | ||
|
||
for entry in dataset: | ||
smiles = entry["smiles"] | ||
|
||
energy_ref = entry["energy"] | ||
forces_ref = entry["forces"].reshape(len(energy_ref), -1, 3) | ||
|
||
coords = ( | ||
entry["coords"] | ||
.reshape(len(energy_ref), -1, 3) | ||
.detach() | ||
.requires_grad_(True) | ||
) | ||
|
||
topology = topologies[smiles] | ||
|
||
energy_pred = smee.compute_energy(topology, force_field, coords) | ||
forces_pred = torch.autograd.grad( | ||
energy_pred.sum(), | ||
coords, | ||
create_graph=True, | ||
retain_graph=True, | ||
allow_unused=True, | ||
)[0] | ||
|
||
if reference.lower() == "mean": | ||
energy_ref_0 = energy_ref.mean() | ||
energy_pred_0 = energy_pred.mean() | ||
elif reference.lower() == "min": | ||
min_idx = energy_ref.argmin() | ||
|
||
energy_ref_0 = energy_ref[min_idx] | ||
energy_pred_0 = energy_pred[min_idx] | ||
else: | ||
raise NotImplementedError(f"invalid reference energy {reference}") | ||
|
||
energy_ref_all.append(energy_ref - energy_ref_0) | ||
forces_ref_all.append(forces_ref.reshape(-1, 3)) | ||
|
||
energy_pred_all.append(energy_pred - energy_pred_0) | ||
forces_pred_all.append(forces_pred.reshape(-1, 3)) | ||
|
||
return ( | ||
torch.cat(energy_ref_all), | ||
torch.cat(energy_pred_all), | ||
torch.cat(forces_ref_all), | ||
torch.cat(forces_pred_all), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import openff.interchange | ||
import openff.toolkit | ||
import pytest | ||
import smee.converters | ||
import torch | ||
|
||
import descent.utils.dataset | ||
from descent.targets.energy import Entry, create_dataset, extract_smiles, predict | ||
|
||
|
||
@pytest.fixture | ||
def mock_meoh_entry() -> Entry: | ||
return { | ||
"smiles": "[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", | ||
"coords": torch.arange(36, dtype=torch.float32).reshape(2, 6, 3), | ||
"energy": 3.0 * torch.arange(2, dtype=torch.float32), | ||
"forces": torch.arange(36, dtype=torch.float32).reshape(2, 6, 3) + 36.0, | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def mock_hoh_entry() -> Entry: | ||
return { | ||
"smiles": "[H:2][O:1][H:3]", | ||
"coords": torch.tensor( | ||
[ | ||
[[0.0, 0.0, 0.0], [-1.0, -0.5, 0.0], [1.0, -0.5, 0.0]], | ||
[[0.0, 0.0, 0.0], [-0.7, -0.5, 0.0], [0.7, -0.5, 0.0]], | ||
] | ||
), | ||
"energy": torch.tensor([2.0, 3.0]), | ||
"forces": torch.arange(18, dtype=torch.float32).reshape(2, 3, 3), | ||
} | ||
|
||
|
||
def test_create_dataset(mock_meoh_entry): | ||
expected_entries = [ | ||
{ | ||
"smiles": mock_meoh_entry["smiles"], | ||
"coords": pytest.approx(mock_meoh_entry["coords"].flatten()), | ||
"energy": pytest.approx(mock_meoh_entry["energy"]), | ||
"forces": pytest.approx(mock_meoh_entry["forces"].flatten()), | ||
}, | ||
] | ||
|
||
dataset = create_dataset([mock_meoh_entry]) | ||
assert len(dataset) == 1 | ||
|
||
entries = list(descent.utils.dataset.iter_dataset(dataset)) | ||
assert entries == expected_entries | ||
|
||
|
||
def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): | ||
expected_smiles = ["[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", "[H:2][O:1][H:3]"] | ||
|
||
dataset = create_dataset([mock_meoh_entry, mock_hoh_entry]) | ||
smiles = extract_smiles(dataset) | ||
|
||
assert smiles == expected_smiles | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"reference, " | ||
"expected_energy_ref, expected_forces_ref, " | ||
"expected_energy_pred, expected_forces_pred", | ||
[ | ||
( | ||
"mean", | ||
torch.tensor([-0.5, 0.5]), | ||
torch.tensor( | ||
[ | ||
[0.0, 1.0, 2.0], | ||
[3.0, 4.0, 5.0], | ||
[6.0, 7.0, 8.0], | ||
[9.0, 10.0, 11.0], | ||
[12.0, 13.0, 14.0], | ||
[15.0, 16.0, 17.0], | ||
] | ||
), | ||
torch.tensor([7.899425506591797, -7.89942741394043]), | ||
torch.tensor( | ||
[ | ||
[0.0, 83.55978393554688, 0.0], | ||
[-161.40325927734375, -41.77988815307617, 0.0], | ||
[161.40325927734375, -41.77988815307617, 0.0], | ||
[0.0, -137.45770263671875, 0.0], | ||
[102.62999725341797, 68.72884368896484, 0.0], | ||
[-102.62999725341797, 68.72884368896484, 0.0], | ||
] | ||
), | ||
), | ||
( | ||
"min", | ||
torch.tensor([0.0, 1.0]), | ||
torch.tensor( | ||
[ | ||
[0.0, 1.0, 2.0], | ||
[3.0, 4.0, 5.0], | ||
[6.0, 7.0, 8.0], | ||
[9.0, 10.0, 11.0], | ||
[12.0, 13.0, 14.0], | ||
[15.0, 16.0, 17.0], | ||
] | ||
), | ||
torch.tensor([0.0, -15.798852920532227]), | ||
torch.tensor( | ||
[ | ||
[0.0, 83.55978393554688, 0.0], | ||
[-161.40325927734375, -41.77988815307617, 0.0], | ||
[161.40325927734375, -41.77988815307617, 0.0], | ||
[0.0, -137.45770263671875, 0.0], | ||
[102.62999725341797, 68.72884368896484, 0.0], | ||
[-102.62999725341797, 68.72884368896484, 0.0], | ||
] | ||
), | ||
), | ||
], | ||
) | ||
def test_predict( | ||
reference, | ||
expected_energy_ref, | ||
expected_forces_ref, | ||
expected_energy_pred, | ||
expected_forces_pred, | ||
mock_hoh_entry, | ||
): | ||
dataset = create_dataset([mock_hoh_entry]) | ||
|
||
force_field, [topology] = smee.converters.convert_interchange( | ||
openff.interchange.Interchange.from_smirnoff( | ||
openff.toolkit.ForceField("openff-1.3.0.offxml"), | ||
openff.toolkit.Molecule.from_mapped_smiles( | ||
mock_hoh_entry["smiles"] | ||
).to_topology(), | ||
) | ||
) | ||
topologies = {mock_hoh_entry["smiles"]: topology} | ||
|
||
energy_ref, energy_pred, forces_ref, forces_pred = predict( | ||
dataset, force_field, topologies, reference=reference | ||
) | ||
|
||
assert energy_pred.shape == expected_energy_pred.shape | ||
assert torch.allclose(energy_pred, expected_energy_pred) | ||
assert energy_ref.shape == expected_energy_ref.shape | ||
assert torch.allclose(energy_ref, expected_energy_ref) | ||
|
||
assert forces_pred.shape == expected_forces_pred.shape | ||
assert torch.allclose(forces_pred, expected_forces_pred) | ||
assert forces_ref.shape == expected_forces_ref.shape | ||
assert torch.allclose(forces_ref, expected_forces_ref) |