Skip to content

Commit

Permalink
Merge pull request #286 from OpenFreeEnergy/feature/iss-277-restart-p…
Browse files Browse the repository at this point in the history
…olicy_resolve_restarts

Restart policy: resolve restarts
  • Loading branch information
ianmkenney authored Sep 19, 2024
2 parents 2310fd5 + 7a4b114 commit cf0e961
Show file tree
Hide file tree
Showing 7 changed files with 665 additions and 52 deletions.
7 changes: 6 additions & 1 deletion alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi import FastAPI, APIRouter, Body, Depends
from fastapi.middleware.gzip import GZipMiddleware
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
from gufe.protocols import ProtocolDAGResult

from ..base.api import (
QueryGUFEHandler,
Expand Down Expand Up @@ -248,7 +249,7 @@ def set_task_result(
validate_scopes(task_sk.scope, token)

pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder)
pdr = GufeTokenizable.from_dict(pdr)
pdr: ProtocolDAGResult = GufeTokenizable.from_dict(pdr)

tf_sk, _ = n4js.get_task_transformation(
task=task_scoped_key,
Expand All @@ -270,7 +271,11 @@ def set_task_result(
if protocoldagresultref.ok:
n4js.set_task_complete(tasks=[task_sk])
else:
n4js.add_protocol_dag_result_ref_tracebacks(
pdr.protocol_unit_failures, result_sk
)
n4js.set_task_error(tasks=[task_sk])
n4js.resolve_task_restarts(tasks=[task_sk])

return result_sk

Expand Down
22 changes: 14 additions & 8 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,11 @@ def __eq__(self, other):
return self.pattern == other.pattern


class Traceback(GufeTokenizable):
class Tracebacks(GufeTokenizable):

def __init__(self, tracebacks: List[str]):
def __init__(
self, tracebacks: List[str], source_keys: List[str], failure_keys: List[str]
):
value_error = ValueError(
"`tracebacks` must be a non-empty list of string values"
)
Expand All @@ -216,21 +218,25 @@ def __init__(self, tracebacks: List[str]):
if not all_string_values or "" in tracebacks:
raise value_error

# TODO: validate
self.tracebacks = tracebacks

def _gufe_tokenize(self):
return hashlib.md5(str(self.tracebacks).encode()).hexdigest()
self.source_keys = source_keys
self.failure_keys = failure_keys

@classmethod
def _defaults(cls):
raise NotImplementedError
return super()._defaults()

@classmethod
def _from_dict(cls, dct):
return Traceback(**dct)
return cls(**dct)

def _to_dict(self):
return {"tracebacks": self.tracebacks}
return {
"tracebacks": self.tracebacks,
"source_keys": self.source_keys,
"failure_keys": self.failure_keys,
}


class TaskHub(GufeTokenizable):
Expand Down
Loading

0 comments on commit cf0e961

Please sign in to comment.