-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] First version of fusion optimizations for transformers #1938
Draft
gramalingam
wants to merge
25
commits into
main
Choose a base branch
from
rama/fusions
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+543
−0
Draft
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
d751d37
Fusions
gramalingam 8c4dff5
MultiHeadAttention fusion
gramalingam 5d3c9af
Merge branch 'main' into rama/fusions
gramalingam 4d3ff90
Move transformers optimization into onnxruntime folder
gramalingam 4a667f9
Support some SDPA variations
gramalingam 33c3753
Add variations of rules for SDPA
gramalingam 404e5c3
Add attention scale validation
gramalingam e98682f
Add validation conditions for rotary embedding
gramalingam 001bb59
Add tests
gramalingam 40b9052
Move into new xformers folder
gramalingam 94ce2f3
Add dropout to optimizer
gramalingam bf3b64a
Run lint
gramalingam 0491366
Undo dropout rewrite rule change
gramalingam 3fb7cd1
Add concat test
gramalingam f25b669
Merge with main
gramalingam 73723f0
Add expand identity optimization
gramalingam bb977ec
Some cleanup
gramalingam 7f1606f
Fix dropout optimization
gramalingam a3e0d1d
Some more cleanup
gramalingam 0879934
Cleanup
gramalingam a8ac3ee
Minor fixes
gramalingam 044a638
Add ort check to test
gramalingam b6f0071
Testing changes
gramalingam b985bb1
Merge branch 'main' into rama/fusions
gramalingam 0f35c45
Merge with main
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import ( | ||
mha_rules as mha_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import ( | ||
rms_normalization_rules as rms_normalization_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import ( | ||
rotary_embedding_rules as rotary_embedding_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules | ||
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import ( | ||
skip_normalization_rules as skip_normalization_rules, | ||
) |
38 changes: 38 additions & 0 deletions
38
onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warning Code scanning / lintrunner RUFF-FORMAT/format Warning
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes | ||
from onnxscript.rewriter.onnxruntime.xformers import ( | ||
mha_rules, | ||
rms_normalization_rules, | ||
rotary_embedding_rules, | ||
sdpa_rules, | ||
skip_normalization_rules, | ||
) | ||
|
||
|
||
def fuse_rotary_embedding(irmodel: ir.Model) -> None: | ||
count = rotary_embedding_rules.apply_to_model(irmodel) | ||
print(f"RotaryEmbedding count: {count}") | ||
|
||
def optimize(irmodel: ir.Model, verbose: int = 0) -> None: | ||
def apply(rulename: str, rule): | ||
count = rule.apply_to_model(irmodel, verbose=verbose) | ||
print(f"{rulename} count: {count}") | ||
|
||
fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4) | ||
remove_unused_nodes(irmodel) | ||
|
||
apply("RMS Normalization", rms_normalization_rules) | ||
apply("Skip Normalization", skip_normalization_rules) | ||
|
||
fold_constants_ir(irmodel) | ||
remove_unused_nodes(irmodel) | ||
|
||
apply("SDPA-Attention", sdpa_rules) | ||
apply("RotaryEmbedding", rotary_embedding_rules) | ||
apply("Multi-Head-Attention", mha_rules) | ||
|
||
remove_unused_nodes(irmodel) |
152 changes: 152 additions & 0 deletions
152
onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
|
||
import os | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
import onnxruntime | ||
import torch | ||
import transformers.models.llama.modeling_llama as modeling_llama | ||
from parameterized import parameterized | ||
from transformers import LlamaConfig | ||
|
||
import onnxscript.ir._io as io | ||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers import ( | ||
_optimize_transformers as optimize_transformers, | ||
) | ||
|
||
# Create a LlamaConfig object with the desired parameters | ||
_config = LlamaConfig( | ||
_name_or_path="HuggingFaceTB/SmolLM-1.7B", | ||
architectures=["LlamaForCausalLM"], | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
bos_token_id=0, | ||
eos_token_id=0, | ||
hidden_act="silu", | ||
hidden_size=2048, | ||
initializer_range=0.02, | ||
intermediate_size=8192, | ||
max_position_embeddings=2048, | ||
model_type="llama", | ||
num_attention_heads=32, | ||
num_hidden_layers=24, | ||
num_key_value_heads=32, | ||
pretraining_tp=1, | ||
rms_norm_eps=1e-05, | ||
rope_scaling=None, | ||
rope_theta=10000.0, | ||
tie_word_embeddings=True, | ||
torch_dtype="float32", | ||
transformers_version="4.37.2", | ||
use_cache=True, | ||
vocab_size=49152, | ||
) | ||
|
||
# Dimensions for inputs: | ||
_batch_size = 1 | ||
_seq_len = 10 | ||
_hidden_size = _config.hidden_size | ||
_num_attention_heads = _config.num_attention_heads | ||
dim = _hidden_size // _num_attention_heads | ||
|
||
# Generate inputs: | ||
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32) | ||
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32)) | ||
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len) | ||
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10) | ||
|
||
# Get model in ONNX format | ||
# def _get_model(llama_attention_class, with_mask: bool): | ||
# model = llama_attention_class(_config, 0) | ||
# if with_mask: | ||
# inputs = (_hidden_states, _attention_mask, _position_ids) | ||
# else: | ||
# inputs = (_hidden_states, None, _position_ids) | ||
# exported = torch.onnx.export(model, inputs, dynamo=True) | ||
# # ORT Transformer optimizations are applied after basic optimization. | ||
# onnxscript.optimizer.optimize(exported.model) | ||
# return exported.model | ||
|
||
class _TestData: | ||
def __init__(self, name: str, attention_class, with_mask: bool): | ||
self.name = name | ||
self.attention_class = attention_class | ||
self.with_mask = with_mask | ||
|
||
def get_torch_model(self): | ||
return self.attention_class(_config, 0) | ||
|
||
def get_onnx_model(self): | ||
model = self.get_torch_model() | ||
inputs = self.get_inputs() | ||
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] | ||
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True) | ||
# ORT Transformer optimizations are applied after basic optimization. | ||
onnxscript.optimizer.optimize(exported.model) | ||
return exported.model | ||
|
||
def get_inputs(self): | ||
if self.with_mask: | ||
return (_hidden_states, _attention_mask, _position_ids) | ||
else: | ||
return (_hidden_states, None, _position_ids) | ||
|
||
def get_torch_outputs(self): | ||
return self.get_torch_model()(*self.get_inputs()) | ||
|
||
def get_ort_inputs(self): | ||
inputs = self.get_inputs() | ||
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None} | ||
|
||
_test_cases = [ | ||
_TestData("attention", modeling_llama.LlamaAttention, False), | ||
_TestData("masked_attention", modeling_llama.LlamaAttention, True), | ||
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False), | ||
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True), | ||
] | ||
|
||
_test_case_tuples = [ (t,) for t in _test_cases] | ||
|
||
def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2): | ||
providers = ["CPUExecutionProvider"] | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
model_path = os.path.join(temp_dir, f"{model_name}.onnx") | ||
io.save(model, model_path) | ||
# Run optimized model | ||
session = onnxruntime.InferenceSession(model_path, providers=providers) | ||
ort_outputs = session.run(None, inputs) | ||
|
||
for i, (baseline_output, optimized_output) in enumerate( | ||
zip(expected_outputs, ort_outputs) | ||
): | ||
try: | ||
np.testing.assert_equal(baseline_output.shape, optimized_output.shape) | ||
np.testing.assert_allclose( | ||
baseline_output, optimized_output, rtol=rtol, atol=atol | ||
) | ||
except AssertionError as e: | ||
print( | ||
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" | ||
) | ||
raise | ||
|
||
class TestOptimizeTransformers(unittest.TestCase): | ||
@parameterized.expand(_test_case_tuples) | ||
def test_attention_optimization(self, test_data: _TestData): | ||
model = test_data.get_onnx_model() | ||
# model.display() | ||
# print("======>") | ||
optimize_transformers.fuse_rotary_embedding(model) | ||
# model.display() | ||
op_types = [n.op_type for n in model.graph] | ||
self.assertIn("RotaryEmbedding", op_types) | ||
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this IR? If so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. But when I look at the signature here, it is not clear this is supported. The example illustrates it, though. I see it eventually calls np.array constructor if nothing else works, so I understand it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. We can update the signature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually covered by npt.ArrayLike (the first)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I tried and it failed, rejecting a list. BTW, I have moved the independent parts of this PR into a separate PR: #1947
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I need to fix that then