-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
About Train #36
Comments
Epoch 1 cannot generate any reasonable image because the value is likely to be NAN. The model should generate reasonable images after 160 epochs (with ema) and 80 epochs (without ema). |
Thanks for your reply!!! |
40-80 should generate non-black images. |
@LTH14 Awesome job! By the way, why not conduct masking before Considering the denoising is conducted token by token, and the unmasked tokens do not contribute to the diffusion_loss, removing them may increase the training speed slightly? (though the mask ratio is high😂) |
@Andy1621 The masking needs to be on the input, otherwise the transformer can see the masked tokens and easily predict it |
@LTH14 |
I just notice you use only 2 GPUs. This will result in a batch size of 64, but we use 2048 in all our experiments. I'm not sure whether this is the cause of your problem, but I would recommend using at least 512 batch size to train the model. |
Also blr=1e-6 -- blr should not be scaled with batch size, as we scale the actual learning rate according to the batch size in our code https://github.com/LTH14/mar/blob/main/main_mar.py#L221. |
@LTH14 I train the mar model on 8 NVIDIA-4090 and now the epoch is 47. |
@pokameng as I mentioned, your --blr should be 1e-4 instead of 1e-6. 1e-6 is too small |
ok thanks !I will modify the blr |
--diffloss_d 3 should be enough. --diffloss_d 6 might give you slightly better performance |
@LTH14 Thanks for your quick response. What I mean is to modify the forward function of def forward(self, target, z, mask=None):
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
z = z[mask] # only keep masked token
model_kwargs = dict(c=z)
loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
loss = loss_dict["loss"]
# if mask is not None:
# loss = (loss * mask).sum() / mask.sum()
return loss.mean() Since SimpleMLPAdaLN operates on each token individually, should it be faster? |
@Andy1621 Oh I got it -- yes will be slightly faster and equivalent to masking after computing the loss, but not very much since the masking ratio is quite high. |
Thanks for your quick response! |
@LTH14 |
@pokameng autoregression is a concept that decomposes the distribution through chain rule, so the model predicts future tokens based on existing tokens. To enable autoregression, we employ mask-based training, which predicts masked tokens given unmasked tokens. Both our training and inference is important to enable autoregression of our model. |
Thank you for your reply! ! |
Yes -- in the paper we call it "random order" autoregression, or masked autoregression. |
yes -- the job of the small diffloss MLP is to model |
Thank you for your quick reply! ! ! |
@LTH14 |
This is to make the diffusion process cheaper -- the small diffusion model can model the distribution in a very efficient way |
@LTH14 |
So the MLP diffusion model is to model distribution p(x|z),right? |
Z is not a reconstructed image. X is the token to reconstruct the image |
Yes, the MLP diffusion model is to model distribution p(x|z) |
tokens to be predict at this step is a subset of the unknown tokens |
@LTH14 |
hello
@LTH14
I am re-training MAR on Imagenet Dataset, and evaluate checkpoint on the epoch 1. However, the image sampled from epoch-1 is black. I want to know why?Does it means the epoch I used is too earlier?
My config:
CUDA_VISIBLE_DEVICES=6,7 torchrun --nproc_per_node=2 --nnodes=1 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT}
main_mar.py
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1
--model mar_base --diffloss_d 3 --diffloss_w 1024
--epochs 400 --warmup_epochs 100 --batch_size 32 --blr 1e-6 --diffusion_batch_mul 4
--output_dir ${OUTPUT_DIR}
--data_path ${IMAGENET_PATH}
The text was updated successfully, but these errors were encountered: