Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About Train #36

Open
pokameng opened this issue Sep 11, 2024 · 31 comments
Open

About Train #36

pokameng opened this issue Sep 11, 2024 · 31 comments

Comments

@pokameng
Copy link

hello
@LTH14
I am re-training MAR on Imagenet Dataset, and evaluate checkpoint on the epoch 1. However, the image sampled from epoch-1 is black. I want to know why?Does it means the epoch I used is too earlier?

My config:
CUDA_VISIBLE_DEVICES=6,7 torchrun --nproc_per_node=2 --nnodes=1 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT}
main_mar.py
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1
--model mar_base --diffloss_d 3 --diffloss_w 1024
--epochs 400 --warmup_epochs 100 --batch_size 32 --blr 1e-6 --diffusion_batch_mul 4
--output_dir ${OUTPUT_DIR}
--data_path ${IMAGENET_PATH}
image

@LTH14
Copy link
Owner

LTH14 commented Sep 11, 2024

Epoch 1 cannot generate any reasonable image because the value is likely to be NAN. The model should generate reasonable images after 160 epochs (with ema) and 80 epochs (without ema).

@pokameng
Copy link
Author

Epoch 1 cannot generate any reasonable image because the value is likely to be NAN. The model should generate reasonable images after 160 epochs (with ema) and 80 epochs (without ema).

Thanks for your reply!!!
But, I want ato know, how epoch can generate images not reasonable imgaes? About 40 epoch? And how to accelerate training?

@LTH14
Copy link
Owner

LTH14 commented Sep 11, 2024

40-80 should generate non-black images.

@Andy1621
Copy link

@LTH14 Awesome job! By the way, why not conduct masking before loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)?

Considering the denoising is conducted token by token, and the unmasked tokens do not contribute to the diffusion_loss, removing them may increase the training speed slightly? (though the mask ratio is high😂)

@LTH14
Copy link
Owner

LTH14 commented Sep 13, 2024

@Andy1621 The masking needs to be on the input, otherwise the transformer can see the masked tokens and easily predict it

@pokameng
Copy link
Author

@LTH14
Hello, I provide mar with additional conditional information and then train mar. However, I have trained for 45 epochs, and when I evaluate, the image is still black.

@LTH14
Copy link
Owner

LTH14 commented Sep 14, 2024

I just notice you use only 2 GPUs. This will result in a batch size of 64, but we use 2048 in all our experiments. I'm not sure whether this is the cause of your problem, but I would recommend using at least 512 batch size to train the model.

@LTH14
Copy link
Owner

LTH14 commented Sep 14, 2024

Also blr=1e-6 -- blr should not be scaled with batch size, as we scale the actual learning rate according to the batch size in our code https://github.com/LTH14/mar/blob/main/main_mar.py#L221.

@pokameng
Copy link
Author

@LTH14
hello, this is my current config:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --nnodes=1 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT}
main_mar_concat.py
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1
--model mar_base --diffloss_d 6 --diffloss_w 1024
--epochs 400 --warmup_epochs 100 --batch_size 32 --blr 1e-6 --diffusion_batch_mul 4
--output_dir ${OUTPUT_DIR}
--data_path ${IMAGENET_PATH} \

I train the mar model on 8 NVIDIA-4090 and now the epoch is 47.
I want to know how many epochs it takes when an image is generated for the first time.
When I evaluate, the generated image is black, which is very strange.

@LTH14
Copy link
Owner

LTH14 commented Sep 14, 2024

@pokameng as I mentioned, your --blr should be 1e-4 instead of 1e-6. 1e-6 is too small

@pokameng
Copy link
Author

@pokameng as I mentioned, your --blr should be 1e-4 instead of 1e-6. 1e-6 is too small

ok thanks !I will modify the blr
the config diffloss_d 6 or diffloss_d 3 does have influence on training?

@LTH14
Copy link
Owner

LTH14 commented Sep 14, 2024

--diffloss_d 3 should be enough. --diffloss_d 6 might give you slightly better performance

@Andy1621
Copy link

Andy1621 commented Sep 14, 2024

@Andy1621 The masking needs to be on the input, otherwise the transformer can see the masked tokens and easily predict it

@LTH14 Thanks for your quick response. What I mean is to modify the forward function of DiffLoss.

def forward(self, target, z, mask=None):
    t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
    z = z[mask] # only keep masked token
    model_kwargs = dict(c=z)
    loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
    loss = loss_dict["loss"]
    # if mask is not None:
    #     loss = (loss * mask).sum() / mask.sum()
    return loss.mean()

Since SimpleMLPAdaLN operates on each token individually, should it be faster?

@LTH14
Copy link
Owner

LTH14 commented Sep 14, 2024

@Andy1621 Oh I got it -- yes will be slightly faster and equivalent to masking after computing the loss, but not very much since the masking ratio is quite high.

@Andy1621
Copy link

Thanks for your quick response!

@pokameng
Copy link
Author

@LTH14
Hello bro!!
Does the training part of MAR include autoregression? I didn't find any code details about autoregression during training. In other words, autoregression is in the inference stage?

@LTH14
Copy link
Owner

LTH14 commented Sep 20, 2024

@pokameng autoregression is a concept that decomposes the distribution through chain rule, so the model predicts future tokens based on existing tokens. To enable autoregression, we employ mask-based training, which predicts masked tokens given unmasked tokens. Both our training and inference is important to enable autoregression of our model.

@pokameng
Copy link
Author

@pokameng autoregression is a concept that decomposes the distribution through chain rule, so the model predicts future tokens based on existing tokens. To enable autoregression, we employ mask-based training, which predicts masked tokens given unmasked tokens. Both our training and inference is important to enable autoregression of our model.

Thank you for your reply! !
So my understanding is that MAR's autoregression is actually, predicting masked markers given unmasked markers.
Is this explanation correct?

@LTH14
Copy link
Owner

LTH14 commented Sep 20, 2024

Yes -- in the paper we call it "random order" autoregression, or masked autoregression.

@pokameng
Copy link
Author

Yes -- in the paper we call it "random order" autoregression, or masked autoregression.

ok!!!
In other words, during training, each picture will be randomly masked. At each step of training, the model is learning how to predict these masked tokens based on known tokens.

This training method can be understood as "autoregression", and the loss obtained at each step is a probability distribution under different conditions z.

Precisely because the mask is different each time, the obtained condition z will be different, and then the token xi to be predicted based on z will also be different.

image
As shown in the formula in the figure, Xk can represent different mask sets. For X1 to Xk-1, it can be regarded as the condition Z obtained under different mask conditions.

@LTH14
Copy link
Owner

LTH14 commented Sep 20, 2024

yes -- the job of the small diffloss MLP is to model $p(X_k|z)$, where $z$ is predicted from $X_1, \cdots, X_{k-1}$ using the large transformer, so the entire model is to model $p(X_k|X_1, \cdots, X_k)$

@pokameng
Copy link
Author

yes -- the job of the small diffloss MLP is to model p ( X k | z ) , where z is predicted from X 1 , ⋯ , X k − 1 using the large transformer, so the entire model is to model p ( X k | X 1 , ⋯ , X k )

Thank you for your quick reply! ! !

@pokameng
Copy link
Author

@LTH14
Why should we introduce a diffusion model to denoise condition Z? Just to learn the probability distribution under condition Z?
P(xi | z)?

@LTH14
Copy link
Owner

LTH14 commented Sep 20, 2024

This is to make the diffusion process cheaper -- the small diffusion model can model the distribution in a very efficient way

@pokameng
Copy link
Author

@LTH14
hello!
The condition Z output by Decoder is a reconstruction image right?

@pokameng
Copy link
Author

This is to make the diffusion process cheaper -- the small diffusion model can model the distribution in a very efficient way

So the MLP diffusion model is to model distribution p(x|z),right?

@LTH14
Copy link
Owner

LTH14 commented Sep 20, 2024

Z is not a reconstructed image. X is the token to reconstruct the image

@LTH14
Copy link
Owner

LTH14 commented Sep 20, 2024

Yes, the MLP diffusion model is to model distribution p(x|z)

@pokameng
Copy link
Author

Z is not a reconstructed image. X is the token to reconstruct the image
image

In this way, how to understand the unknown token, to predict at this step and unknown.

@LTH14
Copy link
Owner

LTH14 commented Sep 20, 2024

tokens to be predict at this step is a subset of the unknown tokens

@pokameng
Copy link
Author

@LTH14
hello bro
If I replace MLP with other diffusion models, such as SD, when training MAR, what should I pay attention to? Does blr need to be modified? And can diffusion_batch_mul=4 be cancelled?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants