From e115fa6520f5ec608ccfb32f0d573ef5eb2b490a Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 26 Sep 2024 15:38:37 +0000 Subject: [PATCH 1/4] add support for dbrx moe --- auto_fp8/quantize.py | 90 ++++++++++++++++++++++++++++++++++++++++ examples/example_dbrx.py | 26 ++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 examples/example_dbrx.py diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index d327d3a..8bab310 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -7,6 +7,7 @@ import tqdm import transformers from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.dbrx.modeling_dbrx import DbrxExpertGLU from .config import BaseQuantizeConfig @@ -105,6 +106,81 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): ) return output +class FP8DbrxExpertGLU(torch.nn.Module): + def __init__( + self, + original_module: DbrxExpertGLU, + ): + super().__init__() + self.hidden_size = original_module.hidden_size + self.ffn_hidden_size = original_module.ffn_hidden_size + self.moe_num_experts = original_module.moe_num_experts + self.activation_fn = original_module.activation_fn + self.cnt = 0 + self.w1 = torch.empty_like(original_module.w1, dtype=torch.float8_e4m3fn) + self.v1 = torch.empty_like(original_module.v1, dtype=torch.float8_e4m3fn) + self.w2 = torch.empty_like(original_module.w2, dtype=torch.float8_e4m3fn) + + self.w1_weight_scale = torch.ones(self.moe_num_experts, dtype=torch.float32) + self.v1_weight_scale = torch.ones(self.moe_num_experts, dtype=torch.float32) + self.w2_weight_scale = torch.ones(self.moe_num_experts, dtype=torch.float32) + + self.w1_input_scale = torch.zeros(self.moe_num_experts, dtype=torch.float32) + self.v1_input_scale = torch.zeros(self.moe_num_experts, dtype=torch.float32) + self.w2_input_scale = torch.zeros(self.moe_num_experts, dtype=torch.float32) + + self._quantize_weights(original_module) + + def _quantize_weights(self, + original_module: DbrxExpertGLU): + + w1_ = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) + v1_ = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) + w2_ = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) + + ow1_ = original_module.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) + ov1_ = original_module.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) + ow2_ = original_module.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) + + # quantize each expert's weight + for expert_id in range(self.moe_num_experts): + w1_[expert_id], self.w1_weight_scale[expert_id] = per_tensor_quantize(ow1_[expert_id]) + v1_[expert_id], self.v1_weight_scale[expert_id] = per_tensor_quantize(ov1_[expert_id]) + w2_[expert_id], self.w2_weight_scale[expert_id] = per_tensor_quantize(ow2_[expert_id]) + + # register the parameter + self.w1_weight = torch.nn.Parameter(self.w1, requires_grad=False) + self.v1_weight = torch.nn.Parameter(self.v1, requires_grad=False) + self.w2_weight = torch.nn.Parameter(self.w2, requires_grad=False) + + self.w1_input_scale = torch.nn.Parameter(self.w1_input_scale, requires_grad=False) + self.v1_input_scale = torch.nn.Parameter(self.v1_input_scale, requires_grad=False) + self.w2_input_scale = torch.nn.Parameter(self.w2_input_scale, requires_grad=False) + + self.w1_weight_scale = torch.nn.Parameter(self.w1_weight_scale, requires_grad=False) + self.v1_weight_scale = torch.nn.Parameter(self.v1_weight_scale, requires_grad=False) + self.w2_weight_scale = torch.nn.Parameter(self.w2_weight_scale, requires_grad=False) + + def forward(self, + x: torch.Tensor, + expert_w1: torch.Tensor, + expert_v1: torch.Tensor, + expert_w2: torch.Tensor): + qinput, x_scale = per_tensor_quantize(x) + self.w1_input_scale[self.cnt] = max(self.w1_input_scale[self.cnt], x_scale) + self.v1_input_scale[self.cnt] = max(self.v1_input_scale[self.cnt], x_scale) + gate_proj = fp8_gemm(qinput, x_scale, expert_w1, self.w1_weight_scale[self.cnt], None, x.dtype) + up_proj = fp8_gemm(qinput, x_scale, expert_v1, self.v1_weight_scale[self.cnt], None, x.dtype) + gate_proj = self.activation_fn(gate_proj) + intermediate_states = gate_proj * up_proj + + qinput, x_scale = per_tensor_quantize(intermediate_states) + self.w2_input_scale[self.cnt] = max(self.w2_input_scale[self.cnt], x_scale) + down_proj = fp8_gemm(qinput, x_scale, expert_w2.t(), self.w2_weight_scale[self.cnt], None, x.dtype) + + self.cnt = ((self.cnt + 1) % self.moe_num_experts) + return down_proj + # Class responsible for quantizing weights class FP8DynamicLinear(torch.nn.Module): @@ -245,6 +321,20 @@ def quantize_weights( del linear cleanup_memory() + # For dbrx moe + for name, module in tqdm.tqdm(named_modules, desc="Quantizing weights"): + if ( + not isinstance(module, DbrxExpertGLU) + or name in quantize_config.ignored_layers + ): + continue + quant_module = FP8DbrxExpertGLU(module) + replace_module(model, name, quant_module) + del module.w1 + del module.v1 + del module.w2 + del module + cleanup_memory() def quantize_activations( model: AutoModelForCausalLM, diff --git a/examples/example_dbrx.py b/examples/example_dbrx.py new file mode 100644 index 0000000..ea4a418 --- /dev/null +++ b/examples/example_dbrx.py @@ -0,0 +1,26 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig + +pretrained_model_dir = "/models/databrix/" +quantized_model_dir = "./output" + +tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +tokenizer.pad_token = tokenizer.eos_token + +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") +examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] +examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") + +quantize_config = BaseQuantizeConfig( + quant_method="fp8", + activation_scheme="static", + ignore_patterns=["re:.*lm_head", "re:.*router"], +) + +model = AutoFP8ForCausalLM.from_pretrained( + pretrained_model_dir, quantize_config=quantize_config +) +model.quantize(examples) +model.save_quantized(quantized_model_dir) From e16aa5c3e34195b37d114088a2871f6837eb6728 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 26 Sep 2024 15:54:57 +0000 Subject: [PATCH 2/4] add support dynamic quant --- auto_fp8/quantize.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 8bab310..2d78e26 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -153,19 +153,23 @@ def _quantize_weights(self, self.v1_weight = torch.nn.Parameter(self.v1, requires_grad=False) self.w2_weight = torch.nn.Parameter(self.w2, requires_grad=False) - self.w1_input_scale = torch.nn.Parameter(self.w1_input_scale, requires_grad=False) - self.v1_input_scale = torch.nn.Parameter(self.v1_input_scale, requires_grad=False) - self.w2_input_scale = torch.nn.Parameter(self.w2_input_scale, requires_grad=False) - self.w1_weight_scale = torch.nn.Parameter(self.w1_weight_scale, requires_grad=False) self.v1_weight_scale = torch.nn.Parameter(self.v1_weight_scale, requires_grad=False) self.w2_weight_scale = torch.nn.Parameter(self.w2_weight_scale, requires_grad=False) + + # For static scheme + def register_input_scale(self): + + self.w1_input_scale = torch.nn.Parameter(self.w1_input_scale, requires_grad=False) + self.v1_input_scale = torch.nn.Parameter(self.v1_input_scale, requires_grad=False) + self.w2_input_scale = torch.nn.Parameter(self.w2_input_scale, requires_grad=False) def forward(self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor): + qinput, x_scale = per_tensor_quantize(x) self.w1_input_scale[self.cnt] = max(self.w1_input_scale[self.cnt], x_scale) self.v1_input_scale[self.cnt] = max(self.v1_input_scale[self.cnt], x_scale) @@ -343,6 +347,9 @@ def quantize_activations( ): # Replace weight quantizer with a dynamic activation quantizer observer for name, dynamic_quant_linear in model.named_modules(): + if isinstance(dynamic_quant_linear, FP8DbrxExpertGLU): + dynamic_quant_linear.register_input_scale() + continue if ( not isinstance(dynamic_quant_linear, FP8DynamicLinear) or name in quantize_config.ignored_layers From 8e349383df730feeee9f5620fac2915599b4f60e Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 26 Sep 2024 16:01:34 +0000 Subject: [PATCH 3/4] format --- auto_fp8/quantize.py | 104 +++++++++++++++++++++++++++++-------------- 1 file changed, 70 insertions(+), 34 deletions(-) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 2d78e26..fef2afa 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -117,52 +117,79 @@ def __init__( self.moe_num_experts = original_module.moe_num_experts self.activation_fn = original_module.activation_fn self.cnt = 0 - self.w1 = torch.empty_like(original_module.w1, dtype=torch.float8_e4m3fn) - self.v1 = torch.empty_like(original_module.v1, dtype=torch.float8_e4m3fn) - self.w2 = torch.empty_like(original_module.w2, dtype=torch.float8_e4m3fn) + self.w1 = torch.empty_like(original_module.w1, + dtype=torch.float8_e4m3fn) + self.v1 = torch.empty_like(original_module.v1, + dtype=torch.float8_e4m3fn) + self.w2 = torch.empty_like(original_module.w2, + dtype=torch.float8_e4m3fn) - self.w1_weight_scale = torch.ones(self.moe_num_experts, dtype=torch.float32) - self.v1_weight_scale = torch.ones(self.moe_num_experts, dtype=torch.float32) - self.w2_weight_scale = torch.ones(self.moe_num_experts, dtype=torch.float32) + self.w1_weight_scale = torch.ones(self.moe_num_experts, + dtype=torch.float32) + self.v1_weight_scale = torch.ones(self.moe_num_experts, + dtype=torch.float32) + self.w2_weight_scale = torch.ones(self.moe_num_experts, + dtype=torch.float32) - self.w1_input_scale = torch.zeros(self.moe_num_experts, dtype=torch.float32) - self.v1_input_scale = torch.zeros(self.moe_num_experts, dtype=torch.float32) - self.w2_input_scale = torch.zeros(self.moe_num_experts, dtype=torch.float32) + self.w1_input_scale = torch.zeros(self.moe_num_experts, + dtype=torch.float32) + self.v1_input_scale = torch.zeros(self.moe_num_experts, + dtype=torch.float32) + self.w2_input_scale = torch.zeros(self.moe_num_experts, + dtype=torch.float32) self._quantize_weights(original_module) def _quantize_weights(self, original_module: DbrxExpertGLU): - w1_ = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) - v1_ = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) - w2_ = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) - - ow1_ = original_module.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) - ov1_ = original_module.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) - ow2_ = original_module.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size) + w1_ = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size) + v1_ = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size) + w2_ = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size) + + ow1_ = original_module.w1.view(self.moe_num_experts, + self.ffn_hidden_size, self.hidden_size) + ov1_ = original_module.v1.view(self.moe_num_experts, + self.ffn_hidden_size, self.hidden_size) + ow2_ = original_module.w2.view(self.moe_num_experts, + self.ffn_hidden_size, self.hidden_size) # quantize each expert's weight for expert_id in range(self.moe_num_experts): - w1_[expert_id], self.w1_weight_scale[expert_id] = per_tensor_quantize(ow1_[expert_id]) - v1_[expert_id], self.v1_weight_scale[expert_id] = per_tensor_quantize(ov1_[expert_id]) - w2_[expert_id], self.w2_weight_scale[expert_id] = per_tensor_quantize(ow2_[expert_id]) + w1_[expert_id], self.w1_weight_scale[expert_id] = \ + per_tensor_quantize(ow1_[expert_id]) + v1_[expert_id], self.v1_weight_scale[expert_id] = \ + per_tensor_quantize(ov1_[expert_id]) + w2_[expert_id], self.w2_weight_scale[expert_id] = \ + per_tensor_quantize(ow2_[expert_id]) # register the parameter - self.w1_weight = torch.nn.Parameter(self.w1, requires_grad=False) - self.v1_weight = torch.nn.Parameter(self.v1, requires_grad=False) - self.w2_weight = torch.nn.Parameter(self.w2, requires_grad=False) + self.w1_weight = torch.nn.Parameter(self.w1, + requires_grad=False) + self.v1_weight = torch.nn.Parameter(self.v1, + requires_grad=False) + self.w2_weight = torch.nn.Parameter(self.w2, + requires_grad=False) - self.w1_weight_scale = torch.nn.Parameter(self.w1_weight_scale, requires_grad=False) - self.v1_weight_scale = torch.nn.Parameter(self.v1_weight_scale, requires_grad=False) - self.w2_weight_scale = torch.nn.Parameter(self.w2_weight_scale, requires_grad=False) + self.w1_weight_scale = torch.nn.Parameter(self.w1_weight_scale, + requires_grad=False) + self.v1_weight_scale = torch.nn.Parameter(self.v1_weight_scale, + requires_grad=False) + self.w2_weight_scale = torch.nn.Parameter(self.w2_weight_scale, + requires_grad=False) # For static scheme def register_input_scale(self): - self.w1_input_scale = torch.nn.Parameter(self.w1_input_scale, requires_grad=False) - self.v1_input_scale = torch.nn.Parameter(self.v1_input_scale, requires_grad=False) - self.w2_input_scale = torch.nn.Parameter(self.w2_input_scale, requires_grad=False) + self.w1_input_scale = torch.nn.Parameter(self.w1_input_scale, + requires_grad=False) + self.v1_input_scale = torch.nn.Parameter(self.v1_input_scale, + requires_grad=False) + self.w2_input_scale = torch.nn.Parameter(self.w2_input_scale, + requires_grad=False) def forward(self, x: torch.Tensor, @@ -171,17 +198,26 @@ def forward(self, expert_w2: torch.Tensor): qinput, x_scale = per_tensor_quantize(x) - self.w1_input_scale[self.cnt] = max(self.w1_input_scale[self.cnt], x_scale) - self.v1_input_scale[self.cnt] = max(self.v1_input_scale[self.cnt], x_scale) - gate_proj = fp8_gemm(qinput, x_scale, expert_w1, self.w1_weight_scale[self.cnt], None, x.dtype) - up_proj = fp8_gemm(qinput, x_scale, expert_v1, self.v1_weight_scale[self.cnt], None, x.dtype) + self.w1_input_scale[self.cnt] = max(self.w1_input_scale[self.cnt], + x_scale) + self.v1_input_scale[self.cnt] = max(self.v1_input_scale[self.cnt], + x_scale) + gate_proj = fp8_gemm(qinput, x_scale, expert_w1, + self.w1_weight_scale[self.cnt], None, x.dtype) + up_proj = fp8_gemm(qinput, x_scale, expert_v1, + self.v1_weight_scale[self.cnt], None, x.dtype) gate_proj = self.activation_fn(gate_proj) intermediate_states = gate_proj * up_proj qinput, x_scale = per_tensor_quantize(intermediate_states) - self.w2_input_scale[self.cnt] = max(self.w2_input_scale[self.cnt], x_scale) - down_proj = fp8_gemm(qinput, x_scale, expert_w2.t(), self.w2_weight_scale[self.cnt], None, x.dtype) + self.w2_input_scale[self.cnt] = max(self.w2_input_scale[self.cnt], + x_scale) + down_proj = fp8_gemm(qinput, x_scale, expert_w2.t(), + self.w2_weight_scale[self.cnt], None, x.dtype) + # Since DbrxExpert's forward function does not pass the export id + # when calling DbrxExpertGLU's forward function, use self.cnt to + # represent the expert id it is using. self.cnt = ((self.cnt + 1) % self.moe_num_experts) return down_proj From 19c2734ef57ce290e81b1428dec6e92a1d9a9deb Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 26 Sep 2024 16:06:16 +0000 Subject: [PATCH 4/4] naming --- examples/example_dbrx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/example_dbrx.py b/examples/example_dbrx.py index ea4a418..3469976 100644 --- a/examples/example_dbrx.py +++ b/examples/example_dbrx.py @@ -3,8 +3,8 @@ from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig -pretrained_model_dir = "/models/databrix/" -quantized_model_dir = "./output" +pretrained_model_dir = "databricks/dbrx-instruct" +quantized_model_dir = "dbrx-instruct-fp8" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) tokenizer.pad_token = tokenizer.eos_token