The code repository for "Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation" [paper] in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:
@inproceedings{DBLP:conf/interspeech/Zhang0YZ22,
author = {Yi{-}Kai Zhang and
Da{-}Wei Zhou and
Han{-}Jia Ye and
De{-}Chuan Zhan},
editor = {Hanseok Ko and
John H. L. Hansen},
title = {Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation},
booktitle = {Interspeech 2022, 23rd Annual Conference of the International Speech
Communication Association, Incheon, Korea, 18-22 September 2022},
pages = {531--535},
publisher = {{ISCA}},
year = {2022},
url = {https://doi.org/10.21437/Interspeech.2022-652},
doi = {10.21437/Interspeech.2022-652},
timestamp = {Tue, 11 Oct 2022 19:11:50 +0200},
biburl = {https://dblp.org/rec/conf/interspeech/Zhang0YZ22.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
Illustration of Proto-CAT. The model transforms the classification space using based on two kinds of audio-visual prototypes (class centers): (1) the base training categories (color with blue, green, and pink); and (2) the additional novel test categories (color with burning transition). Proto-CAT learns and generalizes on novel test categories from limited labeled examples, maintaining performance on the base training ones. includes audio-visual level and category level prototype-based co-adaptation. From left to right, more coverage and more bright colors represent a more reliable classification space.
Dataset | LRW | LRW-1000 | ||||||
Data Source | Audio () | Video () | Audio-Video () | Audio-Video () | ||||
Perf. Measures on | H-mean | H-mean | Base | Novel | H-mean | Base | Novel | H-mean |
LSTM-based | 32.20 | 8.00 | 97.09 | 23.76 | 37.22 | 71.34 | 0.03 | 0.07 |
GRU-based | 37.01 | 10.58 | 97.44 | 27.35 | 41.71 | 71.34 | 0.05 | 0.09 |
MS-TCN-based | 62.29 | 19.06 | 80.96 | 51.28 | 61.76 | 71.55 | 0.33 | 0.63 |
MAML | 35.49 | 10.25 | 40.09 | 66.70 | 49.20 | 29.40 | 23.21 | 25.83 |
BootstrappedMAML | 33.75 | 6.52 | 35.29 | 64.20 | 45.17 | 28.15 | 27.98 | 28.09 |
ProtoNet | 39.95 | 14.40 | 96.33 | 39.23 | 54.79 | 69.33 | 0.76 | 1.47 |
MatchingNet | 36.76 | 12.09 | 94.54 | 36.57 | 52.31 | 68.42 | 0.95 | 1.89 |
MetaOptNet | 43.81 | 19.59 | 88.20 | 47.06 | 60.73 | 69.01 | 1.79 | 3.44 |
DeepEMD | -- | 27.02 | 82.53 | 16.43 | 27.02 | 64.54 | 0.80 | 1.56 |
FEAT | 49.90 | 25.75 | 96.26 | 54.52 | 68.83 | 71.69 | 2.62 | 4.89 |
DFSL | 72.13 | 42.56 | 66.10 | 84.62 | 73.81 | 31.68 | 68.72 | 42.56 |
CASTLE | 75.48 | 34.68 | 73.50 | 90.20 | 80.74 | 11.13 | 54.07 | 17.84 |
Proto-CAT (Ours) | 84.18 | 74.55 | 93.37 | 91.20 | 92.13 | 49.70 | 38.27 | 42.25 |
Proto-CAT+ (Ours) | 93.18 | 90.16 | 91.49 | 54.55 | 38.16 | 43.88 |
Audio-visual generalized few-shot learning classification performance (in %; measured over 10,000 rounds; higher is better) of 5-way 1-shot training tasks on LRW and LRW-1000 datasets. The best result of each scenario is in bold font. The performance measure on both base and novel classes (Base, Novel in the table) is mean accuracy. Harmonic mean (i.e., H-mean) of the above two is a better generalized few-shot learning performance measure.
Please refer to requirements.txt
and run:
pip install -r requirement.txt
-
Use preprocessed data (suggested):LRW and LRW-1000 forbid directly share the preprocessed data.
-
Use raw data and do preprocess:
Download LRW Dataset and unzip, like,
/your data_path set in .sh file ├── lipread_mp4 │ ├── [ALL CLASS FOLDER] │ ├── ...
Run
prepare_lrw_audio.py
andprepare_lrw_video.py
to preprocess data on video and audio modality, respectively. Please modify the data path in the above preprocessing file in advance.Similarly, Download LRW-1000 dataset and unzip. Run
prepare_lrw1000_audio.py
andprepare_lrw1000_video.py
to preprocess it.
We provide pretrained weights on LRW and LRW-1000 dataset. Download from Google Drive or Baidu Yun(password: 3ad2) and put them as:
/your init_weights set in .sh file
├── Conv1dResNetGRU_LRW-pre.pth
├── Conv3dResNetLSTM_LRW-pre.pth
├── Conv1dResNetGRU_LRW1000-pre.pth
├── Conv3dResNetLSTM_LRW1000-pre.pth
For LRW dataset, fine-tune the parameters in run/protocat_lrw.sh
, and run:
cd ./Proto-CAT/run
bash protocat_lrw.sh
Similarly, run bash protocat_lrw1000.sh
for dataset LRW-1000.
Run bash protocat_plus_lrw.sh
/ bash protocat_plus_lrw1000.sh
to train Proto-CAT+.
Download the trained models from Google Drive or Baidu Yun(password: swzd) and run:
bash test_protocat_lrw.sh
Run bash test_protocat_lrw1000.sh
, bash test_protocat_plus_lrw.sh
, or bash test_protocat_plus_lrw1000.sh
to evaluate other models.
Proto-CAT's entry function is in main.py
. It calls the manager Trainer
in models/train.py
that contains the main training logic. In Trainer
, prepare_handle.prepare_dataloader
combined with train_prepare_batch
inputs and preprocesses generalized few-shot style data. fit_handle
controls forward and backward propagation. callbacks
deals with the behaviors at each stage.
All parameters are defined in models/utils.py
. We list the main ones below:
do_train
,do_test
: Store-true switch for whether to train or test.data_path
: Data directory to be set.model_save_path
: Optimal model save directory to be set.init_weights
: Pretrained weights to be set.dataset
: Option for the dataset.model_class
: Option for the top model.backend_type
: Option list for the backend type.train_way
,val_way
,test_way
,train_shot
,val_shot
,test_shot
,train_query
,val_query
,test_query
: Tasks setting of generalized few-shot learning.gfsl_train
,gfsl_test
: Switch for whether train or test in generalized few-shot learning way, i.e., whether additional base class data is included.mm_list
: Participating modalities.lr_scheduler
: List of learning rate scheduler.loss_fn
: Option for the loss function.max_epoch
: Maximum training epoch.episodes_per_train_epoch
,episodes_per_val_epoch
,episodes_per_test_epoch
: Number of sampled episodes per epoch.num_tasks
: Number of tasks per episode.meta_batch_size
: Batch size of each task.test_model_filepath
: Trained weights.pth
file path when testing a model.gpu
: Multi-GPU option like--gpu 0,1,2,3
.logger_filename
: Logger file save directory.time_str
: Token for each run, and will generate by itself if empty.acc_per_class
: Switch for whether to measure the accuracy of each class with base, novel, and harmonic mean.verbose
,epoch_verbose
: Switch for whether to output message or output progress bar.torch_seed
,cuda_seed
,np_seed
,random_seed
: Seeds of random number generation.
We thank the following repos providing helpful components/functions in our work.