### What problem does this PR solve? Fix keys of Xinference deployed models, especially has the same model name with public hosted models. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: 0000sir <0000sir@gmail.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.13.0
| @@ -343,10 +343,10 @@ def list_app(): | |||
| for m in llms: | |||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied | |||
| llm_set = set([m["llm_name"] for m in llms]) | |||
| llm_set = set([m["llm_name"]+"@"+m["fid"] for m in llms]) | |||
| for o in objs: | |||
| if not o.api_key:continue | |||
| if o.llm_name in llm_set:continue | |||
| if o.llm_name+"@"+o.llm_factory in llm_set:continue | |||
| llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True}) | |||
| res = {} | |||
| @@ -494,25 +494,24 @@ def set(tenant_id,dataset_id,document_id,chunk_id): | |||
| @manager.route('/retrieval', methods=['GET']) | |||
| @manager.route('/retrieval', methods=['POST']) | |||
| @token_required | |||
| def retrieval_test(tenant_id): | |||
| req = request.args | |||
| req_json = request.json | |||
| if not req_json.get("datasets"): | |||
| req = request.json | |||
| if not req.get("datasets"): | |||
| return get_error_data_result("`datasets` is required.") | |||
| for id in req_json.get("datasets"): | |||
| kb_id = req["datasets"] | |||
| if isinstance(kb_id, str): kb_id = [kb_id] | |||
| for id in kb_id: | |||
| if not KnowledgebaseService.query(id=id,tenant_id=tenant_id): | |||
| return get_error_data_result(f"You don't own the dataset {id}.") | |||
| if "question" not in req_json: | |||
| return get_error_data_result("`question` is required.") | |||
| page = int(req.get("offset", 1)) | |||
| size = int(req.get("limit", 30)) | |||
| question = req_json["question"] | |||
| kb_id = req_json["datasets"] | |||
| if isinstance(kb_id, str): kb_id = [kb_id] | |||
| doc_ids = req_json.get("documents", []) | |||
| similarity_threshold = float(req.get("similarity_threshold", 0.0)) | |||
| question = req["question"] | |||
| doc_ids = req.get("documents", []) | |||
| similarity_threshold = float(req.get("similarity_threshold", 0.2)) | |||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | |||
| top = int(req.get("top_k", 1024)) | |||
| if req.get("highlight")=="False" or req.get("highlight")=="false": | |||
| @@ -453,7 +453,7 @@ class XinferenceCV(Base): | |||
| def __init__(self, key, model_name="", lang="Chinese", base_url=""): | |||
| if base_url.split("/")[-1] != "v1": | |||
| base_url = os.path.join(base_url, "v1") | |||
| self.client = OpenAI(api_key="xxx", base_url=base_url) | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| @@ -274,7 +274,7 @@ class XinferenceEmbed(Base): | |||
| def __init__(self, key, model_name="", base_url=""): | |||
| if base_url.split("/")[-1] != "v1": | |||
| base_url = os.path.join(base_url, "v1") | |||
| self.client = OpenAI(api_key="xxx", base_url=base_url) | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| def encode(self, texts: list, batch_size=32): | |||
| @@ -162,7 +162,8 @@ class XInferenceRerank(Base): | |||
| self.base_url = base_url | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "accept": "application/json" | |||
| "accept": "application/json", | |||
| "Authorization": f"Bearer {key}" | |||
| } | |||
| def similarity(self, query: str, texts: list): | |||
| @@ -90,6 +90,7 @@ class XinferenceSeq2txt(Base): | |||
| def __init__(self,key,model_name="whisper-small",**kwargs): | |||
| self.base_url = kwargs.get('base_url', None) | |||
| self.model_name = model_name | |||
| self.key = key | |||
| def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): | |||
| if isinstance(audio, str): | |||
| @@ -74,6 +74,12 @@ class RAGFlow: | |||
| if res.get("code") != 0: | |||
| raise Exception(res["message"]) | |||
| def get_dataset(self,name: str): | |||
| _list = self.list_datasets(name=name) | |||
| if len(_list) > 0: | |||
| return _list[0] | |||
| raise Exception("Dataset %s not found" % name) | |||
| def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True, | |||
| id: str = None, name: str = None) -> \ | |||
| List[DataSet]: | |||