Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rewriter] Remove redundant op.Slice and op.ScatterND #1925

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
no_op,
)
Expand All @@ -21,6 +22,7 @@
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule,
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
]


Expand Down
114 changes: 114 additions & 0 deletions onnxscript/rewriter/collapse_slices.py
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging
import sys

from onnxscript import ir
from onnxscript.rewriter import pattern

logger = logging.getLogger(__name__)


def _check_if_redundant_slice(
context,
data: ir.Value,
starts: ir.Value,
ends: ir.Value,
axes: ir.Value,
steps: ir.Value,
**_,
):
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
"""If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant."""
del context # Reserved for future extensions

starts_const = starts.const_value
ends_const = ends.const_value
axes_const = axes.const_value
steps_const = steps.const_value

if starts_const is None or ends_const is None or axes_const is None or steps_const is None:
logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.")
return False

Check warning on line 33 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L32-L33

Added lines #L32 - L33 were not covered by tests
if steps_const.numpy().item() != 1:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return False

Check warning on line 35 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L35

Added line #L35 was not covered by tests
# starts is 0
if starts_const.numpy().item() != 0:
return False
# In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize
if ends_const.numpy().item() == sys.maxsize:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return True
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
if data.shape is None:
logger.info("The value 'data' shape is not statically known.")
return False

Check warning on line 44 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L43-L44

Added lines #L43 - L44 were not covered by tests
if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return False

return True


def _identity_to_itself(op, data, **_):
return op.Identity(data)


def _identity_to_updates(op, data, indices, updates, **_):
return op.Identity(updates)


def _potential_redundant_slice(op, data, starts, ends, axes, steps):
return op.Slice(data, starts, ends, axes, steps)


def _potential_redundant_scatternd(op, data, indices, updates):
return op.ScatterND(data, indices, updates)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved


def _check_if_redundant_scatternd(
context,
data: ir.Value,
indices: ir.Value,
updates: ir.Value,
**_,
):
"""If the indices is the same length as the first dim of data, and the shape of updates is equal to data, we can simply swap the whole value."""
del context # Reserved for future extensions

# To validate data can be replaced directly by updates, we need to check the following:
# 1. they have the same shape
if data.shape is None:
logger.info("The value 'data' shape is not statically known.")
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return False

Check warning on line 81 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L80-L81

Added lines #L80 - L81 were not covered by tests
if updates.shape is None:
logger.info("The value 'updates' shape is not statically known.")
return False

Check warning on line 84 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L83-L84

Added lines #L83 - L84 were not covered by tests
if data.shape != updates.shape:
logger.info("The shape of 'data' and 'updates' are different.")
return False

# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
if indices.const_value is None:
logger.info("The value 'indices' is not statically known.")
return False

Check warning on line 92 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L91-L92

Added lines #L91 - L92 were not covered by tests
if indices.const_value.numpy().tolist() != [[i] for i in range(data.shape[0])]: # type: ignore[arg-type]
logger.info("The 'indices' is not referring to the whole data.")
return False

Check warning on line 95 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L94-L95

Added lines #L94 - L95 were not covered by tests

return True


# Register the rewrite rules
remove_redundant_slice = pattern.RewriteRule(
_potential_redundant_slice,
_identity_to_itself,
_check_if_redundant_slice,
)

remove_redundant_scatternd = pattern.RewriteRule(
_potential_redundant_scatternd,
_identity_to_updates,
_check_if_redundant_scatternd,
)

# NOTE: The order of the rules is important. Larger pattern should be checked first.
rules = pattern.RewriteRuleSet([remove_redundant_slice, remove_redundant_scatternd])
78 changes: 78 additions & 0 deletions onnxscript/rewriter/collapse_slices_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import onnx.parser
import onnx.shape_inference

from onnxscript import ir
from onnxscript.rewriter import collapse_slices

_INT64_MAX = 9223372036854775807


class TwoReshapesMatMulReshapeTest(unittest.TestCase):
def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[512, 16, 112, 112] data) => (float[512, 16, 112, 112] output)
{
starts = Constant<value: tensor = int64[1] {0}>()
ends = Constant<value: tensor = int64[1] {9999}>()
axes = Constant<value: tensor = int64[1] {0}>()
steps = Constant<value: tensor = int64[1] {1}>()
output = Slice (data, starts, ends, axes, steps)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)

def test_slice_is_redundant_when_ends_reaches_int64_max(self):
model_proto = onnx.parser.parse_model(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[512, 16, 112, 112] data) => (float[512, 16, 112, 112] output)
{{
starts = Constant<value: tensor = int64[1] {{0}}>()
ends = Constant<value: tensor = int64[1] {{{_INT64_MAX}}}>()
axes = Constant<value: tensor = int64[1] {{0}}>()
steps = Constant<value: tensor = int64[1] {{1}}>()
output = Slice (data, starts, ends, axes, steps)
}}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)

def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[112, 16, 112, 512] data, float[112, 16, 112, 512] updates) => (float[112, 16, 112, 512] output)
{
output = ScatterND (data, indices, updates)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
indices = np.arange(112).reshape(112, 1)
model_proto.graph.initializer.extend(
[
onnx.helper.make_tensor(
"indices",
onnx.TensorProto.FLOAT16,
indices.shape,
indices,
),
]
)
model = ir.serde.deserialize_model(model_proto)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion onnxscript/rewriter/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def dropout_inference(op, x):


# Replacement
def identity(op, x):
def identity(op, x, **_):
return op.Identity(x)


Expand Down
Loading