Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dbrx moe #45

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tqdm
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.dbrx.modeling_dbrx import DbrxExpertGLU

from .config import BaseQuantizeConfig

Expand Down Expand Up @@ -105,6 +106,121 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
)
return output

class FP8DbrxExpertGLU(torch.nn.Module):
def __init__(
self,
original_module: DbrxExpertGLU,
):
super().__init__()
self.hidden_size = original_module.hidden_size
self.ffn_hidden_size = original_module.ffn_hidden_size
self.moe_num_experts = original_module.moe_num_experts
self.activation_fn = original_module.activation_fn
self.cnt = 0
self.w1 = torch.empty_like(original_module.w1,
dtype=torch.float8_e4m3fn)
self.v1 = torch.empty_like(original_module.v1,
dtype=torch.float8_e4m3fn)
self.w2 = torch.empty_like(original_module.w2,
dtype=torch.float8_e4m3fn)

self.w1_weight_scale = torch.ones(self.moe_num_experts,
dtype=torch.float32)
self.v1_weight_scale = torch.ones(self.moe_num_experts,
dtype=torch.float32)
self.w2_weight_scale = torch.ones(self.moe_num_experts,
dtype=torch.float32)

self.w1_input_scale = torch.zeros(self.moe_num_experts,
dtype=torch.float32)
self.v1_input_scale = torch.zeros(self.moe_num_experts,
dtype=torch.float32)
self.w2_input_scale = torch.zeros(self.moe_num_experts,
dtype=torch.float32)

self._quantize_weights(original_module)

def _quantize_weights(self,
original_module: DbrxExpertGLU):

w1_ = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
self.hidden_size)
v1_ = self.v1.view(self.moe_num_experts, self.ffn_hidden_size,
self.hidden_size)
w2_ = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
self.hidden_size)

ow1_ = original_module.w1.view(self.moe_num_experts,
self.ffn_hidden_size, self.hidden_size)
ov1_ = original_module.v1.view(self.moe_num_experts,
self.ffn_hidden_size, self.hidden_size)
ow2_ = original_module.w2.view(self.moe_num_experts,
self.ffn_hidden_size, self.hidden_size)

# quantize each expert's weight
for expert_id in range(self.moe_num_experts):
w1_[expert_id], self.w1_weight_scale[expert_id] = \
per_tensor_quantize(ow1_[expert_id])
v1_[expert_id], self.v1_weight_scale[expert_id] = \
per_tensor_quantize(ov1_[expert_id])
w2_[expert_id], self.w2_weight_scale[expert_id] = \
per_tensor_quantize(ow2_[expert_id])

# register the parameter
self.w1_weight = torch.nn.Parameter(self.w1,
requires_grad=False)
self.v1_weight = torch.nn.Parameter(self.v1,
requires_grad=False)
self.w2_weight = torch.nn.Parameter(self.w2,
requires_grad=False)

self.w1_weight_scale = torch.nn.Parameter(self.w1_weight_scale,
requires_grad=False)
self.v1_weight_scale = torch.nn.Parameter(self.v1_weight_scale,
requires_grad=False)
self.w2_weight_scale = torch.nn.Parameter(self.w2_weight_scale,
requires_grad=False)

# For static scheme
def register_input_scale(self):

self.w1_input_scale = torch.nn.Parameter(self.w1_input_scale,
requires_grad=False)
self.v1_input_scale = torch.nn.Parameter(self.v1_input_scale,
requires_grad=False)
self.w2_input_scale = torch.nn.Parameter(self.w2_input_scale,
requires_grad=False)

def forward(self,
x: torch.Tensor,
expert_w1: torch.Tensor,
expert_v1: torch.Tensor,
expert_w2: torch.Tensor):

qinput, x_scale = per_tensor_quantize(x)
self.w1_input_scale[self.cnt] = max(self.w1_input_scale[self.cnt],
x_scale)
self.v1_input_scale[self.cnt] = max(self.v1_input_scale[self.cnt],
x_scale)
gate_proj = fp8_gemm(qinput, x_scale, expert_w1,
self.w1_weight_scale[self.cnt], None, x.dtype)
up_proj = fp8_gemm(qinput, x_scale, expert_v1,
self.v1_weight_scale[self.cnt], None, x.dtype)
gate_proj = self.activation_fn(gate_proj)
intermediate_states = gate_proj * up_proj

qinput, x_scale = per_tensor_quantize(intermediate_states)
self.w2_input_scale[self.cnt] = max(self.w2_input_scale[self.cnt],
x_scale)
down_proj = fp8_gemm(qinput, x_scale, expert_w2.t(),
self.w2_weight_scale[self.cnt], None, x.dtype)

# Since DbrxExpert's forward function does not pass the export id
# when calling DbrxExpertGLU's forward function, use self.cnt to
# represent the expert id it is using.
self.cnt = ((self.cnt + 1) % self.moe_num_experts)
return down_proj


# Class responsible for quantizing weights
class FP8DynamicLinear(torch.nn.Module):
Expand Down Expand Up @@ -245,6 +361,20 @@ def quantize_weights(
del linear
cleanup_memory()

# For dbrx moe
for name, module in tqdm.tqdm(named_modules, desc="Quantizing weights"):
if (
not isinstance(module, DbrxExpertGLU)
or name in quantize_config.ignored_layers
):
continue
quant_module = FP8DbrxExpertGLU(module)
replace_module(model, name, quant_module)
del module.w1
del module.v1
del module.w2
del module
cleanup_memory()

def quantize_activations(
model: AutoModelForCausalLM,
Expand All @@ -253,6 +383,9 @@ def quantize_activations(
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
if isinstance(dynamic_quant_linear, FP8DbrxExpertGLU):
dynamic_quant_linear.register_input_scale()
continue
if (
not isinstance(dynamic_quant_linear, FP8DynamicLinear)
or name in quantize_config.ignored_layers
Expand Down
26 changes: 26 additions & 0 deletions examples/example_dbrx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "databricks/dbrx-instruct"
quantized_model_dir = "dbrx-instruct-fp8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="static",
ignore_patterns=["re:.*lm_head", "re:.*router"],
)

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)
model.save_quantized(quantized_model_dir)