Skip to content

Commit

Permalink
Add Stable Diffusion ControlNet (#359)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
joelpaulkoch and jonatanklosko authored Apr 8, 2024
1 parent 2fbe380 commit be8e710
Show file tree
Hide file tree
Showing 11 changed files with 1,362 additions and 33 deletions.
1 change: 1 addition & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ defmodule Bumblebee do
"CLIPModel" => {Bumblebee.Multimodal.Clip, :base},
"CLIPTextModel" => {Bumblebee.Text.ClipText, :base},
"CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base},
"ControlNetModel" => {Bumblebee.Diffusion.ControlNet, :base},
"ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification},
"ConvNextModel" => {Bumblebee.Vision.ConvNext, :base},
"DeiTForImageClassification" => {Bumblebee.Vision.Deit, :for_image_classification},
Expand Down
14 changes: 12 additions & 2 deletions lib/bumblebee/conversion/pytorch_params.ex
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,18 @@ defmodule Bumblebee.Conversion.PyTorchParams do
defp match_template(name, template), do: match_template(name, template, %{})

defp match_template(<<_, _::binary>> = name, <<"{", template::binary>>, substitutes) do
[value, name] = String.split(name, ".", parts: 2)
[key, template] = String.split(template, "}.", parts: 2)
{value, name} =
case String.split(name, ".", parts: 2) do
[value] -> {value, ""}
[value, name] -> {value, name}
end

{key, template} =
case String.split(template, "}", parts: 2) do
[key, ""] -> {key, ""}
[key, "." <> template] -> {key, template}
end

match_template(name, template, put_in(substitutes[key], value))
end

Expand Down
Loading

0 comments on commit be8e710

Please sign in to comment.