### What problem does this PR solve? Fix some issues in API ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>tags/v0.13.0
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| from api.db.services.dialog_service import DialogService | from api.db.services.dialog_service import DialogService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_error_data_result, token_required | from api.utils.api_utils import get_error_data_result, token_required | ||||
| from api.utils.api_utils import get_result | from api.utils.api_utils import get_result | ||||
| @manager.route('/chat', methods=['POST']) | @manager.route('/chat', methods=['POST']) | ||||
| @token_required | @token_required | ||||
| def create(tenant_id): | def create(tenant_id): | ||||
| req=request.json | req=request.json | ||||
| ids= req.get("knowledgebases") | |||||
| ids= req.get("datasets") | |||||
| if not ids: | if not ids: | ||||
| return get_error_data_result(retmsg="`knowledgebases` is required") | |||||
| return get_error_data_result(retmsg="`datasets` is required") | |||||
| for kb_id in ids: | for kb_id in ids: | ||||
| kbs = KnowledgebaseService.query(id=kb_id,tenant_id=tenant_id) | kbs = KnowledgebaseService.query(id=kb_id,tenant_id=tenant_id) | ||||
| if not kbs: | if not kbs: | ||||
| if llm: | if llm: | ||||
| if "model_name" in llm: | if "model_name" in llm: | ||||
| req["llm_id"] = llm.pop("model_name") | req["llm_id"] = llm.pop("model_name") | ||||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): | |||||
| return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") | |||||
| req["llm_setting"] = req.pop("llm") | req["llm_setting"] = req.pop("llm") | ||||
| e, tenant = TenantService.get_by_id(tenant_id) | e, tenant = TenantService.get_by_id(tenant_id) | ||||
| if not e: | if not e: | ||||
| req["top_n"] = req.get("top_n", 6) | req["top_n"] = req.get("top_n", 6) | ||||
| req["top_k"] = req.get("top_k", 1024) | req["top_k"] = req.get("top_k", 1024) | ||||
| req["rerank_id"] = req.get("rerank_id", "") | req["rerank_id"] = req.get("rerank_id", "") | ||||
| if req.get("llm_id"): | |||||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||||
| return get_error_data_result(retmsg="the model_name does not exist.") | |||||
| else: | |||||
| if req.get("rerank_id"): | |||||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_id"),model_type="rerank"): | |||||
| return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") | |||||
| if not req.get("llm_id"): | |||||
| req["llm_id"] = tenant.llm_id | req["llm_id"] = tenant.llm_id | ||||
| if not req.get("name"): | if not req.get("name"): | ||||
| return get_error_data_result(retmsg="`name` is required.") | return get_error_data_result(retmsg="`name` is required.") | ||||
| res["llm"] = res.pop("llm_setting") | res["llm"] = res.pop("llm_setting") | ||||
| res["llm"]["model_name"] = res.pop("llm_id") | res["llm"]["model_name"] = res.pop("llm_id") | ||||
| del res["kb_ids"] | del res["kb_ids"] | ||||
| res["knowledgebases"] = req["knowledgebases"] | |||||
| res["datasets"] = req["datasets"] | |||||
| res["avatar"] = res.pop("icon") | res["avatar"] = res.pop("icon") | ||||
| return get_result(data=res) | return get_result(data=res) | ||||
| if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): | if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): | ||||
| return get_error_data_result(retmsg='You do not own the chat') | return get_error_data_result(retmsg='You do not own the chat') | ||||
| req =request.json | req =request.json | ||||
| if "knowledgebases" in req: | |||||
| if not req.get("knowledgebases"): | |||||
| return get_error_data_result(retmsg="`knowledgebases` can't be empty value") | |||||
| kb_list = [] | |||||
| for kb in req.get("knowledgebases"): | |||||
| if not kb["id"]: | |||||
| return get_error_data_result(retmsg="knowledgebase needs id") | |||||
| if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): | |||||
| return get_error_data_result(retmsg="you do not own the knowledgebase") | |||||
| # if not DocumentService.query(kb_id=kb["id"]): | |||||
| # return get_error_data_result(retmsg="There is a invalid knowledgebase") | |||||
| kb_list.append(kb["id"]) | |||||
| req["kb_ids"] = kb_list | |||||
| ids = req.get("datasets") | |||||
| if "datasets" in req: | |||||
| if not ids: | |||||
| return get_error_data_result("`datasets` can't be empty") | |||||
| if ids: | |||||
| for kb_id in ids: | |||||
| kbs = KnowledgebaseService.query(id=kb_id, tenant_id=tenant_id) | |||||
| if not kbs: | |||||
| return get_error_data_result(f"You don't own the dataset {kb_id}") | |||||
| kb = kbs[0] | |||||
| if kb.chunk_num == 0: | |||||
| return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") | |||||
| req["kb_ids"] = ids | |||||
| llm = req.get("llm") | llm = req.get("llm") | ||||
| if llm: | if llm: | ||||
| if "model_name" in llm: | if "model_name" in llm: | ||||
| req["llm_id"] = llm.pop("model_name") | req["llm_id"] = llm.pop("model_name") | ||||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): | |||||
| return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") | |||||
| req["llm_setting"] = req.pop("llm") | req["llm_setting"] = req.pop("llm") | ||||
| e, tenant = TenantService.get_by_id(tenant_id) | e, tenant = TenantService.get_by_id(tenant_id) | ||||
| if not e: | if not e: | ||||
| return get_error_data_result(retmsg="Tenant not found!") | return get_error_data_result(retmsg="Tenant not found!") | ||||
| if req.get("rerank_model"): | |||||
| if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_model"),model_type="rerank"): | |||||
| return get_error_data_result(f"`rerank_model` {req.get('rerank_model')} doesn't exist") | |||||
| # prompt | # prompt | ||||
| prompt = req.get("prompt") | prompt = req.get("prompt") | ||||
| key_mapping = {"parameters": "variables", | key_mapping = {"parameters": "variables", | ||||
| req["prompt_config"] = req.pop("prompt") | req["prompt_config"] = req.pop("prompt") | ||||
| e, res = DialogService.get_by_id(chat_id) | e, res = DialogService.get_by_id(chat_id) | ||||
| res = res.to_json() | res = res.to_json() | ||||
| if "llm_id" in req: | |||||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||||
| return get_error_data_result(retmsg="The `model_name` does not exist.") | |||||
| if "name" in req: | if "name" in req: | ||||
| if not req.get("name"): | if not req.get("name"): | ||||
| return get_error_data_result(retmsg="`name` is not empty.") | return get_error_data_result(retmsg="`name` is not empty.") | ||||
| # avatar | # avatar | ||||
| if "avatar" in req: | if "avatar" in req: | ||||
| req["icon"] = req.pop("avatar") | req["icon"] = req.pop("avatar") | ||||
| if "knowledgebases" in req: | |||||
| req.pop("knowledgebases") | |||||
| if "datasets" in req: | |||||
| req.pop("datasets") | |||||
| if not DialogService.update_by_id(chat_id, req): | if not DialogService.update_by_id(chat_id, req): | ||||
| return get_error_data_result(retmsg="Chat not found!") | return get_error_data_result(retmsg="Chat not found!") | ||||
| return get_result() | return get_result() | ||||
| return get_error_data_result(retmsg=f"Don't exist the kb {kb_id}") | return get_error_data_result(retmsg=f"Don't exist the kb {kb_id}") | ||||
| kb_list.append(kb[0].to_json()) | kb_list.append(kb[0].to_json()) | ||||
| del res["kb_ids"] | del res["kb_ids"] | ||||
| res["knowledgebases"] = kb_list | |||||
| res["datasets"] = kb_list | |||||
| res["avatar"] = res.pop("icon") | res["avatar"] = res.pop("icon") | ||||
| list_assts.append(res) | list_assts.append(res) | ||||
| return get_result(data=list_assts) | return get_result(data=list_assts) |
| # | # | ||||
| from flask import request | from flask import request | ||||
| from api.db import StatusEnum, FileSource | from api.db import StatusEnum, FileSource | ||||
| from api.db.db_models import File | from api.db.db_models import File | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.db.services.file2document_service import File2DocumentService | from api.db.services.file2document_service import File2DocumentService | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from api.settings import RetCode | from api.settings import RetCode | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_result, token_required, get_error_data_result, valid | |||||
| from api.utils.api_utils import get_result, token_required, get_error_data_result, valid,get_parser_config | |||||
| @manager.route('/dataset', methods=['POST']) | @manager.route('/dataset', methods=['POST']) | ||||
| permission = req.get("permission") | permission = req.get("permission") | ||||
| language = req.get("language") | language = req.get("language") | ||||
| chunk_method = req.get("chunk_method") | chunk_method = req.get("chunk_method") | ||||
| valid_permission = ("me", "team") | |||||
| valid_language =("Chinese", "English") | |||||
| valid_chunk_method = ("naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email") | |||||
| parser_config = req.get("parser_config") | |||||
| valid_permission = {"me", "team"} | |||||
| valid_language ={"Chinese", "English"} | |||||
| valid_chunk_method = {"naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"} | |||||
| check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) | check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) | ||||
| if check_validation: | if check_validation: | ||||
| return check_validation | return check_validation | ||||
| if "tenant_id" in req or "embedding_model" in req: | |||||
| req["parser_config"]=get_parser_config(chunk_method,parser_config) | |||||
| if "tenant_id" in req: | |||||
| return get_error_data_result( | return get_error_data_result( | ||||
| retmsg="`tenant_id` or `embedding_model` must not be provided") | |||||
| retmsg="`tenant_id` must not be provided") | |||||
| chunk_count=req.get("chunk_count") | chunk_count=req.get("chunk_count") | ||||
| document_count=req.get("document_count") | document_count=req.get("document_count") | ||||
| if chunk_count or document_count: | if chunk_count or document_count: | ||||
| retmsg="`name` is not empty string!") | retmsg="`name` is not empty string!") | ||||
| if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | ||||
| return get_error_data_result( | return get_error_data_result( | ||||
| retmsg="Duplicated knowledgebase name in creating dataset.") | |||||
| retmsg="Duplicated dataset name in creating dataset.") | |||||
| req["tenant_id"] = req['created_by'] = tenant_id | req["tenant_id"] = req['created_by'] = tenant_id | ||||
| req['embedding_model'] = t.embd_id | |||||
| if not req.get("embedding_model"): | |||||
| req['embedding_model'] = t.embd_id | |||||
| else: | |||||
| if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): | |||||
| return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") | |||||
| key_mapping = { | key_mapping = { | ||||
| "chunk_num": "chunk_count", | "chunk_num": "chunk_count", | ||||
| "doc_num": "document_count", | "doc_num": "document_count", | ||||
| permission = req.get("permission") | permission = req.get("permission") | ||||
| language = req.get("language") | language = req.get("language") | ||||
| chunk_method = req.get("chunk_method") | chunk_method = req.get("chunk_method") | ||||
| valid_permission = ("me", "team") | |||||
| valid_language =("Chinese", "English") | |||||
| valid_chunk_method = ("naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email") | |||||
| check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) | |||||
| parser_config = req.get("parser_config") | |||||
| valid_permission = {"me", "team"} | |||||
| valid_language = {"Chinese", "English"} | |||||
| valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", | |||||
| "knowledge_graph", "email"} | |||||
| check_validation = valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method) | |||||
| if check_validation: | if check_validation: | ||||
| return check_validation | return check_validation | ||||
| if "tenant_id" in req: | if "tenant_id" in req: | ||||
| return get_error_data_result( | return get_error_data_result( | ||||
| retmsg="If `chunk_count` is not 0, `chunk_method` is not changeable.") | retmsg="If `chunk_count` is not 0, `chunk_method` is not changeable.") | ||||
| req['parser_id'] = req.pop('chunk_method') | req['parser_id'] = req.pop('chunk_method') | ||||
| if req['parser_id'] != kb.parser_id: | |||||
| req["parser_config"] = get_parser_config(chunk_method, parser_config) | |||||
| if "embedding_model" in req: | if "embedding_model" in req: | ||||
| if kb.chunk_num != 0 and req['embedding_model'] != kb.embd_id: | if kb.chunk_num != 0 and req['embedding_model'] != kb.embd_id: | ||||
| return get_error_data_result( | return get_error_data_result( | ||||
| retmsg="If `chunk_count` is not 0, `embedding_method` is not changeable.") | retmsg="If `chunk_count` is not 0, `embedding_method` is not changeable.") | ||||
| if not req.get("embedding_model"): | |||||
| return get_error_data_result("`embedding_model` can't be empty") | |||||
| if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): | |||||
| return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") | |||||
| req['embd_id'] = req.pop('embedding_model') | req['embd_id'] = req.pop('embedding_model') | ||||
| if "name" in req: | if "name" in req: | ||||
| req["name"] = req["name"].strip() | req["name"] = req["name"].strip() | ||||
| and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, | and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, | ||||
| status=StatusEnum.VALID.value)) > 0: | status=StatusEnum.VALID.value)) > 0: | ||||
| return get_error_data_result( | return get_error_data_result( | ||||
| retmsg="Duplicated knowledgebase name in updating dataset.") | |||||
| retmsg="Duplicated dataset name in updating dataset.") | |||||
| if not KnowledgebaseService.update_by_id(kb.id, req): | if not KnowledgebaseService.update_by_id(kb.id, req): | ||||
| return get_error_data_result(retmsg="Update dataset error.(Database error)") | return get_error_data_result(retmsg="Update dataset error.(Database error)") | ||||
| return get_result(retcode=RetCode.SUCCESS) | return get_result(retcode=RetCode.SUCCESS) |
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.settings import RetCode, retrievaler | from api.settings import RetCode, retrievaler | ||||
| from api.utils.api_utils import construct_json_result | |||||
| from api.utils.api_utils import construct_json_result,get_parser_config | |||||
| from rag.nlp import search | from rag.nlp import search | ||||
| from rag.utils import rmSpace | from rag.utils import rmSpace | ||||
| from rag.utils.es_conn import ELASTICSEARCH | from rag.utils.es_conn import ELASTICSEARCH | ||||
| MAXIMUM_OF_UPLOADING_FILES = 256 | MAXIMUM_OF_UPLOADING_FILES = 256 | ||||
| MAXIMUM_OF_UPLOADING_FILES = 256 | |||||
| MAXIMUM_OF_UPLOADING_FILES = 256 | |||||
| @manager.route('/dataset/<dataset_id>/document', methods=['POST']) | @manager.route('/dataset/<dataset_id>/document', methods=['POST']) | ||||
| @token_required | @token_required | ||||
| if file_obj.filename == '': | if file_obj.filename == '': | ||||
| return get_result( | return get_result( | ||||
| retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) | retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) | ||||
| # total size | |||||
| total_size = 0 | |||||
| for file_obj in file_objs: | |||||
| file_obj.seek(0, os.SEEK_END) | |||||
| total_size += file_obj.tell() | |||||
| file_obj.seek(0) | |||||
| MAX_TOTAL_FILE_SIZE=10*1024*1024 | |||||
| if total_size > MAX_TOTAL_FILE_SIZE: | |||||
| return get_result( | |||||
| retmsg=f'Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)', | |||||
| retcode=RetCode.ARGUMENT_ERROR) | |||||
| e, kb = KnowledgebaseService.get_by_id(dataset_id) | e, kb = KnowledgebaseService.get_by_id(dataset_id) | ||||
| if not e: | if not e: | ||||
| raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!") | |||||
| err, _ = FileService.upload_document(kb, file_objs, tenant_id) | |||||
| raise LookupError(f"Can't find the dataset with ID {dataset_id}!") | |||||
| err, files= FileService.upload_document(kb, file_objs, tenant_id) | |||||
| if err: | if err: | ||||
| return get_result( | return get_result( | ||||
| retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) | retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) | ||||
| return get_result() | |||||
| # rename key's name | |||||
| renamed_doc_list = [] | |||||
| for file in files: | |||||
| doc = file[0] | |||||
| key_mapping = { | |||||
| "chunk_num": "chunk_count", | |||||
| "kb_id": "dataset_id", | |||||
| "token_num": "token_count", | |||||
| "parser_id": "chunk_method" | |||||
| } | |||||
| renamed_doc = {} | |||||
| for key, value in doc.items(): | |||||
| new_key = key_mapping.get(key, key) | |||||
| renamed_doc[new_key] = value | |||||
| renamed_doc["run"] = "UNSTART" | |||||
| renamed_doc_list.append(renamed_doc) | |||||
| return get_result(data=renamed_doc_list) | |||||
| @manager.route('/dataset/<dataset_id>/info/<document_id>', methods=['PUT']) | @manager.route('/dataset/<dataset_id>/info/<document_id>', methods=['PUT']) | ||||
| for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | ||||
| if d.name == req["name"]: | if d.name == req["name"]: | ||||
| return get_error_data_result( | return get_error_data_result( | ||||
| retmsg="Duplicated document name in the same knowledgebase.") | |||||
| retmsg="Duplicated document name in the same dataset.") | |||||
| if not DocumentService.update_by_id( | if not DocumentService.update_by_id( | ||||
| document_id, {"name": req["name"]}): | document_id, {"name": req["name"]}): | ||||
| return get_error_data_result( | return get_error_data_result( | ||||
| if "parser_config" in req: | if "parser_config" in req: | ||||
| DocumentService.update_parser_config(doc.id, req["parser_config"]) | DocumentService.update_parser_config(doc.id, req["parser_config"]) | ||||
| if "chunk_method" in req: | if "chunk_method" in req: | ||||
| valid_chunk_method = {"naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"} | |||||
| if req.get("chunk_method") not in valid_chunk_method: | |||||
| return get_error_data_result(f"`chunk_method` {req['chunk_method']} doesn't exist") | |||||
| if doc.parser_id.lower() == req["chunk_method"].lower(): | if doc.parser_id.lower() == req["chunk_method"].lower(): | ||||
| return get_result() | return get_result() | ||||
| "run": TaskStatus.UNSTART.value}) | "run": TaskStatus.UNSTART.value}) | ||||
| if not e: | if not e: | ||||
| return get_error_data_result(retmsg="Document not found!") | return get_error_data_result(retmsg="Document not found!") | ||||
| req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config")) | |||||
| if doc.token_num > 0: | if doc.token_num > 0: | ||||
| e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, | e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, | ||||
| doc.process_duation * -1) | doc.process_duation * -1) | ||||
| for doc in docs: | for doc in docs: | ||||
| key_mapping = { | key_mapping = { | ||||
| "chunk_num": "chunk_count", | "chunk_num": "chunk_count", | ||||
| "kb_id": "knowledgebase_id", | |||||
| "kb_id": "dataset_id", | |||||
| "token_num": "token_count", | "token_num": "token_count", | ||||
| "parser_id": "chunk_method" | "parser_id": "chunk_method" | ||||
| } | } | ||||
| run_mapping = { | |||||
| "0" :"UNSTART", | |||||
| "1":"RUNNING", | |||||
| "2":"CANCEL", | |||||
| "3":"DONE", | |||||
| "4":"FAIL" | |||||
| } | |||||
| renamed_doc = {} | renamed_doc = {} | ||||
| for key, value in doc.items(): | for key, value in doc.items(): | ||||
| if key =="run": | |||||
| renamed_doc["run"]=run_mapping.get(str(value)) | |||||
| new_key = key_mapping.get(key, key) | new_key = key_mapping.get(key, key) | ||||
| renamed_doc[new_key] = value | renamed_doc[new_key] = value | ||||
| renamed_doc_list.append(renamed_doc) | renamed_doc_list.append(renamed_doc) | ||||
| return get_result(data=res) | return get_result(data=res) | ||||
| @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST']) | @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST']) | ||||
| @token_required | @token_required | ||||
| def create(tenant_id,dataset_id,document_id): | |||||
| def add_chunk(tenant_id,dataset_id,document_id): | |||||
| if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): | if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): | ||||
| return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") | return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") | ||||
| doc = DocumentService.query(id=document_id, kb_id=dataset_id) | doc = DocumentService.query(id=document_id, kb_id=dataset_id) | ||||
| return get_result() | return get_result() | ||||
| @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT']) | @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT']) | ||||
| @token_required | @token_required | ||||
| def update_chunk(tenant_id,dataset_id,document_id,chunk_id): | def update_chunk(tenant_id,dataset_id,document_id,chunk_id): | ||||
| d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) | d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) | ||||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | ||||
| if "important_keywords" in req: | if "important_keywords" in req: | ||||
| if type(req["important_keywords"]) != list: | |||||
| return get_error_data_result("`important_keywords` is required to be a list") | |||||
| if not isinstance(req["important_keywords"],list): | |||||
| return get_error_data_result("`important_keywords` should be a list") | |||||
| d["important_kwd"] = req.get("important_keywords") | d["important_kwd"] = req.get("important_keywords") | ||||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"])) | d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"])) | ||||
| if "available" in req: | if "available" in req: | ||||
| d["available_int"] = req["available"] | |||||
| d["available_int"] = int(req["available"]) | |||||
| embd_id = DocumentService.get_embd_id(document_id) | embd_id = DocumentService.get_embd_id(document_id) | ||||
| embd_mdl = TenantLLMService.model_instance( | embd_mdl = TenantLLMService.model_instance( | ||||
| tenant_id, LLMType.EMBEDDING.value, embd_id) | tenant_id, LLMType.EMBEDDING.value, embd_id) | ||||
| return get_result() | return get_result() | ||||
| @manager.route('/retrieval', methods=['POST']) | @manager.route('/retrieval', methods=['POST']) | ||||
| @token_required | @token_required | ||||
| def retrieval_test(tenant_id): | def retrieval_test(tenant_id): | ||||
| if not req.get("datasets"): | if not req.get("datasets"): | ||||
| return get_error_data_result("`datasets` is required.") | return get_error_data_result("`datasets` is required.") | ||||
| kb_ids = req["datasets"] | kb_ids = req["datasets"] | ||||
| if not isinstance(kb_ids,list): | |||||
| return get_error_data_result("`datasets` should be a list") | |||||
| kbs = KnowledgebaseService.get_by_ids(kb_ids) | kbs = KnowledgebaseService.get_by_ids(kb_ids) | ||||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | embd_nms = list(set([kb.embd_id for kb in kbs])) | ||||
| if len(embd_nms) != 1: | if len(embd_nms) != 1: | ||||
| if "question" not in req: | if "question" not in req: | ||||
| return get_error_data_result("`question` is required.") | return get_error_data_result("`question` is required.") | ||||
| page = int(req.get("offset", 1)) | page = int(req.get("offset", 1)) | ||||
| size = int(req.get("limit", 30)) | |||||
| size = int(req.get("limit", 1024)) | |||||
| question = req["question"] | question = req["question"] | ||||
| doc_ids = req.get("documents", []) | doc_ids = req.get("documents", []) | ||||
| if not isinstance(req.get("documents"),list): | |||||
| return get_error_data_result("`documents` should be a list") | |||||
| doc_ids_list=KnowledgebaseService.list_documents_by_ids(kb_ids) | |||||
| for doc_id in doc_ids: | |||||
| if doc_id not in doc_ids_list: | |||||
| return get_error_data_result(f"You don't own the document {doc_id}") | |||||
| similarity_threshold = float(req.get("similarity_threshold", 0.2)) | similarity_threshold = float(req.get("similarity_threshold", 0.2)) | ||||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | ||||
| top = int(req.get("top_k", 1024)) | top = int(req.get("top_k", 1024)) | ||||
| try: | try: | ||||
| e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) | e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) | ||||
| if not e: | if not e: | ||||
| return get_error_data_result(retmsg="Knowledgebase not found!") | |||||
| return get_error_data_result(retmsg="Dataset not found!") | |||||
| embd_mdl = TenantLLMService.model_instance( | embd_mdl = TenantLLMService.model_instance( | ||||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | ||||
| "content": chunk["content_with_weight"], | "content": chunk["content_with_weight"], | ||||
| "document_id": chunk["doc_id"], | "document_id": chunk["doc_id"], | ||||
| "document_name": chunk["docnm_kwd"], | "document_name": chunk["docnm_kwd"], | ||||
| "knowledgebase_id": chunk["kb_id"], | |||||
| "dataset_id": chunk["kb_id"], | |||||
| "image_id": chunk["img_id"], | "image_id": chunk["img_id"], | ||||
| "similarity": chunk["similarity"], | "similarity": chunk["similarity"], | ||||
| "vector_similarity": chunk["vector_similarity"], | "vector_similarity": chunk["vector_similarity"], |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from api.db import StatusEnum, TenantPermission | from api.db import StatusEnum, TenantPermission | ||||
| from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant | |||||
| from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant,Document | |||||
| from api.db.services.common_service import CommonService | from api.db.services.common_service import CommonService | ||||
| class KnowledgebaseService(CommonService): | class KnowledgebaseService(CommonService): | ||||
| model = Knowledgebase | model = Knowledgebase | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def list_documents_by_ids(cls,kb_ids): | |||||
| doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where( | |||||
| cls.model.id.in_(kb_ids) | |||||
| ) | |||||
| doc_ids =list(doc_ids.dicts()) | |||||
| doc_ids = [doc["document_id"] for doc in doc_ids] | |||||
| return doc_ids | |||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | def get_by_tenant_ids(cls, joined_tenant_ids, user_id, |
| def valid_parameter(parameter,valid_values): | def valid_parameter(parameter,valid_values): | ||||
| if parameter and parameter not in valid_values: | if parameter and parameter not in valid_values: | ||||
| return get_error_data_result(f"{parameter} not in {valid_values}") | |||||
| return get_error_data_result(f"{parameter} not in {valid_values}") | |||||
| def get_parser_config(chunk_method,parser_config): | |||||
| if parser_config: | |||||
| return parser_config | |||||
| if not chunk_method: | |||||
| chunk_method = "naive" | |||||
| key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"user_raptor": False}}, | |||||
| "qa":{"raptor":{"use_raptor":False}}, | |||||
| "resume":None, | |||||
| "manual":{"raptor":{"use_raptor":False}}, | |||||
| "table":None, | |||||
| "paper":{"raptor":{"use_raptor":False}}, | |||||
| "book":{"raptor":{"use_raptor":False}}, | |||||
| "laws":{"raptor":{"use_raptor":False}}, | |||||
| "presentation":{"raptor":{"use_raptor":False}}, | |||||
| "one":None, | |||||
| "knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]}} | |||||
| parser_config=key_mapping[chunk_method] | |||||
| return parser_config |
| self.id = "" | self.id = "" | ||||
| self.name = "assistant" | self.name = "assistant" | ||||
| self.avatar = "path/to/avatar" | self.avatar = "path/to/avatar" | ||||
| self.knowledgebases = ["kb1"] | |||||
| self.datasets = ["kb1"] | |||||
| self.llm = Chat.LLM(rag, {}) | self.llm = Chat.LLM(rag, {}) | ||||
| self.prompt = Chat.Prompt(rag, {}) | self.prompt = Chat.Prompt(rag, {}) | ||||
| super().__init__(rag, res_dict) | super().__init__(rag, res_dict) |
| self.important_keywords = [] | self.important_keywords = [] | ||||
| self.create_time = "" | self.create_time = "" | ||||
| self.create_timestamp = 0.0 | self.create_timestamp = 0.0 | ||||
| self.knowledgebase_id = None | |||||
| self.dataset_id = None | |||||
| self.document_name = "" | self.document_name = "" | ||||
| self.document_id = "" | self.document_id = "" | ||||
| self.available = 1 | |||||
| self.available = True | |||||
| for k in list(res_dict.keys()): | for k in list(res_dict.keys()): | ||||
| if k not in self.__dict__: | if k not in self.__dict__: | ||||
| res_dict.pop(k) | res_dict.pop(k) | ||||
| def update(self,update_message:dict): | def update(self,update_message:dict): | ||||
| res = self.put(f"/dataset/{self.knowledgebase_id}/document/{self.document_id}/chunk/{self.id}",update_message) | |||||
| res = self.put(f"/dataset/{self.dataset_id}/document/{self.document_id}/chunk/{self.id}",update_message) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("code") != 0 : | if res.get("code") != 0 : | ||||
| raise Exception(res["message"]) | raise Exception(res["message"]) |
| class DataSet(Base): | class DataSet(Base): | ||||
| class ParserConfig(Base): | class ParserConfig(Base): | ||||
| def __init__(self, rag, res_dict): | def __init__(self, rag, res_dict): | ||||
| self.chunk_token_count = 128 | |||||
| self.layout_recognize = True | |||||
| self.delimiter = '\n!?。;!?' | |||||
| self.task_page_size = 12 | |||||
| super().__init__(rag, res_dict) | super().__init__(rag, res_dict) | ||||
| def __init__(self, rag, res_dict): | def __init__(self, rag, res_dict): | ||||
| def upload_documents(self,document_list: List[dict]): | def upload_documents(self,document_list: List[dict]): | ||||
| url = f"/dataset/{self.id}/document" | url = f"/dataset/{self.id}/document" | ||||
| files = [("file",(ele["name"],ele["blob"])) for ele in document_list] | |||||
| files = [("file",(ele["displayed_name"],ele["blob"])) for ele in document_list] | |||||
| res = self.post(path=url,json=None,files=files) | res = self.post(path=url,json=None,files=files) | ||||
| res = res.json() | res = res.json() | ||||
| if res.get("code") != 0: | |||||
| raise Exception(res.get("message")) | |||||
| if res.get("code") == 0: | |||||
| doc_list=[] | |||||
| for doc in res["data"]: | |||||
| document = Document(self.rag,doc) | |||||
| doc_list.append(document) | |||||
| return doc_list | |||||
| raise Exception(res.get("message")) | |||||
| def list_documents(self, id: str = None, keywords: str = None, offset: int =1, limit: int = 1024, orderby: str = "create_time", desc: bool = True): | def list_documents(self, id: str = None, keywords: str = None, offset: int =1, limit: int = 1024, orderby: str = "create_time", desc: bool = True): | ||||
| res = self.get(f"/dataset/{self.id}/info",params={"id": id,"keywords": keywords,"offset": offset,"limit": limit,"orderby": orderby,"desc": desc}) | res = self.get(f"/dataset/{self.id}/info",params={"id": id,"keywords": keywords,"offset": offset,"limit": limit,"orderby": orderby,"desc": desc}) |
| class Document(Base): | class Document(Base): | ||||
| class ParserConfig(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| super().__init__(rag, res_dict) | |||||
| def __init__(self, rag, res_dict): | def __init__(self, rag, res_dict): | ||||
| self.id = "" | self.id = "" | ||||
| self.name = "" | self.name = "" | ||||
| self.thumbnail = None | self.thumbnail = None | ||||
| self.knowledgebase_id = None | |||||
| self.chunk_method = "" | |||||
| self.dataset_id = None | |||||
| self.chunk_method = "naive" | |||||
| self.parser_config = {"pages": [[1, 1000000]]} | self.parser_config = {"pages": [[1, 1000000]]} | ||||
| self.source_type = "local" | self.source_type = "local" | ||||
| self.type = "" | self.type = "" | ||||
| def update(self, update_message: dict): | def update(self, update_message: dict): | ||||
| res = self.put(f'/dataset/{self.knowledgebase_id}/info/{self.id}', | |||||
| res = self.put(f'/dataset/{self.dataset_id}/info/{self.id}', | |||||
| update_message) | update_message) | ||||
| res = res.json() | res = res.json() | ||||
| if res.get("code") != 0: | if res.get("code") != 0: | ||||
| raise Exception(res["message"]) | raise Exception(res["message"]) | ||||
| def download(self): | def download(self): | ||||
| res = self.get(f"/dataset/{self.knowledgebase_id}/document/{self.id}") | |||||
| res = self.get(f"/dataset/{self.dataset_id}/document/{self.id}") | |||||
| try: | try: | ||||
| res = res.json() | res = res.json() | ||||
| raise Exception(res.get("message")) | raise Exception(res.get("message")) | ||||
| def list_chunks(self,offset=0, limit=30, keywords="", id:str=None): | def list_chunks(self,offset=0, limit=30, keywords="", id:str=None): | ||||
| data={"document_id": self.id,"keywords": keywords,"offset":offset,"limit":limit,"id":id} | data={"document_id": self.id,"keywords": keywords,"offset":offset,"limit":limit,"id":id} | ||||
| res = self.get(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', data) | |||||
| res = self.get(f'/dataset/{self.dataset_id}/document/{self.id}/chunk', data) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("code") == 0: | if res.get("code") == 0: | ||||
| chunks=[] | chunks=[] | ||||
| raise Exception(res.get("message")) | raise Exception(res.get("message")) | ||||
| def add_chunk(self, content: str): | |||||
| res = self.post(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', {"content":content}) | |||||
| def add_chunk(self, content: str,important_keywords:List[str]=[]): | |||||
| res = self.post(f'/dataset/{self.dataset_id}/document/{self.id}/chunk', {"content":content,"important_keywords":important_keywords}) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("code") == 0: | if res.get("code") == 0: | ||||
| return Chunk(self.rag,res["data"].get("chunk")) | return Chunk(self.rag,res["data"].get("chunk")) | ||||
| raise Exception(res.get("message")) | raise Exception(res.get("message")) | ||||
| def delete_chunks(self,ids:List[str]): | def delete_chunks(self,ids:List[str]): | ||||
| res = self.rm(f"dataset/{self.knowledgebase_id}/document/{self.id}/chunk",{"ids":ids}) | |||||
| res = self.rm(f"dataset/{self.dataset_id}/document/{self.id}/chunk",{"ids":ids}) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("code")!=0: | if res.get("code")!=0: | ||||
| raise Exception(res.get("message")) | raise Exception(res.get("message")) |
| "content": chunk["content_with_weight"], | "content": chunk["content_with_weight"], | ||||
| "document_id": chunk["doc_id"], | "document_id": chunk["doc_id"], | ||||
| "document_name": chunk["docnm_kwd"], | "document_name": chunk["docnm_kwd"], | ||||
| "knowledgebase_id": chunk["kb_id"], | |||||
| "dataset_id": chunk["kb_id"], | |||||
| "image_id": chunk["img_id"], | "image_id": chunk["img_id"], | ||||
| "similarity": chunk["similarity"], | "similarity": chunk["similarity"], | ||||
| "vector_similarity": chunk["vector_similarity"], | "vector_similarity": chunk["vector_similarity"], | ||||
| self.content = None | self.content = None | ||||
| self.document_id = "" | self.document_id = "" | ||||
| self.document_name = "" | self.document_name = "" | ||||
| self.knowledgebase_id = "" | |||||
| self.dataset_id = "" | |||||
| self.image_id = "" | self.image_id = "" | ||||
| self.similarity = None | self.similarity = None | ||||
| self.vector_similarity = None | self.vector_similarity = None |
| return res | return res | ||||
| def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English", | def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English", | ||||
| permission: str = "me", | |||||
| document_count: int = 0, chunk_count: int = 0, chunk_method: str = "naive", | |||||
| permission: str = "me",chunk_method: str = "naive", | |||||
| parser_config: DataSet.ParserConfig = None) -> DataSet: | parser_config: DataSet.ParserConfig = None) -> DataSet: | ||||
| if parser_config is None: | |||||
| parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True, | |||||
| "delimiter": "\n!?。;!?", "task_page_size": 12}) | |||||
| parser_config = parser_config.to_json() | |||||
| res = self.post("/dataset", | res = self.post("/dataset", | ||||
| {"name": name, "avatar": avatar, "description": description, "language": language, | {"name": name, "avatar": avatar, "description": description, "language": language, | ||||
| "permission": permission, | |||||
| "document_count": document_count, "chunk_count": chunk_count, "chunk_method": chunk_method, | |||||
| "permission": permission, "chunk_method": chunk_method, | |||||
| "parser_config": parser_config | "parser_config": parser_config | ||||
| } | } | ||||
| ) | ) | ||||
| return result_list | return result_list | ||||
| raise Exception(res["message"]) | raise Exception(res["message"]) | ||||
| def create_chat(self, name: str, avatar: str = "", knowledgebases: List[DataSet] = [], | |||||
| def create_chat(self, name: str, avatar: str = "", datasets: List[DataSet] = [], | |||||
| llm: Chat.LLM = None, prompt: Chat.Prompt = None) -> Chat: | llm: Chat.LLM = None, prompt: Chat.Prompt = None) -> Chat: | ||||
| datasets = [] | |||||
| for dataset in knowledgebases: | |||||
| datasets.append(dataset.to_json()) | |||||
| dataset_list = [] | |||||
| for dataset in datasets: | |||||
| dataset_list.append(dataset.to_json()) | |||||
| if llm is None: | if llm is None: | ||||
| llm = Chat.LLM(self, {"model_name": None, | llm = Chat.LLM(self, {"model_name": None, | ||||
| temp_dict = {"name": name, | temp_dict = {"name": name, | ||||
| "avatar": avatar, | "avatar": avatar, | ||||
| "knowledgebases": datasets, | |||||
| "datasets": dataset_list, | |||||
| "llm": llm.to_json(), | "llm": llm.to_json(), | ||||
| "prompt": prompt.to_json()} | "prompt": prompt.to_json()} | ||||
| res = self.post("/chat", temp_dict) | res = self.post("/chat", temp_dict) | ||||
| raise Exception(res["message"]) | raise Exception(res["message"]) | ||||
| def retrieve(self, question="",datasets=None,documents=None, offset=1, limit=30, similarity_threshold=0.2,vector_similarity_weight=0.3,top_k=1024,rerank_id:str=None,keyword:bool=False,): | |||||
| data_params = { | |||||
| def retrieve(self, datasets,documents,question="", offset=1, limit=1024, similarity_threshold=0.2,vector_similarity_weight=0.3,top_k=1024,rerank_id:str=None,keyword:bool=False,): | |||||
| data_json ={ | |||||
| "offset": offset, | "offset": offset, | ||||
| "limit": limit, | "limit": limit, | ||||
| "similarity_threshold": similarity_threshold, | "similarity_threshold": similarity_threshold, | ||||
| "vector_similarity_weight": vector_similarity_weight, | "vector_similarity_weight": vector_similarity_weight, | ||||
| "top_k": top_k, | "top_k": top_k, | ||||
| "knowledgebase_id": datasets, | |||||
| "rerank_id":rerank_id, | |||||
| "keyword":keyword | |||||
| } | |||||
| data_json ={ | |||||
| "rerank_id": rerank_id, | |||||
| "keyword": keyword, | |||||
| "question": question, | "question": question, | ||||
| "datasets": datasets, | "datasets": datasets, | ||||
| "documents": documents | "documents": documents | ||||
| } | } | ||||
| # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) | # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) | ||||
| res = self.get(f'/retrieval', data_params,data_json) | |||||
| res = self.post(f'/retrieval',json=data_json) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("code") ==0: | if res.get("code") ==0: | ||||
| chunks=[] | chunks=[] |
| from ragflow import RAGFlow, Chat | from ragflow import RAGFlow, Chat | ||||
| from xgboost.testing import datasets | |||||
| from common import API_KEY, HOST_ADDRESS | from common import API_KEY, HOST_ADDRESS | ||||
| from test_sdkbase import TestSdk | from test_sdkbase import TestSdk | ||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_create_chat") | kb = rag.create_dataset(name="test_create_chat") | ||||
| chat = rag.create_chat("test_create", knowledgebases=[kb]) | |||||
| chat = rag.create_chat("test_create", datasets=[kb]) | |||||
| if isinstance(chat, Chat): | if isinstance(chat, Chat): | ||||
| assert chat.name == "test_create", "Name does not match." | assert chat.name == "test_create", "Name does not match." | ||||
| else: | else: | ||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_update_chat") | kb = rag.create_dataset(name="test_update_chat") | ||||
| chat = rag.create_chat("test_update", knowledgebases=[kb]) | |||||
| chat = rag.create_chat("test_update", datasets=[kb]) | |||||
| if isinstance(chat, Chat): | if isinstance(chat, Chat): | ||||
| assert chat.name == "test_update", "Name does not match." | assert chat.name == "test_update", "Name does not match." | ||||
| res=chat.update({"name":"new_chat"}) | res=chat.update({"name":"new_chat"}) | ||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_delete_chat") | kb = rag.create_dataset(name="test_delete_chat") | ||||
| chat = rag.create_chat("test_delete", knowledgebases=[kb]) | |||||
| chat = rag.create_chat("test_delete", datasets=[kb]) | |||||
| if isinstance(chat, Chat): | if isinstance(chat, Chat): | ||||
| assert chat.name == "test_delete", "Name does not match." | assert chat.name == "test_delete", "Name does not match." | ||||
| res = rag.delete_chats(ids=[chat.id]) | res = rag.delete_chats(ids=[chat.id]) |
| def test_create_session(self): | def test_create_session(self): | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_create_session") | kb = rag.create_dataset(name="test_create_session") | ||||
| assistant = rag.create_chat(name="test_create_session", knowledgebases=[kb]) | |||||
| assistant = rag.create_chat(name="test_create_session", datasets=[kb]) | |||||
| session = assistant.create_session() | session = assistant.create_session() | ||||
| assert isinstance(session,Session), "Failed to create a session." | assert isinstance(session,Session), "Failed to create a session." | ||||
| def test_create_chat_with_success(self): | def test_create_chat_with_success(self): | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_create_chat") | kb = rag.create_dataset(name="test_create_chat") | ||||
| assistant = rag.create_chat(name="test_create_chat", knowledgebases=[kb]) | |||||
| assistant = rag.create_chat(name="test_create_chat", datasets=[kb]) | |||||
| session = assistant.create_session() | session = assistant.create_session() | ||||
| question = "What is AI" | question = "What is AI" | ||||
| for ans in session.ask(question, stream=True): | for ans in session.ask(question, stream=True): | ||||
| def test_delete_sessions_with_success(self): | def test_delete_sessions_with_success(self): | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_delete_session") | kb = rag.create_dataset(name="test_delete_session") | ||||
| assistant = rag.create_chat(name="test_delete_session",knowledgebases=[kb]) | |||||
| assistant = rag.create_chat(name="test_delete_session",datasets=[kb]) | |||||
| session=assistant.create_session() | session=assistant.create_session() | ||||
| res=assistant.delete_sessions(ids=[session.id]) | res=assistant.delete_sessions(ids=[session.id]) | ||||
| assert res is None, "Failed to delete the dataset." | assert res is None, "Failed to delete the dataset." | ||||
| def test_update_session_with_success(self): | def test_update_session_with_success(self): | ||||
| rag=RAGFlow(API_KEY,HOST_ADDRESS) | rag=RAGFlow(API_KEY,HOST_ADDRESS) | ||||
| kb=rag.create_dataset(name="test_update_session") | kb=rag.create_dataset(name="test_update_session") | ||||
| assistant = rag.create_chat(name="test_update_session",knowledgebases=[kb]) | |||||
| assistant = rag.create_chat(name="test_update_session",datasets=[kb]) | |||||
| session=assistant.create_session(name="old session") | session=assistant.create_session(name="old session") | ||||
| res=session.update({"name":"new session"}) | res=session.update({"name":"new session"}) | ||||
| assert res is None,"Failed to update the session" | assert res is None,"Failed to update the session" | ||||
| def test_list_sessions_with_success(self): | def test_list_sessions_with_success(self): | ||||
| rag=RAGFlow(API_KEY,HOST_ADDRESS) | rag=RAGFlow(API_KEY,HOST_ADDRESS) | ||||
| kb=rag.create_dataset(name="test_list_session") | kb=rag.create_dataset(name="test_list_session") | ||||
| assistant=rag.create_chat(name="test_list_session",knowledgebases=[kb]) | |||||
| assistant=rag.create_chat(name="test_list_session",datasets=[kb]) | |||||
| assistant.create_session("test_1") | assistant.create_session("test_1") | ||||
| assistant.create_session("test_2") | assistant.create_session("test_2") | ||||
| sessions=assistant.list_sessions() | sessions=assistant.list_sessions() |