Skip to content

Commit

Permalink
remove prepare_model_for_8bit_training
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Mar 25, 2024
1 parent 75da24a commit fa4fc08
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
from transformers.utils import PaddingStrategy
from peft import (
get_peft_model, LoraConfig, TaskType, PeftModel,
prepare_model_for_kbit_training,
prepare_model_for_int8_training
prepare_model_for_kbit_training
)
from peft.tuners.lora import LoraLayer

Expand Down Expand Up @@ -1108,13 +1107,13 @@ def __init__(self,
lora_config['bias'] = "none"
lora_config['task_type'] = TaskType.CAUSAL_LM

if load_kbit == 4:
if load_kbit in [4, 8]:
model = MODEL_CLASS.from_pretrained(
model_name_or_path,
load_in_4bit=True,
config=None,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
load_in_4bit=load_kbit == 4,
load_in_8bit=load_kbit == 8,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float32,
Expand Down Expand Up @@ -1149,17 +1148,10 @@ def __init__(self,
if train_mode:
model = MODEL_CLASS.from_pretrained(
model_name_or_path,
load_in_8bit=load_kbit == 8,
torch_dtype=torch.float16 if load_kbit == 16 else torch.float32,
device_map=device_map,
trust_remote_code=True,
)
if load_kbit == 8:
model = prepare_model_for_int8_training(model, **kbit_kwargs)
if 'target_modules' not in lora_config or lora_config.get('target_modules', None) is None:
target_modules = find_all_linear_names(model)
lora_config['target_modules'] = target_modules
logger.info(f'lora target modules={target_modules}')
if pretrained_lora_path is not None:
print(f'Load lora weight from {pretrained_lora_path}')
model = PeftModel.from_pretrained(
Expand Down

0 comments on commit fa4fc08

Please sign in to comment.