Skip to content

Commit

Permalink
fix rope for cross attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume "Vermeille" Sanchez committed May 17, 2024
1 parent 5b76f4b commit 17675da
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions torchelie/nn/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def __init__(self, dim, base=10000):
self.cos_cached = None
self.sin_cached = None

def forward(self, q, k, v, seq_dim=-2):
seq_len = q.shape[seq_dim]
def forward(self, q, k, v):
seq_len = max(q.shape[-2], k.shape[-2])
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(q.shape[seq_dim],
t = torch.arange(seq_len,
device=q.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(q.device)
Expand All @@ -37,14 +37,14 @@ def forward(self, q, k, v, seq_dim=-2):
# rotary pos emb helpers:

def rotate_half(self, x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat(
(-x2, x1),
dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=x1.ndim - 1)

def apply_rotary_pos_emb(self, q, k, v, cos, sin):
return (q * cos) + (self.rotate_half(q) *
sin), (k * cos) + (self.rotate_half(k) * sin), v
q_len = q.shape[-2]
k_len = k.shape[-2]
return (q * cos[:q_len]) + (self.rotate_half(q) * sin[:q_len]), (
k * cos[:k_len]) + (self.rotate_half(k) * sin[:k_len]), v


class SelfAttention(nn.Module):
Expand Down

0 comments on commit 17675da

Please sign in to comment.