Skip to content

Commit

Permalink
Fix torch.compile error for PyTorch v2.3 (microsoft#5463)
Browse files Browse the repository at this point in the history
PyTorch v2.3 throws an error when it tries to compile `iter_params` used
for ZeRO3.
This PR excludes the function from the compilation targets.

After this PR is merged, we can [unpin the torch version for unit
tests](microsoft#5459).
  • Loading branch information
tohtana authored and umchand committed May 20, 2024
1 parent a10ebfd commit ecffd7d
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())


@compiler.disable
def iter_params(module: Module, recurse=False) -> Iterable[Parameter]:
return map(lambda pair: pair[1], get_all_parameters(module, recurse))

Expand Down

0 comments on commit ecffd7d

Please sign in to comment.