diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 7976b199..5ee2d36a 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -90,12 +90,35 @@ defmodule Bumblebee.Text.Generation do argument. Note that the inputs map should additionally include a `"seed"` tensor, with one value per input in the batch. + ## Streaming + + This function sets up a hook that is invoked after every generated + token. The hook receives a map with the following attributes: + + * `:token_id` - the newly generated token + + * `:finished?` - a boolean indicating if the sequence is finished + + * `:length` - the current length of the generated sequence. Once + the sequence is finished, the length does not increase + + Each of the attributes is a tensor with a leading batch dimension. + + When streaming you may not care about the output result, in which + case you can enable `:ignore_output` to reduce the output size. + ## Options * `:logits_processors` - a list of numerical functions to modify predicted scores at each generation step. The functions are applied in order, after all default processors + * `:ignore_output` - if true, returns a dummy tensor that should + be ignored. This is useful when you consume the generated tokens + in a stream fashion via the hook, so that the full output does + not need to be transferred unnecessarily after the computation. + Defaults to `false` + """ @spec build_generate( Axon.t(), @@ -103,9 +126,10 @@ defmodule Bumblebee.Text.Generation do Bumblebee.Text.GenerationConfig.t(), keyword() ) :: - (params :: map(), inputs :: map() -> %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()}) + (params :: map(), inputs :: map() -> + %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()} | (ignored :: Nx.Tensor.t())) def build_generate(model, spec, config, opts \\ []) do - opts = Keyword.validate!(opts, logits_processors: []) + opts = Keyword.validate!(opts, logits_processors: [], ignore_output: false) decoder_start_token_id = config.decoder_start_token_id || config.bos_token_id eos_token_id = config.eos_token_id @@ -148,7 +172,8 @@ defmodule Bumblebee.Text.Generation do traverse_cache_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id, - strategy: config.strategy + strategy: config.strategy, + ignore_output: opts[:ignore_output] ) end @@ -359,7 +384,7 @@ defmodule Bumblebee.Text.Generation do strategy = opts[:strategy] - {sequences, finished_length} = + state = case strategy.type do :greedy_search -> greedy( @@ -400,11 +425,15 @@ defmodule Bumblebee.Text.Generation do ) end - %{ - # Output only the newly generated tokens - token_ids: sequences[[.., length..-1//1]], - length: finished_length - length - } + if opts[:ignore_output] do + state.ignored + else + %{ + # Output only the newly generated tokens + token_ids: state.sequences[[.., length..-1//1]], + length: state.finished_length - length + } + end end deftransformp pop_seed(inputs), do: Map.pop!(inputs, "seed") @@ -426,19 +455,15 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - {sequences, length = input_length, finished_length} = - init_sequences(decoder_input_ids, max_length, pad_token_id) + state = init_sequences(decoder_input_ids, max_length, pad_token_id) # The loop works with inputs of length 1, so if the initial input # is longer, we make the initial pass outside - {sequences, length, finished_length, inputs} = - if length > 1 do + {state, inputs} = + if state.length > 1 do greedy_step( - sequences, - length, - finished_length, + state, inputs, - input_length, predict_fun, params, logits_processor_fun, @@ -447,19 +472,15 @@ defmodule Bumblebee.Text.Generation do eos_token_id: eos_token_id ) else - {sequences, length, finished_length, inputs} + {state, inputs} end - {sequences, _length, finished_length, _inputs, _params} = - while {sequences, length, finished_length, inputs, params}, - continue?(finished_length) do - {sequences, length, finished_length, inputs} = + {state, _inputs, _params} = + while {state, inputs, params}, continue?(state.finished_length) do + {state, inputs} = greedy_step( - sequences, - length, - finished_length, + state, inputs, - input_length, predict_fun, params, logits_processor_fun, @@ -468,10 +489,10 @@ defmodule Bumblebee.Text.Generation do eos_token_id: eos_token_id ) - {sequences, length, finished_length, inputs, params} + {state, inputs, params} end - {sequences, finished_length} + state end defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do @@ -484,7 +505,14 @@ defmodule Bumblebee.Text.Generation do # means that it has not been finished yet finished_length = Nx.broadcast(0, {batch_size}) - {sequences, length, finished_length} + %{ + sequences: sequences, + input_length: length, + length: length, + finished_length: finished_length, + # The ignored return value that we attach all hooks to + ignored: Nx.broadcast(0, {batch_size}) + } end defnp continue?(finished_length) do @@ -492,11 +520,8 @@ defmodule Bumblebee.Text.Generation do end defnp greedy_step( - sequences, - length, - finished_length, + state, inputs, - input_length, predict_fun, params, logits_processor_fun, @@ -509,34 +534,24 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, sequences, length, input_length) + logits = batch_process_logits(logits_processor_fun, logits, state) token_id = Nx.argmax(logits, axis: -1) - {sequences, length, finished_length} = - update_sequences( - sequences, - input_length, - length, - finished_length, - token_id, - pad_token_id, - eos_token_id - ) + state = update_sequences(state, token_id, pad_token_id, eos_token_id) inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1)) - {sequences, length, finished_length, inputs} + {state, inputs} end - defnp update_sequences( - sequences, - input_length, - length, - finished_length, - token_id, - pad_token_id, - eos_token_id - ) do + defnp update_sequences(state, token_id, pad_token_id, eos_token_id) do + %{ + sequences: sequences, + length: length, + input_length: input_length, + finished_length: finished_length + } = state + token_id = Nx.select(finished_length > 0, pad_token_id, token_id) token_ids = Nx.new_axis(token_id, -1) @@ -564,16 +579,17 @@ defmodule Bumblebee.Text.Generation do token = create_token() {token, _} = hook_token(token, data, :token) - attach_token(token, {sequences, length, finished_length}) + state = %{state | sequences: sequences, length: length, finished_length: finished_length} + attach_token(token, state) end - defnp batch_process_logits(logits_processor_fun, logits, sequences, length, input_length) do + defnp batch_process_logits(logits_processor_fun, logits, state) do logits |> Nx.vectorize(:batch) |> logits_processor_fun.(%{ - sequence: Nx.vectorize(sequences, :batch), - length: length, - input_length: input_length + sequence: Nx.vectorize(state.sequences, :batch), + length: state.length, + input_length: state.input_length }) |> Nx.devectorize(keep_names: false) end @@ -596,8 +612,7 @@ defmodule Bumblebee.Text.Generation do top_k = opts[:top_k] penalty_alpha = opts[:penalty_alpha] - {sequences, length = input_length, finished_length} = - init_sequences(decoder_input_ids, max_length, pad_token_id) + state = init_sequences(decoder_input_ids, max_length, pad_token_id) # Step (1) # Initial pass to obtain hidden state and expand inputs to top-k @@ -614,7 +629,7 @@ defmodule Bumblebee.Text.Generation do joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, sequences, length, input_length) + logits = batch_process_logits(logits_processor_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -629,10 +644,9 @@ defmodule Bumblebee.Text.Generation do # pick the best one using the contrastive rank. From the same model # pass we also get the next top-k continuation tokens - {sequences, _length, finished_length, _inputs, _params, _joint_hidden_state, _top_k_values} = - while {sequences, length, finished_length, inputs, params, joint_hidden_state, - {top_k_scores, top_k_token_ids}}, - continue?(finished_length) do + {state, _inputs, _params, _joint_hidden_state, _top_k_values} = + while {state, inputs, params, joint_hidden_state, {top_k_scores, top_k_token_ids}}, + continue?(state.finished_length) do outputs = predict_fun.(params, inputs) hidden_state = decoder_hidden_state(outputs) @@ -643,33 +657,22 @@ defmodule Bumblebee.Text.Generation do contrastive_rank( context_hidden_state, hidden_state, - length, + state.length, top_k_scores, penalty_alpha, top_k ) hidden_state = Utils.Nx.chunked_take(hidden_state, top_k, selected_idx) - joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, length, 0], hidden_state) + joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, state.length, 0], hidden_state) token_id = top_k_token_ids |> Nx.flatten() |> Utils.Nx.chunked_take(top_k, selected_idx) - {sequences, length, finished_length} = - update_sequences( - sequences, - input_length, - length, - finished_length, - token_id, - pad_token_id, - eos_token_id - ) + state = update_sequences(state, token_id, pad_token_id, eos_token_id) logits = outputs.logits[[.., -1]] logits = Utils.Nx.chunked_take(logits, top_k, selected_idx) - - logits = - batch_process_logits(logits_processor_fun, logits, sequences, length, input_length) + logits = batch_process_logits(logits_processor_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -678,11 +681,10 @@ defmodule Bumblebee.Text.Generation do cache = reflect_cache(outputs.cache, top_k, selected_idx, traverse_cache_fun) inputs = update_inputs_fun.(inputs, cache, Nx.reshape(top_k_token_ids, {:auto, 1})) - {sequences, length, finished_length, inputs, params, joint_hidden_state, - {top_k_scores, top_k_token_ids}} + {state, inputs, params, joint_hidden_state, {top_k_scores, top_k_token_ids}} end - {sequences, finished_length} + state end deftransformp decoder_hidden_state(outputs) do @@ -767,21 +769,17 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - {sequences, length = input_length, finished_length} = - init_sequences(decoder_input_ids, max_length, pad_token_id) + state = init_sequences(decoder_input_ids, max_length, pad_token_id) prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key() # The loop works with inputs of length 1, so if the initial input # is longer, we make the initial pass outside - {sequences, length, finished_length, inputs, prng_key} = - if length > 1 do + {state, inputs, prng_key} = + if state.length > 1 do sampling_step( - sequences, - length, - finished_length, + state, inputs, - input_length, predict_fun, params, prng_key, @@ -791,19 +789,15 @@ defmodule Bumblebee.Text.Generation do eos_token_id: eos_token_id ) else - {sequences, length, finished_length, inputs, prng_key} + {state, inputs, prng_key} end - {sequences, _length, finished_length, _inputs, _params, _key} = - while {sequences, length, finished_length, inputs, params, prng_key}, - continue?(finished_length) do - {sequences, length, finished_length, inputs, prng_key} = + {state, _inputs, _params, _key} = + while {state, inputs, params, prng_key}, continue?(state.finished_length) do + {state, inputs, prng_key} = sampling_step( - sequences, - length, - finished_length, + state, inputs, - input_length, predict_fun, params, prng_key, @@ -813,18 +807,15 @@ defmodule Bumblebee.Text.Generation do eos_token_id: eos_token_id ) - {sequences, length, finished_length, inputs, params, prng_key} + {state, inputs, params, prng_key} end - {sequences, finished_length} + state end defnp sampling_step( - sequences, - length, - finished_length, + state, inputs, - input_length, predict_fun, params, prng_key, @@ -841,24 +832,15 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, sequences, length, input_length) + logits = batch_process_logits(logits_processor_fun, logits, state) scores = Axon.Activations.softmax(logits) token_id = batched_choice(key, scores) - {sequences, length, finished_length} = - update_sequences( - sequences, - input_length, - length, - finished_length, - token_id, - pad_token_id, - eos_token_id - ) + state = update_sequences(state, token_id, pad_token_id, eos_token_id) inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1)) - {sequences, length, finished_length, inputs, prng_key} + {state, inputs, prng_key} end deftransformp batched_choice(key, scores) do diff --git a/lib/bumblebee/text/text_generation.ex b/lib/bumblebee/text/text_generation.ex index 556ccd6f..90f9f581 100644 --- a/lib/bumblebee/text/text_generation.ex +++ b/lib/bumblebee/text/text_generation.ex @@ -42,7 +42,10 @@ defmodule Bumblebee.Text.TextGeneration do return_length: true ) - generate_fun = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) + generate_fun = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + ignore_output: opts[:stream] + ) batch_keys = Shared.sequence_batch_keys(sequence_length)