Skip to content

Commit

Permalink
[torchlib] Use traced function param schema to process inputs (#1916)
Browse files Browse the repository at this point in the history
The firs step of #1914,
this is setting up onnxscript CI to test whether traced_only function
has enough information to process inputs to tensors.
  • Loading branch information
titaiwangms authored Oct 25, 2024
1 parent 2b60939 commit 561a600
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 19 deletions.
15 changes: 15 additions & 0 deletions onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,18 @@ def tag_arguments_with_param_schemas(
raise TypeError(f"Required input/attribute '{param}' was not provided")

return tagged_args, tagged_kwargs


def turn_to_kwargs_to_avoid_ordering(
param_schemas: Sequence[values.ParamSchema],
inputs: list[Any],
attributes: dict[str, Any],
) -> dict[str, Any]:
"""Return the inputs and attributes to the order of the function signature."""
for idx, param in enumerate(param_schemas):
if param.name not in attributes:
if param.is_variadic_input:
attributes[param.name] = inputs[idx:]
elif inputs:
attributes[param.name] = inputs.pop(0)
return attributes
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,6 @@ def eval_function( # type: ignore[override]
else:
# Python constants are scalars
return 0
elif function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

# args/kwargs are TorchScriptTensor/python built-in based
param_schemas = function.param_schemas()
Expand Down Expand Up @@ -422,6 +419,15 @@ def eval_function( # type: ignore[override]
value, float
):
attributes[name] = (value,)
if function.traceable:
inputs = self._graph.preprocess_inputs(inputs)
inputs = _wrap_torch_value_to_tensor(inputs) # type: ignore[assignment]
# The args and kwargs matters, as it's traced onnx function
kwargs = param_manipulation.turn_to_kwargs_to_avoid_ordering(
param_schemas, inputs, attributes
)
# Trace the function call instead of adding the function as a node
return function.function(**kwargs)
return self._graph.add_function_call(function, inputs, attributes)


Expand Down Expand Up @@ -730,14 +736,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
value.setDebugName(_rename_intermediate_value(value.debugName()))
return value

@runtime_typing.checked
def _add_torchscript_op_call(
self,
name: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
n_outputs: int,
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> List[torch.Value]:
unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs)
graph_inputs = []
assert isinstance(unwrapped_inputs, Sequence)
Expand All @@ -761,6 +760,17 @@ def _add_torchscript_op_call(
graph_inputs.append(self._add_constant_to_graph(input))
else:
graph_inputs.append(input)
return graph_inputs

@runtime_typing.checked
def _add_torchscript_op_call(
self,
name: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
n_outputs: int,
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
graph_inputs = self.preprocess_inputs(onnx_inputs)
for key, value in onnx_attributes.items():
assert not isinstance(
value, TorchScriptTensor
Expand Down
11 changes: 3 additions & 8 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5752,14 +5752,9 @@ def aten_nansum(
def aten_narrow(self: TTensor, dim: INT64, start: INT64, length: INT64) -> TTensor:
"""narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)"""

if IsScalar(dim):
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))

if IsScalar(start):
start = op.Reshape(start, op.Constant(value_ints=[-1]))

if IsScalar(length):
length = op.Reshape(length, op.Constant(value_ints=[-1]))
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
start = op.Reshape(start, op.Constant(value_ints=[-1]))
length = op.Reshape(length, op.Constant(value_ints=[-1]))

end = op.Add(start, length)
return op.Slice(self, start, end, dim)
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,7 @@ def _where_input_wrangler(
.xfail(
variant_name="decimals_0",
reason="This variant does not accept decimals",
test_class_name="TestOutputConsistencyEager",
)
.xfail(
variant_name="decimals_3",
Expand Down

0 comments on commit 561a600

Please sign in to comment.