Skip to content

Commit

Permalink
fix: Adapt bge reranker and bge m3 to FlagEmbedding>1.3.0 (#47)
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Nov 19, 2024
1 parent f35dce6 commit 1b9549e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 51 deletions.
7 changes: 5 additions & 2 deletions milvus_model/hybrid/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ def __init__(
"Using fp16 with CPU can lead to runtime errors such as 'LayerNormKernelImpl', It's recommended to set 'use_fp16 = False' when using cpu. "
)

if "devices" in kwargs:
device = devices
kwargs.pop(device)

_model_config = dict(
{
"model_name_or_path": model_name,
"device": device,
"devices": device,
"normalize_embeddings": normalize_embeddings,
"use_fp16": use_fp16,
},
Expand Down Expand Up @@ -80,7 +84,6 @@ def dim(self) -> Dict:
}

def _encode(self, texts: List[str]) -> Dict:
# Change 'sentences' to 'queries' to match the expected parameter
output = self.model.encode(queries=texts, **self._encode_config)
results = {}
if self._encode_config["return_dense"] is True:
Expand Down
73 changes: 25 additions & 48 deletions milvus_model/reranker/bgereranker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import Any, List, Optional, Union

import torch

Expand All @@ -7,7 +7,7 @@

import_FlagEmbedding()
import_transformers()
from FlagEmbedding import FlagReranker
from FlagEmbedding import FlagAutoReranker
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Expand All @@ -21,14 +21,35 @@ def __init__(
use_fp16: bool = True,
batch_size: int = 32,
normalize: bool = True,
device: Optional[str] = None,
device: Optional[Union[str, List]] = None,
query_max_length: int = 256,
max_length: int = 512,
**kwargs: Any,
):

self.model_name = model_name
self.batch_size = batch_size
self.normalize = normalize
self.device = device
self.reranker = _FlagReranker(model_name, use_fp16=use_fp16, device=device)

if "devices" in kwargs:
device = devices
kwargs.pop("devices")

_model_config = dict(
{
"model_name_or_path": model_name,
"batch_size": batch_size,
"use_fp16": use_fp16,
"devices": device,
"max_length": max_length,
"query_max_length": query_max_length,
"normalize": normalize,
},
**kwargs,
)
self.reranker = FlagAutoReranker.from_finetuned(**_model_config)


def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
Expand All @@ -53,47 +74,3 @@ def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[Rer
results.append(RerankResult(text=documents[index], score=scores[index], index=index))
return results


class _FlagReranker(FlagReranker):
def __init__(
self,
model_name_or_path: Optional[str] = None,
use_fp16: bool = False,
cache_dir: Optional[str] = None,
device: Optional[Union[str, int]] = None,
) -> None:

self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, cache_dir=cache_dir
)

if device and isinstance(device, str):
self.device = torch.device(device)
if device == "cpu":
use_fp16 = False
elif torch.cuda.is_available():
if device is not None:
self.device = torch.device(f"cuda:{device}")
else:
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
if use_fp16:
self.model.half()

self.model = self.model.to(self.device)

self.model.eval()

if device is None:
self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
self.model = torch.nn.DataParallel(self.model)
else:
self.num_gpus = 1
2 changes: 1 addition & 1 deletion milvus_model/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def import_sentence_transformers():

def import_FlagEmbedding():
_check_library("peft", package="peft")
_check_library("FlagEmbedding", package="FlagEmbedding~=1.2.11")
_check_library("FlagEmbedding", package="FlagEmbedding")

def import_nltk():
_check_library("nltk", package="nltk>=3.9.1")
Expand Down

0 comments on commit 1b9549e

Please sign in to comment.