Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] First version of fusion optimizations for transformers #1938

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import onnx

import onnxscript.optimizer._constant_folding as constant_folding
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import basic_constant_propagation
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import optimize_ir
from onnxscript.optimizer._remove_unused import remove_unused_nodes

basic_constant_propagation = constant_folding.basic_constant_propagation
fold_constants_ir = constant_folding.fold_constants


def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
Expand All @@ -19,8 +22,16 @@ def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
return legacy_optimizer.optimize(model, *args, **kwargs)


def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
return constant_folding.fold_constants(model, *args, **kwargs)
else:
return legacy_constant_folding.fold_constants(model, *args, **kwargs)


__all__ = [
"fold_constants",
"fold_constants_ir",
"remove_unused_nodes",
"optimize",
"optimize_ir",
Expand Down
44 changes: 37 additions & 7 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import onnx
import onnx.helper
import onnx.reference.ops

import onnxscript.ir as ir
Expand Down Expand Up @@ -434,25 +435,54 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
@register("Dropout", version=(12, None))
def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Dropout by Identity when applicable."""
if len(node.outputs) != 1:
# If output mask is requested, optimization is more complex.
# TODO: handle this case. But unlikely to be needed in practice.
return None

def optimized_dropout():
input = node.inputs[0]
output = op.Identity(input)
if len(node.outputs) == 1:
return output
else:
true_tensor = onnx.helper.make_tensor("true", onnx.TensorProto.BOOL, [1], [True])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this IR? If so

Suggested change
true_tensor = onnx.helper.make_tensor("true", onnx.TensorProto.BOOL, [1], [True])
true_tensor = ir.tensor([True])

Copy link
Collaborator Author

@gramalingam gramalingam Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. But when I look at the signature here, it is not clear this is supported. The example illustrates it, though. I see it eventually calls np.array constructor if nothing else works, so I understand it now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. We can update the signature

Copy link
Collaborator

@justinchuby justinchuby Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually covered by npt.ArrayLike (the first)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I tried and it failed, rejecting a list. BTW, I have moved the independent parts of this PR into a separate PR: #1947

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I need to fix that then

input_shape = op.Shape(input)
mask = op.ConstantOfShape(input_shape, value=true_tensor)
return output, mask

inputs = node.inputs
if (len(inputs) <= 2) or inputs[2] is None:
# No training_mode specified:
return op.Identity(inputs[0])
return optimized_dropout()
if _get_bool_value(inputs[2]) is False:
# training_mode is False: dropout is not applied.
return op.Identity(inputs[0])
return optimized_dropout()
ratio = _get_numpy_value(inputs[1])
if ratio is None:
return None
if ratio.size != 1: # Only scalar dropout ratio is supported.
return None
if ratio.item() == 0:
# dropout ratio is 0: dropout is not applied.
return op.Identity(inputs[0])
return optimized_dropout()
return None


@register("Expand")
def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Expand node by Identity when applicable."""
if len(node.inputs) != 2:
return None
if (input := node.inputs[0]) is None:
return None
if (input_shape := input.shape) is None:
# Input shape is not known.
return None
if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
# Target shape is not known.
return None
if expanded_shape.ndim != 1:
# Target shape must be a 1D tensor. Erroneous model.
return None
if input_shape.dims == tuple(expanded_shape.tolist()):
return op.Identity(input)
return None


Expand Down
16 changes: 16 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,22 @@ def test_concat_identity(self):
self.assertEqual(len(optimized.graph.node), 1)
self.assertEqual(optimized.graph.node[0].op_type, "Identity")

def test_expand_identity(self):
if not self.using_ir:
self.skipTest("New optimizations not supported for legacy optimizer")
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[128, 256] x) => (float[128, 256] z)
{
shape = Constant <value_ints=[128, 256]> ()
z = Expand (x, shape)
}
"""
)
optimized = self._fold(model)
self.assertEqual(optimized.graph.node[-1].op_type, "Identity")


if __name__ == "__main__":
unittest.main()
7 changes: 4 additions & 3 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def inline_calls_in(self, graph: ir.Graph) -> None:

def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
inliner = _Inliner(model)
inliner.inline_calls_in(model.graph)
model.functions.clear()
if model.functions:
inliner = _Inliner(model)
inliner.inline_calls_in(model.graph)
model.functions.clear()
23 changes: 23 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation

Expand All @@ -11,3 +13,24 @@ def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
if node is not None:
basic_constant_propagation([node])
return value.const_value


def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
if val is None:
return None
const_value = val.const_value
if const_value is not None:
try:
return const_value.numpy()
except FileNotFoundError:
# External data is not available.
return None
return None


def get_singleton_value(val: ir.Value | None):
"""Returns element of a single element tensor constant value, and None otherwise."""
np_val = get_numpy_value(val)
if np_val is not None and np_val.size == 1:
return np_val.item()
return None
17 changes: 17 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import (
mha_rules as mha_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import (
rms_normalization_rules as rms_normalization_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import (
rotary_embedding_rules as rotary_embedding_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import (
skip_normalization_rules as skip_normalization_rules,
)
38 changes: 38 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes
from onnxscript.rewriter.onnxruntime.xformers import (
mha_rules,
rms_normalization_rules,
rotary_embedding_rules,
sdpa_rules,
skip_normalization_rules,
)


def fuse_rotary_embedding(irmodel: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(irmodel)
print(f"RotaryEmbedding count: {count}")

def optimize(irmodel: ir.Model, verbose: int = 0) -> None:
def apply(rulename: str, rule):
count = rule.apply_to_model(irmodel, verbose=verbose)
print(f"{rulename} count: {count}")

fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4)
remove_unused_nodes(irmodel)

apply("RMS Normalization", rms_normalization_rules)
apply("Skip Normalization", skip_normalization_rules)

fold_constants_ir(irmodel)
remove_unused_nodes(irmodel)

apply("SDPA-Attention", sdpa_rules)
apply("RotaryEmbedding", rotary_embedding_rules)
apply("Multi-Head-Attention", mha_rules)

remove_unused_nodes(irmodel)
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations
Fixed Show fixed Hide fixed

import os
import tempfile
import unittest

import numpy as np
import onnxruntime
import torch
import transformers.models.llama.modeling_llama as modeling_llama
from parameterized import parameterized
from transformers import LlamaConfig

import onnxscript.ir._io as io
import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import (
_optimize_transformers as optimize_transformers,
)

# Create a LlamaConfig object with the desired parameters
_config = LlamaConfig(
_name_or_path="HuggingFaceTB/SmolLM-1.7B",
architectures=["LlamaForCausalLM"],
attention_bias=False,
attention_dropout=0.0,
bos_token_id=0,
eos_token_id=0,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=2048,
model_type="llama",
num_attention_heads=32,
num_hidden_layers=24,
num_key_value_heads=32,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
rope_theta=10000.0,
tie_word_embeddings=True,
torch_dtype="float32",
transformers_version="4.37.2",
use_cache=True,
vocab_size=49152,
)

# Dimensions for inputs:
_batch_size = 1
_seq_len = 10
_hidden_size = _config.hidden_size
_num_attention_heads = _config.num_attention_heads
dim = _hidden_size // _num_attention_heads

# Generate inputs:
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32)
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32))
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len)
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10)

# Get model in ONNX format
# def _get_model(llama_attention_class, with_mask: bool):
# model = llama_attention_class(_config, 0)
# if with_mask:
# inputs = (_hidden_states, _attention_mask, _position_ids)
# else:
# inputs = (_hidden_states, None, _position_ids)
# exported = torch.onnx.export(model, inputs, dynamo=True)
# # ORT Transformer optimizations are applied after basic optimization.
# onnxscript.optimizer.optimize(exported.model)
# return exported.model

class _TestData:
def __init__(self, name: str, attention_class, with_mask: bool):
self.name = name
self.attention_class = attention_class
self.with_mask = with_mask

def get_torch_model(self):
return self.attention_class(_config, 0)

def get_onnx_model(self):
model = self.get_torch_model()
inputs = self.get_inputs()
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None]
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True)
# ORT Transformer optimizations are applied after basic optimization.
onnxscript.optimizer.optimize(exported.model)
return exported.model

def get_inputs(self):
if self.with_mask:
return (_hidden_states, _attention_mask, _position_ids)
else:
return (_hidden_states, None, _position_ids)

def get_torch_outputs(self):
return self.get_torch_model()(*self.get_inputs())

def get_ort_inputs(self):
inputs = self.get_inputs()
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None}

_test_cases = [
_TestData("attention", modeling_llama.LlamaAttention, False),
_TestData("masked_attention", modeling_llama.LlamaAttention, True),
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False),
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True),
]

_test_case_tuples = [ (t,) for t in _test_cases]

def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
# Run optimized model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)

for i, (baseline_output, optimized_output) in enumerate(
zip(expected_outputs, ort_outputs)
):
try:
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
np.testing.assert_allclose(
baseline_output, optimized_output, rtol=rtol, atol=atol
)
except AssertionError as e:
print(
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
)
raise

class TestOptimizeTransformers(unittest.TestCase):
@parameterized.expand(_test_case_tuples)
def test_attention_optimization(self, test_data: _TestData):
model = test_data.get_onnx_model()
# model.display()
# print("======>")
optimize_transformers.fuse_rotary_embedding(model)
# model.display()
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs())


if __name__ == "__main__":
unittest.main()
Loading
Loading