### What problem does this PR solve? 1. Module init won't connect database any more. 2. Config in settings need to be used with settings.CONFIG_NAME ### Type of change - [x] Refactoring Signed-off-by: jinhai <haijin.chn@gmail.com>tags/v0.14.0
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.dialog_service import message_fit_in | from api.db.services.dialog_service import message_fit_in | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.settings import retrievaler | |||||
| from api import settings | |||||
| from agent.component.base import ComponentBase, ComponentParamBase | from agent.component.base import ComponentBase, ComponentParamBase | ||||
| component_name = "Generate" | component_name = "Generate" | ||||
| def get_dependent_components(self): | def get_dependent_components(self): | ||||
| cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0] | |||||
| cpnts = [para["component_id"] for para in self._param.parameters if | |||||
| para.get("component_id") and para["component_id"].lower().find("answer") < 0] | |||||
| return cpnts | return cpnts | ||||
| def set_cite(self, retrieval_res, answer): | def set_cite(self, retrieval_res, answer): | ||||
| retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) | retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) | ||||
| if "empty_response" in retrieval_res.columns: | if "empty_response" in retrieval_res.columns: | ||||
| retrieval_res["empty_response"].fillna("", inplace=True) | retrieval_res["empty_response"].fillna("", inplace=True) | ||||
| answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], | |||||
| [ck["vector"] for _, ck in retrieval_res.iterrows()], | |||||
| LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, | |||||
| self._canvas.get_embedding_model()), tkweight=0.7, | |||||
| vtweight=0.3) | |||||
| answer, idx = settings.retrievaler.insert_citations(answer, | |||||
| [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], | |||||
| [ck["vector"] for _, ck in retrieval_res.iterrows()], | |||||
| LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, | |||||
| self._canvas.get_embedding_model()), tkweight=0.7, | |||||
| vtweight=0.3) | |||||
| doc_ids = set([]) | doc_ids = set([]) | ||||
| recall_docs = [] | recall_docs = [] | ||||
| for i in idx: | for i in idx: | ||||
| else: | else: | ||||
| if cpn.component_name.lower() == "retrieval": | if cpn.component_name.lower() == "retrieval": | ||||
| retrieval_res.append(out) | retrieval_res.append(out) | ||||
| kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]]) | |||||
| kwargs[para["key"]] = " - " + "\n - ".join( | |||||
| [o if isinstance(o, str) else str(o) for o in out["content"]]) | |||||
| self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) | self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) | ||||
| if retrieval_res: | if retrieval_res: | ||||
| retrieval_res = pd.concat(retrieval_res, ignore_index=True) | retrieval_res = pd.concat(retrieval_res, ignore_index=True) | ||||
| else: retrieval_res = pd.DataFrame([]) | |||||
| else: | |||||
| retrieval_res = pd.DataFrame([]) | |||||
| for n, v in kwargs.items(): | for n, v in kwargs.items(): | ||||
| prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) | prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) | 
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.settings import retrievaler | |||||
| from api import settings | |||||
| from agent.component.base import ComponentBase, ComponentParamBase | from agent.component.base import ComponentBase, ComponentParamBase | ||||
| if self._param.rerank_id: | if self._param.rerank_id: | ||||
| rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) | rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) | ||||
| kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, | |||||
| kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, | |||||
| 1, self._param.top_n, | 1, self._param.top_n, | ||||
| self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, | self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, | ||||
| aggs=False, rerank_mdl=rerank_mdl) | aggs=False, rerank_mdl=rerank_mdl) | 
| from flask_session import Session | from flask_session import Session | ||||
| from flask_login import LoginManager | from flask_login import LoginManager | ||||
| from api.settings import SECRET_KEY | |||||
| from api.settings import API_VERSION | |||||
| from api import settings | |||||
| from api.utils.api_utils import server_error_response | from api.utils.api_utils import server_error_response | ||||
| from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | ||||
| app.json_encoder = CustomJSONEncoder | app.json_encoder = CustomJSONEncoder | ||||
| app.errorhandler(Exception)(server_error_response) | app.errorhandler(Exception)(server_error_response) | ||||
| ## convince for dev and debug | ## convince for dev and debug | ||||
| # app.config["LOGIN_DISABLED"] = True | # app.config["LOGIN_DISABLED"] = True | ||||
| app.config["SESSION_PERMANENT"] = False | app.config["SESSION_PERMANENT"] = False | ||||
| page_name = page_path.stem.rstrip("_app") | page_name = page_path.stem.rstrip("_app") | ||||
| module_name = ".".join( | module_name = ".".join( | ||||
| page_path.parts[page_path.parts.index("api") : -1] + (page_name,) | |||||
| page_path.parts[page_path.parts.index("api"): -1] + (page_name,) | |||||
| ) | ) | ||||
| spec = spec_from_file_location(module_name, page_path) | spec = spec_from_file_location(module_name, page_path) | ||||
| spec.loader.exec_module(page) | spec.loader.exec_module(page) | ||||
| page_name = getattr(page, "page_name", page_name) | page_name = getattr(page, "page_name", page_name) | ||||
| url_prefix = ( | url_prefix = ( | ||||
| f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}" | |||||
| f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}" | |||||
| ) | ) | ||||
| app.register_blueprint(page.manager, url_prefix=url_prefix) | app.register_blueprint(page.manager, url_prefix=url_prefix) | ||||
| @login_manager.request_loader | @login_manager.request_loader | ||||
| def load_user(web_request): | def load_user(web_request): | ||||
| jwt = Serializer(secret_key=SECRET_KEY) | |||||
| jwt = Serializer(secret_key=settings.SECRET_KEY) | |||||
| authorization = web_request.headers.get("Authorization") | authorization = web_request.headers.get("Authorization") | ||||
| if authorization: | if authorization: | ||||
| try: | try: | 
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.task_service import queue_tasks, TaskService | from api.db.services.task_service import queue_tasks, TaskService | ||||
| from api.db.services.user_service import UserTenantService | from api.db.services.user_service import UserTenantService | ||||
| from api.settings import RetCode, retrievaler | |||||
| from api import settings | |||||
| from api.utils import get_uuid, current_timestamp, datetime_format | from api.utils import get_uuid, current_timestamp, datetime_format | ||||
| from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ | from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ | ||||
| generate_confirmation_token | generate_confirmation_token | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req = request.json | req = request.json | ||||
| try: | try: | ||||
| if objs[0].source == "agent": | if objs[0].source == "agent": | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req = request.json | req = request.json | ||||
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | ||||
| if not e: | if not e: | ||||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | API4ConversationService.append_message(conv.id, conv.to_dict()) | ||||
| rename_field(result) | rename_field(result) | ||||
| return get_json_result(data=result) | return get_json_result(data=result) | ||||
| #******************For dialog****************** | |||||
| # ******************For dialog****************** | |||||
| conv.message.append(msg[-1]) | conv.message.append(msg[-1]) | ||||
| e, dia = DialogService.get_by_id(conv.dialog_id) | e, dia = DialogService.get_by_id(conv.dialog_id) | ||||
| if not e: | if not e: | ||||
| resp.headers.add_header("X-Accel-Buffering", "no") | resp.headers.add_header("X-Accel-Buffering", "no") | ||||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | ||||
| return resp | return resp | ||||
| answer = None | answer = None | ||||
| for ans in chat(dia, msg, **req): | for ans in chat(dia, msg, **req): | ||||
| answer = ans | answer = ans | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| try: | try: | ||||
| e, conv = API4ConversationService.get_by_id(conversation_id) | e, conv = API4ConversationService.get_by_id(conversation_id) | ||||
| if not e: | if not e: | ||||
| conv = conv.to_dict() | conv = conv.to_dict() | ||||
| if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: | if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: | ||||
| return get_json_result(data=False, message='Token is not valid for this conversation_id!"', | return get_json_result(data=False, message='Token is not valid for this conversation_id!"', | ||||
| code=RetCode.AUTHENTICATION_ERROR) | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| for referenct_i in conv['reference']: | for referenct_i in conv['reference']: | ||||
| if referenct_i is None or len(referenct_i) == 0: | if referenct_i is None or len(referenct_i) == 0: | ||||
| continue | continue | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| kb_name = request.form.get("kb_name").strip() | kb_name = request.form.get("kb_name").strip() | ||||
| tenant_id = objs[0].tenant_id | tenant_id = objs[0].tenant_id | ||||
| if 'file' not in request.files: | if 'file' not in request.files: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| file = request.files['file'] | file = request.files['file'] | ||||
| if file.filename == '': | if file.filename == '': | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| root_folder = FileService.get_root_folder(tenant_id) | root_folder = FileService.get_root_folder(tenant_id) | ||||
| pf_id = root_folder["id"] | pf_id = root_folder["id"] | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| if 'file' not in request.files: | if 'file' not in request.files: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| file_objs = request.files.getlist('file') | file_objs = request.files.getlist('file') | ||||
| for file_obj in file_objs: | for file_obj in file_objs: | ||||
| if file_obj.filename == '': | if file_obj.filename == '': | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) | doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) | ||||
| return get_json_result(data=doc_ids) | return get_json_result(data=doc_ids) | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req = request.json | req = request.json | ||||
| ) | ) | ||||
| kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | ||||
| res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids) | |||||
| res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids) | |||||
| res = [ | res = [ | ||||
| { | { | ||||
| "content": res_item["content_with_weight"], | "content": res_item["content_with_weight"], | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req = request.json | req = request.json | ||||
| tenant_id = objs[0].tenant_id | tenant_id = objs[0].tenant_id | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| @manager.route('/document/infos', methods=['POST']) | @manager.route('/document/infos', methods=['POST']) | ||||
| @validate_request("doc_ids") | @validate_request("doc_ids") | ||||
| def docinfos(): | def docinfos(): | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req = request.json | req = request.json | ||||
| doc_ids = req["doc_ids"] | doc_ids = req["doc_ids"] | ||||
| docs = DocumentService.get_by_ids(doc_ids) | docs = DocumentService.get_by_ids(doc_ids) | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| tenant_id = objs[0].tenant_id | tenant_id = objs[0].tenant_id | ||||
| req = request.json | req = request.json | ||||
| errors += str(e) | errors += str(e) | ||||
| if errors: | if errors: | ||||
| return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) | |||||
| return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | ||||
| if not e: | if not e: | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req = request.json | req = request.json | ||||
| kb_ids = req.get("kb_id",[]) | |||||
| kb_ids = req.get("kb_id", []) | |||||
| doc_ids = req.get("doc_ids", []) | doc_ids = req.get("doc_ids", []) | ||||
| question = req.get("question") | question = req.get("question") | ||||
| page = int(req.get("page", 1)) | page = int(req.get("page", 1)) | ||||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | embd_nms = list(set([kb.embd_id for kb in kbs])) | ||||
| if len(embd_nms) != 1: | if len(embd_nms) != 1: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR) | |||||
| data=False, message='Knowledge bases use different embedding models or does not exist."', | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| embd_mdl = TenantLLMService.model_instance( | embd_mdl = TenantLLMService.model_instance( | ||||
| kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) | kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) | ||||
| rerank_mdl = None | rerank_mdl = None | ||||
| if req.get("rerank_id"): | if req.get("rerank_id"): | ||||
| rerank_mdl = TenantLLMService.model_instance( | rerank_mdl = TenantLLMService.model_instance( | ||||
| kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||||
| kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||||
| if req.get("keyword", False): | if req.get("keyword", False): | ||||
| chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) | chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) | ||||
| question += keyword_extraction(chat_mdl, question) | question += keyword_extraction(chat_mdl, question) | ||||
| ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, | |||||
| similarity_threshold, vector_similarity_weight, top, | |||||
| doc_ids, rerank_mdl=rerank_mdl) | |||||
| ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, | |||||
| similarity_threshold, vector_similarity_weight, top, | |||||
| doc_ids, rerank_mdl=rerank_mdl) | |||||
| for c in ranks["chunks"]: | for c in ranks["chunks"]: | ||||
| if "vector" in c: | if "vector" in c: | ||||
| del c["vector"] | del c["vector"] | ||||
| except Exception as e: | except Exception as e: | ||||
| if str(e).find("not_found") > 0: | if str(e).find("not_found") > 0: | ||||
| return get_json_result(data=False, message='No chunk found! Check the chunk status please!', | return get_json_result(data=False, message='No chunk found! Check the chunk status please!', | ||||
| code=RetCode.DATA_ERROR) | |||||
| code=settings.RetCode.DATA_ERROR) | |||||
| return server_error_response(e) | return server_error_response(e) | 
| from flask import request, Response | from flask import request, Response | ||||
| from flask_login import login_required, current_user | from flask_login import login_required, current_user | ||||
| from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService | from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService | ||||
| from api.settings import RetCode | |||||
| from api import settings | |||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result | from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result | ||||
| from agent.canvas import Canvas | from agent.canvas import Canvas | ||||
| @login_required | @login_required | ||||
| def canvas_list(): | def canvas_list(): | ||||
| return get_json_result(data=sorted([c.to_dict() for c in \ | return get_json_result(data=sorted([c.to_dict() for c in \ | ||||
| UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1) | |||||
| UserCanvasService.query(user_id=current_user.id)], | |||||
| key=lambda x: x["update_time"] * -1) | |||||
| ) | ) | ||||
| @login_required | @login_required | ||||
| def rm(): | def rm(): | ||||
| for i in request.json["canvas_ids"]: | for i in request.json["canvas_ids"]: | ||||
| if not UserCanvasService.query(user_id=current_user.id,id=i): | |||||
| if not UserCanvasService.query(user_id=current_user.id, id=i): | |||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of canvas authorized for this operation.', | data=False, message='Only owner of canvas authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| UserCanvasService.delete_by_id(i) | UserCanvasService.delete_by_id(i) | ||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of canvas authorized for this operation.', | data=False, message='Only owner of canvas authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| UserCanvasService.update_by_id(req["id"], req) | UserCanvasService.update_by_id(req["id"], req) | ||||
| return get_json_result(data=req) | return get_json_result(data=req) | ||||
| if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of canvas authorized for this operation.', | data=False, message='Only owner of canvas authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| if not isinstance(cvs.dsl, str): | if not isinstance(cvs.dsl, str): | ||||
| cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) | cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) | ||||
| if "message" in req: | if "message" in req: | ||||
| canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) | canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) | ||||
| if len([m for m in canvas.messages if m["role"] == "user"]) > 1: | if len([m for m in canvas.messages if m["role"] == "user"]) > 1: | ||||
| #ten = TenantService.get_info_by(current_user.id)[0] | |||||
| #req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages) | |||||
| # ten = TenantService.get_info_by(current_user.id)[0] | |||||
| # req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages) | |||||
| pass | pass | ||||
| canvas.add_user_input(req["message"]) | canvas.add_user_input(req["message"]) | ||||
| answer = canvas.run(stream=stream) | answer = canvas.run(stream=stream) | ||||
| assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." | assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." | ||||
| if stream: | if stream: | ||||
| assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." | |||||
| assert isinstance(answer, | |||||
| partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." | |||||
| def sse(): | def sse(): | ||||
| nonlocal answer, cvs | nonlocal answer, cvs | ||||
| if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of canvas authorized for this operation.', | data=False, message='Only owner of canvas authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) | canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) | ||||
| canvas.reset() | canvas.reset() | 
| from api.db.services.user_service import UserTenantService | from api.db.services.user_service import UserTenantService | ||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn | |||||
| from api import settings | |||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| import hashlib | import hashlib | ||||
| import re | import re | ||||
| @manager.route('/list', methods=['POST']) | @manager.route('/list', methods=['POST']) | ||||
| @login_required | @login_required | ||||
| @validate_request("doc_id") | @validate_request("doc_id") | ||||
| } | } | ||||
| if "available_int" in req: | if "available_int" in req: | ||||
| query["available_int"] = int(req["available_int"]) | query["available_int"] = int(req["available_int"]) | ||||
| sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) | |||||
| sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) | |||||
| res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | ||||
| for id in sres.ids: | for id in sres.ids: | ||||
| d = { | d = { | ||||
| "positions": json.loads(sres.field[id].get("position_list", "[]")), | "positions": json.loads(sres.field[id].get("position_list", "[]")), | ||||
| } | } | ||||
| assert isinstance(d["positions"], list) | assert isinstance(d["positions"], list) | ||||
| assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) | |||||
| assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) | |||||
| res["chunks"].append(d) | res["chunks"].append(d) | ||||
| return get_json_result(data=res) | return get_json_result(data=res) | ||||
| except Exception as e: | except Exception as e: | ||||
| if str(e).find("not_found") > 0: | if str(e).find("not_found") > 0: | ||||
| return get_json_result(data=False, message='No chunk found!', | return get_json_result(data=False, message='No chunk found!', | ||||
| code=RetCode.DATA_ERROR) | |||||
| code=settings.RetCode.DATA_ERROR) | |||||
| return server_error_response(e) | return server_error_response(e) | ||||
| tenant_id = tenants[0].tenant_id | tenant_id = tenants[0].tenant_id | ||||
| kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | ||||
| chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) | |||||
| chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) | |||||
| if chunk is None: | if chunk is None: | ||||
| return server_error_response("Chunk not found") | return server_error_response("Chunk not found") | ||||
| k = [] | k = [] | ||||
| except Exception as e: | except Exception as e: | ||||
| if str(e).find("NotFoundError") >= 0: | if str(e).find("NotFoundError") >= 0: | ||||
| return get_json_result(data=False, message='Chunk not found!', | return get_json_result(data=False, message='Chunk not found!', | ||||
| code=RetCode.DATA_ERROR) | |||||
| code=settings.RetCode.DATA_ERROR) | |||||
| return server_error_response(e) | return server_error_response(e) | ||||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | ||||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | ||||
| d["q_%d_vec" % len(v)] = v.tolist() | d["q_%d_vec" % len(v)] = v.tolist() | ||||
| docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||||
| settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(message="Document not found!") | return get_data_error_result(message="Document not found!") | ||||
| if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])}, | |||||
| search.index_name(doc.tenant_id), doc.kb_id): | |||||
| if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])}, | |||||
| search.index_name(doc.tenant_id), doc.kb_id): | |||||
| return get_data_error_result(message="Index updating failure") | return get_data_error_result(message="Index updating failure") | ||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(message="Document not found!") | return get_data_error_result(message="Document not found!") | ||||
| if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): | |||||
| if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): | |||||
| return get_data_error_result(message="Index updating failure") | return get_data_error_result(message="Index updating failure") | ||||
| deleted_chunk_ids = req["chunk_ids"] | deleted_chunk_ids = req["chunk_ids"] | ||||
| chunk_number = len(deleted_chunk_ids) | chunk_number = len(deleted_chunk_ids) | ||||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | ||||
| v = 0.1 * v[0] + 0.9 * v[1] | v = 0.1 * v[0] + 0.9 * v[1] | ||||
| d["q_%d_vec" % len(v)] = v.tolist() | d["q_%d_vec" % len(v)] = v.tolist() | ||||
| docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||||
| settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||||
| DocumentService.increment_chunk_num( | DocumentService.increment_chunk_num( | ||||
| doc.id, doc.kb_id, c, 1, 0) | doc.id, doc.kb_id, c, 1, 0) | ||||
| else: | else: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | data=False, message='Only owner of knowledgebase authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) | e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) | ||||
| if not e: | if not e: | ||||
| chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) | chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) | ||||
| question += keyword_extraction(chat_mdl, question) | question += keyword_extraction(chat_mdl, question) | ||||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||||
| retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler | |||||
| ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size, | ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size, | ||||
| similarity_threshold, vector_similarity_weight, top, | similarity_threshold, vector_similarity_weight, top, | ||||
| doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) | doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) | ||||
| except Exception as e: | except Exception as e: | ||||
| if str(e).find("not_found") > 0: | if str(e).find("not_found") > 0: | ||||
| return get_json_result(data=False, message='No chunk found! Check the chunk status please!', | return get_json_result(data=False, message='No chunk found! Check the chunk status please!', | ||||
| code=RetCode.DATA_ERROR) | |||||
| code=settings.RetCode.DATA_ERROR) | |||||
| return server_error_response(e) | return server_error_response(e) | ||||
| tenant_id = DocumentService.get_tenant_id(doc_id) | tenant_id = DocumentService.get_tenant_id(doc_id) | ||||
| kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | ||||
| req = { | req = { | ||||
| "doc_ids":[doc_id], | |||||
| "doc_ids": [doc_id], | |||||
| "knowledge_graph_kwd": ["graph", "mind_map"] | "knowledge_graph_kwd": ["graph", "mind_map"] | ||||
| } | } | ||||
| sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids) | |||||
| sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids) | |||||
| obj = {"graph": {}, "mind_map": {}} | obj = {"graph": {}, "mind_map": {}} | ||||
| for id in sres.ids[:2]: | for id in sres.ids[:2]: | ||||
| ty = sres.field[id]["knowledge_graph_kwd"] | ty = sres.field[id]["knowledge_graph_kwd"] | ||||
| obj[ty] = content_json | obj[ty] = content_json | ||||
| return get_json_result(data=obj) | return get_json_result(data=obj) | ||||
| from api.db.services.dialog_service import DialogService, ConversationService, chat, ask | from api.db.services.dialog_service import DialogService, ConversationService, chat, ask | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService | from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService | ||||
| from api.settings import RetCode, retrievaler | |||||
| from api import settings | |||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from graphrag.mind_map_extractor import MindMapExtractor | from graphrag.mind_map_extractor import MindMapExtractor | ||||
| else: | else: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of conversation authorized for this operation.', | data=False, message='Only owner of conversation authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| conv = conv.to_dict() | conv = conv.to_dict() | ||||
| return get_json_result(data=conv) | return get_json_result(data=conv) | ||||
| except Exception as e: | except Exception as e: | ||||
| else: | else: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of conversation authorized for this operation.', | data=False, message='Only owner of conversation authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| ConversationService.delete_by_id(cid) | ConversationService.delete_by_id(cid) | ||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| if not DialogService.query(tenant_id=current_user.id, id=dialog_id): | if not DialogService.query(tenant_id=current_user.id, id=dialog_id): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of dialog authorized for this operation.', | data=False, message='Only owner of dialog authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| convs = ConversationService.query( | convs = ConversationService.query( | ||||
| dialog_id=dialog_id, | dialog_id=dialog_id, | ||||
| order_by=ConversationService.model.create_time, | order_by=ConversationService.model.create_time, | ||||
| def ask_about(): | def ask_about(): | ||||
| req = request.json | req = request.json | ||||
| uid = current_user.id | uid = current_user.id | ||||
| def stream(): | def stream(): | ||||
| nonlocal req, uid | nonlocal req, uid | ||||
| try: | try: | ||||
| embd_mdl = TenantLLMService.model_instance( | embd_mdl = TenantLLMService.model_instance( | ||||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | ||||
| chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) | chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) | ||||
| ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, | |||||
| 0.3, 0.3, aggs=False) | |||||
| ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, | |||||
| 0.3, 0.3, aggs=False) | |||||
| mindmap = MindMapExtractor(chat_mdl) | mindmap = MindMapExtractor(chat_mdl) | ||||
| mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output | mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output | ||||
| if "error" in mind_map: | if "error" in mind_map: | 
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.user_service import TenantService, UserTenantService | from api.db.services.user_service import TenantService, UserTenantService | ||||
| from api.settings import RetCode | |||||
| from api import settings | |||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| else: | else: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of dialog authorized for this operation.', | data=False, message='Only owner of dialog authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) | dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) | ||||
| DialogService.update_many_by_id(dialog_list) | DialogService.update_many_by_id(dialog_list) | ||||
| return get_json_result(data=True) | return get_json_result(data=True) | 
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.db import FileType, TaskStatus, ParserType, FileSource | from api.db import FileType, TaskStatus, ParserType, FileSource | ||||
| from api.db.services.document_service import DocumentService, doc_upload_and_parse | from api.db.services.document_service import DocumentService, doc_upload_and_parse | ||||
| from api.settings import RetCode, docStoreConn | |||||
| from api import settings | |||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| from rag.utils.storage_factory import STORAGE_IMPL | from rag.utils.storage_factory import STORAGE_IMPL | ||||
| from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory | from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory | ||||
| kb_id = request.form.get("kb_id") | kb_id = request.form.get("kb_id") | ||||
| if not kb_id: | if not kb_id: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| if 'file' not in request.files: | if 'file' not in request.files: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| file_objs = request.files.getlist('file') | file_objs = request.files.getlist('file') | ||||
| for file_obj in file_objs: | for file_obj in file_objs: | ||||
| if file_obj.filename == '': | if file_obj.filename == '': | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | e, kb = KnowledgebaseService.get_by_id(kb_id) | ||||
| if not e: | if not e: | ||||
| err, _ = FileService.upload_document(kb, file_objs, current_user.id) | err, _ = FileService.upload_document(kb, file_objs, current_user.id) | ||||
| if err: | if err: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message="\n".join(err), code=RetCode.SERVER_ERROR) | |||||
| data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| kb_id = request.form.get("kb_id") | kb_id = request.form.get("kb_id") | ||||
| if not kb_id: | if not kb_id: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| name = request.form.get("name") | name = request.form.get("name") | ||||
| url = request.form.get("url") | url = request.form.get("url") | ||||
| if not is_valid_url(url): | if not is_valid_url(url): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | e, kb = KnowledgebaseService.get_by_id(kb_id) | ||||
| if not e: | if not e: | ||||
| raise LookupError("Can't find this knowledgebase!") | raise LookupError("Can't find this knowledgebase!") | ||||
| kb_id = req["kb_id"] | kb_id = req["kb_id"] | ||||
| if not kb_id: | if not kb_id: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| try: | try: | ||||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | e, kb = KnowledgebaseService.get_by_id(kb_id) | ||||
| kb_id = request.args.get("kb_id") | kb_id = request.args.get("kb_id") | ||||
| if not kb_id: | if not kb_id: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| tenants = UserTenantService.query(user_id=current_user.id) | tenants = UserTenantService.query(user_id=current_user.id) | ||||
| for tenant in tenants: | for tenant in tenants: | ||||
| if KnowledgebaseService.query( | if KnowledgebaseService.query( | ||||
| else: | else: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | data=False, message='Only owner of knowledgebase authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| keywords = request.args.get("keywords", "") | keywords = request.args.get("keywords", "") | ||||
| page_number = int(request.args.get("page", 1)) | page_number = int(request.args.get("page", 1)) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| docs = DocumentService.get_by_ids(doc_ids) | docs = DocumentService.get_by_ids(doc_ids) | ||||
| return get_json_result(data=list(docs.dicts())) | return get_json_result(data=list(docs.dicts())) | ||||
| @manager.route('/thumbnails', methods=['GET']) | @manager.route('/thumbnails', methods=['GET']) | ||||
| #@login_required | |||||
| # @login_required | |||||
| def thumbnails(): | def thumbnails(): | ||||
| doc_ids = request.args.get("doc_ids").split(",") | doc_ids = request.args.get("doc_ids").split(",") | ||||
| if not doc_ids: | if not doc_ids: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| try: | try: | ||||
| docs = DocumentService.get_thumbnails(doc_ids) | docs = DocumentService.get_thumbnails(doc_ids) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='"Status" must be either 0 or 1!', | message='"Status" must be either 0 or 1!', | ||||
| code=RetCode.ARGUMENT_ERROR) | |||||
| code=settings.RetCode.ARGUMENT_ERROR) | |||||
| if not DocumentService.accessible(req["doc_id"], current_user.id): | if not DocumentService.accessible(req["doc_id"], current_user.id): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR) | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| try: | try: | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| message="Database error (Document update)!") | message="Database error (Document update)!") | ||||
| status = int(req["status"]) | status = int(req["status"]) | ||||
| docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) | |||||
| settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, | |||||
| search.index_name(kb.tenant_id), doc.kb_id) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| root_folder = FileService.get_root_folder(current_user.id) | root_folder = FileService.get_root_folder(current_user.id) | ||||
| errors += str(e) | errors += str(e) | ||||
| if errors: | if errors: | ||||
| return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) | |||||
| return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| try: | try: | ||||
| for id in req["doc_ids"]: | for id in req["doc_ids"]: | ||||
| e, doc = DocumentService.get_by_id(id) | e, doc = DocumentService.get_by_id(id) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(message="Document not found!") | return get_data_error_result(message="Document not found!") | ||||
| if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||||
| docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) | |||||
| if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||||
| settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) | |||||
| if str(req["run"]) == TaskStatus.RUNNING.value: | if str(req["run"]) == TaskStatus.RUNNING.value: | ||||
| TaskService.filter_delete([Task.doc_id == id]) | TaskService.filter_delete([Task.doc_id == id]) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| try: | try: | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message="The extension of file can't be changed", | message="The extension of file can't be changed", | ||||
| code=RetCode.ARGUMENT_ERROR) | |||||
| code=settings.RetCode.ARGUMENT_ERROR) | |||||
| for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | ||||
| if d.name == req["name"]: | if d.name == req["name"]: | ||||
| return get_data_error_result( | return get_data_error_result( | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| try: | try: | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | ||||
| if not tenant_id: | if not tenant_id: | ||||
| return get_data_error_result(message="Tenant not found!") | return get_data_error_result(message="Tenant not found!") | ||||
| if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||||
| if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| def upload_and_parse(): | def upload_and_parse(): | ||||
| if 'file' not in request.files: | if 'file' not in request.files: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| file_objs = request.files.getlist('file') | file_objs = request.files.getlist('file') | ||||
| for file_obj in file_objs: | for file_obj in file_objs: | ||||
| if file_obj.filename == '': | if file_obj.filename == '': | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) | doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) | ||||
| if url: | if url: | ||||
| if not is_valid_url(url): | if not is_valid_url(url): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| download_path = os.path.join(get_project_base_directory(), "logs/downloads") | download_path = os.path.join(get_project_base_directory(), "logs/downloads") | ||||
| os.makedirs(download_path, exist_ok=True) | os.makedirs(download_path, exist_ok=True) | ||||
| from selenium.webdriver import Chrome, ChromeOptions | from selenium.webdriver import Chrome, ChromeOptions | ||||
| if 'file' not in request.files: | if 'file' not in request.files: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| file_objs = request.files.getlist('file') | file_objs = request.files.getlist('file') | ||||
| txt = FileService.parse_docs(file_objs, current_user.id) | txt = FileService.parse_docs(file_objs, current_user.id) | 
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.db import FileType | from api.db import FileType | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.settings import RetCode | |||||
| from api import settings | |||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| file_ids = req["file_ids"] | file_ids = req["file_ids"] | ||||
| if not file_ids: | if not file_ids: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| try: | try: | ||||
| for file_id in file_ids: | for file_id in file_ids: | ||||
| informs = File2DocumentService.get_by_file_id(file_id) | informs = File2DocumentService.get_by_file_id(file_id) | 
| from api.db import FileType, FileSource | from api.db import FileType, FileSource | ||||
| from api.db.services import duplicate_name | from api.db.services import duplicate_name | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.settings import RetCode | |||||
| from api import settings | |||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| from api.utils.file_utils import filename_type | from api.utils.file_utils import filename_type | ||||
| from rag.utils.storage_factory import STORAGE_IMPL | from rag.utils.storage_factory import STORAGE_IMPL | ||||
| if 'file' not in request.files: | if 'file' not in request.files: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| file_objs = request.files.getlist('file') | file_objs = request.files.getlist('file') | ||||
| for file_obj in file_objs: | for file_obj in file_objs: | ||||
| if file_obj.filename == '': | if file_obj.filename == '': | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||||
| file_res = [] | file_res = [] | ||||
| try: | try: | ||||
| for file_obj in file_objs: | for file_obj in file_objs: | ||||
| try: | try: | ||||
| if not FileService.is_parent_folder_exist(pf_id): | if not FileService.is_parent_folder_exist(pf_id): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR) | |||||
| data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR) | |||||
| if FileService.query(name=req["name"], parent_id=pf_id): | if FileService.query(name=req["name"], parent_id=pf_id): | ||||
| return get_data_error_result( | return get_data_error_result( | ||||
| message="Duplicated folder name in the same folder.") | message="Duplicated folder name in the same folder.") | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message="The extension of file can't be changed", | message="The extension of file can't be changed", | ||||
| code=RetCode.ARGUMENT_ERROR) | |||||
| code=settings.RetCode.ARGUMENT_ERROR) | |||||
| for file in FileService.query(name=req["name"], pf_id=file.parent_id): | for file in FileService.query(name=req["name"], pf_id=file.parent_id): | ||||
| if file.name == req["name"]: | if file.name == req["name"]: | ||||
| return get_data_error_result( | return get_data_error_result( | 
| from api.db import StatusEnum, FileSource | from api.db import StatusEnum, FileSource | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.db_models import File | from api.db.db_models import File | ||||
| from api.settings import RetCode | |||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| from api.settings import docStoreConn | |||||
| from api import settings | |||||
| from rag.nlp import search | from rag.nlp import search | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| try: | try: | ||||
| if not KnowledgebaseService.query( | if not KnowledgebaseService.query( | ||||
| created_by=current_user.id, id=req["kb_id"]): | created_by=current_user.id, id=req["kb_id"]): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) | |||||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) | |||||
| e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) | e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) | ||||
| if not e: | if not e: | ||||
| else: | else: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | data=False, message='Only owner of knowledgebase authorized for this operation.', | ||||
| code=RetCode.OPERATING_ERROR) | |||||
| code=settings.RetCode.OPERATING_ERROR) | |||||
| kb = KnowledgebaseService.get_detail(kb_id) | kb = KnowledgebaseService.get_detail(kb_id) | ||||
| if not kb: | if not kb: | ||||
| return get_data_error_result( | return get_data_error_result( | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message='No authorization.', | message='No authorization.', | ||||
| code=RetCode.AUTHENTICATION_ERROR | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| try: | try: | ||||
| kbs = KnowledgebaseService.query( | kbs = KnowledgebaseService.query( | ||||
| created_by=current_user.id, id=req["kb_id"]) | created_by=current_user.id, id=req["kb_id"]) | ||||
| if not kbs: | if not kbs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) | |||||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) | |||||
| for doc in DocumentService.query(kb_id=req["kb_id"]): | for doc in DocumentService.query(kb_id=req["kb_id"]): | ||||
| if not DocumentService.remove_document(doc, kbs[0].tenant_id): | if not DocumentService.remove_document(doc, kbs[0].tenant_id): | ||||
| message="Database error (Knowledgebase removal)!") | message="Database error (Knowledgebase removal)!") | ||||
| tenants = UserTenantService.query(user_id=current_user.id) | tenants = UserTenantService.query(user_id=current_user.id) | ||||
| for tenant in tenants: | for tenant in tenants: | ||||
| docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"]) | |||||
| settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"]) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | 
| from flask import request | from flask import request | ||||
| from flask_login import login_required, current_user | from flask_login import login_required, current_user | ||||
| from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | ||||
| from api.settings import LIGHTEN | |||||
| from api import settings | |||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from api.db import StatusEnum, LLMType | from api.db import StatusEnum, LLMType | ||||
| from api.db.db_models import TenantLLM | from api.db.db_models import TenantLLM | ||||
| @login_required | @login_required | ||||
| def list_app(): | def list_app(): | ||||
| self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] | self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] | ||||
| weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else [] | |||||
| weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] | |||||
| model_type = request.args.get("model_type") | model_type = request.args.get("model_type") | ||||
| try: | try: | ||||
| objs = TenantLLMService.query(tenant_id=current_user.id) | objs = TenantLLMService.query(tenant_id=current_user.id) | 
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from flask import request | from flask import request | ||||
| from api.settings import RetCode | |||||
| from api import settings | |||||
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| from api.db.services.dialog_service import DialogService | from api.db.services.dialog_service import DialogService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| kbs = KnowledgebaseService.get_by_ids(ids) | kbs = KnowledgebaseService.get_by_ids(ids) | ||||
| embd_count = list(set([kb.embd_id for kb in kbs])) | embd_count = list(set([kb.embd_id for kb in kbs])) | ||||
| if len(embd_count) != 1: | if len(embd_count) != 1: | ||||
| return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR) | |||||
| return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req["kb_ids"] = ids | req["kb_ids"] = ids | ||||
| # llm | # llm | ||||
| llm = req.get("llm") | llm = req.get("llm") | ||||
| if len(embd_count) != 1 : | if len(embd_count) != 1 : | ||||
| return get_result( | return get_result( | ||||
| message='Datasets use different embedding models."', | message='Datasets use different embedding models."', | ||||
| code=RetCode.AUTHENTICATION_ERROR) | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| req["kb_ids"] = ids | req["kb_ids"] = ids | ||||
| llm = req.get("llm") | llm = req.get("llm") | ||||
| if llm: | if llm: | 
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import TenantLLMService, LLMService | from api.db.services.llm_service import TenantLLMService, LLMService | ||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from api.settings import RetCode | |||||
| from api import settings | |||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import ( | from api.utils.api_utils import ( | ||||
| get_result, | get_result, | ||||
| File2DocumentService.delete_by_document_id(doc.id) | File2DocumentService.delete_by_document_id(doc.id) | ||||
| if not KnowledgebaseService.delete_by_id(id): | if not KnowledgebaseService.delete_by_id(id): | ||||
| return get_error_data_result(message="Delete dataset error.(Database error)") | return get_error_data_result(message="Delete dataset error.(Database error)") | ||||
| return get_result(code=RetCode.SUCCESS) | |||||
| return get_result(code=settings.RetCode.SUCCESS) | |||||
| @manager.route("/datasets/<dataset_id>", methods=["PUT"]) | @manager.route("/datasets/<dataset_id>", methods=["PUT"]) | ||||
| ) | ) | ||||
| if not KnowledgebaseService.update_by_id(kb.id, req): | if not KnowledgebaseService.update_by_id(kb.id, req): | ||||
| return get_error_data_result(message="Update dataset error.(Database error)") | return get_error_data_result(message="Update dataset error.(Database error)") | ||||
| return get_result(code=RetCode.SUCCESS) | |||||
| return get_result(code=settings.RetCode.SUCCESS) | |||||
| @manager.route("/datasets", methods=["GET"]) | @manager.route("/datasets", methods=["GET"]) | 
| from api.db import LLMType, ParserType | from api.db import LLMType, ParserType | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.settings import retrievaler, kg_retrievaler, RetCode | |||||
| from api import settings | |||||
| from api.utils.api_utils import validate_request, build_error_result, apikey_required | from api.utils.api_utils import validate_request, build_error_result, apikey_required | ||||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | e, kb = KnowledgebaseService.get_by_id(kb_id) | ||||
| if not e: | if not e: | ||||
| return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) | |||||
| return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND) | |||||
| if kb.tenant_id != tenant_id: | if kb.tenant_id != tenant_id: | ||||
| return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) | |||||
| return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND) | |||||
| embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | ||||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||||
| retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler | |||||
| ranks = retr.retrieval( | ranks = retr.retrieval( | ||||
| question, | question, | ||||
| embd_mdl, | embd_mdl, | ||||
| if str(e).find("not_found") > 0: | if str(e).find("not_found") > 0: | ||||
| return build_error_result( | return build_error_result( | ||||
| message='No chunk found! Check the chunk status please!', | message='No chunk found! Check the chunk status please!', | ||||
| code=RetCode.NOT_FOUND | |||||
| code=settings.RetCode.NOT_FOUND | |||||
| ) | ) | ||||
| return build_error_result(message=str(e), code=RetCode.SERVER_ERROR) | |||||
| return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR) | 
| from rag.nlp import rag_tokenizer | from rag.nlp import rag_tokenizer | ||||
| from api.db import LLMType, ParserType | from api.db import LLMType, ParserType | ||||
| from api.db.services.llm_service import TenantLLMService | from api.db.services.llm_service import TenantLLMService | ||||
| from api.settings import kg_retrievaler | |||||
| from api import settings | |||||
| import hashlib | import hashlib | ||||
| import re | import re | ||||
| from api.utils.api_utils import token_required | from api.utils.api_utils import token_required | ||||
| from api.db.services.file2document_service import File2DocumentService | from api.db.services.file2document_service import File2DocumentService | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.settings import RetCode, retrievaler | |||||
| from api import settings | |||||
| from api.utils.api_utils import construct_json_result, get_parser_config | from api.utils.api_utils import construct_json_result, get_parser_config | ||||
| from rag.nlp import search | from rag.nlp import search | ||||
| from rag.utils import rmSpace | from rag.utils import rmSpace | ||||
| from api.settings import docStoreConn | |||||
| from rag.utils.storage_factory import STORAGE_IMPL | from rag.utils.storage_factory import STORAGE_IMPL | ||||
| import os | import os | ||||
| """ | """ | ||||
| if "file" not in request.files: | if "file" not in request.files: | ||||
| return get_error_data_result( | return get_error_data_result( | ||||
| message="No file part!", code=RetCode.ARGUMENT_ERROR | |||||
| message="No file part!", code=settings.RetCode.ARGUMENT_ERROR | |||||
| ) | ) | ||||
| file_objs = request.files.getlist("file") | file_objs = request.files.getlist("file") | ||||
| for file_obj in file_objs: | for file_obj in file_objs: | ||||
| if file_obj.filename == "": | if file_obj.filename == "": | ||||
| return get_result( | return get_result( | ||||
| message="No file selected!", code=RetCode.ARGUMENT_ERROR | |||||
| message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR | |||||
| ) | ) | ||||
| # total size | # total size | ||||
| total_size = 0 | total_size = 0 | ||||
| if total_size > MAX_TOTAL_FILE_SIZE: | if total_size > MAX_TOTAL_FILE_SIZE: | ||||
| return get_result( | return get_result( | ||||
| message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", | message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", | ||||
| code=RetCode.ARGUMENT_ERROR, | |||||
| code=settings.RetCode.ARGUMENT_ERROR, | |||||
| ) | ) | ||||
| e, kb = KnowledgebaseService.get_by_id(dataset_id) | e, kb = KnowledgebaseService.get_by_id(dataset_id) | ||||
| if not e: | if not e: | ||||
| raise LookupError(f"Can't find the dataset with ID {dataset_id}!") | raise LookupError(f"Can't find the dataset with ID {dataset_id}!") | ||||
| err, files = FileService.upload_document(kb, file_objs, tenant_id) | err, files = FileService.upload_document(kb, file_objs, tenant_id) | ||||
| if err: | if err: | ||||
| return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR) | |||||
| return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR) | |||||
| # rename key's name | # rename key's name | ||||
| renamed_doc_list = [] | renamed_doc_list = [] | ||||
| for file in files: | for file in files: | ||||
| if "name" in req and req["name"] != doc.name: | if "name" in req and req["name"] != doc.name: | ||||
| if ( | if ( | ||||
| pathlib.Path(req["name"].lower()).suffix | |||||
| != pathlib.Path(doc.name.lower()).suffix | |||||
| pathlib.Path(req["name"].lower()).suffix | |||||
| != pathlib.Path(doc.name.lower()).suffix | |||||
| ): | ): | ||||
| return get_result( | return get_result( | ||||
| message="The extension of file can't be changed", | message="The extension of file can't be changed", | ||||
| code=RetCode.ARGUMENT_ERROR, | |||||
| code=settings.RetCode.ARGUMENT_ERROR, | |||||
| ) | ) | ||||
| for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | ||||
| if d.name == req["name"]: | if d.name == req["name"]: | ||||
| ) | ) | ||||
| if not e: | if not e: | ||||
| return get_error_data_result(message="Document not found!") | return get_error_data_result(message="Document not found!") | ||||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||||
| return get_result() | return get_result() | ||||
| file_stream = STORAGE_IMPL.get(doc_id, doc_location) | file_stream = STORAGE_IMPL.get(doc_id, doc_location) | ||||
| if not file_stream: | if not file_stream: | ||||
| return construct_json_result( | return construct_json_result( | ||||
| message="This file is empty.", code=RetCode.DATA_ERROR | |||||
| message="This file is empty.", code=settings.RetCode.DATA_ERROR | |||||
| ) | ) | ||||
| file = BytesIO(file_stream) | file = BytesIO(file_stream) | ||||
| # Use send_file with a proper filename and MIME type | # Use send_file with a proper filename and MIME type | ||||
| errors += str(e) | errors += str(e) | ||||
| if errors: | if errors: | ||||
| return get_result(message=errors, code=RetCode.SERVER_ERROR) | |||||
| return get_result(message=errors, code=settings.RetCode.SERVER_ERROR) | |||||
| return get_result() | return get_result() | ||||
| info["chunk_num"] = 0 | info["chunk_num"] = 0 | ||||
| info["token_num"] = 0 | info["token_num"] = 0 | ||||
| DocumentService.update_by_id(id, info) | DocumentService.update_by_id(id, info) | ||||
| docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) | |||||
| settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) | |||||
| TaskService.filter_delete([Task.doc_id == id]) | TaskService.filter_delete([Task.doc_id == id]) | ||||
| e, doc = DocumentService.get_by_id(id) | e, doc = DocumentService.get_by_id(id) | ||||
| doc = doc.to_dict() | doc = doc.to_dict() | ||||
| ) | ) | ||||
| info = {"run": "2", "progress": 0, "chunk_num": 0} | info = {"run": "2", "progress": 0, "chunk_num": 0} | ||||
| DocumentService.update_by_id(id, info) | DocumentService.update_by_id(id, info) | ||||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||||
| return get_result() | return get_result() | ||||
| res = {"total": 0, "chunks": [], "doc": renamed_doc} | res = {"total": 0, "chunks": [], "doc": renamed_doc} | ||||
| origin_chunks = [] | origin_chunks = [] | ||||
| if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): | |||||
| sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) | |||||
| if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): | |||||
| sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, | |||||
| highlight=True) | |||||
| res["total"] = sres.total | res["total"] = sres.total | ||||
| sign = 0 | sign = 0 | ||||
| for id in sres.ids: | for id in sres.ids: | ||||
| v, c = embd_mdl.encode([doc.name, req["content"]]) | v, c = embd_mdl.encode([doc.name, req["content"]]) | ||||
| v = 0.1 * v[0] + 0.9 * v[1] | v = 0.1 * v[0] + 0.9 * v[1] | ||||
| d["q_%d_vec" % len(v)] = v.tolist() | d["q_%d_vec" % len(v)] = v.tolist() | ||||
| docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) | |||||
| settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) | |||||
| DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) | DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) | ||||
| # rename keys | # rename keys | ||||
| condition = {"doc_id": document_id} | condition = {"doc_id": document_id} | ||||
| if "chunk_ids" in req: | if "chunk_ids" in req: | ||||
| condition["id"] = req["chunk_ids"] | condition["id"] = req["chunk_ids"] | ||||
| chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) | |||||
| chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) | |||||
| if chunk_number != 0: | if chunk_number != 0: | ||||
| DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) | DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) | ||||
| if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]): | if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]): | ||||
| schema: | schema: | ||||
| type: object | type: object | ||||
| """ | """ | ||||
| chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) | |||||
| chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) | |||||
| if chunk is None: | if chunk is None: | ||||
| return get_error_data_result(f"Can't find this chunk {chunk_id}") | return get_error_data_result(f"Can't find this chunk {chunk_id}") | ||||
| if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): | if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): | ||||
| v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) | v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) | ||||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | ||||
| d["q_%d_vec" % len(v)] = v.tolist() | d["q_%d_vec" % len(v)] = v.tolist() | ||||
| docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) | |||||
| settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) | |||||
| return get_result() | return get_result() | ||||
| if len(embd_nms) != 1: | if len(embd_nms) != 1: | ||||
| return get_result( | return get_result( | ||||
| message='Datasets use different embedding models."', | message='Datasets use different embedding models."', | ||||
| code=RetCode.AUTHENTICATION_ERROR, | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||||
| ) | ) | ||||
| if "question" not in req: | if "question" not in req: | ||||
| return get_error_data_result("`question` is required.") | return get_error_data_result("`question` is required.") | ||||
| chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) | chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) | ||||
| question += keyword_extraction(chat_mdl, question) | question += keyword_extraction(chat_mdl, question) | ||||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||||
| retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler | |||||
| ranks = retr.retrieval( | ranks = retr.retrieval( | ||||
| question, | question, | ||||
| embd_mdl, | embd_mdl, | ||||
| if str(e).find("not_found") > 0: | if str(e).find("not_found") > 0: | ||||
| return get_result( | return get_result( | ||||
| message="No chunk found! Check the chunk status please!", | message="No chunk found! Check the chunk status please!", | ||||
| code=RetCode.DATA_ERROR, | |||||
| code=settings.RetCode.DATA_ERROR, | |||||
| ) | ) | ||||
| return server_error_response(e) | |||||
| return server_error_response(e) | 
| from api.db.services.api_service import APITokenService | from api.db.services.api_service import APITokenService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.user_service import UserTenantService | from api.db.services.user_service import UserTenantService | ||||
| from api.settings import DATABASE_TYPE | |||||
| from api import settings | |||||
| from api.utils import current_timestamp, datetime_format | from api.utils import current_timestamp, datetime_format | ||||
| from api.utils.api_utils import ( | from api.utils.api_utils import ( | ||||
| get_json_result, | get_json_result, | ||||
| generate_confirmation_token, | generate_confirmation_token, | ||||
| ) | ) | ||||
| from api.versions import get_ragflow_version | from api.versions import get_ragflow_version | ||||
| from api.settings import docStoreConn | |||||
| from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE | from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| res = {} | res = {} | ||||
| st = timer() | st = timer() | ||||
| try: | try: | ||||
| res["doc_store"] = docStoreConn.health() | |||||
| res["doc_store"] = settings.docStoreConn.health() | |||||
| res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) | res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) | ||||
| except Exception as e: | except Exception as e: | ||||
| res["doc_store"] = { | res["doc_store"] = { | ||||
| try: | try: | ||||
| KnowledgebaseService.get_by_id("x") | KnowledgebaseService.get_by_id("x") | ||||
| res["database"] = { | res["database"] = { | ||||
| "database": DATABASE_TYPE.lower(), | |||||
| "database": settings.DATABASE_TYPE.lower(), | |||||
| "status": "green", | "status": "green", | ||||
| "elapsed": "{:.1f}".format((timer() - st) * 1000.0), | "elapsed": "{:.1f}".format((timer() - st) * 1000.0), | ||||
| } | } | ||||
| except Exception as e: | except Exception as e: | ||||
| res["database"] = { | res["database"] = { | ||||
| "database": DATABASE_TYPE.lower(), | |||||
| "database": settings.DATABASE_TYPE.lower(), | |||||
| "status": "red", | "status": "red", | ||||
| "elapsed": "{:.1f}".format((timer() - st) * 1000.0), | "elapsed": "{:.1f}".format((timer() - st) * 1000.0), | ||||
| "error": str(e), | "error": str(e), | 
| datetime_format, | datetime_format, | ||||
| ) | ) | ||||
| from api.db import UserTenantRole, FileType | from api.db import UserTenantRole, FileType | ||||
| from api.settings import ( | |||||
| RetCode, | |||||
| GITHUB_OAUTH, | |||||
| FEISHU_OAUTH, | |||||
| CHAT_MDL, | |||||
| EMBEDDING_MDL, | |||||
| ASR_MDL, | |||||
| IMAGE2TEXT_MDL, | |||||
| PARSERS, | |||||
| API_KEY, | |||||
| LLM_FACTORY, | |||||
| LLM_BASE_URL, | |||||
| RERANK_MDL, | |||||
| ) | |||||
| from api import settings | |||||
| from api.db.services.user_service import UserService, TenantService, UserTenantService | from api.db.services.user_service import UserService, TenantService, UserTenantService | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.utils.api_utils import get_json_result, construct_response | from api.utils.api_utils import get_json_result, construct_response | ||||
| """ | """ | ||||
| if not request.json: | if not request.json: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" | |||||
| data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" | |||||
| ) | ) | ||||
| email = request.json.get("email", "") | email = request.json.get("email", "") | ||||
| if not users: | if not users: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| code=RetCode.AUTHENTICATION_ERROR, | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||||
| message=f"Email: {email} is not registered!", | message=f"Email: {email} is not registered!", | ||||
| ) | ) | ||||
| password = decrypt(password) | password = decrypt(password) | ||||
| except BaseException: | except BaseException: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password" | |||||
| data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password" | |||||
| ) | ) | ||||
| user = UserService.query_user(email, password) | user = UserService.query_user(email, password) | ||||
| else: | else: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| code=RetCode.AUTHENTICATION_ERROR, | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||||
| message="Email and password do not match!", | message="Email and password do not match!", | ||||
| ) | ) | ||||
| import requests | import requests | ||||
| res = requests.post( | res = requests.post( | ||||
| GITHUB_OAUTH.get("url"), | |||||
| settings.GITHUB_OAUTH.get("url"), | |||||
| data={ | data={ | ||||
| "client_id": GITHUB_OAUTH.get("client_id"), | |||||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||||
| "client_id": settings.GITHUB_OAUTH.get("client_id"), | |||||
| "client_secret": settings.GITHUB_OAUTH.get("secret_key"), | |||||
| "code": request.args.get("code"), | "code": request.args.get("code"), | ||||
| }, | }, | ||||
| headers={"Accept": "application/json"}, | headers={"Accept": "application/json"}, | ||||
| import requests | import requests | ||||
| app_access_token_res = requests.post( | app_access_token_res = requests.post( | ||||
| FEISHU_OAUTH.get("app_access_token_url"), | |||||
| settings.FEISHU_OAUTH.get("app_access_token_url"), | |||||
| data=json.dumps( | data=json.dumps( | ||||
| { | { | ||||
| "app_id": FEISHU_OAUTH.get("app_id"), | |||||
| "app_secret": FEISHU_OAUTH.get("app_secret"), | |||||
| "app_id": settings.FEISHU_OAUTH.get("app_id"), | |||||
| "app_secret": settings.FEISHU_OAUTH.get("app_secret"), | |||||
| } | } | ||||
| ), | ), | ||||
| headers={"Content-Type": "application/json; charset=utf-8"}, | headers={"Content-Type": "application/json; charset=utf-8"}, | ||||
| return redirect("/?error=%s" % app_access_token_res) | return redirect("/?error=%s" % app_access_token_res) | ||||
| res = requests.post( | res = requests.post( | ||||
| FEISHU_OAUTH.get("user_access_token_url"), | |||||
| settings.FEISHU_OAUTH.get("user_access_token_url"), | |||||
| data=json.dumps( | data=json.dumps( | ||||
| { | { | ||||
| "grant_type": FEISHU_OAUTH.get("grant_type"), | |||||
| "grant_type": settings.FEISHU_OAUTH.get("grant_type"), | |||||
| "code": request.args.get("code"), | "code": request.args.get("code"), | ||||
| } | } | ||||
| ), | ), | ||||
| if request_data.get("password"): | if request_data.get("password"): | ||||
| new_password = request_data.get("new_password") | new_password = request_data.get("new_password") | ||||
| if not check_password_hash( | if not check_password_hash( | ||||
| current_user.password, decrypt(request_data["password"]) | |||||
| current_user.password, decrypt(request_data["password"]) | |||||
| ): | ): | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| code=RetCode.AUTHENTICATION_ERROR, | |||||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||||
| message="Password error!", | message="Password error!", | ||||
| ) | ) | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(e) | logging.exception(e) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR | |||||
| data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR | |||||
| ) | ) | ||||
| tenant = { | tenant = { | ||||
| "id": user_id, | "id": user_id, | ||||
| "name": user["nickname"] + "‘s Kingdom", | "name": user["nickname"] + "‘s Kingdom", | ||||
| "llm_id": CHAT_MDL, | |||||
| "embd_id": EMBEDDING_MDL, | |||||
| "asr_id": ASR_MDL, | |||||
| "parser_ids": PARSERS, | |||||
| "img2txt_id": IMAGE2TEXT_MDL, | |||||
| "rerank_id": RERANK_MDL, | |||||
| "llm_id": settings.CHAT_MDL, | |||||
| "embd_id": settings.EMBEDDING_MDL, | |||||
| "asr_id": settings.ASR_MDL, | |||||
| "parser_ids": settings.PARSERS, | |||||
| "img2txt_id": settings.IMAGE2TEXT_MDL, | |||||
| "rerank_id": settings.RERANK_MDL, | |||||
| } | } | ||||
| usr_tenant = { | usr_tenant = { | ||||
| "tenant_id": user_id, | "tenant_id": user_id, | ||||
| "location": "", | "location": "", | ||||
| } | } | ||||
| tenant_llm = [] | tenant_llm = [] | ||||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||||
| for llm in LLMService.query(fid=settings.LLM_FACTORY): | |||||
| tenant_llm.append( | tenant_llm.append( | ||||
| { | { | ||||
| "tenant_id": user_id, | "tenant_id": user_id, | ||||
| "llm_factory": LLM_FACTORY, | |||||
| "llm_factory": settings.LLM_FACTORY, | |||||
| "llm_name": llm.llm_name, | "llm_name": llm.llm_name, | ||||
| "model_type": llm.model_type, | "model_type": llm.model_type, | ||||
| "api_key": API_KEY, | |||||
| "api_base": LLM_BASE_URL, | |||||
| "api_key": settings.API_KEY, | |||||
| "api_base": settings.LLM_BASE_URL, | |||||
| } | } | ||||
| ) | ) | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message=f"Invalid email address: {email_address}!", | message=f"Invalid email address: {email_address}!", | ||||
| code=RetCode.OPERATING_ERROR, | |||||
| code=settings.RetCode.OPERATING_ERROR, | |||||
| ) | ) | ||||
| # Check if the email address is already used | # Check if the email address is already used | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message=f"Email: {email_address} has already registered!", | message=f"Email: {email_address} has already registered!", | ||||
| code=RetCode.OPERATING_ERROR, | |||||
| code=settings.RetCode.OPERATING_ERROR, | |||||
| ) | ) | ||||
| # Construct user info data | # Construct user info data | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, | data=False, | ||||
| message=f"User registration failure, error: {str(e)}", | message=f"User registration failure, error: {str(e)}", | ||||
| code=RetCode.EXCEPTION_ERROR, | |||||
| code=settings.RetCode.EXCEPTION_ERROR, | |||||
| ) | ) | ||||
| ) | ) | ||||
| from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase | from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase | ||||
| from api.db import SerializedType, ParserType | from api.db import SerializedType, ParserType | ||||
| from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE | |||||
| from api import settings | |||||
| from api import utils | from api import utils | ||||
| def singleton(cls, *args, **kw): | def singleton(cls, *args, **kw): | ||||
| class LongTextField(TextField): | class LongTextField(TextField): | ||||
| field_type = TextFieldType[DATABASE_TYPE.upper()].value | |||||
| field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value | |||||
| class JSONField(LongTextField): | class JSONField(LongTextField): | ||||
| @singleton | @singleton | ||||
| class BaseDataBase: | class BaseDataBase: | ||||
| def __init__(self): | def __init__(self): | ||||
| database_config = DATABASE.copy() | |||||
| database_config = settings.DATABASE.copy() | |||||
| db_name = database_config.pop("name") | db_name = database_config.pop("name") | ||||
| self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config) | |||||
| self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config) | |||||
| logging.info('init database on cluster mode successfully') | logging.info('init database on cluster mode successfully') | ||||
| class PostgresDatabaseLock: | class PostgresDatabaseLock: | ||||
| DB = BaseDataBase().database_connection | DB = BaseDataBase().database_connection | ||||
| DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value | |||||
| DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value | |||||
| def close_connection(): | def close_connection(): | ||||
| return self.email | return self.email | ||||
| def get_id(self): | def get_id(self): | ||||
| jwt = Serializer(secret_key=SECRET_KEY) | |||||
| jwt = Serializer(secret_key=settings.SECRET_KEY) | |||||
| return jwt.dumps(str(self.access_token)) | return jwt.dumps(str(self.access_token)) | ||||
| class Meta: | class Meta: | ||||
| def migrate_db(): | def migrate_db(): | ||||
| with DB.transaction(): | with DB.transaction(): | ||||
| migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB) | |||||
| migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) | |||||
| try: | try: | ||||
| migrate( | migrate( | ||||
| migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", | migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", | 
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle | from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle | ||||
| from api.db.services.user_service import TenantService, UserTenantService | from api.db.services.user_service import TenantService, UserTenantService | ||||
| from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL | |||||
| from api import settings | |||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| tenant = { | tenant = { | ||||
| "id": user_info["id"], | "id": user_info["id"], | ||||
| "name": user_info["nickname"] + "‘s Kingdom", | "name": user_info["nickname"] + "‘s Kingdom", | ||||
| "llm_id": CHAT_MDL, | |||||
| "embd_id": EMBEDDING_MDL, | |||||
| "asr_id": ASR_MDL, | |||||
| "parser_ids": PARSERS, | |||||
| "img2txt_id": IMAGE2TEXT_MDL | |||||
| "llm_id": settings.CHAT_MDL, | |||||
| "embd_id": settings.EMBEDDING_MDL, | |||||
| "asr_id": settings.ASR_MDL, | |||||
| "parser_ids": settings.PARSERS, | |||||
| "img2txt_id": settings.IMAGE2TEXT_MDL | |||||
| } | } | ||||
| usr_tenant = { | usr_tenant = { | ||||
| "tenant_id": user_info["id"], | "tenant_id": user_info["id"], | ||||
| "role": UserTenantRole.OWNER | "role": UserTenantRole.OWNER | ||||
| } | } | ||||
| tenant_llm = [] | tenant_llm = [] | ||||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||||
| for llm in LLMService.query(fid=settings.LLM_FACTORY): | |||||
| tenant_llm.append( | tenant_llm.append( | ||||
| {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, | |||||
| "api_key": API_KEY, "api_base": LLM_BASE_URL}) | |||||
| {"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name, | |||||
| "model_type": llm.model_type, | |||||
| "api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL}) | |||||
| if not UserService.save(**user_info): | if not UserService.save(**user_info): | ||||
| logging.error("can't init admin.") | logging.error("can't init admin.") | ||||
| chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | ||||
| msg = chat_mdl.chat(system="", history=[ | msg = chat_mdl.chat(system="", history=[ | ||||
| {"role": "user", "content": "Hello!"}], gen_conf={}) | |||||
| {"role": "user", "content": "Hello!"}], gen_conf={}) | |||||
| if msg.find("ERROR: ") == 0: | if msg.find("ERROR: ") == 0: | ||||
| logging.error( | logging.error( | ||||
| "'{}' dosen't work. {}".format( | "'{}' dosen't work. {}".format( | ||||
| start_time = time.time() | start_time = time.time() | ||||
| init_llm_factory() | init_llm_factory() | ||||
| #if not UserService.get_all().count(): | |||||
| # if not UserService.get_all().count(): | |||||
| # init_superuser() | # init_superuser() | ||||
| add_graph_templates() | add_graph_templates() | 
| from api.db.services.common_service import CommonService | from api.db.services.common_service import CommonService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | ||||
| from api.settings import retrievaler, kg_retrievaler | |||||
| from api import settings | |||||
| from rag.app.resume import forbidden_select_fields4resume | from rag.app.resume import forbidden_select_fields4resume | ||||
| from rag.nlp.search import index_name | from rag.nlp.search import index_name | ||||
| from rag.utils import rmSpace, num_tokens_from_string, encoder | from rag.utils import rmSpace, num_tokens_from_string, encoder | ||||
| return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} | return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} | ||||
| is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | ||||
| retr = retrievaler if not is_kg else kg_retrievaler | |||||
| retr = settings.retrievaler if not is_kg else settings.kg_retrievaler | |||||
| questions = [m["content"] for m in messages if m["role"] == "user"][-3:] | questions = [m["content"] for m in messages if m["role"] == "user"][-3:] | ||||
| attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None | attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None | ||||
| logging.debug(f"{question} get SQL(refined): {sql}") | logging.debug(f"{question} get SQL(refined): {sql}") | ||||
| tried_times += 1 | tried_times += 1 | ||||
| return retrievaler.sql_retrieval(sql, format="json"), sql | |||||
| return settings.retrievaler.sql_retrieval(sql, format="json"), sql | |||||
| tbl, sql = get_table() | tbl, sql = get_table() | ||||
| if tbl is None: | if tbl is None: | ||||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | embd_nms = list(set([kb.embd_id for kb in kbs])) | ||||
| is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | ||||
| retr = retrievaler if not is_kg else kg_retrievaler | |||||
| retr = settings.retrievaler if not is_kg else settings.kg_retrievaler | |||||
| embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) | embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) | ||||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) | chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) | 
| from peewee import fn | from peewee import fn | ||||
| from api.db.db_utils import bulk_insert_into_db | from api.db.db_utils import bulk_insert_into_db | ||||
| from api.settings import docStoreConn | |||||
| from api import settings | |||||
| from api.utils import current_timestamp, get_format_time, get_uuid | from api.utils import current_timestamp, get_format_time, get_uuid | ||||
| from graphrag.mind_map_extractor import MindMapExtractor | from graphrag.mind_map_extractor import MindMapExtractor | ||||
| from rag.settings import SVR_QUEUE_NAME | from rag.settings import SVR_QUEUE_NAME | ||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def remove_document(cls, doc, tenant_id): | def remove_document(cls, doc, tenant_id): | ||||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||||
| cls.clear_chunk_num(doc.id) | cls.clear_chunk_num(doc.id) | ||||
| return cls.delete_by_id(doc.id) | return cls.delete_by_id(doc.id) | ||||
| d["q_%d_vec" % len(v)] = v | d["q_%d_vec" % len(v)] = v | ||||
| for b in range(0, len(cks), es_bulk_size): | for b in range(0, len(cks), es_bulk_size): | ||||
| if try_create_idx: | if try_create_idx: | ||||
| if not docStoreConn.indexExist(idxnm, kb_id): | |||||
| docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) | |||||
| if not settings.docStoreConn.indexExist(idxnm, kb_id): | |||||
| settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) | |||||
| try_create_idx = False | try_create_idx = False | ||||
| docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) | |||||
| settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) | |||||
| DocumentService.increment_chunk_num( | DocumentService.increment_chunk_num( | ||||
| doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) | doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) | 
| from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
| from werkzeug.serving import run_simple | from werkzeug.serving import run_simple | ||||
| from api import settings | |||||
| from api.apps import app | from api.apps import app | ||||
| from api.db.runtime_config import RuntimeConfig | from api.db.runtime_config import RuntimeConfig | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.settings import ( | |||||
| HOST, HTTP_PORT | |||||
| ) | |||||
| from api import utils | from api import utils | ||||
| from api.db.db_models import init_database_tables as init_web_db | from api.db.db_models import init_database_tables as init_web_db | ||||
| f'project base: {utils.file_utils.get_project_base_directory()}' | f'project base: {utils.file_utils.get_project_base_directory()}' | ||||
| ) | ) | ||||
| show_configs() | show_configs() | ||||
| settings.init_settings() | |||||
| # init db | # init db | ||||
| init_web_db() | init_web_db() | ||||
| logging.info("run on debug mode") | logging.info("run on debug mode") | ||||
| RuntimeConfig.init_env() | RuntimeConfig.init_env() | ||||
| RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) | |||||
| RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT) | |||||
| thread = ThreadPoolExecutor(max_workers=1) | thread = ThreadPoolExecutor(max_workers=1) | ||||
| thread.submit(update_progress) | thread.submit(update_progress) | ||||
| try: | try: | ||||
| logging.info("RAGFlow HTTP server start...") | logging.info("RAGFlow HTTP server start...") | ||||
| run_simple( | run_simple( | ||||
| hostname=HOST, | |||||
| port=HTTP_PORT, | |||||
| hostname=settings.HOST_IP, | |||||
| port=settings.HOST_PORT, | |||||
| application=app, | application=app, | ||||
| threaded=True, | threaded=True, | ||||
| use_reloader=RuntimeConfig.DEBUG, | use_reloader=RuntimeConfig.DEBUG, | 
| REQUEST_WAIT_SEC = 2 | REQUEST_WAIT_SEC = 2 | ||||
| REQUEST_MAX_WAIT_SEC = 300 | REQUEST_MAX_WAIT_SEC = 300 | ||||
| LLM = get_base_config("user_default_llm", {}) | |||||
| LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") | |||||
| LLM_BASE_URL = LLM.get("base_url") | |||||
| CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = "" | |||||
| if not LIGHTEN: | |||||
| default_llm = { | |||||
| "Tongyi-Qianwen": { | |||||
| "chat_model": "qwen-plus", | |||||
| "embedding_model": "text-embedding-v2", | |||||
| "image2text_model": "qwen-vl-max", | |||||
| "asr_model": "paraformer-realtime-8k-v1", | |||||
| }, | |||||
| "OpenAI": { | |||||
| "chat_model": "gpt-3.5-turbo", | |||||
| "embedding_model": "text-embedding-ada-002", | |||||
| "image2text_model": "gpt-4-vision-preview", | |||||
| "asr_model": "whisper-1", | |||||
| }, | |||||
| "Azure-OpenAI": { | |||||
| "chat_model": "gpt-35-turbo", | |||||
| "embedding_model": "text-embedding-ada-002", | |||||
| "image2text_model": "gpt-4-vision-preview", | |||||
| "asr_model": "whisper-1", | |||||
| }, | |||||
| "ZHIPU-AI": { | |||||
| "chat_model": "glm-3-turbo", | |||||
| "embedding_model": "embedding-2", | |||||
| "image2text_model": "glm-4v", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "Ollama": { | |||||
| "chat_model": "qwen-14B-chat", | |||||
| "embedding_model": "flag-embedding", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "Moonshot": { | |||||
| "chat_model": "moonshot-v1-8k", | |||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "DeepSeek": { | |||||
| "chat_model": "deepseek-chat", | |||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "VolcEngine": { | |||||
| "chat_model": "", | |||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "BAAI": { | |||||
| "chat_model": "", | |||||
| "embedding_model": "BAAI/bge-large-zh-v1.5", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| "rerank_model": "BAAI/bge-reranker-v2-m3", | |||||
| } | |||||
| } | |||||
| if LLM_FACTORY: | |||||
| CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" | |||||
| ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" | |||||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" | |||||
| EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" | |||||
| RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" | |||||
| API_KEY = LLM.get("api_key", "") | |||||
| PARSERS = LLM.get( | |||||
| "parsers", | |||||
| "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") | |||||
| HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||||
| HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||||
| SECRET_KEY = get_base_config( | |||||
| RAG_FLOW_SERVICE_NAME, | |||||
| {}).get("secret_key", str(date.today())) | |||||
| LLM = None | |||||
| LLM_FACTORY = None | |||||
| LLM_BASE_URL = None | |||||
| CHAT_MDL = "" | |||||
| EMBEDDING_MDL = "" | |||||
| RERANK_MDL = "" | |||||
| ASR_MDL = "" | |||||
| IMAGE2TEXT_MDL = "" | |||||
| API_KEY = None | |||||
| PARSERS = None | |||||
| HOST_IP = None | |||||
| HOST_PORT = None | |||||
| SECRET_KEY = None | |||||
| DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') | DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') | ||||
| DATABASE = decrypt_database_config(name=DATABASE_TYPE) | DATABASE = decrypt_database_config(name=DATABASE_TYPE) | ||||
| # authentication | # authentication | ||||
| AUTHENTICATION_CONF = get_base_config("authentication", {}) | |||||
| AUTHENTICATION_CONF = None | |||||
| # client | # client | ||||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( | |||||
| "client", {}).get( | |||||
| "switch", False) | |||||
| HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") | |||||
| GITHUB_OAUTH = get_base_config("oauth", {}).get("github") | |||||
| FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") | |||||
| DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") | |||||
| if DOC_ENGINE == "elasticsearch": | |||||
| docStoreConn = rag.utils.es_conn.ESConnection() | |||||
| elif DOC_ENGINE == "infinity": | |||||
| docStoreConn = rag.utils.infinity_conn.InfinityConnection() | |||||
| else: | |||||
| raise Exception(f"Not supported doc engine: {DOC_ENGINE}") | |||||
| retrievaler = search.Dealer(docStoreConn) | |||||
| kg_retrievaler = kg_search.KGSearch(docStoreConn) | |||||
| CLIENT_AUTHENTICATION = None | |||||
| HTTP_APP_KEY = None | |||||
| GITHUB_OAUTH = None | |||||
| FEISHU_OAUTH = None | |||||
| DOC_ENGINE = None | |||||
| docStoreConn = None | |||||
| retrievaler = None | |||||
| kg_retrievaler = None | |||||
| def init_settings(): | |||||
| global LLM, LLM_FACTORY, LLM_BASE_URL | |||||
| LLM = get_base_config("user_default_llm", {}) | |||||
| LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") | |||||
| LLM_BASE_URL = LLM.get("base_url") | |||||
| global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL | |||||
| if not LIGHTEN: | |||||
| default_llm = { | |||||
| "Tongyi-Qianwen": { | |||||
| "chat_model": "qwen-plus", | |||||
| "embedding_model": "text-embedding-v2", | |||||
| "image2text_model": "qwen-vl-max", | |||||
| "asr_model": "paraformer-realtime-8k-v1", | |||||
| }, | |||||
| "OpenAI": { | |||||
| "chat_model": "gpt-3.5-turbo", | |||||
| "embedding_model": "text-embedding-ada-002", | |||||
| "image2text_model": "gpt-4-vision-preview", | |||||
| "asr_model": "whisper-1", | |||||
| }, | |||||
| "Azure-OpenAI": { | |||||
| "chat_model": "gpt-35-turbo", | |||||
| "embedding_model": "text-embedding-ada-002", | |||||
| "image2text_model": "gpt-4-vision-preview", | |||||
| "asr_model": "whisper-1", | |||||
| }, | |||||
| "ZHIPU-AI": { | |||||
| "chat_model": "glm-3-turbo", | |||||
| "embedding_model": "embedding-2", | |||||
| "image2text_model": "glm-4v", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "Ollama": { | |||||
| "chat_model": "qwen-14B-chat", | |||||
| "embedding_model": "flag-embedding", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "Moonshot": { | |||||
| "chat_model": "moonshot-v1-8k", | |||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "DeepSeek": { | |||||
| "chat_model": "deepseek-chat", | |||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "VolcEngine": { | |||||
| "chat_model": "", | |||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "BAAI": { | |||||
| "chat_model": "", | |||||
| "embedding_model": "BAAI/bge-large-zh-v1.5", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| "rerank_model": "BAAI/bge-reranker-v2-m3", | |||||
| } | |||||
| } | |||||
| if LLM_FACTORY: | |||||
| CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" | |||||
| ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" | |||||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" | |||||
| EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" | |||||
| RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" | |||||
| global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY | |||||
| API_KEY = LLM.get("api_key", "") | |||||
| PARSERS = LLM.get( | |||||
| "parsers", | |||||
| "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") | |||||
| HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||||
| HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||||
| SECRET_KEY = get_base_config( | |||||
| RAG_FLOW_SERVICE_NAME, | |||||
| {}).get("secret_key", str(date.today())) | |||||
| global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH | |||||
| # authentication | |||||
| AUTHENTICATION_CONF = get_base_config("authentication", {}) | |||||
| # client | |||||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( | |||||
| "client", {}).get( | |||||
| "switch", False) | |||||
| HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") | |||||
| GITHUB_OAUTH = get_base_config("oauth", {}).get("github") | |||||
| FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") | |||||
| global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler | |||||
| DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") | |||||
| if DOC_ENGINE == "elasticsearch": | |||||
| docStoreConn = rag.utils.es_conn.ESConnection() | |||||
| elif DOC_ENGINE == "infinity": | |||||
| docStoreConn = rag.utils.infinity_conn.InfinityConnection() | |||||
| else: | |||||
| raise Exception(f"Not supported doc engine: {DOC_ENGINE}") | |||||
| retrievaler = search.Dealer(docStoreConn) | |||||
| kg_retrievaler = kg_search.KGSearch(docStoreConn) | |||||
| def get_host_ip(): | |||||
| global HOST_IP | |||||
| return HOST_IP | |||||
| def get_host_port(): | |||||
| global HOST_PORT | |||||
| return HOST_PORT | |||||
| class CustomEnum(Enum): | class CustomEnum(Enum): | 
| from werkzeug.http import HTTP_STATUS_CODES | from werkzeug.http import HTTP_STATUS_CODES | ||||
| from api.db.db_models import APIToken | from api.db.db_models import APIToken | ||||
| from api.settings import ( | |||||
| REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | |||||
| CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY | |||||
| ) | |||||
| from api.settings import RetCode | |||||
| from api import settings | |||||
| from api import settings | |||||
| from api.utils import CustomJSONEncoder, get_uuid | from api.utils import CustomJSONEncoder, get_uuid | ||||
| from api.utils import json_dumps | from api.utils import json_dumps | ||||
| {}).items()} | {}).items()} | ||||
| prepped = requests.Request(**kwargs).prepare() | prepped = requests.Request(**kwargs).prepare() | ||||
| if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: | |||||
| if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY: | |||||
| timestamp = str(round(time() * 1000)) | timestamp = str(round(time() * 1000)) | ||||
| nonce = str(uuid1()) | nonce = str(uuid1()) | ||||
| signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([ | |||||
| signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([ | |||||
| timestamp.encode('ascii'), | timestamp.encode('ascii'), | ||||
| nonce.encode('ascii'), | nonce.encode('ascii'), | ||||
| HTTP_APP_KEY.encode('ascii'), | |||||
| settings.HTTP_APP_KEY.encode('ascii'), | |||||
| prepped.path_url.encode('ascii'), | prepped.path_url.encode('ascii'), | ||||
| prepped.body if kwargs.get('json') else b'', | prepped.body if kwargs.get('json') else b'', | ||||
| urlencode( | urlencode( | ||||
| prepped.headers.update({ | prepped.headers.update({ | ||||
| 'TIMESTAMP': timestamp, | 'TIMESTAMP': timestamp, | ||||
| 'NONCE': nonce, | 'NONCE': nonce, | ||||
| 'APP-KEY': HTTP_APP_KEY, | |||||
| 'APP-KEY': settings.HTTP_APP_KEY, | |||||
| 'SIGNATURE': signature, | 'SIGNATURE': signature, | ||||
| }) | }) | ||||
| def get_exponential_backoff_interval(retries, full_jitter=False): | def get_exponential_backoff_interval(retries, full_jitter=False): | ||||
| """Calculate the exponential backoff wait time.""" | """Calculate the exponential backoff wait time.""" | ||||
| # Will be zero if factor equals 0 | # Will be zero if factor equals 0 | ||||
| countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) | |||||
| countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries)) | |||||
| # Full jitter according to | # Full jitter according to | ||||
| # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ | # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ | ||||
| if full_jitter: | if full_jitter: | ||||
| return max(0, countdown) | return max(0, countdown) | ||||
| def get_data_error_result(code=RetCode.DATA_ERROR, | |||||
| def get_data_error_result(code=settings.RetCode.DATA_ERROR, | |||||
| message='Sorry! Data missing!'): | message='Sorry! Data missing!'): | ||||
| import re | import re | ||||
| result_dict = { | result_dict = { | ||||
| pass | pass | ||||
| if len(e.args) > 1: | if len(e.args) > 1: | ||||
| return get_json_result( | return get_json_result( | ||||
| code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||||
| return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) | |||||
| code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||||
| return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) | |||||
| def error_response(response_code, message=None): | def error_response(response_code, message=None): | ||||
| error_string += "required argument values: {}".format( | error_string += "required argument values: {}".format( | ||||
| ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) | ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) | ||||
| return get_json_result( | return get_json_result( | ||||
| code=RetCode.ARGUMENT_ERROR, message=error_string) | |||||
| code=settings.RetCode.ARGUMENT_ERROR, message=error_string) | |||||
| return func(*_args, **_kwargs) | return func(*_args, **_kwargs) | ||||
| return decorated_function | return decorated_function | ||||
| return send_file(f, as_attachment=True, attachment_filename=filename) | return send_file(f, as_attachment=True, attachment_filename=filename) | ||||
| def get_json_result(code=RetCode.SUCCESS, message='success', data=None): | |||||
| def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): | |||||
| response = {"code": code, "message": message, "data": data} | response = {"code": code, "message": message, "data": data} | ||||
| return jsonify(response) | return jsonify(response) | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return build_error_result( | return build_error_result( | ||||
| message='API-KEY is invalid!', code=RetCode.FORBIDDEN | |||||
| message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN | |||||
| ) | ) | ||||
| kwargs['tenant_id'] = objs[0].tenant_id | kwargs['tenant_id'] = objs[0].tenant_id | ||||
| return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
| return decorated_function | return decorated_function | ||||
| def build_error_result(code=RetCode.FORBIDDEN, message='success'): | |||||
| def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'): | |||||
| response = {"code": code, "message": message} | response = {"code": code, "message": message} | ||||
| response = jsonify(response) | response = jsonify(response) | ||||
| response.status_code = code | response.status_code = code | ||||
| return response | return response | ||||
| def construct_response(code=RetCode.SUCCESS, | |||||
| def construct_response(code=settings.RetCode.SUCCESS, | |||||
| message='success', data=None, auth=None): | message='success', data=None, auth=None): | ||||
| result_dict = {"code": code, "message": message, "data": data} | result_dict = {"code": code, "message": message, "data": data} | ||||
| response_dict = {} | response_dict = {} | ||||
| return response | return response | ||||
| def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): | |||||
| def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'): | |||||
| import re | import re | ||||
| result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} | result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} | ||||
| response = {} | response = {} | ||||
| return jsonify(response) | return jsonify(response) | ||||
| def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): | |||||
| def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): | |||||
| if data is None: | if data is None: | ||||
| return jsonify({"code": code, "message": message}) | return jsonify({"code": code, "message": message}) | ||||
| else: | else: | ||||
| logging.exception(e) | logging.exception(e) | ||||
| try: | try: | ||||
| if e.code == 401: | if e.code == 401: | ||||
| return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) | |||||
| return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e)) | |||||
| except BaseException: | except BaseException: | ||||
| pass | pass | ||||
| if len(e.args) > 1: | if len(e.args) > 1: | ||||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) | |||||
| return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||||
| return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) | |||||
| def token_required(func): | def token_required(func): | ||||
| objs = APIToken.query(token=token) | objs = APIToken.query(token=token) | ||||
| if not objs: | if not objs: | ||||
| return get_json_result( | return get_json_result( | ||||
| data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR | |||||
| data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR | |||||
| ) | ) | ||||
| kwargs['tenant_id'] = objs[0].tenant_id | kwargs['tenant_id'] = objs[0].tenant_id | ||||
| return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
| return decorated_function | return decorated_function | ||||
| def get_result(code=RetCode.SUCCESS, message="", data=None): | |||||
| def get_result(code=settings.RetCode.SUCCESS, message="", data=None): | |||||
| if code == 0: | if code == 0: | ||||
| if data is not None: | if data is not None: | ||||
| response = {"code": code, "data": data} | response = {"code": code, "data": data} | ||||
| return jsonify(response) | return jsonify(response) | ||||
| def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR, | |||||
| def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR, | |||||
| ): | ): | ||||
| import re | import re | ||||
| result_dict = { | result_dict = { | 
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| from pypdf import PdfReader as pdf2_read | from pypdf import PdfReader as pdf2_read | ||||
| from api.settings import LIGHTEN | |||||
| from api import settings | |||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer | from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer | ||||
| from rag.nlp import rag_tokenizer | from rag.nlp import rag_tokenizer | ||||
| self.tbl_det = TableStructureRecognizer() | self.tbl_det = TableStructureRecognizer() | ||||
| self.updown_cnt_mdl = xgb.Booster() | self.updown_cnt_mdl = xgb.Booster() | ||||
| if not LIGHTEN: | |||||
| if not settings.LIGHTEN: | |||||
| try: | try: | ||||
| import torch | import torch | ||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | 
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.settings import retrievaler | |||||
| from api import settings | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) | kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) | ||||
| ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) | ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) | ||||
| docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] | |||||
| docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] | |||||
| info = { | info = { | ||||
| "input_text": docs, | "input_text": docs, | ||||
| "entity_specs": "organization, person", | "entity_specs": "organization, person", | 
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.settings import retrievaler | |||||
| from api import settings | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) | kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) | ||||
| ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) | ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) | ||||
| docs = [d["content_with_weight"] for d in | docs = [d["content_with_weight"] for d in | ||||
| retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])] | |||||
| settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])] | |||||
| graph = ex(docs) | graph = ex(docs) | ||||
| er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) | er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) | 
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.settings import retrievaler, docStoreConn | |||||
| from api import settings | |||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from rag.nlp import tokenize, search | from rag.nlp import tokenize, search | ||||
| from ranx import evaluate | from ranx import evaluate | ||||
| run = defaultdict(dict) | run = defaultdict(dict) | ||||
| query_list = list(qrels.keys()) | query_list = list(qrels.keys()) | ||||
| for query in query_list: | for query in query_list: | ||||
| ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, | |||||
| ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, | |||||
| 0.0, self.vector_similarity_weight) | 0.0, self.vector_similarity_weight) | ||||
| if len(ranks["chunks"]) == 0: | if len(ranks["chunks"]) == 0: | ||||
| print(f"deleted query: {query}") | print(f"deleted query: {query}") | ||||
| def init_index(self, vector_size: int): | def init_index(self, vector_size: int): | ||||
| if self.initialized_index: | if self.initialized_index: | ||||
| return | return | ||||
| if docStoreConn.indexExist(self.index_name, self.kb_id): | |||||
| docStoreConn.deleteIdx(self.index_name, self.kb_id) | |||||
| docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) | |||||
| if settings.docStoreConn.indexExist(self.index_name, self.kb_id): | |||||
| settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) | |||||
| settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) | |||||
| self.initialized_index = True | self.initialized_index = True | ||||
| def ms_marco_index(self, file_path, index_name): | def ms_marco_index(self, file_path, index_name): | ||||
| docs_count += len(docs) | docs_count += len(docs) | ||||
| docs, vector_size = self.embedding(docs) | docs, vector_size = self.embedding(docs) | ||||
| self.init_index(vector_size) | self.init_index(vector_size) | ||||
| docStoreConn.insert(docs, self.index_name, self.kb_id) | |||||
| settings.docStoreConn.insert(docs, self.index_name, self.kb_id) | |||||
| docs = [] | docs = [] | ||||
| if docs: | if docs: | ||||
| docs, vector_size = self.embedding(docs) | docs, vector_size = self.embedding(docs) | ||||
| self.init_index(vector_size) | self.init_index(vector_size) | ||||
| docStoreConn.insert(docs, self.index_name, self.kb_id) | |||||
| settings.docStoreConn.insert(docs, self.index_name, self.kb_id) | |||||
| return qrels, texts | return qrels, texts | ||||
| def trivia_qa_index(self, file_path, index_name): | def trivia_qa_index(self, file_path, index_name): | ||||
| docs_count += len(docs) | docs_count += len(docs) | ||||
| docs, vector_size = self.embedding(docs) | docs, vector_size = self.embedding(docs) | ||||
| self.init_index(vector_size) | self.init_index(vector_size) | ||||
| docStoreConn.insert(docs,self.index_name) | |||||
| settings.docStoreConn.insert(docs,self.index_name) | |||||
| docs = [] | docs = [] | ||||
| docs, vector_size = self.embedding(docs) | docs, vector_size = self.embedding(docs) | ||||
| self.init_index(vector_size) | self.init_index(vector_size) | ||||
| docStoreConn.insert(docs, self.index_name) | |||||
| settings.docStoreConn.insert(docs, self.index_name) | |||||
| return qrels, texts | return qrels, texts | ||||
| def miracl_index(self, file_path, corpus_path, index_name): | def miracl_index(self, file_path, corpus_path, index_name): | ||||
| docs_count += len(docs) | docs_count += len(docs) | ||||
| docs, vector_size = self.embedding(docs) | docs, vector_size = self.embedding(docs) | ||||
| self.init_index(vector_size) | self.init_index(vector_size) | ||||
| docStoreConn.insert(docs, self.index_name) | |||||
| settings.docStoreConn.insert(docs, self.index_name) | |||||
| docs = [] | docs = [] | ||||
| docs, vector_size = self.embedding(docs) | docs, vector_size = self.embedding(docs) | ||||
| self.init_index(vector_size) | self.init_index(vector_size) | ||||
| docStoreConn.insert(docs, self.index_name) | |||||
| settings.docStoreConn.insert(docs, self.index_name) | |||||
| return qrels, texts | return qrels, texts | ||||
| def save_results(self, qrels, run, texts, dataset, file_path): | def save_results(self, qrels, run, texts, dataset, file_path): | 
| import numpy as np | import numpy as np | ||||
| import asyncio | import asyncio | ||||
| from api.settings import LIGHTEN | |||||
| from api import settings | |||||
| from api.utils.file_utils import get_home_cache_dir | from api.utils.file_utils import get_home_cache_dir | ||||
| from rag.utils import num_tokens_from_string, truncate | from rag.utils import num_tokens_from_string, truncate | ||||
| import google.generativeai as genai | import google.generativeai as genai | ||||
| ^_- | ^_- | ||||
| """ | """ | ||||
| if not LIGHTEN and not DefaultEmbedding._model: | |||||
| if not settings.LIGHTEN and not DefaultEmbedding._model: | |||||
| with DefaultEmbedding._model_lock: | with DefaultEmbedding._model_lock: | ||||
| from FlagEmbedding import FlagModel | from FlagEmbedding import FlagModel | ||||
| import torch | import torch | ||||
| threads: Optional[int] = None, | threads: Optional[int] = None, | ||||
| **kwargs, | **kwargs, | ||||
| ): | ): | ||||
| if not LIGHTEN and not FastEmbed._model: | |||||
| if not settings.LIGHTEN and not FastEmbed._model: | |||||
| from fastembed import TextEmbedding | from fastembed import TextEmbedding | ||||
| self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | ||||
| _client = None | _client = None | ||||
| def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | ||||
| if not LIGHTEN and not YoudaoEmbed._client: | |||||
| if not settings.LIGHTEN and not YoudaoEmbed._client: | |||||
| from BCEmbedding import EmbeddingModel as qanthing | from BCEmbedding import EmbeddingModel as qanthing | ||||
| try: | try: | ||||
| logging.info("LOADING BCE...") | logging.info("LOADING BCE...") | 
| from abc import ABC | from abc import ABC | ||||
| import numpy as np | import numpy as np | ||||
| from api.settings import LIGHTEN | |||||
| from api import settings | |||||
| from api.utils.file_utils import get_home_cache_dir | from api.utils.file_utils import get_home_cache_dir | ||||
| from rag.utils import num_tokens_from_string, truncate | from rag.utils import num_tokens_from_string, truncate | ||||
| import json | import json | ||||
| ^_- | ^_- | ||||
| """ | """ | ||||
| if not LIGHTEN and not DefaultRerank._model: | |||||
| if not settings.LIGHTEN and not DefaultRerank._model: | |||||
| import torch | import torch | ||||
| from FlagEmbedding import FlagReranker | from FlagEmbedding import FlagReranker | ||||
| with DefaultRerank._model_lock: | with DefaultRerank._model_lock: | ||||
| _model_lock = threading.Lock() | _model_lock = threading.Lock() | ||||
| def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): | def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): | ||||
| if not LIGHTEN and not YoudaoRerank._model: | |||||
| if not settings.LIGHTEN and not YoudaoRerank._model: | |||||
| from BCEmbedding import RerankerModel | from BCEmbedding import RerankerModel | ||||
| with YoudaoRerank._model_lock: | with YoudaoRerank._model_lock: | ||||
| if not YoudaoRerank._model: | if not YoudaoRerank._model: | 
| import logging | import logging | ||||
| import sys | import sys | ||||
| from api.utils.log_utils import initRootLogger | from api.utils.log_utils import initRootLogger | ||||
| CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] | CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] | ||||
| initRootLogger(f"task_executor_{CONSUMER_NO}") | initRootLogger(f"task_executor_{CONSUMER_NO}") | ||||
| for module in ["pdfminer"]: | for module in ["pdfminer"]: | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.task_service import TaskService | from api.db.services.task_service import TaskService | ||||
| from api.db.services.file2document_service import File2DocumentService | from api.db.services.file2document_service import File2DocumentService | ||||
| from api.settings import retrievaler, docStoreConn | |||||
| from api import settings | |||||
| from api.db.db_models import close_connection | from api.db.db_models import close_connection | ||||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email | |||||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ | |||||
| knowledge_graph, email | |||||
| from rag.nlp import search, rag_tokenizer | from rag.nlp import search, rag_tokenizer | ||||
| from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | ||||
| from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME | from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME | ||||
| HEAD_CREATED_AT = "" | HEAD_CREATED_AT = "" | ||||
| HEAD_DETAIL = "" | HEAD_DETAIL = "" | ||||
| def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): | def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): | ||||
| global PAYLOAD | global PAYLOAD | ||||
| if prog is not None and prog < 0: | if prog is not None and prog < 0: | ||||
| "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) | "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) | ||||
| except TimeoutError: | except TimeoutError: | ||||
| callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") | callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") | ||||
| logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) | |||||
| logging.exception( | |||||
| "Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) | |||||
| return | return | ||||
| except Exception as e: | except Exception as e: | ||||
| if re.search("(No such file|not found)", str(e)): | if re.search("(No such file|not found)", str(e)): | ||||
| logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"])) | logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"])) | ||||
| except Exception as e: | except Exception as e: | ||||
| callback(-1, "Internal server error while chunking: %s" % | callback(-1, "Internal server error while chunking: %s" % | ||||
| str(e).replace("'", "")) | |||||
| str(e).replace("'", "")) | |||||
| logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"])) | logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"])) | ||||
| return | return | ||||
| STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue()) | STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue()) | ||||
| el += timer() - st | el += timer() - st | ||||
| except Exception: | except Exception: | ||||
| logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) | |||||
| logging.exception( | |||||
| "Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) | |||||
| d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) | d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) | ||||
| del d["image"] | del d["image"] | ||||
| d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"], | d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"], | ||||
| row["parser_config"]["auto_keywords"]).split(",") | row["parser_config"]["auto_keywords"]).split(",") | ||||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) | d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) | ||||
| callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st)) | |||||
| callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st)) | |||||
| if row["parser_config"].get("auto_questions", 0): | if row["parser_config"].get("auto_questions", 0): | ||||
| st = timer() | st = timer() | ||||
| d["content_ltks"] += " " + qst | d["content_ltks"] += " " + qst | ||||
| if "content_sm_ltks" in d: | if "content_sm_ltks" in d: | ||||
| d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst) | d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst) | ||||
| callback(msg="Question generation completed in {:.2f}s".format(timer()-st)) | |||||
| callback(msg="Question generation completed in {:.2f}s".format(timer() - st)) | |||||
| return docs | return docs | ||||
| def init_kb(row, vector_size: int): | def init_kb(row, vector_size: int): | ||||
| idxnm = search.index_name(row["tenant_id"]) | idxnm = search.index_name(row["tenant_id"]) | ||||
| return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size) | |||||
| return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size) | |||||
| def embedding(docs, mdl, parser_config=None, callback=None): | def embedding(docs, mdl, parser_config=None, callback=None): | ||||
| vector_size = len(vts[0]) | vector_size = len(vts[0]) | ||||
| vctr_nm = "q_%d_vec" % vector_size | vctr_nm = "q_%d_vec" % vector_size | ||||
| chunks = [] | chunks = [] | ||||
| for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]): | |||||
| for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | |||||
| fields=["content_with_weight", vctr_nm]): | |||||
| chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) | chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) | ||||
| raptor = Raptor( | raptor = Raptor( | ||||
| # TODO: exception handler | # TODO: exception handler | ||||
| ## set_progress(r["did"], -1, "ERROR: ") | ## set_progress(r["did"], -1, "ERROR: ") | ||||
| callback( | callback( | ||||
| msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st) | |||||
| msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), | |||||
| timer() - st) | |||||
| ) | ) | ||||
| st = timer() | st = timer() | ||||
| try: | try: | ||||
| es_r = "" | es_r = "" | ||||
| es_bulk_size = 4 | es_bulk_size = 4 | ||||
| for b in range(0, len(cks), es_bulk_size): | for b in range(0, len(cks), es_bulk_size): | ||||
| es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) | |||||
| es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) | |||||
| if b % 128 == 0: | if b % 128 == 0: | ||||
| callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") | callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") | ||||
| logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) | logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) | ||||
| if es_r: | if es_r: | ||||
| callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") | callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") | ||||
| docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||||
| settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||||
| logging.error('Insert chunk error: ' + str(es_r)) | logging.error('Insert chunk error: ' + str(es_r)) | ||||
| else: | else: | ||||
| if TaskService.do_cancel(r["id"]): | if TaskService.do_cancel(r["id"]): | ||||
| docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||||
| settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||||
| continue | continue | ||||
| callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) | callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) | ||||
| callback(1., "Done!") | callback(1., "Done!") | ||||
| if PENDING_TASKS > 0: | if PENDING_TASKS > 0: | ||||
| head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) | head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) | ||||
| if head_info is not None: | if head_info is not None: | ||||
| seconds = int(head_info[0].split("-")[0])/1000 | |||||
| seconds = int(head_info[0].split("-")[0]) / 1000 | |||||
| HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() | HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() | ||||
| HEAD_DETAIL = head_info[1] | HEAD_DETAIL = head_info[1] | ||||
| REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) | REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) | ||||
| logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") | logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") | ||||
| expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30) | |||||
| expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30) | |||||
| if expired > 0: | if expired > 0: | ||||
| REDIS_CONN.zpopmin(CONSUMER_NAME, expired) | REDIS_CONN.zpopmin(CONSUMER_NAME, expired) | ||||
| except Exception: | except Exception: |