-
Notifications
You must be signed in to change notification settings - Fork 54
TorchLib function authoring guide
Updated: November 2023
Authors: @titaiwangms @justinchuby
TorchLib functions are pure data. This means we avoid defining runtime behavior as code in the functions.
The primary objective of torchlib revolves around transforming a PyTorch model into an ONNX model. To accomplish this, it's essential to initially grasp the function signature within PyTorch, specifically focusing on ATen operators. You can find a comprehensive list of these native functions in PyTorch defined within the native_functions.yaml file.
- func: func_name(ArgType arg0[=default], ArgType arg1[=default], ...) -> Return
variants: function, method
dispatch:
CPU: func_cpu
CUDA: func_cuda
Developers need to exercise caution when dealing with the ArgType, as each distinct ArgType corresponds to a different TypeVar within torchlib.
The torch_op
decorator serves the purpose of formally registering the function within the torchlib framework.
def torch_op(
name: str | tuple[str, ...],
*,
registry: Optional[Registry] = None,
trace_only: bool = False,
private: bool = False,
complex: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.
Args:
name: Qualified ATen name of the function. E.g. "aten::relu", "aten::add.Tensor".
Or a tuple of names e.g. ("aten::add.Scalar", "aten::add.Tensor").
Default overloads should be specified by omitting the overload part,
i.e. "aten::relu" instead of "aten::relu.default".
registry: Registry to register the function to. If None, the default registry is used.
trace_only: Whether the function should only be traced and not compiled.
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function supports complex.
"""
...
The trace_only
feature enhances the functionality of script()
by incorporating intricate control-flow through the utilization of the TracedOnnxFunction class. Unlike the process of compiling the entire control-flow enabled function into OnnxFunction, TracedOnnxFunction merely traces it as a standard Python function. This adaptation allows for the handling of unsupported control-flow scenarios.
- Name a function starting with the namespace it's from. For example,
aten_abs
orprims_abs
. - Correctly annotate the inputs and attributes with
native_function.yaml
.
Introduce or define a single TypeVar within tensor_typing that corresponds to the designated ArgType indicated in the native_functions.yaml file. Typically, inputs are expected to conform to tensor types, while attributes are anticipated to be of primitive types. Nevertheless, the specific circumstances evolve on a case-by-case basis due to the implementation of OnnxFunction, adapting to the prerequisites of the employed ONNX operators within the function.
When scripting the function, it's imperative that every computation within the function is executed using ONNX operators. A prefix denoted as opset{version} is employed to indicate the source of the operator. OnnxFunction additionally provides partial support for control-flow operations, as well as streamlined coding practices like 'if' statements, 'for' loops, automatic constant encapsulation, and automatic basic arithmetic encapsulation.
@torch_op("aten::gather")
def aten_gather(
self: TReal,
dim: int,
index: TInt,
sparse_grad: bool = False, # pylint: disable=unused-argument
) -> TReal:
"""gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"""
# When (index) is empty, return (self)
if op.Size(op.Shape(index)) == 0: # Support control-flow
result = self
else:
if op.Size(op.Shape(self)) == 0: # 0 is auto-wrapping of op.Constant(value_float=[0])
self = op.Reshape(self, op.Constant(value_ints=[-1]))
if op.Size(index) == 0: # == is auto-wrapping on op.Equal()
result = op.CastLike(index, self)
else:
index = op.Cast(index, to=INT64.dtype)
result = op.GatherElements(self, index, axis=dim)
return result
This category of function is essentially a pure Python function that encompasses OnnxFunction. The rationale behind its necessity lies in the fact that the constrained coding capabilities of OnnxFunction cannot adequately address intricate scenarios within the operator. These situations may encompass unsupported operations such as dictionaries, the 'len()' function, checks for 'None', and so on.
@torch_op("aten::layer_norm", trace_only=True)
def aten_layer_norm(
input: TReal,
normalized_shape: INT64,
weight: Optional[TReal] = None,
bias: Optional[TReal] = None,
eps: float = 1e-05,
) -> TReal:
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
# trace_only to use Python to obtain start_axis
start_axis = -len(normalized_shape)
if weight is None: # Unsupported None check
one = op.Constant(value_float=1.0)
weight = op.Expand(one, op.Shape(input, start=start_axis))
if bias is None: # Unsupported None check
zero = op.Constant(value_float=0.0)
bias = op.Expand(zero, op.Shape(input, start=start_axis))
return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps) # covers a private OnnxFunction
@torch_op("aten::layer_norm", private=True)
def _aten_layer_norm_onnx(
input: TReal,
weight: TReal,
bias: TReal,
axis: int,
eps: float = 1e-05,
) -> TReal:
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
# TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982
result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps)
return result
To make sure the OnnxFunction/TracedOnnxFunction has valid implementation, we provide Op-level correctness test.
This test use PyTorch's OpInfo mechanism to generate test cases for each operator. You may find all OpInfos in https://github.com/pytorch/pytorch/blob/7ec0d6f006fdd2c9b978dc6aa4923144684a3f51/torch/testing/_internal/common_methods_invocations.py#L8804
-
To enable test cases for an operator Add a
TorchLibOpInfo
entry toTORCH_LIB_OPINFO
inops_test_data.py
. Explicitly specifytrace_only
if the op is trace_only. Specifycomplex
if the function is designed for complex inputs.The
op_info_name
inTorchLibOpInfo
needs to be unique in the TORCH_LIB_OPINFO list, but complex=True ops can share the same name with non-complex ops because they are tested separately. -
Add
.skip
and/or.xfail
to skip or xfail tests. Prefer xfail over skip when possible because that allows us to monitor the behavior and update the test will it passes.2a. If a test is now failing because of xpass, because some previous errors are now fixed, removed the corresponding xfail.
-
If sample inputs of the OpInfo needs to be adjusted to fit the aten signature, create an input wrangler function. See
_mean_input_wrangler
for an example. -
To test different ONNX functions that are registered as overloads of the same op, use
ops_test_common.duplicate_opinfo
to create new OpInfo with new names and map each to one overload.
Use https://github.com/microsoft/onnxscript/pull/1260/files and https://github.com/microsoft/onnxscript/pull/1284 as examples for implementing an operator and creating tests for it.