diff --git a/lib/ex_vision/model/definition.ex b/lib/ex_vision/model/definition.ex index 3ae2641..79d4469 100644 --- a/lib/ex_vision/model/definition.ex +++ b/lib/ex_vision/model/definition.ex @@ -25,15 +25,15 @@ defmodule ExVision.Model.Definition do Application.ensure_all_started(:req) options = - Keyword.validate!(options, [ - :categories, + Keyword.validate!(options, + categories: nil, name: module_to_name(__CALLER__.module) - ]) + ) quote do - unless is_nil(unquote(options[:categories])) do - use ExVision.Model.Definition.Parts.WithCategories, unquote(options) - end + # conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro + # here is explenation https://cocoa-research.works/2022/10/conditional-compliation-with-if-and-use-in-elixir/ + use ExVision.Model.Definition.Parts.WithCategories, unquote(options) @behaviour ExVision.Model.Definition diff --git a/lib/ex_vision/model/definition/parts/with_categories.ex b/lib/ex_vision/model/definition/parts/with_categories.ex index 81d552b..2b4481a 100644 --- a/lib/ex_vision/model/definition/parts/with_categories.ex +++ b/lib/ex_vision/model/definition/parts/with_categories.ex @@ -5,22 +5,25 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do defmacro __using__(options) do options = Keyword.validate!(options, [:name, :categories]) - categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories() - spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative() - quote do - require Bunch.Typespec + unless is_nil(options |> Keyword.fetch!(:categories)) do + categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories() + spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative() - @typedoc """ - Type describing all categories recognised by #{unquote(options[:name])} - """ - @type category_t() :: unquote(spec) + quote do + require Bunch.Typespec - @doc """ - Returns a list of all categories recognised by #{unquote(options[:name])} - """ - @spec categories() :: [category_t()] - def categories(), do: unquote(categories) + @typedoc """ + Type describing all categories recognised by #{unquote(options[:name])} + """ + @type category_t() :: unquote(spec) + + @doc """ + Returns a list of all categories recognised by #{unquote(options[:name])} + """ + @spec categories() :: [category_t()] + def categories(), do: unquote(categories) + end end end end diff --git a/lib/ex_vision/style_transfer/style_transfer.ex b/lib/ex_vision/style_transfer/style_transfer.ex new file mode 100644 index 0000000..297dc9c --- /dev/null +++ b/lib/ex_vision/style_transfer/style_transfer.ex @@ -0,0 +1,75 @@ +defmodule Configuration do + @moduledoc false + + @low_resolution {400, 300} + @high_resolution {640, 480} + + @spec configuration() :: %{} + def configuration do + %{ + ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution], + ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.PrincessFast => [ + model: "princess_fast.onnx", + resolution: @low_resolution + ], + ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution], + ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.MosaicFast => [ + model: "mosaic_fast.onnx", + resolution: @low_resolution + ] + } + end +end + +for {module, opts} <- Configuration.configuration() do + defmodule module do + @moduledoc """ + #{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference. + """ + use ExVision.Model.Definition.Ortex, model: unquote(opts[:model]) + + require Logger + + @typedoc """ + A type consisting of output tesnor (stylized image tensor) from style transfer models of shape {#{Enum.join(Tuple.to_list(opts[:resolution]) ++ [3], ", ")}}. + """ + @type output_t() :: Nx.Tensor.t() + + @impl true + def load(options \\ []) do + if Keyword.has_key?(options, :batch_size) do + Logger.warning( + "`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" + ) + end + + options + |> Keyword.put(:batch_size, 1) + |> default_model_load() + end + + @impl true + def preprocessing(img, _metdata) do + img |> ExVision.Utils.resize(unquote(opts[:resolution])) |> Nx.divide(255.0) + end + + @impl true + def postprocessing( + stylized_frame, + metadata + ) do + {h, w} = unquote(opts[:resolution]) + + stylized_frame["55"] + |> Nx.reshape({3, h, w}, names: [:channel, :height, :width]) + |> NxImage.resize(metadata.original_size, channels: :first, method: :bilinear) + |> Nx.clip(0.0, 255.0) + |> Nx.as_type(:u8) + |> Nx.transpose(axes: [1, 2, 0]) + end + end +end diff --git a/mix.exs b/mix.exs index 2e355fe..ed3fa08 100644 --- a/mix.exs +++ b/mix.exs @@ -99,6 +99,14 @@ defmodule ExVision.Mixfile do ExVision.Classification.EfficientNet_V2_L, ExVision.Classification.SqueezeNet1_1, ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3, + ExVision.StyleTransfer.Candy, + ExVision.StyleTransfer.CandyFast, + ExVision.StyleTransfer.Udnie, + ExVision.StyleTransfer.UdnieFast, + ExVision.StyleTransfer.Mosaic, + ExVision.StyleTransfer.MosaicFast, + ExVision.StyleTransfer.Princess, + ExVision.StyleTransfer.PrincessFast, ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2, ExVision.ObjectDetection.Ssdlite320_MobileNetv3, ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN, @@ -123,6 +131,7 @@ defmodule ExVision.Mixfile do ExVision.Types, ExVision.Classification, ExVision.SemanticSegmentation, + ExVision.StyleTransfer, ExVision.InstanceSegmentation, ExVision.ObjectDetection, ExVision.KeypointDetection diff --git a/mix.lock b/mix.lock index 9931294..c45bba9 100644 --- a/mix.lock +++ b/mix.lock @@ -6,7 +6,7 @@ "cc_precompiler": {:hex, :cc_precompiler, "0.1.10", "47c9c08d8869cf09b41da36538f62bc1abd3e19e41701c2cea2675b53c704258", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f6e046254e53cd6b41c6bacd70ae728011aa82b2742a80d6e2214855c6e06b22"}, "coerce": {:hex, :coerce, "1.0.1", "211c27386315dc2894ac11bc1f413a0e38505d808153367bd5c6e75a4003d096", [:mix], [], "hexpm", "b44a691700f7a1a15b4b7e2ff1fa30bebd669929ac8aa43cffe9e2f8bf051cf1"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, - "credo": {:hex, :credo, "1.7.5", "643213503b1c766ec0496d828c90c424471ea54da77c8a168c725686377b9545", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "f799e9b5cd1891577d8c773d245668aa74a2fcd15eb277f51a0131690ebfb3fd"}, + "credo": {:hex, :credo, "1.7.7", "771445037228f763f9b2afd612b6aa2fd8e28432a95dbbc60d8e03ce71ba4446", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "8bc87496c9aaacdc3f90f01b7b0582467b69b4bd2441fe8aae3109d843cc2f2e"}, "dialyxir": {:hex, :dialyxir, "1.4.3", "edd0124f358f0b9e95bfe53a9fcf806d615d8f838e2202a9f430d59566b6b53b", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "bf2cfb75cd5c5006bec30141b131663299c661a864ec7fbbc72dfa557487a986"}, "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, @@ -14,11 +14,11 @@ "evision": {:hex, :evision, "0.1.38", "f8b23ad685c3ebd70969a3457027b5c74b5bc8dc51588661c516098c3240b92d", [:make, :mix, :rebar3], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}, {:progress_bar, "~> 2.0 or ~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: true]}], "hexpm", "f9302547d76c5e4ad7022ffdc76be13e33c990fdd67ad2af203f24ab5d3aee20"}, "ex_doc": {:hex, :ex_doc, "0.32.1", "21e40f939515373bcdc9cffe65f3b3543f05015ac6c3d01d991874129d173420", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5142c9db521f106d61ff33250f779807ed2a88620e472ac95dc7d59c380113da"}, "exla": {:hex, :exla, "0.7.2", "8ac573093df8e5e6b36845beeb3f5a0ea92b05082bf2fa4678f80170cfc887f6", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d061ea87858415e5585cbd4b7bdae5489000339519a2c6a7f51eb0defd73b588"}, - "file_system": {:hex, :file_system, "1.0.0", "b689cc7dcee665f774de94b5a832e578bd7963c8e637ef940cd44327db7de2cd", [:mix], [], "hexpm", "6752092d66aec5a10e662aefeed8ddb9531d79db0bc145bb8c40325ca1d8536d"}, + "file_system": {:hex, :file_system, "1.0.1", "79e8ceaddb0416f8b8cd02a0127bdbababe7bf4a23d2a395b983c1f8b3f73edd", [:mix], [], "hexpm", "4414d1f38863ddf9120720cd976fce5bdde8e91d8283353f0e31850fa89feb9e"}, "finch": {:hex, :finch, "0.18.0", "944ac7d34d0bd2ac8998f79f7a811b21d87d911e77a786bc5810adb75632ada4", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "69f5045b042e531e53edc2574f15e25e735b522c37e2ddb766e15b979e03aa65"}, "hpax": {:hex, :hpax, "0.2.0", "5a58219adcb75977b2edce5eb22051de9362f08236220c9e859a47111c194ff5", [:mix], [], "hexpm", "bea06558cdae85bed075e6c036993d43cd54d447f76d8190a8db0dc5893fa2f1"}, "image": {:hex, :image, "0.44.0", "e8eea9398abbed12b7784e786f26a5c839a00bcddd8f2f8ba12adf7e227beb9f", [:mix], [{:bumblebee, "~> 0.3", [hex: :bumblebee, repo: "hexpm", optional: true]}, {:evision, "~> 0.1.33", [hex: :evision, repo: "hexpm", optional: true]}, {:exla, "~> 0.5", [hex: :exla, repo: "hexpm", optional: true]}, {:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: true]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: true]}, {:nx_image, "~> 0.1", [hex: :nx_image, repo: "hexpm", optional: true]}, {:phoenix_html, "~> 2.1 or ~> 3.2 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:plug, "~> 1.13", [hex: :plug, repo: "hexpm", optional: true]}, {:req, "~> 0.4", [hex: :req, repo: "hexpm", optional: true]}, {:rustler, "> 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:sweet_xml, "~> 0.7", [hex: :sweet_xml, repo: "hexpm", optional: false]}, {:vix, "~> 0.23", [hex: :vix, repo: "hexpm", optional: false]}], "hexpm", "cd00a3de4d7a40a2cb1ca72b9852b0d81701793414af8babf4d33dbeb6de0f6f"}, - "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.5", "e0ff5a7c708dda34311f7522a8758e23bfcd7d8d8068dc312b5eb41c6fd76eba", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "94d2e986428585a21516d7d7149781480013c56e30c6a233534bedf38867a59a"}, diff --git a/test/assets/results/style_transfer/cat_candy.gt b/test/assets/results/style_transfer/cat_candy.gt new file mode 100644 index 0000000..c2afd55 Binary files /dev/null and b/test/assets/results/style_transfer/cat_candy.gt differ diff --git a/test/assets/results/style_transfer/cat_candy_fast.gt b/test/assets/results/style_transfer/cat_candy_fast.gt new file mode 100644 index 0000000..6182c9c Binary files /dev/null and b/test/assets/results/style_transfer/cat_candy_fast.gt differ diff --git a/test/assets/results/style_transfer/cat_mosaic.gt b/test/assets/results/style_transfer/cat_mosaic.gt new file mode 100644 index 0000000..f56190e Binary files /dev/null and b/test/assets/results/style_transfer/cat_mosaic.gt differ diff --git a/test/assets/results/style_transfer/cat_mosaic_fast.gt b/test/assets/results/style_transfer/cat_mosaic_fast.gt new file mode 100644 index 0000000..9d951d3 Binary files /dev/null and b/test/assets/results/style_transfer/cat_mosaic_fast.gt differ diff --git a/test/assets/results/style_transfer/cat_princess.gt b/test/assets/results/style_transfer/cat_princess.gt new file mode 100644 index 0000000..dc36a3c Binary files /dev/null and b/test/assets/results/style_transfer/cat_princess.gt differ diff --git a/test/assets/results/style_transfer/cat_princess_fast.gt b/test/assets/results/style_transfer/cat_princess_fast.gt new file mode 100644 index 0000000..10e0938 Binary files /dev/null and b/test/assets/results/style_transfer/cat_princess_fast.gt differ diff --git a/test/assets/results/style_transfer/cat_udnie.gt b/test/assets/results/style_transfer/cat_udnie.gt new file mode 100644 index 0000000..23a300c Binary files /dev/null and b/test/assets/results/style_transfer/cat_udnie.gt differ diff --git a/test/assets/results/style_transfer/cat_udnie_fast.gt b/test/assets/results/style_transfer/cat_udnie_fast.gt new file mode 100644 index 0000000..be67867 Binary files /dev/null and b/test/assets/results/style_transfer/cat_udnie_fast.gt differ diff --git a/test/ex_vision/style_transfer/style_transfer_test.exs b/test/ex_vision/style_transfer/style_transfer_test.exs new file mode 100644 index 0000000..f9c197e --- /dev/null +++ b/test/ex_vision/style_transfer/style_transfer_test.exs @@ -0,0 +1,56 @@ +defmodule TestConfiguration do + @spec configuration() :: %{} + def configuration do + %{ + ExVision.StyleTransfer.CandyTest => [ + module: ExVision.StyleTransfer.Candy, + gt_file: "cat_candy.gt" + ], + ExVision.StyleTransfer.CandyFastTest => [ + module: ExVision.StyleTransfer.CandyFast, + gt_file: "cat_candy_fast.gt" + ], + ExVision.StyleTransfer.PrincessTest => [ + module: ExVision.StyleTransfer.Princess, + gt_file: "cat_princess.gt" + ], + ExVision.StyleTransfer.PrincessFastTest => [ + module: ExVision.StyleTransfer.PrincessFast, + gt_file: "cat_princess_fast.gt" + ], + ExVision.StyleTransfer.UdnieTest => [ + module: ExVision.StyleTransfer.Udnie, + gt_file: "cat_udnie.gt" + ], + ExVision.StyleTransfer.UdnieFastTest => [ + module: ExVision.StyleTransfer.UdnieFast, + gt_file: "cat_udnie_fast.gt" + ], + ExVision.StyleTransfer.MosaicTest => [ + module: ExVision.StyleTransfer.Mosaic, + gt_file: "cat_mosaic.gt" + ], + ExVision.StyleTransfer.MosaicFastTest => [ + module: ExVision.StyleTransfer.MosaicFast, + gt_file: "cat_mosaic_fast.gt" + ] + } + end +end + +for {module, opts} <- TestConfiguration.configuration() do + defmodule module do + use ExVision.Model.Case, module: unquote(opts[:module]) + use ExVision.TestUtils + + @impl true + def test_inference_result(result) do + expected_result = + "test/assets/results/style_transfer/#{unquote(opts[:gt_file])}" + |> File.read!() + |> Nx.deserialize() + + assert_tensors_equal(result, expected_result, 5, 0.05) + end + end +end diff --git a/test/support/exvision/test_utils.ex b/test/support/exvision/test_utils.ex index 7213968..8e95244 100644 --- a/test/support/exvision/test_utils.ex +++ b/test/support/exvision/test_utils.ex @@ -42,6 +42,28 @@ defmodule ExVision.TestUtils do end end + defmacro assert_tensors_equal(a, b, delta \\ @default_delta, relative_delta \\ 0.0) do + quote do + value_condition = + unquote(a) + |> Nx.all_close(unquote(b), atol: unquote(delta), rtol: unquote(relative_delta)) + |> Nx.reduce_min() + |> Nx.to_number() == 1 + + equal_on_count = + unquote(a) + |> Nx.equal(unquote(b)) + |> Nx.as_type(:u64) + |> Nx.reduce(0, fn x, y -> Nx.add(x, y) end) + |> Nx.to_number() + + number_count = unquote(a) |> Nx.shape() |> Tuple.product() + proportional_condition = equal_on_count / number_count > 0.99 + + assert value_condition or proportional_condition + end + end + defmacro __using__(_opts) do quote do import ExVision.TestUtils, only: :macros