Skip to content

Commit

Permalink
fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 4, 2024
1 parent 51d15af commit 077e274
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/test/distributed/fsdp/flat_param_handle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_flat_param_handle_collate_flat_params():
dist.all_reduce(og_data)

flat_params = [ShardedFlatParameter.shard(og_data) for og_data in all_og_data]
handle = FlatParamHandle.collate_flat_params(flat_params, ["x", "y", "z"])
handle = FlatParamHandle.collate_flat_params(flat_params, ["x", "y", "z"], device=get_default_device())
for param in handle.params:
assert same_storage(param, handle.params_data)

Expand Down

0 comments on commit 077e274

Please sign in to comment.