Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Use traced function param schema to process inputs #1915

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,3 @@ def _load_boolean_flag(
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
default=True,
)
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_USE_IR",
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
deprecated=True,
)
20 changes: 6 additions & 14 deletions onnxscript/function_libs/torch_lib/graph_building/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,9 @@
"TorchScriptTracingEvaluator",
]

from onnxscript.function_libs.torch_lib import _flags

if _flags.EXPERIMENTAL_USE_IR:
from ._graph_building_ir import (
TorchScriptGraph,
TorchScriptTensor,
TorchScriptTracingEvaluator,
)
else:
from ._graph_building_torch import ( # type: ignore[assignment]
TorchScriptGraph,
TorchScriptTensor,
TorchScriptTracingEvaluator,
)

from ._graph_building_ir import (
TorchScriptGraph,
TorchScriptTensor,
TorchScriptTracingEvaluator,
)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@ def eval_function( # type: ignore[override]
else:
# Python constants are scalars
return 0
elif function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

# args/kwargs are TorchScriptTensor/python built-in based
param_schemas = function.param_schemas()
Expand Down Expand Up @@ -269,6 +266,10 @@ def eval_function( # type: ignore[override]
value, float
):
attributes[name] = (value,)
if function.traceable:
inputs = self._graph.precprocess_inputs(inputs, attributes)
# Trace the function call instead of adding the function as a node
return function.function(*inputs, **attributes)
return self._graph.add_function_call(function, inputs, attributes)


Expand Down Expand Up @@ -522,15 +523,11 @@ def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]:
)
return value

def _add_ir_graph_op_call(
def precprocess_inputs(
self,
*,
domain: str,
op_type: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
num_outputs: int,
) -> Sequence[TorchScriptTensor]:
) -> list[TorchScriptTensor]:
graph_inputs: list[TorchScriptTensor] = []
assert isinstance(onnx_inputs, Sequence)
for input in onnx_inputs:
Expand Down Expand Up @@ -559,6 +556,18 @@ def _add_ir_graph_op_call(
assert not isinstance(
value, TorchScriptTensor
), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
return graph_inputs

def _add_ir_graph_op_call(
self,
*,
domain: str,
op_type: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
num_outputs: int,
) -> Sequence[TorchScriptTensor]:
graph_inputs = self.precprocess_inputs(onnx_inputs, onnx_attributes)
tensors = _create_op_call_in_graph(
self._graph,
domain,
Expand Down
Loading