-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixtral 8*22B Quantization Failed with 2 issues #35
Comments
Hey @qingquansong, thanks for trying out llm-compressor! To address your first issue with SmoothQuantModifier, we recently added a default mapping that gets used if none is provided (see https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/smoothquant/base.py#L16). You'll need to set the mappings argument in the modifier manually to work with Mistral As for your OOM issue, the mappings set by device_map="auto" do not take into account the hessians allocated during GPTQ. To reduce the memory usage you can add the Additionally, we have support for CPU offloading currently in PR: #34. This also adds support for accounting for the GPTQ and quantization memory needs on model load. |
@Satrat Thank you for the response! One quick question I have is: for the mappings I saw the default one is defined as:
how we decide the two layers of the mappings (I can understand why we put qkv together but something I'm confused is that why we have to based on the block to put the "re:.*input_layernorm" in the second layer of this first list but put the others in a separate list, can we just do [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm", ["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm" ] and no need to separate a list in the second line? ) and if using the mixtral as an example, should we do the following or maybe need to separate something or merge some lists here?
And it seems if there's another issue even I do:
Thank you! |
The mappings should be of the form: ([weights_to_balance...], activation_to_smooth]). For instance in the default: we want to smooth the input activations that feed into q/k/v proj, and these activations are coming out of input_layernorm. Same for the second item in the list, we want to smooth the activations coming into gate/up proj, and they come out of post_attention_layernorm. The diagram on page 4 of https://arxiv.org/pdf/2211.10438 is useful for visualizing this. The Mixtral example suggested would not work, as the mapping must be formatted in groups of 2, where activation_to_smooth is a single element and [weights_to_balance] can be a list |
@Satrat Thank you!
|
Interesting, this is an issue with the recipe parsing. I was able to make a barebones example that reproduced that issue and filed is as a bug here: #37. To unblock, you can try defining the recipe as a string or yaml file rather than programmatically (explained in the linked issue) I also checked with our research team, and running smoothquant on mixtral is not something we have tried before. So unfortunately I don't have a specific recipe to suggest. I think the general flow of matching up each linear layer with the activation before it is a good place to start. There also is no requirement that every activation be smoothed, so for instance you could try only smoothing into q/k/v/o and gate and leaving the experts as is |
Make sense. Thank you for the quick response! Let me explore a bit and can share some finding/experience or provide a recipe for Mixtral later if needed. |
Hey @Satrat , I'm able to make smoothquant and GPTQ work for mixtral 8*7b with the following recipe (shared here for other users as reference) and waiting for 8 * 22b +
|
@qingquansong great! One correction though is SmoothQuant should always come before the GPTQModifier; you won't see an error applying it afterwards but you won't get any benefit. The intended usage is to run smoothquant to "squish" the outliers, then run quantization/gptq afterwards with the squashed dynamic range |
oh @Satrat yeah, you're absolutely right, sorry forget that. Btw, I faced some issues related to the hessian matrix compute ill-conditioned. I think it's partially due to (1) the number of samples I used is small (2) the dampening_frac (3) I think if we could add the act order, it will help a lot based on our previous analysis for GPTQ. I remember there's a pr (branch) for checking it, do we plan to add it soon? Should be fairly easy. |
Yeah, for an MoE model especially the amount of calibration data will be an issue. As for activation reordering, its currently being worked on: @bfineran do we have an estimate on when that PR will be ready? |
Hey @Satrat , I think there seems to be another issue here, I follow the sample yaml here llm-compressor/src/llmcompressor/modifiers/quantization/gptq/base.py Lines 60 to 61 in 29cb10d
but it should be "group" or "channel" I guess? Or maybe we should and one option to make wNa16 accepted for the first case that input_quant is None for some layers so we can keep the strategy as "tensor" for the w8a8 case. Another thing that is weird is when I use the previously yaml file to quantize the model, when loading the layer, it gives me all Changing to |
I probably know the reason and wanna confirm with you (also doing testing myself now). Several things to change: (1) cannot use group option for w8a8 and only channel and tensor are available so need to remove group size and change to channel or tensor. (2) input_activation should be changed to 8bits rather than None (3) need to add ignore lm_head in GPTQ (smooth quant is fine) (4) I added gate in mappings but since it's only for moe routing, maybe it's better not add it to protect performance, but since the hidden states is shared for q/k/v and gate so to make sure things are all 8 bits I add it in it now. (5) Not sure if output_activation needs to be 8bits but I feel we don't need to. Llama works and mixtral still testing it.
|
The model can be quantized for moe but still have issue will vllm 0.5.2 loading. Since I saw @robertgshaw2-neuralmagic has made some recent changes in that, maybe know the root cause? I tested llama and the weight loader is:
However, for mixtral with the same quantization setup, I got:
I think it's related to the |
Hey @qingquansong thanks for the update! To address your questions
For the vLLM side issues I'll pass it over to @robertgshaw2-neuralmagic :) |
Thanks for the repsponse! Let me take a look and vllm issue I think I figure out why, I put a quick fix here: vllm-project/vllm#6793 @Satrat @robertgshaw2-neuralmagic please help take a look and see if it makes sense. |
@Satrat I think probably there're some misunderstanding here. I didn't set group for input I think (please correct me if I setup wrong)?Only gptq one has the group size setup currently. If I don't set
then the input_quant will be None when loading with vllm 0.5.2 even I use smooth quant, causing issues with vllm loading. Also, the group option seems also blocks the vllm loading in this case since there's no one available for w8a8 from here . The reason i want to quantize both is to adopt W8A8 tensor core for faster inference. I changed to channel and remove the group_size which works fine then. For weight loading issue, after fix in the vllm pr, vllm-project/vllm#6793 it seems to be fine to load and run, but not sure about the accuracy and mistral 8*22b become much slower in this case and requires more memory usage. I'm not sure if it's related to the channel wise quantization, but ideally it should be much faster since I only have prefill stage with 1 token generation right? (Quick update:for the speed issue,it should because of mixtral_quant in vllm does not have fused moe layer for int8 quantization,I'll create one ) |
So any of the quantization options for weights are also available for input_activations and output_activations. You can see all of the options here: https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_args.py. As for the vLLM side, I didn't realize we didn't have grouped weight support for w8a8 yet. @robertgshaw2-neuralmagic is working on documenting what is and isn't supported, which should hopefully clear up confusion in the future |
Update from my side: I corrected my setup to remove A question related here is: since I do smooth quant for llm-compressor/src/llmcompressor/modifiers/smoothquant/base.py Lines 16 to 19 in 780256c
|
@Satrat is this good to close? |
Hey @robertgshaw2-neuralmagic I think we can close it now. The only thing left is maybe can look at this pr: vllm-project/vllm#6978 related to the fused w8a8 MoE for speedup and performance issue. Other than that it's all good. One more thing related to fp8 quantization, since vllm directly support dynamic fp8 quantization (without using llm-compressor) is it suggested to use that one or maybe llm-compressor is for providing more fp8 quant strategies? I feel we probably can remove the mixtral_quant.py scripts later after we have the above pr checked in so all mixtral can use the same script. Thank you! |
I’m dealing with a few critical issues in v0.5.4. I will review your PR after this even with in place wuantization, having a checkpoint is better since it’s half the disk space. So we still want to have quantized checkpoints |
* initial fix * fix for counter * using aminmax to fix discrepencies * minor improvements * remove prints
Describe the bug
A clear and concise description of what the bug is.
Hey Team, trying to quantize mistral 8*22b with W8A8 recipe and failed with two issues with different versions:
This issue happen when using the latest main branch and I think there're some regex issue and previously when using the main branch 1~2 week ago I didn't see the issue. Any things has changed?
I think I can fix with changing the default mapping of smoothquant to the mixtral one manually,but wondering if there're some better solutions here and why previously it does not happen.
auto
and seems to work well for llama3 70b but not mistral 8*22b (larger though)) So probably have a cpu offloading or better block clean up schema is needed.Expected behavior
A clear and concise description of what you expected to happen.
Expect to finish with 1node 8A100 setup
Environment
Include all relevant environment information:
f7245c8
]: latest branch (end of 2024-0723)To Reproduce
Exact steps to reproduce the behavior:
Errors
If applicable, add a full print-out of any errors or exceptions that are raised or include screenshots to help explain your problem.
Another OOM issue cannot find the log but happens after 4 layers
Additional context
Add any other context about the problem here. Also include any relevant files.
The text was updated successfully, but these errors were encountered: