Skip to content

Commit

Permalink
Add Starcoder
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Nov 27, 2023
1 parent 3cc0ff9 commit c56ac70
Show file tree
Hide file tree
Showing 31 changed files with 840 additions and 154 deletions.
6 changes: 6 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ defmodule Bumblebee do
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
"GPT2Model" => {BumbleBee.Text.Gpt2, :base},
"GPTBigCodeModel" => {Bumblebee.Text.GptBigCode, :base},
"GPTBigCodeForCausalLM" => {Bumblebee.Text.GptBigCode, :for_causal_language_modeling},
"GPTBigCodeForSequenceClassification" =>
{Bumblebee.Text.GptBigCode, :for_sequence_classification},
"GPTBigCodeForTokenClassification" => {Bumblebee.Text.GptBigCode, :for_token_classification},
"GPTNeoXModel" => {Bumblebee.Text.GptNeoX, :base},
"GPTNeoXForCausalLM" => {Bumblebee.Text.GptNeoX, :for_causal_language_modeling},
"GPTNeoXForSequenceClassification" => {Bumblebee.Text.GptNeoX, :for_sequence_classification},
Expand Down Expand Up @@ -215,6 +220,7 @@ defmodule Bumblebee do
"clip" => Bumblebee.Text.ClipTokenizer,
"gpt_neox" => Bumblebee.Text.GptNeoXTokenizer,
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
"gpt_bigcode" => Bumblebee.Text.Gpt2Tokenizer,
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
"llama" => Bumblebee.Text.LlamaTokenizer,
"mistral" => Bumblebee.Text.LlamaTokenizer,
Expand Down
4 changes: 2 additions & 2 deletions lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ defmodule Bumblebee.Audio.Whisper do
cross_hidden_state: encoder_hidden_state,
cross_attention_head_mask: cross_attention_head_mask,
cache: cache,
causal?: true,
causal: true,
num_blocks: spec.decoder_num_blocks,
num_attention_heads: spec.decoder_num_attention_heads,
hidden_size: spec.hidden_size,
Expand Down Expand Up @@ -520,7 +520,7 @@ defmodule Bumblebee.Audio.Whisper do
decoder_num_attention_heads: {"decoder_attention_heads", number()},
encoder_intermediate_size: {"encoder_ffn_dim", number()},
decoder_intermediate_size: {"decoder_ffn_dim", number()},
activation: {"activation_function", atom()},
activation: {"activation_function", activation()},
dropout_rate: {"dropout", number()},
attention_dropout_rate: {"attention_dropout", number()},
activation_dropout_rate: {"activation_dropout", number()},
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/diffusion/unet_2d_conditional.ex
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do
num_attention_heads: {"attention_head_dim", one_of([number(), list(number())])},
cross_attention_size: {"cross_attention_dim", number()},
use_linear_projection: {"use_linear_projection", boolean()},
activation: {"act_fn", atom()},
activation: {"act_fn", activation()},
group_norm_num_groups: {"norm_num_groups", number()},
group_norm_epsilon: {"norm_eps", number()}
)
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/diffusion/vae_kl.ex
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ defmodule Bumblebee.Diffusion.VaeKl do
down_block_types:
{"down_block_types", list(mapping(%{"DownEncoderBlock2D" => :down_block}))},
up_block_types: {"up_block_types", list(mapping(%{"UpDecoderBlock2D" => :up_block}))},
activation: {"act_fn", atom()}
activation: {"act_fn", activation()}
)

@for.config(spec, opts)
Expand Down
37 changes: 25 additions & 12 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Bumblebee.Layers do

import Nx.Defn

@unsupported_activations [:gelu_new, :quick_gelu]
@unsupported_activations [:gelu_approx_tanh, :gelu_approx_sigmoid]

@pi :math.pi()

Expand All @@ -30,17 +30,29 @@ defmodule Bumblebee.Layers do
end

@doc """
Implements the GeLU new activation from huggingface/transformers.
Implements the GeLU activation approximated with tanh.
## References
* [Gaussian Error Linear Units (GeLUs)](https://arxiv.org/pdf/1606.08415.pdf)
"""
defn gelu_new(input, _opts \\ []) do
defn gelu_approx_tanh(input, _opts \\ []) do
0.5 * input *
(1.0 + Nx.tanh(Nx.sqrt(2.0 / @pi) * (input + 0.044715 * Nx.pow(input, 3.0))))
end

@doc """
Implements the GeLU quick activation from huggingface/transformers.
Implements the GeLU activation approximated with sigmoid.
Note that this approximation is less accurate than `gelu_approx_tanh/2`.
## References
* [Gaussian Error Linear Units (GeLUs)](https://arxiv.org/pdf/1606.08415.pdf)
"""
defn quick_gelu(input, _opts \\ []) do
defn gelu_approx_sigmoid(input, _opts \\ []) do
input * Nx.sigmoid(1.702 * input)
end

Expand Down Expand Up @@ -184,28 +196,29 @@ defmodule Bumblebee.Layers do
## Options
* `:scale_query?` - whether to scale the query. Defaults to `true`
* `:scale` - whether to scale the weights. Defaults to `true`
"""
def attention_weights(query, key, bias, opts \\ []) do
Axon.layer(&attention_weights_impl/4, [query, key, bias], opts)
end

defnp attention_weights_impl(query, key, bias, opts \\ []) do
opts = keyword!(opts, mode: :train, scale_query?: true)
opts = keyword!(opts, mode: :train, scale: true)

key = Nx.transpose(key, axes: [0, 2, 1, 3])
query = Nx.transpose(query, axes: [0, 2, 1, 3])

query =
if opts[:scale_query?] do
weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])

weights =
if opts[:scale] do
depth = Nx.axis_size(query, -1)
query / Nx.sqrt(depth)
weights / Nx.sqrt(depth)
else
query
weights
end

weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])
weights = weights + bias
Axon.Activations.softmax(weights, axis: -1)
end
Expand Down
36 changes: 18 additions & 18 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ defmodule Bumblebee.Layers.Transformer do
block_opts_keys = [
:num_attention_heads,
:num_key_value_heads,
:causal?,
:causal,
:hidden_size,
:ffn,
:kernel_initializer,
Expand All @@ -56,7 +56,7 @@ defmodule Bumblebee.Layers.Transformer do
:output_use_bias,
:layer_norm,
:block_type,
:scale_query?,
:scale_attention_weights,
:rotary_embedding
]

Expand Down Expand Up @@ -216,7 +216,7 @@ defmodule Bumblebee.Layers.Transformer do
* `:offset` - offset in the input sequence during iterative decoding
* `:causal?` - whether the self-attention block should be causal.
* `:causal` - whether the self-attention block should be causal.
Defaults to `false`
* `:kernel_initializer` - initializer for kernel weights. Defaults
Expand Down Expand Up @@ -265,7 +265,7 @@ defmodule Bumblebee.Layers.Transformer do
* `:parallel` - block with attention and FFN independently (in parallel).
This type doesn't support cross-attention
* `:scale_query?` - whether to scale query in the traditional style of
* `:scale_attention_weights` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`
* `:rotary_embedding` - configuration of rotary embedding. If set,
Expand Down Expand Up @@ -308,7 +308,7 @@ defmodule Bumblebee.Layers.Transformer do
cross_attention_head_mask: Layers.none(),
block_cache: Layers.none(),
offset: Layers.none(),
causal?: false,
causal: false,
kernel_initializer: :glorot_uniform,
attention_head_size: nil,
dropout_rate: 0.0,
Expand All @@ -319,7 +319,7 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
block_type: :standard,
layer_norm: [],
scale_query?: true,
scale_attention_weights: true,
rotary_embedding: nil
])

Expand All @@ -328,7 +328,7 @@ defmodule Bumblebee.Layers.Transformer do
num_key_value_heads = opts[:num_key_value_heads] || num_attention_heads
hidden_size = opts[:hidden_size]
ffn = opts[:ffn]
causal? = opts[:causal?]
causal = opts[:causal]
kernel_initializer = opts[:kernel_initializer]
attention_head_size = opts[:attention_head_size]
dropout_rate = opts[:dropout_rate]
Expand All @@ -347,7 +347,7 @@ defmodule Bumblebee.Layers.Transformer do
offset = opts[:offset]
layer_norm = opts[:layer_norm]
block_type = opts[:block_type]
scale_query? = opts[:scale_query?]
scale_attention_weights = opts[:scale_attention_weights]
rotary_embedding = opts[:rotary_embedding]

ffn_fun =
Expand Down Expand Up @@ -393,7 +393,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_relative_bias: attention_relative_bias,
attention_cache: self_attention_cache,
offset: offset,
causal?: causal?,
causal: causal,
num_heads: num_attention_heads,
num_key_value_heads: num_key_value_heads,
hidden_size: hidden_size,
Expand All @@ -404,7 +404,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
scale_query?: scale_query?,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "self_attention")
)
Expand Down Expand Up @@ -448,7 +448,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
scale_query?: scale_query?,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "cross_attention")
)
Expand Down Expand Up @@ -673,7 +673,7 @@ defmodule Bumblebee.Layers.Transformer do
* `:offset` - offset in the input sequence during iterative decoding
* `:causal?` - whether to apply causal attention mask, so that tokens
* `:causal` - whether to apply causal attention mask, so that tokens
are attended to only in a single direction. Defaults to `false`
* `:kernel_initializer` - initializer for kernel weights. Defaults
Expand Down Expand Up @@ -727,8 +727,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_relative_bias: Layers.none(),
attention_cache: Layers.none(),
offset: Layers.none(),
causal?: false,
scale_query?: true,
causal: false,
scale_attention_weights: true,
kernel_initializer: :glorot_uniform,
dropout_rate: 0.0,
attention_head_size: nil,
Expand All @@ -749,8 +749,8 @@ defmodule Bumblebee.Layers.Transformer do
num_key_value_heads = opts[:num_key_value_heads] || num_heads
hidden_size = opts[:hidden_size]
kernel_initializer = opts[:kernel_initializer]
causal? = opts[:causal?]
scale_query? = opts[:scale_query?]
causal = opts[:causal]
scale_attention_weights = opts[:scale_attention_weights]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]

Expand Down Expand Up @@ -850,7 +850,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_mask = Layers.expand_attention_mask(attention_mask)

attention_mask =
if causal? do
if causal do
Layers.Decoder.apply_causal_mask(attention_mask, query, offset)
else
attention_mask
Expand Down Expand Up @@ -884,7 +884,7 @@ defmodule Bumblebee.Layers.Transformer do
end

attention_weights =
Layers.attention_weights(query, key, attention_bias, scale_query?: scale_query?)
Layers.attention_weights(query, key, attention_bias, scale: scale_attention_weights)
|> Axon.dropout(rate: dropout_rate)
|> Layers.apply_attention_head_mask(attention_head_mask)

Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/multimodal/layout_lm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ defmodule Bumblebee.Multimodal.LayoutLm do
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", atom()},
activation: {"hidden_act", activation()},
dropout_rate: {"hidden_dropout_prob", number()},
attention_dropout_rate: {"attention_probs_dropout_prob", number()},
classifier_dropout_rate: {"classifier_dropout", optional(number())},
Expand Down
45 changes: 45 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -512,4 +512,49 @@ defmodule Bumblebee.Shared do
end
end
end

@doc """
Slices a subset of dense layer parameters.
Expects `out_template` to be a tuple representing a "shape" of the
output units. The tuple should include a list in place of the axis
along which the parameters are concatenated. The list should contain
chunk sizes. `chunk_idx` indicates which chunk to slice.
"""
def sliced_dense_params_source(source_layer_name, out_template, chunk_idx) do
out_template = Tuple.to_list(out_template)
chunk_axis = Enum.find_index(out_template, &is_list/1)
chunk_sizes = Enum.at(out_template, chunk_axis)
{prev_chunk_sizes, [chunk_size | _]} = Enum.split(chunk_sizes, chunk_idx)
offset = Enum.sum(prev_chunk_sizes)
out_shape = List.replace_at(out_template, chunk_axis, Enum.sum(chunk_sizes))

%{
"kernel" => {
[{source_layer_name, "weight"}],
fn [kernel] ->
in_size = Nx.axis_size(kernel, -1)

kernel =
kernel
|> Nx.reshape(List.to_tuple(out_shape ++ [in_size]))
|> Nx.slice_along_axis(offset, chunk_size, axis: chunk_axis)
|> Nx.reshape({:auto, in_size})

# Transpose the kernel
[out_features, in_features] = Nx.axes(kernel)
Nx.transpose(kernel, axes: [in_features, out_features])
end
},
"bias" => {
[{source_layer_name, "bias"}],
fn [bias] ->
bias
|> Nx.reshape(List.to_tuple(out_shape))
|> Nx.slice_along_axis(offset, chunk_size, axis: chunk_axis)
|> Nx.flatten()
end
}
}
end
end
15 changes: 15 additions & 0 deletions lib/bumblebee/shared/converters.ex
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,19 @@ defmodule Bumblebee.Shared.Converters do
end
end
end

def activation() do
mapping = %{
"gelu_new" => :gelu_approx_tanh,
"gelu_pytorch_tanh" => :gelu_approx_tanh,
"quick_gelu" => :gelu_approx_sigmoid
}

fn name, value ->
case Map.fetch(mapping, value) do
{:ok, replacement} -> {:ok, replacement}
:error -> atom().(name, value)
end
end
end
end
2 changes: 1 addition & 1 deletion lib/bumblebee/text/albert.ex
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ defmodule Bumblebee.Text.Albert do
block_depth: {"inner_group_num", number()},
num_attention_heads: {"num_attention_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", atom()},
activation: {"hidden_act", activation()},
dropout_rate: {"hidden_dropout_prob", number()},
attention_dropout_rate: {"attention_probs_dropout_prob", number()},
classifier_dropout_rate: {"classifier_dropout_prob", optional(number())},
Expand Down
4 changes: 2 additions & 2 deletions lib/bumblebee/text/bart.ex
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ defmodule Bumblebee.Text.Bart do
cross_attention_mask: encoder_attention_mask,
cross_attention_head_mask: cross_attention_head_mask,
cache: cache,
causal?: true,
causal: true,
num_blocks: spec.decoder_num_blocks,
num_attention_heads: spec.decoder_num_attention_heads,
hidden_size: spec.hidden_size,
Expand Down Expand Up @@ -639,7 +639,7 @@ defmodule Bumblebee.Text.Bart do
encoder_intermediate_size: {"encoder_ffn_dim", number()},
decoder_intermediate_size: {"decoder_ffn_dim", number()},
scale_embedding: {"scale_embedding", boolean()},
activation: {"activation_function", atom()},
activation: {"activation_function", activation()},
dropout_rate: {"dropout", number()},
attention_dropout_rate: {"attention_dropout", number()},
activation_dropout_rate: {"activation_dropout", number()},
Expand Down
Loading

0 comments on commit c56ac70

Please sign in to comment.