Skip to content

Commit

Permalink
feat: Add nomic and mistralai depencencies, remove the vertexai tempo… (
Browse files Browse the repository at this point in the history
#29)

…rarily.

Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Aug 7, 2024
1 parent 692dc6c commit a450490
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 97 deletions.
6 changes: 0 additions & 6 deletions milvus_model/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"JinaEmbeddingFunction",
"OnnxEmbeddingFunction",
"CohereEmbeddingFunction",
"VertexAIEmbeddingFunction",
"MistralAIEmbeddingFunction",
"NomicEmbeddingFunction",
]
Expand All @@ -20,7 +19,6 @@
voyageai = LazyImport("voyageai", globals(), "milvus_model.dense.voyageai")
onnx = LazyImport("onnx", globals(), "milvus_model.dense.onnx")
cohere = LazyImport("cohere", globals(), "milvus_model.dense.cohere")
vertexai = LazyImport("vertexai", globals(), "milvus_model.dense.vertexai")
mistralai = LazyImport("mistralai", globals(), "milvus_model.dense.mistralai")
nomic = LazyImport("nomic", globals(), "milvus_model.dense.nomic")

Expand Down Expand Up @@ -49,10 +47,6 @@ def CohereEmbeddingFunction(*args, **kwargs):
return cohere.CohereEmbeddingFunction(*args, **kwargs)


def VertexAIEmbeddingFunction(*args, **kwargs):
return vertexai.VertexAIEmbeddingFunction(*args, **kwargs)


def MistralAIEmbeddingFunction(*args, **kwargs):
return mistralai.MistralAIEmbeddingFunction(*args, **kwargs)

Expand Down
10 changes: 7 additions & 3 deletions milvus_model/dense/mistralai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import List, Optional
import os
import numpy as np
from collections import defaultdict
from mistralai.client import MistralClient
import os

from milvus_model.base import BaseEmbeddingFunction
from milvus_model.utils import import_mistralai

import_mistralai()
from mistralai.client import MistralClient

class MistralAIEmbeddingFunction:
class MistralAIEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
api_key: str,
Expand Down
22 changes: 15 additions & 7 deletions milvus_model/dense/nomic.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from typing import List
import numpy as np
from collections import defaultdict
from nomic import embed
import os
from collections import defaultdict

from milvus_model.base import BaseEmbeddingFunction
from milvus_model.utils import import_nomic

class NomicEmbeddingFunction:
import_nomic()
from nomic import embed

class NomicEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
api_key: str,
model_name: str = "nomic-embed-text-v1.5",
task_type: str = "search_document",
dimensionality: int = 768,
dimensions: int = 768,
**kwargs,
):
self._nomic_model_meta_info = defaultdict(dict)
self._nomic_model_meta_info[model_name]["dim"] = dimensionality # set the dimension
self._nomic_model_meta_info[model_name]["dim"] = dimensions # set the dimension

if api_key is None:
if "NOMIC_API_KEY" in os.environ and os.environ["NOMIC_API_KEY"]:
Expand All @@ -31,11 +35,15 @@ def __init__(
self.api_key = api_key
self.model_name = model_name
self.task_type = task_type
self.dimensionality = dimensionality
self.dimensionality = dimensions
if "dimensionality" in kwargs:
self.dimensionality = kwargs["dimensionality"]
kwargs.pop("dimensionality")

self._encode_config = {
"model": model_name,
"task_type": task_type,
"dimensionality": dimensionality,
"dimensionality": self.dimensionality,
**kwargs,
}

Expand Down
79 changes: 0 additions & 79 deletions milvus_model/dense/vertexai.py

This file was deleted.

12 changes: 10 additions & 2 deletions milvus_model/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
"import_protobuf",
"import_unidic_lite",
"import_cohere",
"import_voyageai"
"import_voyageai",
"import_torch",
"import_huggingface_hub"
"import_huggingface_hub",
"import_mistralai",
"import_nomic"
]

import importlib.util
Expand Down Expand Up @@ -66,6 +68,12 @@ def import_torch():
def import_huggingface_hub():
_check_library("huggingface_hub", package="huggingface-hub")

def import_mistralai():
_check_library("mistralai", package="mistralai")

def import_nomic():
_check_library("nomic", package="nomic")

def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None):
is_avail = False
if importlib.util.find_spec(libname):
Expand Down

0 comments on commit a450490

Please sign in to comment.