Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement task restart policies #280

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7f752b3
Added placeholder tests for proposed methods
ianmkenney Jul 16, 2024
dd8f0e9
Added models for new node types
ianmkenney Jul 16, 2024
da17e45
Updated new GufeTokenizable models in statestore
ianmkenney Jul 17, 2024
b7f63d4
Added placeholder unit tests for new models
ianmkenney Jul 17, 2024
6a167f1
Added validation and unit tests for storage models
ianmkenney Jul 18, 2024
a10e235
Added `taskhub_sk` to `TaskRestartPattern`
ianmkenney Jul 22, 2024
b99d8ef
Added `statestore` methods for restart patterns
ianmkenney Jul 22, 2024
39f9868
Added APPLIES relationship when adding pattern
ianmkenney Jul 25, 2024
988155f
Establish APPLIES when actioning a Task
ianmkenney Jul 26, 2024
d3f25f8
Canceling a Task removes the APPLIES relationship
ianmkenney Jul 26, 2024
510ae66
Task status changes affect APPLIES relationship
ianmkenney Aug 1, 2024
2310fd5
Tests for Task status change on APPLIES
ianmkenney Aug 4, 2024
ea2851f
Added method (unimplemented) calls for restarts
ianmkenney Aug 4, 2024
8e011be
Implemented add_protocol_dag_result_ref_traceback
ianmkenney Aug 5, 2024
4f07dde
Started implementation of restart resolution
ianmkenney Aug 6, 2024
78c4551
Tracebacks now include key data from its source units
ianmkenney Aug 7, 2024
7acc003
Built out custom fixture for testing restart policies
ianmkenney Aug 13, 2024
03d9fa1
Added the `chainable` decorator to Neo4jStore
ianmkenney Aug 19, 2024
aad97e3
Resolve task restarts now sets all remaining tasks to waiting
ianmkenney Aug 19, 2024
a655dc7
Corrected resolution logic
ianmkenney Aug 19, 2024
5bb6700
Extracted complexity out of test_resolve_task_restarts
ianmkenney Aug 23, 2024
fe4b87b
resolve restart of tasks with no tracebacks
ianmkenney Aug 23, 2024
8a6f980
Replaced many maps with a for loop
ianmkenney Aug 23, 2024
93eb5f5
Small changes from review
dotsdl Sep 4, 2024
0900f39
Chainable now uses the update_wrapper function
ianmkenney Sep 9, 2024
c8ddafc
Updated Traceback class
ianmkenney Sep 9, 2024
2a59499
Renamed Traceback to Tracebacks
ianmkenney Sep 9, 2024
148d048
Updated cancel and increment logic
ianmkenney Sep 9, 2024
645b2e4
Fixed query for deleting the APPLIES relationship
ianmkenney Sep 9, 2024
3a8eeca
Removed unused testing fixture
ianmkenney Sep 9, 2024
ea6e66f
Clarified comment and added complimentary assertion
ianmkenney Sep 9, 2024
7a4b114
Small changes to Tracebacks
dotsdl Sep 13, 2024
cf0e961
Merge pull request #286 from OpenFreeEnergy/feature/iss-277-restart-p…
ianmkenney Sep 19, 2024
6066796
Fix for Tracebacks unit tests
ianmkenney Sep 24, 2024
fcf77a0
Added API endpoints for managing restart policies
ianmkenney Sep 25, 2024
cea16bc
Added untested client method for task restart policies
ianmkenney Oct 1, 2024
a4da776
Added testing for client methods dealing with restart policies
ianmkenney Oct 1, 2024
fdc25a7
`get_taskhub` calls `get_taskhubs`
ianmkenney Oct 7, 2024
51194ff
Updated docstrings
ianmkenney Oct 7, 2024
f03417c
Merge branch 'main' into feature/iss-277-restart-policy
ianmkenney Oct 8, 2024
977c896
Added docstrings to client methods
ianmkenney Oct 21, 2024
2d2d8f6
Added Task restart patterns to user guide
ianmkenney Oct 21, 2024
d7dcd5c
Link to python classes and methods in restart pattern section
ianmkenney Oct 21, 2024
006e689
Merge branch 'main' into feature/iss-277-restart-policy
dotsdl Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from fastapi import FastAPI, APIRouter, Body, Depends
from fastapi.middleware.gzip import GZipMiddleware
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
from gufe.protocols import ProtocolDAGResult

from ..base.api import (
QueryGUFEHandler,
Expand Down Expand Up @@ -329,7 +330,7 @@
validate_scopes(task_sk.scope, token)

pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder)
pdr = GufeTokenizable.from_dict(pdr)
pdr: ProtocolDAGResult = GufeTokenizable.from_dict(pdr)

Check warning on line 333 in alchemiscale/compute/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/compute/api.py#L333

Added line #L333 was not covered by tests

tf_sk, _ = n4js.get_task_transformation(
task=task_scoped_key,
Expand All @@ -351,7 +352,11 @@
if protocoldagresultref.ok:
n4js.set_task_complete(tasks=[task_sk])
else:
n4js.add_protocol_dag_result_ref_tracebacks(

Check warning on line 355 in alchemiscale/compute/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/compute/api.py#L355

Added line #L355 was not covered by tests
pdr.protocol_unit_failures, result_sk
)
n4js.set_task_error(tasks=[task_sk])
n4js.resolve_task_restarts(tasks=[task_sk])

Check warning on line 359 in alchemiscale/compute/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/compute/api.py#L359

Added line #L359 was not covered by tests

return result_sk

Expand Down
80 changes: 80 additions & 0 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,86 @@
return status[0].value


@router.post("/networks/{network_scoped_key}/restartpolicy/add")
def add_task_restart_patterns(
network_scoped_key: str,
*,
patterns: list[str] = Body(embed=True),
number_of_retries: int = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
taskhub_scoped_key = n4js.get_taskhub(ScopedKey.from_str(network_scoped_key))
n4js.add_task_restart_patterns(taskhub_scoped_key, patterns, number_of_retries)

Check warning on line 960 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L959-L960

Added lines #L959 - L960 were not covered by tests


@router.post("/networks/{network_scoped_key}/restartpolicy/remove")
def remove_task_restart_patterns(
network_scoped_key: str,
*,
patterns: list[str] = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
taskhub_scoped_key = n4js.get_taskhub(ScopedKey.from_str(network_scoped_key))
n4js.remove_task_restart_patterns(taskhub_scoped_key, patterns)

Check warning on line 972 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L971-L972

Added lines #L971 - L972 were not covered by tests


@router.get("/networks/{network_scoped_key}/restartpolicy/clear")
def clear_task_restart_patterns(
network_scoped_key: str,
*,
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
taskhub_scoped_key = n4js.get_taskhub(ScopedKey.from_str(network_scoped_key))
n4js.clear_task_restart_patterns(taskhub_scoped_key)
return [network_scoped_key]

Check warning on line 984 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L982-L984

Added lines #L982 - L984 were not covered by tests


@router.post("/bulk/networks/restartpolicy/get")
def get_task_restart_patterns(
*,
networks: list[str] = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
) -> dict[str, set[tuple[str, int]]]:

network_scoped_keys = [ScopedKey.from_str(network) for network in networks]
taskhub_scoped_keys = n4js.get_taskhubs(network_scoped_keys)

Check warning on line 996 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L995-L996

Added lines #L995 - L996 were not covered by tests

taskhub_network_map = {

Check warning on line 998 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L998

Added line #L998 was not covered by tests
taskhub_scoped_key: network_scoped_key
for taskhub_scoped_key, network_scoped_key in zip(
taskhub_scoped_keys, network_scoped_keys
)
}

restart_patterns = n4js.get_task_restart_patterns(taskhub_scoped_keys)

Check warning on line 1005 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L1005

Added line #L1005 was not covered by tests

as_str = {}
for key, value in restart_patterns.items():
network_scoped_key = taskhub_network_map[key]
as_str[str(network_scoped_key)] = value

Check warning on line 1010 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L1007-L1010

Added lines #L1007 - L1010 were not covered by tests

return as_str

Check warning on line 1012 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L1012

Added line #L1012 was not covered by tests


@router.post("/networks/{network_scoped_key}/restartpolicy/maxretries")
def set_task_restart_patterns_max_retries(
network_scoped_key: str,
*,
patterns: list[str] = Body(embed=True),
max_retries: int = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
taskhub_scoped_key = n4js.get_taskhub(ScopedKey.from_str(network_scoped_key))
n4js.set_task_restart_patterns_max_retries(

Check warning on line 1025 in alchemiscale/interface/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/interface/api.py#L1024-L1025

Added lines #L1024 - L1025 were not covered by tests
taskhub_scoped_key, patterns, max_retries
)


@router.get("/tasks/{task_scoped_key}/transformation")
def get_task_transformation(
task_scoped_key,
Expand Down
102 changes: 102 additions & 0 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,3 +1740,105 @@ def get_task_failures(
)

return pdrs

def add_task_restart_patterns(
self,
network_scoped_key: ScopedKey,
patterns: list[str],
num_allowed_restarts: int,
) -> ScopedKey:
"""Add a list of restart patterns to an `AlchemicalNetwork`.

Parameters
----------
network_scoped_key: ScopedKey
The ScopedKey for the AlchemicalNetwork to add the patterns to.
patterns: list[str]
The regular expression strings to compare to ProtocolUnitFailure tracebacks.
Matching patterns will set the Task status back to 'waiting'.
num_allowed_restarts: int
The number of times each pattern will be able to restart each `Task`. When
this number is exceeded, the `Task` is canceled from the `AlchemicalNetwork`
and left with the `error` status.

Returns
-------
network_scoped_key: ScopedKey
The ScopedKey of the AlchemicalNetwork the patterns were added to.
"""
data = {"patterns": patterns, "number_of_retries": num_allowed_restarts}
self._post_resource(f"/networks/{network_scoped_key}/restartpolicy/add", data)
return network_scoped_key

def get_task_restart_patterns(
self, network_scoped_key: ScopedKey
) -> dict[str, int]:
"""Get the Task restart patterns enforcing an AlchemicalNetwork along with the number of retries allowed for each pattern.

Parameters
----------
network_scoped_key: ScopedKey
The ScopedKey of the AlchemicalNetwork to query.

Returns
-------
patterns : dict[str, int]
A dictionary whose keys are all of the patterns enforcing the `AlchemicalNetwork` and whose
values are the number of retries each pattern will allow.
"""
data = {"networks": [str(network_scoped_key)]}
mapped_patterns = self._post_resource(
"/bulk/networks/restartpolicy/get", data=data
)
network_patterns = mapped_patterns[str(network_scoped_key)]
patterns_with_retries = {pattern: retry for pattern, retry in network_patterns}
return patterns_with_retries

def set_task_restart_patterns_allowed_restarts(
self,
network_scoped_key: ScopedKey,
patterns: list[str],
num_allowed_restarts: int,
) -> None:
"""Set the number of allowed restarts that patterns allowed to perform within an AlchemicalNetwork.

Parameters
----------
network_scoped_key : ScopedKey
The ScopedKey of the `AlchemicalNetwork` enforced by `patterns`.
patterns: list[str]
The patterns to set the number of allowed restarts for.
num_allowed_restarts : int
The new number of allowed restarts.
"""
data = {"patterns": patterns, "max_retries": num_allowed_restarts}
self._post_resource(
f"/networks/{network_scoped_key}/restartpolicy/maxretries", data
)

def remove_task_restart_patterns(
self, network_scoped_key: ScopedKey, patterns: list[str]
) -> None:
"""Remove specific patterns from an `AlchemicalNetwork`.

Parameters
----------
network_scoped_key : ScopedKey
The ScopedKey of the `AlchemicalNetwork` enforced by `patterns`.
patterns: list[str]
The patterns to remove from the `AlchemicalNetwork`.
"""
data = {"patterns": patterns}
self._post_resource(
f"/networks/{network_scoped_key}/restartpolicy/remove", data
)

def clear_task_restart_patterns(self, network_scoped_key: ScopedKey) -> None:
"""Clear all restart patterns from an `AlchemicalNetwork`.

Parameters
----------
network_scoped_key : ScopedKey
The ScopeKey of the `AlchemicalNetwork` to be cleared of restart patterns.
"""
self._query_resource(f"/networks/{network_scoped_key}/restartpolicy/clear")
109 changes: 107 additions & 2 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from copy import copy
from datetime import datetime
from enum import Enum
from typing import Union, Dict, Optional
from typing import Union, Optional, List
from uuid import uuid4
import hashlib


from pydantic import BaseModel, Field
from pydantic import BaseModel
from gufe.tokenization import GufeTokenizable, GufeKey

from ..models import ScopedKey, Scope
Expand Down Expand Up @@ -143,6 +143,111 @@
return super()._defaults()


class TaskRestartPattern(GufeTokenizable):
"""A pattern to compare returned Task tracebacks to.

Attributes
----------
pattern: str
A regular expression pattern that can match to returned tracebacks of errored Tasks.
max_retries: int
The number of times the pattern can trigger a restart for a Task.
taskhub_sk: str
The TaskHub the pattern is bound to. This is needed to properly set a unique Gufe key.
"""

pattern: str
max_retries: int
taskhub_sk: str

def __init__(
self, pattern: str, max_retries: int, taskhub_scoped_key: Union[str, ScopedKey]
):

if not isinstance(pattern, str) or pattern == "":
raise ValueError("`pattern` must be a non-empty string")

self.pattern = pattern

if not isinstance(max_retries, int) or max_retries <= 0:
raise ValueError("`max_retries` must have a positive integer value.")
self.max_retries = max_retries

self.taskhub_scoped_key = str(taskhub_scoped_key)

def _gufe_tokenize(self):
key_string = self.pattern + self.taskhub_scoped_key
return hashlib.md5(key_string.encode()).hexdigest()

@classmethod
def _defaults(cls):
raise NotImplementedError

@classmethod
def _from_dict(cls, dct):
return cls(**dct)

def _to_dict(self):
return {
"pattern": self.pattern,
"max_retries": self.max_retries,
"taskhub_scoped_key": self.taskhub_scoped_key,
}

# TODO: should this also compare taskhub scoped keys?
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.pattern == other.pattern

Check warning on line 201 in alchemiscale/storage/models.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/storage/models.py#L199-L201

Added lines #L199 - L201 were not covered by tests


class Tracebacks(GufeTokenizable):
"""
Attributes
----------
tracebacks: list[str]
The tracebacks returned with the ProtocolUnitFailures.
source_keys:list[ScopedKey]
The ScopedKeys of the Protocols that failed.
failure_keys: list[ScopedKey]
The ScopedKeys of the ProtocolUnitFailures.
"""

def __init__(
self, tracebacks: List[str], source_keys: List[str], failure_keys: List[str]
):
value_error = ValueError(
"`tracebacks` must be a non-empty list of string values"
)
if not isinstance(tracebacks, list) or tracebacks == []:
raise value_error
else:
# in the case where tracebacks is not an iterable, this will raise a TypeError
all_string_values = all([isinstance(value, str) for value in tracebacks])
if not all_string_values or "" in tracebacks:
raise value_error

# TODO: validate
self.tracebacks = tracebacks
self.source_keys = source_keys
self.failure_keys = failure_keys

@classmethod
def _defaults(cls):
return super()._defaults()

@classmethod
def _from_dict(cls, dct):
return cls(**dct)

def _to_dict(self):
return {
"tracebacks": self.tracebacks,
"source_keys": self.source_keys,
"failure_keys": self.failure_keys,
}


class TaskHub(GufeTokenizable):
"""

Expand Down
Loading
Loading