Skip to content

Commit

Permalink
Transfer serving computation result to binary backend upfront (#282)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Nov 14, 2023
1 parent 57bdcce commit 4e0e178
Show file tree
Hide file tree
Showing 15 changed files with 32 additions and 35 deletions.
2 changes: 1 addition & 1 deletion lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
7 changes: 3 additions & 4 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
%{image: image}
end

Bumblebee.Utils.Nx.composite_unflatten_batch(output, inputs.size)
output
|> Bumblebee.Utils.Nx.composite_unflatten_batch(inputs.size)
|> Shared.serving_post_computation()
end
end

Expand Down Expand Up @@ -318,9 +320,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
end

defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do
# We use binary backend so we are not blocked by the serving computation
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)

for outputs <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
results =
for outputs = %{image: image} <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
Expand Down
13 changes: 13 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ defmodule Bumblebee.Shared do
Nx.Batch.pad(batch, batch_size - size)
end

@doc """
Shared logic applied after serving computation to the resulting tensor
or container.
"""
@spec serving_post_computation(result) :: result when result: Nx.Tensor.t() | Nx.Container.t()
def serving_post_computation(result) do
# We transfer to binary backend so tensor access in post-processing
# is not blocked by the serving the serving computation. It is also
# necessary when partitions are enabled since we may need to
# concatenate results for input exceeding the expected batch size.
Nx.backend_transfer(result, Nx.BinaryBackend)
end

@doc """
Compiles or wraps the function with just-in-time compilation.
Expand Down
1 change: 1 addition & 0 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ defmodule Bumblebee.Text.Conversation do
end

sequences[[.., start_idx..-1//1]]
|> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/fill_mask.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ defmodule Bumblebee.Text.FillMask do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ defmodule Bumblebee.Text.Generation do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
7 changes: 1 addition & 6 deletions lib/bumblebee/text/question_answering.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ defmodule Bumblebee.Text.QuestionAnswering do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)

predict_fun.(params, inputs)
predict_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down Expand Up @@ -103,10 +102,6 @@ defmodule Bumblebee.Text.QuestionAnswering do
{batch, {all_inputs, raw_inputs, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn {outputs, _metadata}, {inputs, raw_inputs, multi?} ->
# We use binary backend so we are not blocked by the serving computation
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)

Enum.zip_with(
[raw_inputs, Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(outputs)],
fn [{_question_text, context_text}, inputs, outputs] ->
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TextClassification do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
5 changes: 1 addition & 4 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ defmodule Bumblebee.Text.TextEmbedding do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
embedding_fun.(params, inputs)
embedding_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand All @@ -131,9 +131,6 @@ defmodule Bumblebee.Text.TextEmbedding do
{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
# We use binary backend so we are not blocked by the serving computation
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)

for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
%{embedding: embedding}
end
Expand Down
9 changes: 1 addition & 8 deletions lib/bumblebee/text/token_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TokenClassification do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand All @@ -88,10 +88,6 @@ defmodule Bumblebee.Text.TokenClassification do
{batch, {all_inputs, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, {inputs, multi?} ->
# We use binary backend so we are not blocked by the serving computation
scores = Nx.backend_transfer(scores, Nx.BinaryBackend)
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)

Enum.zip_with(
Utils.Nx.batch_to_list(inputs),
Utils.Nx.batch_to_list(scores),
Expand All @@ -110,9 +106,6 @@ defmodule Bumblebee.Text.TokenClassification do
end

defp gather_raw_entities(scores, tokenizer, inputs) do
# We use binary backend so we are not blocked by the serving computation
scores = Nx.backend_transfer(scores, Nx.BinaryBackend)

{sequence_length, _} = Nx.shape(scores)
flat_special_tokens_mask = Nx.to_flat_list(inputs["special_tokens_mask"])
flat_input_ids = Nx.to_flat_list(inputs["input_ids"])
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/zero_shot_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ defmodule Bumblebee.Text.ZeroShotClassification do
scores = Axon.Activations.softmax(logits[[.., .., entailment_id]])
k = min(top_k, Nx.axis_size(scores, 1))
{top_scores, top_indices} = Nx.top_k(scores, k: k)
{top_scores, top_indices}
{top_scores, top_indices} |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ defmodule Bumblebee.Vision.ImageClassification do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
5 changes: 1 addition & 4 deletions lib/bumblebee/vision/image_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ defmodule Bumblebee.Vision.ImageEmbedding do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
embedding_fun.(params, inputs)
embedding_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand All @@ -94,9 +94,6 @@ defmodule Bumblebee.Vision.ImageEmbedding do
{Nx.Batch.concatenate([inputs]), multi?}
end)
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
# We use binary backend so we are not blocked by the serving computation
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)

for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
%{embedding: embedding}
end
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ defmodule Bumblebee.Vision.ImageToText do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
6 changes: 4 additions & 2 deletions test/bumblebee/text/text_embedding_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ defmodule Bumblebee.Text.TextEmbeddingTest do

text = "query: Cats are cute."

assert %{embedding: %Nx.Tensor{} = embedding1} = Nx.Serving.batched_run(test, text)
assert %{embedding: %Nx.Tensor{} = embedding2} = Nx.Serving.batched_run(test, text)
assert [
%{embedding: %Nx.Tensor{} = embedding1},
%{embedding: %Nx.Tensor{} = embedding2}
] = Nx.Serving.batched_run(test, [text, text])

assert_equal(embedding1, embedding2)
end
Expand Down

0 comments on commit 4e0e178

Please sign in to comment.