Skip to content

Latest commit

 

History

History
79 lines (51 loc) · 5.56 KB

README.md

File metadata and controls

79 lines (51 loc) · 5.56 KB

Generative Latent Replay

arXiv PyTorch License: MIT conda Avalanche Python

Generative Latent Replay diagram

Method overview

Repo for Generative Latent Replay (GLR) - a continual learning method which aleviates catastophic forgetting through strict regularisation of low level data representation and synthetic latent replay. Explicitly GLR:

  1. Freezes the backbone of a network after initial training
  2. Builds generative models of the backbone-output latent representations of each dataset encountered by the model
  3. Samples latent pseudo-examples from these generators for replay during subsequent training (to mitigate catastrophic forgetting)

Features

Generative latent replay overcomes two issues encountered in traditional replay strategies:

  1. High memory footprint:
    • replays can be sampled ad hoc
    • caches [compressed] latent representations
  2. Privacy concerns
    • data is synthetic
Continual Learning Method Replay based Low memory Privacy
Naive
Replay
Latent Replay
Generative Latent Replay

Experiments

Description

We compare generative latent replay against the above methods on the following datasets:

  • Permuted MNIST
  • Rotated MNIST
  • CoRE50

We also explore the effect of different:

  • generative models (GMM, etc)
  • network freeze depths
  • replay buffer sizes
  • replay sampling strategies

Reproducing experiments

To run experiments, first create and activate a virtual environment:

conda env create -f environment.yml
conda activate env-glr

Then run the appropriate notebooks detailing the experiments.

Alternatively you can run the notebook directly in Google Colab:

Benchmark baseline

Porting method

Our implementation is fully compatible with the Avalanche continual learning library, and can be imported as a plugin in the same way as other Avalanche strategies:

from avalanche.training.plugins import StrategyPlugin
from glr.strategies import GenerativeLatentReplay