This repository contains code to reproduce the results from our paper From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step.
For multiplication, an online demo is available at gpt2-multiplication. The inference code and model behind this demo can be accessed in the app.py file. Please note that this online demo uses a slightly different data format than the one used in this repository. Specifically, it uses =
instead of <|endoftext|>
as the separator to ensure compatibility with standard Hugging Face transformer generation code, and the ####
symbols in the answer part are removed to provide a cleaner output. Despite these changes, the proposed approach still works effectively, demonstrating the generality of the method. Additionally, the online demo merges data with varying input digit lengths for training, allowing a single model to handle inputs of different digit counts. In contrast, the code in this repository trains one model for each number of digits for controlled experiments.
For GSM8K, an online demo is available at implicit-cot-math. Note that this online demo internalizes one </s>
token that separates the CoT from the answer (while keeping the </s>
token that separates the input from the CoT intact), ensuring compatibility with standard Hugging Face transformer generation code. The final accuracy remains at 52%, nearly the same as the number reported in the paper (51%).
- PyTorch
- transformers (
pip install transformers
)
All dataset files and log files during inference are included in this repository, with the exception of large training files maintained under Git LFS. Model checkpoints are stored on Google Drive. The folder containing all model checkpoints and datasets can be found at this link.
- 4 X 4 Mult - GPT-2 (Acc: 1.00): data model log
- 5 X 5 Mult - GPT-2 (Acc: 1.00): data model log
- 7 X 7 Mult - GPT-2 (Acc: 0.95): data model log
- 9 X 9 Mult - GPT-2 (Acc: 0.99): data model log
- 11 X 11 Mult - GPT-2 (Acc: 0.74): data model (partially internalized) log
- GSM8K - GPT-2 (Acc: 0.30): data model log
- GSM8K - GPT-2 Medium (Acc: 0.35): data model log
- GSM8K - Phi-3 3.8B (Acc: 0.31): data model log
- GSM8K - Mistral 7B (Acc: 0.51): data model log
We have included more multiplication datasets than those used in the paper to encourage future research that might yield even better results. The folder containing all datasets can be found at this link.
- 6 X 6 Mult: data
- 8 X 8 Mult: data
- 10 X 10 Mult: data
- 12 X 12 Mult: data
- 13 X 13 Mult: data
- 14 X 14 Mult: data
- 15 X 15 Mult: data
- 16 X 16 Mult: data
- 17 X 17 Mult: data
- 18 X 18 Mult: data
- 19 X 19 Mult: data
- 20 X 20 Mult: data
We use 9 X 9 Mult with GPT-2 as an example. We assume that the working directory is Internalize_CoT_Step_by_Step
throughout this document.
The format of training, validation, and test files is as follows:
[input 1]||[chain-of-thought 1] #### [output 1]
[input 2]||[chain-of-thought 2] #### [output 3]
[input 3]||[chain-of-thought 2] #### [output 3]
...
For example, the first line from the 4 X 4 Mult test set in data/4_by_4_mult/test_bigbench.txt is:
9 1 7 3 * 9 4 3 3||1 7 4 3 3 + 0 6 7 8 4 1 ( 1 3 2 2 8 1 ) + 0 0 7 5 1 1 1 ( 1 3 9 7 9 2 1 ) + 0 0 0 7 5 1 1 1 #### 1 3 9 4 5 4 2 1
In this example, the input is 9 1 7 3 * 9 4 3 3
(corresponding to 3719*3349
, note that we reversed the digits), the chain-of-thought is 1 7 4 3 3 + 0 6 7 8 4 1 ( 1 3 2 2 8 1 ) + 0 0 7 5 1 1 1 ( 1 3 9 7 9 2 1 ) + 0 0 0 7 5 1 1 1
, and the output is 1 3 9 4 5 4 2 1
(corresponding to 12454931
).
Note that the chain-of-thought steps are only used for training, not for generation.
To train the model, run the following commands. The example uses 9 X 9 Mult with GPT-2:
export D=9
export FOLDER=data/${D}_by_${D}_mult/
export MODEL=gpt2
export EPOCHS=200
export LR=5e-5
export BSZ=32
export ACCUMULATE=1
export REMOVE_PER_EPOCH=8
export REMOVE_ALL_WHEN_REMOVE_BEYOND=inf
export REMOVAL_SMOOTHING_LAMBDA=4
export REMOVAL_SIDE=left
export PRETRAIN_EPOCHS=0
export SEED=3456
export SAVE=train_models/${D}_by_${D}_mult/gpt2
mkdir -p $SAVE
TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/train.py \
--model ${MODEL} \
--train_path ${FOLDER}/train.txt \
--val_path ${FOLDER}/valid.txt \
--epochs ${EPOCHS} \
--lr ${LR} \
--batch_size ${BSZ} \
--accumulate ${ACCUMULATE} \
--remove_per_epoch ${REMOVE_PER_EPOCH} \
--remove_all_when_remove_beyond ${REMOVE_ALL_WHEN_REMOVE_BEYOND} \
--removal_smoothing_lambda ${REMOVAL_SMOOTHING_LAMBDA} \
--removal_side ${REMOVAL_SIDE} \
--pretrain_epochs ${PRETRAIN_EPOCHS} \
--seed ${SEED} \
--reset_optimizer \
--save_model ${SAVE} \
> ${SAVE}/log.train 2>&1
Here we use a pretrained model as an example. Download the folder models/9_by_9_mult/gpt2
, then the following command will run inference and evaluate both accuracy and throughput, logged in file generation_logs/9_by_9_mult/log.generate
.
export D=9
export FOLDER=data/${D}_by_${D}_mult/
export MODEL=models/${D}_by_${D}_mult/gpt2
export BSZ=1
export SAVE=generation_logs/${D}_by_${D}_mult/gpt2
mkdir -p $SAVE
TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/generate.py \
--from_pretrained ${MODEL} \
--test_path ${FOLDER}/test_bigbench.txt \
--batch_size ${BSZ} \
> ${SAVE}/log.generate 2>&1&
export FOLDER=data/gsm8k
export MODEL=mistralai/Mistral-7B-v0.1
export EPOCHS=80
export LR=1e-5
export BSZ=16
export ACCUMULATE=2
export REMOVE_PER_EPOCH=8
export REMOVE_ALL_WHEN_REMOVE_BEYOND=39
export MAX_LEN_TRAIN=150
export REMOVAL_SMOOTHING_LAMBDA=4
export REMOVAL_SIDE=left
export PRETRAIN_EPOCHS=0
export SEED=1234
export SAVE=train_models/gsm8k
mkdir -p $SAVE
TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/train.py \
--model ${MODEL} \
--train_path ${FOLDER}/train.txt \
--val_path ${FOLDER}/valid.txt \
--epochs ${EPOCHS} \
--lr ${LR} \
--batch_size ${BSZ} \
--accumulate ${ACCUMULATE} \
--remove_per_epoch ${REMOVE_PER_EPOCH} \
--remove_all_when_remove_beyond ${REMOVE_ALL_WHEN_REMOVE_BEYOND} \
--removal_smoothing_lambda ${REMOVAL_SMOOTHING_LAMBDA} \
--removal_side ${REMOVAL_SIDE} \
--pretrain_epochs ${PRETRAIN_EPOCHS} \
--seed ${SEED} \
--reset_optimizer \
--bf16 \
--max_len_train ${MAX_LEN_TRAIN} \
--save_model ${SAVE} \
> ${SAVE}/log.train 2>&1