Fix custom fwd and bwd for older PyTorch versions #596
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #594
Fixes #496
Fixes #517
#584 broke main for torch with older versions, e.g. torch==2.2.0. There have been similar attempts at #501 and #560
The main difference is to create a wrapper around the decorator that passes the new kwarg if necessary and changes the import depending on what torch version we have (based on if amp has custom_(fwd|bwd)). I moved it to the utils folder, open to change the structure.
P.S. Verified the fix with torch=2.2.0 (older version who doesn't have custom_(fwd|bwd) in amp) and torch=2.5.0 (newer version who does have custom_(fwd|bwd) in amp).