Skip to content

Commit

Permalink
feat: Update rerankers to their latest version. (#42)
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Sep 28, 2024
1 parent 1f1d4e1 commit 7715504
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
14 changes: 9 additions & 5 deletions milvus_model/reranker/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@
import cohere

class CohereRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "rerank-english-v2.0", api_key: Optional[str] = None):
def __init__(self, model_name: str = "rerank-english-v3.0", api_key: Optional[str] = None, return_documents=True, **kwargs):
self.model_name = model_name
self.client = cohere.Client(api_key)
self.client = cohere.ClientV2(api_key)
self.rerank_config = {"return_documents": return_documents, **kwargs}


def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
co_results = self.client.rerank(
query=query, documents=documents, top_n=top_k, model="rerank-english-v2.0"
)
query=query, documents=documents, top_n=top_k, model=self.model_name, **self.rerank_config)
results = []
for co_result in co_results.results:
document_text = ""
if self.rerank_config["return_documents"] is True:
document_text = co_result.document.text
results.append(
RerankResult(
text=co_result.document["text"],
text=document_text,
score=co_result.relevance_score,
index=co_result.index,
)
Expand Down
2 changes: 1 addition & 1 deletion milvus_model/reranker/jinaai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class JinaRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "jina-reranker-v1-base-en", api_key: Optional[str] = None):
def __init__(self, model_name: str = "jina-reranker-v2-base-multilingual", api_key: Optional[str] = None):
if api_key is None:
if "JINAAI_API_KEY" in os.environ and os.environ["JINAAI_API_KEY"]:
self.api_key = os.environ["JINAAI_API_KEY"]
Expand Down
2 changes: 1 addition & 1 deletion milvus_model/reranker/voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import voyageai

class VoyageRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "rerank-lite-1", api_key: Optional[str] = None):
def __init__(self, model_name: str = "rerank-2", api_key: Optional[str] = None):
self.model_name = model_name
self.client = voyageai.Client(api_key=api_key)

Expand Down
2 changes: 1 addition & 1 deletion milvus_model/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def import_unidic_lite():
_check_library("unidic-lite", package="unidic-lite")

def import_cohere():
_check_library("cohere", "cohere")
_check_library("cohere", "cohere>=5.10.0")

def import_voyageai():
_check_library("voyageai", "voyageai>=0.2.0")
Expand Down

0 comments on commit 7715504

Please sign in to comment.