Skip to content

Commit

Permalink
Merge pull request #232 from nlesc-nano/anchor_filter
Browse files Browse the repository at this point in the history
ENH: Add the `anchor.multi_anchor_filter` option
  • Loading branch information
BvB93 authored Jan 21, 2022
2 parents aee208e + c2d79df commit c424729
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 5 deletions.
16 changes: 14 additions & 2 deletions CAT/attachment/ligand_anchoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from rdkit import Chem

from ..logger import logger
from ..utils import get_template, AnchorTup, KindEnum, get_formula, FormatEnum
from ..utils import get_template, AnchorTup, KindEnum, get_formula, FormatEnum, MultiAnchorEnum
from ..mol_utils import separate_mod # noqa: F401
from ..workflows import MOL, FORMULA, HDF5_INDEX, OPT
from ..settings_dataframe import SettingsDataFrame
Expand Down Expand Up @@ -201,12 +201,24 @@ def find_substructure(
ligand_idx_dict[anchor_tup].append(idx_tup)
ref_set.add(anchor_idx_tup)

# Apply some further filtering to the ligands
if condition is not None:
if not condition(sum((len(i) for i in ligand_idx_dict.values()), 0)):
if not condition(sum((len(j) for j in ligand_idx_dict.values()), 0)):
err = (f"Failed to satisfy the passed condition ({condition!r}) for "
f"ligand: {ligand.properties.name!r}")
logger.error(err)
return []
else:
for anchor_tup, j in ligand_idx_dict.items():
if anchor_tup.multi_anchor_filter == MultiAnchorEnum.ALL:
pass
elif anchor_tup.multi_anchor_filter == MultiAnchorEnum.FIRST and len(j) > 1:
ligand_idx_dict[anchor_tup] = j[:1]
elif anchor_tup.multi_anchor_filter == MultiAnchorEnum.RAISE and len(j) > 1:
logger.error(
f"Found multiple valid functional groups for {ligand.properties.name!r}"
)
return []

ret = []
idx_dict_items = chain.from_iterable(zip(repeat(k), v) for k, v in ligand_idx_dict.items())
Expand Down
18 changes: 16 additions & 2 deletions CAT/data_handling/anchor_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from schema import Schema, Use, Optional
from typing_extensions import TypedDict, SupportsIndex

from ..utils import AnchorTup, KindEnum, FormatEnum
from ..utils import AnchorTup, KindEnum, FormatEnum, MultiAnchorEnum
from ..attachment.ligand_anchoring import _smiles_to_rdmol, get_functional_groups

__all__ = ["parse_anchors"]
Expand All @@ -26,6 +26,7 @@ class _UnparsedAnchorDict(_UnparsedAnchorDictBase, total=False):
dihedral: "None | SupportsFloat | SupportsIndex | bytes | str"
kind: "None | str | KindEnum"
group_format: "None | str | FormatEnum"
multi_anchor_filter: "None | str | MultiAnchorEnum"


class _AnchorDict(TypedDict):
Expand All @@ -35,7 +36,8 @@ class _AnchorDict(TypedDict):
kind: KindEnum
angle_offset: "None | float"
dihedral: "None | float"
group_format: "FormatEnum"
group_format: FormatEnum
multi_anchor_filter: MultiAnchorEnum


def _parse_group_idx(item: "SupportsIndex | Iterable[SupportsIndex]") -> Tuple[int, ...]:
Expand Down Expand Up @@ -85,6 +87,17 @@ def _parse_group_format(typ: "None | str | FormatEnum") -> FormatEnum:
raise TypeError("`group_format` expected None or a string")


def _parse_multi_anchor_filter(typ: "None | str | MultiAnchorEnum") -> MultiAnchorEnum:
"""Parse the ``multi_anchor_filter`` option."""
if typ is None:
return MultiAnchorEnum.ALL
elif isinstance(typ, MultiAnchorEnum):
return typ
elif isinstance(typ, str):
return MultiAnchorEnum[typ.upper()]
raise TypeError("`multi_anchor_filter` expected None or a string")


_UNIT_PATTERN = re.compile(r"([\.\_0-9]+)(\s+)?(\w+)?")


Expand Down Expand Up @@ -143,6 +156,7 @@ def _symbol_to_rdmol(symbol: str) -> Chem.Mol:
Optional("angle_offset", default=None): Use(_parse_angle_offset),
Optional("dihedral", default=None): Use(_parse_angle_offset),
Optional("group_format", default=FormatEnum.SMILES): Use(_parse_group_format),
Optional("multi_anchor_filter", default=MultiAnchorEnum.ALL): Use(_parse_multi_anchor_filter),
})

#: A collection of symbols used for different kinds of dummy atoms.
Expand Down
9 changes: 9 additions & 0 deletions CAT/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,14 @@ class AllignmentEnum(enum.Enum):
SURFACE = 1


class MultiAnchorEnum(enum.Enum):
"""An enum with different actions for when ligands with multiple anchors are found."""

ALL = 0
FIRST = 1
RAISE = 2


class AnchorTup(NamedTuple):
"""A named tuple with anchoring operation instructions."""

Expand All @@ -562,6 +570,7 @@ class AnchorTup(NamedTuple):
angle_offset: "None | float" = None
dihedral: "None | float" = None
group_format: FormatEnum = FormatEnum.SMILES
multi_anchor_filter: MultiAnchorEnum = MultiAnchorEnum.ALL


class AllignmentTup(NamedTuple):
Expand Down
15 changes: 15 additions & 0 deletions docs/4_optional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ Ligand
* :attr:`anchor.kind`
* :attr:`anchor.angle_offset`
* :attr:`anchor.dihedral`
* :attr:`anchor.multi_anchor_filter`

.. note::

Expand Down Expand Up @@ -790,6 +791,20 @@ Ligand
but if so desired one can explicitly pass the unit: ``dihedral: "0.5 rad"``.


.. attribute:: optional.ligand.anchor.multi_anchor_filter

:Parameter: * **Type** - :class:`str`
* **Default value** – :data:`"ALL"`

How ligands with multiple valid anchor sites are to-be treated.

Accepts one of the following options:

* ``"all"``: Construct a new ligand for each valid anchor/ligand combination.
* ``"first"``: Pick only the first valid functional group, all others are ignored.
* ``"raise"``: Treat a ligand as invalid if it has multiple valid anchoring sites.


.. attribute:: optional.ligand.split

:Parameter: * **Type** - :class:`bool`
Expand Down
9 changes: 8 additions & 1 deletion tests/test_ligand_anchoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from schema import SchemaError
from packaging.version import Version

from CAT.utils import get_template, KindEnum, AnchorTup, FormatEnum
from CAT.utils import get_template, KindEnum, AnchorTup, FormatEnum, MultiAnchorEnum
from CAT.base import prep_input
from CAT.attachment.ligand_anchoring import (
get_functional_groups, _smiles_to_rdmol, find_substructure, init_ligand_anchoring
Expand Down Expand Up @@ -279,6 +279,7 @@ class TestInputParsing:
invalid_group=({"group": "OC", "group_idx": 0, "group": 1.0}, SchemaError),
invalid_group_format=({"group": "OC", "group_idx": 0, "group_format": 1}, SchemaError),
invalid_kind=({"group": "OC", "group_idx": 0, "kind": 1}, SchemaError),
invalid_multi_anchor_filter=({"group": "OC", "group_idx": 0, "multi_anchor_filter": 1}, SchemaError),
)

@pytest.mark.parametrize("inp,exc_type", PARAM_RAISE.values(), ids=PARAM_RAISE.keys())
Expand Down Expand Up @@ -323,6 +324,9 @@ def test_raise_core(self, inp: Any, exc_type: "type[Exception]") -> None:
group_format_none={"group": "OCC", "group_idx": 0, "group_format": None},
group_format_str={"group": "OCC", "group_idx": 0, "group_format": "SMARTS"},
group_format_enum={"group": "OCC", "group_idx": 0, "group_format": FormatEnum.SMARTS},
multi_anchor_filter_none={"group": "OCC", "group_idx": 0, "multi_anchor_filter": None},
multi_anchor_filter_str={"group": "OCC", "group_idx": 0, "multi_anchor_filter": "ALL"},
multi_anchor_filter_enum={"group": "OCC", "group_idx": 0, "multi_anchor_filter": MultiAnchorEnum.ALL},
)
_PARAM_PASS2 = OrderedDict(
idx_scalar=AnchorTup(None, group="OCC", group_idx=(0,)),
Expand All @@ -348,6 +352,9 @@ def test_raise_core(self, inp: Any, exc_type: "type[Exception]") -> None:
group_format_none=AnchorTup(None, group="OCC", group_idx=(0,), group_format=FormatEnum.SMILES),
group_format_str=AnchorTup(None, group="OCC", group_idx=(0,), group_format=FormatEnum.SMARTS),
group_format_enum=AnchorTup(None, group="OCC", group_idx=(0,), group_format=FormatEnum.SMARTS),
multi_anchor_filter_none=AnchorTup(None, group="OCC", group_idx=(0,), multi_anchor_filter=MultiAnchorEnum.ALL),
multi_anchor_filter_str=AnchorTup(None, group="OCC", group_idx=(0,), multi_anchor_filter=MultiAnchorEnum.ALL),
multi_anchor_filter_enum=AnchorTup(None, group="OCC", group_idx=(0,), multi_anchor_filter=MultiAnchorEnum.ALL),
)
PARAM_PASS = OrderedDict({
k: (v1, v2) for (k, v1), v2 in zip(_PARAM_PASS1.items(), _PARAM_PASS2.values())
Expand Down

0 comments on commit c424729

Please sign in to comment.