From a9d10b967ca5f9523efb19e6327a75c2215b4910 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 30 May 2024 10:24:21 -0400 Subject: [PATCH 1/9] Use model state --- lib/bumblebee/conversion/pytorch_params.ex | 7 ++++--- mix.lock | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index 10fb050f..f0d67a92 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -100,13 +100,13 @@ defmodule Bumblebee.Conversion.PyTorchParams do {params, diff} = layers - |> Enum.filter(fn {_layer, layer_name} -> params_expr[layer_name] end) + |> Enum.filter(fn {_layer, layer_name} -> params_expr.data[layer_name] end) |> Enum.map_reduce(diff, fn {layer, layer_name}, diff -> params_source = params_source(layer_name, prefixes, params_mapping) {params, diff} = Enum.reduce(layer.parameters, {[], diff}, fn param, {params, diff} -> - param_expr = params_expr[layer_name][param.name] + param_expr = params_expr.data[layer_name][param.name] {sources, builder_fun} = case params_source do @@ -168,7 +168,8 @@ defmodule Bumblebee.Conversion.PyTorchParams do {{layer_name, Map.new(params)}, diff} end) - params = Map.new(params) + params_data = Map.new(params) + params = %{params_expr | data: params_data} |> IO.inspect diff = %{ missing: Enum.reverse(diff.missing), diff --git a/mix.lock b/mix.lock index 0f23d0fd..b4c9a5fc 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "axon": {:git, "https://github.com/elixir-nx/axon.git", "7e0e5930ac4b8d2a89f48106b8121e103e597c89", []}, + "axon": {:git, "https://github.com/elixir-nx/axon.git", "f48dcd11ffc6a8c8a553a37d8fd01f1e54be3c03", []}, "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"}, From f23a0e857734598ee60fbf03c5ba0b8b4a26f518 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 30 May 2024 10:25:21 -0400 Subject: [PATCH 2/9] Formatting --- lib/bumblebee/conversion/pytorch_params.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index f0d67a92..82ed77f2 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -169,7 +169,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do end) params_data = Map.new(params) - params = %{params_expr | data: params_data} |> IO.inspect + params = %{params_expr | data: params_data} |> IO.inspect() diff = %{ missing: Enum.reverse(diff.missing), From be929af24cc7642dc8db4b524f8f2d13a819a459 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 30 May 2024 10:25:44 -0400 Subject: [PATCH 3/9] Remove inspect --- lib/bumblebee/conversion/pytorch_params.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index 82ed77f2..9d662123 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -169,7 +169,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do end) params_data = Map.new(params) - params = %{params_expr | data: params_data} |> IO.inspect() + params = %{params_expr | data: params_data} diff = %{ missing: Enum.reverse(diff.missing), From ee398f43372db046c3b7f29cebd9a30f639252e2 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 30 May 2024 10:48:42 -0400 Subject: [PATCH 4/9] Fix failing tests --- test/bumblebee/conversion/pytorch_params_test.exs | 6 +++--- test/bumblebee/text/roberta_test.exs | 4 +++- test/bumblebee_test.exs | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/bumblebee/conversion/pytorch_params_test.exs b/test/bumblebee/conversion/pytorch_params_test.exs index c73dbef7..8d58ec56 100644 --- a/test/bumblebee/conversion/pytorch_params_test.exs +++ b/test/bumblebee/conversion/pytorch_params_test.exs @@ -39,7 +39,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do log = ExUnit.CaptureLog.capture_log(fn -> - params = + %Axon.ModelState{data: params} = PyTorchParams.load_params!(model, input_template(), path, params_mapping: params_mapping() ) @@ -89,7 +89,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do log = ExUnit.CaptureLog.capture_log(fn -> - params = + %Axon.ModelState{data: params} = PyTorchParams.load_params!(model, input_template(), path, params_mapping: params_mapping() ) @@ -107,7 +107,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do log = ExUnit.CaptureLog.capture_log(fn -> - params = + %Axon.ModelState{data: params} = PyTorchParams.load_params!(model, input_template(), path, params_mapping: params_mapping() ) diff --git a/test/bumblebee/text/roberta_test.exs b/test/bumblebee/text/roberta_test.exs index b6f2d6bb..186371f0 100644 --- a/test/bumblebee/text/roberta_test.exs +++ b/test/bumblebee/text/roberta_test.exs @@ -157,7 +157,9 @@ defmodule Bumblebee.Text.RobertaTest do assert %Bumblebee.Text.Roberta{architecture: :for_causal_language_modeling} = spec # TODO: remove once we load tied embeddings - params = put_in(params["language_modeling_head.output"], params["embedder.token_embedding"]) + params = update_in(params, [Access.key!(:data)], fn data -> + put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) + end) inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), diff --git a/test/bumblebee_test.exs b/test/bumblebee_test.exs index 213415a7..41266560 100644 --- a/test/bumblebee_test.exs +++ b/test/bumblebee_test.exs @@ -76,7 +76,7 @@ defmodule BumblebeeTest do end test "passing :type casts params accordingly" do - assert {:ok, %{params: params}} = + assert {:ok, %{params: %Axon.ModelState{data: params}}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, type: :bf16 ) @@ -84,7 +84,7 @@ defmodule BumblebeeTest do assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16} assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16} - assert {:ok, %{params: params}} = + assert {:ok, %{params: %Axon.ModelState{data: params}}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, type: Axon.MixedPrecision.create_policy(params: :f16) ) From c8f63da951f6d0f38dfd3a132f3a631fc5853aaf Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 30 May 2024 12:13:17 -0400 Subject: [PATCH 5/9] Fix another test --- test/bumblebee/text/roberta_test.exs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/bumblebee/text/roberta_test.exs b/test/bumblebee/text/roberta_test.exs index 186371f0..235e81d1 100644 --- a/test/bumblebee/text/roberta_test.exs +++ b/test/bumblebee/text/roberta_test.exs @@ -33,7 +33,9 @@ defmodule Bumblebee.Text.RobertaTest do assert %Bumblebee.Text.Roberta{architecture: :for_masked_language_modeling} = spec # TODO: remove once we load tied embeddings - params = put_in(params["language_modeling_head.output"], params["embedder.token_embedding"]) + params = update_in(params, [Access.key!(:data)], fn data -> + put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) + end) inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), From b5084019313d01337cc30e36d8955ad56f5a0840 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Fri, 31 May 2024 14:41:02 +0700 Subject: [PATCH 6/9] Up --- lib/bumblebee.ex | 2 +- lib/bumblebee/conversion/pytorch_params.ex | 21 +++++++++++---------- test/bumblebee/text/roberta_test.exs | 14 ++++++++------ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 35d90a3d..0f22e9f0 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -299,7 +299,7 @@ defmodule Bumblebee do """ @type model_info :: %{ model: Axon.t(), - params: map(), + params: %Axon.ModelState{}, spec: Bumblebee.ModelSpec.t() } diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index 9d662123..88fe3ef2 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -29,7 +29,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do `Bumblebee.Conversion.PyTorchLoader.load!/1` """ - @spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map() + @spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{} def load_params!(model, input_template, path, opts \\ []) do opts = opts @@ -55,25 +55,27 @@ defmodule Bumblebee.Conversion.PyTorchParams do end) |> Enum.reduce(&Map.merge/2) - params_expr = Axon.trace_init(model, input_template) + model_state = Axon.trace_init(model, input_template) + params_expr = model_state.data {params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping]) + model_state = %{model_state | data: params} params_complete? = diff.missing == [] and diff.mismatched == [] - params = + model_state = if params_complete? do - params + model_state else {init_fun, _} = Axon.build(model, compiler: Nx.Defn.Evaluator) - init_fun.(input_template, params) + init_fun.(input_template, model_state) end if Keyword.get(opts, :log_params_diff, not params_complete?) do log_params_diff(diff) end - params + model_state end) end @@ -100,13 +102,13 @@ defmodule Bumblebee.Conversion.PyTorchParams do {params, diff} = layers - |> Enum.filter(fn {_layer, layer_name} -> params_expr.data[layer_name] end) + |> Enum.filter(fn {_layer, layer_name} -> params_expr[layer_name] end) |> Enum.map_reduce(diff, fn {layer, layer_name}, diff -> params_source = params_source(layer_name, prefixes, params_mapping) {params, diff} = Enum.reduce(layer.parameters, {[], diff}, fn param, {params, diff} -> - param_expr = params_expr.data[layer_name][param.name] + param_expr = params_expr[layer_name][param.name] {sources, builder_fun} = case params_source do @@ -168,8 +170,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do {{layer_name, Map.new(params)}, diff} end) - params_data = Map.new(params) - params = %{params_expr | data: params_data} + params = Map.new(params) diff = %{ missing: Enum.reverse(diff.missing), diff --git a/test/bumblebee/text/roberta_test.exs b/test/bumblebee/text/roberta_test.exs index 235e81d1..b4f91933 100644 --- a/test/bumblebee/text/roberta_test.exs +++ b/test/bumblebee/text/roberta_test.exs @@ -33,9 +33,10 @@ defmodule Bumblebee.Text.RobertaTest do assert %Bumblebee.Text.Roberta{architecture: :for_masked_language_modeling} = spec # TODO: remove once we load tied embeddings - params = update_in(params, [Access.key!(:data)], fn data -> - put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) - end) + params = + update_in(params, [Access.key!(:data)], fn data -> + put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) + end) inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -159,9 +160,10 @@ defmodule Bumblebee.Text.RobertaTest do assert %Bumblebee.Text.Roberta{architecture: :for_causal_language_modeling} = spec # TODO: remove once we load tied embeddings - params = update_in(params, [Access.key!(:data)], fn data -> - put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) - end) + params = + update_in(params, [Access.key!(:data)], fn data -> + put_in(data["language_modeling_head.output"], data["embedder.token_embedding"]) + end) inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), From d330a7aa3186549ad9a90afd8be507883336b1a7 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 30 Jul 2024 10:35:38 -0400 Subject: [PATCH 7/9] Bump axon --- mix.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mix.lock b/mix.lock index b4c9a5fc..9681c1f7 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "axon": {:git, "https://github.com/elixir-nx/axon.git", "f48dcd11ffc6a8c8a553a37d8fd01f1e54be3c03", []}, + "axon": {:git, "https://github.com/elixir-nx/axon.git", "a54ee13acf7b1492524cc8764bcd7bc5e88482d0", []}, "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"}, From a1fa298fcbafd3aad23a57a090a6861a53bb7cc7 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 30 Jul 2024 11:11:27 -0400 Subject: [PATCH 8/9] Bump Axon --- mix.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mix.lock b/mix.lock index 9681c1f7..af7d1eb6 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "axon": {:git, "https://github.com/elixir-nx/axon.git", "a54ee13acf7b1492524cc8764bcd7bc5e88482d0", []}, + "axon": {:git, "https://github.com/elixir-nx/axon.git", "8e0a6d913e7a862f37b8c36fdc6141f01c163128", []}, "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"}, From 5db51c0a8737c3eb99f709d7d8b37875c656a026 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 30 Jul 2024 11:17:48 -0400 Subject: [PATCH 9/9] Bump axon one more time --- mix.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mix.lock b/mix.lock index af7d1eb6..58f378bd 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "axon": {:git, "https://github.com/elixir-nx/axon.git", "8e0a6d913e7a862f37b8c36fdc6141f01c163128", []}, + "axon": {:git, "https://github.com/elixir-nx/axon.git", "054eb4c1c224582528e8d1603ad08e7c4088f21c", []}, "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"},