Skip to content

Commit

Permalink
[IR] Implement save/load functions in IR and handle external data pro…
Browse files Browse the repository at this point in the history
…perly (#1801)

Implement efficient save/load and handle loading external data properly
in the IR.

Before this change, when a ModelProto containing external data is
converted to IR, the external tensor objects will load the data from a
path relative to the working directory, not the ONNX file. This is
because we do not store the onnx file path and thus have no way to look
for the external data file.

With the change, a `base_dir` property is added to ExternalTensor that
we can set, in a separate pass when the directory is available, so the
object has full information to find the data file on disk. The base_dir
is not serialized to the proto to maintain a relative path in the
"location" field in TensorProto.

#1701,
#1792

Example:

```
>>> m.graph.initializers["model.model.decoder.layers.2.encoder_attn.v_proj.weight"].const_value.display()
ExternalTensor<FLOAT,[512,512]>(path='model.onnx.data', 
name='model.model.decoder.layers.2.encoder_attn.v_proj.weight', offset=245864448, length=1048576, 
base_dir='/home/justinchu/dev/ONNXConverter/docker/dump_bash_bench/BlenderbotSmallForConditionalGeneration-torch
-onnx-detailed-cpu-')

Min: -0.08586505800485611, Max: 0.09103105217218399, NaN count: 0, Inf count: 0
Sparsity (abs<1e-06): 0.00
Histogram:
   11504 ┼
   10226 ┤                                  ╭───────╮
    8948 ┤                                ╭─╯       ╰─╮
    7670 ┤                              ╭─╯           ╰─╮
    6392 ┤                            ╭─╯               ╰─╮
    5113 ┤                          ╭─╯                   ╰─╮
    3835 ┤                        ╭─╯                       ╰─╮
    2557 ┤                     ╭──╯                           ╰─╮
    1279 ┤                ╭────╯                                ╰────╮
       1 ┼────────────────╯                                          ╰───────────────────
    -0.0859  -0.0682  -0.0505  -0.0306  -0.0129  0.0070  0.0225  0.0402  0.0557  0.0733  0.0910
```
  • Loading branch information
justinchuby authored Aug 13, 2024
1 parent 87aee66 commit 87d7c4f
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 9 deletions.
4 changes: 4 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
# Pass infrastructure
"passes",
"traversal",
# IO
"load",
"save",
]

from onnxscript.ir import passes, serde, traversal
Expand Down Expand Up @@ -114,6 +117,7 @@
AttributeType,
DataType,
)
from onnxscript.ir._io import load, save
from onnxscript.ir._protocols import (
ArrayCompatible,
AttributeProtocol,
Expand Down
30 changes: 26 additions & 4 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
Attributes:
path: The path to the data file. This can be a relative path or an absolute path.
base_dir: The base directory for the external data. It is used to resolve relative paths.
At serialization, only the ``path`` is serialized into the "location" field of the TensorProto.
offset: The offset in bytes from the start of the file.
length: The length of the data in bytes.
dtype: The data type of the tensor.
Expand Down Expand Up @@ -509,8 +511,15 @@ def __init__(
name: str,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
base_dir: os.PathLike | str = "",
) -> None:
self._path = path
if os.path.isabs(path):
self._base_dir = os.path.dirname(path)
self._path = os.path.basename(path)
else:
self._base_dir = base_dir
self._path = path

self._offset: int | None = offset
self._length: int | None = length
self._dtype: _enums.DataType = dtype
Expand All @@ -528,6 +537,15 @@ def path(self) -> str | os.PathLike:
# Immutable
return self._path

@property
def base_dir(self) -> str | os.PathLike:
# Mutable
return self._base_dir

@base_dir.setter
def base_dir(self, value: str | os.PathLike) -> None:
self._base_dir = value

@property
def offset(self) -> int | None:
# Immutable
Expand Down Expand Up @@ -556,7 +574,8 @@ def _load(self):
return
# Map the whole file into the memory
# TODO(justinchuby): Verify if this would exhaust the memory address space
with open(self._path, "rb") as f:
file_path = os.path.join(self._base_dir, self._path)
with open(file_path, "rb") as f:
self.raw = mmap.mmap(
f.fileno(),
0,
Expand Down Expand Up @@ -599,7 +618,10 @@ def __dlpack_device__(self) -> tuple[int, int]:
)

def __repr__(self) -> str:
return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})"
return (
f"{self._repr_base()}(path='{self._path}', name={self.name!r}, "
f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})"
)

def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array.
Expand Down Expand Up @@ -2069,7 +2091,7 @@ def __init__(
outputs: Sequence[Value],
*,
nodes: Iterable[Node],
initializers: Sequence[_protocols.TensorProtocol] = (),
initializers: Sequence[_protocols.ValueProtocol] = (),
doc_string: str | None = None,
opset_imports: dict[str, int] | None = None,
name: str | None = None,
Expand Down
17 changes: 17 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ def test_initialize(self):
# Ensure repeated reads are consistent
np.testing.assert_equal(tensor, self.data)

def test_initialize_with_relative_path(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
tensor = _core.ExternalTensor(
path=external_info.location,
offset=external_info.offset,
length=external_info.length,
dtype=ir.DataType.FLOAT,
name="input",
shape=_core.Shape(external_tensor.dims),
base_dir=pathlib.Path(self.base_path),
)
self.assertEqual(tensor.dtype, ir.DataType.FLOAT)
np.testing.assert_equal(tensor, self.data)
# Ensure repeated reads are consistent
np.testing.assert_equal(tensor, self.data)

def test_totypes_returns_correct_data_in(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
Expand Down
53 changes: 53 additions & 0 deletions onnxscript/ir/_external_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""External data related utilities."""

from __future__ import annotations

__all__ = ["set_base_dir"]

import os
from typing import Iterator

from onnxscript.ir import _core, _enums, _protocols, traversal


def _all_tensors(
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
) -> Iterator[_protocols.TensorProtocol]:
"""Iterate over all tensors in the graph.
Args:
graph: The graph to traverse tensors on.
include_attributes: Whether to include tensors in attributes.
Yields:
Tensors in the graph.
"""
# Yield all tensors in initializers
for value in graph.initializers.values():
if value.const_value is not None:
yield value.const_value
if not include_attributes:
return
# Look at constant attributes in nodes
for node in traversal.RecursiveGraphIterator(graph):
for attr in node.attributes.values():
if isinstance(attr, _core.RefAttr):
continue
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
yield attr.value
elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None:
yield from attr.value


def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None:
"""Set the base directory for external data in a graph.
Args:
graph: The graph to traverse tensors on.
base_dir: The base directory. This is the directory where the ONNX file is.
"""
for tensor in _all_tensors(graph, include_attributes=True):
if isinstance(tensor, _core.ExternalTensor):
tensor.base_dir = base_dir
59 changes: 59 additions & 0 deletions onnxscript/ir/_external_data_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest

import onnx
import onnx.external_data_helper

from onnxscript import ir
from onnxscript.ir import _external_data


class ExternalDataTest(unittest.TestCase):
def test_set_base_dir_sets_base_dir_for_all_external_tensors(self):
attr_tensor = onnx.helper.make_tensor(
name="test_constant",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=b"\x01\x00\x00\x00",
raw=True,
)
graph = onnx.helper.make_graph(
nodes=[
onnx.helper.make_node(
"Constant",
[],
["test"],
value=attr_tensor,
)
],
name="test",
inputs=[],
outputs=[],
initializer=[
onnx.helper.make_tensor(
name="test_tensor",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=b"\x01\x00\x00\x00",
raw=True,
),
],
)
model_proto = onnx.helper.make_model(graph)
onnx.external_data_helper.convert_model_to_external_data(
model_proto, location="tempdir", size_threshold=0, convert_attribute=True
)
model = ir.serde.deserialize_model(model_proto)
expected_dir = "something_else"
_external_data.set_base_dir(model.graph, expected_dir)

initializer_tensor = model.graph.initializers["test_tensor"].const_value
assert isinstance(initializer_tensor, ir.ExternalTensor)
self.assertEqual(initializer_tensor.base_dir, expected_dir)
attr_tensor = model.graph.node(0).attributes["value"].value
self.assertEqual(attr_tensor.base_dir, expected_dir)


if __name__ == "__main__":
unittest.main()
50 changes: 50 additions & 0 deletions onnxscript/ir/_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Load and save ONNX models."""

from __future__ import annotations

__all__ = ["load", "save"]

import os

import onnx

from onnxscript.ir import _core, _external_data, serde


def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
"""Load an ONNX model from a file.
Args:
path: The path to the ONNX file.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.
Returns:
The loaded model.
"""
# Do not use ONNX to load external data because the IR handles external data
# by doing memory mapping directly.
proto = onnx.load(path, format=format, load_external_data=False)
model = serde.deserialize_model(proto)
base_dir = os.path.dirname(path)
# Set the base directory for external data to the directory of the ONNX file
# so that relative paths are resolved correctly.
_external_data.set_base_dir(model.graph, base_dir)
return model


def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None:
"""Save an ONNX model to a file.
Args:
model: The model to save.
path: The path to save the model to.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.
"""
proto = serde.serialize_model(model)
onnx.save(proto, path, format=format)
# TODO(justinchuby): Handle external data when the relative path has changed
# TODO(justinchuby): Handle off loading external data to disk when saving
12 changes: 7 additions & 5 deletions onnxscript/ir/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,35 @@
"RecursiveGraphIterator",
]

from typing import Callable, Iterator, Reversible
from typing import Callable, Iterator, Reversible, Union

from typing_extensions import Self

from onnxscript.ir import _core, _enums

GraphLike = Union[_core.Graph, _core.Function, _core.GraphView]


class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
def __init__(
self,
graph: _core.Graph | _core.Function | _core.GraphView,
graph_like: GraphLike,
*,
recursive: Callable[[_core.Node], bool] | None = None,
reverse: bool = False,
):
"""Iterate over the nodes in the graph, recursively visiting subgraphs.
Args:
graph: The graph to traverse.
graph_like: The graph to traverse.
recursive: A callback that determines whether to recursively visit the subgraphs
contained in a node. If not provided, all nodes in subgraphs are visited.
reverse: Whether to iterate in reverse order.
"""
self._graph = graph
self._graph = graph_like
self._recursive = recursive
self._reverse = reverse
self._iterator = self._recursive_node_iter(graph)
self._iterator = self._recursive_node_iter(graph_like)

def __iter__(self) -> Self:
self._iterator = self._recursive_node_iter(self._graph)
Expand Down

0 comments on commit 87d7c4f

Please sign in to comment.