Skip to content

Commit

Permalink
Re-Attach to running alignment chunks on retry (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
morsecodist authored Jul 10, 2024
1 parent a7c54be commit b7c97a8
Showing 1 changed file with 74 additions and 21 deletions.
95 changes: 74 additions & 21 deletions lib/idseq_utils/idseq_utils/batch_run_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import logging
import os
Expand All @@ -9,7 +10,7 @@
from os import listdir
from multiprocessing import Pool
from subprocess import run
from typing import Dict, List
from typing import Dict, List, Optional
from urllib.parse import urlparse

from idseq_utils.diamond_scatter import blastx_join
Expand All @@ -19,6 +20,11 @@
from botocore.exceptions import ClientError
from botocore.config import Config

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
log = logging.getLogger(__name__)

MAX_CHUNKS_IN_FLIGHT = 30 # TODO: remove this constant, currently does nothing since we have at most 30 index chunks
Expand Down Expand Up @@ -83,25 +89,62 @@ def _get_job_status(job_id, use_batch_api=False):
raise e


class BatchJobCache:
"""
BatchJobCache saves job IDs so the coordinator can re-attach to running batch jobs when the coordinator fails
The output should always be the same if the inputs are the same, however we also incorporate the batch_args
into the cache because a retry on spot vs on demand will result in a different batch queue.
"""
def __init__(self, bucket: str, prefix: str, inputs: Dict[str, str]):
self.bucket = bucket
self.prefix = prefix
self.inputs = inputs

def _key(self, batch_args: Dict) -> str:
hash = hashlib.sha256()
cache_dict = {"inputs": self.inputs, "batch_args": batch_args}
hash.update(json.dumps(cache_dict, sort_keys=True).encode())
return os.path.join(self.prefix, hash.hexdigest())

def get(self, batch_args: Dict) -> Optional[str]:
try:
resp = _s3_client.get_object(Bucket=self.bucket, Key=self._key(batch_args))
return resp["Body"].read().decode()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
return None
else:
raise e

def put(self, batch_args: Dict, job_id: str):
_s3_client.put_object(
Bucket=self.bucket,
Key=self._key(batch_args),
Body=job_id.encode(),
Tagging="AlignmentCoordination=True",
)


def _run_batch_job(
job_name: str,
job_queue: str,
job_definition: str,
environment: Dict[str, str],
retries: int,
cache: BatchJobCache,
):
response = _batch_client.submit_job(
jobName=job_name,
jobQueue=job_queue,
jobDefinition=job_definition,
containerOverrides={
submit_args = {
"jobName": job_name,
"jobQueue": job_queue,
"jobDefinition": job_definition,
"containerOverrides": {
"environment": [{"name": k, "value": v} for k, v in environment.items()],
"memory": 130816,
"vcpus": 24,
},
retryStrategy={"attempts": retries},
)
job_id = response["jobId"]
"retryStrategy": {"attempts": retries},
}

def _log_status(status: str):
level = logging.INFO if status != "FAILED" else logging.ERROR
Expand All @@ -119,7 +162,14 @@ def _log_status(status: str):
),
)

_log_status("SUBMITTED")
job_id = cache.get(submit_args)
if job_id:
log.info(f"reattach to batch job: {job_id}")
else:
response = _batch_client.submit_job(**submit_args)
job_id = response["jobId"]
cache.put(submit_args, job_id)
_log_status("SUBMITTED")

delay = 60 + random.randint(
-60 // 2, 60 // 2
Expand Down Expand Up @@ -194,26 +244,19 @@ def _job_queue(provisioning_model: str):
input_bucket, input_key = _bucket_and_key(wdl_input_uri)

wdl_output_uri = os.path.join(chunk_dir, f"{chunk_id}-output.json")
output_bucket, output_key = _bucket_and_key(wdl_output_uri)

wdl_workflow_uri = f"s3://idseq-workflows/{aligner}-{aligner_wdl_version}/{aligner}.wdl"

# if this job fails we don't want to re-run chunks that have already been processed
# the presence of the output file means the chunk has already been processed
try:
_s3_client.head_object(Bucket=output_bucket, Key=output_key)
log.info(f"skipping chunk, output already exists: {wdl_output_uri}")
return
except ClientError as e:
# raise the error if it is anything other than "not found"
if e.response["Error"]["Code"] != "404":
raise e
cache_prefix_uri = os.path.join(chunk_dir, "batch_job_cache/")
cache_bucket, cache_prefix = _bucket_and_key(cache_prefix_uri)
cache = BatchJobCache(cache_bucket, cache_prefix, inputs)

_s3_client.put_object(
Bucket=input_bucket,
Key=input_key,
Body=json.dumps(inputs).encode(),
ContentType="application/json",
Tagging="AlignmentCoordination=True",
)

environment = {
Expand All @@ -231,6 +274,7 @@ def _job_queue(provisioning_model: str):
job_definition=job_definition,
environment=environment,
retries=2,
cache=cache,
)
except BatchJobFailed:
_run_batch_job(
Expand All @@ -239,6 +283,7 @@ def _job_queue(provisioning_model: str):
job_definition=job_definition,
environment=environment,
retries=1,
cache=cache,
)


Expand All @@ -263,6 +308,7 @@ def run_alignment(
):
bucket, prefix = _bucket_and_key(db_path)
chunk_dir = os.path.join(input_dir, f"{aligner}-chunks")
_, chunk_prefix = _bucket_and_key(chunk_dir)
chunks = (
[
input_dir,
Expand All @@ -281,9 +327,16 @@ def run_alignment(
run(["s3parcp", "--recursive", chunk_dir, "chunks"], check=True)
if os.path.exists(os.path.join("chunks", "cache")):
shutil.rmtree(os.path.join("chunks", "cache"))
if os.path.exists(os.path.join("chunks", "batch_job_cache")):
shutil.rmtree(os.path.join("chunks", "batch_job_cache"))
for fn in listdir("chunks"):
if fn.endswith("json"):
os.remove(os.path.join("chunks", fn))
_s3_client.put_object_tagging(
Bucket=bucket,
Key=os.path.join(chunk_prefix, fn),
Tagging={"TagSet": [{"Key": "AlignmentCoordination", "Value": "True"}]},
)
if aligner == "diamond":
blastx_join("chunks", result_path, aligner_args, *queries)
else:
Expand Down

0 comments on commit b7c97a8

Please sign in to comment.