Skip to content

Commit

Permalink
integrate candle einsum library
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsanbear committed Mar 12, 2024
1 parent 6819c17 commit d15a4ea
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/utils_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ pub fn scaled_dot_product_gqa(

let similarity = match num_head_groups > 1 || cfg.force_grouped {
true => {
// Original python code:
// query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
// similarity = einsum(query, key, "b g h n d, b h s d -> b h n s")
let query = einops!(
Expand Down Expand Up @@ -218,7 +219,7 @@ pub fn scaled_dot_product_gqa(

// Move head dimension back to axis 2
// out = rearrange(out, "b h n d -> b n h d")
let out = out.permute([0, 2, 1, 3])?;
let out = einops!("b h n d -> b n h d", &out);

let attn_weights = match cfg.need_weights {
false => None,
Expand All @@ -228,7 +229,7 @@ pub fn scaled_dot_product_gqa(
// output: (b, n, h, d).
// python code:
// attn_weights = rearrange(attention, "b h n s -> b n s h")
let attn_weights = attention.permute([0, 2, 3, 1])?;
let attn_weights = einops!("b h n s -> b n s h", &attention);
// if average_attn_weights:
// attn_weights = attn_weights.mean(dim=1)
if cfg.average_attn_weights {
Expand Down

0 comments on commit d15a4ea

Please sign in to comment.