Skip to content

Commit

Permalink
Merge branch 'main' into rama/fusions
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Nov 15, 2024
2 parents b6f0071 + d81480b commit b985bb1
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 16 deletions.
111 changes: 111 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,42 @@ def outputs(self) -> list[Value]:
def initializers(self) -> dict[str, Value]:
return self._initializers

def register_initializer(self, value: Value) -> None:
"""Register an initializer to the graph.
This is a convenience method to register an initializer to the graph with
checks.
Args:
value: The :class:`Value` to register as an initializer of the graph.
It must have its ``.const_value`` set.
Raises:
ValueError: If a value of the same name that is not this value
is already registered.
ValueError: If the value does not have a name.
ValueError: If the initializer is produced by a node.
ValueError: If the value does not have its ``.const_value`` set.
"""
if value.name in self._initializers:
if self._initializers[value.name] is not value:
raise ValueError(
f"Initializer '{value.name}' is already registered, but"
" it is not the same object: existing={self._initializers[value.name]!r},"
f" new={value!r}"
)
if not value.name:
raise ValueError(f"Initializer must have a name: {value!r}")
if value.producer() is not None:
raise ValueError(
f"Value '{value!r}' is produced by a node and cannot be an initializer."
)
if value.const_value is None:
raise ValueError(
f"Value '{value!r}' must have its const_value set to be an initializer."
)
self._initializers[value.name] = value

@property
def doc_string(self) -> str | None:
return self._doc_string
Expand Down Expand Up @@ -2715,6 +2751,81 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"

# Well typed getters
def as_float(self) -> float:
"""Get the attribute value as a float."""
# Do not use isinstance check because it may prevent np.float32 etc. from being used
return float(self.value)

def as_int(self) -> int:
"""Get the attribute value as an int."""
# Do not use isinstance check because it may prevent np.int32 etc. from being used
return int(self.value)

def as_string(self) -> str:
"""Get the attribute value as a string."""
if not isinstance(self.value, str):
raise TypeError(f"Value of attribute '{self!r}' is not a string.")
return self.value

def as_tensor(self) -> _protocols.TensorProtocol:
"""Get the attribute value as a tensor."""
if not isinstance(self.value, _protocols.TensorProtocol):
raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
return self.value

def as_graph(self) -> Graph:
"""Get the attribute value as a graph."""
if not isinstance(self.value, Graph):
raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
return self.value

def as_floats(self) -> Sequence[float]:
"""Get the attribute value as a sequence of floats."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
# Create a copy of the list to prevent mutation
return [float(v) for v in self.value]

def as_ints(self) -> Sequence[int]:
"""Get the attribute value as a sequence of ints."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
# Create a copy of the list to prevent mutation
return list(self.value)

def as_strings(self) -> Sequence[str]:
"""Get the attribute value as a sequence of strings."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if not all(isinstance(x, str) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
# Create a copy of the list to prevent mutation
return list(self.value)

def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
"""Get the attribute value as a sequence of tensors."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
# Create a copy of the list to prevent mutation
return list(self.value)

def as_graphs(self) -> Sequence[Graph]:
"""Get the attribute value as a sequence of graphs."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if not all(isinstance(x, Graph) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
# Create a copy of the list to prevent mutation
return list(self.value)


# NOTE: The following functions are just for convenience
def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
Expand Down
78 changes: 78 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,30 @@ def test_remove_safe_removes_uses_of_removed_nodes(self):
self.assertEqual(tuple(graph), (sub_node, identity_node))
self.assertEqual(add_node.inputs, (None, None))

def test_register_initializer(self):
self.v1.const_value = ir.tensor([1, 2, 3])
self.graph.register_initializer(self.v1)
self.assertEqual(self.graph.initializers, {self.v1.name: self.v1})

def test_register_initializer_raises_when_value_is_not_constant(self):
with self.assertRaises(ValueError):
self.graph.register_initializer(self.v0)

def test_register_initializer_raises_when_a_different_value_is_already_registered(self):
self.v1.const_value = ir.tensor([1, 2, 3])
self.graph.register_initializer(self.v1)
# This is fine
self.graph.register_initializer(self.v1)
self.v0.name = "v1"
with self.assertRaisesRegex(ValueError, "already registered"):
# Registering a different value with the same name should raise
self.graph.register_initializer(self.v0)

def test_register_initializer_raises_when_value_does_not_have_a_name(self):
self.v1.name = None
with self.assertRaises(ValueError):
self.graph.register_initializer(self.v1)

# TODO(justinchuby): Test graph mutation methods

# Test topological sort.
Expand Down Expand Up @@ -1061,5 +1085,59 @@ def test_composite_type_is_comparable(self, _: str, type_: ir.TypeProtocol):
self.assertEqual(type_, copy.deepcopy(type_))


class AttrTest(unittest.TestCase):
"""Test the Attr class."""

def test_init(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42, doc_string="test string")
self.assertEqual(attr.name, "test")
self.assertEqual(attr.value, 42)
self.assertEqual(attr.type, ir.AttributeType.INT)
self.assertEqual(attr.doc_string, "test string")

def test_as_float(self):
attr = _core.Attr("test", ir.AttributeType.FLOAT, 42.0)
self.assertEqual(attr.as_float(), 42.0)

attr_int_value = _core.Attr("test", ir.AttributeType.FLOAT, 42)
self.assertEqual(attr_int_value.as_float(), 42.0)

def test_as_int(self):
attr = _core.Attr("test", ir.AttributeType.INT, 0)
self.assertEqual(attr.as_int(), 0)

def test_as_string(self):
attr = _core.Attr("test", ir.AttributeType.STRING, "test string")
self.assertEqual(attr.as_string(), "test string")

def test_as_tensor(self):
attr = _core.Attr("test", ir.AttributeType.TENSOR, ir.tensor([42.0]))
np.testing.assert_equal(attr.as_tensor().numpy(), np.array([42.0]))

def test_as_graph(self):
attr = _core.Attr("test", ir.AttributeType.GRAPH, _core.Graph((), (), nodes=()))
self.assertIsInstance(attr.as_graph(), _core.Graph)

def test_as_floats(self):
attr = _core.Attr("test", ir.AttributeType.FLOATS, [42.0])
self.assertEqual(attr.as_floats(), [42.0])

def test_as_ints(self):
attr = _core.Attr("test", ir.AttributeType.INTS, [42])
self.assertEqual(attr.as_ints(), [42])

def test_as_strings(self):
attr = _core.Attr("test", ir.AttributeType.STRINGS, ["test string", ""])
self.assertEqual(attr.as_strings(), ["test string", ""])

def test_as_tensors(self):
attr = _core.Attr("test", ir.AttributeType.TENSORS, [ir.tensor([42.0])])
np.testing.assert_equal(attr.as_tensors()[0].numpy(), np.array([42.0]))

def test_as_graphs(self):
attr = _core.Attr("test", ir.AttributeType.GRAPHS, [_core.Graph((), (), nodes=())])
self.assertIsInstance(attr.as_graphs()[0], _core.Graph)


if __name__ == "__main__":
unittest.main()
34 changes: 30 additions & 4 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class Replacement:
class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, Any] = {}
self._initializer_inputs: list[set[ir.Value]] = []

def get_sym_value(self, value: ir.Value | None) -> Any:
if value is None:
Expand All @@ -147,6 +148,19 @@ def get_sym_value(self, value: ir.Value | None) -> Any:
def set_sym_value(self, value: ir.Value, sym_value: Any) -> None:
self._sym_value_map[value] = sym_value

def push_initializer_inputs(self) -> None:
self._initializer_inputs.append(set())

def pop_initializer_inputs(self) -> None:
self._initializer_inputs.pop()

def add_initializer_input(self, value: ir.Value) -> None:
assert self._initializer_inputs
self._initializer_inputs[-1].add(value)

def is_initializer_input(self, value: ir.Value) -> bool:
return any(value in inputs for inputs in self._initializer_inputs)


# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).
Expand Down Expand Up @@ -377,7 +391,7 @@ def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
if graph_attr.type != ir.AttributeType.GRAPH:
return None
assert isinstance(graph_attr, ir.Attr)
graph: ir.Graph = graph_attr.value
graph = graph_attr.as_graph()
formal_outs = graph.outputs
actual_outs = node.outputs
renamings = {
Expand Down Expand Up @@ -784,6 +798,9 @@ def process_node(self, node: ir.Node):
if any(x is None for x in input_values):
return None

if any(self._state.is_initializer_input(x) for x in node.inputs): # type: ignore[arg-type]
return None

if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr]
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [input.size for input in input_values] # type: ignore[union-attr]
Expand Down Expand Up @@ -831,10 +848,10 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None:
if isinstance(attr, ir.Attr):
if attr.type == ir.AttributeType.GRAPH:
self.visit_graph(attr.value) # type: ignore[arg-type]
self.visit_graph(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.value:
self.visit_graph(graph) # type: ignore[arg-type]
for graph in attr.as_graphs():
self.visit_graph(graph)

def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function):
replacement = self.process_node(node)
Expand All @@ -847,9 +864,18 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function):
self.replace_node(node, replacement, root)

def visit_graph(self, graph: ir.Graph) -> None:
# Track inputs that have a const_value (which is really a default-value, and should not
# be used for constant-folding).
self._state.push_initializer_inputs()
for input in graph.inputs:
if input.const_value is not None:
self._state.add_initializer_input(input)

for node in graph:
self.visit_node(node, graph)

self._state.pop_initializer_inputs()

def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)
Expand Down
24 changes: 24 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import parameterized
import pytest

import onnxscript.ir as ir
import onnxscript.optimizer as optimizer
from onnxscript.ir import serde
from onnxscript.optimizer import _constant_folding
Expand Down Expand Up @@ -450,5 +451,28 @@ def test_expand_identity(self):
self.assertEqual(optimized.graph.node[-1].op_type, "Identity")


class FoldConstantsIrTest(unittest.TestCase):
def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model:
model_proto = onnx.parser.parse_model(model_text)
model = serde.deserialize_model(model_proto)
_constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
optimizer.remove_unused_nodes(model)
return model

def test_initializer_input_not_folded(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z)
{
# c is not a constant, and following should not be folded.
two_c = Add (c, c)
z = Mul (x, two_c)
}
"""
optimized = self._fold(model_text)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph.node(0).op_type, "Add")


if __name__ == "__main__":
unittest.main()
12 changes: 6 additions & 6 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None:
def clone_attr(self, key: str, attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr | None:
if isinstance(attr, ir.Attr):
if attr.type == ir.AttributeType.GRAPH:
graph = self.clone_graph(attr.value)
graph = self.clone_graph(attr.as_graph())
return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string)
elif attr.type == ir.AttributeType.GRAPHS:
graphs = [self.clone_graph(graph) for graph in attr.value]
graphs = [self.clone_graph(graph) for graph in attr.as_graphs()]
return ir.Attr(
key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
)
Expand Down Expand Up @@ -236,9 +236,9 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl

# Identify call-stack for node, used to generate unique names.
call_stack = self.node_context.get(node, [])
call_stack.append(call_site_id)
new_call_stack = [*call_stack, call_site_id]

cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, call_stack)
cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack)

# iterate over the nodes in the function, creating a copy of each node
# and replacing inputs with the corresponding values in the value map.
Expand Down Expand Up @@ -297,9 +297,9 @@ def inline_calls_in(self, graph: ir.Graph) -> None:
if not isinstance(attr, ir.Attr):
continue
if attr.type == ir.AttributeType.GRAPH:
self.inline_calls_in(attr.value)
self.inline_calls_in(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.value:
for graph in attr.as_graphs():
self.inline_calls_in(graph)


Expand Down
4 changes: 2 additions & 2 deletions onnxscript/optimizer/_remove_unused.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int:
if not isinstance(attr, ir.Attr):
continue
if attr.type == ir.AttributeType.GRAPH:
count += process_function_or_graph(attr.value)
count += process_function_or_graph(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.value:
for graph in attr.as_graphs():
count += process_function_or_graph(graph)
return count

Expand Down
Loading

0 comments on commit b985bb1

Please sign in to comment.