Skip to content

Commit

Permalink
Add support for int32 beta in torch.baddbmm (#1925)
Browse files Browse the repository at this point in the history
  • Loading branch information
comeweber authored Jul 28, 2023
1 parent e0f8918 commit 3b6b963
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
27 changes: 19 additions & 8 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5439,21 +5439,32 @@ def baddbmm(context, node):
inputs = _get_inputs(context, node, expected=5)
bias, batch1, batch2, beta, alpha = inputs

if beta.val != 1.0:
# Apply scaling factor beta to the bias.
bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled")
context.add(bias)

if alpha.val != 1.0:
# Apply scaling factor alpha to the input.
batch1 = mb.mul(x=alpha, y=batch1, name=batch1.name + "_scaled")
context.add(batch1)

bmm_node = mb.matmul(x=batch1, y=batch2, name=node.name + "_bmm")
context.add(bmm_node)

baddbmm_node = mb.add(x=bias, y=bmm_node, name=node.name)
context.add(baddbmm_node)
if beta.val != 0.0 or bias.shape != bmm_node.shape:
context.add(bmm_node)
if beta.val != 1.0:
# Torch supports integers, so convert to float before
if beta.dtype != bias.dtype:
logger.warning(
f"Casting the `beta`(value={beta.val}) argument of `baddbmm` op {node.name} "
f"from {beta.dtype} to {bias.dtype} dtype")
beta = mb.cast(x=beta, dtype=TYPE_TO_DTYPE_STRING[bias.dtype])
# Apply scaling factor beta to the bias.
bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled")
context.add(bias)

baddbmm_node = mb.add(x=bias, y=bmm_node, name=node.name)
context.add(baddbmm_node)
else:
bmm_node.name = node.name
context.add(bmm_node)



@register_torch_op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8400,14 +8400,15 @@ def forward(self, x):

class TestBaddbmm(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, shapes",
"compute_unit, backend, shapes, beta",
itertools.product(
compute_units,
backends,
[(2, 4, 6, 8), (4, 12, 6, 16)],
[0.0, 0.5, 1.0, 2],
),
)
def test_baddbmm(self, compute_unit, backend, shapes):
def test_baddbmm(self, compute_unit, backend, shapes, beta):
B, N, M, P = shapes

# input shape: any shape broadcastable to (B, N, P)
Expand All @@ -8421,7 +8422,7 @@ def __init__(self):
self.batch2 = torch.randn(B, M, P)

def forward(self, x):
return torch.baddbmm(x, self.batch1, self.batch2)
return torch.baddbmm(x, self.batch1, self.batch2, beta=beta)

model = BaddbmmModel()
# Makes it broadcastable to (B, N, P).
Expand Down

0 comments on commit 3b6b963

Please sign in to comment.