Skip to content

Commit

Permalink
Add bulk transformations functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Rubinstein committed Aug 8, 2023
1 parent 777cf46 commit c1c3077
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 43 deletions.
14 changes: 14 additions & 0 deletions lib/tokenizers/encoding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,20 @@ defmodule Tokenizers.Encoding do
"""
@spec n_tokens(encoding :: t()) :: non_neg_integer()
defdelegate n_tokens(encoding), to: Tokenizers.Native, as: :encoding_get_length

@doc """
Performs set of transformations to given encoding, creating a new one.
Transformations are applied in order they are given.
While all these transformations can be done one by one, this function
is more efficient as it avoids multiple allocations and Garbage Collection
for intermediate encodings.
Check the module `Tokenizers.Encoding.Transformation` for handy functions,
that can be used to build the transformations list.
Also, you can build this list manually, as long as it follows the format.
"""
defdelegate transform(encoding, transformations), to: Tokenizers.Native, as: :encoding_transform
end

defimpl Inspect, for: Tokenizers.Encoding do
Expand Down
46 changes: 46 additions & 0 deletions lib/tokenizers/encoding/transformation.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
defmodule Tokenizers.Encoding.Transformation do
@moduledoc """
Module containing handy functions to build the transformations list.
This list is aplied to an encoding using `Tokenizers.Encoding.transform/2`.
"""

@type t :: [
{:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}},
{:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}},
{:set_sequence_id, non_neg_integer()}
]

@doc """
Generates the padding transformation.
Check `Tokenizers.Encoding.pad/3` for more information.
"""
@spec pad(non_neg_integer(), Tokenizers.Encoding.padding_opts()) ::
{:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}}
def pad(target_length, opts \\ []) do
{:pad, {target_length, opts}}
end

@doc """
Generates the truncation transformation.
Check `Tokenizers.Encoding.truncate/3` for more information.
"""
@spec truncate(non_neg_integer(), Tokenizers.Encoding.truncation_opts()) ::
{:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}}
def truncate(max_length, opts \\ []) do
{:truncate, {max_length, opts}}
end

@doc """
Generates the set_sequence_id transformation.
Check `Tokenizers.Encoding.set_sequence_id/2` for more information.
"""
@spec set_sequence_id(non_neg_integer()) ::
{:set_sequence_id, non_neg_integer()}
def set_sequence_id(id) do
{:set_sequence_id, id}
end
end
2 changes: 2 additions & 0 deletions lib/tokenizers/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ defmodule Tokenizers.Native do
def encoding_char_to_word(_encoding, _position, _seq_id), do: err()
def encoding_pad(_encoding, _target_length, _opts), do: err()
def encoding_truncate(_encoding, _max_length, _opts), do: err()
#
def encoding_transform(_encoding, _transformers), do: err()

# Models
def models_save(_model, _folder, _opts), do: err()
Expand Down
3 changes: 2 additions & 1 deletion lib/tokenizers/tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ defmodule Tokenizers.Tokenizer do
* `:add_special_tokens` - whether to add special tokens to the
sequence. Defaults to `true`
* `:encoding_transformations` (default: `[]`) - a list of `Tokenizers.Encoding.Transformation.t()` to apply to the encoding.
Check `Tokenizers.Encoding.transform/2` for more information.
"""
@doc type: :inference
@spec encode(t(), encode_input(), keyword()) :: {:ok, Encoding.t()} | {:error, term()}
Expand Down
116 changes: 87 additions & 29 deletions native/ex_tokenizers/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,33 +187,38 @@ pub enum PadOption {
Direction(Direction),
}

#[rustler::nif]
pub fn encoding_pad(
encoding: ExTokenizersEncoding,
target_length: usize,
opts: Vec<PadOption>,
) -> ExTokenizersEncoding {
struct Padding {
pad_id: u32,
pad_type_id: u32,
pad_token: String,
direction: Direction,
}
struct Padding {
pad_id: u32,
pad_type_id: u32,
pad_token: String,
direction: Direction,
}

fn parse_pad_options(opts: &Vec<PadOption>) -> Padding {
let mut default = Padding {
pad_id: 0,
pad_type_id: 0,
pad_token: "[PAD]".to_string(),
direction: Direction::Right,
};

for opt in opts {
match opt {
PadOption::PadId(id) => default.pad_id = id,
PadOption::PadTypeId(id) => default.pad_type_id = id,
PadOption::PadToken(token) => default.pad_token = token,
PadOption::Direction(direction) => default.direction = direction,
PadOption::PadId(id) => default.pad_id = *id,
PadOption::PadTypeId(id) => default.pad_type_id = *id,
PadOption::PadToken(token) => default.pad_token = token.clone(),
PadOption::Direction(direction) => default.direction = direction.clone(),
}
}
default
}

#[rustler::nif]
pub fn encoding_pad(
encoding: ExTokenizersEncoding,
target_length: usize,
opts: Vec<PadOption>,
) -> ExTokenizersEncoding {
let default = parse_pad_options(&opts);

let mut encoding = encoding.resource.0.clone();
encoding.pad(
Expand All @@ -232,27 +237,33 @@ pub enum TruncationOption {
Direction(Direction),
}

#[rustler::nif]
pub fn encoding_truncate(
encoding: ExTokenizersEncoding,
max_len: usize,
opts: Vec<TruncationOption>,
) -> ExTokenizersEncoding {
struct Truncation {
stride: usize,
direction: Direction,
}
struct Truncation {
stride: usize,
direction: Direction,
}

fn parse_truncation_options(opts: &Vec<TruncationOption>) -> Truncation {
let mut default = Truncation {
stride: 0,
direction: Direction::Right,
};

for opt in opts {
match opt {
TruncationOption::Stride(stride) => default.stride = stride,
TruncationOption::Direction(direction) => default.direction = direction,
TruncationOption::Stride(stride) => default.stride = *stride,
TruncationOption::Direction(direction) => default.direction = direction.clone(),
}
}
default
}

#[rustler::nif]
pub fn encoding_truncate(
encoding: ExTokenizersEncoding,
max_len: usize,
opts: Vec<TruncationOption>,
) -> ExTokenizersEncoding {
let default = parse_truncation_options(&opts);

let mut encoding = encoding.resource.0.clone();

Expand All @@ -263,3 +274,50 @@ pub fn encoding_truncate(
fn slice_u32_to_u8(slice: &[u32]) -> &[u8] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * 4) }
}

///////////////////////////////////////////////////////////////////////////////
/// Encoding transformations
///////////////////////////////////////////////////////////////////////////////

#[derive(NifTaggedEnum)]
pub enum TransformationElement {
Pad((usize, Vec<PadOption>)), // {:pad, {target_length, opts}}
Truncate((usize, Vec<TruncationOption>)), // {:truncate, {max_len, opts}}
SetSequenceId(usize), // {:set_sequence_id, seq_id}
}

#[rustler::nif]
pub fn encoding_transform(
encoding: ExTokenizersEncoding,
transformations: Vec<TransformationElement>,
) -> ExTokenizersEncoding {
let mut encoding = encoding.resource.0.clone();
apply_transformations(&mut encoding, &transformations);
encoding.into()
}

pub fn apply_transformations(
encoding: &mut Encoding,
transformations: &Vec<TransformationElement>,
) {
for transformation in transformations {
match transformation {
TransformationElement::Pad((target_length, opts)) => {
let default = parse_pad_options(opts);

encoding.pad(
*target_length,
default.pad_id,
default.pad_type_id,
&default.pad_token,
default.direction.into(),
)
}
TransformationElement::Truncate((max_len, opts)) => {
let default = parse_truncation_options(opts);
encoding.truncate(*max_len, default.stride, default.direction.into())
}
TransformationElement::SetSequenceId(seq_id) => encoding.set_sequence_id(*seq_id),
}
}
}
2 changes: 2 additions & 0 deletions native/ex_tokenizers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ rustler::init!(
encoding_char_to_word,
encoding_pad,
encoding_truncate,
//
encoding_transform,
// Models
models_save,
//
Expand Down
37 changes: 29 additions & 8 deletions native/ex_tokenizers/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tokenizers::{EncodeInput, TokenizerImpl};

use crate::added_token::{AddedSpecialTokenInput, AddedTokenInput};
use crate::decoders::ExTokenizersDecoder;
use crate::encoding::ExTokenizersEncoding;
use crate::encoding::{apply_transformations, ExTokenizersEncoding, TransformationElement};
use crate::error::ExTokenizersError;
use crate::models::ExTokenizersModel;
use crate::normalizers::ExTokenizersNormalizer;
Expand Down Expand Up @@ -428,6 +428,7 @@ fn term_to_encode_input<'a>(term: &'a Term) -> Result<EncodeInput<'a>, ExTokeniz
#[derive(NifTaggedEnum)]
pub enum EncodeOption {
AddSpecialTokens(bool),
EncodingTransformations(Vec<TransformationElement>),
}

#[rustler::nif(schedule = "DirtyCpu")]
Expand All @@ -438,21 +439,27 @@ pub fn tokenizer_encode(
) -> Result<ExTokenizersEncoding, ExTokenizersError> {
struct Opts {
add_special_tokens: bool,
encoding_transformations: Vec<TransformationElement>,
}
let mut opts = Opts {
add_special_tokens: true,
encoding_transformations: Vec::new(),
};
options.iter().for_each(|option| match option {
options.into_iter().for_each(|option| match option {
EncodeOption::AddSpecialTokens(add_special_tokens) => {
opts.add_special_tokens = *add_special_tokens
opts.add_special_tokens = add_special_tokens
}
EncodeOption::EncodingTransformations(encoding_transformations) => {
opts.encoding_transformations = encoding_transformations
}
});

let input = term_to_encode_input(&input)?;
let encoding = tokenizer
let mut encoding = tokenizer
.resource
.0
.encode(input, opts.add_special_tokens)?;
apply_transformations(&mut encoding, &opts.encoding_transformations);
Ok(encoding.into())
}

Expand All @@ -465,24 +472,38 @@ pub fn tokenizer_encode_batch(
) -> Result<Vec<ExTokenizersEncoding>, ExTokenizersError> {
struct Opts {
add_special_tokens: bool,
encoding_transformations: Vec<TransformationElement>,
}
let mut opts = Opts {
add_special_tokens: true,
encoding_transformations: Vec::new(),
};
options.iter().for_each(|option| match option {
options.into_iter().for_each(|option| match option {
EncodeOption::AddSpecialTokens(add_special_tokens) => {
opts.add_special_tokens = *add_special_tokens
opts.add_special_tokens = add_special_tokens
}
EncodeOption::EncodingTransformations(encoding_transformations) => {
opts.encoding_transformations = encoding_transformations
}
});
let inputs = inputs
.iter()
.map(term_to_encode_input)
.collect::<Result<Vec<EncodeInput>, ExTokenizersError>>()?;
let encodings = tokenizer
let mut encodings = tokenizer
.resource
.0
.encode_batch(inputs, opts.add_special_tokens)?;
let ex_encodings = encodings.into_iter().map(|x| x.into()).collect();

// Applying transformations (if any)
for encoding in encodings.iter_mut() {
apply_transformations(encoding, &opts.encoding_transformations);
}

let ex_encodings = encodings
.into_iter()
.map(|encoding| encoding.into())
.collect();
Ok(ex_encodings)
}

Expand Down
2 changes: 1 addition & 1 deletion native/ex_tokenizers/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl rustler::Encoder for Info {
}
}

#[derive(rustler::NifUnitEnum)]
#[derive(rustler::NifUnitEnum, Clone)]
pub enum Direction {
Left,
Right,
Expand Down
6 changes: 2 additions & 4 deletions test/tokenizers/post_processor_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ defmodule Tokenizers.PostProcessorTest do
Tokenizers.PostProcessor.bert({"[SEP]", 0}, {"[CLS]", 1})
)

{:ok, output} =
Tokenizers.Tokenizer.encode(tokenizer, {"my name", "pair"})
{:ok, output} = Tokenizers.Tokenizer.encode(tokenizer, {"my name", "pair"})

assert Tokenizers.Encoding.get_tokens(output) == [
"[CLS]",
Expand Down Expand Up @@ -52,8 +51,7 @@ defmodule Tokenizers.PostProcessorTest do
Tokenizers.PostProcessor.roberta({"</s>", 1}, {"<s>", 0})
)

{:ok, output} =
Tokenizers.Tokenizer.encode(tokenizer, {"my name", "pair"})
{:ok, output} = Tokenizers.Tokenizer.encode(tokenizer, {"my name", "pair"})

assert Tokenizers.Encoding.get_tokens(output) == [
"<s>",
Expand Down
Loading

0 comments on commit c1c3077

Please sign in to comment.