Skip to content

Commit

Permalink
begin alignment with new API
Browse files Browse the repository at this point in the history
  • Loading branch information
TShapinsky committed Dec 4, 2023
1 parent 401cc01 commit e13f8b5
Showing 1 changed file with 84 additions and 62 deletions.
146 changes: 84 additions & 62 deletions alfalfa_client/alfalfa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ****************************************************************************************************

import functools
import json
import os
from collections import OrderedDict
Expand All @@ -48,7 +49,7 @@
)

ModelID = str
SiteID = str
RunID = str


class AlfalfaClient:
Expand All @@ -69,6 +70,7 @@ def __init__(self, host: str = 'http://localhost', api_version: str = 'v2'):

self.host = host
self.api_version = api_version
self.point_translation_map = {}

@property
def url(self):
Expand All @@ -91,30 +93,30 @@ def _request(self, endpoint: str, method="POST", parameters=None) -> requests.Re
return response

@parallelize
def status(self, site_id: Union[SiteID, List[SiteID]]) -> str:
"""Get status of site
def status(self, run_id: Union[RunID, List[RunID]]) -> str:
"""Get status of run
:param site_id: id of site or list of ids
:returns: status of site
:param run_id: id of run or list of ids
:returns: status of run
"""
response = self._request(f"sites/{site_id}", method="GET").json()
return response["data"]["status"]
response = self._request(f"runs/{run_id}", method="GET").json()
return response["status"]

@parallelize
def get_error_log(self, site_id: Union[SiteID, List[SiteID]]) -> str:
"""Get error log from site
def get_error_log(self, run_id: Union[RunID, List[RunID]]) -> str:
"""Get error log from run
:param site_id: id of site or list of ids
:returns: error log from site
:param run_id: id of run or list of ids
:returns: error log from run
"""
response = self._request(f"sites/{site_id}", method="GET").json()
return response["data"]["errorLog"]
response = self._request(f"runs/{run_id}", method="GET").json()
return response["errorLog"]

@parallelize
def wait(self, site_id: Union[SiteID, List[SiteID]], desired_status: str, timeout: float = 600) -> None:
"""Wait for a site to have a certain status or timeout with error
def wait(self, run_id: Union[RunID, List[RunID]], desired_status: str, timeout: float = 600) -> None:
"""Wait for a run to have a certain status or timeout with error
:param site_id: id of site or list of ids
:param run_id: id of run or list of ids
:param desired_status: status to wait for
:param timeout: timeout length in seconds
"""
Expand All @@ -124,13 +126,13 @@ def wait(self, site_id: Union[SiteID, List[SiteID]], desired_status: str, timeou
current_status = None
while time() - timeout < start_time:
try:
current_status = self.status(site_id)
current_status = self.status(run_id)
except HTTPError as e:
if e.response.status_code != 404:
raise e

if current_status == "error":
error_log = self.get_error_log(site_id)
error_log = self.get_error_log(run_id)
raise AlfalfaException(error_log)

if current_status != previous_status:
Expand Down Expand Up @@ -167,7 +169,7 @@ def upload_model(self, model_path: os.PathLike) -> ModelID:

return model_id

def create_run_from_model(self, model_id: Union[ModelID, List[ModelID]], wait_for_status: bool = True) -> SiteID:
def create_run_from_model(self, model_id: Union[ModelID, List[ModelID]], wait_for_status: bool = True) -> RunID:
"""Create a run from a model
:param model_id: id of model to create a run from or list of ids
Expand All @@ -183,7 +185,7 @@ def create_run_from_model(self, model_id: Union[ModelID, List[ModelID]], wait_fo
return run_id

@parallelize
def submit(self, model_path: Union[str, List[str]], wait_for_status: bool = True) -> SiteID:
def submit(self, model_path: Union[str, List[str]], wait_for_status: bool = True) -> RunID:
"""Submit a model to alfalfa
:param model_path: path to the model to upload or list of paths
Expand All @@ -194,17 +196,17 @@ def submit(self, model_path: Union[str, List[str]], wait_for_status: bool = True

model_id = self.upload_model(model_path)

# After the file has been uploaded, then tell BOPTEST to process the site
# After the file has been uploaded, then tell BOPTEST to process the run
# This is done not via the haystack api, but through a REST api
run_id = self.create_run_from_model(model_id, wait_for_status=wait_for_status)

return run_id

@parallelize
def start(self, site_id: Union[SiteID, List[SiteID]], start_datetime: Union[Number, datetime], end_datetime: Union[Number, datetime], timescale: int = 5, external_clock: bool = False, realtime: bool = False, wait_for_status: bool = True):
def start(self, run_id: Union[RunID, List[RunID]], start_datetime: Union[Number, datetime], end_datetime: Union[Number, datetime], timescale: int = 5, external_clock: bool = False, realtime: bool = False, wait_for_status: bool = True):
"""Start one run from a model.
:param site_id: id of site or list of ids
:param run_id: id of run or list of ids
:param start_datetime: time to start the model from
:param end_datetime: time to stop the model at (may not be honored for external_clock=True)
:param timescale: multiple of real time to run model at (for external_clock=False)
Expand All @@ -220,100 +222,120 @@ def start(self, site_id: Union[SiteID, List[SiteID]], start_datetime: Union[Numb
'realtime': realtime
}

response = self._request(f"sites/{site_id}/start", parameters=parameters)
response = self._request(f"runs/{run_id}/start", parameters=parameters)

assert response.status_code == 204, "Got wrong status_code from alfalfa"

if wait_for_status:
self.wait(site_id, "running")
self.wait(run_id, "running")

@parallelize
def stop(self, site_id: Union[SiteID, List[SiteID]], wait_for_status: bool = True):
def stop(self, run_id: Union[RunID, List[RunID]], wait_for_status: bool = True):
"""Stop a run
:param site_id: id of the site or list of ids
:param wait_for_status: wait for the site to be "complete" before returning
:param run_id: id of the run or list of ids
:param wait_for_status: wait for the run to be "complete" before returning
"""

response = self._request(f"sites/{site_id}/stop")
response = self._request(f"runs/{run_id}/stop")

assert response.status_code == 204, "Got wrong status_code from alfalfa"

if wait_for_status:
self.wait(site_id, "complete")
self.wait(run_id, "complete")

@parallelize
def advance(self, site_id: Union[SiteID, List[SiteID]]) -> None:
"""Advance a site 1 timestep
def advance(self, run_id: Union[RunID, List[RunID]]) -> None:
"""Advance a run 1 timestep
:param site_id: id of site or list of ids"""
self._request(f"sites/{site_id}/advance")
:param run_id: id of run or list of ids"""
self._request(f"runs/{run_id}/advance")

def get_inputs(self, site_id: str) -> List[str]:
"""Get inputs of site
def get_inputs(self, run_id: str) -> List[str]:
"""Get inputs of run
:param site_id: id of site
:param run_id: id of run
:returns: list of input names"""

response = self._request(f"sites/{site_id}/points/inputs", method="GET")
response = self._request(f"runs/{run_id}/points", method="POST",
parameters={ "pointTypes": ["INPUT", "BIDIRECTIONAL"]})
response_body = response.json()
inputs = []
for point in response_body["data"]:
for point in response_body:
if point["name"] != "":
inputs.append(point["name"])
return inputs

def set_inputs(self, site_id: str, inputs: dict) -> None:
"""Set inputs of site
def set_inputs(self, run_id: str, inputs: dict) -> None:
"""Set inputs of run
:param site_id: id of site
:param run_id: id of run
:param inputs: dictionary of point names and input values"""
point_writes = []
for name, value in inputs.items():
point_writes.append({'name': name, 'value': value})
self._request(f"sites/{site_id}/points/inputs", method="PUT", parameters={'points': point_writes})
id = self._get_point_translation(run_id, name)
if id:
point_writes.append({'id': id, 'value': value})
self._request(f"runs/{run_id}/points/values", method="PUT", parameters={'points': point_writes})

def get_outputs(self, site_id: str) -> dict:
"""Get outputs of site
def get_outputs(self, run_id: str) -> dict:
"""Get outputs of run
:param site_id: id of site
:param run_id: id of run
:returns: dictionary of output names and values"""
response = self._request(f"sites/{site_id}/points/outputs", method="GET")
response = self._request(f"runs/{run_id}/points/values", method="POST",
parameters={ "pointTypes": ["OUTPUT", "BIDIRECTIONAL"]})
response_body = response.json()
outputs = {}
for point in response_body["data"]:
for point in response_body:
name = self._get_point_translation(run_id, point["id"])
if "value" in point.keys():
outputs[point["name"]] = point["value"]
outputs[name] = point["value"]
else:
outputs[point["name"]] = None
outputs[name] = None

return outputs

@parallelize
def get_sim_time(self, site_id: Union[SiteID, List[SiteID]]) -> datetime:
"""Get sim_time of site
def get_sim_time(self, run_id: Union[RunID, List[RunID]]) -> datetime:
"""Get sim_time of run
:param site_id: id of site or list of ids
:param run_id: id of site or list of ids
:returns: datetime of site
"""
response = self._request(f"sites/{site_id}/time", method="GET")
response = self._request(f"runs/{run_id}/time", method="GET")
response_body = response.json()
return datetime.strptime(response_body["time"], '%Y-%m-%d %H:%M:%S')

def set_alias(self, alias: str, site_id: SiteID) -> None:
"""Set alias to point to a site_id
def set_alias(self, alias: str, run_id: RunID) -> None:
"""Set alias to point to a run_id
:param site_id: id of site to point alias to
:param run_id: id of run to point alias to
:param alias: alias to use"""

self._request(f"aliases/{alias}", method="PUT", parameters={"siteId": site_id})
self._request(f"aliases/{alias}", method="PUT", parameters={"runId": run_id})

def get_alias(self, alias: str) -> SiteID:
"""Get site_id from alias
def get_alias(self, alias: str) -> RunID:
"""Get run_id from alias
:param alias: alias
:returns: Id of site associated with alias"""
:returns: Id of run associated with alias"""

response = self._request(f"aliases/{alias}", method="GET")
response_body = response.json()
return response_body

def _get_point_translation(self, *args):
if args in self.point_translation_map:
return self.point_translation_map[args]
if args not in self.point_translation_map:
self._fetch_points(args[0])
if args in self.point_translation_map:
return self.point_translation_map[args]
return None

def _fetch_points(self, run_id):
response = self._request(f"runs/{run_id}/points", method = "GET")
for point in response.json():
self.point_translation_map[(run_id, point["name"])] = point["id"]
self.point_translation_map[(run_id, point["id"])] = point["name"]

0 comments on commit e13f8b5

Please sign in to comment.