From 83a5c9074ff87eea214eaf828bfeaca17605fab4 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Fri, 25 Oct 2024 23:42:24 +0200 Subject: [PATCH] Fix custom fwd and bwd for older PyTorch versions (#596) * fix custom fwd and bwd for older torch versions * forgot to push the new utils file.. * use partial to fix kwargs passing with dec --- mamba_ssm/distributed/tensor_parallel.py | 6 +++--- mamba_ssm/ops/selective_scan_interface.py | 6 +++--- mamba_ssm/ops/triton/layer_norm.py | 6 +++--- mamba_ssm/ops/triton/ssd_combined.py | 6 +++--- mamba_ssm/utils/torch.py | 19 +++++++++++++++++++ 5 files changed, 31 insertions(+), 12 deletions(-) create mode 100644 mamba_ssm/utils/torch.py diff --git a/mamba_ssm/distributed/tensor_parallel.py b/mamba_ssm/distributed/tensor_parallel.py index 683e01de..2d67b530 100644 --- a/mamba_ssm/distributed/tensor_parallel.py +++ b/mamba_ssm/distributed/tensor_parallel.py @@ -6,8 +6,8 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +from mamba_ssm.utils.torch import custom_bwd, custom_fwd from einops import rearrange @@ -22,7 +22,7 @@ class ParallelLinearFunc(torch.autograd.Function): @staticmethod - @custom_fwd(device_type="cuda") + @custom_fwd def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel @@ -58,7 +58,7 @@ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): return output @staticmethod - @custom_bwd(device_type="cuda") + @custom_bwd def backward(ctx, grad_output): grad_output = grad_output.contiguous() process_group = ctx.process_group diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 79deb224..d9199806 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F -from torch.amp import custom_bwd, custom_fwd +from mamba_ssm.utils.torch import custom_bwd, custom_fwd from einops import rearrange, repeat @@ -160,7 +160,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta class MambaInnerFn(torch.autograd.Function): @staticmethod - @custom_fwd(device_type="cuda") + @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, @@ -236,7 +236,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod - @custom_bwd(device_type="cuda") + @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." diff --git a/mamba_ssm/ops/triton/layer_norm.py b/mamba_ssm/ops/triton/layer_norm.py index af8db0d0..200b415a 100755 --- a/mamba_ssm/ops/triton/layer_norm.py +++ b/mamba_ssm/ops/triton/layer_norm.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F -from torch.amp import custom_fwd, custom_bwd +from mamba_ssm.utils.torch import custom_bwd, custom_fwd import triton import triton.language as tl @@ -982,7 +982,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): class LayerNormLinearFn(torch.autograd.Function): @staticmethod - @custom_fwd(device_type="cuda") + @custom_fwd def forward( ctx, x, @@ -1041,7 +1041,7 @@ def forward( return out if not prenorm else (out, residual_out.reshape(x_shape_og)) @staticmethod - @custom_bwd(device_type="cuda") + @custom_bwd def backward(ctx, dout, *args): x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors dout = dout.reshape(-1, dout.shape[-1]) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index ea24285d..58a6e04a 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.amp import custom_bwd, custom_fwd +from mamba_ssm.utils.torch import custom_bwd, custom_fwd import triton import triton.language as tl @@ -754,7 +754,7 @@ def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D= class MambaSplitConv1dScanCombinedFn(torch.autograd.Function): @staticmethod - @custom_fwd(device_type="cuda") + @custom_fwd def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): @@ -832,7 +832,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, return out if not return_final_states else (out, final_states) @staticmethod - @custom_bwd(device_type="cuda") + @custom_bwd def backward(ctx, dout, *args): zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors dfinal_states = args[0] if ctx.return_final_states else None diff --git a/mamba_ssm/utils/torch.py b/mamba_ssm/utils/torch.py new file mode 100644 index 00000000..afe1dfcf --- /dev/null +++ b/mamba_ssm/utils/torch.py @@ -0,0 +1,19 @@ +import torch +from functools import partial + + +def custom_amp_decorator(dec, cuda_amp_deprecated): + def decorator(func): + return dec(func) if not cuda_amp_deprecated else partial(dec, func, device_type="cuda") + return decorator + + +if hasattr(torch.amp, "custom_fwd"): + deprecated = True + from torch.amp import custom_fwd, custom_bwd +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated)