Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implement register_initializer #1941

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,42 @@
def initializers(self) -> dict[str, Value]:
return self._initializers

def register_initializer(self, value: Value) -> None:
"""Register an initializer to the graph.

This is a convenience method to register an initializer to the graph with
checks.

Args:
value: The :class:`Value` to register as an initializer of the graph.
It must have its ``.const_value`` set.

Raises:
ValueError: If a value of the same name that is not this value
is already registered.
ValueError: If the value does not have a name.
ValueError: If the initializer is produced by a node.
ValueError: If the value does not have its ``.const_value`` set.
"""
if value.name in self._initializers:
if self._initializers[value.name] is not value:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Not necessarily in this PR, but would be nice to have a method to reuse initializers, especially for common duplicated constants, like 0, 1 or even shapes like (batchsize, seqlen, hiddensize).
  • On a related note, for very small tensors, we could replace the is comparison above with a value-equality check. May not be that important if we add a separate utility to reuse initializers as above, but if not this could serve so approximately.

Copy link
Collaborator Author

@justinchuby justinchuby Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be a graph pass too (dedup initializers)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise ValueError(
f"Initializer '{value.name}' is already registered, but"
" it is not the same object: existing={self._initializers[value.name]!r},"
f" new={value!r}"
)
if not value.name:
raise ValueError(f"Initializer must have a name: {value!r}")
if value.producer() is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should support promotion of a Constant node to an initializer (later on perhaps; could be an option in this method or a separate one).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that can be a graph pass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise ValueError(

Check warning on line 1854 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1854

Added line #L1854 was not covered by tests
f"Value '{value!r}' is produced by a node and cannot be an initializer."
)
if value.const_value is None:
raise ValueError(
f"Value '{value!r}' must have its const_value set to be an initializer."
)
self._initializers[value.name] = value

@property
def doc_string(self) -> str | None:
return self._doc_string
Expand Down
24 changes: 24 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,30 @@ def test_remove_safe_removes_uses_of_removed_nodes(self):
self.assertEqual(tuple(graph), (sub_node, identity_node))
self.assertEqual(add_node.inputs, (None, None))

def test_register_initializer(self):
self.v1.const_value = ir.tensor([1, 2, 3])
self.graph.register_initializer(self.v1)
self.assertEqual(self.graph.initializers, {self.v1.name: self.v1})

def test_register_initializer_raises_when_value_is_not_constant(self):
with self.assertRaises(ValueError):
self.graph.register_initializer(self.v0)

def test_register_initializer_raises_when_a_different_value_is_already_registered(self):
self.v1.const_value = ir.tensor([1, 2, 3])
self.graph.register_initializer(self.v1)
# This is fine
self.graph.register_initializer(self.v1)
self.v0.name = "v1"
with self.assertRaisesRegex(ValueError, "already registered"):
# Registering a different value with the same name should raise
self.graph.register_initializer(self.v0)

def test_register_initializer_raises_when_value_does_not_have_a_name(self):
self.v1.name = None
with self.assertRaises(ValueError):
self.graph.register_initializer(self.v1)

# TODO(justinchuby): Test graph mutation methods

# Test topological sort.
Expand Down
Loading