From 7b0d06cf0cc51f29e2ea24d4b4e8d9d0a590006f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Fri, 23 Feb 2024 16:20:14 +0100 Subject: [PATCH] Make sure the initial decoding cache has the proper types (#346) --- lib/bumblebee/text/generation.ex | 29 ++++++++++++++++++++++++++--- test/bumblebee/text/bart_test.exs | 28 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 5ee2d36a..17bfcf8b 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -249,7 +249,7 @@ defmodule Bumblebee.Text.Generation do }) max_length = max_length_fun.(1) - inputs = prepare_decoder_inputs(inputs, "decoder_", spec, max_length) + inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length) {inputs, inputs["decoder_input_ids"], max_length} end @@ -260,7 +260,7 @@ defmodule Bumblebee.Text.Generation do prepare_inputs_fun = fn inputs, _params -> sequence_length = Nx.axis_size(inputs["input_ids"], 1) max_length = max_length_fun.(sequence_length) - inputs = prepare_decoder_inputs(inputs, "", spec, max_length) + inputs = prepare_decoder_inputs(inputs, "", spec, model, max_length) {inputs, inputs["input_ids"], max_length} end @@ -279,7 +279,7 @@ defmodule Bumblebee.Text.Generation do inputs["input_ids"] || inputs["input_features"] || inputs["pixel_values"] end - defp prepare_decoder_inputs(inputs, prefix, spec, max_length) do + defp prepare_decoder_inputs(inputs, prefix, spec, model, max_length) do input_ids = inputs[prefix <> "input_ids"] attention_mask = inputs[prefix <> "attention_mask"] || Nx.broadcast(1, input_ids) @@ -295,9 +295,32 @@ defmodule Bumblebee.Text.Generation do batch_size = Nx.axis_size(input_ids, 0) cache = init_cache(spec, batch_size, max_length, inputs) + + output_policy = model_output_policy(model) + + # TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to + # match Axon compiler + + # Cast all float cache tensors to match the model output. This way + # we make sure the cache we pass as input has the same types as + # the updated cache returned from the model + cache = + Bumblebee.Utils.Nx.map(cache, fn tensor -> + if Nx.Type.integer?(Nx.type(tensor)) do + tensor + else + Axon.MixedPrecision.cast(output_policy, tensor, :output) + end + end) + Map.put(inputs, "cache", cache) end + defp model_output_policy(model) do + {node, _} = Axon.pop_node(model) + node.policy + end + defp update_decoder_inputs(prefix, inputs, cache, token_ids) do inputs |> Map.replace!(prefix <> "input_ids", token_ids) diff --git a/test/bumblebee/text/bart_test.exs b/test/bumblebee/text/bart_test.exs index 515c6903..dcb715f8 100644 --- a/test/bumblebee/text/bart_test.exs +++ b/test/bumblebee/text/bart_test.exs @@ -154,4 +154,32 @@ defmodule Bumblebee.Text.BartTest do assert_equal(token_ids, Nx.tensor([[988, 988, 988]])) end + + test "generation with :for_conditional_generation and lower precision" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}, + type: :f16 + ) + + {:ok, generation_config} = + Bumblebee.load_generation_config( + {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"} + ) + + assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + "seed" => Nx.tensor([0]) + } + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) + + generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) + %{token_ids: token_ids} = generate.(params, inputs) + + assert_equal(token_ids, Nx.tensor([[988, 988, 988]])) + end end