From eb56a9f0c2f63cf2e79e90da878a00d1a282cec9 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 9 May 2024 09:04:15 -0700 Subject: [PATCH] Add weka as a new S3-like scheme (#16) --- src/olmo_core/io.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/io.py b/src/olmo_core/io.py index db9692e9..068190b1 100644 --- a/src/olmo_core/io.py +++ b/src/olmo_core/io.py @@ -47,7 +47,7 @@ def file_size(path: PathOrStr) -> int: parsed = urlparse(str(path)) if parsed.scheme == "gs": return _gcs_file_size(parsed.netloc, parsed.path.strip("/")) - elif parsed.scheme in ("s3", "r2"): + elif parsed.scheme in ("s3", "r2", "weka"): return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) elif parsed.scheme in ("http", "https"): return _http_file_size(str(path)) @@ -73,7 +73,7 @@ def get_bytes_range(path: PathOrStr, bytes_start: int, num_bytes: int) -> bytes: parsed = urlparse(str(path)) if parsed.scheme == "gs": return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes) - elif parsed.scheme in ("s3", "r2"): + elif parsed.scheme in ("s3", "r2", "weka"): return _s3_get_bytes_range( parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes ) @@ -106,7 +106,7 @@ def upload(source: PathOrStr, target: str, save_overwrite: bool = False): parsed = urlparse(target) if parsed.scheme == "gs": _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) - elif parsed.scheme in ("s3", "r2"): + elif parsed.scheme in ("s3", "r2", "weka"): _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) else: raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme") @@ -145,7 +145,7 @@ def file_exists(path: PathOrStr) -> bool: return False else: return True - elif parsed.scheme in ("s3", "r2"): + elif parsed.scheme in ("s3", "r2", "weka"): try: _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) except FileNotFoundError: @@ -172,7 +172,7 @@ def clear_directory(dir: PathOrStr): from urllib.parse import urlparse parsed = urlparse(str(dir)) - if parsed.scheme in ("s3", "r2"): + if parsed.scheme in ("s3", "r2", "weka"): return _s3_clear_directory(parsed.scheme, parsed.netloc, parsed.path.strip("/")) elif parsed.scheme == "file": return clear_directory(str(dir).replace("file://", "", 1)) @@ -344,6 +344,14 @@ def _get_s3_profile_name(scheme: str) -> Optional[str]: ) return profile_name + if scheme == "weka": + profile_name = os.environ.get("WEKA_PROFILE") + if profile_name is None: + raise OLMoEnvironmentError( + "WEKA profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?" + ) + + return profile_name raise NotImplementedError(f"Cannot get profile name for scheme {scheme}") @@ -359,6 +367,14 @@ def _get_s3_endpoint_url(scheme: str) -> Optional[str]: ) return r2_endpoint_url + if scheme == "weka": + weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL") + if weka_endpoint_url is None: + raise OLMoEnvironmentError( + "WEKA endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?" + ) + + return weka_endpoint_url raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")