Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix custom fwd and bwd for older PyTorch versions #596

Merged
merged 3 commits into from
Oct 25, 2024

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Oct 18, 2024

Fixes #594
Fixes #496
Fixes #517

#584 broke main for torch with older versions, e.g. torch==2.2.0. There have been similar attempts at #501 and #560
The main difference is to create a wrapper around the decorator that passes the new kwarg if necessary and changes the import depending on what torch version we have (based on if amp has custom_(fwd|bwd)). I moved it to the utils folder, open to change the structure.

P.S. Verified the fix with torch=2.2.0 (older version who doesn't have custom_(fwd|bwd) in amp) and torch=2.5.0 (newer version who does have custom_(fwd|bwd) in amp).

@naromero77amd
Copy link

If possible, it would be good to maintain backwards compatibility with the last two versions of PyTorch. It would like to see this PR. land.

@tridao tridao merged commit 83a5c90 into state-spaces:main Oct 25, 2024
@vasqu vasqu deleted the fix-fwd-bwd-for-older-torch branch October 25, 2024 22:11
@KokeCacao
Copy link
Contributor

KokeCacao commented Oct 26, 2024

In older version, this expression is allowed: @custom_fwd(cast_inputs=torch.float32)

Would it be better to change it to the following? See #608

def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
    def decorator(*args, **kwargs):
        if cuda_amp_deprecated:
            kwargs["device_type"] = "cuda"
        return dec(*args, **kwargs)
    return decorator


if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
    deprecated = True
    from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
else:
    deprecated = False
    from torch.cuda.amp import custom_fwd, custom_bwd

custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants