Source code and datasets for Explicitly disentangling image content from translation and rotation with spatial-VAE to appear at NeurIPS 2019.
Learned hinge motion of 5HDB (1-d latent variable)
Learned arm motion of CODH/ACS (2-d latent variable)
Learned antibody conformations (2-d latent variable)
@incollection{bepler2019spatialvae,
title = {Explicitly disentangling image content from translation and rotation with spatial-VAE},
author = {Bepler, Tristan and Zhong, Ellen and Kelley, Kotaro and Brignole, Edward and Berger, Bonnie},
booktitle = {Advances in Neural Information Processing Systems 32},
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
pages = {15409--15419},
year = {2019},
publisher = {Curran Associates, Inc.},
url = {http://papers.nips.cc/paper/9677-explicitly-disentangling-image-content-from-translation-and-rotation-with-spatial-vae.pdf}
}
Dependencies:
- python 3
- pytorch >= 0.4
- torchvision
- numpy
- pillow
- topaz (for loading MRC files)
Datasets as tarballs are available from the links below.
- Rotated MNIST
- Rotated & Translated MNIST
- 5HDB simulated EM images
- CODH/ACS EM images
- Antibody EM images
- Galaxy zoo
The scripts, "train_mnist.py", "train_particles.py", and "train_galaxy.py", train spatial-VAE models on the MNIST, single particle EM, and galaxy zoo data.
For example, to train a spatial-VAE model on the CODH/ACS dataset
python train_particles.py data/codhacs/processed_train.npy data/codhacs/processed_test.npy --num-epochs=1000 --augment-rotation
Some script options include:
--z-dim: dimension of the unstructured latent variable (default: 2)
--p-hidden-dim and --p-num-layers: the number of layers and number of units per layer in the spatial generator network
--q-hidden-dim and --q-num-layers: the number of layers and number of units per layer in the approximate inference network
--dx-prior, --theta-prior: standard deviation (in fraction of image size) of the translation prior and standard deviation of the rotation prior
--no-rotate, --no-translate: flags to disable rotation and translation inference
--normalize: normalize the images before training (subtract mean, divide by standard deviation)
--ctf-train, --ctf-test: path to tables containing CTF parameters for the train and test images, used to perform CTF correction if provided
--fit-noise: also output the standard deviation of each pixel from the spatial generator network, sometimes called a colored noise model
--save-prefix: save model parameters every few epochs to this path prefix
See --help for complete arguments list.
This source code is provided under the MIT License.