Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'MllamaConfig' object has no attribute 'use_cache' #688

Open
mgoin opened this issue Sep 26, 2024 · 0 comments · May be fixed by #834
Open

AttributeError: 'MllamaConfig' object has no attribute 'use_cache' #688

mgoin opened this issue Sep 26, 2024 · 0 comments · May be fixed by #834
Assignees
Labels
bug Something isn't working

Comments

@mgoin
Copy link
Collaborator

mgoin commented Sep 26, 2024

Currently we require model configs to have a use_cache attribute when using apply_compression with calibration data

Traceback (most recent call last):
  File "/home/mgoin/code/llm-compressor/examples/quantization_w8a8_int8/llama3.2_vision_example.py", line 70, in <module>
    oneshot(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 76, in oneshot
    main(model_args, data_args, training_args)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 364, in main
    stage_runner.one_shot()
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/runner.py", line 171, in one_shot
    self.trainer.one_shot(calibration_data=calib_data, stage=stage)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/session_mixin.py", line 401, in one_shot
    apply(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session_functions.py", line 184, in apply
    return active_session().apply(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session.py", line 210, in apply
    self.initialize(**kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session.py", line 156, in initialize
    mod_data = self._lifecycle.initialize(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/lifecycle.py", line 126, in initialize
    data = mod.initialize(state=self.state, **extras)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/stage.py", line 124, in initialize
    modifier.initialize(state, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/modifier.py", line 118, in initialize
    initialized = self.on_initialize(state=state, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/quantization/gptq/base.py", line 202, in on_initialize
    self.apply_compression(calibration_dataloader)
  File "/home/mgoin/venvs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/quantization/gptq/base.py", line 287, in apply_compression
    forward_pass_use_cache = self.model.config.use_cache
  File "/home/mgoin/code/transformers/src/transformers/configuration_utils.py", line 202, in __getattribute__
    return super().__getattribute__(key)
AttributeError: 'MllamaConfig' object has no attribute 'use_cache'

Code to trigger:

from datasets import load_dataset
from transformers import AutoTokenizer, MllamaForConditionalGeneration

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.transformers import oneshot, wrap_hf_model_class

# Select model and load it.
MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model_class = wrap_hf_model_class(MllamaForConditionalGeneration)
model = model_class.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="auto",
)
processor = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 4
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
    return {
        "text": processor.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return processor(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)
print(ds)

# Configure algorithms. In this case, we:
#   * apply SmoothQuant to make the activations easier to quantize
#   * quantize the weights to int8 with GPTQ (static per channel)
#   * quantize the activations to int8 (dynamic per token)
# Note: set sequential_update: true in the recipe to reduce memory
ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"]
recipe = [
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore),
]

# Apply algorithms.
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(processor.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)
@mgoin mgoin added the bug Something isn't working label Sep 26, 2024
@kylesayrs kylesayrs linked a pull request Oct 9, 2024 that will close this issue
@kylesayrs kylesayrs self-assigned this Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
@mgoin @kylesayrs and others