-
Notifications
You must be signed in to change notification settings - Fork 473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] comparing vanilla torch model and clipping with OLMo FSDP, no_shard and OLMo clipping #577
base: main
Are you sure you want to change the base?
Conversation
Can you make this a draft PR? I don't think we'll merge this? |
Find out which one of these makes the difference? |
tests/grad_norm_test.py
Outdated
reduce_dtype=torch.float32, | ||
buffer_dtype=torch.float32, | ||
), | ||
auto_wrap_policy=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you check how it wraps the model?
tests/grad_norm_test.py
Outdated
# use same model, data, optimizer, fsdp_model and send to trainer and compare gradient clip | ||
|
||
# olmo optimizer | ||
model = OLMo(cfg.model).to('cuda') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make double sure the initialization is same here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the looks of it, this model will definitely start with different weights than the one above. The best way to be sure would be to load starting state dict of the first model into the second model.
tests/grad_norm_test.py
Outdated
# olmo optimizer | ||
model = OLMo(cfg.model).to('cuda') | ||
olmo_optimizer = build_optimizer(cfg, model) | ||
data_loader = build_train_dataloader(cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure the data loader is same too? By the time you get here you might have consumed some random state, so you might get different data. I would go so far as to pre-load the data into a List
, and then use that, so you can be 100% sure it's the same.
tests/grad_norm_test.py
Outdated
# use same model, data, optimizer, fsdp_model and send to trainer and compare gradient clip | ||
|
||
# olmo optimizer | ||
model = OLMo(cfg.model).to('cuda') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the looks of it, this model will definitely start with different weights than the one above. The best way to be sure would be to load starting state dict of the first model into the second model.
# Now reduce metrics over all ranks. | ||
total_grad_norm: torch.Tensor | ||
per_param_avg_metrics: List[torch.Tensor] = [] | ||
if is_distributed(): # TODO (epwalsh): skip for non-sharded params | ||
if is_distributed() and param_group_sharded: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will still need to reduce gradient metrics with non-sharded params since each rank will have different gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't the gradients sync after loss.backward() call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes you're right. My mistake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in DDP loss.backward() syncs grads, optimizer.step() updates each copy with synced grads! Or am I getting this wrong? So, if loss.backward() syncs grads, then every rank must have the same grads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@epwalsh does this apply to FSDP no_shard as well?
There is an order of magnitude difference between the losses between the two setups. @dirkgr @epwalsh can you sanity check the OLMo grad_clipping code for FSDP no_shard/DDP?
On 3 batches being sent to the model again and again, the model should overfit and the loss should go to 0. We can see that in the screenshot for vanilla pytorch run.
Comparing the two runs: