Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Create stable APIs for PyTorch 2.5 #1832

Merged
merged 16 commits into from
Aug 30, 2024
3 changes: 3 additions & 0 deletions onnxscript/_framework_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Semi-private stable APIs for framework-specific usage only."""
Fixed Show fixed Hide fixed
160 changes: 160 additions & 0 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
"""Stable APIs for PyTorch 2.5."""
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved

from __future__ import annotations

__all__ = [
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
]

import dataclasses
Fixed Show fixed Hide fixed
import os
import pathlib
Fixed Show fixed Hide fixed
from typing import Callable

import onnx

from onnxscript import ir
from onnxscript.function_libs.torch_lib import registration
from onnxscript.ir import _external_data

# Internal flag. Will go away.
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") == "1"
)


@dataclasses.dataclass(frozen=True)
class _OnnxFunctionMeta:
"""A wrapper of onnx-script function with additional metadata.

qualified_name: The qualified name of the aten operator.
function: The onnx-script function.
domain: The domain of the function.
name: The name of the function.
is_complex: Whether the function is a complex function.
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
"""

qualified_name: str
function: Callable
domain: str
name: str
is_complex: bool = False


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""

# TODO(justinchuby): Use the optimizer
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return model

Check warning on line 54 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L54

Added line #L54 was not covered by tests


def convert_version(model: ir.Model, target_version: int) -> ir.Model:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
Fixed Show fixed Hide fixed
"""Convert the model to the specified ONNX opset version."""
# model_version = model.opset_import.get("")
# if model_version == target_version:
# # No conversion needed
# return model

# # FIXME(justinchuby): version_converter does not support functions
# proto = ir.serde.serialize_model(model)
# proto = onnx.version_converter.convert_version(proto, target_version)
# return ir.serde.deserialize_model(proto)
# TODO(justinchuby): This function needs to be carefully implemented
# to handle large models. For now, we just return the model.
del target_version # Unused
return model

Check warning on line 71 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L70-L71

Added lines #L70 - L71 were not covered by tests


def check_model(model: ir.Model) -> None:
Fixed Show fixed Hide fixed
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
"""Check the model."""

del model # Unused yet

Check warning on line 77 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L77

Added line #L77 was not covered by tests

Check warning

Code scanning / CodeQL

Unnecessary delete statement in function Warning

Unnecessary deletion of local variable
model
in function
check_model
.


def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
"""Save the model with external data. The model is unchanged after saving."""
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
justinchuby marked this conversation as resolved.
Show resolved Hide resolved

# TODO(#1835): Decide if we want to externalize large attributes as well
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR:
initializer_values = tuple(model.graph.initializers.values())

Check warning on line 85 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L85

Added line #L85 was not covered by tests
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
raise ValueError(

Check warning on line 89 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L89

Added line #L89 was not covered by tests
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

Check warning on line 95 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L93-L95

Added lines #L93 - L95 were not covered by tests

external_tensors = _external_data.convert_tensors_to_external(

Check warning on line 97 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L97

Added line #L97 was not covered by tests
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

Check warning on line 106 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L105-L106

Added lines #L105 - L106 were not covered by tests

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor

Check warning on line 110 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L110

Added line #L110 was not covered by tests

else:
destination_path = pathlib.Path(model_path)

Check warning on line 113 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L113

Added line #L113 was not covered by tests
# Create the directory if it does not exist
data_path = f"{destination_path.name}.data"
proto = ir.serde.serialize_model(model)
onnx.save_model(

Check warning on line 117 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L115-L117

Added lines #L115 - L117 were not covered by tests
proto,
model_path,
save_as_external_data=True,
location=data_path,
)


def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
# Trigger op registration
from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel

Check warning on line 127 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L127

Added line #L127 was not covered by tests
ops,
)

del ops # Unused

Check warning on line 131 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L131

Added line #L131 was not covered by tests

torchlib_registry = registration.default_registry
function_metas = []

Check warning on line 134 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L133-L134

Added lines #L133 - L134 were not covered by tests

for qualified_name, aten_overloads_func in torchlib_registry.items():
if qualified_name.startswith("internal::"):
# Skip the custom defined internal functions
continue

Check warning on line 139 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L139

Added line #L139 was not covered by tests

for overload_func in aten_overloads_func.overloads:
function_meta = _OnnxFunctionMeta(

Check warning on line 142 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L142

Added line #L142 was not covered by tests
qualified_name=qualified_name,
function=overload_func,
domain=overload_func.function_ir.domain,
name=overload_func.name,
is_complex=False,
)
function_metas.append(function_meta)

Check warning on line 149 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L149

Added line #L149 was not covered by tests
for complex_func in aten_overloads_func.complex:
function_meta = _OnnxFunctionMeta(

Check warning on line 151 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L151

Added line #L151 was not covered by tests
qualified_name=qualified_name,
function=complex_func,
domain=complex_func.function_ir.domain,
name=complex_func.name,
is_complex=True,
)
function_metas.append(function_meta)

Check warning on line 158 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L158

Added line #L158 was not covered by tests

return function_metas

Check warning on line 160 in onnxscript/_framework_apis/torch_2_5.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_5.py#L160

Added line #L160 was not covered by tests
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ ignore-init-module-imports = true
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["TID252"] # Allow relative imports in init files
"setup.py" = ["TID251"] # pathlib is allowed in supporting code
"**/{examples,tests,docs,tools,utils,opgen}/*" = ["TID251"] # pathlib is allowed in supporting code
"**/{examples,tests,docs,tools,utils,opgen,_framework_apis}/*" = ["TID251"] # pathlib is allowed in supporting code
"**/*_test.py" = ["TID251"] # pathlib is allowed in tests

[tool.ruff.lint.flake8-tidy-imports]
Expand Down
Loading