### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.12.0
| @@ -84,15 +84,28 @@ def upload(dataset_id, tenant_id): | |||
| @token_required | |||
| def docinfos(tenant_id): | |||
| req = request.args | |||
| if "id" not in req and "name" not in req: | |||
| return get_data_error_result( | |||
| retmsg="Id or name should be provided") | |||
| doc_id=None | |||
| if "id" in req: | |||
| doc_id = req["id"] | |||
| e, doc = DocumentService.get_by_id(doc_id) | |||
| return get_json_result(data=doc.to_json()) | |||
| if "name" in req: | |||
| doc_name = req["name"] | |||
| doc_id = DocumentService.get_doc_id_by_doc_name(doc_name) | |||
| e, doc = DocumentService.get_by_id(doc_id) | |||
| return get_json_result(data=doc.to_json()) | |||
| e, doc = DocumentService.get_by_id(doc_id) | |||
| #rename key's name | |||
| key_mapping = { | |||
| "chunk_num": "chunk_count", | |||
| "kb_id": "knowledgebase_id", | |||
| "token_num": "token_count", | |||
| } | |||
| renamed_doc = {} | |||
| for key, value in doc.to_dict().items(): | |||
| new_key = key_mapping.get(key, key) | |||
| renamed_doc[new_key] = value | |||
| return get_json_result(data=renamed_doc) | |||
| @manager.route('/save', methods=['POST']) | |||
| @@ -246,7 +259,7 @@ def rename(): | |||
| req["doc_id"], {"name": req["name"]}): | |||
| return get_data_error_result( | |||
| retmsg="Database error (Document rename)!") | |||
| informs = File2DocumentService.get_by_document_id(req["doc_id"]) | |||
| if informs: | |||
| e, file = FileService.get_by_id(informs[0].file_id) | |||
| @@ -259,7 +272,7 @@ def rename(): | |||
| @manager.route("/<document_id>", methods=["GET"]) | |||
| @token_required | |||
| def download_document(dataset_id, document_id): | |||
| def download_document(dataset_id, document_id,tenant_id): | |||
| try: | |||
| # Check whether there is this document | |||
| exist, document = DocumentService.get_by_id(document_id) | |||
| @@ -313,7 +326,21 @@ def list_docs(dataset_id, tenant_id): | |||
| try: | |||
| docs, tol = DocumentService.get_by_kb_id( | |||
| kb_id, page_number, items_per_page, orderby, desc, keywords) | |||
| return get_json_result(data={"total": tol, "docs": docs}) | |||
| # rename key's name | |||
| renamed_doc_list = [] | |||
| for doc in docs: | |||
| key_mapping = { | |||
| "chunk_num": "chunk_count", | |||
| "kb_id": "knowledgebase_id", | |||
| "token_num": "token_count", | |||
| } | |||
| renamed_doc = {} | |||
| for key, value in doc.items(): | |||
| new_key = key_mapping.get(key, key) | |||
| renamed_doc[new_key] = value | |||
| renamed_doc_list.append(renamed_doc) | |||
| return get_json_result(data={"total": tol, "docs": renamed_doc_list}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -436,6 +463,8 @@ def list_chunk(tenant_id): | |||
| query["available_int"] = int(req["available_int"]) | |||
| sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) | |||
| res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | |||
| origin_chunks=[] | |||
| for id in sres.ids: | |||
| d = { | |||
| "chunk_id": id, | |||
| @@ -455,7 +484,21 @@ def list_chunk(tenant_id): | |||
| poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]), | |||
| float(d["positions"][i + 3]), float(d["positions"][i + 4])]) | |||
| d["positions"] = poss | |||
| res["chunks"].append(d) | |||
| origin_chunks.append(d) | |||
| ##rename keys | |||
| for chunk in origin_chunks: | |||
| key_mapping = { | |||
| "chunk_id": "id", | |||
| "content_with_weight": "content", | |||
| "doc_id": "document_id", | |||
| "important_kwd": "important_keywords", | |||
| } | |||
| renamed_chunk = {} | |||
| for key, value in chunk.items(): | |||
| new_key = key_mapping.get(key, key) | |||
| renamed_chunk[new_key] = value | |||
| res["chunks"].append(renamed_chunk) | |||
| return get_json_result(data=res) | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| @@ -471,8 +514,9 @@ def create(tenant_id): | |||
| req = request.json | |||
| md5 = hashlib.md5() | |||
| md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) | |||
| chunck_id = md5.hexdigest() | |||
| d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), | |||
| chunk_id = md5.hexdigest() | |||
| d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), | |||
| "content_with_weight": req["content_with_weight"]} | |||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | |||
| d["important_kwd"] = req.get("important_kwd", []) | |||
| @@ -503,20 +547,33 @@ def create(tenant_id): | |||
| DocumentService.increment_chunk_num( | |||
| doc.id, doc.kb_id, c, 1, 0) | |||
| return get_json_result(data={"chunk": d}) | |||
| # return get_json_result(data={"chunk_id": chunck_id}) | |||
| d["chunk_id"] = chunk_id | |||
| #rename keys | |||
| key_mapping = { | |||
| "chunk_id": "id", | |||
| "content_with_weight": "content", | |||
| "doc_id": "document_id", | |||
| "important_kwd": "important_keywords", | |||
| "kb_id":"knowledge_base_id", | |||
| } | |||
| renamed_chunk = {} | |||
| for key, value in d.items(): | |||
| new_key = key_mapping.get(key, key) | |||
| renamed_chunk[new_key] = value | |||
| return get_json_result(data={"chunk": renamed_chunk}) | |||
| # return get_json_result(data={"chunk_id": chunk_id}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/chunk/rm', methods=['POST']) | |||
| @token_required | |||
| @validate_request("chunk_ids", "doc_id") | |||
| def rm_chunk(): | |||
| def rm_chunk(tenant_id): | |||
| req = request.json | |||
| try: | |||
| if not ELASTICSEARCH.deleteByQuery( | |||
| Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): | |||
| Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)): | |||
| return get_data_error_result(retmsg="Index updating failure") | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| @@ -526,4 +583,126 @@ def rm_chunk(): | |||
| DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/chunk/set', methods=['POST']) | |||
| @token_required | |||
| @validate_request("doc_id", "chunk_id", "content_with_weight", | |||
| "important_kwd") | |||
| def set(tenant_id): | |||
| req = request.json | |||
| d = { | |||
| "id": req["chunk_id"], | |||
| "content_with_weight": req["content_with_weight"]} | |||
| d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) | |||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | |||
| d["important_kwd"] = req["important_kwd"] | |||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) | |||
| if "available_int" in req: | |||
| d["available_int"] = req["available_int"] | |||
| try: | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| embd_id = DocumentService.get_embd_id(req["doc_id"]) | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| tenant_id, LLMType.EMBEDDING.value, embd_id) | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if doc.parser_id == ParserType.QA: | |||
| arr = [ | |||
| t for t in re.split( | |||
| r"[\n\t]", | |||
| req["content_with_weight"]) if len(t) > 1] | |||
| if len(arr) != 2: | |||
| return get_data_error_result( | |||
| retmsg="Q&A must be separated by TAB/ENTER key.") | |||
| q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) | |||
| d = beAdoc(d, arr[0], arr[1], not any( | |||
| [rag_tokenizer.is_chinese(t) for t in q + a])) | |||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/retrieval_test', methods=['POST']) | |||
| @token_required | |||
| @validate_request("kb_id", "question") | |||
| def retrieval_test(tenant_id): | |||
| req = request.json | |||
| page = int(req.get("page", 1)) | |||
| size = int(req.get("size", 30)) | |||
| question = req["question"] | |||
| kb_id = req["kb_id"] | |||
| if isinstance(kb_id, str): kb_id = [kb_id] | |||
| doc_ids = req.get("doc_ids", []) | |||
| similarity_threshold = float(req.get("similarity_threshold", 0.2)) | |||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | |||
| top = int(req.get("top_k", 1024)) | |||
| try: | |||
| tenants = UserTenantService.query(user_id=tenant_id) | |||
| for kid in kb_id: | |||
| for tenant in tenants: | |||
| if KnowledgebaseService.query( | |||
| tenant_id=tenant.tenant_id, id=kid): | |||
| break | |||
| else: | |||
| return get_json_result( | |||
| data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', | |||
| retcode=RetCode.OPERATING_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id[0]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Knowledgebase not found!") | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |||
| rerank_mdl = None | |||
| if req.get("rerank_id"): | |||
| rerank_mdl = TenantLLMService.model_instance( | |||
| kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||
| if req.get("keyword", False): | |||
| chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) | |||
| question += keyword_extraction(chat_mdl, question) | |||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||
| ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size, | |||
| similarity_threshold, vector_similarity_weight, top, | |||
| doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) | |||
| for c in ranks["chunks"]: | |||
| if "vector" in c: | |||
| del c["vector"] | |||
| ##rename keys | |||
| renamed_chunks=[] | |||
| for chunk in ranks["chunks"]: | |||
| key_mapping = { | |||
| "chunk_id": "id", | |||
| "content_with_weight": "content", | |||
| "doc_id": "document_id", | |||
| "important_kwd": "important_keywords", | |||
| } | |||
| rename_chunk={} | |||
| for key, value in chunk.items(): | |||
| new_key = key_mapping.get(key, key) | |||
| rename_chunk[new_key] = value | |||
| renamed_chunks.append(rename_chunk) | |||
| ranks["chunks"] = renamed_chunks | |||
| return get_json_result(data=ranks) | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', | |||
| retcode=RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @@ -3,32 +3,48 @@ from .base import Base | |||
| class Chunk(Base): | |||
| def __init__(self, rag, res_dict): | |||
| # 初始化类的属性 | |||
| self.id = "" | |||
| self.content_with_weight = "" | |||
| self.content_ltks = [] | |||
| self.content_sm_ltks = [] | |||
| self.important_kwd = [] | |||
| self.important_tks = [] | |||
| self.content = "" | |||
| self.important_keywords = [] | |||
| self.create_time = "" | |||
| self.create_timestamp_flt = 0.0 | |||
| self.kb_id = None | |||
| self.docnm_kwd = "" | |||
| self.doc_id = "" | |||
| self.q_vec = [] | |||
| self.knowledgebase_id = None | |||
| self.document_name = "" | |||
| self.document_id = "" | |||
| self.status = "1" | |||
| for k, v in res_dict.items(): | |||
| if hasattr(self, k): | |||
| setattr(self, k, v) | |||
| for k in list(res_dict.keys()): | |||
| if k not in self.__dict__: | |||
| res_dict.pop(k) | |||
| super().__init__(rag, res_dict) | |||
| def delete(self) -> bool: | |||
| """ | |||
| Delete the chunk in the document. | |||
| """ | |||
| res = self.rm('/doc/chunk/rm', | |||
| {"doc_id": [self.id],""}) | |||
| res = self.post('/doc/chunk/rm', | |||
| {"doc_id": self.document_id, 'chunk_ids': [self.id]}) | |||
| res = res.json() | |||
| if res.get("retmsg") == "success": | |||
| return True | |||
| raise Exception(res["retmsg"]) | |||
| raise Exception(res["retmsg"]) | |||
| def save(self) -> bool: | |||
| """ | |||
| Save the document details to the server. | |||
| """ | |||
| res = self.post('/doc/chunk/set', | |||
| {"chunk_id": self.id, | |||
| "kb_id": self.knowledgebase_id, | |||
| "name": self.document_name, | |||
| "content_with_weight": self.content, | |||
| "important_kwd": self.important_keywords, | |||
| "create_time": self.create_time, | |||
| "create_timestamp_flt": self.create_timestamp_flt, | |||
| "doc_id": self.document_id, | |||
| "status": self.status, | |||
| }) | |||
| res = res.json() | |||
| if res.get("retmsg") == "success": | |||
| return True | |||
| raise Exception(res["retmsg"]) | |||
| @@ -9,15 +9,15 @@ class Document(Base): | |||
| self.id = "" | |||
| self.name = "" | |||
| self.thumbnail = None | |||
| self.kb_id = None | |||
| self.knowledgebase_id = None | |||
| self.parser_method = "" | |||
| self.parser_config = {"pages": [[1, 1000000]]} | |||
| self.source_type = "local" | |||
| self.type = "" | |||
| self.created_by = "" | |||
| self.size = 0 | |||
| self.token_num = 0 | |||
| self.chunk_num = 0 | |||
| self.token_count = 0 | |||
| self.chunk_count = 0 | |||
| self.progress = 0.0 | |||
| self.progress_msg = "" | |||
| self.process_begin_at = None | |||
| @@ -34,10 +34,10 @@ class Document(Base): | |||
| Save the document details to the server. | |||
| """ | |||
| res = self.post('/doc/save', | |||
| {"id": self.id, "name": self.name, "thumbnail": self.thumbnail, "kb_id": self.kb_id, | |||
| {"id": self.id, "name": self.name, "thumbnail": self.thumbnail, "kb_id": self.knowledgebase_id, | |||
| "parser_id": self.parser_method, "parser_config": self.parser_config.to_json(), | |||
| "source_type": self.source_type, "type": self.type, "created_by": self.created_by, | |||
| "size": self.size, "token_num": self.token_num, "chunk_num": self.chunk_num, | |||
| "size": self.size, "token_num": self.token_count, "chunk_num": self.chunk_count, | |||
| "progress": self.progress, "progress_msg": self.progress_msg, | |||
| "process_begin_at": self.process_begin_at, "process_duation": self.process_duration | |||
| }) | |||
| @@ -177,8 +177,10 @@ class Document(Base): | |||
| if res.status_code == 200: | |||
| res_data = res.json() | |||
| if res_data.get("retmsg") == "success": | |||
| chunks = res_data["data"]["chunks"] | |||
| self.chunks = chunks # Store the chunks in the document instance | |||
| chunks=[] | |||
| for chunk_data in res_data["data"].get("chunks", []): | |||
| chunk=Chunk(self.rag,chunk_data) | |||
| chunks.append(chunk) | |||
| return chunks | |||
| else: | |||
| raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") | |||
| @@ -187,10 +189,9 @@ class Document(Base): | |||
| def add_chunk(self, content: str): | |||
| res = self.post('/doc/chunk/create', {"doc_id": self.id, "content_with_weight":content}) | |||
| # 假设返回的 response 包含 chunk 的信息 | |||
| if res.status_code == 200: | |||
| chunk_data = res.json() | |||
| return Chunk(self.rag,chunk_data) # 假设有一个 Chunk 类来处理 chunk 对象 | |||
| res_data = res.json().get("data") | |||
| chunk_data = res_data.get("chunk") | |||
| return Chunk(self.rag,chunk_data) | |||
| else: | |||
| raise Exception(f"Failed to add chunk: {res.status_code} {res.text}") | |||
| @@ -20,6 +20,8 @@ import requests | |||
| from .modules.assistant import Assistant | |||
| from .modules.dataset import DataSet | |||
| from .modules.document import Document | |||
| from .modules.chunk import Chunk | |||
| class RAGFlow: | |||
| def __init__(self, user_key, base_url, version='v1'): | |||
| @@ -143,7 +145,7 @@ class RAGFlow: | |||
| return result_list | |||
| raise Exception(res["retmsg"]) | |||
| def create_document(self, ds:DataSet, name: str, blob: bytes) -> bool: | |||
| def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool: | |||
| url = f"/doc/dataset/{ds.id}/documents/upload" | |||
| files = { | |||
| 'file': (name, blob) | |||
| @@ -164,6 +166,7 @@ class RAGFlow: | |||
| raise Exception(f"Upload failed: {response.json().get('retmsg')}") | |||
| return False | |||
| def get_document(self, id: str = None, name: str = None) -> Document: | |||
| res = self.get("/doc/infos", {"id": id, "name": name}) | |||
| res = res.json() | |||
| @@ -204,8 +207,6 @@ class RAGFlow: | |||
| if not doc_ids or not isinstance(doc_ids, list): | |||
| raise ValueError("doc_ids must be a non-empty list of document IDs") | |||
| data = {"doc_ids": doc_ids, "run": 2} | |||
| res = self.post(f'/doc/run', data) | |||
| if res.status_code != 200: | |||
| @@ -217,4 +218,61 @@ class RAGFlow: | |||
| print(f"Error occurred during canceling parsing for documents: {str(e)}") | |||
| raise | |||
| def retrieval(self, | |||
| question, | |||
| datasets=None, | |||
| documents=None, | |||
| offset=0, | |||
| limit=6, | |||
| similarity_threshold=0.1, | |||
| vector_similarity_weight=0.3, | |||
| top_k=1024): | |||
| """ | |||
| Perform document retrieval based on the given parameters. | |||
| :param question: The query question. | |||
| :param datasets: A list of datasets (optional, as documents may be provided directly). | |||
| :param documents: A list of documents (if specific documents are provided). | |||
| :param offset: Offset for the retrieval results. | |||
| :param limit: Maximum number of retrieval results. | |||
| :param similarity_threshold: Similarity threshold. | |||
| :param vector_similarity_weight: Weight of vector similarity. | |||
| :param top_k: Number of top most similar documents to consider (for pre-filtering or ranking). | |||
| Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API. | |||
| """ | |||
| try: | |||
| data = { | |||
| "question": question, | |||
| "datasets": datasets if datasets is not None else [], | |||
| "documents": [doc.id if hasattr(doc, 'id') else doc for doc in | |||
| documents] if documents is not None else [], | |||
| "offset": offset, | |||
| "limit": limit, | |||
| "similarity_threshold": similarity_threshold, | |||
| "vector_similarity_weight": vector_similarity_weight, | |||
| "top_k": top_k, | |||
| "kb_id": datasets, | |||
| } | |||
| # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) | |||
| res = self.post(f'/doc/retrieval_test', data) | |||
| # Check the response status code | |||
| if res.status_code == 200: | |||
| res_data = res.json() | |||
| if res_data.get("retmsg") == "success": | |||
| chunks = [] | |||
| for chunk_data in res_data["data"].get("chunks", []): | |||
| chunk = Chunk(self, chunk_data) | |||
| chunks.append(chunk) | |||
| return chunks | |||
| else: | |||
| raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") | |||
| else: | |||
| raise Exception(f"API request failed with status code {res.status_code}") | |||
| except Exception as e: | |||
| print(f"An error occurred during retrieval: {e}") | |||
| raise | |||
| @@ -41,6 +41,7 @@ class TestDocument(TestSdk): | |||
| def test_update_document_with_success(self): | |||
| """ | |||
| Test updating a document with success. | |||
| Update name or parser_method are supported | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| doc = rag.get_document(name="TestDocument.txt") | |||
| @@ -60,7 +61,7 @@ class TestDocument(TestSdk): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| # Retrieve a document | |||
| doc = rag.get_document(name="TestDocument.txt") | |||
| doc = rag.get_document(name="manual.txt") | |||
| # Check if the retrieved document is of type Document | |||
| if isinstance(doc, Document): | |||
| @@ -147,14 +148,16 @@ class TestDocument(TestSdk): | |||
| ds = rag.create_dataset(name="God4") | |||
| # Define the document name and path | |||
| name3 = 'ai.pdf' | |||
| path = 'test_data/ai.pdf' | |||
| name3 = 'westworld.pdf' | |||
| path = 'test_data/westworld.pdf' | |||
| # Create a document in the dataset using the file path | |||
| rag.create_document(ds, name=name3, blob=open(path, "rb").read()) | |||
| # Retrieve the document by name | |||
| doc = rag.get_document(name="ai.pdf") | |||
| doc = rag.get_document(name="westworld.pdf") | |||
| # Initiate asynchronous parsing | |||
| doc.async_parse() | |||
| @@ -185,9 +188,9 @@ class TestDocument(TestSdk): | |||
| # Prepare a list of file names and paths | |||
| documents = [ | |||
| {'name': 'ai1.pdf', 'path': 'test_data/ai1.pdf'}, | |||
| {'name': 'ai2.pdf', 'path': 'test_data/ai2.pdf'}, | |||
| {'name': 'ai3.pdf', 'path': 'test_data/ai3.pdf'} | |||
| {'name': 'test1.txt', 'path': 'test_data/test1.txt'}, | |||
| {'name': 'test2.txt', 'path': 'test_data/test2.txt'}, | |||
| {'name': 'test3.txt', 'path': 'test_data/test3.txt'} | |||
| ] | |||
| # Create documents in bulk | |||
| @@ -248,6 +251,7 @@ class TestDocument(TestSdk): | |||
| print(c) | |||
| assert c is not None, "Chunk is None" | |||
| assert "rag" in c['content_with_weight'].lower(), f"Keyword 'rag' not found in chunk content: {c.content}" | |||
| def test_add_chunk_to_chunk_list(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| doc = rag.get_document(name='story.txt') | |||
| @@ -258,12 +262,44 @@ class TestDocument(TestSdk): | |||
| def test_delete_chunk_of_chunk_list(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| doc = rag.get_document(name='story.txt') | |||
| chunk = doc.add_chunk(content="assss") | |||
| assert chunk is not None, "Chunk is None" | |||
| assert isinstance(chunk, Chunk), "Chunk was not added to chunk list" | |||
| chunk_num_before=doc.chunk_num | |||
| doc = rag.get_document(name='story.txt') | |||
| chunk_count_before=doc.chunk_count | |||
| chunk.delete() | |||
| assert doc.chunk_num == chunk_num_before-1, "Chunk was not deleted" | |||
| doc = rag.get_document(name='story.txt') | |||
| assert doc.chunk_count == chunk_count_before-1, "Chunk was not deleted" | |||
| def test_update_chunk_content(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| doc = rag.get_document(name='story.txt') | |||
| chunk = doc.add_chunk(content="assssd") | |||
| assert chunk is not None, "Chunk is None" | |||
| assert isinstance(chunk, Chunk), "Chunk was not added to chunk list" | |||
| chunk.content = "ragflow123" | |||
| res=chunk.save() | |||
| assert res is True, f"Failed to update chunk, error: {res}" | |||
| def test_retrieval_chunks(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| ds = rag.create_dataset(name="God8") | |||
| name = 'ragflow_test.txt' | |||
| path = 'test_data/ragflow_test.txt' | |||
| rag.create_document(ds, name=name, blob=open(path, "rb").read()) | |||
| doc = rag.get_document(name=name) | |||
| doc.async_parse() | |||
| # Wait for parsing to complete and get progress updates using join | |||
| for progress, msg in doc.join(interval=5, timeout=30): | |||
| print(progress, msg) | |||
| assert 0 <= progress <= 100, f"Invalid progress: {progress}" | |||
| assert msg, "Message should not be empty" | |||
| for c in rag.retrieval(question="What's ragflow?", | |||
| datasets=[ds.id], documents=[doc], | |||
| offset=0, limit=6, similarity_threshold=0.1, | |||
| vector_similarity_weight=0.3, | |||
| top_k=1024 | |||
| ): | |||
| print(c) | |||
| assert c is not None, "Chunk is None" | |||
| assert "ragflow" in c.content.lower(), f"Keyword 'rag' not found in chunk content: {c.content}" | |||