Ting-Yun Chang, Jesse Thomason, and Robin Jia
Paper: https://arxiv.org/abs/2406.13131 Blog: https://terarachang.github.io/projects/llm-decomp.html
export HF_TOKEN="YOUR TOKEN"
pip install -r requirements.txt
$ bash scripts/comp_rw.sh
- Implementation of model decomposition: decompose.py
- Implementation of reweighting: train_components.py
$ bash scripts/standard.sh
$ bash scripts/calibration.sh
- Implementation of trainable calibration: train_calib.py
- Our repo supports LLMs in the Llama and Mistral family
- To support new models, please add hooks to the model and follow the naming convention of my_modeling_llama.py
- If the new model also uses RMSNorm, the decompose.py file is directly applicable. Otherwise, please take care of layernorms, which may greatly influence model performance!
- *We do not fully adopt TransformerLens to avoid numerical issues in Llama-3 and reduce computation overhead