diff --git a/onnxscript/function_libs/torch_lib/_flags.py b/onnxscript/function_libs/torch_lib/_flags.py index fcdc00f32..79593f346 100644 --- a/onnxscript/function_libs/torch_lib/_flags.py +++ b/onnxscript/function_libs/torch_lib/_flags.py @@ -51,8 +51,3 @@ def _load_boolean_flag( this_will="trace all traceable functions to fold if branches and collapse constant expressions", default=True, ) -EXPERIMENTAL_USE_IR: bool = _load_boolean_flag( - "TORCHLIB_EXPERIMENTAL_USE_IR", - this_will="use the ONNX IR instead of the PyTorch Graph for graph building", - deprecated=True, -) diff --git a/onnxscript/function_libs/torch_lib/graph_building/__init__.py b/onnxscript/function_libs/torch_lib/graph_building/__init__.py index 58acc6c05..656e24977 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/__init__.py +++ b/onnxscript/function_libs/torch_lib/graph_building/__init__.py @@ -40,17 +40,9 @@ "TorchScriptTracingEvaluator", ] -from onnxscript.function_libs.torch_lib import _flags - -if _flags.EXPERIMENTAL_USE_IR: - from ._graph_building_ir import ( - TorchScriptGraph, - TorchScriptTensor, - TorchScriptTracingEvaluator, - ) -else: - from ._graph_building_torch import ( # type: ignore[assignment] - TorchScriptGraph, - TorchScriptTensor, - TorchScriptTracingEvaluator, - ) + +from ._graph_building_ir import ( + TorchScriptGraph, + TorchScriptTensor, + TorchScriptTracingEvaluator, +) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py index 1270c6376..1d6d8567a 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py @@ -237,9 +237,6 @@ def eval_function( # type: ignore[override] else: # Python constants are scalars return 0 - elif function.traceable: - # Trace the function call instead of adding the function as a node - return function.function(*args, **kwargs) # args/kwargs are TorchScriptTensor/python built-in based param_schemas = function.param_schemas() @@ -269,6 +266,10 @@ def eval_function( # type: ignore[override] value, float ): attributes[name] = (value,) + if function.traceable: + inputs = self._graph.precprocess_inputs(inputs, attributes) + # Trace the function call instead of adding the function as a node + return function.function(*inputs, **attributes) return self._graph.add_function_call(function, inputs, attributes) @@ -522,15 +523,11 @@ def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]: ) return value - def _add_ir_graph_op_call( + def precprocess_inputs( self, - *, - domain: str, - op_type: str, onnx_inputs: Sequence[ValidInputType], onnx_attributes: Mapping[str, ValidArgumentType], - num_outputs: int, - ) -> Sequence[TorchScriptTensor]: + ) -> list[TorchScriptTensor]: graph_inputs: list[TorchScriptTensor] = [] assert isinstance(onnx_inputs, Sequence) for input in onnx_inputs: @@ -559,6 +556,18 @@ def _add_ir_graph_op_call( assert not isinstance( value, TorchScriptTensor ), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}." + return graph_inputs + + def _add_ir_graph_op_call( + self, + *, + domain: str, + op_type: str, + onnx_inputs: Sequence[ValidInputType], + onnx_attributes: Mapping[str, ValidArgumentType], + num_outputs: int, + ) -> Sequence[TorchScriptTensor]: + graph_inputs = self.precprocess_inputs(onnx_inputs, onnx_attributes) tensors = _create_op_call_in_graph( self._graph, domain,