diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index d327d3a..fef2afa 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -7,6 +7,7 @@ import tqdm import transformers from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.dbrx.modeling_dbrx import DbrxExpertGLU from .config import BaseQuantizeConfig @@ -105,6 +106,121 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): ) return output +class FP8DbrxExpertGLU(torch.nn.Module): + def __init__( + self, + original_module: DbrxExpertGLU, + ): + super().__init__() + self.hidden_size = original_module.hidden_size + self.ffn_hidden_size = original_module.ffn_hidden_size + self.moe_num_experts = original_module.moe_num_experts + self.activation_fn = original_module.activation_fn + self.cnt = 0 + self.w1 = torch.empty_like(original_module.w1, + dtype=torch.float8_e4m3fn) + self.v1 = torch.empty_like(original_module.v1, + dtype=torch.float8_e4m3fn) + self.w2 = torch.empty_like(original_module.w2, + dtype=torch.float8_e4m3fn) + + self.w1_weight_scale = torch.ones(self.moe_num_experts, + dtype=torch.float32) + self.v1_weight_scale = torch.ones(self.moe_num_experts, + dtype=torch.float32) + self.w2_weight_scale = torch.ones(self.moe_num_experts, + dtype=torch.float32) + + self.w1_input_scale = torch.zeros(self.moe_num_experts, + dtype=torch.float32) + self.v1_input_scale = torch.zeros(self.moe_num_experts, + dtype=torch.float32) + self.w2_input_scale = torch.zeros(self.moe_num_experts, + dtype=torch.float32) + + self._quantize_weights(original_module) + + def _quantize_weights(self, + original_module: DbrxExpertGLU): + + w1_ = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size) + v1_ = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size) + w2_ = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size) + + ow1_ = original_module.w1.view(self.moe_num_experts, + self.ffn_hidden_size, self.hidden_size) + ov1_ = original_module.v1.view(self.moe_num_experts, + self.ffn_hidden_size, self.hidden_size) + ow2_ = original_module.w2.view(self.moe_num_experts, + self.ffn_hidden_size, self.hidden_size) + + # quantize each expert's weight + for expert_id in range(self.moe_num_experts): + w1_[expert_id], self.w1_weight_scale[expert_id] = \ + per_tensor_quantize(ow1_[expert_id]) + v1_[expert_id], self.v1_weight_scale[expert_id] = \ + per_tensor_quantize(ov1_[expert_id]) + w2_[expert_id], self.w2_weight_scale[expert_id] = \ + per_tensor_quantize(ow2_[expert_id]) + + # register the parameter + self.w1_weight = torch.nn.Parameter(self.w1, + requires_grad=False) + self.v1_weight = torch.nn.Parameter(self.v1, + requires_grad=False) + self.w2_weight = torch.nn.Parameter(self.w2, + requires_grad=False) + + self.w1_weight_scale = torch.nn.Parameter(self.w1_weight_scale, + requires_grad=False) + self.v1_weight_scale = torch.nn.Parameter(self.v1_weight_scale, + requires_grad=False) + self.w2_weight_scale = torch.nn.Parameter(self.w2_weight_scale, + requires_grad=False) + + # For static scheme + def register_input_scale(self): + + self.w1_input_scale = torch.nn.Parameter(self.w1_input_scale, + requires_grad=False) + self.v1_input_scale = torch.nn.Parameter(self.v1_input_scale, + requires_grad=False) + self.w2_input_scale = torch.nn.Parameter(self.w2_input_scale, + requires_grad=False) + + def forward(self, + x: torch.Tensor, + expert_w1: torch.Tensor, + expert_v1: torch.Tensor, + expert_w2: torch.Tensor): + + qinput, x_scale = per_tensor_quantize(x) + self.w1_input_scale[self.cnt] = max(self.w1_input_scale[self.cnt], + x_scale) + self.v1_input_scale[self.cnt] = max(self.v1_input_scale[self.cnt], + x_scale) + gate_proj = fp8_gemm(qinput, x_scale, expert_w1, + self.w1_weight_scale[self.cnt], None, x.dtype) + up_proj = fp8_gemm(qinput, x_scale, expert_v1, + self.v1_weight_scale[self.cnt], None, x.dtype) + gate_proj = self.activation_fn(gate_proj) + intermediate_states = gate_proj * up_proj + + qinput, x_scale = per_tensor_quantize(intermediate_states) + self.w2_input_scale[self.cnt] = max(self.w2_input_scale[self.cnt], + x_scale) + down_proj = fp8_gemm(qinput, x_scale, expert_w2.t(), + self.w2_weight_scale[self.cnt], None, x.dtype) + + # Since DbrxExpert's forward function does not pass the export id + # when calling DbrxExpertGLU's forward function, use self.cnt to + # represent the expert id it is using. + self.cnt = ((self.cnt + 1) % self.moe_num_experts) + return down_proj + # Class responsible for quantizing weights class FP8DynamicLinear(torch.nn.Module): @@ -245,6 +361,20 @@ def quantize_weights( del linear cleanup_memory() + # For dbrx moe + for name, module in tqdm.tqdm(named_modules, desc="Quantizing weights"): + if ( + not isinstance(module, DbrxExpertGLU) + or name in quantize_config.ignored_layers + ): + continue + quant_module = FP8DbrxExpertGLU(module) + replace_module(model, name, quant_module) + del module.w1 + del module.v1 + del module.w2 + del module + cleanup_memory() def quantize_activations( model: AutoModelForCausalLM, @@ -253,6 +383,9 @@ def quantize_activations( ): # Replace weight quantizer with a dynamic activation quantizer observer for name, dynamic_quant_linear in model.named_modules(): + if isinstance(dynamic_quant_linear, FP8DbrxExpertGLU): + dynamic_quant_linear.register_input_scale() + continue if ( not isinstance(dynamic_quant_linear, FP8DynamicLinear) or name in quantize_config.ignored_layers diff --git a/examples/example_dbrx.py b/examples/example_dbrx.py new file mode 100644 index 0000000..3469976 --- /dev/null +++ b/examples/example_dbrx.py @@ -0,0 +1,26 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig + +pretrained_model_dir = "databricks/dbrx-instruct" +quantized_model_dir = "dbrx-instruct-fp8" + +tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +tokenizer.pad_token = tokenizer.eos_token + +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") +examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] +examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") + +quantize_config = BaseQuantizeConfig( + quant_method="fp8", + activation_scheme="static", + ignore_patterns=["re:.*lm_head", "re:.*router"], +) + +model = AutoFP8ForCausalLM.from_pretrained( + pretrained_model_dir, quantize_config=quantize_config +) +model.quantize(examples) +model.save_quantized(quantized_model_dir)