Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.fill_ can not apply after general function #1924

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,25 @@ def forward(self, x):
y_cm = ct_model.predict({'x': x})['y']

assert((y_cm == np.zeros(shape)).all())

@staticmethod
def test_inpace_op_from_add():
class Net(torch.nn.Module):
def forward(self, x):
y = torch.empty(x.shape) + 1
y.fill_(0)
return y

shape = (2, 3)
x = torch.rand(*shape)
traced_fn = torch.jit.trace(Net(), x).eval()

ct_model = ct.convert(
traced_fn,
inputs=[ct.TensorType(shape=shape)],
outputs=[ct.TensorType(name="y", dtype=np.int32)],
source="pytorch",
)
y_cm = ct_model.predict({'x': x})['y']

assert((y_cm == np.zeros(shape)).all())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove the outmost parentheses to be assert all(y_cm == np.zeros(shape))?

19 changes: 8 additions & 11 deletions coremltools/converters/mil/frontend/torch/torchir_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,7 @@ def _construct_nodes_to_fuse_inputs(nodes_to_fuse):
tensor_to_node_sequence_mapping.pop(node_input)
node_sequence.append(node)
tensor_to_node_sequence_mapping[node_output] = node_sequence

if node.kind == "to":
node_input = node.inputs[0]
if node_input in tensor_to_node_sequence_mapping:
# update the mapping
node_output = node.outputs[0]
val = tensor_to_node_sequence_mapping[node_input]
del tensor_to_node_sequence_mapping[node_input]
tensor_to_node_sequence_mapping[node_output] = val

if node.kind in ("copy_", "fill_"):
elif node.kind in ("copy_", "fill_"):
node_input = node.inputs[0]
if node_input not in tensor_to_node_sequence_mapping:
raise ValueError("No matching select or slice.")
Expand Down Expand Up @@ -176,6 +166,13 @@ def _construct_nodes_to_fuse_inputs(nodes_to_fuse):
blocks=[],
)
graph.nodes[i] = tensor_assign_node
elif node.inputs:
node_input = node.inputs[0]
if node_input in tensor_to_node_sequence_mapping:
# update the mapping
node_output = node.outputs[0]
val = tensor_to_node_sequence_mapping[node_input]
tensor_to_node_sequence_mapping[node_output] = val
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share your reasoning on why this would be generally applicable for nodes in addition to to?

Also, what is the reason to delete line del tensor_to_node_sequence_mapping[node_input]? With that line, it makes tensor_to_node_sequence_mapping[node_output] to be ordered as latest

Copy link
Contributor Author

@fukatani fukatani Jul 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason to delete line del tensor_to_node_sequence_mapping[node_input]

with del tensor_to_node_sequence_mapping[node_input], if cast source is referenced more than once, ValueError will be raised.

class Net(torch.nn.Module):
    def forward(self, x):
        y = torch.empty(x.shape)
        z = y.to(torch.int32)
        w = y.to(torch.int32)
        return z, w

stacktrace

Traceback (most recent call last):
  File "/Users/ryosukefukatani/work/coremltools/onth11.py", line 21, in <module>
    ct.TensorType("x", shape=(ct.RangeDim(), ct.RangeDim())),
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/_converters_entry.py", line 542, in convert
    main_pipeline=pass_pipeline,
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 217, in _mil_convert
    **kwargs
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 61, in load
    specification_version,
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 335, in __init__
    p(self.graph)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/torchir_passes.py", line 141, in generate_tensor_assignment_ops
    raise ValueError("No matching select or slice.")
ValueError: No matching select or slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share your reasoning on why this would be generally applicable for nodes in addition to to?

We need general solution of #1917

ex.

class Net(torch.nn.Module):
    def forward(self, x):
        y = torch.empty(x.shape) + 1
        y.fill_(0.0)
        return y
class Net(torch.nn.Module):
    def forward(self, x):
        y = torch.empty(x.shape) - 1
        y.fill_(0.0)
        return y
class Net(torch.nn.Module):
    def forward(self, x):
        y = torch.empty(x.shape) .flatten()
        y.fill_(0.0)
        return y

and all other operations.


# modify the graph outputs if it is effected by this graph pass
for idx in range(len(graph.outputs)):
Expand Down