-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[IR] Implement save/load functions in IR and handle external data pro…
…perly (#1801) Implement efficient save/load and handle loading external data properly in the IR. Before this change, when a ModelProto containing external data is converted to IR, the external tensor objects will load the data from a path relative to the working directory, not the ONNX file. This is because we do not store the onnx file path and thus have no way to look for the external data file. With the change, a `base_dir` property is added to ExternalTensor that we can set, in a separate pass when the directory is available, so the object has full information to find the data file on disk. The base_dir is not serialized to the proto to maintain a relative path in the "location" field in TensorProto. #1701, #1792 Example: ``` >>> m.graph.initializers["model.model.decoder.layers.2.encoder_attn.v_proj.weight"].const_value.display() ExternalTensor<FLOAT,[512,512]>(path='model.onnx.data', name='model.model.decoder.layers.2.encoder_attn.v_proj.weight', offset=245864448, length=1048576, base_dir='/home/justinchu/dev/ONNXConverter/docker/dump_bash_bench/BlenderbotSmallForConditionalGeneration-torch -onnx-detailed-cpu-') Min: -0.08586505800485611, Max: 0.09103105217218399, NaN count: 0, Inf count: 0 Sparsity (abs<1e-06): 0.00 Histogram: 11504 ┼ 10226 ┤ ╭───────╮ 8948 ┤ ╭─╯ ╰─╮ 7670 ┤ ╭─╯ ╰─╮ 6392 ┤ ╭─╯ ╰─╮ 5113 ┤ ╭─╯ ╰─╮ 3835 ┤ ╭─╯ ╰─╮ 2557 ┤ ╭──╯ ╰─╮ 1279 ┤ ╭────╯ ╰────╮ 1 ┼────────────────╯ ╰─────────────────── -0.0859 -0.0682 -0.0505 -0.0306 -0.0129 0.0070 0.0225 0.0402 0.0557 0.0733 0.0910 ```
- Loading branch information
1 parent
87aee66
commit 87d7c4f
Showing
7 changed files
with
216 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""External data related utilities.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = ["set_base_dir"] | ||
|
||
import os | ||
from typing import Iterator | ||
|
||
from onnxscript.ir import _core, _enums, _protocols, traversal | ||
|
||
|
||
def _all_tensors( | ||
graph: _core.Graph | _core.GraphView, include_attributes: bool = False | ||
) -> Iterator[_protocols.TensorProtocol]: | ||
"""Iterate over all tensors in the graph. | ||
Args: | ||
graph: The graph to traverse tensors on. | ||
include_attributes: Whether to include tensors in attributes. | ||
Yields: | ||
Tensors in the graph. | ||
""" | ||
# Yield all tensors in initializers | ||
for value in graph.initializers.values(): | ||
if value.const_value is not None: | ||
yield value.const_value | ||
if not include_attributes: | ||
return | ||
# Look at constant attributes in nodes | ||
for node in traversal.RecursiveGraphIterator(graph): | ||
for attr in node.attributes.values(): | ||
if isinstance(attr, _core.RefAttr): | ||
continue | ||
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: | ||
yield attr.value | ||
elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: | ||
yield from attr.value | ||
|
||
|
||
def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None: | ||
"""Set the base directory for external data in a graph. | ||
Args: | ||
graph: The graph to traverse tensors on. | ||
base_dir: The base directory. This is the directory where the ONNX file is. | ||
""" | ||
for tensor in _all_tensors(graph, include_attributes=True): | ||
if isinstance(tensor, _core.ExternalTensor): | ||
tensor.base_dir = base_dir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
import unittest | ||
|
||
import onnx | ||
import onnx.external_data_helper | ||
|
||
from onnxscript import ir | ||
from onnxscript.ir import _external_data | ||
|
||
|
||
class ExternalDataTest(unittest.TestCase): | ||
def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): | ||
attr_tensor = onnx.helper.make_tensor( | ||
name="test_constant", | ||
data_type=onnx.TensorProto.FLOAT, | ||
dims=[1], | ||
vals=b"\x01\x00\x00\x00", | ||
raw=True, | ||
) | ||
graph = onnx.helper.make_graph( | ||
nodes=[ | ||
onnx.helper.make_node( | ||
"Constant", | ||
[], | ||
["test"], | ||
value=attr_tensor, | ||
) | ||
], | ||
name="test", | ||
inputs=[], | ||
outputs=[], | ||
initializer=[ | ||
onnx.helper.make_tensor( | ||
name="test_tensor", | ||
data_type=onnx.TensorProto.FLOAT, | ||
dims=[1], | ||
vals=b"\x01\x00\x00\x00", | ||
raw=True, | ||
), | ||
], | ||
) | ||
model_proto = onnx.helper.make_model(graph) | ||
onnx.external_data_helper.convert_model_to_external_data( | ||
model_proto, location="tempdir", size_threshold=0, convert_attribute=True | ||
) | ||
model = ir.serde.deserialize_model(model_proto) | ||
expected_dir = "something_else" | ||
_external_data.set_base_dir(model.graph, expected_dir) | ||
|
||
initializer_tensor = model.graph.initializers["test_tensor"].const_value | ||
assert isinstance(initializer_tensor, ir.ExternalTensor) | ||
self.assertEqual(initializer_tensor.base_dir, expected_dir) | ||
attr_tensor = model.graph.node(0).attributes["value"].value | ||
self.assertEqual(attr_tensor.base_dir, expected_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Load and save ONNX models.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = ["load", "save"] | ||
|
||
import os | ||
|
||
import onnx | ||
|
||
from onnxscript.ir import _core, _external_data, serde | ||
|
||
|
||
def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: | ||
"""Load an ONNX model from a file. | ||
Args: | ||
path: The path to the ONNX file. | ||
format: The format of the file (e.g. protobuf, textproto, json, etc.). | ||
If None, the format is inferred from the file extension. | ||
Returns: | ||
The loaded model. | ||
""" | ||
# Do not use ONNX to load external data because the IR handles external data | ||
# by doing memory mapping directly. | ||
proto = onnx.load(path, format=format, load_external_data=False) | ||
model = serde.deserialize_model(proto) | ||
base_dir = os.path.dirname(path) | ||
# Set the base directory for external data to the directory of the ONNX file | ||
# so that relative paths are resolved correctly. | ||
_external_data.set_base_dir(model.graph, base_dir) | ||
return model | ||
|
||
|
||
def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None: | ||
"""Save an ONNX model to a file. | ||
Args: | ||
model: The model to save. | ||
path: The path to save the model to. | ||
format: The format of the file (e.g. protobuf, textproto, json, etc.). | ||
If None, the format is inferred from the file extension. | ||
""" | ||
proto = serde.serialize_model(model) | ||
onnx.save(proto, path, format=format) | ||
# TODO(justinchuby): Handle external data when the relative path has changed | ||
# TODO(justinchuby): Handle off loading external data to disk when saving |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters