Jiahang Cao1*, Hanzhong Guo2*, Ziqing Wang3*, Deming Zhou1, Hao Cheng1, Qiang Zhang1, Renjing Xu1†
1The Hong Kong University of Science and Technology (Guangzhou)
2The Hong Kong University
3Northwestern University
This work SDM is an extended version of SDDPM. We introduce several key improvements:
- A New Family of Spiking-based Diffusion Models:This work extends applicability to a wider array of diffusion solvers, including but not limited to DDPM, DDIM, Analytic-DPM and DPM-Solver.
- Biologically Inspired Temporal-wise Spiking Mechanism (TSM): Inspired by biological processes that the neuron input at each moment experiences considerable fluctuations rather than being predominantly controlled by fixed synaptic weights, this module enables spiking neurons to capture more dynamic information. The TSM module can be integrated with existing modules (proposed by SDDPM) to further improve the image generation quality.
- ANN-SNN Conversion for SDM: To the best of our knowledge, we make the first attempt to utilize an ANN-SNN approach for implementing spiking diffusion models, complete with theoretical foundations.
Please see SDDPM.
Here we provide an example code to finetune the SDM models by inheriting the weights obtained from SDDPM pre-training:
from TSM import Spk_UNet_TSM
... (First, pretrain the standard SNN UNet)
pretrained_model = Spk_UNet(
T=args.T, ch=args.ch, ch_mult=args.ch_mult, attn=args.attn,
num_res_blocks=args.num_res_blocks, dropout=args.dropout, timestep=args.timestep, img_ch=args.img_ch)
# Load model
ckpt = torch.load(os.path.join('/your/pretrained_checkpoint'))
pretrained_model.load_state_dict(ckpt['net_model'], strict=True)
pretrained_dict = pretrained_model.state_dict()
net_model = Spk_UNet_TSM(
T=args.T, ch=args.ch, ch_mult=args.ch_mult, attn=args.attn,
num_res_blocks=args.num_res_blocks, dropout=args.dropout, timestep=args.timestep, img_ch=args.img_ch)
model_dict = net_model.state_dict()
new_state_dict = OrderedDict()
for name,para in pretrained_dict.items():
if name in model_dict:
new_state_dict[name] = para
elif 'conv' and 'weight' in name:
head = name[:-7]
new_name = head + '.tsmconv.weight'
new_state_dict[new_name] = para
elif 'conv' and 'bias' in name:
head = name[:-5]
new_name = head + '.tsmconv.bias'
new_state_dict[new_name] = para
net_model.load_state_dict(new_state_dict, strict=False)
print(f'-------Successfully inherit pretrained weights-------')
...(Next, finetune the TSM SDM with the same training code from SDDPM)
Example codes for sampling the images with DDIM solver.
The checkpoint of SDM with snn_timesteps=8
in CIFAR-10 is released. You can download the checkpoint through this link.
cd SDM
CUDA_VISIBLE_DEVICES=0 python sample.py
If you find our work useful, please consider citing:
@inproceedings{cao2024spiking,
title={Spiking Diffusion Models},
author={Cao, Jiahang and Guo, Hanzhong and Wang, Ziqing and Zhou, Deming and Cheng, Hao and Zhang, Qiang and Xu, Renjing},
journal={arXiv preprint arXiv:2408.16467},
year={2024}
}
We thank the authors (pytorch-ddpm, Fast-SNN, spikingjelly) for their open-sourced codes.
For any help or issues of this project, please contact [email protected].