From ccd123138c7585260a6e18f8a1fc3878b0f977de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sercan=20De=C4=9Firmenci?= Date: Mon, 31 Jul 2023 20:35:01 +0300 Subject: [PATCH] Fix GroupNorm dynamic input (#1922) fix groupnorm dynamic input --- .../converters/mil/frontend/torch/ops.py | 11 +++++-- .../mil/frontend/torch/test/test_torch_ops.py | 30 +++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 50c461c76..3a9dd363c 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1828,10 +1828,17 @@ def group_norm(context, node): bias = inputs[3] eps = inputs[4] n,c = x.shape[0],x.shape[1] # at minimum (N, C) required - input_shape = [*x.shape] # n, c, * num_groups = builtins.min(num_groups,c) new_shape = [n, num_groups, c//num_groups] - new_shape += [*x.shape[2:]] # adds remaining dims + # optimization for non symbolic shapes. This get rids of 3 mil ops that required on dynamic shapes + if not any_symbolic(x.shape[2:]): + new_shape += [*x.shape[2:]] # adds remaining dims + input_shape = [*x.shape] # n, c, * + else: + input_shape = mb.shape(x=x) + input_shape_sliced = mb.slice_by_size(x=input_shape, begin=[2], size=[-1]) # x_shape[2:] + new_shape = mb.concat(values=[new_shape, input_shape_sliced], axis=0) + num_extra_axes = len(x.shape[2:]) axes_ = [int(i) for i in range(2, 2 + num_extra_axes + 1)] weight_shape, bias_shape = [1,c], [1,c] diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 926f565ae..174a19b04 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -1285,6 +1285,36 @@ def test_groupnorm_rank2_input( compute_unit=compute_unit, ) + @pytest.mark.parametrize( + "compute_unit, backend, group_features, eps, affine", + itertools.product( + compute_units, backends, [(16, 32), (1, 1)], [0.1, 1e-05], [True, False] + ), + ) + def test_groupnorm_dynamic(self, compute_unit, backend, group_features, eps, affine): + model = nn.GroupNorm( + group_features[0], group_features[1], eps=eps, affine=affine + ) + dim_upper_bound = 30 if backend[0] == "mlprogram" else -1 + converter_input_type = [ + TensorType( + shape=( + 6, + group_features[1], + RangeDim(default=10, lower_bound=5, upper_bound=dim_upper_bound), + RangeDim(default=10, lower_bound=5, upper_bound=dim_upper_bound), + ), + dtype=np.float32, + ) + ] + self.run_compare_torch( + (6, group_features[1], 10, 10), + model, + backend=backend, + compute_unit=compute_unit, + converter_input_type=converter_input_type, + ) + class TestLinear(TorchBaseTest): @pytest.mark.parametrize(