Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bulk transformations functionality #49

Merged
merged 4 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions lib/tokenizers/encoding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -180,46 +180,54 @@ defmodule Tokenizers.Encoding do
to: Tokenizers.Native,
as: :encoding_char_to_word

@doc """
Pad the encoding to the given length.
@typedoc """
Padding configuration.

## Options
* `:direction` - the padding direction. Defaults to `:right`

* `direction` (default `:right`) - the padding direction
* `:pad_id` - the id corresponding to the padding token. Defaults
to `0`

* `pad_id` (default `0`) - the id corresponding to the padding
token
* `:pad_type_id` - the type ID corresponding to the padding token.
Defaults to `0`

* `pad_type_id` (default `0`) - the type ID corresponding to the
padding token
* `:pad_token` - the padding token to use. Defaults to `"[PAD]"`

* `pad_token` (default `[PAD]`) - the padding token to use
"""
@type padding_opts :: [
pad_id: non_neg_integer(),
pad_type_id: non_neg_integer(),
pad_token: String.t(),
direction: :left | :right
]

@doc """
Pad the encoding to the given length.

For available options see `t:padding_opts/0`.
"""
@spec pad(t(), non_neg_integer(), opts) :: t()
when opts: [
pad_id: non_neg_integer(),
pad_type_id: non_neg_integer(),
pad_token: String.t(),
direction: :left | :right
]
@spec pad(t(), non_neg_integer(), opts :: padding_opts()) :: t()
defdelegate pad(encoding, target_length, opts \\ []),
to: Tokenizers.Native,
as: :encoding_pad

@doc """
Truncate the encoding to the given length.
@typedoc """
Truncation configuration.

## Options
* `:stride` - the length of previous content to be included in each
overflowing piece. Defaults to `0`

* `stride` (default `0`) - the length of previous content to be
included in each overflowing piece
* `:direction` - the truncation direction. Defaults to `:right`

* `direction` (default `:right`) - the truncation direction
"""
@type truncation_opts :: [stride: non_neg_integer(), direction: :left | :right]

@doc """
Truncate the encoding to the given length.

For available options see `t:truncation_opts/0`.
"""
@spec truncate(t(), non_neg_integer(), opts) :: t()
when opts: [stride: non_neg_integer(), direction: :left | :right]
@spec truncate(t(), non_neg_integer(), opts :: truncation_opts()) :: t()
defdelegate truncate(encoding, max_length, opts \\ []),
to: Tokenizers.Native,
as: :encoding_truncate
Expand All @@ -229,6 +237,20 @@ defmodule Tokenizers.Encoding do
"""
@spec n_tokens(encoding :: t()) :: non_neg_integer()
defdelegate n_tokens(encoding), to: Tokenizers.Native, as: :encoding_get_length

@doc """
Performs set of transformations to given encoding, creating a new one.
Transformations are applied in order they are given.

While all these transformations can be done one by one, this function
is more efficient as it avoids multiple allocations and Garbage Collection
for intermediate encodings.

Check the module `Tokenizers.Encoding.Transformation` for handy functions,
that can be used to build the transformations list.
Also, you can build this list manually, as long as it follows the format.
"""
defdelegate transform(encoding, transformations), to: Tokenizers.Native, as: :encoding_transform
end

defimpl Inspect, for: Tokenizers.Encoding do
Expand Down
46 changes: 46 additions & 0 deletions lib/tokenizers/encoding/transformation.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
defmodule Tokenizers.Encoding.Transformation do
@moduledoc """
Module containing handy functions to build the transformations list.

This list is aplied to an encoding using `Tokenizers.Encoding.transform/2`.
"""

@type t :: [
{:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}},
{:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}},
{:set_sequence_id, non_neg_integer()}
]

@doc """
Generates the padding transformation.

Check `Tokenizers.Encoding.pad/3` for more information.
"""
@spec pad(non_neg_integer(), Tokenizers.Encoding.padding_opts()) ::
{:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}}
Comment on lines +19 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved option types directly to the relevant functions, so we no longer have that type. We can bring it back, but it's also fine to just say keyword() since we point to the user to Encoding.pad/3 anyway :)

Suggested change
@spec pad(non_neg_integer(), Tokenizers.Encoding.padding_opts()) ::
{:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}}
@spec pad(non_neg_integer(), keyword()) :: {:pad, {non_neg_integer(), keyword()}}

Copy link
Contributor Author

@Virviil Virviil Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko I'm not sure that having opts inside function spec and not as a type is good idea (at least here). Using keyword() removes the ElixirLS ability to autosuggest, while having full list of opts in every function is a duplication and can lead to inconsistency, when in the future commits one place will be updated and other - no. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I wasn't aware ElixirLS does that, I don't have strong opinion, but that's a fair argument. So in this case we can bring the type back to share it :)

Copy link
Contributor Author

@Virviil Virviil Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is how it looks:

image

It take it from spec, so keyword() is less expressive.

I'll take a look how to return types at least for these functions

def pad(target_length, opts \\ []) do
{:pad, {target_length, opts}}
end

@doc """
Generates the truncation transformation.

Check `Tokenizers.Encoding.truncate/3` for more information.
"""
@spec truncate(non_neg_integer(), Tokenizers.Encoding.truncation_opts()) ::
{:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}}
def truncate(max_length, opts \\ []) do
{:truncate, {max_length, opts}}
end

@doc """
Generates the set_sequence_id transformation.

Check `Tokenizers.Encoding.set_sequence_id/2` for more information.
"""
@spec set_sequence_id(non_neg_integer()) ::
{:set_sequence_id, non_neg_integer()}
def set_sequence_id(id) do
{:set_sequence_id, id}
end
end
2 changes: 2 additions & 0 deletions lib/tokenizers/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ defmodule Tokenizers.Native do
def encoding_char_to_word(_encoding, _position, _seq_id), do: err()
def encoding_pad(_encoding, _target_length, _opts), do: err()
def encoding_truncate(_encoding, _max_length, _opts), do: err()
#
def encoding_transform(_encoding, _transformers), do: err()

# Models
def models_save(_model, _folder, _opts), do: err()
Expand Down
4 changes: 4 additions & 0 deletions lib/tokenizers/tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ defmodule Tokenizers.Tokenizer do
* `:add_special_tokens` - whether to add special tokens to the
sequence. Defaults to `true`

* `:encoding_transformations` - a list of `t:Tokenizers.Encoding.Transformation.t/0`
to apply to the encoding. Check `Tokenizers.Encoding.transform/2`
for more information. Defaults to `[]`

"""
@doc type: :inference
@spec encode(t(), encode_input(), keyword()) :: {:ok, Encoding.t()} | {:error, term()}
Expand Down
116 changes: 87 additions & 29 deletions native/ex_tokenizers/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,33 +187,38 @@ pub enum PadOption {
Direction(Direction),
}

#[rustler::nif]
pub fn encoding_pad(
encoding: ExTokenizersEncoding,
target_length: usize,
opts: Vec<PadOption>,
) -> ExTokenizersEncoding {
struct Padding {
pad_id: u32,
pad_type_id: u32,
pad_token: String,
direction: Direction,
}
struct Padding {
pad_id: u32,
pad_type_id: u32,
pad_token: String,
direction: Direction,
}

fn parse_pad_options(opts: &Vec<PadOption>) -> Padding {
let mut default = Padding {
pad_id: 0,
pad_type_id: 0,
pad_token: "[PAD]".to_string(),
direction: Direction::Right,
};

for opt in opts {
match opt {
PadOption::PadId(id) => default.pad_id = id,
PadOption::PadTypeId(id) => default.pad_type_id = id,
PadOption::PadToken(token) => default.pad_token = token,
PadOption::Direction(direction) => default.direction = direction,
PadOption::PadId(id) => default.pad_id = *id,
PadOption::PadTypeId(id) => default.pad_type_id = *id,
PadOption::PadToken(token) => default.pad_token = token.clone(),
PadOption::Direction(direction) => default.direction = direction.clone(),
}
}
default
}

#[rustler::nif]
pub fn encoding_pad(
encoding: ExTokenizersEncoding,
target_length: usize,
opts: Vec<PadOption>,
) -> ExTokenizersEncoding {
let default = parse_pad_options(&opts);

let mut encoding = encoding.resource.0.clone();
encoding.pad(
Expand All @@ -232,27 +237,33 @@ pub enum TruncationOption {
Direction(Direction),
}

#[rustler::nif]
pub fn encoding_truncate(
encoding: ExTokenizersEncoding,
max_len: usize,
opts: Vec<TruncationOption>,
) -> ExTokenizersEncoding {
struct Truncation {
stride: usize,
direction: Direction,
}
struct Truncation {
stride: usize,
direction: Direction,
}

fn parse_truncation_options(opts: &Vec<TruncationOption>) -> Truncation {
let mut default = Truncation {
stride: 0,
direction: Direction::Right,
};

for opt in opts {
match opt {
TruncationOption::Stride(stride) => default.stride = stride,
TruncationOption::Direction(direction) => default.direction = direction,
TruncationOption::Stride(stride) => default.stride = *stride,
TruncationOption::Direction(direction) => default.direction = direction.clone(),
}
}
default
}

#[rustler::nif]
pub fn encoding_truncate(
encoding: ExTokenizersEncoding,
max_len: usize,
opts: Vec<TruncationOption>,
) -> ExTokenizersEncoding {
let default = parse_truncation_options(&opts);

let mut encoding = encoding.resource.0.clone();

Expand All @@ -263,3 +274,50 @@ pub fn encoding_truncate(
fn slice_u32_to_u8(slice: &[u32]) -> &[u8] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * 4) }
}

///////////////////////////////////////////////////////////////////////////////
/// Encoding transformations
///////////////////////////////////////////////////////////////////////////////

#[derive(NifTaggedEnum)]
pub enum TransformationElement {
Pad((usize, Vec<PadOption>)), // {:pad, {target_length, opts}}
Truncate((usize, Vec<TruncationOption>)), // {:truncate, {max_len, opts}}
SetSequenceId(usize), // {:set_sequence_id, seq_id}
}

#[rustler::nif]
pub fn encoding_transform(
encoding: ExTokenizersEncoding,
transformations: Vec<TransformationElement>,
) -> ExTokenizersEncoding {
let mut encoding = encoding.resource.0.clone();
apply_transformations(&mut encoding, &transformations);
encoding.into()
}

pub fn apply_transformations(
encoding: &mut Encoding,
transformations: &Vec<TransformationElement>,
) {
for transformation in transformations {
match transformation {
TransformationElement::Pad((target_length, opts)) => {
let default = parse_pad_options(opts);

encoding.pad(
*target_length,
default.pad_id,
default.pad_type_id,
&default.pad_token,
default.direction.into(),
)
}
TransformationElement::Truncate((max_len, opts)) => {
let default = parse_truncation_options(opts);
encoding.truncate(*max_len, default.stride, default.direction.into())
}
TransformationElement::SetSequenceId(seq_id) => encoding.set_sequence_id(*seq_id),
}
}
}
2 changes: 2 additions & 0 deletions native/ex_tokenizers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ rustler::init!(
encoding_char_to_word,
encoding_pad,
encoding_truncate,
//
encoding_transform,
// Models
models_save,
//
Expand Down
Loading