diff --git a/coremltools/converters/mil/frontend/torch/test/test_passes.py b/coremltools/converters/mil/frontend/torch/test/test_passes.py index 4401ebbfb..f2c83fe9e 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_passes.py +++ b/coremltools/converters/mil/frontend/torch/test/test_passes.py @@ -405,3 +405,25 @@ def forward(self, x): y_cm = ct_model.predict({'x': x})['y'] assert((y_cm == np.zeros(shape)).all()) + + @staticmethod + def test_inpace_op_from_add(): + class Net(torch.nn.Module): + def forward(self, x): + y = torch.empty(x.shape) + 1 + y.fill_(0) + return y + + shape = (2, 3) + x = torch.rand(*shape) + traced_fn = torch.jit.trace(Net(), x).eval() + + ct_model = ct.convert( + traced_fn, + inputs=[ct.TensorType(shape=shape)], + outputs=[ct.TensorType(name="y", dtype=np.int32)], + source="pytorch", + ) + y_cm = ct_model.predict({'x': x})['y'] + + assert((y_cm == np.zeros(shape)).all()) diff --git a/coremltools/converters/mil/frontend/torch/torchir_passes.py b/coremltools/converters/mil/frontend/torch/torchir_passes.py index d066d9a9d..100fe66a1 100644 --- a/coremltools/converters/mil/frontend/torch/torchir_passes.py +++ b/coremltools/converters/mil/frontend/torch/torchir_passes.py @@ -135,17 +135,7 @@ def _construct_nodes_to_fuse_inputs(nodes_to_fuse): tensor_to_node_sequence_mapping.pop(node_input) node_sequence.append(node) tensor_to_node_sequence_mapping[node_output] = node_sequence - - if node.kind == "to": - node_input = node.inputs[0] - if node_input in tensor_to_node_sequence_mapping: - # update the mapping - node_output = node.outputs[0] - val = tensor_to_node_sequence_mapping[node_input] - del tensor_to_node_sequence_mapping[node_input] - tensor_to_node_sequence_mapping[node_output] = val - - if node.kind in ("copy_", "fill_"): + elif node.kind in ("copy_", "fill_"): node_input = node.inputs[0] if node_input not in tensor_to_node_sequence_mapping: raise ValueError("No matching select or slice.") @@ -176,6 +166,13 @@ def _construct_nodes_to_fuse_inputs(nodes_to_fuse): blocks=[], ) graph.nodes[i] = tensor_assign_node + elif node.inputs: + node_input = node.inputs[0] + if node_input in tensor_to_node_sequence_mapping: + # update the mapping + node_output = node.outputs[0] + val = tensor_to_node_sequence_mapping[node_input] + tensor_to_node_sequence_mapping[node_output] = val # modify the graph outputs if it is effected by this graph pass for idx in range(len(graph.outputs)):