Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 17, 2024
1 parent 71b1aee commit 1a5b59d
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,12 @@ def run_save_and_load_fsdp_model(dir, model_factory, model_data_factory, pre_ini

# Check optimizer state.
for p1, p2 in zip(fsdp_model.parameters(), fsdp_model2.parameters()):
torch.testing.assert_close(optim.state[p1], optim2.state[p2])
if p1.numel() > 0:
torch.testing.assert_close(optim.state[p1], optim2.state[p2])
else:
for key in ("exp_avg", "exp_avg_sq"):
assert key not in optim.state or optim.state[p1][key].numel() == 0
assert key not in optim2.state or optim2.state[p2][key].numel() == 0

# Check unsharding model state.
full_model_state = unshard_model_state(dir)
Expand Down

0 comments on commit 1a5b59d

Please sign in to comment.