Skip to content

Commit

Permalink
Handle input initializers correctly in constant folding (#1944)
Browse files Browse the repository at this point in the history
Values that are both inputs and initializers of a model/graph should not
be treated as constants (and cannot be used for constant-folding).
Unfortunately, the single `const_value` field is class Value is used
both to indicate constant-values of proper constants as well as
initializer values of initializers. Ideally, the IR should provide an
easy way to distinguish this at the value level (with either an extra
boolean flag to indicate the value is an input-value or by using
distinct fields for "initializer_value" and "const_value".

Meanwhile, this PR introduces a workaround to handle the main issue.
  • Loading branch information
gramalingam authored Nov 14, 2024
1 parent 1cfe0ca commit 5a35958
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
26 changes: 26 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class Replacement:
class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, Any] = {}
self._initializer_inputs: list[set[ir.Value]] = []

def get_sym_value(self, value: ir.Value | None) -> Any:
if value is None:
Expand All @@ -146,6 +147,19 @@ def get_sym_value(self, value: ir.Value | None) -> Any:
def set_sym_value(self, value: ir.Value, sym_value: Any) -> None:
self._sym_value_map[value] = sym_value

def push_initializer_inputs(self) -> None:
self._initializer_inputs.append(set())

def pop_initializer_inputs(self) -> None:
self._initializer_inputs.pop()

def add_initializer_input(self, value: ir.Value) -> None:
assert self._initializer_inputs
self._initializer_inputs[-1].add(value)

def is_initializer_input(self, value: ir.Value) -> bool:
return any(value in inputs for inputs in self._initializer_inputs)


# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).
Expand Down Expand Up @@ -754,6 +768,9 @@ def process_node(self, node: ir.Node):
if any(x is None for x in input_values):
return None

if any(self._state.is_initializer_input(x) for x in node.inputs): # type: ignore[arg-type]
return None

if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr]
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [input.size for input in input_values] # type: ignore[union-attr]
Expand Down Expand Up @@ -817,9 +834,18 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function):
self.replace_node(node, replacement, root)

def visit_graph(self, graph: ir.Graph) -> None:
# Track inputs that have a const_value (which is really a default-value, and should not
# be used for constant-folding).
self._state.push_initializer_inputs()
for input in graph.inputs:
if input.const_value is not None:
self._state.add_initializer_input(input)

for node in graph:
self.visit_node(node, graph)

self._state.pop_initializer_inputs()

def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)
Expand Down
24 changes: 24 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import parameterized
import pytest

import onnxscript.ir as ir
import onnxscript.optimizer as optimizer
from onnxscript.ir import serde
from onnxscript.optimizer import _constant_folding
Expand Down Expand Up @@ -434,5 +435,28 @@ def test_concat_identity(self):
self.assertEqual(optimized.graph.node[0].op_type, "Identity")


class FoldConstantsIrTest(unittest.TestCase):
def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model:
model_proto = onnx.parser.parse_model(model_text)
model = serde.deserialize_model(model_proto)
_constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
optimizer.remove_unused_nodes(model)
return model

def test_initializer_input_not_folded(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z)
{
# c is not a constant, and following should not be folded.
two_c = Add (c, c)
z = Mul (x, two_c)
}
"""
optimized = self._fold(model_text)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph.node(0).op_type, "Add")


if __name__ == "__main__":
unittest.main()

0 comments on commit 5a35958

Please sign in to comment.