### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.12.0
| @token_required | @token_required | ||||
| def docinfos(tenant_id): | def docinfos(tenant_id): | ||||
| req = request.args | req = request.args | ||||
| if "id" not in req and "name" not in req: | |||||
| return get_data_error_result( | |||||
| retmsg="Id or name should be provided") | |||||
| doc_id=None | |||||
| if "id" in req: | if "id" in req: | ||||
| doc_id = req["id"] | doc_id = req["id"] | ||||
| e, doc = DocumentService.get_by_id(doc_id) | |||||
| return get_json_result(data=doc.to_json()) | |||||
| if "name" in req: | if "name" in req: | ||||
| doc_name = req["name"] | doc_name = req["name"] | ||||
| doc_id = DocumentService.get_doc_id_by_doc_name(doc_name) | doc_id = DocumentService.get_doc_id_by_doc_name(doc_name) | ||||
| e, doc = DocumentService.get_by_id(doc_id) | |||||
| return get_json_result(data=doc.to_json()) | |||||
| e, doc = DocumentService.get_by_id(doc_id) | |||||
| #rename key's name | |||||
| key_mapping = { | |||||
| "chunk_num": "chunk_count", | |||||
| "kb_id": "knowledgebase_id", | |||||
| "token_num": "token_count", | |||||
| } | |||||
| renamed_doc = {} | |||||
| for key, value in doc.to_dict().items(): | |||||
| new_key = key_mapping.get(key, key) | |||||
| renamed_doc[new_key] = value | |||||
| return get_json_result(data=renamed_doc) | |||||
| @manager.route('/save', methods=['POST']) | @manager.route('/save', methods=['POST']) | ||||
| req["doc_id"], {"name": req["name"]}): | req["doc_id"], {"name": req["name"]}): | ||||
| return get_data_error_result( | return get_data_error_result( | ||||
| retmsg="Database error (Document rename)!") | retmsg="Database error (Document rename)!") | ||||
| informs = File2DocumentService.get_by_document_id(req["doc_id"]) | informs = File2DocumentService.get_by_document_id(req["doc_id"]) | ||||
| if informs: | if informs: | ||||
| e, file = FileService.get_by_id(informs[0].file_id) | e, file = FileService.get_by_id(informs[0].file_id) | ||||
| @manager.route("/<document_id>", methods=["GET"]) | @manager.route("/<document_id>", methods=["GET"]) | ||||
| @token_required | @token_required | ||||
| def download_document(dataset_id, document_id): | |||||
| def download_document(dataset_id, document_id,tenant_id): | |||||
| try: | try: | ||||
| # Check whether there is this document | # Check whether there is this document | ||||
| exist, document = DocumentService.get_by_id(document_id) | exist, document = DocumentService.get_by_id(document_id) | ||||
| try: | try: | ||||
| docs, tol = DocumentService.get_by_kb_id( | docs, tol = DocumentService.get_by_kb_id( | ||||
| kb_id, page_number, items_per_page, orderby, desc, keywords) | kb_id, page_number, items_per_page, orderby, desc, keywords) | ||||
| return get_json_result(data={"total": tol, "docs": docs}) | |||||
| # rename key's name | |||||
| renamed_doc_list = [] | |||||
| for doc in docs: | |||||
| key_mapping = { | |||||
| "chunk_num": "chunk_count", | |||||
| "kb_id": "knowledgebase_id", | |||||
| "token_num": "token_count", | |||||
| } | |||||
| renamed_doc = {} | |||||
| for key, value in doc.items(): | |||||
| new_key = key_mapping.get(key, key) | |||||
| renamed_doc[new_key] = value | |||||
| renamed_doc_list.append(renamed_doc) | |||||
| return get_json_result(data={"total": tol, "docs": renamed_doc_list}) | |||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| query["available_int"] = int(req["available_int"]) | query["available_int"] = int(req["available_int"]) | ||||
| sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) | sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) | ||||
| res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | ||||
| origin_chunks=[] | |||||
| for id in sres.ids: | for id in sres.ids: | ||||
| d = { | d = { | ||||
| "chunk_id": id, | "chunk_id": id, | ||||
| poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]), | poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]), | ||||
| float(d["positions"][i + 3]), float(d["positions"][i + 4])]) | float(d["positions"][i + 3]), float(d["positions"][i + 4])]) | ||||
| d["positions"] = poss | d["positions"] = poss | ||||
| res["chunks"].append(d) | |||||
| origin_chunks.append(d) | |||||
| ##rename keys | |||||
| for chunk in origin_chunks: | |||||
| key_mapping = { | |||||
| "chunk_id": "id", | |||||
| "content_with_weight": "content", | |||||
| "doc_id": "document_id", | |||||
| "important_kwd": "important_keywords", | |||||
| } | |||||
| renamed_chunk = {} | |||||
| for key, value in chunk.items(): | |||||
| new_key = key_mapping.get(key, key) | |||||
| renamed_chunk[new_key] = value | |||||
| res["chunks"].append(renamed_chunk) | |||||
| return get_json_result(data=res) | return get_json_result(data=res) | ||||
| except Exception as e: | except Exception as e: | ||||
| if str(e).find("not_found") > 0: | if str(e).find("not_found") > 0: | ||||
| req = request.json | req = request.json | ||||
| md5 = hashlib.md5() | md5 = hashlib.md5() | ||||
| md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) | md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) | ||||
| chunck_id = md5.hexdigest() | |||||
| d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), | |||||
| chunk_id = md5.hexdigest() | |||||
| d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), | |||||
| "content_with_weight": req["content_with_weight"]} | "content_with_weight": req["content_with_weight"]} | ||||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | ||||
| d["important_kwd"] = req.get("important_kwd", []) | d["important_kwd"] = req.get("important_kwd", []) | ||||
| DocumentService.increment_chunk_num( | DocumentService.increment_chunk_num( | ||||
| doc.id, doc.kb_id, c, 1, 0) | doc.id, doc.kb_id, c, 1, 0) | ||||
| return get_json_result(data={"chunk": d}) | |||||
| # return get_json_result(data={"chunk_id": chunck_id}) | |||||
| d["chunk_id"] = chunk_id | |||||
| #rename keys | |||||
| key_mapping = { | |||||
| "chunk_id": "id", | |||||
| "content_with_weight": "content", | |||||
| "doc_id": "document_id", | |||||
| "important_kwd": "important_keywords", | |||||
| "kb_id":"knowledge_base_id", | |||||
| } | |||||
| renamed_chunk = {} | |||||
| for key, value in d.items(): | |||||
| new_key = key_mapping.get(key, key) | |||||
| renamed_chunk[new_key] = value | |||||
| return get_json_result(data={"chunk": renamed_chunk}) | |||||
| # return get_json_result(data={"chunk_id": chunk_id}) | |||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| @manager.route('/chunk/rm', methods=['POST']) | @manager.route('/chunk/rm', methods=['POST']) | ||||
| @token_required | @token_required | ||||
| @validate_request("chunk_ids", "doc_id") | @validate_request("chunk_ids", "doc_id") | ||||
| def rm_chunk(): | |||||
| def rm_chunk(tenant_id): | |||||
| req = request.json | req = request.json | ||||
| try: | try: | ||||
| if not ELASTICSEARCH.deleteByQuery( | if not ELASTICSEARCH.deleteByQuery( | ||||
| Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): | |||||
| Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)): | |||||
| return get_data_error_result(retmsg="Index updating failure") | return get_data_error_result(retmsg="Index updating failure") | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| if not e: | if not e: | ||||
| DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) | DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) | ||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | |||||
| @manager.route('/chunk/set', methods=['POST']) | |||||
| @token_required | |||||
| @validate_request("doc_id", "chunk_id", "content_with_weight", | |||||
| "important_kwd") | |||||
| def set(tenant_id): | |||||
| req = request.json | |||||
| d = { | |||||
| "id": req["chunk_id"], | |||||
| "content_with_weight": req["content_with_weight"]} | |||||
| d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) | |||||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | |||||
| d["important_kwd"] = req["important_kwd"] | |||||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) | |||||
| if "available_int" in req: | |||||
| d["available_int"] = req["available_int"] | |||||
| try: | |||||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||||
| if not tenant_id: | |||||
| return get_data_error_result(retmsg="Tenant not found!") | |||||
| embd_id = DocumentService.get_embd_id(req["doc_id"]) | |||||
| embd_mdl = TenantLLMService.model_instance( | |||||
| tenant_id, LLMType.EMBEDDING.value, embd_id) | |||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Document not found!") | |||||
| if doc.parser_id == ParserType.QA: | |||||
| arr = [ | |||||
| t for t in re.split( | |||||
| r"[\n\t]", | |||||
| req["content_with_weight"]) if len(t) > 1] | |||||
| if len(arr) != 2: | |||||
| return get_data_error_result( | |||||
| retmsg="Q&A must be separated by TAB/ENTER key.") | |||||
| q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) | |||||
| d = beAdoc(d, arr[0], arr[1], not any( | |||||
| [rag_tokenizer.is_chinese(t) for t in q + a])) | |||||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | |||||
| d["q_%d_vec" % len(v)] = v.tolist() | |||||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||||
| return get_json_result(data=True) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/retrieval_test', methods=['POST']) | |||||
| @token_required | |||||
| @validate_request("kb_id", "question") | |||||
| def retrieval_test(tenant_id): | |||||
| req = request.json | |||||
| page = int(req.get("page", 1)) | |||||
| size = int(req.get("size", 30)) | |||||
| question = req["question"] | |||||
| kb_id = req["kb_id"] | |||||
| if isinstance(kb_id, str): kb_id = [kb_id] | |||||
| doc_ids = req.get("doc_ids", []) | |||||
| similarity_threshold = float(req.get("similarity_threshold", 0.2)) | |||||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | |||||
| top = int(req.get("top_k", 1024)) | |||||
| try: | |||||
| tenants = UserTenantService.query(user_id=tenant_id) | |||||
| for kid in kb_id: | |||||
| for tenant in tenants: | |||||
| if KnowledgebaseService.query( | |||||
| tenant_id=tenant.tenant_id, id=kid): | |||||
| break | |||||
| else: | |||||
| return get_json_result( | |||||
| data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', | |||||
| retcode=RetCode.OPERATING_ERROR) | |||||
| e, kb = KnowledgebaseService.get_by_id(kb_id[0]) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Knowledgebase not found!") | |||||
| embd_mdl = TenantLLMService.model_instance( | |||||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |||||
| rerank_mdl = None | |||||
| if req.get("rerank_id"): | |||||
| rerank_mdl = TenantLLMService.model_instance( | |||||
| kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||||
| if req.get("keyword", False): | |||||
| chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) | |||||
| question += keyword_extraction(chat_mdl, question) | |||||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||||
| ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size, | |||||
| similarity_threshold, vector_similarity_weight, top, | |||||
| doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) | |||||
| for c in ranks["chunks"]: | |||||
| if "vector" in c: | |||||
| del c["vector"] | |||||
| ##rename keys | |||||
| renamed_chunks=[] | |||||
| for chunk in ranks["chunks"]: | |||||
| key_mapping = { | |||||
| "chunk_id": "id", | |||||
| "content_with_weight": "content", | |||||
| "doc_id": "document_id", | |||||
| "important_kwd": "important_keywords", | |||||
| } | |||||
| rename_chunk={} | |||||
| for key, value in chunk.items(): | |||||
| new_key = key_mapping.get(key, key) | |||||
| rename_chunk[new_key] = value | |||||
| renamed_chunks.append(rename_chunk) | |||||
| ranks["chunks"] = renamed_chunks | |||||
| return get_json_result(data=ranks) | |||||
| except Exception as e: | |||||
| if str(e).find("not_found") > 0: | |||||
| return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', | |||||
| retcode=RetCode.DATA_ERROR) | |||||
| return server_error_response(e) | return server_error_response(e) | 
| class Chunk(Base): | class Chunk(Base): | ||||
| def __init__(self, rag, res_dict): | def __init__(self, rag, res_dict): | ||||
| # 初始化类的属性 | |||||
| self.id = "" | self.id = "" | ||||
| self.content_with_weight = "" | |||||
| self.content_ltks = [] | |||||
| self.content_sm_ltks = [] | |||||
| self.important_kwd = [] | |||||
| self.important_tks = [] | |||||
| self.content = "" | |||||
| self.important_keywords = [] | |||||
| self.create_time = "" | self.create_time = "" | ||||
| self.create_timestamp_flt = 0.0 | self.create_timestamp_flt = 0.0 | ||||
| self.kb_id = None | |||||
| self.docnm_kwd = "" | |||||
| self.doc_id = "" | |||||
| self.q_vec = [] | |||||
| self.knowledgebase_id = None | |||||
| self.document_name = "" | |||||
| self.document_id = "" | |||||
| self.status = "1" | self.status = "1" | ||||
| for k, v in res_dict.items(): | |||||
| if hasattr(self, k): | |||||
| setattr(self, k, v) | |||||
| for k in list(res_dict.keys()): | |||||
| if k not in self.__dict__: | |||||
| res_dict.pop(k) | |||||
| super().__init__(rag, res_dict) | super().__init__(rag, res_dict) | ||||
| def delete(self) -> bool: | def delete(self) -> bool: | ||||
| """ | """ | ||||
| Delete the chunk in the document. | Delete the chunk in the document. | ||||
| """ | """ | ||||
| res = self.rm('/doc/chunk/rm', | |||||
| {"doc_id": [self.id],""}) | |||||
| res = self.post('/doc/chunk/rm', | |||||
| {"doc_id": self.document_id, 'chunk_ids': [self.id]}) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("retmsg") == "success": | if res.get("retmsg") == "success": | ||||
| return True | return True | ||||
| raise Exception(res["retmsg"]) | |||||
| raise Exception(res["retmsg"]) | |||||
| def save(self) -> bool: | |||||
| """ | |||||
| Save the document details to the server. | |||||
| """ | |||||
| res = self.post('/doc/chunk/set', | |||||
| {"chunk_id": self.id, | |||||
| "kb_id": self.knowledgebase_id, | |||||
| "name": self.document_name, | |||||
| "content_with_weight": self.content, | |||||
| "important_kwd": self.important_keywords, | |||||
| "create_time": self.create_time, | |||||
| "create_timestamp_flt": self.create_timestamp_flt, | |||||
| "doc_id": self.document_id, | |||||
| "status": self.status, | |||||
| }) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": | |||||
| return True | |||||
| raise Exception(res["retmsg"]) | |||||
| self.id = "" | self.id = "" | ||||
| self.name = "" | self.name = "" | ||||
| self.thumbnail = None | self.thumbnail = None | ||||
| self.kb_id = None | |||||
| self.knowledgebase_id = None | |||||
| self.parser_method = "" | self.parser_method = "" | ||||
| self.parser_config = {"pages": [[1, 1000000]]} | self.parser_config = {"pages": [[1, 1000000]]} | ||||
| self.source_type = "local" | self.source_type = "local" | ||||
| self.type = "" | self.type = "" | ||||
| self.created_by = "" | self.created_by = "" | ||||
| self.size = 0 | self.size = 0 | ||||
| self.token_num = 0 | |||||
| self.chunk_num = 0 | |||||
| self.token_count = 0 | |||||
| self.chunk_count = 0 | |||||
| self.progress = 0.0 | self.progress = 0.0 | ||||
| self.progress_msg = "" | self.progress_msg = "" | ||||
| self.process_begin_at = None | self.process_begin_at = None | ||||
| Save the document details to the server. | Save the document details to the server. | ||||
| """ | """ | ||||
| res = self.post('/doc/save', | res = self.post('/doc/save', | ||||
| {"id": self.id, "name": self.name, "thumbnail": self.thumbnail, "kb_id": self.kb_id, | |||||
| {"id": self.id, "name": self.name, "thumbnail": self.thumbnail, "kb_id": self.knowledgebase_id, | |||||
| "parser_id": self.parser_method, "parser_config": self.parser_config.to_json(), | "parser_id": self.parser_method, "parser_config": self.parser_config.to_json(), | ||||
| "source_type": self.source_type, "type": self.type, "created_by": self.created_by, | "source_type": self.source_type, "type": self.type, "created_by": self.created_by, | ||||
| "size": self.size, "token_num": self.token_num, "chunk_num": self.chunk_num, | |||||
| "size": self.size, "token_num": self.token_count, "chunk_num": self.chunk_count, | |||||
| "progress": self.progress, "progress_msg": self.progress_msg, | "progress": self.progress, "progress_msg": self.progress_msg, | ||||
| "process_begin_at": self.process_begin_at, "process_duation": self.process_duration | "process_begin_at": self.process_begin_at, "process_duation": self.process_duration | ||||
| }) | }) | ||||
| if res.status_code == 200: | if res.status_code == 200: | ||||
| res_data = res.json() | res_data = res.json() | ||||
| if res_data.get("retmsg") == "success": | if res_data.get("retmsg") == "success": | ||||
| chunks = res_data["data"]["chunks"] | |||||
| self.chunks = chunks # Store the chunks in the document instance | |||||
| chunks=[] | |||||
| for chunk_data in res_data["data"].get("chunks", []): | |||||
| chunk=Chunk(self.rag,chunk_data) | |||||
| chunks.append(chunk) | |||||
| return chunks | return chunks | ||||
| else: | else: | ||||
| raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") | raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") | ||||
| def add_chunk(self, content: str): | def add_chunk(self, content: str): | ||||
| res = self.post('/doc/chunk/create', {"doc_id": self.id, "content_with_weight":content}) | res = self.post('/doc/chunk/create', {"doc_id": self.id, "content_with_weight":content}) | ||||
| # 假设返回的 response 包含 chunk 的信息 | |||||
| if res.status_code == 200: | if res.status_code == 200: | ||||
| chunk_data = res.json() | |||||
| return Chunk(self.rag,chunk_data) # 假设有一个 Chunk 类来处理 chunk 对象 | |||||
| res_data = res.json().get("data") | |||||
| chunk_data = res_data.get("chunk") | |||||
| return Chunk(self.rag,chunk_data) | |||||
| else: | else: | ||||
| raise Exception(f"Failed to add chunk: {res.status_code} {res.text}") | raise Exception(f"Failed to add chunk: {res.status_code} {res.text}") | 
| from .modules.assistant import Assistant | from .modules.assistant import Assistant | ||||
| from .modules.dataset import DataSet | from .modules.dataset import DataSet | ||||
| from .modules.document import Document | from .modules.document import Document | ||||
| from .modules.chunk import Chunk | |||||
| class RAGFlow: | class RAGFlow: | ||||
| def __init__(self, user_key, base_url, version='v1'): | def __init__(self, user_key, base_url, version='v1'): | ||||
| return result_list | return result_list | ||||
| raise Exception(res["retmsg"]) | raise Exception(res["retmsg"]) | ||||
| def create_document(self, ds:DataSet, name: str, blob: bytes) -> bool: | |||||
| def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool: | |||||
| url = f"/doc/dataset/{ds.id}/documents/upload" | url = f"/doc/dataset/{ds.id}/documents/upload" | ||||
| files = { | files = { | ||||
| 'file': (name, blob) | 'file': (name, blob) | ||||
| raise Exception(f"Upload failed: {response.json().get('retmsg')}") | raise Exception(f"Upload failed: {response.json().get('retmsg')}") | ||||
| return False | return False | ||||
| def get_document(self, id: str = None, name: str = None) -> Document: | def get_document(self, id: str = None, name: str = None) -> Document: | ||||
| res = self.get("/doc/infos", {"id": id, "name": name}) | res = self.get("/doc/infos", {"id": id, "name": name}) | ||||
| res = res.json() | res = res.json() | ||||
| if not doc_ids or not isinstance(doc_ids, list): | if not doc_ids or not isinstance(doc_ids, list): | ||||
| raise ValueError("doc_ids must be a non-empty list of document IDs") | raise ValueError("doc_ids must be a non-empty list of document IDs") | ||||
| data = {"doc_ids": doc_ids, "run": 2} | data = {"doc_ids": doc_ids, "run": 2} | ||||
| res = self.post(f'/doc/run', data) | res = self.post(f'/doc/run', data) | ||||
| if res.status_code != 200: | if res.status_code != 200: | ||||
| print(f"Error occurred during canceling parsing for documents: {str(e)}") | print(f"Error occurred during canceling parsing for documents: {str(e)}") | ||||
| raise | raise | ||||
| def retrieval(self, | |||||
| question, | |||||
| datasets=None, | |||||
| documents=None, | |||||
| offset=0, | |||||
| limit=6, | |||||
| similarity_threshold=0.1, | |||||
| vector_similarity_weight=0.3, | |||||
| top_k=1024): | |||||
| """ | |||||
| Perform document retrieval based on the given parameters. | |||||
| :param question: The query question. | |||||
| :param datasets: A list of datasets (optional, as documents may be provided directly). | |||||
| :param documents: A list of documents (if specific documents are provided). | |||||
| :param offset: Offset for the retrieval results. | |||||
| :param limit: Maximum number of retrieval results. | |||||
| :param similarity_threshold: Similarity threshold. | |||||
| :param vector_similarity_weight: Weight of vector similarity. | |||||
| :param top_k: Number of top most similar documents to consider (for pre-filtering or ranking). | |||||
| Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API. | |||||
| """ | |||||
| try: | |||||
| data = { | |||||
| "question": question, | |||||
| "datasets": datasets if datasets is not None else [], | |||||
| "documents": [doc.id if hasattr(doc, 'id') else doc for doc in | |||||
| documents] if documents is not None else [], | |||||
| "offset": offset, | |||||
| "limit": limit, | |||||
| "similarity_threshold": similarity_threshold, | |||||
| "vector_similarity_weight": vector_similarity_weight, | |||||
| "top_k": top_k, | |||||
| "kb_id": datasets, | |||||
| } | |||||
| # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) | |||||
| res = self.post(f'/doc/retrieval_test', data) | |||||
| # Check the response status code | |||||
| if res.status_code == 200: | |||||
| res_data = res.json() | |||||
| if res_data.get("retmsg") == "success": | |||||
| chunks = [] | |||||
| for chunk_data in res_data["data"].get("chunks", []): | |||||
| chunk = Chunk(self, chunk_data) | |||||
| chunks.append(chunk) | |||||
| return chunks | |||||
| else: | |||||
| raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") | |||||
| else: | |||||
| raise Exception(f"API request failed with status code {res.status_code}") | |||||
| except Exception as e: | |||||
| print(f"An error occurred during retrieval: {e}") | |||||
| raise | |||||
| def test_update_document_with_success(self): | def test_update_document_with_success(self): | ||||
| """ | """ | ||||
| Test updating a document with success. | Test updating a document with success. | ||||
| Update name or parser_method are supported | |||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| doc = rag.get_document(name="TestDocument.txt") | doc = rag.get_document(name="TestDocument.txt") | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| # Retrieve a document | # Retrieve a document | ||||
| doc = rag.get_document(name="TestDocument.txt") | |||||
| doc = rag.get_document(name="manual.txt") | |||||
| # Check if the retrieved document is of type Document | # Check if the retrieved document is of type Document | ||||
| if isinstance(doc, Document): | if isinstance(doc, Document): | ||||
| ds = rag.create_dataset(name="God4") | ds = rag.create_dataset(name="God4") | ||||
| # Define the document name and path | # Define the document name and path | ||||
| name3 = 'ai.pdf' | |||||
| path = 'test_data/ai.pdf' | |||||
| name3 = 'westworld.pdf' | |||||
| path = 'test_data/westworld.pdf' | |||||
| # Create a document in the dataset using the file path | # Create a document in the dataset using the file path | ||||
| rag.create_document(ds, name=name3, blob=open(path, "rb").read()) | rag.create_document(ds, name=name3, blob=open(path, "rb").read()) | ||||
| # Retrieve the document by name | # Retrieve the document by name | ||||
| doc = rag.get_document(name="ai.pdf") | |||||
| doc = rag.get_document(name="westworld.pdf") | |||||
| # Initiate asynchronous parsing | # Initiate asynchronous parsing | ||||
| doc.async_parse() | doc.async_parse() | ||||
| # Prepare a list of file names and paths | # Prepare a list of file names and paths | ||||
| documents = [ | documents = [ | ||||
| {'name': 'ai1.pdf', 'path': 'test_data/ai1.pdf'}, | |||||
| {'name': 'ai2.pdf', 'path': 'test_data/ai2.pdf'}, | |||||
| {'name': 'ai3.pdf', 'path': 'test_data/ai3.pdf'} | |||||
| {'name': 'test1.txt', 'path': 'test_data/test1.txt'}, | |||||
| {'name': 'test2.txt', 'path': 'test_data/test2.txt'}, | |||||
| {'name': 'test3.txt', 'path': 'test_data/test3.txt'} | |||||
| ] | ] | ||||
| # Create documents in bulk | # Create documents in bulk | ||||
| print(c) | print(c) | ||||
| assert c is not None, "Chunk is None" | assert c is not None, "Chunk is None" | ||||
| assert "rag" in c['content_with_weight'].lower(), f"Keyword 'rag' not found in chunk content: {c.content}" | assert "rag" in c['content_with_weight'].lower(), f"Keyword 'rag' not found in chunk content: {c.content}" | ||||
| def test_add_chunk_to_chunk_list(self): | def test_add_chunk_to_chunk_list(self): | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| doc = rag.get_document(name='story.txt') | doc = rag.get_document(name='story.txt') | ||||
| def test_delete_chunk_of_chunk_list(self): | def test_delete_chunk_of_chunk_list(self): | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| doc = rag.get_document(name='story.txt') | doc = rag.get_document(name='story.txt') | ||||
| chunk = doc.add_chunk(content="assss") | chunk = doc.add_chunk(content="assss") | ||||
| assert chunk is not None, "Chunk is None" | assert chunk is not None, "Chunk is None" | ||||
| assert isinstance(chunk, Chunk), "Chunk was not added to chunk list" | assert isinstance(chunk, Chunk), "Chunk was not added to chunk list" | ||||
| chunk_num_before=doc.chunk_num | |||||
| doc = rag.get_document(name='story.txt') | |||||
| chunk_count_before=doc.chunk_count | |||||
| chunk.delete() | chunk.delete() | ||||
| assert doc.chunk_num == chunk_num_before-1, "Chunk was not deleted" | |||||
| doc = rag.get_document(name='story.txt') | |||||
| assert doc.chunk_count == chunk_count_before-1, "Chunk was not deleted" | |||||
| def test_update_chunk_content(self): | |||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| doc = rag.get_document(name='story.txt') | |||||
| chunk = doc.add_chunk(content="assssd") | |||||
| assert chunk is not None, "Chunk is None" | |||||
| assert isinstance(chunk, Chunk), "Chunk was not added to chunk list" | |||||
| chunk.content = "ragflow123" | |||||
| res=chunk.save() | |||||
| assert res is True, f"Failed to update chunk, error: {res}" | |||||
| def test_retrieval_chunks(self): | |||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| ds = rag.create_dataset(name="God8") | |||||
| name = 'ragflow_test.txt' | |||||
| path = 'test_data/ragflow_test.txt' | |||||
| rag.create_document(ds, name=name, blob=open(path, "rb").read()) | |||||
| doc = rag.get_document(name=name) | |||||
| doc.async_parse() | |||||
| # Wait for parsing to complete and get progress updates using join | |||||
| for progress, msg in doc.join(interval=5, timeout=30): | |||||
| print(progress, msg) | |||||
| assert 0 <= progress <= 100, f"Invalid progress: {progress}" | |||||
| assert msg, "Message should not be empty" | |||||
| for c in rag.retrieval(question="What's ragflow?", | |||||
| datasets=[ds.id], documents=[doc], | |||||
| offset=0, limit=6, similarity_threshold=0.1, | |||||
| vector_similarity_weight=0.3, | |||||
| top_k=1024 | |||||
| ): | |||||
| print(c) | |||||
| assert c is not None, "Chunk is None" | |||||
| assert "ragflow" in c.content.lower(), f"Keyword 'rag' not found in chunk content: {c.content}" |