Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Oct 31, 2024
1 parent 81088ca commit c489664
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Model2Vec is a technique to turn any sentence transformer into a really small st
- [Distillation](#distillation)
- [Inference](#inference)
- [Evaluation](#evaluation)
- [Integrations](#integrations)
- [Model List](#model-list)
- [Results](#results)
- [Related Work](#related-work)
Expand Down Expand Up @@ -356,6 +357,14 @@ print(make_leaderboard(task_scores))
```
</details>

### Integrations
<details>
<summary> Sentence Transformers </summary>
<br>
</details>



## Model List

We provide a number of models that can be used out of the box. These models are available on the [HuggingFace hub](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) and can be loaded using the `from_pretrained` method. The models are listed below.
Expand Down
18 changes: 16 additions & 2 deletions scripts/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def save_tokenizer(tokenizer: Tokenizer, save_directory: Path) -> None:
:param tokenizer: The tokenizer from the StaticModel.
:param save_directory: The directory to save the tokenizer files.
:raises FileNotFoundError: If config.json is not found in save_directory.
:raises FileNotFoundError: If tokenizer_config.json is not found in save_directory.
:raises ValueError: If tokenizer_name is not found in config.json.
"""
tokenizer_json_path = save_directory / "tokenizer.json"
Expand Down Expand Up @@ -164,15 +165,28 @@ def save_tokenizer(tokenizer: Tokenizer, save_directory: Path) -> None:
special_tokens = original_tokenizer.special_tokens_map
tokenizer_class = original_tokenizer.__class__.__name__

# Load the tokenizer using PreTrainedTokenizerFast with the correct class and special tokens
# Load the tokenizer using PreTrainedTokenizerFast with special tokens
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_file=str(tokenizer_json_path),
tokenizer_class=tokenizer_class,
**special_tokens,
)

# Save the tokenizer files
fast_tokenizer.save_pretrained(str(save_directory))
# Modify tokenizer_config.json to set the correct tokenizer_class
tokenizer_config_path = save_directory / "tokenizer_config.json"
if tokenizer_config_path.exists():
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
else:
raise FileNotFoundError(f"tokenizer_config.json not found in {save_directory}")

# Update the tokenizer_class field
tokenizer_config["tokenizer_class"] = tokenizer_class

# Write the updated tokenizer_config.json back to disk
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_config, f, indent=4, sort_keys=True)


if __name__ == "__main__":
Expand Down

0 comments on commit c489664

Please sign in to comment.