### What problem does this PR solve? #724 #162 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)tags/v0.7.0
| @@ -39,17 +39,18 @@ def factories(): | |||
| def set_api_key(): | |||
| req = request.json | |||
| # test if api key works | |||
| chat_passed = False | |||
| chat_passed, embd_passed, rerank_passed = False, False, False | |||
| factory = req["llm_factory"] | |||
| msg = "" | |||
| for llm in LLMService.query(fid=factory): | |||
| if llm.model_type == LLMType.EMBEDDING.value: | |||
| if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: | |||
| mdl = EmbeddingModel[factory]( | |||
| req["api_key"], llm.llm_name, base_url=req.get("base_url")) | |||
| try: | |||
| arr, tc = mdl.encode(["Test if the api key is available"]) | |||
| if len(arr[0]) == 0 or tc == 0: | |||
| raise Exception("Fail") | |||
| embd_passed = True | |||
| except Exception as e: | |||
| msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) | |||
| elif not chat_passed and llm.model_type == LLMType.CHAT.value: | |||
| @@ -60,20 +61,21 @@ def set_api_key(): | |||
| "temperature": 0.9}) | |||
| if not tc: | |||
| raise Exception(m) | |||
| chat_passed = True | |||
| except Exception as e: | |||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | |||
| e) | |||
| elif llm.model_type == LLMType.RERANK: | |||
| chat_passed = True | |||
| elif not rerank_passed and llm.model_type == LLMType.RERANK: | |||
| mdl = RerankModel[factory]( | |||
| req["api_key"], llm.llm_name, base_url=req.get("base_url")) | |||
| try: | |||
| m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) | |||
| if len(arr[0]) == 0 or tc == 0: | |||
| arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) | |||
| if len(arr) == 0 or tc == 0: | |||
| raise Exception("Fail") | |||
| except Exception as e: | |||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | |||
| e) | |||
| rerank_passed = True | |||
| if msg: | |||
| return get_data_error_result(retmsg=msg) | |||
| @@ -147,7 +147,6 @@ class TenantLLMService(CommonService): | |||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ | |||
| .execute() | |||
| except Exception as e: | |||
| print(e) | |||
| pass | |||
| return num | |||
| @@ -28,6 +28,7 @@ EmbeddingModel = { | |||
| "FastEmbed": FastEmbed, | |||
| "Youdao": YoudaoEmbed, | |||
| "BaiChuan": BaiChuanEmbed, | |||
| "Jina": JinaEmbed, | |||
| "BAAI": DefaultEmbedding | |||
| } | |||
| @@ -291,7 +291,7 @@ class JinaEmbed(Base): | |||
| "input": texts, | |||
| 'encoding_type': 'float' | |||
| } | |||
| res = requests.post(self.base_url, headers=self.headers, json=data) | |||
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |||
| return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] | |||
| def encode_queries(self, text): | |||
| @@ -91,7 +91,7 @@ class JinaRerank(Base): | |||
| "documents": texts, | |||
| "top_n": len(texts) | |||
| } | |||
| res = requests.post(self.base_url, headers=self.headers, json=data) | |||
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |||
| return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"] | |||
| @@ -44,7 +44,7 @@ class EsQueryer: | |||
| def question(self, txt, tbl="qa", min_match="60%"): | |||
| txt = re.sub( | |||
| r"[ \r\n\t,,。??/`!!&\^%%]+", | |||
| r"[ :\r\n\t,,。??/`!!&\^%%]+", | |||
| " ", | |||
| rag_tokenizer.tradi2simp( | |||
| rag_tokenizer.strQ2B( | |||
| @@ -104,7 +104,7 @@ class Dealer: | |||
| while i < len(tks): | |||
| j = i | |||
| if i == 0 and oneTerm(tks[i]) and len( | |||
| tks) > 1 and len(tks[i + 1]) > 1: # 多 工位 | |||
| tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 | |||
| res.append(" ".join(tks[0:2])) | |||
| i = 2 | |||
| continue | |||