Official implementation of the ICLR 2023 paper A Learning Based Hypothesis Test for Harmful Covariate Shift
We introduce the Detectron, a learning based hypothesis test for harmful covariate shift. Given a pretrained model
The algorithm works in two major steps:
First, we estimate the distribution of the test statistic
Next, we train another classifier
In our paper, we further show how to boost the power of the test using emsembling and by replacing the disagreement statistic
Test power at
CIFAR 10.1 [Recht et al.] | Camelyon 17 | UCI Heart Disease | |
---|---|---|---|
Black Box Shift Detection [Lipton et al.] | |||
Rel. Mahalanobis Distance [Ren et al.] | |||
Deep Ensemble (Disagreement) [Ablation] | |||
Deep Ensemble (Entropy) [Ablation] | |||
Classifier Two Sample Test (CTST) [Lopez-Paz et al.] | |||
Deep Kernel MMD [Liu et al.] | |||
H-Divergence [Zhao et al.] | |||
Detectron (Disagreement) [Ours] | |||
Detectron (Entropy) [Ours] |
The best result for each column is bolded, results that are within 2% of the best are underlined and the best baseline method is italicized.
detectron
requires a working build of pytorch
with the cudatoolkit enabled.
A simple environment setup using conda
is provided below.
# create and activate conda environment using a python version >= 3.9
conda create -n detectron python=3.9
conda activate detectron
# install the latest stable release of pytorch (tested for >= 1.9.0)
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
# install additional dependencies with pip
pip install -r requirements.txt
We provide a simple config system to store dataset path mappings in the file detectron/config.yml
datasets:
default: /datasets
cifar10_1: /datasets/cifar-10-1
camelyon17: /datasets/camelyon17
for more information on downloading datasets see detectron/data/sample_data/README.md
.
There is work in progress to package Detectron in a robust and easy to deploy system.
For now, all the code needed to reproduce our experiments is in located in the experiments
directory
and can be run like the following example.
# run the cifar experiment using the standard config
# use python experiments.detectron_cifar --help for a documented list of options
❯ python -m experiments.detectron_cifar --run_name cifar
The scratch files will write the output for each seed to a .pt
file in a directory named results/<run_name>
.
The script in experiments/analysis.py
will read these files and produce a summary of the results for each test
described in the paper.
❯ python -m experiments.analysis --run_name cifar
# Output
→ 600 runs loaded
→ Running Disagreement Test
N = 10, 20, 50
TPR: .37 ± .05 AUC: 0.799 | TPR: .54 ± .05 AUC: 0.902 | TPR: .83 ± .04 AUC: 0.981
→ Running Entropy Test
N = 10, 20, 50
TPR: .35 ± .05 AUC: 0.712 | TPR: .56 ± .05 AUC: 0.866 | TPR: .92 ± .03 AUC: 0.981
Please use the following citation if you use this code or methods in your own work.
@inproceedings{
ginsberg2023a,
title = {A Learning Based Hypothesis Test for Harmful Covariate Shift},
author = {Tom Ginsberg and Zhongyuan Liang and Rahul G Krishnan},
booktitle = {The Eleventh International Conference on Learning Representations },
year = {2023},
url = {https://openreview.net/forum?id=rdfgqiwz7lZ}
}