Skip to content

Commit

Permalink
masked_fill: fill value type must match tensor type (#1915)
Browse files Browse the repository at this point in the history
Fix a bug in torch masked_fill op translation
  • Loading branch information
pcuenca authored Jul 25, 2023
1 parent d9434fc commit e0f8918
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
4 changes: 4 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from coremltools.converters.mil.mil.types import is_bool, nptype_from_builtin
from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic
from coremltools.converters.mil.mil.types.type_mapping import builtin_to_string
from coremltools.converters.mil.mil.var import ListVar, Var

from .._utils import build_einsum_mil, value_at
Expand Down Expand Up @@ -4251,6 +4252,9 @@ def masked_fill(context, node):
# cond must be bool type
mask = mb.cast(x=mask, dtype="bool")

if value.dtype != x.dtype:
value = mb.cast(x=value, dtype=builtin_to_string(x.dtype))

res = mb.select(cond=mask, a=value, b=x, name=node.name)
context.add(res)

Expand Down
21 changes: 15 additions & 6 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7997,21 +7997,30 @@ def test_constant_pad_3d(self, compute_unit, backend):

class TestMaskedFill(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(compute_units, backends),
"compute_unit, backend, dtype, value",
itertools.product(
compute_units,
backends,
[np.int32, np.float32],
[10.3, 7, 0],
),
)
def test_masked_fill(self, compute_unit, backend):
def test_masked_fill(self, compute_unit, backend, dtype, value):
SHAPE = (2, 3)
MASK = torch.bernoulli(torch.rand(SHAPE[-1])).to(torch.bool)
VALUE = 10.0

model = ModuleWrapper(torch.masked_fill, {"mask": MASK, "value": VALUE})
input_data = np.random.randint(-100, 100, SHAPE).astype(dtype)
input_data = torch.from_numpy(input_data)
model = ModuleWrapper(torch.masked_fill, {"mask": MASK, "value": value})
converter_input_type = [TensorType(shape=SHAPE, dtype=dtype)]

TorchBaseTest.run_compare_torch(
SHAPE,
input_data,
model,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
converter_input_type=converter_input_type,
)


Expand Down

0 comments on commit e0f8918

Please sign in to comment.