Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add framework for version converter API #1926

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

__all__ = [
# Functions
"convert_version",
"inline",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"inline",

inline is not part of the converter

]

from onnxscript import ir
from onnxscript.optimizer._inliner import inline
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import modules only

Suggested change
from onnxscript.optimizer._inliner import inline
from onnxscript.optimizer import _inliner

from onnxscript.version_converter.version_converter import version_convert
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Perfer calling it convert_version. Prefer using a verb as a function name



def convert_version(model: ir.Model, target_version: int) -> None:
"""Convert the model to the specified ONNX opset version."""

# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
# Hence, we inline all the functions.
inline(model)
version_convert(model, target_version)
310 changes: 310 additions & 0 deletions onnxscript/version_converter/version_converter.py
Copy link
Collaborator

@justinchuby justinchuby Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module can be private. All public functions are exposed in __init__

Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Convert the model to the specified ONNX opset version."""

from __future__ import annotations

import dataclasses
import logging
from typing import Callable, Sequence, Union

import onnxscript.ir._convenience as _convenience
Copy link
Collaborator

@justinchuby justinchuby Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import onnxscript.ir._convenience as _convenience
import onnxscript.ir.convenience as ir_convenience

conventional naming; import the public module

import onnxscript.rewriter.pattern as orp
from onnxscript import ir

logger = logging.getLogger(__name__)


CURRENT_MAX_ONNX_OPSET = 23


@dataclasses.dataclass
class Replacement:
"""A replacement for a node in the graph."""

new_outputs: Sequence[ir.Value]
new_nodes: Sequence[ir.Node]


# A version-adapter function takes a node, a RewriterContext and returns
# a Replacement for the node or None (if no replacement is needed).

ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would a function return a Replacement over a Value?

AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]


@dataclasses.dataclass
class VersionAdapter:
"""A class that represents a version checker for a particular op.
It is applicable for a specific version upgrade (orignal_version -> original_version + 1)
or downgrade (orignal_version -> original_version - 1)of the op.
"""

node_version: int
up_conversion: bool
function: AdapterFunction


class AdapterRegistry:
"""A class that maintains a registry of adapters for ops."""

def __init__(self):
self.op_adapters: dict[tuple[str, str, int, bool], VersionAdapter] = {}

def lookup_adapters(
self,
domain: str,
opname: str,
original_version: int,
up_conversion: bool = True,
) -> Union[AdapterFunction, None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> Union[AdapterFunction, None]:
) -> AdapterFunction | None:

adapter = self.op_adapters.get((domain, opname, original_version, up_conversion), None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
adapter = self.op_adapters.get((domain, opname, original_version, up_conversion), None)
adapter = self.op_adapters.get((domain, opname, original_version, up_conversion))

default is None

if adapter is not None:
return adapter.function
Fixed Show fixed Hide fixed
else:
return None
Comment on lines +65 to +66
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
return None
return None


def register(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring?

justinchuby marked this conversation as resolved.
Show resolved Hide resolved
self, opname: str, domain: str = "", node_version=None, up_conversion=True
) -> Callable[[AdapterFunction], AdapterFunction]:
def decorator(function: AdapterFunction) -> AdapterFunction:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decorate with functools.wraps so the typing information is preserved

self.op_adapters[(domain, opname, node_version, up_conversion)] = VersionAdapter(
node_version, up_conversion, function
)
return function

return decorator


registry: AdapterRegistry = AdapterRegistry()

register = registry.register


def _get_input(node: ir.Node, index: int) -> ir.Value | None:
if index < len(node.inputs):
return node.inputs[index]
return None

Check warning on line 88 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L88

Added line #L88 was not covered by tests


def _get_output(node: ir.Node, index: int) -> ir.Value | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this used?

if index < len(node.outputs):
return node.outputs[index]
return None

Check warning on line 94 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L93-L94

Added lines #L93 - L94 were not covered by tests


def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None:
if name in node.attributes:
attr = node.attributes[name]
if not isinstance(attr, ir.Attr):
return None

Check warning on line 101 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L101

Added line #L101 was not covered by tests
attr_val = attr.value
if isinstance(attr_val, int):
return attr_val
# This is an invalid model: attribute has invalid/unexpected type.
# For now, we just return None. We could raise an error too.
return None

Check warning on line 107 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L107

Added line #L107 was not covered by tests
return default


def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> str | None:
if name in node.attributes:
attr = node.attributes[name]
if not isinstance(attr, ir.Attr):
return None

Check warning on line 115 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L115

Added line #L115 was not covered by tests
attr_val = attr.value
if isinstance(attr_val, str):
return attr_val
# This is an invalid model: attribute has invalid/unexpected type.
# For now, we just return None. We could raise an error too.
return None

Check warning on line 121 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L121

Added line #L121 was not covered by tests
return default


## Op-specific adapters

# Opset 19 -> 20


@register("DFT", node_version=19, up_conversion=True)
def dft_19_20(node: ir.Node, op):
input = node.inputs[0]
inverse = _get_int_attribute(node, "inverse", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a complication here, relating to handling of functions. In functions, we can have attribute-parameters, which mean we don't know the value of the attribute. The version-conversion, as written, will be wrong in such cases.

For now, I think it is okay to inline all functions (see relevant comment below) and ignore this issue.

Copy link
Collaborator

@justinchuby justinchuby Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to just promote the attribute to input and use them in the upgraded op?

onesided = _get_int_attribute(node, "onesided", 0)
axis = _get_int_attribute(node, "axis", None)
if axis is not None:
axis_value = op.Constant(value_int=axis)
return op.DFT(input, axis_value, inverse=inverse, onesided=onesided)
return None


@register("GridSample", node_version=19, up_conversion=True)
def gridsample_19_20(node: ir.Node, op):
x = node.inputs[0]
grid = node.inputs[1]
align_corners = _get_int_attribute(node, "align_corners", 0)
mode = _get_str_attribute(node, "mode", "linear")
padding_mode = _get_str_attribute(node, "padding_mode", "zeros")
if mode == "bilinear":
return op.GridSample(
x, grid, align_corners=align_corners, mode="linear", padding_mode=padding_mode
)
elif mode == "bicubic":
return op.GridSample(
x, grid, align_corners=align_corners, mode="cubic", padding_mode=padding_mode
)
return None


# Opset 20 -> 21


@register("GroupNormalization", node_version=20, up_conversion=True)
def groupnormalization_20_21(node: ir.Node, op):
x = _get_input(node, 0)
scale = _get_input(node, 1)
bias = _get_input(node, 2)
if x is None or scale is None or bias is None:
return None

Check warning on line 169 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L169

Added line #L169 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some branches are missing coverage


x_shape = x.shape
if x_shape is None:
return None

Check warning on line 173 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L173

Added line #L173 was not covered by tests
num_channels = x_shape[1]
if not isinstance(num_channels, int):
return None

Check warning on line 176 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L176

Added line #L176 was not covered by tests

scale_shape = scale.shape
bias_shape = bias.shape
if scale_shape is None or bias_shape is None:
return None
if not isinstance(scale_shape[0], int) or not isinstance(bias_shape[0], int):
return None

Check warning on line 183 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L183

Added line #L183 was not covered by tests

num_groups = _get_int_attribute(node, "num_groups", None)
if num_groups is None:
return None

Check warning on line 187 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L187

Added line #L187 was not covered by tests
if (
num_groups != num_channels
and num_groups == scale_shape[0]
and num_groups == bias_shape[0]
):
reshape_1_sizes = op.Constant(value_ints=[-1, 1])
reshape_2_sizes = op.Constant(value_ints=[-1])
c_div = int(num_channels / num_groups)
expand_sizes = op.Constant(value_ints=[1, c_div])

# Modify scale input
scale_reshape_1 = op.Reshape(scale, reshape_1_sizes)
scale_expand = op.Expand(scale_reshape_1, expand_sizes)
scale_reshape_2 = op.Reshape(scale_expand, reshape_2_sizes)

# Modify bias input
bias_reshape_1 = op.Reshape(bias, reshape_1_sizes)
bias_expand = op.Expand(bias_reshape_1, expand_sizes)
bias_reshape_2 = op.Reshape(bias_expand, reshape_2_sizes)

return op.GroupNormalization(x, scale_reshape_2, bias_reshape_2, num_groups=num_groups)
return None

Check warning on line 209 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L209

Added line #L209 was not covered by tests


class _VersionConverter:
opset_imports: dict[str, int]

def __init__(self, target_version: int):
self.target_version = target_version

def process_node(self, node: ir.Node, opset_version, up_conversion: bool = True):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.

Check notice

Code scanning / lintrunner

PYLINT/R1710 Note

Either all return statements in a function should return an expression, or none of them should. (inconsistent-return-statements)
See inconsistent-return-statements. To disable, use # pylint: disable=inconsistent-return-statements
Copy link
Collaborator

@justinchuby justinchuby Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def process_node(self, node: ir.Node, opset_version, up_conversion: bool = True):
def process_node(self, node: ir.Node, opset_version: int, up_conversion: bool = True) -> Replacement | None:

if node.domain not in self.opset_imports:
return None

Check warning on line 220 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L220

Added line #L220 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a test to cover this line?

adapter = registry.lookup_adapters(
node.domain, node.op_type, opset_version, up_conversion
)
if adapter is None:
return None
context = orp.RewriterContext()
output = adapter(node, context)
Fixed Show fixed Hide fixed
if output is not None:
if isinstance(output, Replacement):
return output

Check warning on line 230 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L230

Added line #L230 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a test to cover this line?

if isinstance(output, ir.Value):
output = [output]
return Replacement(output, context.nodes)

def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function):
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:

logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)

_convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)

def visit_attribute(
self, attr: ir.Attr | ir.RefAttr, opset_version: int, up_conversion: bool = True
) -> None:
if isinstance(attr, ir.Attr):
if attr.type == ir.AttributeType.GRAPH:
self.visit_graph(attr.value, opset_version, up_conversion) # type: ignore[arg-type]

Check warning on line 247 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L247

Added line #L247 was not covered by tests
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.value:
self.visit_graph(graph, opset_version, up_conversion) # type: ignore[arg-type]

Check warning on line 250 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L250

Added line #L250 was not covered by tests

def visit_node(

Check notice

Code scanning / lintrunner

PYLINT/R1710 Note

Either all return statements in a function should return an expression, or none of them should. (inconsistent-return-statements)
See inconsistent-return-statements. To disable, use # pylint: disable=inconsistent-return-statements
self,
node: ir.Node,
root: ir.Graph | ir.Function,
opset_version: int,
up_conversion: bool = True,
):
replacement = self.process_node(node, opset_version, up_conversion)
if replacement is None:
# No change. Process attributes.
for attr in node.attributes.values():
self.visit_attribute(attr, opset_version, up_conversion)
return None
else:
self.replace_node(node, replacement, root)

def visit_graph(
self, graph: ir.Graph, opset_version: int, up_conversion: bool = True
) -> None:
for node in graph:
self.visit_node(node, graph, opset_version, up_conversion)
node.version = self.target_version

def visit_model(self, model: ir.Model) -> None:

Check notice

Code scanning / lintrunner

PYLINT/R1710 Note

Either all return statements in a function should return an expression, or none of them should. (inconsistent-return-statements)
See inconsistent-return-statements. To disable, use # pylint: disable=inconsistent-return-statements
self.opset_imports = model.opset_imports
model_version = model.opset_imports.get("")
if model_version is None:
return None

Check warning on line 279 in onnxscript/version_converter/version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/version_converter.py#L279

Added line #L279 was not covered by tests

up_conversion = True
if self.target_version < model_version:
up_conversion = False
Fixed Show fixed Hide fixed
# TODO (shubhambhokare1) : Remove once down-conversion adapters are supoorted
if up_conversion is False:
logger.warning(
"Target opset: %s less than %s, downstream version conversion not currently handled.",
self.target_version,
model_version,
)
return None
# Iterate from current model version -> target version
# Updating each node based on the correct adapter
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
for opset_version in range(model_version, self.target_version):
if up_conversion is True and opset_version == CURRENT_MAX_ONNX_OPSET:
logger.warning(
"Conversion from opset: %s to target opset: %s not currently supported.",
opset_version,
opset_version + 1,
)
return None

self.visit_graph(model.graph, opset_version, up_conversion)


def version_convert(model: ir.Model, target_version: int) -> None:
"""Convert the model to the specified ONNX opset version."""
version_converter = _VersionConverter(target_version=target_version)
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
version_converter.visit_model(model)
Loading
Loading