Skip to content

Commit

Permalink
Fix Op (scaled_dot_product_attention) | feat(torchlib) (#1800)
Browse files Browse the repository at this point in the history
Fix #1799 

Add an extra argument: `enable_gqa` to unblock the export.
The real implementation:
#1802
  • Loading branch information
titaiwangms authored Aug 13, 2024
1 parent 87d7c4f commit af69f4d
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,8 +1769,9 @@ def aten_scaled_dot_product_attention(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> TFloat:
"""scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
"""scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
Expand All @@ -1790,6 +1791,10 @@ def aten_scaled_dot_product_attention(
is_causal and attn_mask is None
), "is_causal and attn_mask cannot be set at the same time"

assert (
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"

# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if scale is None:
scale = _attention_scale(query)
Expand Down Expand Up @@ -1982,8 +1987,9 @@ def aten_scaled_dot_product_attention_bool_mask(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> TFloat:
"""scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
"""scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
Expand All @@ -2003,6 +2009,10 @@ def aten_scaled_dot_product_attention_bool_mask(
is_causal and attn_mask is None
), "is_causal and attn_mask cannot be set at the same time"

assert (
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"

if scale is None:
scale = _attention_scale(query)
scale = op.CastLike(scale, query)
Expand Down

0 comments on commit af69f4d

Please sign in to comment.