Skip to content

Commit

Permalink
Fix custom fwd and bwd for older PyTorch versions (#596)
Browse files Browse the repository at this point in the history
* fix custom fwd and bwd for older torch versions

* forgot to push the new utils file..

* use partial to fix kwargs passing with dec
  • Loading branch information
vasqu authored Oct 25, 2024
1 parent bc84fb1 commit 83a5c90
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 12 deletions.
6 changes: 3 additions & 3 deletions mamba_ssm/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from mamba_ssm.utils.torch import custom_bwd, custom_fwd

from einops import rearrange

Expand All @@ -22,7 +22,7 @@

class ParallelLinearFunc(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
@custom_fwd
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
Expand Down Expand Up @@ -58,7 +58,7 @@ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
return output

@staticmethod
@custom_bwd(device_type="cuda")
@custom_bwd
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
process_group = ctx.process_group
Expand Down
6 changes: 3 additions & 3 deletions mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.nn.functional as F
from torch.amp import custom_bwd, custom_fwd
from mamba_ssm.utils.torch import custom_bwd, custom_fwd

from einops import rearrange, repeat

Expand Down Expand Up @@ -160,7 +160,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta
class MambaInnerFn(torch.autograd.Function):

@staticmethod
@custom_fwd(device_type="cuda")
@custom_fwd
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
Expand Down Expand Up @@ -236,7 +236,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)

@staticmethod
@custom_bwd(device_type="cuda")
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
Expand Down
6 changes: 3 additions & 3 deletions mamba_ssm/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
import torch.nn.functional as F
from torch.amp import custom_fwd, custom_bwd
from mamba_ssm.utils.torch import custom_bwd, custom_fwd

import triton
import triton.language as tl
Expand Down Expand Up @@ -982,7 +982,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):

class LayerNormLinearFn(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
@custom_fwd
def forward(
ctx,
x,
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def forward(
return out if not prenorm else (out, residual_out.reshape(x_shape_og))

@staticmethod
@custom_bwd(device_type="cuda")
@custom_bwd
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
Expand Down
6 changes: 3 additions & 3 deletions mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.amp import custom_bwd, custom_fwd
from mamba_ssm.utils.torch import custom_bwd, custom_fwd

import triton
import triton.language as tl
Expand Down Expand Up @@ -754,7 +754,7 @@ def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):

@staticmethod
@custom_fwd(device_type="cuda")
@custom_fwd
def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
ngroups=1, norm_before_gate=True):
Expand Down Expand Up @@ -832,7 +832,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
return out if not return_final_states else (out, final_states)

@staticmethod
@custom_bwd(device_type="cuda")
@custom_bwd
def backward(ctx, dout, *args):
zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
dfinal_states = args[0] if ctx.return_final_states else None
Expand Down
19 changes: 19 additions & 0 deletions mamba_ssm/utils/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from functools import partial


def custom_amp_decorator(dec, cuda_amp_deprecated):
def decorator(func):
return dec(func) if not cuda_amp_deprecated else partial(dec, func, device_type="cuda")
return decorator


if hasattr(torch.amp, "custom_fwd"):
deprecated = True
from torch.amp import custom_fwd, custom_bwd
else:
deprecated = False
from torch.cuda.amp import custom_fwd, custom_bwd

custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)

0 comments on commit 83a5c90

Please sign in to comment.