From 30933a55fa474ef4fe4728dace5f1cda40cdeda7 Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 23 Oct 2024 10:44:18 +0200 Subject: [PATCH] fix: don't rely on reported vocab size, log warning if inconsistent --- model2vec/distill/inference.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/model2vec/distill/inference.py b/model2vec/distill/inference.py index fcc80d0..a8dbfcc 100644 --- a/model2vec/distill/inference.py +++ b/model2vec/distill/inference.py @@ -115,7 +115,15 @@ def create_output_embeddings_from_model_name( :return: The tokens and output embeddings. """ model = model.to(device) - ids = torch.arange(tokenizer.vocab_size) + + # Quick check to see if the tokenizer is consistent. + vocab_length = len(tokenizer.get_vocab()) + if vocab_length != tokenizer.vocab_size: + logger.warning( + f"Reported vocab size {tokenizer.vocab_size} is inconsistent with the vocab size {vocab_length}." + ) + + ids = torch.arange(vocab_length) # Work-around to get the eos and bos token ids without having to go into tokenizer internals. dummy_encoding = tokenizer.encode("A")