Parcourir la source

Feat: support huggingface re-rank model. (#5684)

### What problem does this PR solve?

#5658

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.17.1
Kevin Hu il y a 8 mois
Parent
révision
b8da2eeb69
Aucun compte lié à l'adresse e-mail de l'auteur
3 fichiers modifiés avec 46 ajouts et 7 suppressions
  1. 1
    1
      conf/llm_factories.json
  2. 5
    0
      rag/llm/__init__.py
  3. 40
    6
      rag/llm/rerank_model.py

+ 1
- 1
conf/llm_factories.json Voir le fichier

@@ -3284,7 +3284,7 @@
{
"name": "HuggingFace",
"logo": "",
"tags": "TEXT EMBEDDING",
"tags": "TEXT EMBEDDING,TEXT RE-RANK",
"status": "1",
"llm": []
},

+ 5
- 0
rag/llm/__init__.py Voir le fichier

@@ -107,6 +107,7 @@ from .cv_model import (
YiCV,
HunyuanCV,
)

from .rerank_model import (
LocalAIRerank,
DefaultRerank,
@@ -123,7 +124,9 @@ from .rerank_model import (
VoyageRerank,
QWenRerank,
GPUStackRerank,
HuggingfaceRerank,
)

from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
@@ -132,6 +135,7 @@ from .sequence2txt_model import (
TencentCloudSeq2txt,
GPUStackSeq2txt,
)

from .tts_model import (
FishAudioTTS,
QwenTTS,
@@ -255,6 +259,7 @@ RerankModel = {
"Voyage AI": VoyageRerank,
"Tongyi-Qianwen": QWenRerank,
"GPUStack": GPUStackRerank,
"HuggingFace": HuggingfaceRerank,
}

Seq2txtModel = {

+ 40
- 6
rag/llm/rerank_model.py Voir le fichier

@@ -31,7 +31,6 @@ from rag.utils import num_tokens_from_string, truncate
import json



def sigmoid(x):
return 1 / (1 + np.exp(-x))

@@ -87,10 +86,9 @@ class DefaultRerank(Base):
local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model
self._dynamic_batch_size = 8
self._dynamic_batch_size = 8
self._min_batch_size = 1

def torch_empty_cache(self):
try:
import torch
@@ -112,7 +110,7 @@ class DefaultRerank(Base):
while retry_count < max_retries:
try:
# call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i:i+current_batch])
batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
res.extend(batch_scores)
i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
@@ -282,6 +280,7 @@ class LocalAIRerank(Base):

return rank, token_count


class NvidiaRerank(Base):
def __init__(
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
@@ -513,6 +512,40 @@ class QWenRerank(Base):
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")


class HuggingfaceRerank(DefaultRerank):
@staticmethod
def post(query: str, texts: list, url="127.0.0.1"):
exc = None
scores = [0 for _ in range(len(texts))]
batch_size = 8
for i in range(0, len(texts), batch_size):
try:
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
json={"query": query, "texts": texts[i: i + batch_size],
"raw_scores": False, "truncate": True})
for o in res.json():
scores[o["index"] + i] = o["score"]
except Exception as e:
exc = e

if exc:
raise exc
return np.array(scores)

def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
self.model_name = model_name
self.base_url = base_url

def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
if not texts:
return np.array([]), 0
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
return HuggingfaceRerank.post(query, texts, self.base_url), token_count


class GPUStackRerank(Base):
def __init__(
self, key, model_name, base_url
@@ -521,7 +554,7 @@ class GPUStackRerank(Base):
raise ValueError("url cannot be None")

self.model_name = model_name
self.base_url = str(URL(base_url)/ "v1" / "rerank")
self.base_url = str(URL(base_url) / "v1" / "rerank")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
@@ -560,5 +593,6 @@ class GPUStackRerank(Base):
)

except httpx.HTTPStatusError as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
raise ValueError(
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")


Chargement…
Annuler
Enregistrer