Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Style transfer #16

Merged
merged 15 commits into from
Aug 21, 2024
12 changes: 6 additions & 6 deletions lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ defmodule ExVision.Model.Definition do
Application.ensure_all_started(:req)

options =
Keyword.validate!(options, [
:categories,
Keyword.validate!(options,
categories: nil,
name: module_to_name(__CALLER__.module)
])
)

quote do
unless is_nil(unquote(options[:categories])) do
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)
end
# conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro
# here is explenation https://cocoa-research.works/2022/10/conditional-compliation-with-if-and-use-in-elixir/
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)

@behaviour ExVision.Model.Definition

Expand Down
29 changes: 16 additions & 13 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
require Bunch.Typespec
unless is_nil(options |> Keyword.fetch!(:categories)) do
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)
quote do
require Bunch.Typespec

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
end
75 changes: 75 additions & 0 deletions lib/ex_vision/style_transfer/style_transfer.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
defmodule Configuration do
mkopcins marked this conversation as resolved.
Show resolved Hide resolved
@moduledoc false

@low_resolution {400, 300}
@high_resolution {640, 480}

@spec configuration() :: %{}
def configuration do
%{
ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.PrincessFast => [
model: "princess_fast.onnx",
resolution: @low_resolution
],
ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.MosaicFast => [
model: "mosaic_fast.onnx",
resolution: @low_resolution
]
}
end
end

for {module, opts} <- Configuration.configuration() do
defmodule module do
@moduledoc """
#{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference.
"""
use ExVision.Model.Definition.Ortex, model: unquote(opts[:model])

require Logger

@typedoc """
A type consisting of output tesnor (stylized image tensor) from style transfer models of shape {#{Enum.join(Tuple.to_list(opts[:resolution]) ++ [3], ", ")}}.
"""
@type output_t() :: Nx.Tensor.t()

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Logger.warning(
"`:max_batch_size` was given, but this model can only process batch of size 1. Overriding"
)
end

options
|> Keyword.put(:batch_size, 1)
|> default_model_load()
end

@impl true
def preprocessing(img, _metdata) do
img |> ExVision.Utils.resize(unquote(opts[:resolution])) |> Nx.divide(255.0)
end

@impl true
def postprocessing(
stylized_frame,
metadata
) do
{h, w} = unquote(opts[:resolution])

stylized_frame["55"]
|> Nx.reshape({3, h, w}, names: [:channel, :height, :width])
|> NxImage.resize(metadata.original_size, channels: :first, method: :bilinear)
|> Nx.clip(0.0, 255.0)
|> Nx.as_type(:u8)
|> Nx.transpose(axes: [1, 2, 0])
end
end
end
9 changes: 9 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ defmodule ExVision.Mixfile do
ExVision.Classification.EfficientNet_V2_L,
ExVision.Classification.SqueezeNet1_1,
ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3,
ExVision.StyleTransfer.Candy,
ExVision.StyleTransfer.CandyFast,
ExVision.StyleTransfer.Udnie,
ExVision.StyleTransfer.UdnieFast,
ExVision.StyleTransfer.Mosaic,
ExVision.StyleTransfer.MosaicFast,
ExVision.StyleTransfer.Princess,
ExVision.StyleTransfer.PrincessFast,
ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2,
ExVision.ObjectDetection.Ssdlite320_MobileNetv3,
ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN,
Expand All @@ -123,6 +131,7 @@ defmodule ExVision.Mixfile do
ExVision.Types,
ExVision.Classification,
ExVision.SemanticSegmentation,
ExVision.StyleTransfer,
ExVision.InstanceSegmentation,
ExVision.ObjectDetection,
ExVision.KeypointDetection
Expand Down
6 changes: 3 additions & 3 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
"cc_precompiler": {:hex, :cc_precompiler, "0.1.10", "47c9c08d8869cf09b41da36538f62bc1abd3e19e41701c2cea2675b53c704258", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f6e046254e53cd6b41c6bacd70ae728011aa82b2742a80d6e2214855c6e06b22"},
"coerce": {:hex, :coerce, "1.0.1", "211c27386315dc2894ac11bc1f413a0e38505d808153367bd5c6e75a4003d096", [:mix], [], "hexpm", "b44a691700f7a1a15b4b7e2ff1fa30bebd669929ac8aa43cffe9e2f8bf051cf1"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"credo": {:hex, :credo, "1.7.5", "643213503b1c766ec0496d828c90c424471ea54da77c8a168c725686377b9545", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "f799e9b5cd1891577d8c773d245668aa74a2fcd15eb277f51a0131690ebfb3fd"},
"credo": {:hex, :credo, "1.7.7", "771445037228f763f9b2afd612b6aa2fd8e28432a95dbbc60d8e03ce71ba4446", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "8bc87496c9aaacdc3f90f01b7b0582467b69b4bd2441fe8aae3109d843cc2f2e"},
"dialyxir": {:hex, :dialyxir, "1.4.3", "edd0124f358f0b9e95bfe53a9fcf806d615d8f838e2202a9f430d59566b6b53b", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "bf2cfb75cd5c5006bec30141b131663299c661a864ec7fbbc72dfa557487a986"},
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
"erlex": {:hex, :erlex, "0.2.6", "c7987d15e899c7a2f34f5420d2a2ea0d659682c06ac607572df55a43753aa12e", [:mix], [], "hexpm", "2ed2e25711feb44d52b17d2780eabf998452f6efda104877a3881c2f8c0c0c75"},
"evision": {:hex, :evision, "0.1.38", "f8b23ad685c3ebd70969a3457027b5c74b5bc8dc51588661c516098c3240b92d", [:make, :mix, :rebar3], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}, {:progress_bar, "~> 2.0 or ~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: true]}], "hexpm", "f9302547d76c5e4ad7022ffdc76be13e33c990fdd67ad2af203f24ab5d3aee20"},
"ex_doc": {:hex, :ex_doc, "0.32.1", "21e40f939515373bcdc9cffe65f3b3543f05015ac6c3d01d991874129d173420", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5142c9db521f106d61ff33250f779807ed2a88620e472ac95dc7d59c380113da"},
"exla": {:hex, :exla, "0.7.2", "8ac573093df8e5e6b36845beeb3f5a0ea92b05082bf2fa4678f80170cfc887f6", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d061ea87858415e5585cbd4b7bdae5489000339519a2c6a7f51eb0defd73b588"},
"file_system": {:hex, :file_system, "1.0.0", "b689cc7dcee665f774de94b5a832e578bd7963c8e637ef940cd44327db7de2cd", [:mix], [], "hexpm", "6752092d66aec5a10e662aefeed8ddb9531d79db0bc145bb8c40325ca1d8536d"},
"file_system": {:hex, :file_system, "1.0.1", "79e8ceaddb0416f8b8cd02a0127bdbababe7bf4a23d2a395b983c1f8b3f73edd", [:mix], [], "hexpm", "4414d1f38863ddf9120720cd976fce5bdde8e91d8283353f0e31850fa89feb9e"},
"finch": {:hex, :finch, "0.18.0", "944ac7d34d0bd2ac8998f79f7a811b21d87d911e77a786bc5810adb75632ada4", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "69f5045b042e531e53edc2574f15e25e735b522c37e2ddb766e15b979e03aa65"},
"hpax": {:hex, :hpax, "0.2.0", "5a58219adcb75977b2edce5eb22051de9362f08236220c9e859a47111c194ff5", [:mix], [], "hexpm", "bea06558cdae85bed075e6c036993d43cd54d447f76d8190a8db0dc5893fa2f1"},
"image": {:hex, :image, "0.44.0", "e8eea9398abbed12b7784e786f26a5c839a00bcddd8f2f8ba12adf7e227beb9f", [:mix], [{:bumblebee, "~> 0.3", [hex: :bumblebee, repo: "hexpm", optional: true]}, {:evision, "~> 0.1.33", [hex: :evision, repo: "hexpm", optional: true]}, {:exla, "~> 0.5", [hex: :exla, repo: "hexpm", optional: true]}, {:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: true]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: true]}, {:nx_image, "~> 0.1", [hex: :nx_image, repo: "hexpm", optional: true]}, {:phoenix_html, "~> 2.1 or ~> 3.2 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:plug, "~> 1.13", [hex: :plug, repo: "hexpm", optional: true]}, {:req, "~> 0.4", [hex: :req, repo: "hexpm", optional: true]}, {:rustler, "> 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:sweet_xml, "~> 0.7", [hex: :sweet_xml, repo: "hexpm", optional: false]}, {:vix, "~> 0.23", [hex: :vix, repo: "hexpm", optional: false]}], "hexpm", "cd00a3de4d7a40a2cb1ca72b9852b0d81701793414af8babf4d33dbeb6de0f6f"},
"jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"},
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
"makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.5", "e0ff5a7c708dda34311f7522a8758e23bfcd7d8d8068dc312b5eb41c6fd76eba", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "94d2e986428585a21516d7d7149781480013c56e30c6a233534bedf38867a59a"},
Expand Down
Binary file added test/assets/results/style_transfer/cat_candy.gt
Binary file not shown.
Binary file not shown.
Binary file added test/assets/results/style_transfer/cat_mosaic.gt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added test/assets/results/style_transfer/cat_udnie.gt
Binary file not shown.
Binary file not shown.
56 changes: 56 additions & 0 deletions test/ex_vision/style_transfer/style_transfer_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defmodule TestConfiguration do
@spec configuration() :: %{}
def configuration do
%{
ExVision.StyleTransfer.CandyTest => [
module: ExVision.StyleTransfer.Candy,
gt_file: "cat_candy.gt"
],
ExVision.StyleTransfer.CandyFastTest => [
module: ExVision.StyleTransfer.CandyFast,
gt_file: "cat_candy_fast.gt"
],
ExVision.StyleTransfer.PrincessTest => [
module: ExVision.StyleTransfer.Princess,
gt_file: "cat_princess.gt"
],
ExVision.StyleTransfer.PrincessFastTest => [
module: ExVision.StyleTransfer.PrincessFast,
gt_file: "cat_princess_fast.gt"
],
ExVision.StyleTransfer.UdnieTest => [
module: ExVision.StyleTransfer.Udnie,
gt_file: "cat_udnie.gt"
],
ExVision.StyleTransfer.UdnieFastTest => [
module: ExVision.StyleTransfer.UdnieFast,
gt_file: "cat_udnie_fast.gt"
],
ExVision.StyleTransfer.MosaicTest => [
module: ExVision.StyleTransfer.Mosaic,
gt_file: "cat_mosaic.gt"
],
ExVision.StyleTransfer.MosaicFastTest => [
module: ExVision.StyleTransfer.MosaicFast,
gt_file: "cat_mosaic_fast.gt"
]
}
end
end

for {module, opts} <- TestConfiguration.configuration() do
defmodule module do
use ExVision.Model.Case, module: unquote(opts[:module])
use ExVision.TestUtils

@impl true
def test_inference_result(result) do
expected_result =
"test/assets/results/style_transfer/#{unquote(opts[:gt_file])}"
|> File.read!()
|> Nx.deserialize()

assert_tensors_equal(result, expected_result, 5, 0.05)
end
end
end
22 changes: 22 additions & 0 deletions test/support/exvision/test_utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ defmodule ExVision.TestUtils do
end
end

defmacro assert_tensors_equal(a, b, delta \\ @default_delta, relative_delta \\ 0.0) do
quote do
value_condition =
unquote(a)
|> Nx.all_close(unquote(b), atol: unquote(delta), rtol: unquote(relative_delta))
|> Nx.reduce_min()
|> Nx.to_number() == 1

equal_on_count =
unquote(a)
|> Nx.equal(unquote(b))
|> Nx.as_type(:u64)
|> Nx.reduce(0, fn x, y -> Nx.add(x, y) end)
|> Nx.to_number()

number_count = unquote(a) |> Nx.shape() |> Tuple.product()
proportional_condition = equal_on_count / number_count > 0.99

assert value_condition or proportional_condition
end
end

defmacro __using__(_opts) do
quote do
import ExVision.TestUtils, only: :macros
Expand Down
Loading