diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 3ad41a2b..0eb09339 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -116,6 +116,7 @@ defmodule Bumblebee do "CLIPModel" => {Bumblebee.Multimodal.Clip, :base}, "CLIPTextModel" => {Bumblebee.Text.ClipText, :base}, "CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base}, + "ControlNetModel" => {Bumblebee.Diffusion.ControlNet, :base}, "ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification}, "ConvNextModel" => {Bumblebee.Vision.ConvNext, :base}, "DeiTForImageClassification" => {Bumblebee.Vision.Deit, :for_image_classification}, diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index ff5ef4f7..10fb050f 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -300,8 +300,18 @@ defmodule Bumblebee.Conversion.PyTorchParams do defp match_template(name, template), do: match_template(name, template, %{}) defp match_template(<<_, _::binary>> = name, <<"{", template::binary>>, substitutes) do - [value, name] = String.split(name, ".", parts: 2) - [key, template] = String.split(template, "}.", parts: 2) + {value, name} = + case String.split(name, ".", parts: 2) do + [value] -> {value, ""} + [value, name] -> {value, name} + end + + {key, template} = + case String.split(template, "}", parts: 2) do + [key, ""] -> {key, ""} + [key, "." <> template] -> {key, template} + end + match_template(name, template, put_in(substitutes[key], value)) end diff --git a/lib/bumblebee/diffusion/controlnet.ex b/lib/bumblebee/diffusion/controlnet.ex new file mode 100644 index 00000000..0e1ff85f --- /dev/null +++ b/lib/bumblebee/diffusion/controlnet.ex @@ -0,0 +1,505 @@ +defmodule Bumblebee.Diffusion.ControlNet do + alias Bumblebee.Shared + + options = [ + sample_size: [ + default: 64, + doc: "the size of the input spatial dimensions" + ], + in_channels: [ + default: 4, + doc: "the number of channels in the input" + ], + out_channels: [ + default: 4, + doc: "the number of channels in the output" + ], + embedding_flip_sin_to_cos: [ + default: true, + doc: "whether to flip the sin to cos in the sinusoidal timestep embedding" + ], + embedding_frequency_correction_term: [ + default: 0, + doc: ~S""" + controls the frequency formula in the timestep sinusoidal embedding. The frequency is computed + as $\\omega_i = \\frac{1}{10000^{\\frac{i}{n - s}}}$, for $i \\in \\{0, ..., n-1\\}$, where $n$ + is half of the embedding size and $s$ is the shift. Historically, certain implementations of + sinusoidal embedding used $s=0$, while others used $s=1$ + """ + ], + hidden_sizes: [ + default: [320, 640, 1280, 1280], + doc: "the dimensionality of hidden layers in each upsample/downsample block" + ], + depth: [ + default: 2, + doc: "the number of residual blocks in each upsample/downsample block" + ], + down_block_types: [ + default: [ + :cross_attention_down_block, + :cross_attention_down_block, + :cross_attention_down_block, + :down_block + ], + doc: + "a list of downsample block types. The supported blocks are: `:down_block`, `:cross_attention_down_block`" + ], + up_block_types: [ + default: [ + :up_block, + :cross_attention_up_block, + :cross_attention_up_block, + :cross_attention_up_block + ], + doc: + "a list of upsample block types. The supported blocks are: `:up_block`, `:cross_attention_up_block`" + ], + downsample_padding: [ + default: [{1, 1}, {1, 1}], + doc: "the padding to use in the downsample convolution" + ], + mid_block_scale_factor: [ + default: 1, + doc: "the scale factor to use for the mid block" + ], + num_attention_heads: [ + default: 8, + doc: + "the number of attention heads for each attention layer. Optionally can be a list with one number per block" + ], + cross_attention_size: [ + default: 1024, + doc: "the dimensionality of the cross attention features" + ], + use_linear_projection: [ + default: false, + doc: + "whether the input/output projection of the transformer block should be linear or convolutional" + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + group_norm_num_groups: [ + default: 32, + doc: "the number of groups used by the group normalization layers" + ], + group_norm_epsilon: [ + default: 1.0e-5, + doc: "the epsilon used by the group normalization layers" + ], + conditioning_embedding_hidden_sizes: [ + default: [16, 32, 96, 256], + doc: "the dimensionality of hidden layers in the conditioning input embedding" + ] + ] + + @moduledoc """ + ControlNet model with two spatial dimensions and conditioning state. + + ## Architectures + + * `:base` - the ControlNet model + + ## Inputs + + * `"sample"` - `{batch_size, sample_size, sample_size, in_channels}` + + Sample input with two spatial dimensions. + + * `"timestep"` - `{}` + + The timestep used to parameterize model behaviour in a multi-step + process, such as diffusion. + + * `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}` + + The conditioning state (context) to use with cross-attention. + + * `"conditioning"` - `{batch_size, conditioning_size, conditioning_size, 3}` + + The conditioning spatial input. + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + alias Bumblebee.Diffusion + + @impl true + def architectures(), do: [:base] + + @impl true + def config(spec, opts) do + Shared.put_config_attrs(spec, opts) + end + + @impl true + def input_template(spec) do + sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} + timestep_shape = {} + + conditioning_size = + spec.sample_size * 2 ** (length(spec.conditioning_embedding_hidden_sizes) - 1) + + conditioning_shape = {1, conditioning_size, conditioning_size, 3} + encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} + + %{ + "sample" => Nx.template(sample_shape, :f32), + "timestep" => Nx.template(timestep_shape, :u32), + "conditioning" => Nx.template(conditioning_shape, :f32), + "conditioning_scale" => Nx.template({}, :f32), + "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs(spec) + |> core(spec) + |> Layers.output() + end + + defp inputs(spec) do + sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} + + conditioning_size = + spec.sample_size * 2 ** (length(spec.conditioning_embedding_hidden_sizes) - 1) + + conditioning_shape = {nil, conditioning_size, conditioning_size, 3} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("sample", shape: sample_shape), + Axon.input("timestep", shape: {}), + Axon.input("conditioning", shape: conditioning_shape), + Axon.input("conditioning_scale", optional: true), + Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}) + ]) + end + + defp core(inputs, spec) do + sample = inputs["sample"] + timestep = inputs["timestep"] + conditioning = inputs["conditioning"] + + conditioning_scale = + Bumblebee.Layers.default inputs["conditioning_scale"] do + Axon.constant(1) + end + + encoder_hidden_state = inputs["encoder_hidden_state"] + + timestep = + Axon.layer( + fn sample, timestep, _opts -> + Nx.broadcast(timestep, {Nx.axis_size(sample, 0)}) + end, + [sample, timestep], + op_name: :broadcast + ) + + timestep_embedding = + timestep + |> Diffusion.Layers.timestep_sinusoidal_embedding(hd(spec.hidden_sizes), + flip_sin_to_cos: spec.embedding_flip_sin_to_cos, + frequency_correction_term: spec.embedding_frequency_correction_term + ) + |> Diffusion.Layers.UNet.timestep_embedding_mlp(hd(spec.hidden_sizes) * 4, + name: "time_embedding" + ) + + sample = + Axon.conv(sample, hd(spec.hidden_sizes), + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: "input_conv" + ) + + controlnet_conditioning_embeddings = + controlnet_embedding(conditioning, spec, name: "controlnet_conditioning_embedding") + + sample = Axon.add(sample, controlnet_conditioning_embeddings) + + {sample, down_block_states} = + down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") + + sample = + mid_block(sample, timestep_embedding, encoder_hidden_state, spec, name: "mid_block") + + down_block_states = controlnet_down_blocks(down_block_states, name: "controlnet_down_blocks") + + down_block_states = + for down_block_state <- Tuple.to_list(down_block_states) do + Axon.multiply(down_block_state, conditioning_scale) + end + |> List.to_tuple() + + mid_block_state = + sample + |> controlnet_mid_block(spec, name: "controlnet_mid_block") + |> Axon.multiply(conditioning_scale) + + %{ + down_block_states: Axon.container(down_block_states), + mid_block_state: mid_block_state + } + end + + defp controlnet_down_blocks(down_block_states, opts) do + name = opts[:name] + + states = + for {{state, out_channels}, i} <- Enum.with_index(Tuple.to_list(down_block_states)) do + Axon.conv(state, out_channels, + kernel_size: 1, + name: name |> join(i) |> join("zero_conv"), + kernel_initializer: :zeros + ) + end + + List.to_tuple(states) + end + + defp controlnet_mid_block(input, spec, opts) do + name = opts[:name] + + Axon.conv(input, List.last(spec.hidden_sizes), + kernel_size: 1, + name: name |> join("zero_conv"), + kernel_initializer: :zeros + ) + end + + defp controlnet_embedding(sample, spec, opts) do + name = opts[:name] + + state = + Axon.conv(sample, hd(spec.conditioning_embedding_hidden_sizes), + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: join(name, "input_conv"), + activation: :silu + ) + + size_pairs = Enum.chunk_every(spec.conditioning_embedding_hidden_sizes, 2, 1) + + sample = + for {[in_size, out_size], i} <- Enum.with_index(size_pairs), reduce: state do + input -> + input + |> Axon.conv(in_size, + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: name |> join("inner_convs") |> join(2 * i), + activation: :silu + ) + |> Axon.conv(out_size, + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + strides: 2, + name: name |> join("inner_convs") |> join(2 * i + 1), + activation: :silu + ) + end + + Axon.conv(sample, hd(spec.hidden_sizes), + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: join(name, "output_conv"), + kernel_initializer: :zeros + ) + end + + defp down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, opts) do + name = opts[:name] + + blocks = + Enum.zip([spec.hidden_sizes, spec.down_block_types, num_attention_heads_per_block(spec)]) + + in_channels = hd(spec.hidden_sizes) + down_block_states = [{sample, in_channels}] + + state = {sample, down_block_states, in_channels} + + {sample, down_block_states, _} = + for {{out_channels, block_type, num_attention_heads}, idx} <- Enum.with_index(blocks), + reduce: state do + {sample, down_block_states, in_channels} -> + last_block? = idx == length(spec.hidden_sizes) - 1 + + {sample, states} = + Diffusion.Layers.UNet.down_block_2d( + block_type, + sample, + timestep_embedding, + encoder_hidden_state, + depth: spec.depth, + in_channels: in_channels, + out_channels: out_channels, + add_downsample: not last_block?, + downsample_padding: spec.downsample_padding, + activation: spec.activation, + norm_epsilon: spec.group_norm_epsilon, + norm_num_groups: spec.group_norm_num_groups, + num_attention_heads: num_attention_heads, + use_linear_projection: spec.use_linear_projection, + name: join(name, idx) + ) + + {sample, down_block_states ++ Tuple.to_list(states), out_channels} + end + + {sample, List.to_tuple(down_block_states)} + end + + defp mid_block(hidden_state, timesteps_embedding, encoder_hidden_state, spec, opts) do + Diffusion.Layers.UNet.mid_cross_attention_block_2d( + hidden_state, + timesteps_embedding, + encoder_hidden_state, + channels: List.last(spec.hidden_sizes), + activation: spec.activation, + norm_epsilon: spec.group_norm_epsilon, + norm_num_groups: spec.group_norm_num_groups, + output_scale_factor: spec.mid_block_scale_factor, + num_attention_heads: spec |> num_attention_heads_per_block() |> List.last(), + use_linear_projection: spec.use_linear_projection, + name: opts[:name] + ) + end + + defp num_attention_heads_per_block(spec) when is_list(spec.num_attention_heads) do + spec.num_attention_heads + end + + defp num_attention_heads_per_block(spec) when is_integer(spec.num_attention_heads) do + num_blocks = length(spec.down_block_types) + List.duplicate(spec.num_attention_heads, num_blocks) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + in_channels: {"in_channels", number()}, + out_channels: {"out_channels", number()}, + sample_size: {"sample_size", number()}, + center_input_sample: {"center_input_sample", boolean()}, + embedding_flip_sin_to_cos: {"flip_sin_to_cos", boolean()}, + embedding_frequency_correction_term: {"freq_shift", number()}, + hidden_sizes: {"block_out_channels", list(number())}, + depth: {"layers_per_block", number()}, + down_block_types: { + "down_block_types", + list( + mapping(%{ + "DownBlock2D" => :down_block, + "CrossAttnDownBlock2D" => :cross_attention_down_block + }) + ) + }, + up_block_types: { + "up_block_types", + list( + mapping(%{ + "UpBlock2D" => :up_block, + "CrossAttnUpBlock2D" => :cross_attention_up_block + }) + ) + }, + downsample_padding: {"downsample_padding", padding(2)}, + mid_block_scale_factor: {"mid_block_scale_factor", number()}, + num_attention_heads: {"attention_head_dim", one_of([number(), list(number())])}, + cross_attention_size: {"cross_attention_dim", number()}, + use_linear_projection: {"use_linear_projection", boolean()}, + activation: {"act_fn", activation()}, + group_norm_num_groups: {"norm_num_groups", number()}, + group_norm_epsilon: {"norm_eps", number()}, + conditioning_embedding_hidden_sizes: + {"conditioning_embedding_out_channels", list(number())} + ) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + alias Bumblebee.HuggingFace.Transformers + + def params_mapping(_spec) do + block_mapping = %{ + "transformers.{m}.norm" => "attentions.{m}.norm", + "transformers.{m}.input_projection" => "attentions.{m}.proj_in", + "transformers.{m}.output_projection" => "attentions.{m}.proj_out", + "transformers.{m}.blocks.{l}.self_attention.query" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_q", + "transformers.{m}.blocks.{l}.self_attention.key" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_k", + "transformers.{m}.blocks.{l}.self_attention.value" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_v", + "transformers.{m}.blocks.{l}.self_attention.output" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_out.0", + "transformers.{m}.blocks.{l}.cross_attention.query" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_q", + "transformers.{m}.blocks.{l}.cross_attention.key" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_k", + "transformers.{m}.blocks.{l}.cross_attention.value" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_v", + "transformers.{m}.blocks.{l}.cross_attention.output" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_out.0", + "transformers.{m}.blocks.{l}.ffn.intermediate" => + "attentions.{m}.transformer_blocks.{l}.ff.net.0.proj", + "transformers.{m}.blocks.{l}.ffn.output" => + "attentions.{m}.transformer_blocks.{l}.ff.net.2", + "transformers.{m}.blocks.{l}.self_attention_norm" => + "attentions.{m}.transformer_blocks.{l}.norm1", + "transformers.{m}.blocks.{l}.cross_attention_norm" => + "attentions.{m}.transformer_blocks.{l}.norm2", + "transformers.{m}.blocks.{l}.output_norm" => + "attentions.{m}.transformer_blocks.{l}.norm3", + "residual_blocks.{m}.timestep_projection" => "resnets.{m}.time_emb_proj", + "residual_blocks.{m}.norm_1" => "resnets.{m}.norm1", + "residual_blocks.{m}.conv_1" => "resnets.{m}.conv1", + "residual_blocks.{m}.norm_2" => "resnets.{m}.norm2", + "residual_blocks.{m}.conv_2" => "resnets.{m}.conv2", + "residual_blocks.{m}.shortcut.projection" => "resnets.{m}.conv_shortcut", + "downsamples.{m}.conv" => "downsamplers.{m}.conv" + } + + blocks_mapping = + ["down_blocks.{n}", "mid_block"] + |> Enum.map(&Transformers.Utils.prefix_params_mapping(block_mapping, &1, &1)) + |> Enum.reduce(&Map.merge/2) + + controlnet = %{ + "controlnet_conditioning_embedding.input_conv" => "controlnet_cond_embedding.conv_in", + "controlnet_conditioning_embedding.inner_convs.{m}" => + "controlnet_cond_embedding.blocks.{m}", + "controlnet_conditioning_embedding.output_conv" => "controlnet_cond_embedding.conv_out", + "controlnet_down_blocks.{m}.zero_conv" => "controlnet_down_blocks.{m}", + "controlnet_mid_block.zero_conv" => "controlnet_mid_block" + } + + %{ + "time_embedding.intermediate" => "time_embedding.linear_1", + "time_embedding.output" => "time_embedding.linear_2", + "input_conv" => "conv_in" + } + |> Map.merge(blocks_mapping) + |> Map.merge(controlnet) + end + end +end diff --git a/lib/bumblebee/diffusion/layers/unet.ex b/lib/bumblebee/diffusion/layers/unet.ex index 0cd72fcb..8f0ad953 100644 --- a/lib/bumblebee/diffusion/layers/unet.ex +++ b/lib/bumblebee/diffusion/layers/unet.ex @@ -51,22 +51,22 @@ defmodule Bumblebee.Diffusion.Layers.UNet do :cross_attention_up_block, sample, timestep_embedding, - residuals, + down_block_states, encoder_hidden_state, opts ) do - up_block_2d(sample, timestep_embedding, residuals, encoder_hidden_state, opts) + up_block_2d(sample, timestep_embedding, down_block_states, encoder_hidden_state, opts) end def up_block_2d( :up_block, sample, timestep_embedding, - residuals, + down_block_states, _encoder_hidden_state, opts ) do - up_block_2d(sample, timestep_embedding, residuals, nil, opts) + up_block_2d(sample, timestep_embedding, down_block_states, nil, opts) end @doc """ @@ -147,7 +147,7 @@ defmodule Bumblebee.Diffusion.Layers.UNet do def up_block_2d( hidden_state, timestep_embedding, - residuals, + down_block_states, encoder_hidden_state, opts ) do @@ -164,18 +164,18 @@ defmodule Bumblebee.Diffusion.Layers.UNet do add_upsample = Keyword.get(opts, :add_upsample, true) name = opts[:name] - ^depth = length(residuals) + ^depth = length(down_block_states) hidden_state = - for {{residual, residual_channels}, idx} <- Enum.with_index(residuals), + for {{down_block_state, down_block_channels}, idx} <- Enum.with_index(down_block_states), reduce: hidden_state do hidden_state -> in_channels = if(idx == 0, do: in_channels, else: out_channels) hidden_state = - Axon.concatenate([hidden_state, residual], axis: -1) + Axon.concatenate([hidden_state, down_block_state], axis: -1) |> Diffusion.Layers.residual_block( - in_channels + residual_channels, + in_channels + down_block_channels, out_channels, timestep_embedding: timestep_embedding, norm_epsilon: norm_epsilon, diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex new file mode 100644 index 00000000..fee08e9f --- /dev/null +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -0,0 +1,542 @@ +defmodule Bumblebee.Diffusion.StableDiffusionControlNet do + @moduledoc """ + High-level tasks based on Stable Diffusion with ControlNet. + """ + + import Nx.Defn + + alias Bumblebee.Utils + alias Bumblebee.Shared + + @type text_to_image_input :: + String.t() + | %{ + :prompt => String.t(), + :conditioning => Nx.Tensor.t(), + optional(:conditioning_scale) => integer(), + optional(:negative_prompt) => String.t(), + optional(:seed) => integer() + } + @type text_to_image_output :: %{results: list(text_to_image_result())} + @type text_to_image_result :: %{:image => Nx.Tensor.t(), optional(:is_safe) => boolean()} + + @doc ~S""" + Build serving for prompt-driven image generation. + + The serving accepts `t:text_to_image_input/0` and returns `t:text_to_image_output/0`. + A list of inputs is also supported. + + You can specify `:safety_checker` model to automatically detect + when a generated image is offensive or harmful and filter it out. + + ## Options + + * `:safety_checker` - the safety checker model info map. When a + safety checker is used, each output entry has an additional + `:is_safe` property and unsafe images are automatically zeroed. + Make sure to also set `:safety_checker_featurizer` + + * `:safety_checker_featurizer` - the featurizer to use to preprocess + the safety checker input images + + * `:num_steps` - the number of denoising steps. More denoising + steps usually lead to higher image quality at the expense of + slower inference. Defaults to `50` + + * `:num_images_per_prompt` - the number of images to generate for + each prompt. Defaults to `1` + + * `:guidance_scale` - the scale used for classifier-free diffusion + guidance. Higher guidance scale makes the generated images more + closely reflect the text prompt. This parameter corresponds to + $\omega$ in Equation (2) of the [Imagen paper](https://arxiv.org/pdf/2205.11487.pdf). + Defaults to `7.5` + + * `:compile` - compiles all computations for predefined input shapes + during serving initialization. Should be a keyword list with the + following keys: + + * `:batch_size` - the maximum batch size of the input. Inputs + are optionally padded to always match this batch size + + * `:sequence_length` - the maximum input sequence length. Input + sequences are always padded/truncated to match that length + + It is advised to set this option in production and also configure + a defn compiler using `:defn_options` to maximally reduce inference + time. + + * `:defn_options` - the options for JIT compilation. Defaults to `[]` + + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. When using this option, you should first + load the parameters into the host. This can be done by passing + `backend: {EXLA.Backend, client: :host}` to `load_model/1` and friends. + Defaults to `false` + + ## Examples + + repository_id = "CompVis/stable-diffusion-v1-4" + + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) + {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) + {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) + {:ok, controlnet} = Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}) + {:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) + {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"}) + {:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"}) + {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) + + serving = + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, + num_steps: 20, + num_images_per_prompt: 2, + safety_checker: safety_checker, + safety_checker_featurizer: featurizer, + compile: [batch_size: 1, sequence_length: 60], + defn_options: [compiler: EXLA] + ) + + prompt = "numbat in forest, detailed, digital art" + + # The conditioning image matching the given ControlNet condition, + # such as edges, pose or depth. Here we use a simple handcrafted + # tensor + conditioning = + Nx.tensor( + [List.duplicate(255, 8) ++ List.duplicate(0, 24)], + type: :u8 + ) + |> Nx.tile([256, 8, 3]) + |> Nx.pad(0, [{192, 64, 0}, {192, 64, 0}, {0, 0, 0}]) + |> Nx.transpose(axes: [1, 0, 2]) + + Nx.Serving.run(serving, %{prompt: prompt, conditioning: conditioning}) + #=> %{ + #=> results: [ + #=> %{ + #=> image: #Nx.Tensor< + #=> u8[512][512][3] + #=> ... + #=> >, + #=> is_safe: true + #=> }, + #=> %{ + #=> image: #Nx.Tensor< + #=> u8[512][512][3] + #=> ... + #=> >, + #=> is_safe: true + #=> } + #=> ] + #=> } + + """ + @spec text_to_image( + Bumblebee.model_info(), + Bumblebee.model_info(), + Bumblebee.model_info(), + Bumblebee.model_info(), + Bumblebee.Tokenizer.t(), + Bumblebee.Scheduler.t(), + keyword() + ) :: Nx.Serving.t() + def text_to_image(encoder, unet, vae, controlnet, tokenizer, scheduler, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :safety_checker, + :safety_checker_featurizer, + :compile, + num_steps: 50, + num_images_per_prompt: 1, + guidance_scale: 7.5, + defn_options: [], + preallocate_params: false + ]) + + safety_checker = opts[:safety_checker] + safety_checker_featurizer = opts[:safety_checker_featurizer] + num_steps = opts[:num_steps] + num_images_per_prompt = opts[:num_images_per_prompt] + preallocate_params = opts[:preallocate_params] + defn_options = opts[:defn_options] + + if safety_checker != nil and safety_checker_featurizer == nil do + raise ArgumentError, "got :safety_checker but no :safety_checker_featurizer was specified" + end + + safety_checker? = safety_checker != nil + + compile = + if compile = opts[:compile] do + compile + |> Keyword.validate!([:batch_size, :sequence_length]) + |> Shared.require_options!([:batch_size, :sequence_length]) + end + + batch_size = compile[:batch_size] + sequence_length = compile[:sequence_length] + + conditioning_size = + controlnet.spec.sample_size * + 2 ** (length(controlnet.spec.conditioning_embedding_hidden_sizes) - 1) + + tokenizer = + Bumblebee.configure(tokenizer, + length: sequence_length, + return_token_type_ids: false, + return_attention_mask: false + ) + + {_, encoder_predict} = Axon.build(encoder.model) + {_, vae_predict} = Axon.build(vae.model) + {_, unet_predict} = Axon.build(unet.model) + {_, controlnet_predict} = Axon.build(controlnet.model) + + scheduler_init = &Bumblebee.scheduler_init(scheduler, num_steps, &1, &2) + scheduler_step = &Bumblebee.scheduler_step(scheduler, &1, &2, &3) + + image_fun = + &text_to_image_impl( + encoder_predict, + &1, + unet_predict, + &2, + vae_predict, + &3, + controlnet_predict, + &4, + scheduler_init, + scheduler_step, + &5, + num_images_per_prompt: opts[:num_images_per_prompt], + latents_sample_size: unet.spec.sample_size, + latents_channels: unet.spec.in_channels, + guidance_scale: opts[:guidance_scale] + ) + + safety_checker_fun = + if safety_checker do + {_, predict_fun} = Axon.build(safety_checker.model) + predict_fun + end + + # Note that all of these are copied when using serving as a process + init_args = [ + {image_fun, safety_checker_fun}, + encoder.params, + unet.params, + vae.params, + controlnet.params, + {safety_checker?, safety_checker[:spec], safety_checker[:params]}, + safety_checker_featurizer, + {compile != nil, batch_size, sequence_length, conditioning_size}, + num_images_per_prompt, + preallocate_params + ] + + Nx.Serving.new( + fn defn_options -> apply(&init/11, init_args ++ [defn_options]) end, + defn_options + ) + |> Nx.Serving.batch_size(batch_size) + |> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer)) + |> Nx.Serving.client_postprocessing(&client_postprocessing(&1, &2, safety_checker)) + end + + defp init( + {image_fun, safety_checker_fun}, + encoder_params, + unet_params, + vae_params, + controlnet_params, + {safety_checker?, safety_checker_spec, safety_checker_params}, + safety_checker_featurizer, + {compile?, batch_size, sequence_length, conditioning_size}, + num_images_per_prompt, + preallocate_params, + defn_options + ) do + encoder_params = Shared.maybe_preallocate(encoder_params, preallocate_params, defn_options) + unet_params = Shared.maybe_preallocate(unet_params, preallocate_params, defn_options) + vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options) + + controlnet_params = + Shared.maybe_preallocate(controlnet_params, preallocate_params, defn_options) + + image_fun = + Shared.compile_or_jit(image_fun, defn_options, compile?, fn -> + inputs = %{ + "conditional_and_unconditional" => %{ + "input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32) + }, + "seed" => Nx.template({batch_size}, :s64), + "conditioning" => + Nx.template( + {batch_size, conditioning_size, conditioning_size, 3}, + :f32 + ), + "conditioning_scale" => Nx.template({batch_size}, :f32) + } + + [encoder_params, unet_params, vae_params, controlnet_params, inputs] + end) + + safety_checker_fun = + safety_checker_fun && + Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn -> + inputs = %{ + "pixel_values" => + Shared.input_template(safety_checker_spec, "pixel_values", [ + batch_size * num_images_per_prompt + ]) + } + + [safety_checker_params, inputs] + end) + + safety_checker_params = + safety_checker_params && + Shared.maybe_preallocate(safety_checker_params, preallocate_params, defn_options) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + + image = image_fun.(encoder_params, unet_params, vae_params, controlnet_params, inputs) + + output = + if safety_checker? do + inputs = Bumblebee.apply_featurizer(safety_checker_featurizer, image) + outputs = safety_checker_fun.(safety_checker_params, inputs) + %{image: image, is_unsafe: outputs.is_unsafe} + else + %{image: image} + end + + output + |> Utils.Nx.composite_unflatten_batch(Utils.Nx.batch_size(inputs)) + |> Shared.serving_post_computation() + end + end + + defp preprocess_image(image) do + NxImage.to_continuous(image, 0, 1) + end + + defp client_preprocessing(input, tokenizer) do + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1) + + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend) + + # Note: we need to tokenize all sequences together, so that + # they are padded to the same length (if not specified) + prompts = Enum.flat_map(inputs, &[&1.prompt, &1.negative_prompt]) + + prompt_pairs = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + inputs = Bumblebee.apply_tokenizer(tokenizer, prompts) + Utils.Nx.composite_unflatten_batch(inputs, Nx.axis_size(seed, 0)) + end) + + conditioning = + Enum.map(inputs, & &1.conditioning) + |> Nx.stack() + |> preprocess_image() + + conditioning_scale = + Enum.map(inputs, & &1.conditioning_scale) + |> Nx.tensor(type: :f32, backend: Nx.BinaryBackend) + + inputs = %{ + "conditional_and_unconditional" => prompt_pairs, + "seed" => seed, + "conditioning" => conditioning, + "conditioning_scale" => conditioning_scale + } + + {Nx.Batch.concatenate([inputs]), multi?} + end + + defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do + for outputs <- Utils.Nx.batch_to_list(outputs) do + results = + for outputs = %{image: image} <- Utils.Nx.batch_to_list(outputs) do + if safety_checker? do + if Nx.to_number(outputs.is_unsafe) == 1 do + %{image: zeroed(image), is_safe: false} + else + %{image: image, is_safe: true} + end + else + %{image: image} + end + end + + %{results: results} + end + |> Shared.normalize_output(multi?) + end + + defp zeroed(tensor) do + 0 + |> Nx.tensor(type: Nx.type(tensor), backend: Nx.BinaryBackend) + |> Nx.broadcast(Nx.shape(tensor)) + end + + defnp text_to_image_impl( + encoder_predict, + encoder_params, + unet_predict, + unet_params, + vae_predict, + vae_params, + controlnet_predict, + controlnet_params, + scheduler_init, + scheduler_step, + inputs, + opts \\ [] + ) do + num_images_per_prompt = opts[:num_images_per_prompt] + latents_sample_size = opts[:latents_sample_size] + latents_in_channels = opts[:latents_channels] + guidance_scale = opts[:guidance_scale] + + seed = inputs["seed"] + conditioning = inputs["conditioning"] + conditioning_scale = inputs["conditioning_scale"] + + inputs = + inputs["conditional_and_unconditional"] + # Transpose conditional and unconditional to separate blocks + |> composite_transpose_leading() + |> Utils.Nx.composite_flatten_batch() + + %{hidden_state: text_embeddings} = encoder_predict.(encoder_params, inputs) + + {_twice_batch_size, sequence_length, hidden_size} = Nx.shape(text_embeddings) + + text_embeddings = + text_embeddings + |> Nx.new_axis(1) + |> Nx.tile([1, num_images_per_prompt, 1, 1]) + |> Nx.reshape({:auto, sequence_length, hidden_size}) + + prng_key = + seed + |> Nx.vectorize(:batch) + |> Nx.Random.key() + |> Nx.Random.split(parts: num_images_per_prompt) + |> Nx.devectorize() + |> Nx.flatten(axes: [0, 1]) + |> Nx.vectorize(:batch) + + {latents, prng_key} = + Nx.Random.normal(prng_key, + shape: {latents_sample_size, latents_sample_size, latents_in_channels} + ) + + {scheduler_state, timesteps} = scheduler_init.(Nx.to_template(latents), prng_key) + + latents = Nx.devectorize(latents) + + {latents, _} = + while {latents, + {scheduler_state, text_embeddings, unet_params, conditioning, conditioning_scale, + controlnet_params}}, + timestep <- timesteps do + sample = Nx.concatenate([latents, latents]) + + controlnet_inputs = %{ + "conditioning" => conditioning, + "conditioning_scale" => conditioning_scale, + "sample" => sample, + "timestep" => timestep, + "encoder_hidden_state" => text_embeddings + } + + %{down_block_states: down_block_states, mid_block_state: mid_block_state} = + controlnet_predict.(controlnet_params, controlnet_inputs) + + unet_inputs = + %{ + "sample" => sample, + "timestep" => timestep, + "encoder_hidden_state" => text_embeddings, + "additional_down_block_states" => down_block_states, + "additional_mid_block_state" => mid_block_state + } + + %{sample: noise_pred} = unet_predict.(unet_params, unet_inputs) + + {noise_pred_conditional, noise_pred_unconditional} = + split_conditional_and_unconditional(noise_pred) + + noise_pred = + noise_pred_unconditional + + guidance_scale * (noise_pred_conditional - noise_pred_unconditional) + + {scheduler_state, latents} = + scheduler_step.( + scheduler_state, + Nx.vectorize(latents, :batch), + Nx.vectorize(noise_pred, :batch) + ) + + latents = Nx.devectorize(latents) + + {latents, + {scheduler_state, text_embeddings, unet_params, conditioning, conditioning_scale, + controlnet_params}} + end + + latents = latents * (1 / 0.18215) + + %{sample: image} = vae_predict.(vae_params, latents) + + NxImage.from_continuous(image, -1, 1) + end + + deftransformp composite_transpose_leading(container) do + Utils.Nx.map(container, fn tensor -> + [first, second | rest] = Nx.axes(tensor) + Nx.transpose(tensor, axes: [second, first | rest]) + end) + end + + defnp split_conditional_and_unconditional(tensor) do + batch_size = Nx.axis_size(tensor, 0) + half_size = div(batch_size, 2) + {tensor[0..(half_size - 1)//1], tensor[half_size..-1//1]} + end + + defp validate_input(prompt) when is_binary(prompt), do: validate_input(%{prompt: prompt}) + + defp validate_input(%{prompt: prompt, conditioning: conditioning} = input) do + {:ok, + %{ + prompt: prompt, + conditioning: conditioning, + conditioning_scale: input[:conditioning_scale] || 1.0, + negative_prompt: input[:negative_prompt] || "", + seed: input[:seed] || :erlang.system_time() + }} + end + + defp validate_input(%{} = input) do + {:error, + "expected the input map to have :prompt and :conditioning key, got: #{inspect(input)}"} + end + + defp validate_input(input) do + {:error, "expected either a string or a map, got: #{inspect(input)}"} + end +end diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index 03d39169..8045eca9 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -96,7 +96,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ] @moduledoc """ - U-Net model with two spatial dimensions and conditional state. + U-Net model with two spatial dimensions and conditioning state. ## Architectures @@ -115,7 +115,16 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do * `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}` - The conditional state (context) to use with cross-attention. + The conditioning state (context) to use with cross-attention. + + * `"additional_down_block_states"` + + Optional outputs matching the structure of down blocks, added as + part of the encoder-decoder skip connections. + + * `"additional_mid_block_state"` + + Optional output added to the mid block result. ## Configuration @@ -166,7 +175,9 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), - Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}) + Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), + Axon.input("additional_down_block_states", optional: true), + Axon.input("additional_mid_block_state", optional: true) ]) end @@ -208,12 +219,19 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do name: "input_conv" ) - {sample, down_block_residuals} = + {sample, down_block_states} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") + down_block_states = + maybe_add_down_block_states(down_block_states, inputs["additional_down_block_states"]) + + sample = + sample + |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") + |> maybe_add_mid_block_state(inputs["additional_mid_block_state"]) + sample - |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") - |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec, + |> up_blocks(timestep_embedding, down_block_states, encoder_hidden_state, spec, name: "up_blocks" ) |> Axon.group_norm(spec.group_norm_num_groups, @@ -235,17 +253,17 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do Enum.zip([spec.hidden_sizes, spec.down_block_types, num_attention_heads_per_block(spec)]) in_channels = hd(spec.hidden_sizes) - down_block_residuals = [{sample, in_channels}] + down_block_states = [{sample, in_channels}] - state = {sample, down_block_residuals, in_channels} + state = {sample, down_block_states, in_channels} - {sample, down_block_residuals, _} = + {sample, down_block_states, _} = for {{out_channels, block_type, num_attention_heads}, idx} <- Enum.with_index(blocks), reduce: state do - {sample, down_block_residuals, in_channels} -> + {sample, down_block_states, in_channels} -> last_block? = idx == length(spec.hidden_sizes) - 1 - {sample, residuals} = + {sample, states} = Diffusion.Layers.UNet.down_block_2d( block_type, sample, @@ -264,10 +282,10 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do name: join(name, idx) ) - {sample, down_block_residuals ++ Tuple.to_list(residuals), out_channels} + {sample, down_block_states ++ Tuple.to_list(states), out_channels} end - {sample, List.to_tuple(down_block_residuals)} + {sample, List.to_tuple(down_block_states)} end defp mid_block(hidden_state, timesteps_embedding, encoder_hidden_state, spec, opts) do @@ -289,15 +307,15 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do defp up_blocks( sample, timestep_embedding, - down_block_residuals, + down_block_states, encoder_hidden_state, spec, opts ) do name = opts[:name] - down_block_residuals = - down_block_residuals + down_block_states = + down_block_states |> Tuple.to_list() |> Enum.reverse() |> Enum.chunk_every(spec.depth + 1) @@ -315,13 +333,13 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do reversed_hidden_sizes, spec.up_block_types, num_attention_heads_per_block, - down_block_residuals + down_block_states ] |> Enum.zip() |> Enum.with_index() {sample, _} = - for {{out_channels, block_type, num_attention_heads, residuals}, idx} <- blocks_and_chunks, + for {{out_channels, block_type, num_attention_heads, states}, idx} <- blocks_and_chunks, reduce: {sample, in_channels} do {sample, in_channels} -> last_block? = idx == length(spec.hidden_sizes) - 1 @@ -331,7 +349,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do block_type, sample, timestep_embedding, - residuals, + states, encoder_hidden_state, depth: spec.depth + 1, in_channels: in_channels, @@ -351,6 +369,32 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do sample end + defp maybe_add_mid_block_state(mid_block_state, additional_mid_block_state) do + maybe_add(mid_block_state, additional_mid_block_state) + end + + defp maybe_add_down_block_states(down_block_states, additional_down_block_states) do + down_block_states = Tuple.to_list(down_block_states) + + for {{down_block_state, out_channels}, i} <- Enum.with_index(down_block_states) do + additional_down_block_state = Axon.nx(additional_down_block_states, &elem(&1, i)) + {maybe_add(down_block_state, additional_down_block_state), out_channels} + end + |> List.to_tuple() + end + + defp maybe_add(left, maybe_right) do + Axon.layer( + fn left, maybe_right, _opts -> + case maybe_right do + %Axon.None{} -> left + right -> Nx.add(left, right) + end + end, + [left, Axon.optional(maybe_right)] + ) + end + defp num_attention_heads_per_block(spec) when is_list(spec.num_attention_heads) do spec.num_attention_heads end diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index b3a0d6fd..a2e4151c 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -492,7 +492,7 @@ defmodule Bumblebee.Shared do def featurizer_resize_size(_images, %{height: height, width: width}), do: {height, width} def featurizer_resize_size(images, %{shortest_edge: size}) do - {height, width} = images_spacial_sizes(images) + {height, width} = images_spatial_sizes(images) {short, long} = if height < width, do: {height, width}, else: {width, height} @@ -502,7 +502,7 @@ defmodule Bumblebee.Shared do if height < width, do: {out_short, out_long}, else: {out_long, out_short} end - defp images_spacial_sizes(images) do + defp images_spatial_sizes(images) do height = Nx.axis_size(images, -3) width = Nx.axis_size(images, -2) {height, width} diff --git a/mix.exs b/mix.exs index b36c9748..b0677f3e 100644 --- a/mix.exs +++ b/mix.exs @@ -71,10 +71,12 @@ defmodule Bumblebee.MixProject do Bumblebee.Audio, Bumblebee.Text, Bumblebee.Vision, - Bumblebee.Diffusion.StableDiffusion + Bumblebee.Diffusion.StableDiffusion, + Bumblebee.Diffusion.StableDiffusionControlNet ], Models: [ Bumblebee.Audio.Whisper, + Bumblebee.Diffusion.ControlNet, Bumblebee.Diffusion.StableDiffusion.SafetyChecker, Bumblebee.Diffusion.UNet2DConditional, Bumblebee.Diffusion.VaeKl, diff --git a/test/bumblebee/diffusion/controlnet_test.exs b/test/bumblebee/diffusion/controlnet_test.exs new file mode 100644 index 00000000..ef1512d7 --- /dev/null +++ b/test/bumblebee/diffusion/controlnet_test.exs @@ -0,0 +1,66 @@ +defmodule Bumblebee.Diffusion.ControlNetTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"}) + + assert %Bumblebee.Diffusion.ControlNet{architecture: :base} = spec + + inputs = %{ + "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), + "timestep" => Nx.tensor(1), + "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), + "conditioning" => Nx.broadcast(0.5, {1, 64, 64, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.mid_block_state) == {1, 16, 16, 64} + + assert_all_close( + outputs.mid_block_state[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [[-0.2818, 1.6207, -0.7002], [0.2391, 1.1387, 0.9682], [-0.6386, 0.7026, -0.4218]], + [[1.0681, 1.8418, -1.0586], [0.9387, 0.5971, 1.2284], [1.2914, 0.4060, -0.9559]], + [[0.5841, 1.2935, 0.0081], [0.7306, 0.2915, 0.7736], [0.0875, 0.9619, 0.4108]] + ] + ]) + ) + + assert tuple_size(outputs.down_block_states) == 6 + + first_down_block_state = elem(outputs.down_block_states, 0) + assert Nx.shape(first_down_block_state) == {1, 32, 32, 32} + + assert_all_close( + first_down_block_state[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [[-0.1423, 0.2804, -0.0497], [-0.1425, 0.2798, -0.0485], [-0.1426, 0.2794, -0.0488]], + [[-0.1419, 0.2810, -0.0493], [-0.1427, 0.2803, -0.0479], [-0.1427, 0.2800, -0.0486]], + [[-0.1417, 0.2812, -0.0494], [-0.1427, 0.2807, -0.0480], [-0.1426, 0.2804, -0.0486]] + ] + ]) + ) + + last_down_block_state = elem(outputs.down_block_states, 5) + assert Nx.shape(last_down_block_state) == {1, 16, 16, 64} + + assert_all_close( + last_down_block_state[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [[-1.1169, 0.8087, 0.1024], [0.4832, 0.0686, 1.0149], [-0.3314, 0.1486, 0.4445]], + [[0.5770, 0.3195, -0.2008], [1.5692, -0.1771, 0.7669], [0.4908, 0.1258, 0.0694]], + [[0.4694, -0.3723, 0.1505], [1.7356, -0.4214, 0.8929], [0.4702, 0.2400, 0.1213]] + ] + ]) + ) + end +end diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs new file mode 100644 index 00000000..5db38905 --- /dev/null +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -0,0 +1,112 @@ +defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do + use ExUnit.Case, async: false + + import Bumblebee.TestHelpers + + @moduletag serving_test_tags() + + describe "text_to_image/6" do + test "generates image for a text prompt with controlnet" do + # Since we don't assert on the result in this case, we use + # a tiny random checkpoint. This test is basically to verify + # the whole generation computation end-to-end + + repository_id = "bumblebee-testing/tiny-stable-diffusion" + + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) + {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) + + {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) + + {:ok, controlnet} = Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"}) + + {:ok, vae} = + Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) + + {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"}) + + {:ok, featurizer} = + Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"}) + + {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) + + conditioning_size = + controlnet.spec.sample_size * + 2 ** (length(controlnet.spec.conditioning_embedding_hidden_sizes) - 1) + + serving = + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, + num_steps: 3, + safety_checker: safety_checker, + safety_checker_featurizer: featurizer + ) + + prompt = "numbat in forest, detailed, digital art" + + conditioning = + Nx.broadcast(Nx.tensor(50, type: :u8), {conditioning_size, conditioning_size, 3}) + + assert %{ + results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] + } = + Nx.Serving.run(serving, %{ + prompt: prompt, + conditioning: conditioning + }) + + # Without safety checker + + serving = + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, + num_steps: 3 + ) + + prompt = "numbat in forest, detailed, digital art" + + assert %{results: [%{image: %Nx.Tensor{}}]} = + Nx.Serving.run(serving, %{ + prompt: prompt, + conditioning: conditioning + }) + + # With compilation + + serving = + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, + num_steps: 3, + safety_checker: safety_checker, + safety_checker_featurizer: featurizer, + compile: [batch_size: 1, sequence_length: 60], + defn_options: [compiler: EXLA] + ) + + prompt = "numbat in forest, detailed, digital art" + + assert %{ + results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] + } = + Nx.Serving.run(serving, %{ + prompt: prompt, + conditioning: conditioning + }) + end + end +end diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index c2266c7f..ce6fc5be 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -34,4 +34,51 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do ]) ) end + + test ":base with additional states for skip connection" do + tiny = "bumblebee-testing/tiny-stable-diffusion" + + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, tiny, subdir: "unet"}) + + assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec + + down_block_states = + [ + {1, 32, 32, 32}, + {1, 32, 32, 32}, + {1, 32, 32, 32}, + {1, 16, 16, 32}, + {1, 16, 16, 64}, + {1, 16, 16, 64} + ] + |> Enum.map(&Nx.broadcast(0.5, &1)) + |> List.to_tuple() + + mid_block_state = Nx.broadcast(0.5, {1, 16, 16, 64}) + + inputs = + %{ + "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), + "timestep" => Nx.tensor(1), + "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), + "additional_down_block_states" => down_block_states, + "additional_mid_block_state" => mid_block_state + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.sample) == {1, 32, 32, 4} + + assert_all_close( + outputs.sample[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [[-0.9457, -0.2378, 1.4223], [-0.5736, -0.2456, 0.7603], [-0.4346, -1.1370, -0.1988]], + [[-0.5274, -1.0902, 0.5937], [-1.2290, -0.7996, 0.0264], [-0.3006, -0.1181, 0.7059]], + [[-0.8336, -1.1615, -0.1906], [-1.0489, -0.3815, -0.5497], [-0.6255, 0.0863, 0.3285]] + ] + ]) + ) + end end