### What problem does this PR solve? Fix some issues in API ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>tags/v0.13.0
| @@ -18,20 +18,21 @@ from flask import request | |||
| from api.db import StatusEnum | |||
| from api.db.services.dialog_service import DialogService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.user_service import TenantService | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_error_data_result, token_required | |||
| from api.utils.api_utils import get_result | |||
| @manager.route('/chat', methods=['POST']) | |||
| @token_required | |||
| def create(tenant_id): | |||
| req=request.json | |||
| ids= req.get("knowledgebases") | |||
| ids= req.get("datasets") | |||
| if not ids: | |||
| return get_error_data_result(retmsg="`knowledgebases` is required") | |||
| return get_error_data_result(retmsg="`datasets` is required") | |||
| for kb_id in ids: | |||
| kbs = KnowledgebaseService.query(id=kb_id,tenant_id=tenant_id) | |||
| if not kbs: | |||
| @@ -45,6 +46,8 @@ def create(tenant_id): | |||
| if llm: | |||
| if "model_name" in llm: | |||
| req["llm_id"] = llm.pop("model_name") | |||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): | |||
| return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") | |||
| req["llm_setting"] = req.pop("llm") | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| @@ -73,10 +76,10 @@ def create(tenant_id): | |||
| req["top_n"] = req.get("top_n", 6) | |||
| req["top_k"] = req.get("top_k", 1024) | |||
| req["rerank_id"] = req.get("rerank_id", "") | |||
| if req.get("llm_id"): | |||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||
| return get_error_data_result(retmsg="the model_name does not exist.") | |||
| else: | |||
| if req.get("rerank_id"): | |||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_id"),model_type="rerank"): | |||
| return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") | |||
| if not req.get("llm_id"): | |||
| req["llm_id"] = tenant.llm_id | |||
| if not req.get("name"): | |||
| return get_error_data_result(retmsg="`name` is required.") | |||
| @@ -135,7 +138,7 @@ def create(tenant_id): | |||
| res["llm"] = res.pop("llm_setting") | |||
| res["llm"]["model_name"] = res.pop("llm_id") | |||
| del res["kb_ids"] | |||
| res["knowledgebases"] = req["knowledgebases"] | |||
| res["datasets"] = req["datasets"] | |||
| res["avatar"] = res.pop("icon") | |||
| return get_result(data=res) | |||
| @@ -145,27 +148,32 @@ def update(tenant_id,chat_id): | |||
| if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): | |||
| return get_error_data_result(retmsg='You do not own the chat') | |||
| req =request.json | |||
| if "knowledgebases" in req: | |||
| if not req.get("knowledgebases"): | |||
| return get_error_data_result(retmsg="`knowledgebases` can't be empty value") | |||
| kb_list = [] | |||
| for kb in req.get("knowledgebases"): | |||
| if not kb["id"]: | |||
| return get_error_data_result(retmsg="knowledgebase needs id") | |||
| if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): | |||
| return get_error_data_result(retmsg="you do not own the knowledgebase") | |||
| # if not DocumentService.query(kb_id=kb["id"]): | |||
| # return get_error_data_result(retmsg="There is a invalid knowledgebase") | |||
| kb_list.append(kb["id"]) | |||
| req["kb_ids"] = kb_list | |||
| ids = req.get("datasets") | |||
| if "datasets" in req: | |||
| if not ids: | |||
| return get_error_data_result("`datasets` can't be empty") | |||
| if ids: | |||
| for kb_id in ids: | |||
| kbs = KnowledgebaseService.query(id=kb_id, tenant_id=tenant_id) | |||
| if not kbs: | |||
| return get_error_data_result(f"You don't own the dataset {kb_id}") | |||
| kb = kbs[0] | |||
| if kb.chunk_num == 0: | |||
| return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") | |||
| req["kb_ids"] = ids | |||
| llm = req.get("llm") | |||
| if llm: | |||
| if "model_name" in llm: | |||
| req["llm_id"] = llm.pop("model_name") | |||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): | |||
| return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") | |||
| req["llm_setting"] = req.pop("llm") | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| return get_error_data_result(retmsg="Tenant not found!") | |||
| if req.get("rerank_model"): | |||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_model"),model_type="rerank"): | |||
| return get_error_data_result(f"`rerank_model` {req.get('rerank_model')} doesn't exist") | |||
| # prompt | |||
| prompt = req.get("prompt") | |||
| key_mapping = {"parameters": "variables", | |||
| @@ -185,9 +193,6 @@ def update(tenant_id,chat_id): | |||
| req["prompt_config"] = req.pop("prompt") | |||
| e, res = DialogService.get_by_id(chat_id) | |||
| res = res.to_json() | |||
| if "llm_id" in req: | |||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||
| return get_error_data_result(retmsg="The `model_name` does not exist.") | |||
| if "name" in req: | |||
| if not req.get("name"): | |||
| return get_error_data_result(retmsg="`name` is not empty.") | |||
| @@ -209,8 +214,8 @@ def update(tenant_id,chat_id): | |||
| # avatar | |||
| if "avatar" in req: | |||
| req["icon"] = req.pop("avatar") | |||
| if "knowledgebases" in req: | |||
| req.pop("knowledgebases") | |||
| if "datasets" in req: | |||
| req.pop("datasets") | |||
| if not DialogService.update_by_id(chat_id, req): | |||
| return get_error_data_result(retmsg="Chat not found!") | |||
| return get_result() | |||
| @@ -279,7 +284,7 @@ def list_chat(tenant_id): | |||
| return get_error_data_result(retmsg=f"Don't exist the kb {kb_id}") | |||
| kb_list.append(kb[0].to_json()) | |||
| del res["kb_ids"] | |||
| res["knowledgebases"] = kb_list | |||
| res["datasets"] = kb_list | |||
| res["avatar"] = res.pop("icon") | |||
| list_assts.append(res) | |||
| return get_result(data=list_assts) | |||
| @@ -15,17 +15,17 @@ | |||
| # | |||
| from flask import request | |||
| from api.db import StatusEnum, FileSource | |||
| from api.db.db_models import File | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.user_service import TenantService | |||
| from api.settings import RetCode | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_result, token_required, get_error_data_result, valid | |||
| from api.utils.api_utils import get_result, token_required, get_error_data_result, valid,get_parser_config | |||
| @manager.route('/dataset', methods=['POST']) | |||
| @@ -36,15 +36,17 @@ def create(tenant_id): | |||
| permission = req.get("permission") | |||
| language = req.get("language") | |||
| chunk_method = req.get("chunk_method") | |||
| valid_permission = ("me", "team") | |||
| valid_language =("Chinese", "English") | |||
| valid_chunk_method = ("naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email") | |||
| parser_config = req.get("parser_config") | |||
| valid_permission = {"me", "team"} | |||
| valid_language ={"Chinese", "English"} | |||
| valid_chunk_method = {"naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"} | |||
| check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) | |||
| if check_validation: | |||
| return check_validation | |||
| if "tenant_id" in req or "embedding_model" in req: | |||
| req["parser_config"]=get_parser_config(chunk_method,parser_config) | |||
| if "tenant_id" in req: | |||
| return get_error_data_result( | |||
| retmsg="`tenant_id` or `embedding_model` must not be provided") | |||
| retmsg="`tenant_id` must not be provided") | |||
| chunk_count=req.get("chunk_count") | |||
| document_count=req.get("document_count") | |||
| if chunk_count or document_count: | |||
| @@ -59,9 +61,13 @@ def create(tenant_id): | |||
| retmsg="`name` is not empty string!") | |||
| if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||
| return get_error_data_result( | |||
| retmsg="Duplicated knowledgebase name in creating dataset.") | |||
| retmsg="Duplicated dataset name in creating dataset.") | |||
| req["tenant_id"] = req['created_by'] = tenant_id | |||
| req['embedding_model'] = t.embd_id | |||
| if not req.get("embedding_model"): | |||
| req['embedding_model'] = t.embd_id | |||
| else: | |||
| if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): | |||
| return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") | |||
| key_mapping = { | |||
| "chunk_num": "chunk_count", | |||
| "doc_num": "document_count", | |||
| @@ -116,10 +122,12 @@ def update(tenant_id,dataset_id): | |||
| permission = req.get("permission") | |||
| language = req.get("language") | |||
| chunk_method = req.get("chunk_method") | |||
| valid_permission = ("me", "team") | |||
| valid_language =("Chinese", "English") | |||
| valid_chunk_method = ("naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email") | |||
| check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) | |||
| parser_config = req.get("parser_config") | |||
| valid_permission = {"me", "team"} | |||
| valid_language = {"Chinese", "English"} | |||
| valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", | |||
| "knowledge_graph", "email"} | |||
| check_validation = valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method) | |||
| if check_validation: | |||
| return check_validation | |||
| if "tenant_id" in req: | |||
| @@ -142,10 +150,16 @@ def update(tenant_id,dataset_id): | |||
| return get_error_data_result( | |||
| retmsg="If `chunk_count` is not 0, `chunk_method` is not changeable.") | |||
| req['parser_id'] = req.pop('chunk_method') | |||
| if req['parser_id'] != kb.parser_id: | |||
| req["parser_config"] = get_parser_config(chunk_method, parser_config) | |||
| if "embedding_model" in req: | |||
| if kb.chunk_num != 0 and req['embedding_model'] != kb.embd_id: | |||
| return get_error_data_result( | |||
| retmsg="If `chunk_count` is not 0, `embedding_method` is not changeable.") | |||
| if not req.get("embedding_model"): | |||
| return get_error_data_result("`embedding_model` can't be empty") | |||
| if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): | |||
| return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") | |||
| req['embd_id'] = req.pop('embedding_model') | |||
| if "name" in req: | |||
| req["name"] = req["name"].strip() | |||
| @@ -153,7 +167,7 @@ def update(tenant_id,dataset_id): | |||
| and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, | |||
| status=StatusEnum.VALID.value)) > 0: | |||
| return get_error_data_result( | |||
| retmsg="Duplicated knowledgebase name in updating dataset.") | |||
| retmsg="Duplicated dataset name in updating dataset.") | |||
| if not KnowledgebaseService.update_by_id(kb.id, req): | |||
| return get_error_data_result(retmsg="Update dataset error.(Database error)") | |||
| return get_result(retcode=RetCode.SUCCESS) | |||
| @@ -39,7 +39,7 @@ from api.db.services.file2document_service import File2DocumentService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.settings import RetCode, retrievaler | |||
| from api.utils.api_utils import construct_json_result | |||
| from api.utils.api_utils import construct_json_result,get_parser_config | |||
| from rag.nlp import search | |||
| from rag.utils import rmSpace | |||
| from rag.utils.es_conn import ELASTICSEARCH | |||
| @@ -49,6 +49,10 @@ MAXIMUM_OF_UPLOADING_FILES = 256 | |||
| MAXIMUM_OF_UPLOADING_FILES = 256 | |||
| MAXIMUM_OF_UPLOADING_FILES = 256 | |||
| MAXIMUM_OF_UPLOADING_FILES = 256 | |||
| @manager.route('/dataset/<dataset_id>/document', methods=['POST']) | |||
| @token_required | |||
| @@ -61,14 +65,41 @@ def upload(dataset_id, tenant_id): | |||
| if file_obj.filename == '': | |||
| return get_result( | |||
| retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) | |||
| # total size | |||
| total_size = 0 | |||
| for file_obj in file_objs: | |||
| file_obj.seek(0, os.SEEK_END) | |||
| total_size += file_obj.tell() | |||
| file_obj.seek(0) | |||
| MAX_TOTAL_FILE_SIZE=10*1024*1024 | |||
| if total_size > MAX_TOTAL_FILE_SIZE: | |||
| return get_result( | |||
| retmsg=f'Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)', | |||
| retcode=RetCode.ARGUMENT_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(dataset_id) | |||
| if not e: | |||
| raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!") | |||
| err, _ = FileService.upload_document(kb, file_objs, tenant_id) | |||
| raise LookupError(f"Can't find the dataset with ID {dataset_id}!") | |||
| err, files= FileService.upload_document(kb, file_objs, tenant_id) | |||
| if err: | |||
| return get_result( | |||
| retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) | |||
| return get_result() | |||
| # rename key's name | |||
| renamed_doc_list = [] | |||
| for file in files: | |||
| doc = file[0] | |||
| key_mapping = { | |||
| "chunk_num": "chunk_count", | |||
| "kb_id": "dataset_id", | |||
| "token_num": "token_count", | |||
| "parser_id": "chunk_method" | |||
| } | |||
| renamed_doc = {} | |||
| for key, value in doc.items(): | |||
| new_key = key_mapping.get(key, key) | |||
| renamed_doc[new_key] = value | |||
| renamed_doc["run"] = "UNSTART" | |||
| renamed_doc_list.append(renamed_doc) | |||
| return get_result(data=renamed_doc_list) | |||
| @manager.route('/dataset/<dataset_id>/info/<document_id>', methods=['PUT']) | |||
| @@ -97,7 +128,7 @@ def update_doc(tenant_id, dataset_id, document_id): | |||
| for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | |||
| if d.name == req["name"]: | |||
| return get_error_data_result( | |||
| retmsg="Duplicated document name in the same knowledgebase.") | |||
| retmsg="Duplicated document name in the same dataset.") | |||
| if not DocumentService.update_by_id( | |||
| document_id, {"name": req["name"]}): | |||
| return get_error_data_result( | |||
| @@ -110,6 +141,9 @@ def update_doc(tenant_id, dataset_id, document_id): | |||
| if "parser_config" in req: | |||
| DocumentService.update_parser_config(doc.id, req["parser_config"]) | |||
| if "chunk_method" in req: | |||
| valid_chunk_method = {"naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"} | |||
| if req.get("chunk_method") not in valid_chunk_method: | |||
| return get_error_data_result(f"`chunk_method` {req['chunk_method']} doesn't exist") | |||
| if doc.parser_id.lower() == req["chunk_method"].lower(): | |||
| return get_result() | |||
| @@ -122,6 +156,7 @@ def update_doc(tenant_id, dataset_id, document_id): | |||
| "run": TaskStatus.UNSTART.value}) | |||
| if not e: | |||
| return get_error_data_result(retmsg="Document not found!") | |||
| req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config")) | |||
| if doc.token_num > 0: | |||
| e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, | |||
| doc.process_duation * -1) | |||
| @@ -182,12 +217,21 @@ def list_docs(dataset_id, tenant_id): | |||
| for doc in docs: | |||
| key_mapping = { | |||
| "chunk_num": "chunk_count", | |||
| "kb_id": "knowledgebase_id", | |||
| "kb_id": "dataset_id", | |||
| "token_num": "token_count", | |||
| "parser_id": "chunk_method" | |||
| } | |||
| run_mapping = { | |||
| "0" :"UNSTART", | |||
| "1":"RUNNING", | |||
| "2":"CANCEL", | |||
| "3":"DONE", | |||
| "4":"FAIL" | |||
| } | |||
| renamed_doc = {} | |||
| for key, value in doc.items(): | |||
| if key =="run": | |||
| renamed_doc["run"]=run_mapping.get(str(value)) | |||
| new_key = key_mapping.get(key, key) | |||
| renamed_doc[new_key] = value | |||
| renamed_doc_list.append(renamed_doc) | |||
| @@ -353,9 +397,10 @@ def list_chunks(tenant_id,dataset_id,document_id): | |||
| return get_result(data=res) | |||
| @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST']) | |||
| @token_required | |||
| def create(tenant_id,dataset_id,document_id): | |||
| def add_chunk(tenant_id,dataset_id,document_id): | |||
| if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): | |||
| return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") | |||
| doc = DocumentService.query(id=document_id, kb_id=dataset_id) | |||
| @@ -441,6 +486,7 @@ def rm_chunk(tenant_id,dataset_id,document_id): | |||
| return get_result() | |||
| @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT']) | |||
| @token_required | |||
| def update_chunk(tenant_id,dataset_id,document_id,chunk_id): | |||
| @@ -470,12 +516,12 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id): | |||
| d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) | |||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | |||
| if "important_keywords" in req: | |||
| if type(req["important_keywords"]) != list: | |||
| return get_error_data_result("`important_keywords` is required to be a list") | |||
| if not isinstance(req["important_keywords"],list): | |||
| return get_error_data_result("`important_keywords` should be a list") | |||
| d["important_kwd"] = req.get("important_keywords") | |||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"])) | |||
| if "available" in req: | |||
| d["available_int"] = req["available"] | |||
| d["available_int"] = int(req["available"]) | |||
| embd_id = DocumentService.get_embd_id(document_id) | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| tenant_id, LLMType.EMBEDDING.value, embd_id) | |||
| @@ -498,6 +544,7 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id): | |||
| return get_result() | |||
| @manager.route('/retrieval', methods=['POST']) | |||
| @token_required | |||
| def retrieval_test(tenant_id): | |||
| @@ -505,6 +552,8 @@ def retrieval_test(tenant_id): | |||
| if not req.get("datasets"): | |||
| return get_error_data_result("`datasets` is required.") | |||
| kb_ids = req["datasets"] | |||
| if not isinstance(kb_ids,list): | |||
| return get_error_data_result("`datasets` should be a list") | |||
| kbs = KnowledgebaseService.get_by_ids(kb_ids) | |||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | |||
| if len(embd_nms) != 1: | |||
| @@ -518,9 +567,15 @@ def retrieval_test(tenant_id): | |||
| if "question" not in req: | |||
| return get_error_data_result("`question` is required.") | |||
| page = int(req.get("offset", 1)) | |||
| size = int(req.get("limit", 30)) | |||
| size = int(req.get("limit", 1024)) | |||
| question = req["question"] | |||
| doc_ids = req.get("documents", []) | |||
| if not isinstance(req.get("documents"),list): | |||
| return get_error_data_result("`documents` should be a list") | |||
| doc_ids_list=KnowledgebaseService.list_documents_by_ids(kb_ids) | |||
| for doc_id in doc_ids: | |||
| if doc_id not in doc_ids_list: | |||
| return get_error_data_result(f"You don't own the document {doc_id}") | |||
| similarity_threshold = float(req.get("similarity_threshold", 0.2)) | |||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | |||
| top = int(req.get("top_k", 1024)) | |||
| @@ -531,7 +586,7 @@ def retrieval_test(tenant_id): | |||
| try: | |||
| e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) | |||
| if not e: | |||
| return get_error_data_result(retmsg="Knowledgebase not found!") | |||
| return get_error_data_result(retmsg="Dataset not found!") | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |||
| @@ -199,7 +199,7 @@ def list(chat_id,tenant_id): | |||
| "content": chunk["content_with_weight"], | |||
| "document_id": chunk["doc_id"], | |||
| "document_name": chunk["docnm_kwd"], | |||
| "knowledgebase_id": chunk["kb_id"], | |||
| "dataset_id": chunk["kb_id"], | |||
| "image_id": chunk["img_id"], | |||
| "similarity": chunk["similarity"], | |||
| "vector_similarity": chunk["vector_similarity"], | |||
| @@ -14,13 +14,23 @@ | |||
| # limitations under the License. | |||
| # | |||
| from api.db import StatusEnum, TenantPermission | |||
| from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant | |||
| from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant,Document | |||
| from api.db.services.common_service import CommonService | |||
| class KnowledgebaseService(CommonService): | |||
| model = Knowledgebase | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def list_documents_by_ids(cls,kb_ids): | |||
| doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where( | |||
| cls.model.id.in_(kb_ids) | |||
| ) | |||
| doc_ids =list(doc_ids.dicts()) | |||
| doc_ids = [doc["document_id"] for doc in doc_ids] | |||
| return doc_ids | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | |||
| @@ -337,4 +337,23 @@ def valid(permission,valid_permission,language,valid_language,chunk_method,valid | |||
| def valid_parameter(parameter,valid_values): | |||
| if parameter and parameter not in valid_values: | |||
| return get_error_data_result(f"{parameter} not in {valid_values}") | |||
| return get_error_data_result(f"{parameter} not in {valid_values}") | |||
| def get_parser_config(chunk_method,parser_config): | |||
| if parser_config: | |||
| return parser_config | |||
| if not chunk_method: | |||
| chunk_method = "naive" | |||
| key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"user_raptor": False}}, | |||
| "qa":{"raptor":{"use_raptor":False}}, | |||
| "resume":None, | |||
| "manual":{"raptor":{"use_raptor":False}}, | |||
| "table":None, | |||
| "paper":{"raptor":{"use_raptor":False}}, | |||
| "book":{"raptor":{"use_raptor":False}}, | |||
| "laws":{"raptor":{"use_raptor":False}}, | |||
| "presentation":{"raptor":{"use_raptor":False}}, | |||
| "one":None, | |||
| "knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]}} | |||
| parser_config=key_mapping[chunk_method] | |||
| return parser_config | |||
| @@ -9,7 +9,7 @@ class Chat(Base): | |||
| self.id = "" | |||
| self.name = "assistant" | |||
| self.avatar = "path/to/avatar" | |||
| self.knowledgebases = ["kb1"] | |||
| self.datasets = ["kb1"] | |||
| self.llm = Chat.LLM(rag, {}) | |||
| self.prompt = Chat.Prompt(rag, {}) | |||
| super().__init__(rag, res_dict) | |||
| @@ -8,10 +8,10 @@ class Chunk(Base): | |||
| self.important_keywords = [] | |||
| self.create_time = "" | |||
| self.create_timestamp = 0.0 | |||
| self.knowledgebase_id = None | |||
| self.dataset_id = None | |||
| self.document_name = "" | |||
| self.document_id = "" | |||
| self.available = 1 | |||
| self.available = True | |||
| for k in list(res_dict.keys()): | |||
| if k not in self.__dict__: | |||
| res_dict.pop(k) | |||
| @@ -19,7 +19,7 @@ class Chunk(Base): | |||
| def update(self,update_message:dict): | |||
| res = self.put(f"/dataset/{self.knowledgebase_id}/document/{self.document_id}/chunk/{self.id}",update_message) | |||
| res = self.put(f"/dataset/{self.dataset_id}/document/{self.document_id}/chunk/{self.id}",update_message) | |||
| res = res.json() | |||
| if res.get("code") != 0 : | |||
| raise Exception(res["message"]) | |||
| @@ -10,10 +10,6 @@ from .base import Base | |||
| class DataSet(Base): | |||
| class ParserConfig(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.chunk_token_count = 128 | |||
| self.layout_recognize = True | |||
| self.delimiter = '\n!?。;!?' | |||
| self.task_page_size = 12 | |||
| super().__init__(rag, res_dict) | |||
| def __init__(self, rag, res_dict): | |||
| @@ -43,11 +39,16 @@ class DataSet(Base): | |||
| def upload_documents(self,document_list: List[dict]): | |||
| url = f"/dataset/{self.id}/document" | |||
| files = [("file",(ele["name"],ele["blob"])) for ele in document_list] | |||
| files = [("file",(ele["displayed_name"],ele["blob"])) for ele in document_list] | |||
| res = self.post(path=url,json=None,files=files) | |||
| res = res.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| if res.get("code") == 0: | |||
| doc_list=[] | |||
| for doc in res["data"]: | |||
| document = Document(self.rag,doc) | |||
| doc_list.append(document) | |||
| return doc_list | |||
| raise Exception(res.get("message")) | |||
| def list_documents(self, id: str = None, keywords: str = None, offset: int =1, limit: int = 1024, orderby: str = "create_time", desc: bool = True): | |||
| res = self.get(f"/dataset/{self.id}/info",params={"id": id,"keywords": keywords,"offset": offset,"limit": limit,"orderby": orderby,"desc": desc}) | |||
| @@ -5,12 +5,16 @@ from typing import List | |||
| class Document(Base): | |||
| class ParserConfig(Base): | |||
| def __init__(self, rag, res_dict): | |||
| super().__init__(rag, res_dict) | |||
| def __init__(self, rag, res_dict): | |||
| self.id = "" | |||
| self.name = "" | |||
| self.thumbnail = None | |||
| self.knowledgebase_id = None | |||
| self.chunk_method = "" | |||
| self.dataset_id = None | |||
| self.chunk_method = "naive" | |||
| self.parser_config = {"pages": [[1, 1000000]]} | |||
| self.source_type = "local" | |||
| self.type = "" | |||
| @@ -31,14 +35,14 @@ class Document(Base): | |||
| def update(self, update_message: dict): | |||
| res = self.put(f'/dataset/{self.knowledgebase_id}/info/{self.id}', | |||
| res = self.put(f'/dataset/{self.dataset_id}/info/{self.id}', | |||
| update_message) | |||
| res = res.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res["message"]) | |||
| def download(self): | |||
| res = self.get(f"/dataset/{self.knowledgebase_id}/document/{self.id}") | |||
| res = self.get(f"/dataset/{self.dataset_id}/document/{self.id}") | |||
| try: | |||
| res = res.json() | |||
| raise Exception(res.get("message")) | |||
| @@ -48,7 +52,7 @@ class Document(Base): | |||
| def list_chunks(self,offset=0, limit=30, keywords="", id:str=None): | |||
| data={"document_id": self.id,"keywords": keywords,"offset":offset,"limit":limit,"id":id} | |||
| res = self.get(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', data) | |||
| res = self.get(f'/dataset/{self.dataset_id}/document/{self.id}/chunk', data) | |||
| res = res.json() | |||
| if res.get("code") == 0: | |||
| chunks=[] | |||
| @@ -59,15 +63,15 @@ class Document(Base): | |||
| raise Exception(res.get("message")) | |||
| def add_chunk(self, content: str): | |||
| res = self.post(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', {"content":content}) | |||
| def add_chunk(self, content: str,important_keywords:List[str]=[]): | |||
| res = self.post(f'/dataset/{self.dataset_id}/document/{self.id}/chunk', {"content":content,"important_keywords":important_keywords}) | |||
| res = res.json() | |||
| if res.get("code") == 0: | |||
| return Chunk(self.rag,res["data"].get("chunk")) | |||
| raise Exception(res.get("message")) | |||
| def delete_chunks(self,ids:List[str]): | |||
| res = self.rm(f"dataset/{self.knowledgebase_id}/document/{self.id}/chunk",{"ids":ids}) | |||
| res = self.rm(f"dataset/{self.dataset_id}/document/{self.id}/chunk",{"ids":ids}) | |||
| res = res.json() | |||
| if res.get("code")!=0: | |||
| raise Exception(res.get("message")) | |||
| @@ -40,7 +40,7 @@ class Session(Base): | |||
| "content": chunk["content_with_weight"], | |||
| "document_id": chunk["doc_id"], | |||
| "document_name": chunk["docnm_kwd"], | |||
| "knowledgebase_id": chunk["kb_id"], | |||
| "dataset_id": chunk["kb_id"], | |||
| "image_id": chunk["img_id"], | |||
| "similarity": chunk["similarity"], | |||
| "vector_similarity": chunk["vector_similarity"], | |||
| @@ -75,7 +75,7 @@ class Chunk(Base): | |||
| self.content = None | |||
| self.document_id = "" | |||
| self.document_name = "" | |||
| self.knowledgebase_id = "" | |||
| self.dataset_id = "" | |||
| self.image_id = "" | |||
| self.similarity = None | |||
| self.vector_similarity = None | |||
| @@ -49,17 +49,11 @@ class RAGFlow: | |||
| return res | |||
| def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English", | |||
| permission: str = "me", | |||
| document_count: int = 0, chunk_count: int = 0, chunk_method: str = "naive", | |||
| permission: str = "me",chunk_method: str = "naive", | |||
| parser_config: DataSet.ParserConfig = None) -> DataSet: | |||
| if parser_config is None: | |||
| parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True, | |||
| "delimiter": "\n!?。;!?", "task_page_size": 12}) | |||
| parser_config = parser_config.to_json() | |||
| res = self.post("/dataset", | |||
| {"name": name, "avatar": avatar, "description": description, "language": language, | |||
| "permission": permission, | |||
| "document_count": document_count, "chunk_count": chunk_count, "chunk_method": chunk_method, | |||
| "permission": permission, "chunk_method": chunk_method, | |||
| "parser_config": parser_config | |||
| } | |||
| ) | |||
| @@ -93,11 +87,11 @@ class RAGFlow: | |||
| return result_list | |||
| raise Exception(res["message"]) | |||
| def create_chat(self, name: str, avatar: str = "", knowledgebases: List[DataSet] = [], | |||
| def create_chat(self, name: str, avatar: str = "", datasets: List[DataSet] = [], | |||
| llm: Chat.LLM = None, prompt: Chat.Prompt = None) -> Chat: | |||
| datasets = [] | |||
| for dataset in knowledgebases: | |||
| datasets.append(dataset.to_json()) | |||
| dataset_list = [] | |||
| for dataset in datasets: | |||
| dataset_list.append(dataset.to_json()) | |||
| if llm is None: | |||
| llm = Chat.LLM(self, {"model_name": None, | |||
| @@ -130,7 +124,7 @@ class RAGFlow: | |||
| temp_dict = {"name": name, | |||
| "avatar": avatar, | |||
| "knowledgebases": datasets, | |||
| "datasets": dataset_list, | |||
| "llm": llm.to_json(), | |||
| "prompt": prompt.to_json()} | |||
| res = self.post("/chat", temp_dict) | |||
| @@ -158,25 +152,22 @@ class RAGFlow: | |||
| raise Exception(res["message"]) | |||
| def retrieve(self, question="",datasets=None,documents=None, offset=1, limit=30, similarity_threshold=0.2,vector_similarity_weight=0.3,top_k=1024,rerank_id:str=None,keyword:bool=False,): | |||
| data_params = { | |||
| def retrieve(self, datasets,documents,question="", offset=1, limit=1024, similarity_threshold=0.2,vector_similarity_weight=0.3,top_k=1024,rerank_id:str=None,keyword:bool=False,): | |||
| data_json ={ | |||
| "offset": offset, | |||
| "limit": limit, | |||
| "similarity_threshold": similarity_threshold, | |||
| "vector_similarity_weight": vector_similarity_weight, | |||
| "top_k": top_k, | |||
| "knowledgebase_id": datasets, | |||
| "rerank_id":rerank_id, | |||
| "keyword":keyword | |||
| } | |||
| data_json ={ | |||
| "rerank_id": rerank_id, | |||
| "keyword": keyword, | |||
| "question": question, | |||
| "datasets": datasets, | |||
| "documents": documents | |||
| } | |||
| # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) | |||
| res = self.get(f'/retrieval', data_params,data_json) | |||
| res = self.post(f'/retrieval',json=data_json) | |||
| res = res.json() | |||
| if res.get("code") ==0: | |||
| chunks=[] | |||
| @@ -1,4 +1,5 @@ | |||
| from ragflow import RAGFlow, Chat | |||
| from xgboost.testing import datasets | |||
| from common import API_KEY, HOST_ADDRESS | |||
| from test_sdkbase import TestSdk | |||
| @@ -11,7 +12,7 @@ class TestChat(TestSdk): | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_create_chat") | |||
| chat = rag.create_chat("test_create", knowledgebases=[kb]) | |||
| chat = rag.create_chat("test_create", datasets=[kb]) | |||
| if isinstance(chat, Chat): | |||
| assert chat.name == "test_create", "Name does not match." | |||
| else: | |||
| @@ -23,7 +24,7 @@ class TestChat(TestSdk): | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_update_chat") | |||
| chat = rag.create_chat("test_update", knowledgebases=[kb]) | |||
| chat = rag.create_chat("test_update", datasets=[kb]) | |||
| if isinstance(chat, Chat): | |||
| assert chat.name == "test_update", "Name does not match." | |||
| res=chat.update({"name":"new_chat"}) | |||
| @@ -37,7 +38,7 @@ class TestChat(TestSdk): | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_delete_chat") | |||
| chat = rag.create_chat("test_delete", knowledgebases=[kb]) | |||
| chat = rag.create_chat("test_delete", datasets=[kb]) | |||
| if isinstance(chat, Chat): | |||
| assert chat.name == "test_delete", "Name does not match." | |||
| res = rag.delete_chats(ids=[chat.id]) | |||
| @@ -7,14 +7,14 @@ class TestSession: | |||
| def test_create_session(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_create_session") | |||
| assistant = rag.create_chat(name="test_create_session", knowledgebases=[kb]) | |||
| assistant = rag.create_chat(name="test_create_session", datasets=[kb]) | |||
| session = assistant.create_session() | |||
| assert isinstance(session,Session), "Failed to create a session." | |||
| def test_create_chat_with_success(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_create_chat") | |||
| assistant = rag.create_chat(name="test_create_chat", knowledgebases=[kb]) | |||
| assistant = rag.create_chat(name="test_create_chat", datasets=[kb]) | |||
| session = assistant.create_session() | |||
| question = "What is AI" | |||
| for ans in session.ask(question, stream=True): | |||
| @@ -24,7 +24,7 @@ class TestSession: | |||
| def test_delete_sessions_with_success(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_delete_session") | |||
| assistant = rag.create_chat(name="test_delete_session",knowledgebases=[kb]) | |||
| assistant = rag.create_chat(name="test_delete_session",datasets=[kb]) | |||
| session=assistant.create_session() | |||
| res=assistant.delete_sessions(ids=[session.id]) | |||
| assert res is None, "Failed to delete the dataset." | |||
| @@ -32,7 +32,7 @@ class TestSession: | |||
| def test_update_session_with_success(self): | |||
| rag=RAGFlow(API_KEY,HOST_ADDRESS) | |||
| kb=rag.create_dataset(name="test_update_session") | |||
| assistant = rag.create_chat(name="test_update_session",knowledgebases=[kb]) | |||
| assistant = rag.create_chat(name="test_update_session",datasets=[kb]) | |||
| session=assistant.create_session(name="old session") | |||
| res=session.update({"name":"new session"}) | |||
| assert res is None,"Failed to update the session" | |||
| @@ -41,7 +41,7 @@ class TestSession: | |||
| def test_list_sessions_with_success(self): | |||
| rag=RAGFlow(API_KEY,HOST_ADDRESS) | |||
| kb=rag.create_dataset(name="test_list_session") | |||
| assistant=rag.create_chat(name="test_list_session",knowledgebases=[kb]) | |||
| assistant=rag.create_chat(name="test_list_session",datasets=[kb]) | |||
| assistant.create_session("test_1") | |||
| assistant.create_session("test_2") | |||
| sessions=assistant.list_sessions() | |||