diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 9ba92127..5443b5b4 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -1163,11 +1163,12 @@ defmodule Bumblebee do defp get_repo_files({:hf, repository_id, opts}) do subdir = opts[:subdir] url = HuggingFace.Hub.file_listing_url(repository_id, subdir, opts[:revision]) + cache_scope = repository_id_to_cache_scope(repository_id) result = HuggingFace.Hub.cached_download( url, - Keyword.take(opts, [:cache_dir, :offline, :auth_token]) + [cache_scope: cache_scope] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token]) ) with {:ok, path} <- result, @@ -1211,13 +1212,21 @@ defmodule Bumblebee do end url = HuggingFace.Hub.file_url(repository_id, filename, opts[:revision]) + cache_scope = repository_id_to_cache_scope(repository_id) HuggingFace.Hub.cached_download( url, - [etag: etag] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token]) + [etag: etag, cache_scope: cache_scope] ++ + Keyword.take(opts, [:cache_dir, :offline, :auth_token]) ) end + defp repository_id_to_cache_scope(repository_id) do + repository_id + |> String.replace("/", "--") + |> String.replace(~r/[^\w-]/, "") + end + defp normalize_repository!({:hf, repository_id}) when is_binary(repository_id) do {:hf, repository_id, []} end diff --git a/lib/bumblebee/huggingface/hub.ex b/lib/bumblebee/huggingface/hub.ex index 1035a724..6a730301 100644 --- a/lib/bumblebee/huggingface/hub.ex +++ b/lib/bumblebee/huggingface/hub.ex @@ -47,6 +47,9 @@ defmodule Bumblebee.HuggingFace.Hub do ETag value, however if the value is already known, it can be passed as an option instead (to skip the extra request) + * `:cache_scope` - a namespace to put the cached files under in + the cache directory + """ @spec cached_download(String.t(), keyword()) :: {:ok, String.t()} | {:error, String.t()} def cached_download(url, opts \\ []) do @@ -56,6 +59,13 @@ defmodule Bumblebee.HuggingFace.Hub do dir = Path.join(cache_dir, "huggingface") + dir = + if cache_scope = opts[:cache_scope] do + Path.join(dir, cache_scope) + else + dir + end + File.mkdir_p!(dir) headers =