Skip to content

Commit

Permalink
Fix GroupNorm dynamic input (#1922)
Browse files Browse the repository at this point in the history
fix groupnorm dynamic input
  • Loading branch information
sercand authored Jul 31, 2023
1 parent 3b6b963 commit ccd1231
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
11 changes: 9 additions & 2 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,10 +1828,17 @@ def group_norm(context, node):
bias = inputs[3]
eps = inputs[4]
n,c = x.shape[0],x.shape[1] # at minimum (N, C) required
input_shape = [*x.shape] # n, c, *
num_groups = builtins.min(num_groups,c)
new_shape = [n, num_groups, c//num_groups]
new_shape += [*x.shape[2:]] # adds remaining dims
# optimization for non symbolic shapes. This get rids of 3 mil ops that required on dynamic shapes
if not any_symbolic(x.shape[2:]):
new_shape += [*x.shape[2:]] # adds remaining dims
input_shape = [*x.shape] # n, c, *
else:
input_shape = mb.shape(x=x)
input_shape_sliced = mb.slice_by_size(x=input_shape, begin=[2], size=[-1]) # x_shape[2:]
new_shape = mb.concat(values=[new_shape, input_shape_sliced], axis=0)

num_extra_axes = len(x.shape[2:])
axes_ = [int(i) for i in range(2, 2 + num_extra_axes + 1)]
weight_shape, bias_shape = [1,c], [1,c]
Expand Down
30 changes: 30 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,36 @@ def test_groupnorm_rank2_input(
compute_unit=compute_unit,
)

@pytest.mark.parametrize(
"compute_unit, backend, group_features, eps, affine",
itertools.product(
compute_units, backends, [(16, 32), (1, 1)], [0.1, 1e-05], [True, False]
),
)
def test_groupnorm_dynamic(self, compute_unit, backend, group_features, eps, affine):
model = nn.GroupNorm(
group_features[0], group_features[1], eps=eps, affine=affine
)
dim_upper_bound = 30 if backend[0] == "mlprogram" else -1
converter_input_type = [
TensorType(
shape=(
6,
group_features[1],
RangeDim(default=10, lower_bound=5, upper_bound=dim_upper_bound),
RangeDim(default=10, lower_bound=5, upper_bound=dim_upper_bound),
),
dtype=np.float32,
)
]
self.run_compare_torch(
(6, group_features[1], 10, 10),
model,
backend=backend,
compute_unit=compute_unit,
converter_input_type=converter_input_type,
)


class TestLinear(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down

0 comments on commit ccd1231

Please sign in to comment.