diff --git a/src/test/distributed/fsdp/flat_param_handle_test.py b/src/test/distributed/fsdp/flat_param_handle_test.py index a0b912fc..ad80c523 100644 --- a/src/test/distributed/fsdp/flat_param_handle_test.py +++ b/src/test/distributed/fsdp/flat_param_handle_test.py @@ -19,7 +19,7 @@ def run_flat_param_handle_collate_flat_params(): dist.all_reduce(og_data) flat_params = [ShardedFlatParameter.shard(og_data) for og_data in all_og_data] - handle = FlatParamHandle.collate_flat_params(flat_params, ["x", "y", "z"]) + handle = FlatParamHandle.collate_flat_params(flat_params, ["x", "y", "z"], device=get_default_device()) for param in handle.params: assert same_storage(param, handle.params_data)