Преглед изворни кода

fix bugs of rerank model with xinference (#1481)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.9.0
Kevin Hu пре 1 година
родитељ
комит
99f7bbaaa2
No account linked to committer's email address
2 измењених фајлова са 20 додато и 7 уклоњено
  1. 11
    0
      api/apps/llm_app.py
  2. 9
    7
      rag/llm/rerank_model.py

+ 11
- 0
api/apps/llm_app.py Прегледај датотеку

@@ -165,6 +165,17 @@ def add_llm():
except Exception as e:
msg += f"\nFail to access model({llm['llm_name']})." + str(
e)
elif llm["model_type"] == LLMType.RERANK:
mdl = RerankModel[factory](
key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
)
try:
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
if len(arr) == 0 or tc == 0:
raise Exception("Not known.")
except Exception as e:
msg += f"\nFail to access model({llm['llm_name']})." + str(
e)
else:
# TODO: check other type of models
pass

+ 9
- 7
rag/llm/rerank_model.py Прегледај датотеку

@@ -136,10 +136,11 @@ class YoudaoRerank(DefaultRerank):
else: res.extend(scores)
return np.array(res), token_count

class XInferenceRerank(Base):
def __init__(self,model_name="",base_url=""):
self.model_name=model_name
self.base_url=base_url
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Content-Type": "application/json",
"accept": "application/json"
@@ -147,11 +148,12 @@ class XInferenceRerank(Base):

def similarity(self, query: str, texts: list):
data = {
"model":self.model_name,
"query":query,
"model": self.model_name,
"query": query,
"return_documents": "true",
"return_len": "true",
"documents":texts
"documents": texts
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"]["output_tokens"]
return np.array([d["relevance_score"] for d in res["results"]]), res["tokens"]["input_tokens"] + res["tokens"][
"output_tokens"]

Loading…
Откажи
Сачувај