Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rewriter] Remove redundant op.Slice and op.ScatterND #1925

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
no_op,
)
Expand All @@ -21,6 +22,7 @@
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule,
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
]


Expand Down
140 changes: 140 additions & 0 deletions onnxscript/rewriter/collapse_slices.py
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging

from onnxscript import ir
from onnxscript.rewriter import pattern

logger = logging.getLogger(__name__)
_INT64_MAX = 9223372036854775807


def _check_if_redundant_slice(
context,
data: ir.Value,
starts: ir.Value,
ends: ir.Value,
axes: ir.Value,
steps: ir.Value,
**_,
):
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
"""If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant."""
del context # Reserved for future extensions

starts_const = starts.const_value
ends_const = ends.const_value
axes_const = axes.const_value
steps_const = steps.const_value

# Check if the values are scalar
if starts_const.numpy().size != 1:
Fixed Show fixed Hide fixed
logger.info("The value 'start' is not a scalar.")
return False
if ends_const.numpy().size != 1:
Fixed Show fixed Hide fixed
logger.info("The value 'end' is not a scalar.")
return False
if axes_const.numpy().size != 1:
Fixed Show fixed Hide fixed
logger.info("The value 'axis' is not a scalar.")
return False
if steps_const.numpy().size != 1:
Fixed Show fixed Hide fixed
logger.info("The value 'step' is not a scalar.")
return False

if starts_const is None or ends_const is None or axes_const is None or steps_const is None:
logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.")
return False
if steps_const.numpy().item() != 1:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
logger.info("The value 'step' is not 1.")
return False
# starts is 0
if starts_const.numpy().item() != 0:
logger.info("The value 'start' is not 0.")
return False
# In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize
if ends_const.numpy().item() == _INT64_MAX:
return True
if data.shape is None:
logger.info("The value 'data' shape is not statically known.")
return False
if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
logger.info("The value 'end' is less than the shape of the specified axis.")
return False

return True


def _identity_to_itself(op, data, **_):
"""Return the input data as the output."""
return op.Identity(data)


def _identity_to_updates(op, data, indices, updates, **_):
"""Return the updates as the output.

This is used when the ScatterND is redundant in terms of
updating the whole data with the updates.

"""
return op.Identity(updates)


def _potential_redundant_slice(op, data, starts, ends, axes, steps):
"""To identify a slice op"""
return op.Slice(data, starts, ends, axes, steps)


def _potential_redundant_scatternd(op, data, indices, updates):
"""To identify a ScatterND op"""
return op.ScatterND(data, indices, updates)


def _check_if_redundant_scatternd(
context,
data: ir.Value,
indices: ir.Value,
updates: ir.Value,
**_,
):
"""If the indices is the same length as the first dim of data, and the shape of updates is equal to data, we can simply swap the whole value."""
del context # Reserved for future extensions

# To validate data can be replaced directly by updates, we need to check the following:
# 1. they have the same shape
if data.shape is None:
logger.info("The value 'data' shape is not statically known.")
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return False
if updates.shape is None:
logger.info("The value 'updates' shape is not statically known.")
return False
if data.shape != updates.shape:
logger.info("The shape of 'data' and 'updates' are different.")
return False

# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
if indices.const_value is None:
logger.info("The value 'indices' is not statically known.")
return False
if indices.const_value.numpy().tolist() != [[i] for i in range(data.shape[0])]: # type: ignore[arg-type]
logger.info("The 'indices' is not referring to the whole data.")
return False

return True


# Register the rewrite rules
remove_redundant_slice = pattern.RewriteRule(
_potential_redundant_slice,
_identity_to_itself,
_check_if_redundant_slice,
)

remove_redundant_scatternd = pattern.RewriteRule(
_potential_redundant_scatternd,
_identity_to_updates,
_check_if_redundant_scatternd,
)

# NOTE: The order of the rules is important. Larger pattern should be checked first.
rules = pattern.RewriteRuleSet([remove_redundant_slice, remove_redundant_scatternd])
100 changes: 100 additions & 0 deletions onnxscript/rewriter/collapse_slices_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import onnx.parser
import onnx.shape_inference

from onnxscript import ir
from onnxscript.rewriter import collapse_slices, testing

_INT64_MAX = 9223372036854775807


class TwoReshapesMatMulReshapeTest(unittest.TestCase):
def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[512, 16, 112] data) => (float[512, 16, 112] output)
{
starts = Constant<value: tensor = int64[1] {0}>()
ends = Constant<value: tensor = int64[1] {9999}>()
axes = Constant<value: tensor = int64[1] {0}>()
steps = Constant<value: tensor = int64[1] {1}>()
output = Slice (data, starts, ends, axes, steps)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 5)
self.assertIn("Identity", [node.op_type for node in model.graph])
testing.assert_numerically_equal(
model_proto,
ir.serde.serialize_model(model),
(np.random.rand(512, 16, 112).astype(np.float32),),
)

def test_slice_is_redundant_when_ends_reaches_int64_max(self):
model_proto = onnx.parser.parse_model(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[512, 16, 112] data) => (float[512, 16, 112] output)
{{
starts = Constant<value: tensor = int64[1] {{0}}>()
ends = Constant<value: tensor = int64[1] {{{_INT64_MAX}}}>()
axes = Constant<value: tensor = int64[1] {{0}}>()
steps = Constant<value: tensor = int64[1] {{1}}>()
output = Slice (data, starts, ends, axes, steps)
}}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 5)
self.assertIn("Identity", [node.op_type for node in model.graph])
testing.assert_numerically_equal(
model_proto,
ir.serde.serialize_model(model),
(np.random.rand(512, 16, 112).astype(np.float32),),
)

def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output)
{
output = ScatterND (data, indices, updates)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
indices = np.arange(112).reshape(112, 1)
model = ir.serde.deserialize_model(model_proto)
# from numpy to ir.Tensor
indices_ir_tensor = ir.Tensor(
name="indices",
value=indices,
)
# assign the tensor to a value
indices = model.graph[0].inputs[1]
indices.const_value = indices_ir_tensor
model.graph.initializers["indices"] = indices
original_model_proto = ir.serde.serialize_model(model)

count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(len(model.graph), 1)
self.assertIn("Identity", [node.op_type for node in model.graph])

input = np.random.rand(112, 16, 512).astype(np.float32)
testing.assert_numerically_equal(
original_model_proto, ir.serde.serialize_model(model), (input, input)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
)
2 changes: 1 addition & 1 deletion onnxscript/rewriter/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def dropout_inference(op, x):


# Replacement
def identity(op, x):
def identity(op, x, **_):
return op.Identity(x)


Expand Down
52 changes: 52 additions & 0 deletions onnxscript/rewriter/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import onnxruntime as ort
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
import numpy as np
import onnx
from typing import Any

def assert_numerically_equal(original_model_proto: onnx.ModelProto, the_rewritten_model_proto: onnx.ModelProto,
args: tuple[Any, ...],
rtol: float = 1,
atol: float = 1e-3,
):
"""Assert that the two models are numerically equal.

Args:
original_model_proto (onnx.ModelProto): The original model proto.
the_rewritten_model_proto (onnx.ModelProto): The rewritten by the rules model proto.
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
rtol: Relative tolerance.
atol: Absolute tolerance.
args: The positional arguments to pass to the model.
"""
original_proto_ort_inputs = {
k.name: v for k, v in zip(original_model_proto.graph.input, args)
}
original_proto_ort_inference_session = _ort_session_initializer(original_model_proto.SerializeToString())
run_options = ort.RunOptions()
run_options.log_severity_level = 3 # 3: Error
original_outputs = original_proto_ort_inference_session.run(None, original_proto_ort_inputs, run_options=run_options)

the_rewritten_proto_ort_inputs = {
k.name: v for k, v in zip(the_rewritten_model_proto.graph.input, args)
}
the_rewritten_proto_ort_inference_session = _ort_session_initializer(the_rewritten_model_proto.SerializeToString())
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run(None, the_rewritten_proto_ort_inputs, run_options=run_options)

np.testing.assert_allclose(original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True)

def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession:
"""Initialize an ONNX Runtime inference session with the specified model."""
import onnxruntime as ort

session_options = ort.SessionOptions()
session_options.log_severity_level = 3 # 3: Error
possible_providers = (
"CUDAExecutionProvider",
"CPUExecutionProvider",
)
available_providers = set(ort.get_available_providers())
providers = [
provider for provider in possible_providers if provider in available_providers
]
return ort.InferenceSession(
model, providers=providers, sess_options=session_options
)
Loading