From ea2851f799ca5ad9389c3fde6eac48ae7f411f17 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Sun, 4 Aug 2024 15:23:38 -0700 Subject: [PATCH 01/20] Added method (unimplemented) calls for restarts New statestore method placeholders: - add_task_traceback - resolve_task_restarts The compute api will add a Task Traceback and resolve restarts for returned failed Tasks. When a list of restart patterns are added, restarts are resolved. --- alchemiscale/compute/api.py | 5 ++++- alchemiscale/storage/statestore.py | 33 +++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/alchemiscale/compute/api.py b/alchemiscale/compute/api.py index db21d5b8..df4844c8 100644 --- a/alchemiscale/compute/api.py +++ b/alchemiscale/compute/api.py @@ -12,6 +12,7 @@ from fastapi import FastAPI, APIRouter, Body, Depends from fastapi.middleware.gzip import GZipMiddleware from gufe.tokenization import GufeTokenizable, JSON_HANDLER +from gufe.protocols import ProtocolDAGResult from ..base.api import ( QueryGUFEHandler, @@ -248,7 +249,7 @@ def set_task_result( validate_scopes(task_sk.scope, token) pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder) - pdr = GufeTokenizable.from_dict(pdr) + pdr: ProtocolDAGResult = GufeTokenizable.from_dict(pdr) tf_sk, _ = n4js.get_task_transformation( task=task_scoped_key, @@ -270,7 +271,9 @@ def set_task_result( if protocoldagresultref.ok: n4js.set_task_complete(tasks=[task_sk]) else: + n4js.add_task_traceback(task_sk, pdr.protocol_unit_failures, result_sk) n4js.set_task_error(tasks=[task_sk]) + n4js.resolve_task_restarts(tasks=[task_sk]) return result_sk diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 07d05d02..bf1bf6fc 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -16,6 +16,7 @@ import networkx as nx from gufe import AlchemicalNetwork, Transformation, NonTransformation, Settings from gufe.tokenization import GufeTokenizable, GufeKey, JSON_HANDLER +from gufe.protocols import ProtocolUnitFailure from neo4j import Transaction, GraphDatabase, Driver @@ -2416,6 +2417,14 @@ def get_task_failures(self, task: ScopedKey) -> List[ProtocolDAGResultRef]: """ return self._get_protocoldagresultrefs(q, task) + def add_task_traceback( + self, + task_scoped_key: ScopedKey, + protocol_unit_failures: List[ProtocolUnitFailure], + protocol_dag_result_ref_scoped_key: ScopedKey, + ): + raise NotImplementedError + def set_task_status( self, tasks: List[ScopedKey], status: TaskStatusEnum, raise_error: bool = False ) -> List[Optional[ScopedKey]]: @@ -2778,15 +2787,17 @@ def add_task_restart_patterns( RETURN task """ + actioned_task_records = ( + tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub)) + .to_eager_result() + .records + ) + subgraph = Subgraph() actioned_task_nodes = [] - for actioned_tasks_record in ( - tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub)) - .to_eager_result() - .records - ): + for actioned_tasks_record in actioned_task_records: actioned_task_nodes.append( record_data_to_node(actioned_tasks_record["task"]) ) @@ -2821,6 +2832,15 @@ def add_task_restart_patterns( ) merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") + actioned_task_scoped_keys: List[ScopedKey] = [] + + for actioned_task_record in actioned_task_records: + actioned_task_scoped_keys.append( + ScopedKey(actioned_task_record["task"]["_scoped_key"]) + ) + + self.resolve_task_restarts(actioned_task_scoped_keys) + # TODO: fill in docstring def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): q = """ @@ -2878,6 +2898,9 @@ def get_task_restart_patterns( return data + def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey]): + raise NotImplementedError + ## authentication def create_credentialed_entity(self, entity: CredentialedEntity): From 8e011beccce70dc4fcac041c36dc673d087990b0 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 5 Aug 2024 10:49:57 -0700 Subject: [PATCH 02/20] Implemented add_protocol_dag_result_ref_traceback * Renamed add_task_traceback to add_protocol_dag_result_ref_traceback * Added tests for add_protocol_dag_result_ref_traceback --- alchemiscale/compute/api.py | 4 +- alchemiscale/storage/statestore.py | 44 +++++++++++++- .../integration/storage/test_statestore.py | 58 +++++++++++++++++++ 3 files changed, 102 insertions(+), 4 deletions(-) diff --git a/alchemiscale/compute/api.py b/alchemiscale/compute/api.py index df4844c8..a50f6d93 100644 --- a/alchemiscale/compute/api.py +++ b/alchemiscale/compute/api.py @@ -271,7 +271,9 @@ def set_task_result( if protocoldagresultref.ok: n4js.set_task_complete(tasks=[task_sk]) else: - n4js.add_task_traceback(task_sk, pdr.protocol_unit_failures, result_sk) + n4js.add_protocol_dag_result_ref_traceback( + pdr.protocol_unit_failures, result_sk + ) n4js.set_task_error(tasks=[task_sk]) n4js.resolve_task_restarts(tasks=[task_sk]) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index bf1bf6fc..f1902421 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -30,6 +30,7 @@ TaskHub, TaskRestartPattern, TaskStatusEnum, + Traceback, ) from ..strategies import Strategy from ..models import Scope, ScopedKey @@ -2417,13 +2418,50 @@ def get_task_failures(self, task: ScopedKey) -> List[ProtocolDAGResultRef]: """ return self._get_protocoldagresultrefs(q, task) - def add_task_traceback( + def add_protocol_dag_result_ref_traceback( self, - task_scoped_key: ScopedKey, protocol_unit_failures: List[ProtocolUnitFailure], protocol_dag_result_ref_scoped_key: ScopedKey, ): - raise NotImplementedError + subgraph = Subgraph() + + with self.transaction() as tx: + + query = """ + MATCH (pdrr:ProtocolDAGResultRef {`_scoped_key`: $protocol_dag_result_ref_scoped_key}) + RETURN pdrr + """ + + pdrr_result = tx.run( + query, + protocol_dag_result_ref_scoped_key=str( + protocol_dag_result_ref_scoped_key + ), + ).to_eager_result() + + try: + protocol_dag_result_ref_node = record_data_to_node( + pdrr_result.records[0]["pdrr"] + ) + except IndexError: + raise KeyError("Could not find ProtocolDAGResultRef in database.") + + tracebacks = list(map(lambda puf: puf.traceback, protocol_unit_failures)) + traceback = Traceback(tracebacks) + + _, traceback_node, _ = self._gufe_to_subgraph( + traceback.to_shallow_dict(), + labels=["GufeTokenizable", traceback.__class__.__name__], + gufe_key=traceback.key, + scope=protocol_dag_result_ref_scoped_key.scope, + ) + + subgraph |= Relationship.type("DETAILS")( + traceback_node, + protocol_dag_result_ref_node, + ) + + merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") def set_task_status( self, tasks: List[ScopedKey], status: TaskStatusEnum, raise_error: bool = False diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index c7901840..94d6e110 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -1944,6 +1944,64 @@ def test_get_task_failures( assert pdr_ref_sk in failure_pdr_ref_sks assert pdr_ref2_sk in failure_pdr_ref_sks + @pytest.mark.parametrize("failure_count", (1, 2, 3, 4)) + def test_add_protocol_dag_result_ref_traceback( + self, + network_tyk2_failure, + n4js, + scope_test, + transformation_failure, + protocoldagresults_failure, + failure_count: int, + ): + + an = network_tyk2_failure.copy_with_replacements( + name=network_tyk2_failure.name + + "_test_add_protocol_dag_result_ref_traceback" + ) + n4js.assemble_network(an, scope_test) + transformation_scoped_key = n4js.get_scoped_key( + transformation_failure, scope_test + ) + + # create a task; pretend we computed it, submit reference for pre-baked + # result + task_scoped_key = n4js.create_task(transformation_scoped_key) + + protocol_unit_failure = protocoldagresults_failure[0].protocol_unit_failures[0] + + pdrr = ProtocolDAGResultRef( + scope=task_scoped_key.scope, + obj_key=protocoldagresults_failure[0].key, + ok=protocoldagresults_failure[0].ok(), + ) + + # push the result + pdrr_scoped_key = n4js.set_task_result(task_scoped_key, pdrr) + + protocol_unit_failures = [] + for failure_index in range(failure_count): + protocol_unit_failures.append( + protocol_unit_failure.copy_with_replacements( + traceback=protocol_unit_failure.traceback + "_" + str(failure_index) + ) + ) + + n4js.add_protocol_dag_result_ref_traceback( + protocol_unit_failures, pdrr_scoped_key + ) + + query = """ + MATCH (traceback:Traceback)-[:DETAILS]->(:ProtocolDAGResultRef {`_scoped_key`: $pdrr_scoped_key}) + RETURN traceback + """ + + results = n4js.execute_query(query, pdrr_scoped_key=str(pdrr_scoped_key)) + + returned_tracebacks = results.records[0]["traceback"]["tracebacks"] + + assert returned_tracebacks == [puf.traceback for puf in protocol_unit_failures] + ### task restart policies class TestTaskRestartPolicy: From 4f07dde8303b334c797a6c68d12e31fa445b2197 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 6 Aug 2024 07:17:37 -0700 Subject: [PATCH 03/20] Started implementation of restart resolution --- alchemiscale/storage/statestore.py | 15 +++++++ .../integration/storage/test_statestore.py | 41 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index f1902421..2b847539 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -2937,6 +2937,21 @@ def get_task_restart_patterns( return data def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey]): + + query = """ + UNWIND $task_scoped_keys AS task_scoped_key + MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern) + CALL { + WITH task + OPTIONAL MATCH (task:Task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef)<-[:DETAILS]-(traceback:Traceback) + RETURN traceback + ORDER BY pdrr.date DESCENDING + LIMIT 1 + } + WITH traceback + RETURN task, app, trp, traceback + """ + raise NotImplementedError ## authentication diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 94d6e110..c5e90fe4 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -2283,6 +2283,47 @@ def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): assert taskhub_grouped_patterns == expected_results + @pytest.mark.xfail(raises=NotImplementedError) + def test_resolve_task_restarts( + self, + n4js, + network_tyk2_failure, + scope_test, + transformation_failure, + protocoldagresults_failure, + ): + + an = network_tyk2_failure.copy_with_replacements( + name=network_tyk2_failure.name + + "_test_add_protocol_dag_result_ref_traceback" + ) + n4js.assemble_network(an, scope_test) + transformation_scoped_key = n4js.get_scoped_key( + transformation_failure, scope_test + ) + + # create a task; pretend we computed it, submit reference for pre-baked + # result + task_scoped_key = n4js.create_task(transformation_scoped_key) + + protocol_unit_failure = protocoldagresults_failure[ + 0 + ].protocol_unit_failures[0] + + from datetime import datetime + + for index in range(5): + pdrr = ProtocolDAGResultRef( + scope=task_scoped_key.scope, + obj_key=protocoldagresults_failure[0].key, + ok=protocoldagresults_failure[0].ok(), + datetime_created=datetime.utcnow(), + ) + + pdrr_scoped_key = n4js.set_task_result(task_scoped_key, pdrr) + + raise NotImplementedError + @pytest.mark.xfail(raises=NotImplementedError) def test_task_actioning_applies_relationship(self): raise NotImplementedError From 78c45518293fd542ac6e17601be9a4a955ac4535 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 7 Aug 2024 15:14:30 -0700 Subject: [PATCH 04/20] Tracebacks now include key data from its source units --- alchemiscale/storage/models.py | 13 +++++++++++-- alchemiscale/storage/statestore.py | 12 ++++++++---- .../tests/integration/storage/test_statestore.py | 1 + 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index 3dc69e0d..7fee8156 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -204,7 +204,9 @@ def __eq__(self, other): class Traceback(GufeTokenizable): - def __init__(self, tracebacks: List[str]): + def __init__( + self, tracebacks: List[str], source_keys: List[str], failure_keys: List[str] + ): value_error = ValueError( "`tracebacks` must be a non-empty list of string values" ) @@ -216,7 +218,10 @@ def __init__(self, tracebacks: List[str]): if not all_string_values or "" in tracebacks: raise value_error + # TODO: validate self.tracebacks = tracebacks + self.source_keys = source_keys + self.failure_keys = failure_keys def _gufe_tokenize(self): return hashlib.md5(str(self.tracebacks).encode()).hexdigest() @@ -230,7 +235,11 @@ def _from_dict(cls, dct): return Traceback(**dct) def _to_dict(self): - return {"tracebacks": self.tracebacks} + return { + "tracebacks": self.tracebacks, + "source_keys": self.source_keys, + "failure_keys": self.failure_keys, + } class TaskHub(GufeTokenizable): diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 2b847539..5d3da98e 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -2447,7 +2447,9 @@ def add_protocol_dag_result_ref_traceback( raise KeyError("Could not find ProtocolDAGResultRef in database.") tracebacks = list(map(lambda puf: puf.traceback, protocol_unit_failures)) - traceback = Traceback(tracebacks) + source_keys = list(map(lambda puf: puf.source_key, protocol_unit_failures)) + failure_keys = list(map(lambda puf: puf.key, protocol_unit_failures)) + traceback = Traceback(tracebacks, source_keys, failure_keys) _, traceback_node, _ = self._gufe_to_subgraph( traceback.to_shallow_dict(), @@ -2877,7 +2879,7 @@ def add_task_restart_patterns( ScopedKey(actioned_task_record["task"]["_scoped_key"]) ) - self.resolve_task_restarts(actioned_task_scoped_keys) + self.resolve_task_restarts(actioned_task_scoped_keys, transaction=tx) # TODO: fill in docstring def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): @@ -2936,7 +2938,9 @@ def get_task_restart_patterns( return data - def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey]): + def resolve_task_restarts( + self, task_scoped_keys: List[ScopedKey], transaction=None + ): query = """ UNWIND $task_scoped_keys AS task_scoped_key @@ -2945,7 +2949,7 @@ def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey]): WITH task OPTIONAL MATCH (task:Task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef)<-[:DETAILS]-(traceback:Traceback) RETURN traceback - ORDER BY pdrr.date DESCENDING + ORDER BY pdrr.datetime_created DESCENDING LIMIT 1 } WITH traceback diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index c5e90fe4..bc5b0f50 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -1979,6 +1979,7 @@ def test_add_protocol_dag_result_ref_traceback( # push the result pdrr_scoped_key = n4js.set_task_result(task_scoped_key, pdrr) + # simulating many failures protocol_unit_failures = [] for failure_index in range(failure_count): protocol_unit_failures.append( From 7acc0036039324c86d2348908d990c8e121775f8 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 13 Aug 2024 08:57:16 -0700 Subject: [PATCH 05/20] Built out custom fixture for testing restart policies Implemented half of the resolve_task_restarts test --- alchemiscale/storage/statestore.py | 8 +- alchemiscale/tests/integration/conftest.py | 40 ++++++ .../tests/integration/interface/conftest.py | 34 +++++ .../integration/storage/test_statestore.py | 134 ++++++++++++++---- 4 files changed, 189 insertions(+), 27 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 5d3da98e..9255c988 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -2876,10 +2876,13 @@ def add_task_restart_patterns( for actioned_task_record in actioned_task_records: actioned_task_scoped_keys.append( - ScopedKey(actioned_task_record["task"]["_scoped_key"]) + ScopedKey.from_str(actioned_task_record["task"]["_scoped_key"]) ) - self.resolve_task_restarts(actioned_task_scoped_keys, transaction=tx) + try: + self.resolve_task_restarts(actioned_task_scoped_keys, transaction=tx) + except NotImplementedError: + pass # TODO: fill in docstring def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): @@ -2914,6 +2917,7 @@ def set_task_restart_patterns_max_retries( ) # TODO: fill in docstring + # TODO: validation of taskhubs variable, will fail in weird ways if not enforced def get_task_restart_patterns( self, taskhubs: List[ScopedKey] ) -> Dict[ScopedKey, Set[Tuple[str, int]]]: diff --git a/alchemiscale/tests/integration/conftest.py b/alchemiscale/tests/integration/conftest.py index 1875981e..026f5866 100644 --- a/alchemiscale/tests/integration/conftest.py +++ b/alchemiscale/tests/integration/conftest.py @@ -167,6 +167,46 @@ def n4js(graph): return Neo4jStore(graph) +@fixture +def n4js_task_restart_policy( + n4js_fresh: Neo4jStore, network_tyk2: AlchemicalNetwork, scope_test +): + + n4js = n4js_fresh + + _, taskhub_scoped_key_with_policy, _ = n4js.assemble_network( + network_tyk2, scope_test + ) + + _, taskhub_scoped_key_no_policy, _ = n4js.assemble_network( + network_tyk2.copy_with_replacements(name=network_tyk2.name + "_no_policy"), + scope_test, + ) + + transformation_1_scoped_key, transformation_2_scoped_key = map( + lambda transformation: n4js.get_scoped_key(transformation, scope_test), + list(network_tyk2.edges)[:2], + ) + + task_scoped_keys = n4js.create_tasks( + [transformation_1_scoped_key] * 4 + [transformation_2_scoped_key] * 4 + ) + + assert all(n4js.action_tasks(task_scoped_keys[:4], taskhub_scoped_key_no_policy)) + assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) + + patterns = [ + "This is an example pattern that will be used as a restart string. 1", + "This is an example pattern that will be used as a restart string. 2", + ] + + n4js.add_task_restart_patterns( + taskhub_scoped_key_with_policy, patterns=patterns, number_of_retries=2 + ) + + return n4js + + @fixture def n4js_fresh(graph): n4js = Neo4jStore(graph) diff --git a/alchemiscale/tests/integration/interface/conftest.py b/alchemiscale/tests/integration/interface/conftest.py index 2eb2c996..b24332e4 100644 --- a/alchemiscale/tests/integration/interface/conftest.py +++ b/alchemiscale/tests/integration/interface/conftest.py @@ -89,6 +89,40 @@ def n4js_preloaded( return n4js +from alchemiscale.storage.statestore import Neo4jStore + + +@pytest.fixture +def n4js_task_restart_policy( + n4js_fresh: Neo4jStore, network_tyk2: AlchemicalNetwork, scope_test +): + + n4js = n4js_fresh + + _, taskhub_scoped_key_with_policy, _ = n4js.assemble_network( + network_tyk2, scope_test + ) + + _, taskhub_scoped_key_no_policy, _ = n4js.assemble_network( + network_tyk2.copy_with_replacements(name=network_tyk2.name + "_no_policy"), + scope_test, + ) + + transformation_1_scoped_key, transformation_2_scoped_key = map( + lambda transformation: n4js.get_scoped_key(transformation, scope_test), + network_tyk2.edges[:2], + ) + + task_scoped_keys = n4js.create_tasks( + [transformation_1_scoped_key] * 4 + [transformation_2_scoped_key] * 4 + ) + + assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_no_policy)) + assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) + + breakpoint() + + @pytest.fixture(scope="module") def scope_consistent_token_data_depends_override(scope_test): """Make a consistent helper to provide an override to the api.app while still accessing fixtures""" diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index bc5b0f50..0e08dc02 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -2,12 +2,15 @@ import random from typing import List, Dict from pathlib import Path +from functools import reduce from itertools import chain +from operator import and_ from collections import defaultdict import pytest from gufe import AlchemicalNetwork from gufe.tokenization import TOKENIZABLE_REGISTRY +from gufe.protocols import ProtocolUnitFailure from gufe.protocols.protocoldag import execute_DAG from alchemiscale.storage.statestore import Neo4jStore @@ -2287,43 +2290,124 @@ def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): @pytest.mark.xfail(raises=NotImplementedError) def test_resolve_task_restarts( self, - n4js, - network_tyk2_failure, - scope_test, - transformation_failure, - protocoldagresults_failure, + scope_test: Scope, + n4js_task_restart_policy: Neo4jStore, ): - an = network_tyk2_failure.copy_with_replacements( - name=network_tyk2_failure.name - + "_test_add_protocol_dag_result_ref_traceback" + def spoof_failure(): + raise NotImplementedError + + # get the actioned tasks for each taskhub + taskhub_actioned_tasks = {} + for taskhub_scoped_key in n4js_task_restart_policy.query_taskhubs(): + taskhub_actioned_tasks[taskhub_scoped_key] = set( + n4js_task_restart_policy.get_taskhub_actioned_tasks( + [taskhub_scoped_key] + )[0] + ) + + restart_patterns = n4js_task_restart_policy.get_task_restart_patterns( + list(taskhub_actioned_tasks.keys()) ) - n4js.assemble_network(an, scope_test) - transformation_scoped_key = n4js.get_scoped_key( - transformation_failure, scope_test + + transformation_tasks = defaultdict(list) + for task in n4js_task_restart_policy.query_tasks( + status=TaskStatusEnum.waiting.value + ): + transformation_scoped_key, _ = ( + n4js_task_restart_policy.get_task_transformation( + task, return_gufe=False + ) + ) + transformation_tasks[transformation_scoped_key].append(task) + + # get a list of all tasks for more convient calls of the resolve method + all_tasks = [] + for task_group in transformation_tasks.values(): + all_tasks.extend(task_group) + + taskhub_scoped_key_no_policy = None + taskhub_scoped_key_with_policy = None + + for taskhub_scoped_key, patterns in restart_patterns.items(): + if not patterns: + taskhub_scoped_key_no_policy = taskhub_scoped_key + continue + else: + taskhub_scoped_key_with_policy = taskhub_scoped_key + continue + + if patterns and taskhub_scoped_key_with_policy: + raise AssertionError("More than one TaskHub has restart patterns") + + assert ( + taskhub_scoped_key_no_policy + and taskhub_scoped_key_with_policy + and (taskhub_scoped_key_no_policy != taskhub_scoped_key_with_policy) ) - # create a task; pretend we computed it, submit reference for pre-baked - # result - task_scoped_key = n4js.create_task(transformation_scoped_key) + # we first check the behavior involving tasks that are actioned by both taskhubs + # this involves confirming: + # + # 1. Completed Tasks do not have an actions relationship with either TaskHub + # 2. A Task entering the error state is switched back to waiting if any restart patterns apply + # 3. A Task entering the error state is left in the error state if no patterns apply and only the TaskHub with + # an enforcing task restart policy exists + # + # Tasks will be set to the error state with a spoofing method, which will create a fake ProtocolDAGResultRef + # and Traceback. This is done since making a protocol fail systematically in the testing environment is not + # obvious at this time. + + # reduce down all tasks until only the common elements between taskhubs exist + tasks_actioned_by_all_taskhubs: List[ScopedKey] = list( + reduce(and_, taskhub_actioned_tasks.values(), set(all_tasks)) + ) - protocol_unit_failure = protocoldagresults_failure[ - 0 - ].protocol_unit_failures[0] + assert len(tasks_actioned_by_all_taskhubs) == 4 - from datetime import datetime + # we're going to just pass the first 2 and fail the second 2 + tasks_to_complete = tasks_actioned_by_all_taskhubs[:2] + tasks_to_fail = tasks_actioned_by_all_taskhubs[3:] - for index in range(5): - pdrr = ProtocolDAGResultRef( - scope=task_scoped_key.scope, - obj_key=protocoldagresults_failure[0].key, - ok=protocoldagresults_failure[0].ok(), + # TODO: either check the results after the loop or within it, whichever makes more sense + for task in tasks_to_complete: + n4js_task_restart_policy.set_task_running([task]) + ok_pdrr = ProtocolDAGResultRef( + ok=True, datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, ) - pdrr_scoped_key = n4js.set_task_result(task_scoped_key, pdrr) + _ = n4js_task_restart_policy.set_task_result(task, ok_pdrr) - raise NotImplementedError + # this should do nothing to the database state since all + # relationships are removed in the previous method call + # TODO: perhaps counts of the connections will be a good test + n4js_task_restart_policy.set_task_complete([task]) + + # TODO: it's unclear the best way to fake a systematic error here + for i, task in enumerate(tasks_to_fail): + n4js_task_restart_policy.set_task_running([task]) + + not_ok_pdrr = ProtocolDAGResultRef( + ok=False, + datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, + ) + + error_messages = ( + "Error message 1", + "Error message 2", + "Error message 3", + ) + + n4js_task_restart_policy.add_protocol_dag_result_ref_traceback() + n4js_task_restart_policy.set_task_error([task]) + + # always feed in all tasks to test for side effects + n4js_task_restart_policy.resolve_task_restarts(all_tasks) @pytest.mark.xfail(raises=NotImplementedError) def test_task_actioning_applies_relationship(self): From 03d9fa1c0218676558f4aa4fd998a6d7a5447f96 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 19 Aug 2024 10:29:43 -0700 Subject: [PATCH 06/20] Added the `chainable` decorator to Neo4jStore With this decorator, if a transaction isn't passed as a keyword arg, one is automatically created (and closed). This allows a chaining behavior where many method calls share a single transaction object. --- alchemiscale/storage/statestore.py | 119 +++++++++++++----- .../tests/integration/interface/conftest.py | 2 - 2 files changed, 87 insertions(+), 34 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 9255c988..7a38e882 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -8,6 +8,7 @@ from datetime import datetime from contextlib import contextmanager import json +import re from functools import lru_cache from typing import Dict, List, Optional, Union, Tuple, Set import weakref @@ -175,6 +176,17 @@ def transaction(self, ignore_exceptions=False) -> Transaction: else: tx.commit() + def chainable(func): + def inner(self, *args, **kwargs): + if kwargs.get("tx") is not None: + return func(self, *args, **kwargs) + + with self.transaction() as tx: + kwargs.update(tx=tx) + return func(self, *args, **kwargs) + + return inner + def execute_query(self, *args, **kwargs): kwargs.update({"database_": self.db_name}) return self.graph.execute_query(*args, **kwargs) @@ -1590,10 +1602,12 @@ def get_task_weights( return weights + @chainable def cancel_tasks( self, tasks: List[ScopedKey], taskhub: ScopedKey, + tx=None, ) -> List[Union[ScopedKey, None]]: """Remove Tasks from the TaskHub for a given AlchemicalNetwork. @@ -1604,31 +1618,30 @@ def cancel_tasks( """ canceled_sks = [] - with self.transaction() as tx: - for task in tasks: - query = """ - // get our task hub, as well as the task :ACTIONS relationship we want to remove - MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key}) - DELETE ar + for task in tasks: + query = """ + // get our task hub, as well as the task :ACTIONS relationship we want to remove + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key}) + DELETE ar + WITH task + CALL { WITH task - CALL { - WITH task - MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern) - DELETE applies - } + MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern) + DELETE applies + } - RETURN task - """ - _task = tx.run( - query, taskhub_scoped_key=str(taskhub), task_scoped_key=str(task) - ).to_eager_result() + RETURN task + """ + _task = tx.run( + query, taskhub_scoped_key=str(taskhub), task_scoped_key=str(task) + ).to_eager_result() - if _task.records: - sk = _task.records[0].data()["task"]["_scoped_key"] - canceled_sks.append(ScopedKey.from_str(sk)) - else: - canceled_sks.append(None) + if _task.records: + sk = _task.records[0].data()["task"]["_scoped_key"] + canceled_sks.append(ScopedKey.from_str(sk)) + else: + canceled_sks.append(None) return canceled_sks @@ -2879,10 +2892,7 @@ def add_task_restart_patterns( ScopedKey.from_str(actioned_task_record["task"]["_scoped_key"]) ) - try: - self.resolve_task_restarts(actioned_task_scoped_keys, transaction=tx) - except NotImplementedError: - pass + self.resolve_task_restarts(actioned_task_scoped_keys, tx=tx) # TODO: fill in docstring def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): @@ -2942,13 +2952,12 @@ def get_task_restart_patterns( return data - def resolve_task_restarts( - self, task_scoped_keys: List[ScopedKey], transaction=None - ): + @chainable + def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey], *, tx=None): query = """ UNWIND $task_scoped_keys AS task_scoped_key - MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern) + MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) CALL { WITH task OPTIONAL MATCH (task:Task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef)<-[:DETAILS]-(traceback:Traceback) @@ -2956,11 +2965,57 @@ def resolve_task_restarts( ORDER BY pdrr.datetime_created DESCENDING LIMIT 1 } - WITH traceback - RETURN task, app, trp, traceback + WITH task, traceback, trp, app, taskhub + RETURN task, traceback, trp, app, taskhub """ - raise NotImplementedError + results = tx.run( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + error=TaskStatusEnum.error.value, + ).to_eager_result() + + if not results: + return + + to_increment: List[Tuple[str, str]] = [] + to_cancel: List[Tuple[str, str]] = [] + for record in results.records: + task_restart_pattern = record["trp"] + applies_relationship = record["app"] + task = record["task"] + taskhub = record["taskhub"] + # TODO: what happens if there is no traceback? i.e. older errored tasks + traceback = record["traceback"] + + num_retries = applies_relationship["num_retries"] + max_retries = task_restart_pattern["max_retries"] + pattern = task_restart_pattern["pattern"] + tracebacks: List[str] = traceback["tracebacks"] + + # exit early if we already know a task is being canceled on a TaskHub + if (task["_scoped_key"], taskhub["_scoped_key"]) in to_cancel: + continue + + # we will always increment (even above the max_retries) and + # cancel later + to_increment.append( + (task["_scoped_key"], task_restart_pattern["_scoped_key"]) + ) + if any([re.search(pattern, message) for message in tracebacks]): + if num_retries + 1 > max_retries: + to_cancel.append((task["_scoped_key"], taskhub["_scoped_key"])) + + increment_query = """ + UNWIND $trp_and_task_pairs as pairs + WITH pairs[0] as task_scoped_key, pairs[1] as task_restart_pattern_scoped_key + MATCH (:Task {`_scoped_key`: task_scoped_key})<-[app:APPLIES]-(:TaskRestartPattern {`_scoped_key`: task_restart_pattern_scoped_key}) + SET app.num_retries = app.num_retries + 1 + """ + + tx.run(increment_query, trp_and_task_pairs=to_increment) + for task, taskhub in to_cancel: + self.cancel_tasks([task], taskhub, tx=tx) ## authentication diff --git a/alchemiscale/tests/integration/interface/conftest.py b/alchemiscale/tests/integration/interface/conftest.py index b24332e4..d7c5da6a 100644 --- a/alchemiscale/tests/integration/interface/conftest.py +++ b/alchemiscale/tests/integration/interface/conftest.py @@ -120,8 +120,6 @@ def n4js_task_restart_policy( assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_no_policy)) assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) - breakpoint() - @pytest.fixture(scope="module") def scope_consistent_token_data_depends_override(scope_test): From aad97e3c0c6438d1dff3fd1b21623a5351b888e4 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 19 Aug 2024 11:57:13 -0700 Subject: [PATCH 07/20] Resolve task restarts now sets all remaining tasks to waiting --- alchemiscale/storage/statestore.py | 25 +++++++++- .../integration/storage/test_statestore.py | 50 +++++++++++++++---- 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 7a38e882..b6c05966 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -11,6 +11,7 @@ import re from functools import lru_cache from typing import Dict, List, Optional, Union, Tuple, Set +from collections.abc import Iterable import weakref import numpy as np @@ -2952,9 +2953,13 @@ def get_task_restart_patterns( return data + # TODO: docstrings @chainable - def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey], *, tx=None): + def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=None): + # Given the scoped keys of a list of Tasks, find all tasks that have an + # error status and have a TaskRestartPattern applied. A subquery is executed + # to optionally get the latest traceback associated with the task query = """ UNWIND $task_scoped_keys AS task_scoped_key MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) @@ -2978,6 +2983,9 @@ def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey], *, tx=None): if not results: return + # iterate over all of the results to determine if an applied pattern needs + # to be iterated or if the task needs to be cancelled outright + to_increment: List[Tuple[str, str]] = [] to_cancel: List[Tuple[str, str]] = [] for record in results.records: @@ -3017,6 +3025,21 @@ def resolve_task_restarts(self, task_scoped_keys: List[ScopedKey], *, tx=None): for task, taskhub in to_cancel: self.cancel_tasks([task], taskhub, tx=tx) + # any remaining tasks must then be okay to switch to waiting + + renew_waiting_status_query = """ + UNWIND $task_scoped_keys AS task_scoped_key + MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) + SET task.status = $waiting + """ + + tx.run( + renew_waiting_status_query, + task_scoped_keys=list(map(str, task_scoped_keys)), + waiting=TaskStatusEnum.waiting.value, + error=TaskStatusEnum.error.value, + ) + ## authentication def create_credentialed_entity(self, entity: CredentialedEntity): diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 0e08dc02..64993239 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -2294,9 +2294,6 @@ def test_resolve_task_restarts( n4js_task_restart_policy: Neo4jStore, ): - def spoof_failure(): - raise NotImplementedError - # get the actioned tasks for each taskhub taskhub_actioned_tasks = {} for taskhub_scoped_key in n4js_task_restart_policy.query_taskhubs(): @@ -2367,7 +2364,7 @@ def spoof_failure(): # we're going to just pass the first 2 and fail the second 2 tasks_to_complete = tasks_actioned_by_all_taskhubs[:2] - tasks_to_fail = tasks_actioned_by_all_taskhubs[3:] + tasks_to_fail = tasks_actioned_by_all_taskhubs[2:] # TODO: either check the results after the loop or within it, whichever makes more sense for task in tasks_to_complete: @@ -2386,7 +2383,6 @@ def spoof_failure(): # TODO: perhaps counts of the connections will be a good test n4js_task_restart_policy.set_task_complete([task]) - # TODO: it's unclear the best way to fake a systematic error here for i, task in enumerate(tasks_to_fail): n4js_task_restart_policy.set_task_running([task]) @@ -2397,18 +2393,50 @@ def spoof_failure(): scope=task.scope, ) - error_messages = ( - "Error message 1", - "Error message 2", - "Error message 3", - ) + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + protocol_unit_failures = [] + for j, message in enumerate(error_messages): + puf = ProtocolUnitFailure( + source_key=f"FakeProtocolUnitKey-123{j}", + inputs={}, + outputs={}, + exception=RuntimeError, + traceback=message, + ) + protocol_unit_failures.append(puf) - n4js_task_restart_policy.add_protocol_dag_result_ref_traceback() + pdrr_scoped_key = n4js_task_restart_policy.set_task_result( + task, not_ok_pdrr + ) + # the following mimics what the compute API would do for a failed task + n4js_task_restart_policy.add_protocol_dag_result_ref_traceback( + protocol_unit_failures, pdrr_scoped_key + ) n4js_task_restart_policy.set_task_error([task]) # always feed in all tasks to test for side effects n4js_task_restart_policy.resolve_task_restarts(all_tasks) + # both tasks should have the waiting status and the APPLIES + # relationship num_retries should have incremented by 1 + + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {`_scoped_key`: task_scoped_key, status: $waiting})<-[:APPLIES {num_retries: 1}]-(:TaskRestartPattern {max_retries: 2}) + RETURN count(DISTINCT task) as renewed_waiting_tasks + """ + + renewed_waiting = n4js_task_restart_policy.execute_query( + query, + task_scoped_keys=list(map(str, tasks_to_fail)), + waiting=TaskStatusEnum.waiting.value, + ).records[0]["renewed_waiting_tasks"] + + assert renewed_waiting == 2 + @pytest.mark.xfail(raises=NotImplementedError) def test_task_actioning_applies_relationship(self): raise NotImplementedError From a655dc7bb50c90a86bfbd3168d8690e6b5fe288b Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 19 Aug 2024 16:33:48 -0700 Subject: [PATCH 08/20] Corrected resolution logic --- alchemiscale/storage/statestore.py | 41 +++++++++++++------ alchemiscale/tests/integration/conftest.py | 4 +- .../integration/storage/test_statestore.py | 35 +++++++++++++++- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index b6c05966..e08ec1c8 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -11,6 +11,7 @@ import re from functools import lru_cache from typing import Dict, List, Optional, Union, Tuple, Set +from collections import defaultdict from collections.abc import Iterable import weakref import numpy as np @@ -2986,8 +2987,14 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non # iterate over all of the results to determine if an applied pattern needs # to be iterated or if the task needs to be cancelled outright + # Keep track of which task/taskhub pairs would need to be canceled + # None => the pair never had a matching restart pattern + # True => at least one patterns max_retries was exceeded + # False => at least one regex matched, but no pattern max_retries were exceeded + cancel_map: defaultdict[Tuple[str, str], Optional[bool]] = defaultdict( + lambda: None + ) to_increment: List[Tuple[str, str]] = [] - to_cancel: List[Tuple[str, str]] = [] for record in results.records: task_restart_pattern = record["trp"] applies_relationship = record["app"] @@ -2996,23 +3003,27 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non # TODO: what happens if there is no traceback? i.e. older errored tasks traceback = record["traceback"] + task_taskhub_tuple = (task["_scoped_key"], taskhub["_scoped_key"]) + + # we have already determined that the task is to be canceled + # is only ever truthy when we say a task needs to be canceled + if cancel_map[task_taskhub_tuple]: + continue + num_retries = applies_relationship["num_retries"] max_retries = task_restart_pattern["max_retries"] pattern = task_restart_pattern["pattern"] tracebacks: List[str] = traceback["tracebacks"] - # exit early if we already know a task is being canceled on a TaskHub - if (task["_scoped_key"], taskhub["_scoped_key"]) in to_cancel: - continue - - # we will always increment (even above the max_retries) and - # cancel later - to_increment.append( - (task["_scoped_key"], task_restart_pattern["_scoped_key"]) - ) if any([re.search(pattern, message) for message in tracebacks]): if num_retries + 1 > max_retries: - to_cancel.append((task["_scoped_key"], taskhub["_scoped_key"])) + cancel_map[task_taskhub_tuple] = True + else: + # to_increment.append(task_taskhub_tuple) + to_increment.append( + (task["_scoped_key"], task_restart_pattern["_scoped_key"]) + ) + cancel_map[task_taskhub_tuple] = False increment_query = """ UNWIND $trp_and_task_pairs as pairs @@ -3022,11 +3033,15 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non """ tx.run(increment_query, trp_and_task_pairs=to_increment) - for task, taskhub in to_cancel: + + # cancel all tasks that didn't trigger any restart patterns (None) + # or exceeded a patterns max_retries value (True) + for (task, taskhub), _ in filter( + lambda values: values[1] is True or values[1] is None, cancel_map.items() + ): self.cancel_tasks([task], taskhub, tx=tx) # any remaining tasks must then be okay to switch to waiting - renew_waiting_status_query = """ UNWIND $task_scoped_keys AS task_scoped_key MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) diff --git a/alchemiscale/tests/integration/conftest.py b/alchemiscale/tests/integration/conftest.py index 026f5866..1a415156 100644 --- a/alchemiscale/tests/integration/conftest.py +++ b/alchemiscale/tests/integration/conftest.py @@ -196,8 +196,8 @@ def n4js_task_restart_policy( assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) patterns = [ - "This is an example pattern that will be used as a restart string. 1", - "This is an example pattern that will be used as a restart string. 2", + r"Error message \d, round \d", + "This is an example pattern that will be used as a restart string.", ] n4js.add_task_restart_patterns( diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 64993239..334c1999 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -2396,7 +2396,6 @@ def test_resolve_task_restarts( error_messages = [ f"Error message {repeat}, round {i}" for repeat in range(3) ] - protocol_unit_failures = [] for j, message in enumerate(error_messages): puf = ProtocolUnitFailure( @@ -2437,6 +2436,40 @@ def test_resolve_task_restarts( assert renewed_waiting == 2 + # we want the resolve restarts to cancel a task. + # deconstruct the tasks to fail, where the first + # one will be cancelled and the second will once again be continued + # but with an additional traceback + task_to_cancel, task_to_wait = tasks_to_fail + + query = """ + MATCH (task:Task {`_scoped_key`: $task_scoped_key_fail})<-[app:APPLIES]-(:TaskRestartPattern) + SET app.num_retries = 2 + SET task.status = $error + """ + + n4js_task_restart_policy.execute_query( + query, + task_scoped_key_fail=str(task_to_cancel), + task_scoped_key_wait=str(task_to_wait), + error=TaskStatusEnum.error.value, + ) + + n4js_task_restart_policy.resolve_task_restarts(tasks_to_fail) + + query = """ + MATCH (task:Task {_scoped_key: $task_scoped_key})<-[:ACTIONS]-(:TaskHub {_scoped_key: $taskhub_scoped_key}) + RETURN task + """ + + results = n4js_task_restart_policy.execute_query( + query, + task_scoped_key=str(task_to_cancel), + taskhub_scoped_key=str(taskhub_scoped_key_with_policy), + ) + + assert len(results.records) == 0 + @pytest.mark.xfail(raises=NotImplementedError) def test_task_actioning_applies_relationship(self): raise NotImplementedError From 5bb67001e08e105871bc0ee4608a8c7c21e8c11e Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 23 Aug 2024 12:04:14 -0700 Subject: [PATCH 09/20] Extracted complexity out of test_resolve_task_restarts --- .../integration/storage/test_statestore.py | 137 +++++++----------- .../tests/integration/storage/utils.py | 90 ++++++++++++ 2 files changed, 144 insertions(+), 83 deletions(-) create mode 100644 alchemiscale/tests/integration/storage/utils.py diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 334c1999..5f5a1eb1 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -4,7 +4,7 @@ from pathlib import Path from functools import reduce from itertools import chain -from operator import and_ +import operator from collections import defaultdict import pytest @@ -31,6 +31,13 @@ ) from alchemiscale.security.auth import hash_key +from alchemiscale.tests.integration.storage.utils import ( + complete_tasks, + fail_task, + tasks_are_errored, + tasks_are_not_actioned_on_taskhub, +) + class TestStateStore: ... @@ -2013,7 +2020,7 @@ class TestTaskRestartPolicy: @pytest.mark.parametrize("status", ("complete", "invalid", "deleted")) def test_task_status_change(self, n4js, network_tyk2, scope_test, status): an = network_tyk2.copy_with_replacements( - name=network_tyk2.name + f"_test_task_status_change" + name=network_tyk2.name + "_test_task_status_change" ) _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) transformation = list(an.edges)[0] @@ -2287,34 +2294,29 @@ def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): assert taskhub_grouped_patterns == expected_results - @pytest.mark.xfail(raises=NotImplementedError) def test_resolve_task_restarts( self, scope_test: Scope, n4js_task_restart_policy: Neo4jStore, ): + n4js = n4js_task_restart_policy # get the actioned tasks for each taskhub taskhub_actioned_tasks = {} - for taskhub_scoped_key in n4js_task_restart_policy.query_taskhubs(): + for taskhub_scoped_key in n4js.query_taskhubs(): taskhub_actioned_tasks[taskhub_scoped_key] = set( - n4js_task_restart_policy.get_taskhub_actioned_tasks( - [taskhub_scoped_key] - )[0] + n4js.get_taskhub_actioned_tasks([taskhub_scoped_key])[0] ) - restart_patterns = n4js_task_restart_policy.get_task_restart_patterns( + restart_patterns = n4js.get_task_restart_patterns( list(taskhub_actioned_tasks.keys()) ) - transformation_tasks = defaultdict(list) - for task in n4js_task_restart_policy.query_tasks( - status=TaskStatusEnum.waiting.value - ): - transformation_scoped_key, _ = ( - n4js_task_restart_policy.get_task_transformation( - task, return_gufe=False - ) + # create a map of the transformations and all of the tasks that perform them + transformation_tasks: dict[ScopedKey, list[ScopedKey]] = defaultdict(list) + for task in n4js.query_tasks(status=TaskStatusEnum.waiting.value): + transformation_scoped_key, _ = n4js.get_task_transformation( + task, return_gufe=False ) transformation_tasks[transformation_scoped_key].append(task) @@ -2326,6 +2328,7 @@ def test_resolve_task_restarts( taskhub_scoped_key_no_policy = None taskhub_scoped_key_with_policy = None + # bind taskhub scoped keys to variables for convenience later for taskhub_scoped_key, patterns in restart_patterns.items(): if not patterns: taskhub_scoped_key_no_policy = taskhub_scoped_key @@ -2357,7 +2360,7 @@ def test_resolve_task_restarts( # reduce down all tasks until only the common elements between taskhubs exist tasks_actioned_by_all_taskhubs: List[ScopedKey] = list( - reduce(and_, taskhub_actioned_tasks.values(), set(all_tasks)) + reduce(operator.and_, taskhub_actioned_tasks.values(), set(all_tasks)) ) assert len(tasks_actioned_by_all_taskhubs) == 4 @@ -2366,69 +2369,43 @@ def test_resolve_task_restarts( tasks_to_complete = tasks_actioned_by_all_taskhubs[:2] tasks_to_fail = tasks_actioned_by_all_taskhubs[2:] - # TODO: either check the results after the loop or within it, whichever makes more sense - for task in tasks_to_complete: - n4js_task_restart_policy.set_task_running([task]) - ok_pdrr = ProtocolDAGResultRef( - ok=True, - datetime_created=datetime.utcnow(), - obj_key=task.gufe_key, - scope=task.scope, - ) + complete_tasks(n4js, tasks_to_complete) - _ = n4js_task_restart_policy.set_task_result(task, ok_pdrr) + records = n4js.execute_query( + """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key})-[:RESULTS_IN]->(:ProtocolDAGResultRef) + RETURN count(task) as task_count + """, + task_scoped_keys=list(map(str, tasks_to_complete)), + ).records - # this should do nothing to the database state since all - # relationships are removed in the previous method call - # TODO: perhaps counts of the connections will be a good test - n4js_task_restart_policy.set_task_complete([task]) + assert records[0]["task_count"] == 2 + # test the behavior of the compute API for i, task in enumerate(tasks_to_fail): - n4js_task_restart_policy.set_task_running([task]) - - not_ok_pdrr = ProtocolDAGResultRef( - ok=False, - datetime_created=datetime.utcnow(), - obj_key=task.gufe_key, - scope=task.scope, - ) - error_messages = [ f"Error message {repeat}, round {i}" for repeat in range(3) ] - protocol_unit_failures = [] - for j, message in enumerate(error_messages): - puf = ProtocolUnitFailure( - source_key=f"FakeProtocolUnitKey-123{j}", - inputs={}, - outputs={}, - exception=RuntimeError, - traceback=message, - ) - protocol_unit_failures.append(puf) - pdrr_scoped_key = n4js_task_restart_policy.set_task_result( - task, not_ok_pdrr + fail_task( + n4js, + task, + resolve=False, + error_messages=error_messages, ) - # the following mimics what the compute API would do for a failed task - n4js_task_restart_policy.add_protocol_dag_result_ref_traceback( - protocol_unit_failures, pdrr_scoped_key - ) - n4js_task_restart_policy.set_task_error([task]) - # always feed in all tasks to test for side effects - n4js_task_restart_policy.resolve_task_restarts(all_tasks) + n4js.resolve_task_restarts(all_tasks) # both tasks should have the waiting status and the APPLIES # relationship num_retries should have incremented by 1 - query = """ UNWIND $task_scoped_keys as task_scoped_key MATCH (task:Task {`_scoped_key`: task_scoped_key, status: $waiting})<-[:APPLIES {num_retries: 1}]-(:TaskRestartPattern {max_retries: 2}) RETURN count(DISTINCT task) as renewed_waiting_tasks """ - renewed_waiting = n4js_task_restart_policy.execute_query( + renewed_waiting = n4js.execute_query( query, task_scoped_keys=list(map(str, tasks_to_fail)), waiting=TaskStatusEnum.waiting.value, @@ -2442,33 +2419,27 @@ def test_resolve_task_restarts( # but with an additional traceback task_to_cancel, task_to_wait = tasks_to_fail - query = """ - MATCH (task:Task {`_scoped_key`: $task_scoped_key_fail})<-[app:APPLIES]-(:TaskRestartPattern) - SET app.num_retries = 2 - SET task.status = $error - """ - - n4js_task_restart_policy.execute_query( - query, - task_scoped_key_fail=str(task_to_cancel), - task_scoped_key_wait=str(task_to_wait), - error=TaskStatusEnum.error.value, - ) + for _ in range(2): + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] - n4js_task_restart_policy.resolve_task_restarts(tasks_to_fail) + fail_task( + n4js, + task_to_cancel, + resolve=False, + error_messages=error_messages, + ) - query = """ - MATCH (task:Task {_scoped_key: $task_scoped_key})<-[:ACTIONS]-(:TaskHub {_scoped_key: $taskhub_scoped_key}) - RETURN task - """ + n4js.resolve_task_restarts(tasks_to_fail) - results = n4js_task_restart_policy.execute_query( - query, - task_scoped_key=str(task_to_cancel), - taskhub_scoped_key=str(taskhub_scoped_key_with_policy), + assert tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_cancel], + taskhub_scoped_key_with_policy, ) - assert len(results.records) == 0 + assert tasks_are_errored(n4js, [task_to_cancel]) @pytest.mark.xfail(raises=NotImplementedError) def test_task_actioning_applies_relationship(self): diff --git a/alchemiscale/tests/integration/storage/utils.py b/alchemiscale/tests/integration/storage/utils.py new file mode 100644 index 00000000..91e4a268 --- /dev/null +++ b/alchemiscale/tests/integration/storage/utils.py @@ -0,0 +1,90 @@ +from datetime import datetime + +from gufe.protocols import ProtocolUnitFailure + +from alchemiscale.storage.statestore import Neo4jStore +from alchemiscale import ScopedKey +from alchemiscale.storage.models import TaskStatusEnum, ProtocolDAGResultRef + + +def tasks_are_not_actioned_on_taskhub( + n4js: Neo4jStore, + task_scoped_keys: list[ScopedKey], + taskhub_scoped_key: ScopedKey, +) -> bool: + + actioned_tasks = n4js.get_taskhub_actioned_tasks([taskhub_scoped_key]) + + for task in task_scoped_keys: + if task in actioned_tasks[0].keys(): + return False + return True + + +def tasks_are_errored(n4js: Neo4jStore, task_scoped_keys: list[ScopedKey]) -> bool: + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key, status: $error}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + error=TaskStatusEnum.error.value, + ) + + return len(results.records) == len(task_scoped_keys) + + +def complete_tasks( + n4js: Neo4jStore, + tasks: list[ScopedKey], +): + n4js.set_task_running(tasks) + for task in tasks: + ok_pdrr = ProtocolDAGResultRef( + ok=True, + datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, + ) + + _ = n4js.set_task_result(task, ok_pdrr) + + n4js.set_task_complete(tasks) + + +def fail_task( + n4js: Neo4jStore, + task: ScopedKey, + resolve: bool = False, + error_messages: list[str] = [], +) -> None: + n4js.set_task_running([task]) + + not_ok_pdrr = ProtocolDAGResultRef( + ok=False, + datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, + ) + + protocol_unit_failures = [] + for j, message in enumerate(error_messages): + puf = ProtocolUnitFailure( + source_key=f"FakeProtocolUnitKey-123{j}", + inputs={}, + outputs={}, + exception=RuntimeError, + traceback=message, + ) + protocol_unit_failures.append(puf) + + pdrr_scoped_key = n4js.set_task_result(task, not_ok_pdrr) + + n4js.add_protocol_dag_result_ref_traceback(protocol_unit_failures, pdrr_scoped_key) + n4js.set_task_error([task]) + + if resolve: + n4js.resolve_task_restarts([task]) From fe4b87be49d61d90f9b8d943cfd57e2a56db1fa4 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 23 Aug 2024 12:56:11 -0700 Subject: [PATCH 10/20] resolve restart of tasks with no tracebacks --- alchemiscale/storage/statestore.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index e08ec1c8..a2a3b4a1 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -2944,7 +2944,9 @@ def get_task_restart_patterns( q, taskhub_scoped_keys=list(map(str, taskhubs)) ).records - data = {taskhub: set() for taskhub in taskhubs} + data: dict[ScopedKey, set[tuple[str, int]]] = { + taskhub: set() for taskhub in taskhubs + } for record in records: pattern = record["trp"]["pattern"] @@ -3000,11 +3002,15 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non applies_relationship = record["app"] task = record["task"] taskhub = record["taskhub"] - # TODO: what happens if there is no traceback? i.e. older errored tasks traceback = record["traceback"] task_taskhub_tuple = (task["_scoped_key"], taskhub["_scoped_key"]) + # TODO: remove in v1.0.0 + # tasks that errored, prior to the indtroduction of task restart policies will have no tracebacks in the database + if traceback is None: + cancel_map[task_taskhub_tuple] = True + # we have already determined that the task is to be canceled # is only ever truthy when we say a task needs to be canceled if cancel_map[task_taskhub_tuple]: From 8a6f98041388804f02cf530df1956474457402c5 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 23 Aug 2024 13:17:31 -0700 Subject: [PATCH 11/20] Replaced many maps with a for loop --- alchemiscale/storage/statestore.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index a2a3b4a1..446aef26 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -2461,9 +2461,15 @@ def add_protocol_dag_result_ref_traceback( except IndexError: raise KeyError("Could not find ProtocolDAGResultRef in database.") - tracebacks = list(map(lambda puf: puf.traceback, protocol_unit_failures)) - source_keys = list(map(lambda puf: puf.source_key, protocol_unit_failures)) - failure_keys = list(map(lambda puf: puf.key, protocol_unit_failures)) + failure_keys = [] + source_keys = [] + tracebacks = [] + + for puf in protocol_unit_failures: + failure_keys.append(puf.key) + source_keys.append(puf.source_key) + tracebacks.append(puf.traceback) + traceback = Traceback(tracebacks, source_keys, failure_keys) _, traceback_node, _ = self._gufe_to_subgraph( From 93eb5f5e9fc8eff6a802658fd2cd11be1384003b Mon Sep 17 00:00:00 2001 From: David Dotson Date: Wed, 4 Sep 2024 10:53:18 -0700 Subject: [PATCH 12/20] Small changes from review --- alchemiscale/storage/statestore.py | 1 - alchemiscale/tests/integration/conftest.py | 23 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 446aef26..ffe4bbc4 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -3031,7 +3031,6 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non if num_retries + 1 > max_retries: cancel_map[task_taskhub_tuple] = True else: - # to_increment.append(task_taskhub_tuple) to_increment.append( (task["_scoped_key"], task_restart_pattern["_scoped_key"]) ) diff --git a/alchemiscale/tests/integration/conftest.py b/alchemiscale/tests/integration/conftest.py index 1a415156..ed9e2b31 100644 --- a/alchemiscale/tests/integration/conftest.py +++ b/alchemiscale/tests/integration/conftest.py @@ -167,6 +167,16 @@ def n4js(graph): return Neo4jStore(graph) +@fixture +def n4js_fresh(graph): + n4js = Neo4jStore(graph) + + n4js.reset() + n4js.initialize() + + return n4js + + @fixture def n4js_task_restart_policy( n4js_fresh: Neo4jStore, network_tyk2: AlchemicalNetwork, scope_test @@ -188,10 +198,13 @@ def n4js_task_restart_policy( list(network_tyk2.edges)[:2], ) + # create 4 tasks for each of the 2 selected transformations task_scoped_keys = n4js.create_tasks( [transformation_1_scoped_key] * 4 + [transformation_2_scoped_key] * 4 ) + # action the tasks for transformation 1 on the taskhub with no policy + # action the tasks for both transformations on the taskhub with a policy assert all(n4js.action_tasks(task_scoped_keys[:4], taskhub_scoped_key_no_policy)) assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) @@ -207,16 +220,6 @@ def n4js_task_restart_policy( return n4js -@fixture -def n4js_fresh(graph): - n4js = Neo4jStore(graph) - - n4js.reset() - n4js.initialize() - - return n4js - - @fixture(scope="module") def s3objectstore_settings(): os.environ["AWS_ACCESS_KEY_ID"] = "test-key-id" From 0900f392e9811fd7bbe0f44d04f9d05594c32287 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Sep 2024 09:50:09 -0700 Subject: [PATCH 13/20] Chainable now uses the update_wrapper function --- alchemiscale/storage/statestore.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index ffe4bbc4..4af43441 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -9,7 +9,7 @@ from contextlib import contextmanager import json import re -from functools import lru_cache +from functools import lru_cache, update_wrapper from typing import Dict, List, Optional, Union, Tuple, Set from collections import defaultdict from collections.abc import Iterable @@ -187,6 +187,8 @@ def inner(self, *args, **kwargs): kwargs.update(tx=tx) return func(self, *args, **kwargs) + update_wrapper(inner, func) + return inner def execute_query(self, *args, **kwargs): From c8ddafc6e773cdcfbf7daa57b1c83ae0b2982218 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Sep 2024 09:51:09 -0700 Subject: [PATCH 14/20] Updated Traceback class * Removed custom tokenization * Implemented _defaults to allow default tokenization to work --- alchemiscale/storage/models.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index 7fee8156..f9c20e9c 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -223,12 +223,9 @@ def __init__( self.source_keys = source_keys self.failure_keys = failure_keys - def _gufe_tokenize(self): - return hashlib.md5(str(self.tracebacks).encode()).hexdigest() - @classmethod def _defaults(cls): - raise NotImplementedError + return {"tracebacks": [], "source_keys": [], "failure_keys": []} @classmethod def _from_dict(cls, dct): From 2a59499acdb1b9d9f082ae4ad02d90b0e905c105 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Sep 2024 10:18:44 -0700 Subject: [PATCH 15/20] Renamed Traceback to Tracebacks --- alchemiscale/compute/api.py | 2 +- alchemiscale/storage/models.py | 4 ++-- alchemiscale/storage/statestore.py | 24 ++++++++++--------- .../integration/storage/test_statestore.py | 6 ++--- .../tests/integration/storage/utils.py | 2 +- .../tests/unit/test_storage_models.py | 22 ++++++++--------- 6 files changed, 31 insertions(+), 29 deletions(-) diff --git a/alchemiscale/compute/api.py b/alchemiscale/compute/api.py index a50f6d93..f3bff55c 100644 --- a/alchemiscale/compute/api.py +++ b/alchemiscale/compute/api.py @@ -271,7 +271,7 @@ def set_task_result( if protocoldagresultref.ok: n4js.set_task_complete(tasks=[task_sk]) else: - n4js.add_protocol_dag_result_ref_traceback( + n4js.add_protocol_dag_result_ref_tracebacks( pdr.protocol_unit_failures, result_sk ) n4js.set_task_error(tasks=[task_sk]) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index f9c20e9c..618467ce 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -202,7 +202,7 @@ def __eq__(self, other): return self.pattern == other.pattern -class Traceback(GufeTokenizable): +class Tracebacks(GufeTokenizable): def __init__( self, tracebacks: List[str], source_keys: List[str], failure_keys: List[str] @@ -229,7 +229,7 @@ def _defaults(cls): @classmethod def _from_dict(cls, dct): - return Traceback(**dct) + return Tracebacks(**dct) def _to_dict(self): return { diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 4af43441..81ecd40b 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -33,7 +33,7 @@ TaskHub, TaskRestartPattern, TaskStatusEnum, - Traceback, + Tracebacks, ) from ..strategies import Strategy from ..models import Scope, ScopedKey @@ -2435,7 +2435,7 @@ def get_task_failures(self, task: ScopedKey) -> List[ProtocolDAGResultRef]: """ return self._get_protocoldagresultrefs(q, task) - def add_protocol_dag_result_ref_traceback( + def add_protocol_dag_result_ref_tracebacks( self, protocol_unit_failures: List[ProtocolUnitFailure], protocol_dag_result_ref_scoped_key: ScopedKey, @@ -2472,7 +2472,7 @@ def add_protocol_dag_result_ref_traceback( source_keys.append(puf.source_key) tracebacks.append(puf.traceback) - traceback = Traceback(tracebacks, source_keys, failure_keys) + traceback = Tracebacks(tracebacks, source_keys, failure_keys) _, traceback_node, _ = self._gufe_to_subgraph( traceback.to_shallow_dict(), @@ -2976,13 +2976,13 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) CALL { WITH task - OPTIONAL MATCH (task:Task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef)<-[:DETAILS]-(traceback:Traceback) - RETURN traceback + OPTIONAL MATCH (task:Task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef)<-[:DETAILS]-(tracebacks:Tracebacks) + RETURN tracebacks ORDER BY pdrr.datetime_created DESCENDING LIMIT 1 } - WITH task, traceback, trp, app, taskhub - RETURN task, traceback, trp, app, taskhub + WITH task, tracebacks, trp, app, taskhub + RETURN task, tracebacks, trp, app, taskhub """ results = tx.run( @@ -3010,13 +3010,13 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non applies_relationship = record["app"] task = record["task"] taskhub = record["taskhub"] - traceback = record["traceback"] + _tracebacks = record["tracebacks"] task_taskhub_tuple = (task["_scoped_key"], taskhub["_scoped_key"]) # TODO: remove in v1.0.0 # tasks that errored, prior to the indtroduction of task restart policies will have no tracebacks in the database - if traceback is None: + if _tracebacks is None: cancel_map[task_taskhub_tuple] = True # we have already determined that the task is to be canceled @@ -3027,9 +3027,11 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non num_retries = applies_relationship["num_retries"] max_retries = task_restart_pattern["max_retries"] pattern = task_restart_pattern["pattern"] - tracebacks: List[str] = traceback["tracebacks"] + tracebacks: List[str] = _tracebacks["tracebacks"] - if any([re.search(pattern, message) for message in tracebacks]): + compiled_pattern = re.compile(pattern) + + if any([compiled_pattern.search(message) for message in tracebacks]): if num_retries + 1 > max_retries: cancel_map[task_taskhub_tuple] = True else: diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 5f5a1eb1..c85d97ae 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -1998,12 +1998,12 @@ def test_add_protocol_dag_result_ref_traceback( ) ) - n4js.add_protocol_dag_result_ref_traceback( + n4js.add_protocol_dag_result_ref_tracebacks( protocol_unit_failures, pdrr_scoped_key ) query = """ - MATCH (traceback:Traceback)-[:DETAILS]->(:ProtocolDAGResultRef {`_scoped_key`: $pdrr_scoped_key}) + MATCH (traceback:Tracebacks)-[:DETAILS]->(:ProtocolDAGResultRef {`_scoped_key`: $pdrr_scoped_key}) RETURN traceback """ @@ -2355,7 +2355,7 @@ def test_resolve_task_restarts( # an enforcing task restart policy exists # # Tasks will be set to the error state with a spoofing method, which will create a fake ProtocolDAGResultRef - # and Traceback. This is done since making a protocol fail systematically in the testing environment is not + # and Tracebacks. This is done since making a protocol fail systematically in the testing environment is not # obvious at this time. # reduce down all tasks until only the common elements between taskhubs exist diff --git a/alchemiscale/tests/integration/storage/utils.py b/alchemiscale/tests/integration/storage/utils.py index 91e4a268..43ec3979 100644 --- a/alchemiscale/tests/integration/storage/utils.py +++ b/alchemiscale/tests/integration/storage/utils.py @@ -83,7 +83,7 @@ def fail_task( pdrr_scoped_key = n4js.set_task_result(task, not_ok_pdrr) - n4js.add_protocol_dag_result_ref_traceback(protocol_unit_failures, pdrr_scoped_key) + n4js.add_protocol_dag_result_ref_tracebacks(protocol_unit_failures, pdrr_scoped_key) n4js.set_task_error([task]) if resolve: diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 55dc872f..391a1063 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -4,7 +4,7 @@ NetworkStateEnum, NetworkMark, TaskRestartPattern, - Traceback, + Tracebacks, ) from alchemiscale import ScopedKey @@ -137,40 +137,40 @@ def test_from_dict(self): assert trp_reconstructed.taskhub_scoped_key == original_taskhub_scoped_key -class TestTraceback(object): +class TestTracebacks(object): valid_entry = ["traceback1", "traceback2", "traceback3"] tracebacks_value_error = "`tracebacks` must be a non-empty list of string values" def test_empty_string_element(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(self.valid_entry + [""]) + Tracebacks(self.valid_entry + [""]) def test_non_list_parameter(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(None) + Tracebacks(None) with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(100) + Tracebacks(100) with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback("not a list, but still an iterable that yields strings") + Tracebacks("not a list, but still an iterable that yields strings") def test_list_non_string_elements(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(self.valid_entry + [None]) + Tracebacks(self.valid_entry + [None]) def test_empty_list(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback([]) + Tracebacks([]) def test_to_dict(self): - tb = Traceback(self.valid_entry) + tb = Tracebacks(self.valid_entry) tb_dict = tb.to_dict() assert len(tb_dict) == 4 - assert tb_dict.pop("__qualname__") == "Traceback" + assert tb_dict.pop("__qualname__") == "Tracebacks" assert tb_dict.pop("__module__") == "alchemiscale.storage.models" # light test of the version key @@ -184,7 +184,7 @@ def test_to_dict(self): assert expected == tb_dict def test_from_dict(self): - tb_orig = Traceback(self.valid_entry) + tb_orig = Tracebacks(self.valid_entry) tb_dict = tb_orig.to_dict() tb_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(tb_dict) From 148d048510c142474a4b720a94cc107f87b337c2 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Sep 2024 11:04:19 -0700 Subject: [PATCH 16/20] Updated cancel and increment logic cancel_map has been changed from a defaultdict to a base dict and instead using the dict.get method to return None. Additionally added a set of all task/taskhub pairs that is later used to determine what should be canceled. I've also added grouping on taskhubs so the number of calls to cancel_tasks is minimized. --- alchemiscale/storage/statestore.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 81ecd40b..f88234e7 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -3001,10 +3001,9 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non # None => the pair never had a matching restart pattern # True => at least one patterns max_retries was exceeded # False => at least one regex matched, but no pattern max_retries were exceeded - cancel_map: defaultdict[Tuple[str, str], Optional[bool]] = defaultdict( - lambda: None - ) + cancel_map: dict[Tuple[str, str], Optional[bool]] = {} to_increment: List[Tuple[str, str]] = [] + all_task_taskhub_pairs: set[Tuple[str, str]] = set() for record in results.records: task_restart_pattern = record["trp"] applies_relationship = record["app"] @@ -3014,14 +3013,16 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non task_taskhub_tuple = (task["_scoped_key"], taskhub["_scoped_key"]) + all_task_taskhub_pairs.add(task_taskhub_tuple) + # TODO: remove in v1.0.0 # tasks that errored, prior to the indtroduction of task restart policies will have no tracebacks in the database if _tracebacks is None: cancel_map[task_taskhub_tuple] = True - # we have already determined that the task is to be canceled - # is only ever truthy when we say a task needs to be canceled - if cancel_map[task_taskhub_tuple]: + # we have already determined that the task is to be canceled. + # this is only ever truthy when we say a task needs to be canceled. + if cancel_map.get(task_taskhub_tuple): continue num_retries = applies_relationship["num_retries"] @@ -3051,10 +3052,14 @@ def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=Non # cancel all tasks that didn't trigger any restart patterns (None) # or exceeded a patterns max_retries value (True) - for (task, taskhub), _ in filter( - lambda values: values[1] is True or values[1] is None, cancel_map.items() - ): - self.cancel_tasks([task], taskhub, tx=tx) + cancel_groups: defaultdict[str, list[str]] = defaultdict(list) + for task_taskhub_pair in all_task_taskhub_pairs: + cancel_result = cancel_map.get(task_taskhub_pair) + if cancel_result is True or cancel_result is None: + cancel_groups[task_taskhub_pair[1]].append(task_taskhub_pair[0]) + + for taskhub, tasks in cancel_groups.items(): + self.cancel_tasks(tasks, taskhub, tx=tx) # any remaining tasks must then be okay to switch to waiting renew_waiting_status_query = """ From 645b2e47a9dc1f8a4eb0bf8190899bbb2ff8ad4d Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Sep 2024 11:12:06 -0700 Subject: [PATCH 17/20] Fixed query for deleting the APPLIES relationship --- alchemiscale/storage/statestore.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index f88234e7..03c8f6d5 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -1628,10 +1628,10 @@ def cancel_tasks( MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key}) DELETE ar - WITH task + WITH task, th CALL { - WITH task - MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern) + WITH task, th + MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern)-[:ENFORCES]->(th) DELETE applies } From 3a8eeca158f07e6f3c0de5783af72b86018184ed Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Sep 2024 11:17:07 -0700 Subject: [PATCH 18/20] Removed unused testing fixture --- .../tests/integration/interface/conftest.py | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/alchemiscale/tests/integration/interface/conftest.py b/alchemiscale/tests/integration/interface/conftest.py index d7c5da6a..2eb2c996 100644 --- a/alchemiscale/tests/integration/interface/conftest.py +++ b/alchemiscale/tests/integration/interface/conftest.py @@ -89,38 +89,6 @@ def n4js_preloaded( return n4js -from alchemiscale.storage.statestore import Neo4jStore - - -@pytest.fixture -def n4js_task_restart_policy( - n4js_fresh: Neo4jStore, network_tyk2: AlchemicalNetwork, scope_test -): - - n4js = n4js_fresh - - _, taskhub_scoped_key_with_policy, _ = n4js.assemble_network( - network_tyk2, scope_test - ) - - _, taskhub_scoped_key_no_policy, _ = n4js.assemble_network( - network_tyk2.copy_with_replacements(name=network_tyk2.name + "_no_policy"), - scope_test, - ) - - transformation_1_scoped_key, transformation_2_scoped_key = map( - lambda transformation: n4js.get_scoped_key(transformation, scope_test), - network_tyk2.edges[:2], - ) - - task_scoped_keys = n4js.create_tasks( - [transformation_1_scoped_key] * 4 + [transformation_2_scoped_key] * 4 - ) - - assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_no_policy)) - assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) - - @pytest.fixture(scope="module") def scope_consistent_token_data_depends_override(scope_test): """Make a consistent helper to provide an override to the api.app while still accessing fixtures""" From ea6e66f4daa087934fffe6c0133e80911bc8a3fb Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Sep 2024 11:45:05 -0700 Subject: [PATCH 19/20] Clarified comment and added complimentary assertion Also expanded test to check behavior of the task that was meant to be waiting. --- .../integration/storage/test_statestore.py | 50 +++++++++++++++++-- .../tests/integration/storage/utils.py | 16 ++++++ 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index c85d97ae..9a5e71ef 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -36,6 +36,7 @@ fail_task, tasks_are_errored, tasks_are_not_actioned_on_taskhub, + tasks_are_waiting, ) @@ -2351,8 +2352,8 @@ def test_resolve_task_restarts( # # 1. Completed Tasks do not have an actions relationship with either TaskHub # 2. A Task entering the error state is switched back to waiting if any restart patterns apply - # 3. A Task entering the error state is left in the error state if no patterns apply and only the TaskHub with - # an enforcing task restart policy exists + # 3. A Task entering the error state is left in the error state if no patterns apply and only the TaskHub without + # an enforcing task restart policy actions the Task # # Tasks will be set to the error state with a spoofing method, which will create a fake ProtocolDAGResultRef # and Tracebacks. This is done since making a protocol fail systematically in the testing environment is not @@ -2360,7 +2361,7 @@ def test_resolve_task_restarts( # reduce down all tasks until only the common elements between taskhubs exist tasks_actioned_by_all_taskhubs: List[ScopedKey] = list( - reduce(operator.and_, taskhub_actioned_tasks.values(), set(all_tasks)) + reduce(operator.and_, taskhub_actioned_tasks.values()) ) assert len(tasks_actioned_by_all_taskhubs) == 4 @@ -2415,10 +2416,10 @@ def test_resolve_task_restarts( # we want the resolve restarts to cancel a task. # deconstruct the tasks to fail, where the first - # one will be cancelled and the second will once again be continued - # but with an additional traceback + # one will be cancelled and the second will continue to wait task_to_cancel, task_to_wait = tasks_to_fail + # error out the first task for _ in range(2): error_messages = [ f"Error message {repeat}, round {i}" for repeat in range(3) @@ -2433,14 +2434,53 @@ def test_resolve_task_restarts( n4js.resolve_task_restarts(tasks_to_fail) + # check that it is no longer actioned on the enforced taskhub assert tasks_are_not_actioned_on_taskhub( n4js, [task_to_cancel], taskhub_scoped_key_with_policy, ) + # check that it is still actioned on the unenforced taskhub + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_cancel], + taskhub_scoped_key_no_policy, + ) + + # it should still be errored though! assert tasks_are_errored(n4js, [task_to_cancel]) + # fail the second task one time + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + fail_task( + n4js, + task_to_wait, + resolve=False, + error_messages=error_messages, + ) + + n4js.resolve_task_restarts(tasks_to_fail) + + # check that the waiting task is actioned on both taskhubs + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_wait], + taskhub_scoped_key_with_policy, + ) + + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_wait], + taskhub_scoped_key_no_policy, + ) + + # it should be waiting + assert tasks_are_waiting(n4js, [task_to_wait]) + @pytest.mark.xfail(raises=NotImplementedError) def test_task_actioning_applies_relationship(self): raise NotImplementedError diff --git a/alchemiscale/tests/integration/storage/utils.py b/alchemiscale/tests/integration/storage/utils.py index 43ec3979..40514a53 100644 --- a/alchemiscale/tests/integration/storage/utils.py +++ b/alchemiscale/tests/integration/storage/utils.py @@ -37,6 +37,22 @@ def tasks_are_errored(n4js: Neo4jStore, task_scoped_keys: list[ScopedKey]) -> bo return len(results.records) == len(task_scoped_keys) +def tasks_are_waiting(n4js: Neo4jStore, task_scoped_keys: list[ScopedKey]) -> bool: + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key, status: $waiting}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + waiting=TaskStatusEnum.waiting.value, + ) + + return len(results.records) == len(task_scoped_keys) + + def complete_tasks( n4js: Neo4jStore, tasks: list[ScopedKey], From 7a4b1149f63468c30f2c38404e91944317598f52 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 12 Sep 2024 18:27:12 -0700 Subject: [PATCH 20/20] Small changes to Tracebacks We don't want to change `_defaults()` from what's done in the base class unless we have real default values to leave out of the hash. --- alchemiscale/storage/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index 618467ce..1d8e1679 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -225,11 +225,11 @@ def __init__( @classmethod def _defaults(cls): - return {"tracebacks": [], "source_keys": [], "failure_keys": []} + return super()._defaults() @classmethod def _from_dict(cls, dct): - return Tracebacks(**dct) + return cls(**dct) def _to_dict(self): return {