This is an open source online test-time adaptation repository based on PyTorch. It is joint work by Robert A. Marsden and Mario Döbler. It is also the official repository for the following works:
- Introducing Intermediate Domains for Effective Self-Training during Test-Time
- Robust Mean Teacher for Continual and Gradual Test-Time Adaptation (CVPR2023)
- Universal Test-time Adaptation through Weight Ensembling, Diversity Weighting, and Prior Correction (WACV2024)
- A Lost Opportunity for Vision-Language Models: A Comparative Study of Online Test-time Adaptation for Vision-Language Models (CVPR2024 MAT Workshop Community Track)
Cite
@article{marsden2022gradual,
title={Gradual test-time adaptation by self-training and style transfer},
author={Marsden, Robert A and D{\"o}bler, Mario and Yang, Bin},
journal={arXiv preprint arXiv:2208.07736},
year={2022}
}
@inproceedings{dobler2023robust,
title={Robust mean teacher for continual and gradual test-time adaptation},
author={D{\"o}bler, Mario and Marsden, Robert A and Yang, Bin},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={7704--7714},
year={2023}
}
@inproceedings{marsden2024universal,
title={Universal Test-time Adaptation through Weight Ensembling, Diversity Weighting, and Prior Correction},
author={Marsden, Robert A and D{\"o}bler, Mario and Yang, Bin},
booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
pages={2555--2565},
year={2024}
}
@article{dobler2024lost,
title={A Lost Opportunity for Vision-Language Models: A Comparative Study of Online Test-time Adaptation for Vision-Language Models},
author={D{\"o}bler, Mario and Marsden, Robert A and Raichle, Tobias and Yang, Bin},
journal={arXiv preprint arXiv:2405.14977},
year={2024}
}
We encourage contributions! Pull requests to add methods are very welcome and appreciated.
To use the repository, we provide a conda environment.
conda update conda
conda env create -f environment.yml
conda activate tta
Features
This repository contains an extensive collection of different methods, datasets, models, and settings, which we evaluate in a comprehensive benchmark (see below). We also provide a tutorial on how to use this repository in combination with CLIP-like models here. A brief overview of the repository's main features is provided below:
-
Datasets
cifar10_c
CIFAR10-Ccifar100_c
CIFAR100-Cimagenet_c
ImageNet-Cimagenet_a
ImageNet-Aimagenet_r
ImageNet-Rimagenet_v2
ImageNet-V2imagenet_k
ImageNet-Sketchimagenet_d
ImageNet-Dimagenet_d109
domainnet126
DomainNet (cleaned)Continually Changing Corruptions
CCC
-
Models
- For adapting to ImageNet variations, all pre-trained models available in Torchvision or timm can be used.
- For the corruption benchmarks, pre-trained models from RobustBench can be used.
- For the DomainNet-126 benchmark, there is a pre-trained model for each domain.
- Further models include ResNet-26 GN.
- It is also possible to use the models provided by OpenCLIP.
-
Settings
reset_each_shift
Reset the model state after the adaptation to a domain.continual
Train the model on a sequence of domains without knowing when a domain shift occurs.gradual
Train the model on a sequence of gradually increasing/decreasing domain shifts without knowing when a domain shift occurs.mixed_domains
Train the model on one long test sequence where consecutive test samples are likely to originate from different domains.correlated
Same as the continual setting but the samples of each domain are further sorted by class label.mixed_domains_correlated
Mixed domains and sorted by class label.- Combinations like
gradual_correlated
orreset_each_shift_correlated
are also possible.
-
Methods
-
Mixed Precision Training
- Almost all of the aforementioned methods (except SAR and GTTA) can be trained with mixed precision. This greatly speeds up your experiments and requires less memory. However, all benchmark results are generated with fp32.
-
Modular Design
- Adding new methods should be rather simple, thanks to the modular design.
To run one of the following benchmarks, the corresponding datasets need to be downloaded.
- CIFAR10-to-CIFAR10-C: the data is automatically downloaded.
- CIFAR100-to-CIFAR100-C: the data is automatically downloaded.
- ImageNet-to-ImageNet-C: for non source-free methods, download ImageNet and ImageNet-C.
- ImageNet-to-ImageNet-A: for non source-free methods, download ImageNet and ImageNet-A.
- ImageNet-to-ImageNet-R: for non source-free methods, download ImageNet and ImageNet-R.
- ImageNet-to-ImageNet-V2: for non source-free methods, download ImageNet and ImageNet-V2.
- ImageNet-to-ImageNet-Sketch: for non source-free methods, download ImageNet and ImageNet-Sketch.
- ImageNet-to-ImageNet-D: for non source-free methods, download ImageNet. For ImageNet-D, see the download instructions for DomainNet-126 below. ImageNet-D is created by symlinks, which are set up at the first use.
- ImageNet-to-ImageNet-D109: see instructions for DomainNet-126 below.
- DomainNet-126: download the 6 splits of the cleaned version. Following MME, DomainNet-126 only uses a subset that contains 126 classes from 4 domains.
- ImageNet-to-CCC: for non source-free methods, download ImageNet. CCC is integrated as a webdataset and does not need to be downloaded! Please note that it cannot be combined with settings such as correlated.
After downloading the missing datasets, you may need to adapt the path to the root directory _C.DATA_DIR = "./data"
located in the file conf.py
. For the individual datasets, the directory names are specified in conf.py
as a dictionary (see function complete_data_dir_path
).
In case your directory names deviate from the ones specified in the mapping dictionary, you can simply modify them.
We provide config files for all experiments and methods. Simply run the following Python file with the corresponding config file.
python test_time.py --cfg cfgs/[ccc/cifar10_c/cifar100_c/imagenet_c/imagenet_others/domainnet126]/[source/norm_test/norm_alpha/tent/memo/rpl/eta/eata/rdumb/sar/cotta/rotta/adacontrast/lame/gtta/rmt/roid/tpt].yaml
For imagenet_others, the argument CORRUPTION.DATASET
has to be passed:
python test_time.py --cfg cfgs/imagenet_others/[source/norm_test/norm_alpha/tent/memo/rpl/eta/eata/rdumb/sar/cotta/rotta/adacontrast/lame/gtta/rmt/roid/tpt].yaml CORRUPTION.DATASET [imagenet_a/imagenet_r/imagenet_k/imagenet_v2/imagenet_d109]
E.g., to run ROID for the ImageNet-to-ImageNet-R benchmark, run the following command.
python test_time.py --cfg cfgs/imagenet_others/roid.yaml CORRUPTION.DATASET imagenet_r
Alternatively, you can reproduce our experiments by running the run.sh
in the subdirectory classification/scripts
.
For the different settings, modify setting
within run.sh
.
To run the different continual DomainNet-126 sequences, you have to pass the MODEL.CKPT_PATH
argument.
When not specifying a CKPT_PATH
, the sequence using the real domain as the source domain will be used.
The checkpoints are provided by AdaContrast and can be downloaded here.
Structurally, it is best to download them into the directory ./ckpt/domainnet126
.
python test_time.py --cfg cfgs/domainnet126/rmt.yaml MODEL.CKPT_PATH ./ckpt/domainnet126/best_clipart_2020.pth
For GTTA, we provide checkpoint files for the style transfer network. The checkpoints are provided on
Google-Drive (download);
extract the zip-file within the classification
subdirectory.
Changing the evaluation configuration is extremely easy. For example, to run TENT on ImageNet-to-ImageNet-C
in the reset_each_shift
setting with a ResNet-50 and the IMAGENET1K_V1
initialization, the arguments below have to be passed.
Further models and initializations can be found here (torchvision) or here (timm).
python test_time.py --cfg cfgs/imagenet_c/tent.yaml MODEL.ARCH resnet50 MODEL.WEIGHTS IMAGENET1K_V1 SETTING reset_each_shift
For ImageNet-C, the default image list provided by robustbench considers 5000 samples per domain
(see here). If you are interested in running experiments on the full
50,000 test samples, simply set CORRUPTION.NUM_EX 50000
, i.e.
python test_time.py --cfg cfgs/imagenet_c/roid.yaml CORRUPTION.NUM_EX 50000
We support for most methods automatic mixed precision updates with loss scaling.
By default mixed precision is set to false. To activate mixed precision set the argument MIXED_PRECISION True
.
We provide detailed results for each method using different models and settings here, The benchmark is updated regularly as new methods, datasets or settings are added to the repository. Further information on the settings or models can also be found in our paper.
- Robustbench official
- CoTTA official
- TENT official
- AdaContrast official
- EATA official
- LAME official
- MEMO official
- RoTTA official
- SAR official
- RDumb official
- CMF official
- DeYO official
- TPT official
For running the experiments based on CarlaTTA, you first have to download the dataset splits as provided below. Again, you probably have to change the data directory _C.DATA_DIR = "./data"
in conf.py
. Further, you have to download the pre-trained source checkpoints (download) and extract the zip-file within the segmentation
subdirectory.
E.g., to run GTTA, use the config file provided in the directory cfgs
and run:
python test_time.py --cfg cfgs/gtta.yaml
You can also change the test sequences by setting LIST_NAME_TEST
to:
- day2night:
day_night_1200.txt
- clear2fog:
clear_fog_1200.txt
- clear2rain:
clear_rain_1200.txt
- dynamic:
dynamic_1200.txt
- highway:
town04_dynamic_1200.txt
If you choose highway as the test sequence, you have to change the source list and the corresponding checkpoint paths.
python test_time.py --cfg cfgs/gtta.yaml LIST_NAME_SRC clear_highway_train.txt LIST_NAME_TEST town04_dynamic_1200.txt CKPT_PATH_SEG ./ckpt/clear_highway/ckpt_seg.pth CKPT_PATH_ADAIN_DEC = ./ckpt/clear_highway/ckpt_adain.pth
We provide the different datasets of CarlaTTA as individual zip-files on Google-Drive: