diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index ae71c2ea..5dc73453 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -449,7 +449,12 @@ def run_save_and_load_fsdp_model(dir, model_factory, model_data_factory, pre_ini # Check optimizer state. for p1, p2 in zip(fsdp_model.parameters(), fsdp_model2.parameters()): - torch.testing.assert_close(optim.state[p1], optim2.state[p2]) + if p1.numel() > 0: + torch.testing.assert_close(optim.state[p1], optim2.state[p2]) + else: + for key in ("exp_avg", "exp_avg_sq"): + assert key not in optim.state or optim.state[p1][key].numel() == 0 + assert key not in optim2.state or optim2.state[p2][key].numel() == 0 # Check unsharding model state. full_model_state = unshard_model_state(dir)