Distilling Knowledge via Intermediate Classifiers (DIH) is a knowledge distillation framework that mitigates the negative impact of the capacity gap, i.e., the difference in model complexity between the teacher and the student model on knowledge distillation. This approach improves the canonical knowledge distillation (KD) with the help of the teacher's intermediate representations (the outputs of some of the hidden layers).
- First, k classifier heads have to be mounted to various intermediate layers of the teacher (see Table 1 for the structure of models, i.e., the location and also the value of k in this repository).
- The added intermediate classifier heads pass a cheap and efficient fine-tuning (while the main teacher is frozen). The fine-tuning step is cheaper and more efficient than training a whole model (i.e., a fraction of the teacher model and the added intermediate classifier head module) from scratch. This is due to the frozen state of the backbone of the model, i.e., only the added intermediate head needs to be trained.
- The cohort of classifiers (all the mounted ones + the final main classifier) co-teach the student simultaneously with knowledge distillation.
- torch 1.7.1 the project is built in PyTorch.
- torchvision 0.8.2 used for datasets, and data preprocessing.
- tqdm 4.48.2 for better visualization of training process.
- torchsummary for invesitating the model's architecture.
- numpy 1.19.4 used in preprocessing the dataset, adn showing examples.
- argparse passsing the input variables for easy reproducibility.
- os reading and writing the trained model's weights.
pip3 install -r requirements.txt
Tiny-ImageNet contains 64x64 pixel RGB images for 200 classes, subsampled from the ImageNet dataset. The dataset is composed of 100,000 training and 10,000 testing images. All training and testing datasets are balanced (i.e., the number of images per class is the same within the dataset). We followed data augmentation techniques similar to CIFAR-10 and CIFAR-100, i.e., the images are augmented by the combination of horizontal flips, 4 pixels padding, and random 64-pixel crops. We also normalized the images by their mean and standard deviation.
- Canonical Knowledge Distillation (KD) As one of the benchmarks, we use conventional KD (in the context and the experiments, we have referred to canonical knowledge distillation as KD). We used the same temperature (τ=5) and the same alpha weight (α=0.1) as DIH.
- FitNets FitNets, as a knowledge distillation framework, first transfers the knowledge of a fraction of a trained teacher model up to a selected layer (known as hint layer) to a fraction of a student model up to a selected intermediate layer (called guided layer). This step optimizes the chosen fraction of the student by using the L2 loss objective. The second step of FitNets is the canonical knowledge distillation (KD) to transfer the knowledge from the complete teacher to the entire student. We trained the selected fraction of the student for 40 epochs using the L2 loss function for the first step. In the second step, we used the same KD setting and trained the complete student model for 200 epochs.
- Knowledge Distillation with Teacher Assistants (TAKD) We limited the number of teacher assistants to 1 for experiments in Table 3 of the paper. The setting for training the teacher assistant and the final student is identical (the same setting for KD).
- Attention Distillation (AT) AT transfers the teacher's attention maps (i.e., channel-wise averaged activation maps) to the student's equivalent layer.
- Contrastive Representation Distillation (CRD) CRD improves canonical KD using contrastive learning. The loss objective maximizes the teacher-student mutual information's lower bound. Using this framework, the student learns to generate feature maps close to each other for positive sample pairs and increases the distance between the representations for negative pairs.
- Task-Oriented Feature Distillation (TOFD) Like our approach, TOFD tries to improve the canonical KD with the help of intermediate classifier heads. However, TOFD equips both the teacher and the student with very deep and complex classifier modules containing multiple convolutional, batch normalization, and fully connected layers. Each classifier module resembles the rest of the teacher backbone architecture after the attachment location up to the end of the model, e.g., Consider a residual model with four residual stages; The classifier module attached to the first residual stage would comprise three remaining residual blocks followed by the fully connected layer at the end. Besides different classifier architectures, TOFD also uses a different set of loss objectives. Each student classifier is optimized using regular CE, canonical KD using soft probabilities generated by the teacher's same-stage classifier, L2 loss objective to match same-stage intermediate representations, and the orthogonal loss for information loss reduction(only applied to feature resizing layers).
- Multi-head Knowledge Distillation for Model Compression (MHKD) MHKD is a similar approach to ours, while in MHKD, similar to TOFD, both teacher and the student are equipped with multiple classifier heads that contain convolutional, batch normalization, and ReLU, followed by a fully connected layer at the end. However, MHKD uses a fixed architecture for classifier modules, containing two convolutional layers with batch normalization and ReLU, followed by two fully connected layers. In contrast, we have used simpler classifier modules by only using fully connected layers as intermediate classifiers. MHKD and TOFD also differ in their loss function. MHKD optimizes the student's classifier heads using regular CE and canonical KD with same-stage teacher classifier's soft labels.
Teacher Model | # Intermediate heads (k) |
---|---|
ResNet-34 | 4 |
ResNet-18 | 4 |
VGG-11 | 4 |
WR-28-2 | 3 |
ResNet-110 | 3 |
ResNet-20 | 3 |
ResNet-14 | 3 |
ResNet-8 | 3 |
dataload.py
loads the data loader for training, validation, and testing for both datasets (CIFAR10-CIFAR100).models_repo
contains model classes(two categories of ResNets, VGG, and also the intermediate classifier module).KD_Loss.py
canonical knowledge distillation loss function.dih_utlis.py
includes the function for loading the trained intermediate heads.train_dih.py
contains the function for distillation via intermediate heads (DIH).train_funcs.py
regular cross-entropy training, and intermediate header's fine_tuning functions.CRD/train_student.py
train function for Contrastive distillation (CRD).TOFD/tofd_train.py
train function for Task-oriented feature distillation (TOFD).MHKD/mhkd_training.py
train function for Multi-head knowledge distillation (MHKD).test.py
testing console for running the functions above.
Hyper-parameter | args tag | Default value |
---|---|---|
student model | student | res8 |
teacher model | teacher | res110 |
learning rate | lr | 0.1 |
weight decay | wd | 5e-4 |
epochs | epochs | 200 |
dataset | dataset | cifar100 |
schedule | schedule | [60,120,180] |
γ | schedule_gamma | 0.1 |
temperature τ (KD) | kd_temperature | 5 |
α (KD) | kd_alpha | 0.1 |
batch size | batch_size | 64 |
training type | training_type | dih |
seed | seed | [30,50,67] |
python3 final_test.py --training_type ce --teacher res110 --path_to_save /home/teacher.pth --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3
python3 final_test.py --training_type fine_tune --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/headers --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3
python3 final_test.py --student res8 --training_type fitnets --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/stage_1.pth --epochs_fitnets_1 40 --nesterov_fitnets_1 True --momentum_fitnets_1 0.9 --lr_fitnets_1 0.1 --wd_fitnets_1 0.0005 --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3
python3 final_test.py --student res8 --training_type kd --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --kd_alpha 0.1 --seed 3 --kd_temperature 5
python3 train_student.py --student res8 --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 128 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --alpha 0.1 --beta 0.03 --temperature 5
python3 tofd_train.py --student res8 --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 128 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --alpha 0.1 --beta 0.03 --temperature 5
python3 train_mhkd.py --student res8 --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 128 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --alpha 0.1 --beta 0.03 --temperature 5
python3 final_test.py --student res8 --teacher res110 --saved_path /home/teacher.pth --saved_intermediates_directory /home/saved_headers/ --alpha 0.1 --temperature 5 --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3 --path_to_save /home/dih_model.pth
arxiv link: http://arxiv.org/abs/2103.00497
If you found this library useful in your research, please consider citing:
@misc{asadian2021distilling,
title={Distilling Knowledge via Intermediate Classifier Heads},
author={Aryan Asadian and Amirali Salehi-Abari},
year={2021},
eprint={2103.00497},
archivePrefix={arXiv},
primaryClass={cs.LG}
}