diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 35d90a3d..0f22e9f0 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -299,7 +299,7 @@ defmodule Bumblebee do """ @type model_info :: %{ model: Axon.t(), - params: map(), + params: %Axon.ModelState{}, spec: Bumblebee.ModelSpec.t() } diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index 10fb050f..88fe3ef2 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -29,7 +29,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do `Bumblebee.Conversion.PyTorchLoader.load!/1` """ - @spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map() + @spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{} def load_params!(model, input_template, path, opts \\ []) do opts = opts @@ -55,25 +55,27 @@ defmodule Bumblebee.Conversion.PyTorchParams do end) |> Enum.reduce(&Map.merge/2) - params_expr = Axon.trace_init(model, input_template) + model_state = Axon.trace_init(model, input_template) + params_expr = model_state.data {params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping]) + model_state = %{model_state | data: params} params_complete? = diff.missing == [] and diff.mismatched == [] - params = + model_state = if params_complete? do - params + model_state else {init_fun, _} = Axon.build(model, compiler: Nx.Defn.Evaluator) - init_fun.(input_template, params) + init_fun.(input_template, model_state) end if Keyword.get(opts, :log_params_diff, not params_complete?) do log_params_diff(diff) end - params + model_state end) end diff --git a/mix.lock b/mix.lock index 0f23d0fd..58f378bd 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "axon": {:git, "https://github.com/elixir-nx/axon.git", "7e0e5930ac4b8d2a89f48106b8121e103e597c89", []}, + "axon": {:git, "https://github.com/elixir-nx/axon.git", "054eb4c1c224582528e8d1603ad08e7c4088f21c", []}, "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"}, diff --git a/test/bumblebee/conversion/pytorch_params_test.exs b/test/bumblebee/conversion/pytorch_params_test.exs index c73dbef7..8d58ec56 100644 --- a/test/bumblebee/conversion/pytorch_params_test.exs +++ b/test/bumblebee/conversion/pytorch_params_test.exs @@ -39,7 +39,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do log = ExUnit.CaptureLog.capture_log(fn -> - params = + %Axon.ModelState{data: params} = PyTorchParams.load_params!(model, input_template(), path, params_mapping: params_mapping() ) @@ -89,7 +89,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do log = ExUnit.CaptureLog.capture_log(fn -> - params = + %Axon.ModelState{data: params} = PyTorchParams.load_params!(model, input_template(), path, params_mapping: params_mapping() ) @@ -107,7 +107,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do log = ExUnit.CaptureLog.capture_log(fn -> - params = + %Axon.ModelState{data: params} = PyTorchParams.load_params!(model, input_template(), path, params_mapping: params_mapping() ) diff --git a/test/bumblebee/text/roberta_test.exs b/test/bumblebee/text/roberta_test.exs index b6f2d6bb..b4f91933 100644 --- a/test/bumblebee/text/roberta_test.exs +++ b/test/bumblebee/text/roberta_test.exs @@ -33,7 +33,10 @@ defmodule Bumblebee.Text.RobertaTest do assert %Bumblebee.Text.Roberta{architecture: :for_masked_language_modeling} = spec # TODO: remove once we load tied embeddings - params = put_in(params["language_modeling_head.output"], params["embedder.token_embedding"]) + params = + update_in(params, [Access.key!(:data)], fn data -> + put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) + end) inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -157,7 +160,10 @@ defmodule Bumblebee.Text.RobertaTest do assert %Bumblebee.Text.Roberta{architecture: :for_causal_language_modeling} = spec # TODO: remove once we load tied embeddings - params = put_in(params["language_modeling_head.output"], params["embedder.token_embedding"]) + params = + update_in(params, [Access.key!(:data)], fn data -> + put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) + end) inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), diff --git a/test/bumblebee_test.exs b/test/bumblebee_test.exs index 213415a7..41266560 100644 --- a/test/bumblebee_test.exs +++ b/test/bumblebee_test.exs @@ -76,7 +76,7 @@ defmodule BumblebeeTest do end test "passing :type casts params accordingly" do - assert {:ok, %{params: params}} = + assert {:ok, %{params: %Axon.ModelState{data: params}}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, type: :bf16 ) @@ -84,7 +84,7 @@ defmodule BumblebeeTest do assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16} assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16} - assert {:ok, %{params: params}} = + assert {:ok, %{params: %Axon.ModelState{data: params}}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, type: Axon.MixedPrecision.create_policy(params: :f16) )