Skip to content

A pip-installable evaluator for GANs (IS and FID). Accepts either dataloaders or individual batches. Supports on-the-fly evaluation during training. A working DCGAN SVHN demo script provided.

License

Notifications You must be signed in to change notification settings

ChenLiu-1996/GAN-evaluator

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GAN Evaluator for Inception Score (IS) and Frechet Inception Distance (FID) in PyTorch

Latest PyPI version PyPI license PyPI download month PyPI download day made-with-python

Please kindly Star Github Stars this repo for better reach if you find it useful. Let's help out the community!

Main Contributions

  1. We created a GAN evaluator for IS and FID that
    • is easy to use,
    • accepts data as either dataloaders or individual batches, and
    • supports on-the-fly evaluation during training.
  2. We provided a simple demo script to demonstrate one common use case.

NEWS

[Feb 18, 2023]

Now available on PyPI! Now you can pip install it to your desired environment via:

pip install gan-evaluator

And in your Python project, wherever you need the GAN_Evaluator, you can import via:

from gan_evaluator import GAN_Evaluator

NOTE 1: You no longer need to copy any code from this repo in order to use GAN_Evalutor! At this point, the primary purpose of this repo is description and demonstration. With that said, you surely can clone this repo and try out the demo script. Also, you may find it easier to copy and modify the code if you want slightly different behaviors.

NOTE 2: During pip install gan-evaluator, the dependencies of GAN_Evaluator (but not of the demo script) are also installed.

Demo Script: Use DCGAN to generate SVHN digits

The script can be found in src/train_dcgan_svhn.py

  • Usage from the demo script, to give you a taste.

    Declaration

    evaluator = GAN_Evaluator(device=device,
                              num_images_real=len(train_loader.dataset),
                              num_images_fake=len(train_loader.dataset))
    

    Before traing loop

    evaluator.load_all_real_imgs(real_loader=train_loader, idx_in_loader=0)
    

    Inside traing loop

    if shall_plot:
        IS_mean, IS_std, FID = evaluator.fill_fake_img_batch(fake_batch=x_fake)
    else:
        evaluator.fill_fake_img_batch(fake_batch=x_fake, return_results=False)
    

    After each epoch of training

    evaluator.clear_fake_imgs()
    
  • Some visualizations of the demo script:

    • Real (top) and Generated (bottom) images.
    • IS and FID curves.

Details: The Evaluator for IS and FID

Introduction to the Evaluator

More details can be found in src/utils/gan_evaluator.py/GAN_Evaluator.

This evaluator computes the following metrics:
    - Inception Score (IS)
    - Frechet Inception Distance (FID)

This evaluator will take in the real images and the fake/generated images.
Then it will compute the activations from the real and fake images as well as the
predictions from the fake images.
The (fake) predictions will be used to compute IS, while
the (real, fake) activations will be used to compute FID.
If input image resolution < 75 x 75, we will upsample the image to accommodate Inception v3.

The real and fake images can be provided to this evaluator in either of the following formats:
1. dataloader
    `load_all_real_imgs`
    `load_all_fake_imgs`
2. per-batch
    `fill_real_img_batch`
    `fill_fake_img_batch`

!!! Please note: the latest IS and FID will be returned upon completion of either of the following:
    `load_all_fake_imgs`
    `fill_fake_img_batch`
Return format:
    (IS mean, IS std, FID)
*So please make sure you load real images before the fake images.*

Common Use Cases:
1. For the purpose of on-the-fly evaluation during GAN training:
    We recommend pre-loading the real images using the dataloader format, and
    populate the fake images using the per-batch format as training goes on.
    - At the end of each epoch, you can clean the fake images using:
        `clear_fake_imgs`
    - In *unusual* cases where your real images change (such as in progressive growing GANs),
    you may want to clear the real images. You can do so via:
        `clear_real_imgs`
2. For the purpose of offline evaluation of a saved dataset:
    We recommend pre-loading the real images and fake images.

Repository Hierarchy

GAN-evaluator
    ├── config
    |   └── `dcgan_svhn.yaml`
    ├── data (*)
    ├── debug_plot (*)
    ├── logs (*)
    └── src
        ├── utils
        |   ├── `gan_evaluator.py`: THIS CONTAINS OUR `GAN_Evaluator`.
        |   └── other utility files.
        └── `train_dcgan_svhn.py`: our demo script.

Folders marked with (*), if not exist, will be created automatically when you run train_dcgan_svhn.py.

Usage

  • To integrate our evaluator into your existing project, you can simply install via pip!

  • You can refer to the demo script to see how to interface with the evaluator. Briefly, you need to declare the evaluator, feed in the real images, and feed in the fake images. In most cases, you would feed in the real images in one pass, whereas feed in the fake image on-the-fly as the model is trained and fake images are generated. Each time you feed in the fake images, you can choose whether or not to compute the metrics. Lastly, don't forget to remove the fake images at the end of each epoch.

  • To run our demo script, do the following after activating the proper environment.

git clone [email protected]:ChenLiu-1996/GAN-evaluator.git
cd src
python train_dcgan_svhn.py --config ../config/dcgan_svhn.yaml

Citation

To be added.

Environement Setup

Packages Needed

The GAN_Evaluator module itself only uses numpy, scipy, torch, torchvision, and (for aesthetics) tqdm.

To run the example script, it additionally requires matplotlib, argparse, and yaml.

On our Yale Vision Lab server
  • There is a virtualenv ready to use, located at /media/home/chliu/.virtualenv/mondi-image-gen/.

  • Alternatively, you can start from an existing environment "torch191-py38env", and install the following packages:

python3 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
python3 -m pip install wget gdown numpy matplotlib pyyaml click scipy yacs scikit-learn

If you see error messages such as Failed to build CUDA kernels for bias_act., you can fix it with:

python3 -m pip install ninja

Acknowledgements

  1. The code for the GAN_Evaluator (specifically, the computation of IS and FID) is inspired by:
  2. The code for the demo script (specifically, architecture and training of DCGAN) is inspired by: