-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add framework for version converter API #1926
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
❌ 12 Tests Failed:
View the full list of 3 ❄️ flaky tests
To view more test analytics, go to the Test Analytics Dashboard |
One of the main design question relates to the "adapter" signature: what form should it take? Essentially it is a function that takes a single node as a parameter, and modifies it in some form. The changes are typically a simple mutation of a node along with potentially other changes (such as the insertion of extra nodes). For now, I think it might be fine to follow the pattern used in the optimizer and rewriter, which are based on node-transformers that, given an input node, return a sequence of replacement nodes or None (if no replacement is required). This allows a simple loop over all nodes in the graph that transforms each node in sequence. This can be generalized later if necessary. |
self.custom_adapters = custom_adapter_list | ||
|
||
def graph_version_convert(self, graph: ir.Graph, target_version: int) -> None: | ||
if self.model_version == target_version: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Extension) I think we will need to soon support the case where the incoming model has nodes with different opset versions. At that point, such checks should happen at the node level, not at the model level.
Questions that are related to the design doc that comes to mind
|
An op can stay the same for many versions. For example, Acosh-9 is the same all the way to 21. Would it make sense to mark the adapter for it to support the "base to next" version (e.g. We can also identify if an adapter is usable for a given node with opset version without having to consult with the ONNX defs to know if Acosh-9 is the same as Acosh-21, because from the |
@register("DFT", node_version=19, upgrade_version=20) | ||
def dft_19_20(node: ir.Node, op): | ||
input = node.inputs[0] | ||
inverse = _get_int_attribute(node, "inverse", 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a complication here, relating to handling of functions. In functions, we can have attribute-parameters, which mean we don't know the value of the attribute. The version-conversion, as written, will be wrong in such cases.
For now, I think it is okay to inline all functions (see relevant comment below) and ignore this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to just promote the attribute to input and use them in the upgraded op?
def __init__(self, target_version: int): | ||
self.target_version = target_version | ||
|
||
def process_node(self, node: ir.Node, opset_version, up_conversion: bool = True): |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
def __init__(self, target_version: int): | ||
self.target_version = target_version | ||
|
||
def process_node(self, node: ir.Node, opset_version, up_conversion: bool = True): |
Check notice
Code scanning / lintrunner
PYLINT/R1710 Note
See inconsistent-return-statements. To disable, use # pylint: disable=inconsistent-return-statements
for graph in attr.value: | ||
self.visit_graph(graph, opset_version, up_conversion) # type: ignore[arg-type] | ||
|
||
def visit_node( |
Check notice
Code scanning / lintrunner
PYLINT/R1710 Note
See inconsistent-return-statements. To disable, use # pylint: disable=inconsistent-return-statements
self.visit_node(node, graph, opset_version, up_conversion) | ||
node.version = self.target_version | ||
|
||
def visit_model(self, model: ir.Model) -> None: |
Check notice
Code scanning / lintrunner
PYLINT/R1710 Note
See inconsistent-return-statements. To disable, use # pylint: disable=inconsistent-return-statements
__all__ = [ | ||
# Functions | ||
"convert_version", | ||
"inline", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"inline", |
inline is not part of the converter
] | ||
|
||
from onnxscript import ir | ||
from onnxscript.optimizer._inliner import inline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import modules only
from onnxscript.optimizer._inliner import inline | |
from onnxscript.optimizer import _inliner |
|
||
from onnxscript import ir | ||
from onnxscript.optimizer._inliner import inline | ||
from onnxscript.version_converter.version_converter import version_convert |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Perfer calling it convert_version. Prefer using a verb as a function name
import logging | ||
from typing import Callable, Sequence, Union | ||
|
||
import onnxscript.ir._convenience as _convenience |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import onnxscript.ir._convenience as _convenience | |
import onnxscript.ir.convenience as ir_convenience |
conventional naming; import the public module
# A version-adapter function takes a node, a RewriterContext and returns | ||
# a Replacement for the node or None (if no replacement is needed). | ||
|
||
ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would a function return a Replacement over a Value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This module can be private. All public functions are exposed in __init__
opname: str, | ||
original_version: int, | ||
up_conversion: bool = True, | ||
) -> Union[AdapterFunction, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> Union[AdapterFunction, None]: | |
) -> AdapterFunction | None: |
original_version: int, | ||
up_conversion: bool = True, | ||
) -> Union[AdapterFunction, None]: | ||
adapter = self.op_adapters.get((domain, opname, original_version, up_conversion), None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adapter = self.op_adapters.get((domain, opname, original_version, up_conversion), None) | |
adapter = self.op_adapters.get((domain, opname, original_version, up_conversion)) |
default is None
else: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | |
return None | |
return None |
else: | ||
return None | ||
|
||
def register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a docstring?
def register( | ||
self, opname: str, domain: str = "", node_version=None, up_conversion=True | ||
) -> Callable[[AdapterFunction], AdapterFunction]: | ||
def decorator(function: AdapterFunction) -> AdapterFunction: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decorate with functools.wraps so the typing information is preserved
return None | ||
|
||
|
||
def _get_output(node: ir.Node, index: int) -> ir.Value | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this used?
def __init__(self, target_version: int): | ||
self.target_version = target_version | ||
|
||
def process_node(self, node: ir.Node, opset_version, up_conversion: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def process_node(self, node: ir.Node, opset_version, up_conversion: bool = True): | |
def process_node(self, node: ir.Node, opset_version: int, up_conversion: bool = True) -> Replacement | None: |
|
||
def process_node(self, node: ir.Node, opset_version, up_conversion: bool = True): | ||
if node.domain not in self.opset_imports: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create a test to cover this line?
output = adapter(node, context) | ||
if output is not None: | ||
if isinstance(output, Replacement): | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create a test to cover this line?
output = [output] | ||
return Replacement(output, context.nodes) | ||
|
||
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): | |
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: |
scale = _get_input(node, 1) | ||
bias = _get_input(node, 2) | ||
if x is None or scale is None or bias is None: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some branches are missing coverage
No description provided.