Skip to content

Commit

Permalink
feat: make encode_batch_fast optional
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Oct 25, 2024
1 parent 5ed70a6 commit 695a1ee
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
10 changes: 9 additions & 1 deletion model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(
self.config = config or {}
self.base_model_name = base_model_name
self.language = language
if hasattr(self.tokenizer, "encode_batch_fast"):
self._can_encode_fast = True
else:
self._can_encode_fast = False

Check warning on line 65 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L65

Added line #L65 was not covered by tests

if normalize is not None:
self.normalize = normalize
Expand Down Expand Up @@ -121,7 +125,11 @@ def tokenize(self, sentences: list[str], max_length: int | None = None) -> list[
m = max_length * self.median_token_length
sentences = [sentence[:m] for sentence in sentences]

encodings: list[Encoding] = self.tokenizer.encode_batch_fast(sentences, add_special_tokens=False)
if self._can_encode_fast:
encodings: list[Encoding] = self.tokenizer.encode_batch_fast(sentences, add_special_tokens=False)
else:
encodings = self.tokenizer.encode_batch(sentences, add_special_tokens=False)

encodings_ids = [encoding.ids for encoding in encodings]

if self.unk_token_id is not None:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ def test_initialization_token_vector_mismatch(mock_tokenizer: Tokenizer, mock_co
StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)


def test_tokenize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
"""Test tokenization of a sentence."""
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
model._can_encode_fast = True
tokens_fast = model.tokenize(["word1 word2"])
model._can_encode_fast = False
tokens_slow = model.tokenize(["word1 word2"])

assert tokens_fast == tokens_slow


def test_encode_single_sentence(
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
) -> None:
Expand Down

0 comments on commit 695a1ee

Please sign in to comment.