Skip to content

Commit

Permalink
Group cache files by repository (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Feb 13, 2024
1 parent 70589fa commit b68ad2b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
13 changes: 11 additions & 2 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1163,11 +1163,12 @@ defmodule Bumblebee do
defp get_repo_files({:hf, repository_id, opts}) do
subdir = opts[:subdir]
url = HuggingFace.Hub.file_listing_url(repository_id, subdir, opts[:revision])
cache_scope = repository_id_to_cache_scope(repository_id)

result =
HuggingFace.Hub.cached_download(
url,
Keyword.take(opts, [:cache_dir, :offline, :auth_token])
[cache_scope: cache_scope] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token])
)

with {:ok, path} <- result,
Expand Down Expand Up @@ -1211,13 +1212,21 @@ defmodule Bumblebee do
end

url = HuggingFace.Hub.file_url(repository_id, filename, opts[:revision])
cache_scope = repository_id_to_cache_scope(repository_id)

HuggingFace.Hub.cached_download(
url,
[etag: etag] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token])
[etag: etag, cache_scope: cache_scope] ++
Keyword.take(opts, [:cache_dir, :offline, :auth_token])
)
end

defp repository_id_to_cache_scope(repository_id) do
repository_id
|> String.replace("/", "--")
|> String.replace(~r/[^\w-]/, "")
end

defp normalize_repository!({:hf, repository_id}) when is_binary(repository_id) do
{:hf, repository_id, []}
end
Expand Down
10 changes: 10 additions & 0 deletions lib/bumblebee/huggingface/hub.ex
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ defmodule Bumblebee.HuggingFace.Hub do
ETag value, however if the value is already known, it can be
passed as an option instead (to skip the extra request)
* `:cache_scope` - a namespace to put the cached files under in
the cache directory
"""
@spec cached_download(String.t(), keyword()) :: {:ok, String.t()} | {:error, String.t()}
def cached_download(url, opts \\ []) do
Expand All @@ -56,6 +59,13 @@ defmodule Bumblebee.HuggingFace.Hub do

dir = Path.join(cache_dir, "huggingface")

dir =
if cache_scope = opts[:cache_scope] do
Path.join(dir, cache_scope)
else
dir
end

File.mkdir_p!(dir)

headers =
Expand Down

0 comments on commit b68ad2b

Please sign in to comment.