Skip to content

Commit

Permalink
Refactor attention implementation (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Dec 12, 2023
1 parent e726e26 commit eca3735
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 161 deletions.
2 changes: 1 addition & 1 deletion .formatter.exs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Used by "mix format"
[
import_deps: [:nx],
inputs: ["{mix,.formatter}.exs", "{config,lib,test,examples}/**/*.{ex,exs}"]
]
218 changes: 153 additions & 65 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,6 @@ defmodule Bumblebee.Layers do
input * Nx.sigmoid(1.702 * input)
end

@doc """
Expands an attention mask of shape `{batch_size, sequence_length}` to
a full mask.
"""
def expand_attention_mask(attention_mask) do
Axon.nx(attention_mask, fn attention_mask ->
attention_mask
|> Nx.new_axis(-2)
|> Nx.new_axis(-2)
end)
end

@doc """
Converts attention mask to bias.
"""
def attention_bias(attention_mask) do
attention_mask
|> Axon.optional()
|> Axon.nx(fn
%Axon.None{} ->
Nx.tensor(0)

attention_mask ->
Nx.select(Nx.greater(attention_mask, 0), 0, -1.0e10)
end)
end

@doc """
Computes relative attention bias.
"""
Expand Down Expand Up @@ -130,7 +103,8 @@ defmodule Bumblebee.Layers do
end

defnp compute_relative_position_buckets(query, key, attention_cache, opts \\ []) do
opts = keyword!(opts, mode: :train, bidirectional: true, num_buckets: 32, max_distance: 128)
opts =
keyword!(opts, mode: :inference, bidirectional: true, num_buckets: 32, max_distance: 128)

{key_length, query_length} = key_query_lengths(query, key, attention_cache)

Expand Down Expand Up @@ -191,71 +165,185 @@ defmodule Bumblebee.Layers do
end
end

@doc """
Computes attention weights.
@doc ~S"""
Computes scaled dot-product attention for multiple attention heads.
This is the core calculation behind multi-head attention, the projection
layers should be applied on top of this layer.
Given input sequences $Q, K, V \in R^{N \times d}$, where $N$ is the
sequence length and $d$ is the head dimension, the scaled dot-product
attention is defined as:
$$
Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d}})V
$$
This operations is further batched across multiple heads and multiple
input sequences.
Intuitively scaled dot-product attention can be thought of as information
retrieval, where for each sequence element in $Q$ the objective is
to extract relevant context from sequence elements in $V$. In this
analogy, $K$ is the summarization of information, while $V$ is the
actual information. Then, assuming $Q$ and $K$ are embedded into a
common space (which is the job of prior projection layers), the
$QK^T$ dot product is a cosine similarity and gives us relevance
weights for sequence elements in $V$.
In case of self-attention, where $Q, K, V$ originate from the same
sequence, the $QK^T$ weights indicate how much "each word attends
to other words".
## Parameter Shapes
* `query` - `{batch_size, sequence_length, num_heads, head_size}`
* `key` - `{batch_size, kv_sequence_length, num_heads, head_size}`
* `value` - `{batch_size, kv_sequence_length, num_heads, head_size}`
* `key_mask` (optional) - `{batch_size, kv_sequence_length}`
* `head_mask` (optional) - `{num_heads}`
* `bias` (optional) - `{batch_size | 1, num_heads | 1, sequence_length, kv_sequence_length}`
* `offset` (optional) - `{}`
## Output Shape
`{batch_size, sequence_length, num_heads, head_size}`
## Options
* `:scale` - whether to scale the weights. Defaults to `true`
* `:causal` - whether to apply causal mask to attention weights.
This is typically used for next token prediction and it
effectively makes each input token use information exclusively
from prior tokens. Defaults to `false`
* `:scale` - whether to scale attention weights by $\frac{1}{\sqrt{d}}$.
Defaults to `true`
* `:dropout_rate` - the dropout rate for attention weights dropout.
Defaults to `0.0`
## References
* [Attention Is All You Need](https://arxiv.org/abs/1706.03762), Figure 2 (left)
"""
def attention_weights(query, key, bias, opts \\ []) do
Axon.layer(&attention_weights_impl/4, [query, key, bias], opts)
def attention(query, key, value, key_mask, head_mask, bias, offset, opts \\ []) do
opts = Keyword.validate!(opts, causal: false, scale: true, dropout_rate: 0.0)

weights =
Axon.layer(
&attention_weights_impl/7,
[
query,
key,
Axon.optional(key_mask),
Axon.optional(head_mask),
Axon.optional(bias),
Axon.optional(offset)
],
causal: opts[:causal],
scale: opts[:scale]
)
|> Axon.dropout(rate: opts[:dropout_rate])

output = Axon.layer(&attention_output_impl/3, [weights, value], opts)

{output, weights}
end

defnp attention_weights_impl(query, key, bias, opts \\ []) do
opts = keyword!(opts, mode: :train, scale: true)
defnp attention_weights_impl(query, key, key_mask, head_mask, bias, offset, opts \\ []) do
opts = keyword!(opts, mode: :inference, scale: true, causal: false)

key = Nx.transpose(key, axes: [0, 2, 1, 3])
query = Nx.transpose(query, axes: [0, 2, 1, 3])
key = Nx.transpose(key, axes: [0, 2, 1, 3])

weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])

weights =
if opts[:scale] do
depth = Nx.axis_size(query, -1)
weights / Nx.sqrt(depth)
weights / Nx.as_type(Nx.sqrt(depth), Nx.type(query))
else
weights
end

key_mask =
case key_mask do
%Axon.None{} -> Nx.broadcast(1, {1, 1, 1, 1})
key_mask -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1)
end

causal_mask =
if opts[:causal] do
query_sequence_length = Nx.axis_size(query, 2)
key_sequence_length = Nx.axis_size(key, 2)
offset = ensure_offset(offset)

Nx.greater_equal(
Nx.iota({query_sequence_length, 1}) + offset,
Nx.iota({1, key_sequence_length})
)
|> Nx.new_axis(0)
|> Nx.new_axis(0)
else
Nx.broadcast(1, {1, 1, 1, 1})
end

mask = Nx.logical_and(key_mask, causal_mask)

bias =
case bias do
%Axon.None{} ->
Nx.select(
mask,
Nx.tensor(0.0, type: Nx.type(query)),
Nx.Constants.min_finite(Nx.type(query))
)

bias ->
Nx.select(
Nx.broadcast(mask, max_shape(mask, bias)),
bias,
Nx.Constants.min_finite(Nx.type(query))
)
end

weights = weights + bias
Axon.Activations.softmax(weights, axis: -1)
end

@doc """
Computes attention outputs.
"""
def attention_output(attention_weights, value) do
Axon.layer(&attention_output_impl/3, [attention_weights, value])
weights = Axon.Activations.softmax(weights, axis: -1)

case head_mask do
%Axon.None{} ->
weights

head_mask ->
head_mask = Nx.reshape(head_mask, {1, :auto, 1, 1})
Nx.multiply(weights, head_mask)
end
end

defnp attention_output_impl(attention_weights, value, _opts \\ []) do
defnp attention_output_impl(weights, value, _opts \\ []) do
value = Nx.transpose(value, axes: [0, 2, 1, 3])
out = Nx.dot(attention_weights, [3], [0, 1], value, [2], [0, 1])
out = Nx.dot(weights, [3], [0, 1], value, [2], [0, 1])
Nx.transpose(out, axes: [0, 2, 1, 3])
end

@doc """
Applies head mask to the given attention weights.
This layer expects computed attention weights and an optional mask.
If the mask is not specified, it will skip masking altogether.
"""
def apply_attention_head_mask(attention_weights, head_mask) do
if_present head_mask do
Axon.layer(
fn attention_weights, head_mask, _ ->
head_mask = Nx.reshape(head_mask, {1, :auto, 1, 1})
Nx.multiply(attention_weights, head_mask)
end,
[attention_weights, head_mask]
)
else
attention_weights
defnp ensure_offset(offset) do
case offset do
%Axon.None{} -> 0
offset -> offset
end
end

deftransformp max_shape(left, right) do
Enum.zip_with(
Tuple.to_list(Nx.shape(left)),
Tuple.to_list(Nx.shape(right)),
&max/2
)
|> List.to_tuple()
end

@doc """
Adds a dense layer to the network.
Expand Down Expand Up @@ -1063,8 +1151,8 @@ defmodule Bumblebee.Layers do

position_ids = Nx.as_type(position_ids, :s64)

cos = cos |> Nx.take(position_ids) |> Nx.new_axis(2)
sin = sin |> Nx.take(position_ids) |> Nx.new_axis(2)
cos = cos |> Nx.take(position_ids) |> Nx.new_axis(2) |> Nx.as_type(Nx.type(query))
sin = sin |> Nx.take(position_ids) |> Nx.new_axis(2) |> Nx.as_type(Nx.type(query))

rotated_query = query * cos + rotate_half(query) * sin
rotated_key = key * cos + rotate_half(key) * sin
Expand Down
58 changes: 0 additions & 58 deletions lib/bumblebee/layers/decoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -243,62 +243,4 @@ defmodule Bumblebee.Layers.Decoder do
[cache, input_embeddings]
)
end

@doc """
Builds a causal mask and combines it with the given attention mask.
A causal mask is used to mask bidirectional self-attention, such
that it works in a single direction.
Accepts an optional offset, which should be set when passing a
partial query.
"""
def apply_causal_mask(attention_mask, query, offset) do
Axon.layer(
fn
%Axon.None{}, query, %Axon.None{}, _opts ->
# The default attention mask would be all ones (matching
# the batch size and sequence length in query), so we can
# skip it altogether
sequence_length = Nx.axis_size(query, 1)
build_causal_mask(Nx.broadcast(1, {1, sequence_length}))

attention_mask, query, offset, _opts ->
sequence_length = Nx.axis_size(attention_mask, -1)

# We generate a full causal mask, then slice it in case of
# iterative decoding
causal_mask = build_causal_mask(Nx.broadcast(1, {1, sequence_length}))

causal_mask =
case offset do
%Axon.None{} ->
causal_mask

offset ->
mask_shift = offset
query_length = Nx.axis_size(query, 1)
Nx.slice_along_axis(causal_mask, mask_shift, query_length, axis: 2)
end

Nx.logical_and(attention_mask, causal_mask)
end,
[Axon.optional(attention_mask), query, Axon.optional(offset)]
)
end

defnp build_causal_mask(input) do
size = Nx.axis_size(input, -1)
idx = Nx.iota({size}) |> Nx.broadcast(input)
build_attention_mask(idx, idx)
end

# Expects a batched, flat inputs of length corresponding to query
# and key length respectively.
defnp build_attention_mask(query_input, key_input) do
query_input
|> Nx.new_axis(-1)
|> Nx.greater_equal(Nx.new_axis(key_input, -2))
|> Nx.new_axis(-3)
end
end
Loading

0 comments on commit eca3735

Please sign in to comment.