Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] First version of fusion optimizations for transformers #1938

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

gramalingam
Copy link
Collaborator

  • Introduce fusion rules for SdpaAttention, RMS Normalization, Skip Normalization, Rotary Embedding, and Multi Head Attention
  • Replace Expand by Identity when applicable (in core optimization)
  • Cleanup Dropout Identity replacement in the case when Dropout has mask output
  • Make repeated (redundant) call to inliner efficient

Still TODO:

  • Multi Head Attention requires extra validation conditions
  • Need to cleanup use of "local" sub-patterns

@gramalingam gramalingam marked this pull request as draft November 9, 2024 01:23
Copy link

codecov bot commented Nov 9, 2024

❌ 10 Tests Failed:

Tests completed Failed Passed Skipped
10004 10 9994 3720
View the top 1 failed tests by shortest run time
::onnxscript.rewriter.onnxruntime.xformers._optimize_transformers_test
Stack Traces | 0s run time
No failure message available
View the full list of 2 ❄️ flaky tests
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_input_and_attribute_by_kwargs_out_of_order

Flake rate in main: 39.26% (Passed 7590 times, Failed 4905 times)

Stack Traces | 0.004s run time
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:91: in run
    res = self._run(x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests\eager_mode_test.py:115: in test_function_input_and_attribute_by_kwargs_out_of_order
    self.assertEqual(add_with_alpha(alpha=3.0, other=2.0, this=1.0), 7.0)
onnxscript\values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript\evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests\eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
onnxscript\onnx_opset\_impl\opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript\values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript\evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript\evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\reference_evaluator.py:599: in run
    outputs = node.run(*inputs, **linked_attributes)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_all_input_by_kwargs

Flake rate in main: 39.26% (Passed 7590 times, Failed 4905 times)

Stack Traces | 0.004s run time
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:91: in run
    res = self._run(x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests\eager_mode_test.py:109: in test_function_all_input_by_kwargs
    self.assertEqual(add_with_alpha(this=1.0, other=2.0), 3.0)
onnxscript\values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript\evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests\eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
onnxscript\onnx_opset\_impl\opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript\values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript\evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript\evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\reference_evaluator.py:599: in run
    outputs = node.run(*inputs, **linked_attributes)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').

To view more test analytics, go to the Test Analytics Dashboard
Got feedback? Let us know on Github

The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)

The dot-product attention is then computed using SDPA

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)

The dot-product attention is then computed using SDPA

Check warning

Code scanning / lintrunner

RUFF/W293 Warning



def _skip_normalization(op, input, skip, gamma, epsilon, stash_type):
normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable mean is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable


def _skip_normalization(op, input, skip, gamma, epsilon, stash_type):
normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable inv\_std\_var is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
if len(node.outputs) == 1:
return output
else:
true_tensor = onnx.helper.make_tensor("true", onnx.TensorProto.BOOL, [1], [True])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this IR? If so

Suggested change
true_tensor = onnx.helper.make_tensor("true", onnx.TensorProto.BOOL, [1], [True])
true_tensor = ir.tensor([True])

Copy link
Collaborator Author

@gramalingam gramalingam Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. But when I look at the signature here, it is not clear this is supported. The example illustrates it, though. I see it eventually calls np.array constructor if nothing else works, so I understand it now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. We can update the signature

Copy link
Collaborator

@justinchuby justinchuby Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually covered by npt.ArrayLike (the first)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I tried and it failed, rejecting a list. BTW, I have moved the independent parts of this PR into a separate PR: #1947

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I need to fix that then

@titaiwangms titaiwangms self-requested a review November 12, 2024 17:56
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
@@ -0,0 +1,152 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

2 participants