diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 4053bb2a1..fde8ec418 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -13,6 +13,7 @@ import numpy as np import onnx +import onnx.helper import onnx.reference.ops import onnxscript.ir as ir diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index bd353f388..eadb67f0a 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations +import numpy as np + import onnxscript.ir as ir from onnxscript.optimizer import basic_constant_propagation @@ -11,3 +13,24 @@ def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: if node is not None: basic_constant_propagation([node]) return value.const_value + + +def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: + if val is None: + return None + const_value = val.const_value + if const_value is not None: + try: + return const_value.numpy() + except FileNotFoundError: + # External data is not available. + return None + return None + + +def get_singleton_value(val: ir.Value | None): + """Returns element of a single element tensor constant value, and None otherwise.""" + np_val = get_numpy_value(val) + if np_val is not None and np_val.size == 1: + return np_val.item() + return None diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py new file mode 100644 index 000000000..dfd0df8da --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import ( + mha_rules as mha_rules, +) +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import ( + rms_normalization_rules as rms_normalization_rules, +) +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import ( + rotary_embedding_rules as rotary_embedding_rules, +) +from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import ( + skip_normalization_rules as skip_normalization_rules, +) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py new file mode 100644 index 000000000..7c5852c72 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes +from onnxscript.rewriter.onnxruntime.xformers import ( + mha_rules, + rms_normalization_rules, + rotary_embedding_rules, + sdpa_rules, + skip_normalization_rules, +) + + +def fuse_rotary_embedding(irmodel: ir.Model) -> None: + count = rotary_embedding_rules.apply_to_model(irmodel) + print(f"RotaryEmbedding count: {count}") + +def optimize(irmodel: ir.Model, verbose: int = 0) -> None: + def apply(rulename: str, rule): + count = rule.apply_to_model(irmodel, verbose=verbose) + print(f"{rulename} count: {count}") + + fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4) + remove_unused_nodes(irmodel) + + apply("RMS Normalization", rms_normalization_rules) + apply("Skip Normalization", skip_normalization_rules) + + fold_constants_ir(irmodel) + remove_unused_nodes(irmodel) + + apply("SDPA-Attention", sdpa_rules) + apply("RotaryEmbedding", rotary_embedding_rules) + apply("Multi-Head-Attention", mha_rules) + + remove_unused_nodes(irmodel) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py new file mode 100644 index 000000000..fef86f1fe --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import tempfile +import unittest + +import numpy as np +import onnxruntime +import torch +import transformers.models.llama.modeling_llama as modeling_llama +from parameterized import parameterized +from transformers import LlamaConfig + +import onnxscript.ir._io as io +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers import ( + _optimize_transformers as optimize_transformers, +) + +# Create a LlamaConfig object with the desired parameters +_config = LlamaConfig( + _name_or_path="HuggingFaceTB/SmolLM-1.7B", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=0, + eos_token_id=0, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=8192, + max_position_embeddings=2048, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=24, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype="float32", + transformers_version="4.37.2", + use_cache=True, + vocab_size=49152, +) + +# Dimensions for inputs: +_batch_size = 1 +_seq_len = 10 +_hidden_size = _config.hidden_size +_num_attention_heads = _config.num_attention_heads +dim = _hidden_size // _num_attention_heads + +# Generate inputs: +_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32) +_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32)) +_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len) +_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10) + +# Get model in ONNX format +# def _get_model(llama_attention_class, with_mask: bool): +# model = llama_attention_class(_config, 0) +# if with_mask: +# inputs = (_hidden_states, _attention_mask, _position_ids) +# else: +# inputs = (_hidden_states, None, _position_ids) +# exported = torch.onnx.export(model, inputs, dynamo=True) +# # ORT Transformer optimizations are applied after basic optimization. +# onnxscript.optimizer.optimize(exported.model) +# return exported.model + +class _TestData: + def __init__(self, name: str, attention_class, with_mask: bool): + self.name = name + self.attention_class = attention_class + self.with_mask = with_mask + + def get_torch_model(self): + return self.attention_class(_config, 0) + + def get_onnx_model(self): + model = self.get_torch_model() + inputs = self.get_inputs() + input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] + exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True) + # ORT Transformer optimizations are applied after basic optimization. + onnxscript.optimizer.optimize(exported.model) + return exported.model + + def get_inputs(self): + if self.with_mask: + return (_hidden_states, _attention_mask, _position_ids) + else: + return (_hidden_states, None, _position_ids) + + def get_torch_outputs(self): + return self.get_torch_model()(*self.get_inputs()) + + def get_ort_inputs(self): + inputs = self.get_inputs() + return {f"input{i}": input for i, input in enumerate(inputs) if input is not None} + +_test_cases = [ + _TestData("attention", modeling_llama.LlamaAttention, False), + _TestData("masked_attention", modeling_llama.LlamaAttention, True), + _TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False), + _TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True), +] + +_test_case_tuples = [ (t,) for t in _test_cases] + +def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2): + providers = ["CPUExecutionProvider"] + with tempfile.TemporaryDirectory() as temp_dir: + model_path = os.path.join(temp_dir, f"{model_name}.onnx") + io.save(model, model_path) + # Run optimized model + session = onnxruntime.InferenceSession(model_path, providers=providers) + ort_outputs = session.run(None, inputs) + + for i, (baseline_output, optimized_output) in enumerate( + zip(expected_outputs, ort_outputs) + ): + try: + np.testing.assert_equal(baseline_output.shape, optimized_output.shape) + np.testing.assert_allclose( + baseline_output, optimized_output, rtol=rtol, atol=atol + ) + except AssertionError as e: + print( + f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" + ) + raise + +class TestOptimizeTransformers(unittest.TestCase): + @parameterized.expand(_test_case_tuples) + def test_attention_optimization(self, test_data: _TestData): + model = test_data.get_onnx_model() + # io.save(model, os.path.join(r"C:\repos\onnxscript\smy\Models", f"{test_data.name}.onnx")) + # model.display() + # print("======>") + optimize_transformers.fuse_rotary_embedding(model) + # model.display() + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + # _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py new file mode 100644 index 000000000..301606a53 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import pattern + +""" +The MultiHeadAttention pattern: + +B: Batch size +S: Sequence length +D: input embedding dimension +H: number of heads +d_h: head size (usually, D = H * d_h) + +thus, weights are usually of shape (D, D) and (D, D) and (D, D) + +for each of Q, K, and V, we have the following pattern: + MatMul (Input, W), producing output of shape (B, S, D) + Reshape to produce a matrix of shape (B, S, H, d_h) + Transpose middle two axes to produce a matrix of shape (B, H, S, d_h) + +This is followed by a RotaryEmbedding pattern for Q and K + +The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) + +The dot-product attention is then computed using SDPA + +Finally, the output is transposed and reshaped back to (B, S, D) shape +""" + + +def _project_transpose_head(op, input, weight): + """Applied to each of Q, K, and V.""" + projected = op.MatMul(input, weight) + # Reshape from (B, S, D) to (B, S, H, D/H) + reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) + return transposed + + +def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): + query = _project_transpose_head(op, input, query_weight) + query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") + key = _project_transpose_head(op, input, key_weight) + key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) + key_reshaped_transposed = op.Transpose(key_reshaped) + key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True) + value = _project_transpose_head(op, input, value_weight) + attention = op.SDPA( + query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" + ) + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + return attention_reshaped, key_rope, value + + +def _multi_head_attention_pattern2( + op, input, query_weight, key_weight, value_weight, cos, sin +): + """Variation of first pattern with Reshape omitted.""" + query = _project_transpose_head(op, input, query_weight) + query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") + key = _project_transpose_head(op, input, key_weight) + key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + # Reshape omitted here. + key_transposed = op.Transpose(key_rope) + # Reshape omitted here + value = _project_transpose_head(op, input, value_weight) + attention = op.SDPA( + query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" + ) + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + return attention_reshaped, key_rope, value + + +def _multi_head_attention( + op, + input, + query_weight, + key_weight, + value_weight, + cos, + sin, +): + # TODO: other checks and concatenation of weights + return op.MultiHeadAttention( + input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3 + ) + + +_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention) +_rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) + +mha_rules = pattern.RewriteRuleSet([_rule1, _rule2]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py new file mode 100644 index 000000000..b0527111b --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import _ir_utils, pattern + + +# Pattern to match against +def _rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): + x_cast = op.Cast(x, to=compute_dtype) + x_square = op.Pow(x_cast, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x_cast, reciprocal_rms) + normalized_cast = op.Cast(normalized, to=target_dtype) + return op.Mul(scale, normalized_cast) + + +# Replacement +def _simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if not isinstance(epsilon_value, float): + return None + source_dtype = x.dtype + if source_dtype is None or source_dtype != target_dtype.value: + return None + return op.SimplifiedLayerNormalization( + x, + scale, + axis=-1, + epsilon=epsilon_value, + stash_type=compute_dtype.value, + _domain="com.microsoft", + ) + + +_rule = pattern.RewriteRule(_rms_norm_pattern, _simplified_layer_norm) +rms_normalization_rules = pattern.RewriteRuleSet([_rule]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py new file mode 100644 index 000000000..0eadaf280 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import _ir_utils, pattern + + +def rotate_half_pattern(op, x, start1, end1, start2, end2): + # Slice(input, starts, ends, axes, steps) + x1 = op.Slice(x, start1, end1, [3], [1]) + x2 = op.Slice(x, start2, end2, [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + return rotated_x + + +def _rotary_embedding_pattern(op, x, cos, sin, start1, end1, start2, end2): + return x * cos + rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + + +def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): + # Check that x is being split into two equal halves: + start1_val = _ir_utils.get_singleton_value(start1) + end1_val = _ir_utils.get_singleton_value(end1) + start2_val = _ir_utils.get_singleton_value(start2) + end2_val = _ir_utils.get_singleton_value(end2) + + if x is None or x.shape is None or len(x.shape) != 4: + return None + dim_size = x.shape[3] + half_dim_size = dim_size // 2 + if ( + start1_val == 0 + and end1_val == half_dim_size + and start2_val == half_dim_size + and end2_val >= dim_size + ): + return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="com.microsoft") + return None + + +_rule = pattern.RewriteRule(_rotary_embedding_pattern, _rotary_embedding) + +rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/sdpa.py b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py new file mode 100644 index 000000000..93d093695 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math + +from onnxscript.rewriter import _ir_utils, pattern + + +def sdpa_pattern(op, query, key_transposed, value, query_scale, key_scale, mask): + scaled_query = op.Mul(query, query_scale) + scaled_key_transposed = op.Mul(key_transposed, key_scale) + attn_score = op.MatMul(scaled_query, scaled_key_transposed) + masked_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +def sdpa(op, query, key_transposed, value, query_scale, key_scale, mask): + # Check if query_scale and key_scale are scalars == 1/sqrt(sqrt(dimsize)) + query_scale_value = _ir_utils.get_singleton_value(query_scale) + key_scale_value = _ir_utils.get_singleton_value(key_scale) + if not isinstance(query_scale_value, float) or not isinstance(key_scale_value, float): + return None + scaling_factor = query_scale_value * key_scale_value + scaling_factor = 1.0 / (scaling_factor * scaling_factor) + # If the dim_size is not statically known, we cannot check if the scale is correct: + if query is None or query.shape is None or len(query.shape) < 2: + return None + dimsize = query.shape[-1] + if not isinstance(dimsize, int) or not math.isclose(scaling_factor, dimsize, abs_tol=1e-3): + return None + return op.SDPA(query, key_transposed, value, mask, _domain="local") + + +def sdpa_pattern2(op, query, key_transposed, value, scale): + attn_score = op.MatMul(query, key_transposed) + masked_score = op.Div(attn_score, scale) + attn_weight = op.Softmax(masked_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +def valid_post_scale(scale, query) -> bool: + # Checks if scale == (sqrt(dimsize)) + scale_value = _ir_utils.get_singleton_value(scale) + if not isinstance(scale_value, float): + return False + scaling_factor = scale_value * scale_value + # If the dim_size is not statically known, we cannot check if the scale is correct: + if query is None or query.shape is None or len(query.shape) < 2: + return False + dimsize = query.shape[-1] + if not isinstance(dimsize, int) or not math.isclose(scaling_factor, dimsize, abs_tol=1e-3): + return False + return True + + +def sdpa2(op, query, key_transposed, value, scale): + if not valid_post_scale(scale, query): + return None + return op.SDPA(query, key_transposed, value, scale, _domain="local") + + +def sdpa_pattern3(op, query, key_transposed, value, scale, mask): + attn_score = op.MatMul(query, key_transposed) + scaled_score = op.Div(attn_score, scale) + masked_score = op.Add(scaled_score, mask) + attn_weight = op.Softmax(masked_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +def sdpa3(op, query, key_transposed, value, scale, mask): + if not valid_post_scale(scale, query): + return None + return op.SDPA(query, key_transposed, value, scale, mask, _domain="local") + + +rule = pattern.RewriteRule(sdpa_pattern, sdpa) +rule2 = pattern.RewriteRule(sdpa_pattern2, sdpa2) +rule3 = pattern.RewriteRule(sdpa_pattern3, sdpa3) + +sdpa_rules = pattern.RewriteRuleSet([rule, rule2, rule3]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py new file mode 100644 index 000000000..38f4281d5 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import pattern + + +def _skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): + skip_sum = op.Add(input, skip) + normalized = op.SimplifiedLayerNormalization( + skip_sum, + gamma, + axis=-1, + epsilon=epsilon, + stash_type=stash_type, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +def _skip_normalization(op, input, skip, gamma, epsilon, stash_type): + normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( + input, + skip, + gamma, + epsilon=epsilon, + stash_type=stash_type, + _domain="com.microsoft", + _outputs=4, + ) + return normalized, skip_sum + + +_rule = pattern.RewriteRule( + _skip_norm_pattern, _skip_normalization, matcher=pattern.SimplePatternMatcher +) + +skip_normalization_rules = pattern.RewriteRuleSet([_rule])