Skip to content

Commit

Permalink
Add weka as a new S3-like scheme (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras authored May 9, 2024
1 parent 95bcc38 commit eb56a9f
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def file_size(path: PathOrStr) -> int:
parsed = urlparse(str(path))
if parsed.scheme == "gs":
return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("s3", "r2"):
elif parsed.scheme in ("s3", "r2", "weka"):
return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("http", "https"):
return _http_file_size(str(path))
Expand All @@ -73,7 +73,7 @@ def get_bytes_range(path: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
parsed = urlparse(str(path))
if parsed.scheme == "gs":
return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
elif parsed.scheme in ("s3", "r2"):
elif parsed.scheme in ("s3", "r2", "weka"):
return _s3_get_bytes_range(
parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
parsed = urlparse(target)
if parsed.scheme == "gs":
_gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
elif parsed.scheme in ("s3", "r2"):
elif parsed.scheme in ("s3", "r2", "weka"):
_s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
else:
raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
Expand Down Expand Up @@ -145,7 +145,7 @@ def file_exists(path: PathOrStr) -> bool:
return False
else:
return True
elif parsed.scheme in ("s3", "r2"):
elif parsed.scheme in ("s3", "r2", "weka"):
try:
_s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
except FileNotFoundError:
Expand All @@ -172,7 +172,7 @@ def clear_directory(dir: PathOrStr):
from urllib.parse import urlparse

parsed = urlparse(str(dir))
if parsed.scheme in ("s3", "r2"):
if parsed.scheme in ("s3", "r2", "weka"):
return _s3_clear_directory(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme == "file":
return clear_directory(str(dir).replace("file://", "", 1))
Expand Down Expand Up @@ -344,6 +344,14 @@ def _get_s3_profile_name(scheme: str) -> Optional[str]:
)

return profile_name
if scheme == "weka":
profile_name = os.environ.get("WEKA_PROFILE")
if profile_name is None:
raise OLMoEnvironmentError(
"WEKA profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?"
)

return profile_name

raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")

Expand All @@ -359,6 +367,14 @@ def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
)

return r2_endpoint_url
if scheme == "weka":
weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL")
if weka_endpoint_url is None:
raise OLMoEnvironmentError(
"WEKA endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?"
)

return weka_endpoint_url

raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")

Expand Down

0 comments on commit eb56a9f

Please sign in to comment.