Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace softmax with stable version + small corrections #8

Merged
merged 9 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_l.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_L do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.5, 0.5, 0.5]),
Nx.tensor([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_m.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_M do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_s.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_S do
image
|> ExVision.Utils.resize({384, 384})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 1 addition & 3 deletions lib/ex_vision/classification/generic_classifier.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ defmodule ExVision.Classification.GenericClassifier do
# Contains a default implementation of post processing for TorchVision classifiers
# To use: `use ExVision.Classification.GenericClassifier`

alias ExVision.Utils

alias ExVision.Types.ImageMetadata

@typep output_t() :: %{atom() => number()}
Expand All @@ -15,7 +13,7 @@ defmodule ExVision.Classification.GenericClassifier do
scores
|> Nx.backend_transfer()
|> Nx.flatten()
|> Utils.softmax()
|> Axon.Activations.softmax(axis: [0])
|> Nx.to_flat_list()
|> then(&Enum.zip(categories, &1))
|> Map.new()
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.MobileNetV3Small do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/squeezenet1_1.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.SqueezeNet1_1 do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
14 changes: 5 additions & 9 deletions lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx",
categories: "priv/categories/coco_categories.json"

import ExVision.Utils

require Logger

alias ExVision.Types.BBoxWithMask
Expand Down Expand Up @@ -46,16 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

masks =
msluszniak marked this conversation as resolved.
Show resolved Hide resolved
masks
Expand Down
30 changes: 10 additions & 20 deletions lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
model: "keypointrcnn_resnet50_fpn_keypoint_detector.onnx",
categories: "priv/categories/no_person_or_person.json"

import ExVision.Utils

require Logger

alias ExVision.Types.BBoxWithKeypoints
Expand Down Expand Up @@ -67,26 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()

keypoints_list =
keypoints_list
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, 1]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

keypoints_scores_list = keypoints_scores_list |> Nx.squeeze(axes: [0]) |> Nx.to_list()
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

keypoints_list = scale_and_listify_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1]))

keypoints_scores_list = squeeze_and_listify(keypoints_scores_list)

[bboxes, scores, labels, keypoints_list, keypoints_scores_list]
|> Enum.zip()
Expand Down
14 changes: 5 additions & 9 deletions lib/ex_vision/object_detection/generic_detector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ defmodule ExVision.ObjectDetection.GenericDetector do
# Contains a default implementation of pre and post processing for TorchVision detectors
# To use: `use ExVision.ObjectDetection.GenericDetector`

import ExVision.Utils

require Logger

alias ExVision.Types.{BBox, ImageMetadata}
Expand All @@ -29,16 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

[bboxes, scores, labels]
|> Enum.zip()
Expand Down
23 changes: 16 additions & 7 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
defmodule ExVision.Utils do
@moduledoc false

import Nx.Defn
require Nx
require Image
alias ExVision.Types
Expand Down Expand Up @@ -86,8 +85,7 @@ defmodule ExVision.Utils do

defp ensure_grad_3(tensor) do
tensor
|> Nx.shape()
|> tuple_size()
|> Nx.rank()
|> case do
3 -> [tensor]
4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list()
Expand Down Expand Up @@ -149,10 +147,6 @@ defmodule ExVision.Utils do
Enum.map(outputs, fn {name, _type, _shape} -> name end)
end

defn softmax(x) do
Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x)))
end

@spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t()
def batched_run(process_name, input) when is_list(input) do
Nx.Serving.batched_run(process_name, input)
Expand All @@ -161,4 +155,19 @@ defmodule ExVision.Utils do
def batched_run(process_name, input) do
process_name |> batched_run([input]) |> hd()
end

@spec scale_and_listify_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()]
def scale_and_listify_bbox(bbox, scales) do
bbox
|> Nx.squeeze(axes: [0])
|> Nx.multiply(scales)
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
end

@spec squeeze_and_listify(Nx.Tensor.t()) :: [number()]
def squeeze_and_listify(batched_value) do
batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list()
end
end
Loading