Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is the speed does not increase after compressed it? #852

Open
liho00 opened this issue Oct 18, 2024 · 8 comments
Open

Why is the speed does not increase after compressed it? #852

liho00 opened this issue Oct 18, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@liho00
Copy link

liho00 commented Oct 18, 2024

https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct

https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_fp8

https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_int8

https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w4a16

i tried these example code to generate a new compressed checkpoint and load with vllm 0.6.3

python -m vllm.entrypoints.openai.api_server --served-model-name /home/llm-compressor/examples/quantization_w8a8_fp8/Llama-3.1-8B-Instruct-FP8 --model meta-llama/Llama-3.1-8B-Instruct --port 8000 --host 0.0.0.0 --tensor-parallel-size 8 --gpu-memory-utilization 0.98

base model: 215 tok/s
compressed model: 205 tok/s

@liho00 liho00 added the bug Something isn't working label Oct 18, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

It looks like you are running the FP16 model in your launch command

That being said, you are running a 3b model with tp=8. I do not think you will see much performance benefit from fp8 in this regime since the linear layers are very small in this setup

@liho00
Copy link
Author

liho00 commented Oct 18, 2024

It looks like you are running the FP16 model in your launch command

That being said, you are running a 3b model with tp=8. I do not think you will see much performance benefit from fp8 in this regime since the linear layers are very small in this setup

Sorry for the typo, its should be 8b model Llama-3.1-8B-Instruct-FP8

python -m vllm.entrypoints.openai.api_server --served-model-name Llama-3.1-8B-Instruct-FP8 --model /root/Meta-Llama-3.1-8B-Instruct-quantized.w4a16 --port 8000 --host 0.0.0.0 --tensor-parallel-size 8 --gpu-memory-utilization 0.95 --dtype bfloat16 --quantization compressed-tensors

any idea to speed up with vllm for compressed the model?
the ideal will be having low latency for the first token.

image
Or it only speed up for large model for 70b only?

@robertgshaw2-neuralmagic
Copy link
Collaborator

One last question - is this running on an H100?

@liho00
Copy link
Author

liho00 commented Oct 18, 2024

One last question - is this running on an H100?

yeap, 8xh100 smx5,

can I add you in discord for further details sharing?

@robertgshaw2-neuralmagic
Copy link
Collaborator

One last question - is this running on an H100?

yeap, 8xh100 smx5,

can I add you in discord for further details sharing?

With an 8xh100, your system is very overpowered for running an 8b parameter model, so the e2e speedup from quantization is small (and we have not really tuned the fp8 kernels for matrices that are so skinny).

I would expect to see speedups on a 1xh100 for an 8b parameter scale though

@piamo
Copy link

piamo commented Oct 23, 2024

Same problem. I have run a fp8 quantized minicpm3 (4B) on a L40, and only see less than 10% speedup

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Oct 23, 2024

Same problem. I have run a fp8 quantized minicpm3 (4B) on a L40, and only see less than 10% speedup

Could you share any more details on your workload? For L40 with 8B model scale, I have measured ~30-50% speedup for offline batch workload.

@piamo
Copy link

piamo commented Oct 24, 2024

Same problem. I have run a fp8 quantized minicpm3 (4B) on a L40, and only see less than 10% speedup

Could you share any more details on your workload? For L40 with 8B model scale, I have measured ~30-50% speedup for offline batch workload.

Here's my test case:

origin model: https://huggingface.co/openbmb/MiniCPM3-4B
quantization method: fp8 w8a8, following the official example
vllm version: 0.6.3
max_model_len: 2048
input token length: 625
output token length: 55

I test one same request 10 times with different batchsize(bs), and below is the avg time cost:

bs = 1
origin: 1.16 s/req
quantized: 1.06 s/req

bs=2
origin: 1.36 s/req
quantized: 1.16 s/req

bs=4
origin: 1.55 s/req
quantized: 1.42 s/req

bs=8
origin: 2.00 s/req
quantized: 1.78 s/req

Besides, I found that if set max_model_len bigger (2048 -> 8192), the time cost may be slightly lower ,like 1.78 -> 1.72 @ (bs=8, quantized), interesting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants