From 2fbe38073e935a5df8fe3db000735ad6d8800efa Mon Sep 17 00:00:00 2001 From: Simeon Mugisha Date: Thu, 21 Mar 2024 11:25:08 +0300 Subject: [PATCH] Add Gemma attention head size (#364) --- lib/bumblebee/text/gemma.ex | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/bumblebee/text/gemma.ex b/lib/bumblebee/text/gemma.ex index 567f76bb..fe718dd2 100644 --- a/lib/bumblebee/text/gemma.ex +++ b/lib/bumblebee/text/gemma.ex @@ -26,6 +26,10 @@ defmodule Bumblebee.Text.Gemma do default: 24576, doc: "the dimensionality of intermediate layers" ], + attention_head_size: [ + default: 256, + doc: "the size of the key, value, and query projection per attention head" + ], num_blocks: [ default: 28, doc: "the number of Transformer blocks in the model" @@ -172,6 +176,7 @@ defmodule Bumblebee.Text.Gemma do def init_cache(spec, batch_size, max_length, _inputs) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_blocks: spec.num_blocks ) @@ -334,6 +339,7 @@ defmodule Bumblebee.Text.Gemma do Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, attention_head_mask: attention_head_mask, + attention_head_size: spec.attention_head_size, cache: cache, num_blocks: spec.num_blocks, num_attention_heads: spec.num_attention_heads, @@ -419,6 +425,7 @@ defmodule Bumblebee.Text.Gemma do num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, num_key_value_heads: {"num_key_value_heads", number()}, + attention_head_size: {"head_dim", number()}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_act", activation()}, use_attention_bias: {"attention_bias", boolean()},