Pārlūkot izejas kodu

fix jina adding issure and term weight refinement (#974)

### What problem does this PR solve?

#724 #162

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
tags/v0.7.0
KevinHuSh pirms 1 gada
vecāks
revīzija
758eb03ccb
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam

+ 8
- 6
api/apps/llm_app.py Parādīt failu

def set_api_key(): def set_api_key():
req = request.json req = request.json
# test if api key works # test if api key works
chat_passed = False
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"] factory = req["llm_factory"]
msg = "" msg = ""
for llm in LLMService.query(fid=factory): for llm in LLMService.query(fid=factory):
if llm.model_type == LLMType.EMBEDDING.value:
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory]( mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url")) req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try: try:
arr, tc = mdl.encode(["Test if the api key is available"]) arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0 or tc == 0: if len(arr[0]) == 0 or tc == 0:
raise Exception("Fail") raise Exception("Fail")
embd_passed = True
except Exception as e: except Exception as e:
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
elif not chat_passed and llm.model_type == LLMType.CHAT.value: elif not chat_passed and llm.model_type == LLMType.CHAT.value:
"temperature": 0.9}) "temperature": 0.9})
if not tc: if not tc:
raise Exception(m) raise Exception(m)
chat_passed = True
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e) e)
elif llm.model_type == LLMType.RERANK:
chat_passed = True
elif not rerank_passed and llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory]( mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url")) req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try: try:
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr[0]) == 0 or tc == 0:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
raise Exception("Fail") raise Exception("Fail")
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e) e)
rerank_passed = True
if msg: if msg:
return get_data_error_result(retmsg=msg) return get_data_error_result(retmsg=msg)

+ 0
- 1
api/db/services/llm_service.py Parādīt failu

.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
.execute() .execute()
except Exception as e: except Exception as e:
print(e)
pass pass
return num return num

+ 1
- 0
rag/llm/__init__.py Parādīt failu

"FastEmbed": FastEmbed, "FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed, "Youdao": YoudaoEmbed,
"BaiChuan": BaiChuanEmbed, "BaiChuan": BaiChuanEmbed,
"Jina": JinaEmbed,
"BAAI": DefaultEmbedding "BAAI": DefaultEmbedding
} }



+ 1
- 1
rag/llm/embedding_model.py Parādīt failu

"input": texts, "input": texts,
'encoding_type': 'float' 'encoding_type': 'float'
} }
res = requests.post(self.base_url, headers=self.headers, json=data)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]


def encode_queries(self, text): def encode_queries(self, text):

+ 1
- 1
rag/llm/rerank_model.py Parādīt failu

"documents": texts, "documents": texts,
"top_n": len(texts) "top_n": len(texts)
} }
res = requests.post(self.base_url, headers=self.headers, json=data)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"] return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]





+ 1
- 1
rag/nlp/query.py Parādīt failu



def question(self, txt, tbl="qa", min_match="60%"): def question(self, txt, tbl="qa", min_match="60%"):
txt = re.sub( txt = re.sub(
r"[ \r\n\t,,。??/`!!&\^%%]+",
r"[ :\r\n\t,,。??/`!!&\^%%]+",
" ", " ",
rag_tokenizer.tradi2simp( rag_tokenizer.tradi2simp(
rag_tokenizer.strQ2B( rag_tokenizer.strQ2B(

+ 1
- 1
rag/nlp/term_weight.py Parādīt failu

while i < len(tks): while i < len(tks):
j = i j = i
if i == 0 and oneTerm(tks[i]) and len( if i == 0 and oneTerm(tks[i]) and len(
tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2])) res.append(" ".join(tks[0:2]))
i = 2 i = 2
continue continue

Notiek ielāde…
Atcelt
Saglabāt