### What problem does this PR solve? 1. Module init won't connect database any more. 2. Config in settings need to be used with settings.CONFIG_NAME ### Type of change - [x] Refactoring Signed-off-by: jinhai <haijin.chn@gmail.com>tags/v0.14.0
| @@ -19,7 +19,7 @@ import pandas as pd | |||
| from api.db import LLMType | |||
| from api.db.services.dialog_service import message_fit_in | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.settings import retrievaler | |||
| from api import settings | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| @@ -63,18 +63,20 @@ class Generate(ComponentBase): | |||
| component_name = "Generate" | |||
| def get_dependent_components(self): | |||
| cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0] | |||
| cpnts = [para["component_id"] for para in self._param.parameters if | |||
| para.get("component_id") and para["component_id"].lower().find("answer") < 0] | |||
| return cpnts | |||
| def set_cite(self, retrieval_res, answer): | |||
| retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) | |||
| if "empty_response" in retrieval_res.columns: | |||
| retrieval_res["empty_response"].fillna("", inplace=True) | |||
| answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], | |||
| [ck["vector"] for _, ck in retrieval_res.iterrows()], | |||
| LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, | |||
| self._canvas.get_embedding_model()), tkweight=0.7, | |||
| vtweight=0.3) | |||
| answer, idx = settings.retrievaler.insert_citations(answer, | |||
| [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], | |||
| [ck["vector"] for _, ck in retrieval_res.iterrows()], | |||
| LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, | |||
| self._canvas.get_embedding_model()), tkweight=0.7, | |||
| vtweight=0.3) | |||
| doc_ids = set([]) | |||
| recall_docs = [] | |||
| for i in idx: | |||
| @@ -127,12 +129,14 @@ class Generate(ComponentBase): | |||
| else: | |||
| if cpn.component_name.lower() == "retrieval": | |||
| retrieval_res.append(out) | |||
| kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]]) | |||
| kwargs[para["key"]] = " - " + "\n - ".join( | |||
| [o if isinstance(o, str) else str(o) for o in out["content"]]) | |||
| self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) | |||
| if retrieval_res: | |||
| retrieval_res = pd.concat(retrieval_res, ignore_index=True) | |||
| else: retrieval_res = pd.DataFrame([]) | |||
| else: | |||
| retrieval_res = pd.DataFrame([]) | |||
| for n, v in kwargs.items(): | |||
| prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) | |||
| @@ -21,7 +21,7 @@ import pandas as pd | |||
| from api.db import LLMType | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.settings import retrievaler | |||
| from api import settings | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| @@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC): | |||
| if self._param.rerank_id: | |||
| rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) | |||
| kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, | |||
| kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, | |||
| 1, self._param.top_n, | |||
| self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, | |||
| aggs=False, rerank_mdl=rerank_mdl) | |||
| @@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands | |||
| from flask_session import Session | |||
| from flask_login import LoginManager | |||
| from api.settings import SECRET_KEY | |||
| from api.settings import API_VERSION | |||
| from api import settings | |||
| from api.utils.api_utils import server_error_response | |||
| from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | |||
| @@ -78,7 +77,6 @@ app.url_map.strict_slashes = False | |||
| app.json_encoder = CustomJSONEncoder | |||
| app.errorhandler(Exception)(server_error_response) | |||
| ## convince for dev and debug | |||
| # app.config["LOGIN_DISABLED"] = True | |||
| app.config["SESSION_PERMANENT"] = False | |||
| @@ -110,7 +108,7 @@ def register_page(page_path): | |||
| page_name = page_path.stem.rstrip("_app") | |||
| module_name = ".".join( | |||
| page_path.parts[page_path.parts.index("api") : -1] + (page_name,) | |||
| page_path.parts[page_path.parts.index("api"): -1] + (page_name,) | |||
| ) | |||
| spec = spec_from_file_location(module_name, page_path) | |||
| @@ -121,7 +119,7 @@ def register_page(page_path): | |||
| spec.loader.exec_module(page) | |||
| page_name = getattr(page, "page_name", page_name) | |||
| url_prefix = ( | |||
| f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}" | |||
| f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}" | |||
| ) | |||
| app.register_blueprint(page.manager, url_prefix=url_prefix) | |||
| @@ -141,7 +139,7 @@ client_urls_prefix = [ | |||
| @login_manager.request_loader | |||
| def load_user(web_request): | |||
| jwt = Serializer(secret_key=SECRET_KEY) | |||
| jwt = Serializer(secret_key=settings.SECRET_KEY) | |||
| authorization = web_request.headers.get("Authorization") | |||
| if authorization: | |||
| try: | |||
| @@ -32,7 +32,7 @@ from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.task_service import queue_tasks, TaskService | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.settings import RetCode, retrievaler | |||
| from api import settings | |||
| from api.utils import get_uuid, current_timestamp, datetime_format | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ | |||
| generate_confirmation_token | |||
| @@ -141,7 +141,7 @@ def set_conversation(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req = request.json | |||
| try: | |||
| if objs[0].source == "agent": | |||
| @@ -183,7 +183,7 @@ def completion(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req = request.json | |||
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | |||
| if not e: | |||
| @@ -290,8 +290,8 @@ def completion(): | |||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |||
| rename_field(result) | |||
| return get_json_result(data=result) | |||
| #******************For dialog****************** | |||
| # ******************For dialog****************** | |||
| conv.message.append(msg[-1]) | |||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||
| if not e: | |||
| @@ -326,7 +326,7 @@ def completion(): | |||
| resp.headers.add_header("X-Accel-Buffering", "no") | |||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | |||
| return resp | |||
| answer = None | |||
| for ans in chat(dia, msg, **req): | |||
| answer = ans | |||
| @@ -347,8 +347,8 @@ def get(conversation_id): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| try: | |||
| e, conv = API4ConversationService.get_by_id(conversation_id) | |||
| if not e: | |||
| @@ -357,8 +357,8 @@ def get(conversation_id): | |||
| conv = conv.to_dict() | |||
| if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: | |||
| return get_json_result(data=False, message='Token is not valid for this conversation_id!"', | |||
| code=RetCode.AUTHENTICATION_ERROR) | |||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| for referenct_i in conv['reference']: | |||
| if referenct_i is None or len(referenct_i) == 0: | |||
| continue | |||
| @@ -378,7 +378,7 @@ def upload(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| kb_name = request.form.get("kb_name").strip() | |||
| tenant_id = objs[0].tenant_id | |||
| @@ -394,12 +394,12 @@ def upload(): | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| file = request.files['file'] | |||
| if file.filename == '': | |||
| return get_json_result( | |||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| root_folder = FileService.get_root_folder(tenant_id) | |||
| pf_id = root_folder["id"] | |||
| @@ -490,17 +490,17 @@ def upload_parse(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| file_objs = request.files.getlist('file') | |||
| for file_obj in file_objs: | |||
| if file_obj.filename == '': | |||
| return get_json_result( | |||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) | |||
| return get_json_result(data=doc_ids) | |||
| @@ -513,7 +513,7 @@ def list_chunks(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req = request.json | |||
| @@ -531,7 +531,7 @@ def list_chunks(): | |||
| ) | |||
| kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | |||
| res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids) | |||
| res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids) | |||
| res = [ | |||
| { | |||
| "content": res_item["content_with_weight"], | |||
| @@ -553,7 +553,7 @@ def list_kb_docs(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req = request.json | |||
| tenant_id = objs[0].tenant_id | |||
| @@ -585,6 +585,7 @@ def list_kb_docs(): | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/document/infos', methods=['POST']) | |||
| @validate_request("doc_ids") | |||
| def docinfos(): | |||
| @@ -592,7 +593,7 @@ def docinfos(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req = request.json | |||
| doc_ids = req["doc_ids"] | |||
| docs = DocumentService.get_by_ids(doc_ids) | |||
| @@ -606,7 +607,7 @@ def document_rm(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| tenant_id = objs[0].tenant_id | |||
| req = request.json | |||
| @@ -653,7 +654,7 @@ def document_rm(): | |||
| errors += str(e) | |||
| if errors: | |||
| return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) | |||
| return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR) | |||
| return get_json_result(data=True) | |||
| @@ -668,7 +669,7 @@ def completion_faq(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | |||
| if not e: | |||
| @@ -805,10 +806,10 @@ def retrieval(): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req = request.json | |||
| kb_ids = req.get("kb_id",[]) | |||
| kb_ids = req.get("kb_id", []) | |||
| doc_ids = req.get("doc_ids", []) | |||
| question = req.get("question") | |||
| page = int(req.get("page", 1)) | |||
| @@ -822,20 +823,21 @@ def retrieval(): | |||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | |||
| if len(embd_nms) != 1: | |||
| return get_json_result( | |||
| data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR) | |||
| data=False, message='Knowledge bases use different embedding models or does not exist."', | |||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) | |||
| rerank_mdl = None | |||
| if req.get("rerank_id"): | |||
| rerank_mdl = TenantLLMService.model_instance( | |||
| kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||
| kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||
| if req.get("keyword", False): | |||
| chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) | |||
| question += keyword_extraction(chat_mdl, question) | |||
| ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, | |||
| similarity_threshold, vector_similarity_weight, top, | |||
| doc_ids, rerank_mdl=rerank_mdl) | |||
| ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, | |||
| similarity_threshold, vector_similarity_weight, top, | |||
| doc_ids, rerank_mdl=rerank_mdl) | |||
| for c in ranks["chunks"]: | |||
| if "vector" in c: | |||
| del c["vector"] | |||
| @@ -843,5 +845,5 @@ def retrieval(): | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| return get_json_result(data=False, message='No chunk found! Check the chunk status please!', | |||
| code=RetCode.DATA_ERROR) | |||
| code=settings.RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @@ -19,7 +19,7 @@ from functools import partial | |||
| from flask import request, Response | |||
| from flask_login import login_required, current_user | |||
| from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService | |||
| from api.settings import RetCode | |||
| from api import settings | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result | |||
| from agent.canvas import Canvas | |||
| @@ -36,7 +36,8 @@ def templates(): | |||
| @login_required | |||
| def canvas_list(): | |||
| return get_json_result(data=sorted([c.to_dict() for c in \ | |||
| UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1) | |||
| UserCanvasService.query(user_id=current_user.id)], | |||
| key=lambda x: x["update_time"] * -1) | |||
| ) | |||
| @@ -45,10 +46,10 @@ def canvas_list(): | |||
| @login_required | |||
| def rm(): | |||
| for i in request.json["canvas_ids"]: | |||
| if not UserCanvasService.query(user_id=current_user.id,id=i): | |||
| if not UserCanvasService.query(user_id=current_user.id, id=i): | |||
| return get_json_result( | |||
| data=False, message='Only owner of canvas authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| UserCanvasService.delete_by_id(i) | |||
| return get_json_result(data=True) | |||
| @@ -72,7 +73,7 @@ def save(): | |||
| if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | |||
| return get_json_result( | |||
| data=False, message='Only owner of canvas authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| UserCanvasService.update_by_id(req["id"], req) | |||
| return get_json_result(data=req) | |||
| @@ -98,7 +99,7 @@ def run(): | |||
| if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | |||
| return get_json_result( | |||
| data=False, message='Only owner of canvas authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| if not isinstance(cvs.dsl, str): | |||
| cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) | |||
| @@ -110,8 +111,8 @@ def run(): | |||
| if "message" in req: | |||
| canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) | |||
| if len([m for m in canvas.messages if m["role"] == "user"]) > 1: | |||
| #ten = TenantService.get_info_by(current_user.id)[0] | |||
| #req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages) | |||
| # ten = TenantService.get_info_by(current_user.id)[0] | |||
| # req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages) | |||
| pass | |||
| canvas.add_user_input(req["message"]) | |||
| answer = canvas.run(stream=stream) | |||
| @@ -122,7 +123,8 @@ def run(): | |||
| assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." | |||
| if stream: | |||
| assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." | |||
| assert isinstance(answer, | |||
| partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." | |||
| def sse(): | |||
| nonlocal answer, cvs | |||
| @@ -173,7 +175,7 @@ def reset(): | |||
| if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): | |||
| return get_json_result( | |||
| data=False, message='Only owner of canvas authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) | |||
| canvas.reset() | |||
| @@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn | |||
| from api import settings | |||
| from api.utils.api_utils import get_json_result | |||
| import hashlib | |||
| import re | |||
| @manager.route('/list', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id") | |||
| @@ -56,7 +57,7 @@ def list_chunk(): | |||
| } | |||
| if "available_int" in req: | |||
| query["available_int"] = int(req["available_int"]) | |||
| sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) | |||
| sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) | |||
| res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | |||
| for id in sres.ids: | |||
| d = { | |||
| @@ -72,13 +73,13 @@ def list_chunk(): | |||
| "positions": json.loads(sres.field[id].get("position_list", "[]")), | |||
| } | |||
| assert isinstance(d["positions"], list) | |||
| assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) | |||
| assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) | |||
| res["chunks"].append(d) | |||
| return get_json_result(data=res) | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| return get_json_result(data=False, message='No chunk found!', | |||
| code=RetCode.DATA_ERROR) | |||
| code=settings.RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @@ -93,7 +94,7 @@ def get(): | |||
| tenant_id = tenants[0].tenant_id | |||
| kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | |||
| chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) | |||
| chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) | |||
| if chunk is None: | |||
| return server_error_response("Chunk not found") | |||
| k = [] | |||
| @@ -107,7 +108,7 @@ def get(): | |||
| except Exception as e: | |||
| if str(e).find("NotFoundError") >= 0: | |||
| return get_json_result(data=False, message='Chunk not found!', | |||
| code=RetCode.DATA_ERROR) | |||
| code=settings.RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @@ -154,7 +155,7 @@ def set(): | |||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||
| settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -169,8 +170,8 @@ def switch(): | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(message="Document not found!") | |||
| if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])}, | |||
| search.index_name(doc.tenant_id), doc.kb_id): | |||
| if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])}, | |||
| search.index_name(doc.tenant_id), doc.kb_id): | |||
| return get_data_error_result(message="Index updating failure") | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -186,7 +187,7 @@ def rm(): | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(message="Document not found!") | |||
| if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): | |||
| if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): | |||
| return get_data_error_result(message="Index updating failure") | |||
| deleted_chunk_ids = req["chunk_ids"] | |||
| chunk_number = len(deleted_chunk_ids) | |||
| @@ -230,7 +231,7 @@ def create(): | |||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||
| settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | |||
| DocumentService.increment_chunk_num( | |||
| doc.id, doc.kb_id, c, 1, 0) | |||
| @@ -265,7 +266,7 @@ def retrieval_test(): | |||
| else: | |||
| return get_json_result( | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) | |||
| if not e: | |||
| @@ -281,7 +282,7 @@ def retrieval_test(): | |||
| chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) | |||
| question += keyword_extraction(chat_mdl, question) | |||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||
| retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler | |||
| ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size, | |||
| similarity_threshold, vector_similarity_weight, top, | |||
| doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) | |||
| @@ -293,7 +294,7 @@ def retrieval_test(): | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| return get_json_result(data=False, message='No chunk found! Check the chunk status please!', | |||
| code=RetCode.DATA_ERROR) | |||
| code=settings.RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @@ -304,10 +305,10 @@ def knowledge_graph(): | |||
| tenant_id = DocumentService.get_tenant_id(doc_id) | |||
| kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | |||
| req = { | |||
| "doc_ids":[doc_id], | |||
| "doc_ids": [doc_id], | |||
| "knowledge_graph_kwd": ["graph", "mind_map"] | |||
| } | |||
| sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids) | |||
| sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids) | |||
| obj = {"graph": {}, "mind_map": {}} | |||
| for id in sres.ids[:2]: | |||
| ty = sres.field[id]["knowledge_graph_kwd"] | |||
| @@ -336,4 +337,3 @@ def knowledge_graph(): | |||
| obj[ty] = content_json | |||
| return get_json_result(data=obj) | |||
| @@ -25,7 +25,7 @@ from api.db import LLMType | |||
| from api.db.services.dialog_service import DialogService, ConversationService, chat, ask | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService | |||
| from api.settings import RetCode, retrievaler | |||
| from api import settings | |||
| from api.utils.api_utils import get_json_result | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from graphrag.mind_map_extractor import MindMapExtractor | |||
| @@ -87,7 +87,7 @@ def get(): | |||
| else: | |||
| return get_json_result( | |||
| data=False, message='Only owner of conversation authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| conv = conv.to_dict() | |||
| return get_json_result(data=conv) | |||
| except Exception as e: | |||
| @@ -110,7 +110,7 @@ def rm(): | |||
| else: | |||
| return get_json_result( | |||
| data=False, message='Only owner of conversation authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| ConversationService.delete_by_id(cid) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -125,7 +125,7 @@ def list_convsersation(): | |||
| if not DialogService.query(tenant_id=current_user.id, id=dialog_id): | |||
| return get_json_result( | |||
| data=False, message='Only owner of dialog authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| convs = ConversationService.query( | |||
| dialog_id=dialog_id, | |||
| order_by=ConversationService.model.create_time, | |||
| @@ -297,6 +297,7 @@ def thumbup(): | |||
| def ask_about(): | |||
| req = request.json | |||
| uid = current_user.id | |||
| def stream(): | |||
| nonlocal req, uid | |||
| try: | |||
| @@ -329,8 +330,8 @@ def mindmap(): | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |||
| chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) | |||
| ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, | |||
| 0.3, 0.3, aggs=False) | |||
| ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, | |||
| 0.3, 0.3, aggs=False) | |||
| mindmap = MindMapExtractor(chat_mdl) | |||
| mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output | |||
| if "error" in mind_map: | |||
| @@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService | |||
| from api.db import StatusEnum | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.settings import RetCode | |||
| from api import settings | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_json_result | |||
| @@ -175,7 +175,7 @@ def rm(): | |||
| else: | |||
| return get_json_result( | |||
| data=False, message='Only owner of dialog authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) | |||
| DialogService.update_many_by_id(dialog_list) | |||
| return get_json_result(data=True) | |||
| @@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va | |||
| from api.utils import get_uuid | |||
| from api.db import FileType, TaskStatus, ParserType, FileSource | |||
| from api.db.services.document_service import DocumentService, doc_upload_and_parse | |||
| from api.settings import RetCode, docStoreConn | |||
| from api import settings | |||
| from api.utils.api_utils import get_json_result | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory | |||
| @@ -49,16 +49,16 @@ def upload(): | |||
| kb_id = request.form.get("kb_id") | |||
| if not kb_id: | |||
| return get_json_result( | |||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| file_objs = request.files.getlist('file') | |||
| for file_obj in file_objs: | |||
| if file_obj.filename == '': | |||
| return get_json_result( | |||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| if not e: | |||
| @@ -67,7 +67,7 @@ def upload(): | |||
| err, _ = FileService.upload_document(kb, file_objs, current_user.id) | |||
| if err: | |||
| return get_json_result( | |||
| data=False, message="\n".join(err), code=RetCode.SERVER_ERROR) | |||
| data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) | |||
| return get_json_result(data=True) | |||
| @@ -78,12 +78,12 @@ def web_crawl(): | |||
| kb_id = request.form.get("kb_id") | |||
| if not kb_id: | |||
| return get_json_result( | |||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||
| name = request.form.get("name") | |||
| url = request.form.get("url") | |||
| if not is_valid_url(url): | |||
| return get_json_result( | |||
| data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| if not e: | |||
| raise LookupError("Can't find this knowledgebase!") | |||
| @@ -145,7 +145,7 @@ def create(): | |||
| kb_id = req["kb_id"] | |||
| if not kb_id: | |||
| return get_json_result( | |||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||
| try: | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| @@ -179,7 +179,7 @@ def list_docs(): | |||
| kb_id = request.args.get("kb_id") | |||
| if not kb_id: | |||
| return get_json_result( | |||
| data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||
| tenants = UserTenantService.query(user_id=current_user.id) | |||
| for tenant in tenants: | |||
| if KnowledgebaseService.query( | |||
| @@ -188,7 +188,7 @@ def list_docs(): | |||
| else: | |||
| return get_json_result( | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| keywords = request.args.get("keywords", "") | |||
| page_number = int(request.args.get("page", 1)) | |||
| @@ -218,19 +218,19 @@ def docinfos(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR | |||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| docs = DocumentService.get_by_ids(doc_ids) | |||
| return get_json_result(data=list(docs.dicts())) | |||
| @manager.route('/thumbnails', methods=['GET']) | |||
| #@login_required | |||
| # @login_required | |||
| def thumbnails(): | |||
| doc_ids = request.args.get("doc_ids").split(",") | |||
| if not doc_ids: | |||
| return get_json_result( | |||
| data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||
| try: | |||
| docs = DocumentService.get_thumbnails(doc_ids) | |||
| @@ -253,13 +253,13 @@ def change_status(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='"Status" must be either 0 or 1!', | |||
| code=RetCode.ARGUMENT_ERROR) | |||
| code=settings.RetCode.ARGUMENT_ERROR) | |||
| if not DocumentService.accessible(req["doc_id"], current_user.id): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR) | |||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| try: | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| @@ -276,7 +276,8 @@ def change_status(): | |||
| message="Database error (Document update)!") | |||
| status = int(req["status"]) | |||
| docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) | |||
| settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, | |||
| search.index_name(kb.tenant_id), doc.kb_id) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -295,7 +296,7 @@ def rm(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR | |||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| root_folder = FileService.get_root_folder(current_user.id) | |||
| @@ -326,7 +327,7 @@ def rm(): | |||
| errors += str(e) | |||
| if errors: | |||
| return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) | |||
| return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR) | |||
| return get_json_result(data=True) | |||
| @@ -341,7 +342,7 @@ def run(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR | |||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| try: | |||
| for id in req["doc_ids"]: | |||
| @@ -358,8 +359,8 @@ def run(): | |||
| e, doc = DocumentService.get_by_id(id) | |||
| if not e: | |||
| return get_data_error_result(message="Document not found!") | |||
| if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||
| docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) | |||
| if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||
| settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) | |||
| if str(req["run"]) == TaskStatus.RUNNING.value: | |||
| TaskService.filter_delete([Task.doc_id == id]) | |||
| @@ -383,7 +384,7 @@ def rename(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR | |||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| try: | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| @@ -394,7 +395,7 @@ def rename(): | |||
| return get_json_result( | |||
| data=False, | |||
| message="The extension of file can't be changed", | |||
| code=RetCode.ARGUMENT_ERROR) | |||
| code=settings.RetCode.ARGUMENT_ERROR) | |||
| for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | |||
| if d.name == req["name"]: | |||
| return get_data_error_result( | |||
| @@ -450,7 +451,7 @@ def change_parser(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR | |||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| try: | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| @@ -483,8 +484,8 @@ def change_parser(): | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: | |||
| return get_data_error_result(message="Tenant not found!") | |||
| if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||
| if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): | |||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -509,13 +510,13 @@ def get_image(image_id): | |||
| def upload_and_parse(): | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| file_objs = request.files.getlist('file') | |||
| for file_obj in file_objs: | |||
| if file_obj.filename == '': | |||
| return get_json_result( | |||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) | |||
| @@ -529,7 +530,7 @@ def parse(): | |||
| if url: | |||
| if not is_valid_url(url): | |||
| return get_json_result( | |||
| data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) | |||
| download_path = os.path.join(get_project_base_directory(), "logs/downloads") | |||
| os.makedirs(download_path, exist_ok=True) | |||
| from selenium.webdriver import Chrome, ChromeOptions | |||
| @@ -553,7 +554,7 @@ def parse(): | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| file_objs = request.files.getlist('file') | |||
| txt = FileService.parse_docs(file_objs, current_user.id) | |||
| @@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va | |||
| from api.utils import get_uuid | |||
| from api.db import FileType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import RetCode | |||
| from api import settings | |||
| from api.utils.api_utils import get_json_result | |||
| @@ -100,7 +100,7 @@ def rm(): | |||
| file_ids = req["file_ids"] | |||
| if not file_ids: | |||
| return get_json_result( | |||
| data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||
| try: | |||
| for file_id in file_ids: | |||
| informs = File2DocumentService.get_by_file_id(file_id) | |||
| @@ -28,7 +28,7 @@ from api.utils import get_uuid | |||
| from api.db import FileType, FileSource | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.file_service import FileService | |||
| from api.settings import RetCode | |||
| from api import settings | |||
| from api.utils.api_utils import get_json_result | |||
| from api.utils.file_utils import filename_type | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| @@ -46,13 +46,13 @@ def upload(): | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| file_objs = request.files.getlist('file') | |||
| for file_obj in file_objs: | |||
| if file_obj.filename == '': | |||
| return get_json_result( | |||
| data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) | |||
| data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) | |||
| file_res = [] | |||
| try: | |||
| for file_obj in file_objs: | |||
| @@ -134,7 +134,7 @@ def create(): | |||
| try: | |||
| if not FileService.is_parent_folder_exist(pf_id): | |||
| return get_json_result( | |||
| data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR) | |||
| data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR) | |||
| if FileService.query(name=req["name"], parent_id=pf_id): | |||
| return get_data_error_result( | |||
| message="Duplicated folder name in the same folder.") | |||
| @@ -299,7 +299,7 @@ def rename(): | |||
| return get_json_result( | |||
| data=False, | |||
| message="The extension of file can't be changed", | |||
| code=RetCode.ARGUMENT_ERROR) | |||
| code=settings.RetCode.ARGUMENT_ERROR) | |||
| for file in FileService.query(name=req["name"], pf_id=file.parent_id): | |||
| if file.name == req["name"]: | |||
| return get_data_error_result( | |||
| @@ -26,9 +26,8 @@ from api.utils import get_uuid | |||
| from api.db import StatusEnum, FileSource | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.db_models import File | |||
| from api.settings import RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| from api.settings import docStoreConn | |||
| from api import settings | |||
| from rag.nlp import search | |||
| @@ -68,13 +67,13 @@ def update(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR | |||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| try: | |||
| if not KnowledgebaseService.query( | |||
| created_by=current_user.id, id=req["kb_id"]): | |||
| return get_json_result( | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) | |||
| if not e: | |||
| @@ -113,7 +112,7 @@ def detail(): | |||
| else: | |||
| return get_json_result( | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | |||
| code=RetCode.OPERATING_ERROR) | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| kb = KnowledgebaseService.get_detail(kb_id) | |||
| if not kb: | |||
| return get_data_error_result( | |||
| @@ -148,14 +147,14 @@ def rm(): | |||
| return get_json_result( | |||
| data=False, | |||
| message='No authorization.', | |||
| code=RetCode.AUTHENTICATION_ERROR | |||
| code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| try: | |||
| kbs = KnowledgebaseService.query( | |||
| created_by=current_user.id, id=req["kb_id"]) | |||
| if not kbs: | |||
| return get_json_result( | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) | |||
| for doc in DocumentService.query(kb_id=req["kb_id"]): | |||
| if not DocumentService.remove_document(doc, kbs[0].tenant_id): | |||
| @@ -170,7 +169,7 @@ def rm(): | |||
| message="Database error (Knowledgebase removal)!") | |||
| tenants = UserTenantService.query(user_id=current_user.id) | |||
| for tenant in tenants: | |||
| docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"]) | |||
| settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"]) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -19,7 +19,7 @@ import json | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | |||
| from api.settings import LIGHTEN | |||
| from api import settings | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.db import StatusEnum, LLMType | |||
| from api.db.db_models import TenantLLM | |||
| @@ -333,7 +333,7 @@ def my_llms(): | |||
| @login_required | |||
| def list_app(): | |||
| self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] | |||
| weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else [] | |||
| weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] | |||
| model_type = request.args.get("model_type") | |||
| try: | |||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||
| @@ -14,7 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| from flask import request | |||
| from api.settings import RetCode | |||
| from api import settings | |||
| from api.db import StatusEnum | |||
| from api.db.services.dialog_service import DialogService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| @@ -44,7 +44,7 @@ def create(tenant_id): | |||
| kbs = KnowledgebaseService.get_by_ids(ids) | |||
| embd_count = list(set([kb.embd_id for kb in kbs])) | |||
| if len(embd_count) != 1: | |||
| return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR) | |||
| return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req["kb_ids"] = ids | |||
| # llm | |||
| llm = req.get("llm") | |||
| @@ -173,7 +173,7 @@ def update(tenant_id,chat_id): | |||
| if len(embd_count) != 1 : | |||
| return get_result( | |||
| message='Datasets use different embedding models."', | |||
| code=RetCode.AUTHENTICATION_ERROR) | |||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req["kb_ids"] = ids | |||
| llm = req.get("llm") | |||
| if llm: | |||
| @@ -23,7 +23,7 @@ from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService, LLMService | |||
| from api.db.services.user_service import TenantService | |||
| from api.settings import RetCode | |||
| from api import settings | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import ( | |||
| get_result, | |||
| @@ -255,7 +255,7 @@ def delete(tenant_id): | |||
| File2DocumentService.delete_by_document_id(doc.id) | |||
| if not KnowledgebaseService.delete_by_id(id): | |||
| return get_error_data_result(message="Delete dataset error.(Database error)") | |||
| return get_result(code=RetCode.SUCCESS) | |||
| return get_result(code=settings.RetCode.SUCCESS) | |||
| @manager.route("/datasets/<dataset_id>", methods=["PUT"]) | |||
| @@ -424,7 +424,7 @@ def update(tenant_id, dataset_id): | |||
| ) | |||
| if not KnowledgebaseService.update_by_id(kb.id, req): | |||
| return get_error_data_result(message="Update dataset error.(Database error)") | |||
| return get_result(code=RetCode.SUCCESS) | |||
| return get_result(code=settings.RetCode.SUCCESS) | |||
| @manager.route("/datasets", methods=["GET"]) | |||
| @@ -18,7 +18,7 @@ from flask import request, jsonify | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.settings import retrievaler, kg_retrievaler, RetCode | |||
| from api import settings | |||
| from api.utils.api_utils import validate_request, build_error_result, apikey_required | |||
| @@ -37,14 +37,14 @@ def retrieval(tenant_id): | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| if not e: | |||
| return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) | |||
| return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND) | |||
| if kb.tenant_id != tenant_id: | |||
| return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) | |||
| return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND) | |||
| embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||
| retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler | |||
| ranks = retr.retrieval( | |||
| question, | |||
| embd_mdl, | |||
| @@ -72,6 +72,6 @@ def retrieval(tenant_id): | |||
| if str(e).find("not_found") > 0: | |||
| return build_error_result( | |||
| message='No chunk found! Check the chunk status please!', | |||
| code=RetCode.NOT_FOUND | |||
| code=settings.RetCode.NOT_FOUND | |||
| ) | |||
| return build_error_result(message=str(e), code=RetCode.SERVER_ERROR) | |||
| return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR) | |||
| @@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc | |||
| from rag.nlp import rag_tokenizer | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.settings import kg_retrievaler | |||
| from api import settings | |||
| import hashlib | |||
| import re | |||
| from api.utils.api_utils import token_required | |||
| @@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.settings import RetCode, retrievaler | |||
| from api import settings | |||
| from api.utils.api_utils import construct_json_result, get_parser_config | |||
| from rag.nlp import search | |||
| from rag.utils import rmSpace | |||
| from api.settings import docStoreConn | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| import os | |||
| @@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id): | |||
| """ | |||
| if "file" not in request.files: | |||
| return get_error_data_result( | |||
| message="No file part!", code=RetCode.ARGUMENT_ERROR | |||
| message="No file part!", code=settings.RetCode.ARGUMENT_ERROR | |||
| ) | |||
| file_objs = request.files.getlist("file") | |||
| for file_obj in file_objs: | |||
| if file_obj.filename == "": | |||
| return get_result( | |||
| message="No file selected!", code=RetCode.ARGUMENT_ERROR | |||
| message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR | |||
| ) | |||
| # total size | |||
| total_size = 0 | |||
| @@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id): | |||
| if total_size > MAX_TOTAL_FILE_SIZE: | |||
| return get_result( | |||
| message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", | |||
| code=RetCode.ARGUMENT_ERROR, | |||
| code=settings.RetCode.ARGUMENT_ERROR, | |||
| ) | |||
| e, kb = KnowledgebaseService.get_by_id(dataset_id) | |||
| if not e: | |||
| raise LookupError(f"Can't find the dataset with ID {dataset_id}!") | |||
| err, files = FileService.upload_document(kb, file_objs, tenant_id) | |||
| if err: | |||
| return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR) | |||
| return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR) | |||
| # rename key's name | |||
| renamed_doc_list = [] | |||
| for file in files: | |||
| @@ -221,12 +220,12 @@ def update_doc(tenant_id, dataset_id, document_id): | |||
| if "name" in req and req["name"] != doc.name: | |||
| if ( | |||
| pathlib.Path(req["name"].lower()).suffix | |||
| != pathlib.Path(doc.name.lower()).suffix | |||
| pathlib.Path(req["name"].lower()).suffix | |||
| != pathlib.Path(doc.name.lower()).suffix | |||
| ): | |||
| return get_result( | |||
| message="The extension of file can't be changed", | |||
| code=RetCode.ARGUMENT_ERROR, | |||
| code=settings.RetCode.ARGUMENT_ERROR, | |||
| ) | |||
| for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): | |||
| if d.name == req["name"]: | |||
| @@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id): | |||
| ) | |||
| if not e: | |||
| return get_error_data_result(message="Document not found!") | |||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||
| return get_result() | |||
| @@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id): | |||
| file_stream = STORAGE_IMPL.get(doc_id, doc_location) | |||
| if not file_stream: | |||
| return construct_json_result( | |||
| message="This file is empty.", code=RetCode.DATA_ERROR | |||
| message="This file is empty.", code=settings.RetCode.DATA_ERROR | |||
| ) | |||
| file = BytesIO(file_stream) | |||
| # Use send_file with a proper filename and MIME type | |||
| @@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id): | |||
| errors += str(e) | |||
| if errors: | |||
| return get_result(message=errors, code=RetCode.SERVER_ERROR) | |||
| return get_result(message=errors, code=settings.RetCode.SERVER_ERROR) | |||
| return get_result() | |||
| @@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id): | |||
| info["chunk_num"] = 0 | |||
| info["token_num"] = 0 | |||
| DocumentService.update_by_id(id, info) | |||
| docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) | |||
| settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) | |||
| TaskService.filter_delete([Task.doc_id == id]) | |||
| e, doc = DocumentService.get_by_id(id) | |||
| doc = doc.to_dict() | |||
| @@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id): | |||
| ) | |||
| info = {"run": "2", "progress": 0, "chunk_num": 0} | |||
| DocumentService.update_by_id(id, info) | |||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) | |||
| return get_result() | |||
| @@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id): | |||
| res = {"total": 0, "chunks": [], "doc": renamed_doc} | |||
| origin_chunks = [] | |||
| if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): | |||
| sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) | |||
| if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): | |||
| sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, | |||
| highlight=True) | |||
| res["total"] = sres.total | |||
| sign = 0 | |||
| for id in sres.ids: | |||
| @@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id): | |||
| v, c = embd_mdl.encode([doc.name, req["content"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) | |||
| settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) | |||
| DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) | |||
| # rename keys | |||
| @@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id): | |||
| condition = {"doc_id": document_id} | |||
| if "chunk_ids" in req: | |||
| condition["id"] = req["chunk_ids"] | |||
| chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) | |||
| chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) | |||
| if chunk_number != 0: | |||
| DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) | |||
| if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]): | |||
| @@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): | |||
| schema: | |||
| type: object | |||
| """ | |||
| chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) | |||
| chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) | |||
| if chunk is None: | |||
| return get_error_data_result(f"Can't find this chunk {chunk_id}") | |||
| if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): | |||
| @@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): | |||
| v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) | |||
| settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) | |||
| return get_result() | |||
| @@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id): | |||
| if len(embd_nms) != 1: | |||
| return get_result( | |||
| message='Datasets use different embedding models."', | |||
| code=RetCode.AUTHENTICATION_ERROR, | |||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||
| ) | |||
| if "question" not in req: | |||
| return get_error_data_result("`question` is required.") | |||
| @@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id): | |||
| chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) | |||
| question += keyword_extraction(chat_mdl, question) | |||
| retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler | |||
| retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler | |||
| ranks = retr.retrieval( | |||
| question, | |||
| embd_mdl, | |||
| @@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id): | |||
| if str(e).find("not_found") > 0: | |||
| return get_result( | |||
| message="No chunk found! Check the chunk status please!", | |||
| code=RetCode.DATA_ERROR, | |||
| code=settings.RetCode.DATA_ERROR, | |||
| ) | |||
| return server_error_response(e) | |||
| return server_error_response(e) | |||
| @@ -22,7 +22,7 @@ from api.db.db_models import APIToken | |||
| from api.db.services.api_service import APITokenService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.settings import DATABASE_TYPE | |||
| from api import settings | |||
| from api.utils import current_timestamp, datetime_format | |||
| from api.utils.api_utils import ( | |||
| get_json_result, | |||
| @@ -31,7 +31,6 @@ from api.utils.api_utils import ( | |||
| generate_confirmation_token, | |||
| ) | |||
| from api.versions import get_ragflow_version | |||
| from api.settings import docStoreConn | |||
| from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE | |||
| from timeit import default_timer as timer | |||
| @@ -98,7 +97,7 @@ def status(): | |||
| res = {} | |||
| st = timer() | |||
| try: | |||
| res["doc_store"] = docStoreConn.health() | |||
| res["doc_store"] = settings.docStoreConn.health() | |||
| res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) | |||
| except Exception as e: | |||
| res["doc_store"] = { | |||
| @@ -128,13 +127,13 @@ def status(): | |||
| try: | |||
| KnowledgebaseService.get_by_id("x") | |||
| res["database"] = { | |||
| "database": DATABASE_TYPE.lower(), | |||
| "database": settings.DATABASE_TYPE.lower(), | |||
| "status": "green", | |||
| "elapsed": "{:.1f}".format((timer() - st) * 1000.0), | |||
| } | |||
| except Exception as e: | |||
| res["database"] = { | |||
| "database": DATABASE_TYPE.lower(), | |||
| "database": settings.DATABASE_TYPE.lower(), | |||
| "status": "red", | |||
| "elapsed": "{:.1f}".format((timer() - st) * 1000.0), | |||
| "error": str(e), | |||
| @@ -38,20 +38,7 @@ from api.utils import ( | |||
| datetime_format, | |||
| ) | |||
| from api.db import UserTenantRole, FileType | |||
| from api.settings import ( | |||
| RetCode, | |||
| GITHUB_OAUTH, | |||
| FEISHU_OAUTH, | |||
| CHAT_MDL, | |||
| EMBEDDING_MDL, | |||
| ASR_MDL, | |||
| IMAGE2TEXT_MDL, | |||
| PARSERS, | |||
| API_KEY, | |||
| LLM_FACTORY, | |||
| LLM_BASE_URL, | |||
| RERANK_MDL, | |||
| ) | |||
| from api import settings | |||
| from api.db.services.user_service import UserService, TenantService, UserTenantService | |||
| from api.db.services.file_service import FileService | |||
| from api.utils.api_utils import get_json_result, construct_response | |||
| @@ -90,7 +77,7 @@ def login(): | |||
| """ | |||
| if not request.json: | |||
| return get_json_result( | |||
| data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" | |||
| data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" | |||
| ) | |||
| email = request.json.get("email", "") | |||
| @@ -98,7 +85,7 @@ def login(): | |||
| if not users: | |||
| return get_json_result( | |||
| data=False, | |||
| code=RetCode.AUTHENTICATION_ERROR, | |||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||
| message=f"Email: {email} is not registered!", | |||
| ) | |||
| @@ -107,7 +94,7 @@ def login(): | |||
| password = decrypt(password) | |||
| except BaseException: | |||
| return get_json_result( | |||
| data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password" | |||
| data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password" | |||
| ) | |||
| user = UserService.query_user(email, password) | |||
| @@ -123,7 +110,7 @@ def login(): | |||
| else: | |||
| return get_json_result( | |||
| data=False, | |||
| code=RetCode.AUTHENTICATION_ERROR, | |||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||
| message="Email and password do not match!", | |||
| ) | |||
| @@ -150,10 +137,10 @@ def github_callback(): | |||
| import requests | |||
| res = requests.post( | |||
| GITHUB_OAUTH.get("url"), | |||
| settings.GITHUB_OAUTH.get("url"), | |||
| data={ | |||
| "client_id": GITHUB_OAUTH.get("client_id"), | |||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||
| "client_id": settings.GITHUB_OAUTH.get("client_id"), | |||
| "client_secret": settings.GITHUB_OAUTH.get("secret_key"), | |||
| "code": request.args.get("code"), | |||
| }, | |||
| headers={"Accept": "application/json"}, | |||
| @@ -235,11 +222,11 @@ def feishu_callback(): | |||
| import requests | |||
| app_access_token_res = requests.post( | |||
| FEISHU_OAUTH.get("app_access_token_url"), | |||
| settings.FEISHU_OAUTH.get("app_access_token_url"), | |||
| data=json.dumps( | |||
| { | |||
| "app_id": FEISHU_OAUTH.get("app_id"), | |||
| "app_secret": FEISHU_OAUTH.get("app_secret"), | |||
| "app_id": settings.FEISHU_OAUTH.get("app_id"), | |||
| "app_secret": settings.FEISHU_OAUTH.get("app_secret"), | |||
| } | |||
| ), | |||
| headers={"Content-Type": "application/json; charset=utf-8"}, | |||
| @@ -249,10 +236,10 @@ def feishu_callback(): | |||
| return redirect("/?error=%s" % app_access_token_res) | |||
| res = requests.post( | |||
| FEISHU_OAUTH.get("user_access_token_url"), | |||
| settings.FEISHU_OAUTH.get("user_access_token_url"), | |||
| data=json.dumps( | |||
| { | |||
| "grant_type": FEISHU_OAUTH.get("grant_type"), | |||
| "grant_type": settings.FEISHU_OAUTH.get("grant_type"), | |||
| "code": request.args.get("code"), | |||
| } | |||
| ), | |||
| @@ -405,11 +392,11 @@ def setting_user(): | |||
| if request_data.get("password"): | |||
| new_password = request_data.get("new_password") | |||
| if not check_password_hash( | |||
| current_user.password, decrypt(request_data["password"]) | |||
| current_user.password, decrypt(request_data["password"]) | |||
| ): | |||
| return get_json_result( | |||
| data=False, | |||
| code=RetCode.AUTHENTICATION_ERROR, | |||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||
| message="Password error!", | |||
| ) | |||
| @@ -438,7 +425,7 @@ def setting_user(): | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| return get_json_result( | |||
| data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR | |||
| data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR | |||
| ) | |||
| @@ -497,12 +484,12 @@ def user_register(user_id, user): | |||
| tenant = { | |||
| "id": user_id, | |||
| "name": user["nickname"] + "‘s Kingdom", | |||
| "llm_id": CHAT_MDL, | |||
| "embd_id": EMBEDDING_MDL, | |||
| "asr_id": ASR_MDL, | |||
| "parser_ids": PARSERS, | |||
| "img2txt_id": IMAGE2TEXT_MDL, | |||
| "rerank_id": RERANK_MDL, | |||
| "llm_id": settings.CHAT_MDL, | |||
| "embd_id": settings.EMBEDDING_MDL, | |||
| "asr_id": settings.ASR_MDL, | |||
| "parser_ids": settings.PARSERS, | |||
| "img2txt_id": settings.IMAGE2TEXT_MDL, | |||
| "rerank_id": settings.RERANK_MDL, | |||
| } | |||
| usr_tenant = { | |||
| "tenant_id": user_id, | |||
| @@ -522,15 +509,15 @@ def user_register(user_id, user): | |||
| "location": "", | |||
| } | |||
| tenant_llm = [] | |||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||
| for llm in LLMService.query(fid=settings.LLM_FACTORY): | |||
| tenant_llm.append( | |||
| { | |||
| "tenant_id": user_id, | |||
| "llm_factory": LLM_FACTORY, | |||
| "llm_factory": settings.LLM_FACTORY, | |||
| "llm_name": llm.llm_name, | |||
| "model_type": llm.model_type, | |||
| "api_key": API_KEY, | |||
| "api_base": LLM_BASE_URL, | |||
| "api_key": settings.API_KEY, | |||
| "api_base": settings.LLM_BASE_URL, | |||
| } | |||
| ) | |||
| @@ -582,7 +569,7 @@ def user_add(): | |||
| return get_json_result( | |||
| data=False, | |||
| message=f"Invalid email address: {email_address}!", | |||
| code=RetCode.OPERATING_ERROR, | |||
| code=settings.RetCode.OPERATING_ERROR, | |||
| ) | |||
| # Check if the email address is already used | |||
| @@ -590,7 +577,7 @@ def user_add(): | |||
| return get_json_result( | |||
| data=False, | |||
| message=f"Email: {email_address} has already registered!", | |||
| code=RetCode.OPERATING_ERROR, | |||
| code=settings.RetCode.OPERATING_ERROR, | |||
| ) | |||
| # Construct user info data | |||
| @@ -625,7 +612,7 @@ def user_add(): | |||
| return get_json_result( | |||
| data=False, | |||
| message=f"User registration failure, error: {str(e)}", | |||
| code=RetCode.EXCEPTION_ERROR, | |||
| code=settings.RetCode.EXCEPTION_ERROR, | |||
| ) | |||
| @@ -31,7 +31,7 @@ from peewee import ( | |||
| ) | |||
| from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase | |||
| from api.db import SerializedType, ParserType | |||
| from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE | |||
| from api import settings | |||
| from api import utils | |||
| def singleton(cls, *args, **kw): | |||
| @@ -62,7 +62,7 @@ class TextFieldType(Enum): | |||
| class LongTextField(TextField): | |||
| field_type = TextFieldType[DATABASE_TYPE.upper()].value | |||
| field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value | |||
| class JSONField(LongTextField): | |||
| @@ -282,9 +282,9 @@ class DatabaseMigrator(Enum): | |||
| @singleton | |||
| class BaseDataBase: | |||
| def __init__(self): | |||
| database_config = DATABASE.copy() | |||
| database_config = settings.DATABASE.copy() | |||
| db_name = database_config.pop("name") | |||
| self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config) | |||
| self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config) | |||
| logging.info('init database on cluster mode successfully') | |||
| class PostgresDatabaseLock: | |||
| @@ -385,7 +385,7 @@ class DatabaseLock(Enum): | |||
| DB = BaseDataBase().database_connection | |||
| DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value | |||
| DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value | |||
| def close_connection(): | |||
| @@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin): | |||
| return self.email | |||
| def get_id(self): | |||
| jwt = Serializer(secret_key=SECRET_KEY) | |||
| jwt = Serializer(secret_key=settings.SECRET_KEY) | |||
| return jwt.dumps(str(self.access_token)) | |||
| class Meta: | |||
| @@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel): | |||
| def migrate_db(): | |||
| with DB.transaction(): | |||
| migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB) | |||
| migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) | |||
| try: | |||
| migrate( | |||
| migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", | |||
| @@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL | |||
| from api import settings | |||
| from api.utils.file_utils import get_project_base_directory | |||
| @@ -51,11 +51,11 @@ def init_superuser(): | |||
| tenant = { | |||
| "id": user_info["id"], | |||
| "name": user_info["nickname"] + "‘s Kingdom", | |||
| "llm_id": CHAT_MDL, | |||
| "embd_id": EMBEDDING_MDL, | |||
| "asr_id": ASR_MDL, | |||
| "parser_ids": PARSERS, | |||
| "img2txt_id": IMAGE2TEXT_MDL | |||
| "llm_id": settings.CHAT_MDL, | |||
| "embd_id": settings.EMBEDDING_MDL, | |||
| "asr_id": settings.ASR_MDL, | |||
| "parser_ids": settings.PARSERS, | |||
| "img2txt_id": settings.IMAGE2TEXT_MDL | |||
| } | |||
| usr_tenant = { | |||
| "tenant_id": user_info["id"], | |||
| @@ -64,10 +64,11 @@ def init_superuser(): | |||
| "role": UserTenantRole.OWNER | |||
| } | |||
| tenant_llm = [] | |||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||
| for llm in LLMService.query(fid=settings.LLM_FACTORY): | |||
| tenant_llm.append( | |||
| {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, | |||
| "api_key": API_KEY, "api_base": LLM_BASE_URL}) | |||
| {"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name, | |||
| "model_type": llm.model_type, | |||
| "api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL}) | |||
| if not UserService.save(**user_info): | |||
| logging.error("can't init admin.") | |||
| @@ -80,7 +81,7 @@ def init_superuser(): | |||
| chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | |||
| msg = chat_mdl.chat(system="", history=[ | |||
| {"role": "user", "content": "Hello!"}], gen_conf={}) | |||
| {"role": "user", "content": "Hello!"}], gen_conf={}) | |||
| if msg.find("ERROR: ") == 0: | |||
| logging.error( | |||
| "'{}' dosen't work. {}".format( | |||
| @@ -179,7 +180,7 @@ def init_web_data(): | |||
| start_time = time.time() | |||
| init_llm_factory() | |||
| #if not UserService.get_all().count(): | |||
| # if not UserService.get_all().count(): | |||
| # init_superuser() | |||
| add_graph_templates() | |||
| @@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | |||
| from api.settings import retrievaler, kg_retrievaler | |||
| from api import settings | |||
| from rag.app.resume import forbidden_select_fields4resume | |||
| from rag.nlp.search import index_name | |||
| from rag.utils import rmSpace, num_tokens_from_string, encoder | |||
| @@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} | |||
| is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | |||
| retr = retrievaler if not is_kg else kg_retrievaler | |||
| retr = settings.retrievaler if not is_kg else settings.kg_retrievaler | |||
| questions = [m["content"] for m in messages if m["role"] == "user"][-3:] | |||
| attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None | |||
| @@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): | |||
| logging.debug(f"{question} get SQL(refined): {sql}") | |||
| tried_times += 1 | |||
| return retrievaler.sql_retrieval(sql, format="json"), sql | |||
| return settings.retrievaler.sql_retrieval(sql, format="json"), sql | |||
| tbl, sql = get_table() | |||
| if tbl is None: | |||
| @@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id): | |||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | |||
| is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | |||
| retr = retrievaler if not is_kg else kg_retrievaler | |||
| retr = settings.retrievaler if not is_kg else settings.kg_retrievaler | |||
| embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) | |||
| @@ -26,7 +26,7 @@ from io import BytesIO | |||
| from peewee import fn | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.settings import docStoreConn | |||
| from api import settings | |||
| from api.utils import current_timestamp, get_format_time, get_uuid | |||
| from graphrag.mind_map_extractor import MindMapExtractor | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| @@ -108,7 +108,7 @@ class DocumentService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def remove_document(cls, doc, tenant_id): | |||
| docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||
| settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) | |||
| cls.clear_chunk_num(doc.id) | |||
| return cls.delete_by_id(doc.id) | |||
| @@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): | |||
| d["q_%d_vec" % len(v)] = v | |||
| for b in range(0, len(cks), es_bulk_size): | |||
| if try_create_idx: | |||
| if not docStoreConn.indexExist(idxnm, kb_id): | |||
| docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) | |||
| if not settings.docStoreConn.indexExist(idxnm, kb_id): | |||
| settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) | |||
| try_create_idx = False | |||
| docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) | |||
| settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) | |||
| DocumentService.increment_chunk_num( | |||
| doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) | |||
| @@ -33,12 +33,10 @@ import traceback | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from werkzeug.serving import run_simple | |||
| from api import settings | |||
| from api.apps import app | |||
| from api.db.runtime_config import RuntimeConfig | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import ( | |||
| HOST, HTTP_PORT | |||
| ) | |||
| from api import utils | |||
| from api.db.db_models import init_database_tables as init_web_db | |||
| @@ -72,6 +70,7 @@ if __name__ == '__main__': | |||
| f'project base: {utils.file_utils.get_project_base_directory()}' | |||
| ) | |||
| show_configs() | |||
| settings.init_settings() | |||
| # init db | |||
| init_web_db() | |||
| @@ -96,7 +95,7 @@ if __name__ == '__main__': | |||
| logging.info("run on debug mode") | |||
| RuntimeConfig.init_env() | |||
| RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) | |||
| RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT) | |||
| thread = ThreadPoolExecutor(max_workers=1) | |||
| thread.submit(update_progress) | |||
| @@ -105,8 +104,8 @@ if __name__ == '__main__': | |||
| try: | |||
| logging.info("RAGFlow HTTP server start...") | |||
| run_simple( | |||
| hostname=HOST, | |||
| port=HTTP_PORT, | |||
| hostname=settings.HOST_IP, | |||
| port=settings.HOST_PORT, | |||
| application=app, | |||
| threaded=True, | |||
| use_reloader=RuntimeConfig.DEBUG, | |||
| @@ -30,114 +30,157 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0")) | |||
| REQUEST_WAIT_SEC = 2 | |||
| REQUEST_MAX_WAIT_SEC = 300 | |||
| LLM = get_base_config("user_default_llm", {}) | |||
| LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") | |||
| LLM_BASE_URL = LLM.get("base_url") | |||
| CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = "" | |||
| if not LIGHTEN: | |||
| default_llm = { | |||
| "Tongyi-Qianwen": { | |||
| "chat_model": "qwen-plus", | |||
| "embedding_model": "text-embedding-v2", | |||
| "image2text_model": "qwen-vl-max", | |||
| "asr_model": "paraformer-realtime-8k-v1", | |||
| }, | |||
| "OpenAI": { | |||
| "chat_model": "gpt-3.5-turbo", | |||
| "embedding_model": "text-embedding-ada-002", | |||
| "image2text_model": "gpt-4-vision-preview", | |||
| "asr_model": "whisper-1", | |||
| }, | |||
| "Azure-OpenAI": { | |||
| "chat_model": "gpt-35-turbo", | |||
| "embedding_model": "text-embedding-ada-002", | |||
| "image2text_model": "gpt-4-vision-preview", | |||
| "asr_model": "whisper-1", | |||
| }, | |||
| "ZHIPU-AI": { | |||
| "chat_model": "glm-3-turbo", | |||
| "embedding_model": "embedding-2", | |||
| "image2text_model": "glm-4v", | |||
| "asr_model": "", | |||
| }, | |||
| "Ollama": { | |||
| "chat_model": "qwen-14B-chat", | |||
| "embedding_model": "flag-embedding", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "Moonshot": { | |||
| "chat_model": "moonshot-v1-8k", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "DeepSeek": { | |||
| "chat_model": "deepseek-chat", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "VolcEngine": { | |||
| "chat_model": "", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "BAAI": { | |||
| "chat_model": "", | |||
| "embedding_model": "BAAI/bge-large-zh-v1.5", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| "rerank_model": "BAAI/bge-reranker-v2-m3", | |||
| } | |||
| } | |||
| if LLM_FACTORY: | |||
| CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" | |||
| ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" | |||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" | |||
| EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" | |||
| RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" | |||
| API_KEY = LLM.get("api_key", "") | |||
| PARSERS = LLM.get( | |||
| "parsers", | |||
| "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") | |||
| HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||
| HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||
| SECRET_KEY = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, | |||
| {}).get("secret_key", str(date.today())) | |||
| LLM = None | |||
| LLM_FACTORY = None | |||
| LLM_BASE_URL = None | |||
| CHAT_MDL = "" | |||
| EMBEDDING_MDL = "" | |||
| RERANK_MDL = "" | |||
| ASR_MDL = "" | |||
| IMAGE2TEXT_MDL = "" | |||
| API_KEY = None | |||
| PARSERS = None | |||
| HOST_IP = None | |||
| HOST_PORT = None | |||
| SECRET_KEY = None | |||
| DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') | |||
| DATABASE = decrypt_database_config(name=DATABASE_TYPE) | |||
| # authentication | |||
| AUTHENTICATION_CONF = get_base_config("authentication", {}) | |||
| AUTHENTICATION_CONF = None | |||
| # client | |||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( | |||
| "client", {}).get( | |||
| "switch", False) | |||
| HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") | |||
| GITHUB_OAUTH = get_base_config("oauth", {}).get("github") | |||
| FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") | |||
| DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") | |||
| if DOC_ENGINE == "elasticsearch": | |||
| docStoreConn = rag.utils.es_conn.ESConnection() | |||
| elif DOC_ENGINE == "infinity": | |||
| docStoreConn = rag.utils.infinity_conn.InfinityConnection() | |||
| else: | |||
| raise Exception(f"Not supported doc engine: {DOC_ENGINE}") | |||
| retrievaler = search.Dealer(docStoreConn) | |||
| kg_retrievaler = kg_search.KGSearch(docStoreConn) | |||
| CLIENT_AUTHENTICATION = None | |||
| HTTP_APP_KEY = None | |||
| GITHUB_OAUTH = None | |||
| FEISHU_OAUTH = None | |||
| DOC_ENGINE = None | |||
| docStoreConn = None | |||
| retrievaler = None | |||
| kg_retrievaler = None | |||
| def init_settings(): | |||
| global LLM, LLM_FACTORY, LLM_BASE_URL | |||
| LLM = get_base_config("user_default_llm", {}) | |||
| LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") | |||
| LLM_BASE_URL = LLM.get("base_url") | |||
| global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL | |||
| if not LIGHTEN: | |||
| default_llm = { | |||
| "Tongyi-Qianwen": { | |||
| "chat_model": "qwen-plus", | |||
| "embedding_model": "text-embedding-v2", | |||
| "image2text_model": "qwen-vl-max", | |||
| "asr_model": "paraformer-realtime-8k-v1", | |||
| }, | |||
| "OpenAI": { | |||
| "chat_model": "gpt-3.5-turbo", | |||
| "embedding_model": "text-embedding-ada-002", | |||
| "image2text_model": "gpt-4-vision-preview", | |||
| "asr_model": "whisper-1", | |||
| }, | |||
| "Azure-OpenAI": { | |||
| "chat_model": "gpt-35-turbo", | |||
| "embedding_model": "text-embedding-ada-002", | |||
| "image2text_model": "gpt-4-vision-preview", | |||
| "asr_model": "whisper-1", | |||
| }, | |||
| "ZHIPU-AI": { | |||
| "chat_model": "glm-3-turbo", | |||
| "embedding_model": "embedding-2", | |||
| "image2text_model": "glm-4v", | |||
| "asr_model": "", | |||
| }, | |||
| "Ollama": { | |||
| "chat_model": "qwen-14B-chat", | |||
| "embedding_model": "flag-embedding", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "Moonshot": { | |||
| "chat_model": "moonshot-v1-8k", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "DeepSeek": { | |||
| "chat_model": "deepseek-chat", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "VolcEngine": { | |||
| "chat_model": "", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "BAAI": { | |||
| "chat_model": "", | |||
| "embedding_model": "BAAI/bge-large-zh-v1.5", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| "rerank_model": "BAAI/bge-reranker-v2-m3", | |||
| } | |||
| } | |||
| if LLM_FACTORY: | |||
| CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" | |||
| ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" | |||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" | |||
| EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" | |||
| RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" | |||
| global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY | |||
| API_KEY = LLM.get("api_key", "") | |||
| PARSERS = LLM.get( | |||
| "parsers", | |||
| "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") | |||
| HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||
| HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||
| SECRET_KEY = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, | |||
| {}).get("secret_key", str(date.today())) | |||
| global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH | |||
| # authentication | |||
| AUTHENTICATION_CONF = get_base_config("authentication", {}) | |||
| # client | |||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( | |||
| "client", {}).get( | |||
| "switch", False) | |||
| HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") | |||
| GITHUB_OAUTH = get_base_config("oauth", {}).get("github") | |||
| FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") | |||
| global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler | |||
| DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") | |||
| if DOC_ENGINE == "elasticsearch": | |||
| docStoreConn = rag.utils.es_conn.ESConnection() | |||
| elif DOC_ENGINE == "infinity": | |||
| docStoreConn = rag.utils.infinity_conn.InfinityConnection() | |||
| else: | |||
| raise Exception(f"Not supported doc engine: {DOC_ENGINE}") | |||
| retrievaler = search.Dealer(docStoreConn) | |||
| kg_retrievaler = kg_search.KGSearch(docStoreConn) | |||
| def get_host_ip(): | |||
| global HOST_IP | |||
| return HOST_IP | |||
| def get_host_port(): | |||
| global HOST_PORT | |||
| return HOST_PORT | |||
| class CustomEnum(Enum): | |||
| @@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer | |||
| from werkzeug.http import HTTP_STATUS_CODES | |||
| from api.db.db_models import APIToken | |||
| from api.settings import ( | |||
| REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | |||
| CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY | |||
| ) | |||
| from api.settings import RetCode | |||
| from api import settings | |||
| from api import settings | |||
| from api.utils import CustomJSONEncoder, get_uuid | |||
| from api.utils import json_dumps | |||
| @@ -59,13 +57,13 @@ def request(**kwargs): | |||
| {}).items()} | |||
| prepped = requests.Request(**kwargs).prepare() | |||
| if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: | |||
| if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY: | |||
| timestamp = str(round(time() * 1000)) | |||
| nonce = str(uuid1()) | |||
| signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([ | |||
| signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([ | |||
| timestamp.encode('ascii'), | |||
| nonce.encode('ascii'), | |||
| HTTP_APP_KEY.encode('ascii'), | |||
| settings.HTTP_APP_KEY.encode('ascii'), | |||
| prepped.path_url.encode('ascii'), | |||
| prepped.body if kwargs.get('json') else b'', | |||
| urlencode( | |||
| @@ -79,7 +77,7 @@ def request(**kwargs): | |||
| prepped.headers.update({ | |||
| 'TIMESTAMP': timestamp, | |||
| 'NONCE': nonce, | |||
| 'APP-KEY': HTTP_APP_KEY, | |||
| 'APP-KEY': settings.HTTP_APP_KEY, | |||
| 'SIGNATURE': signature, | |||
| }) | |||
| @@ -89,7 +87,7 @@ def request(**kwargs): | |||
| def get_exponential_backoff_interval(retries, full_jitter=False): | |||
| """Calculate the exponential backoff wait time.""" | |||
| # Will be zero if factor equals 0 | |||
| countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) | |||
| countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries)) | |||
| # Full jitter according to | |||
| # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ | |||
| if full_jitter: | |||
| @@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False): | |||
| return max(0, countdown) | |||
| def get_data_error_result(code=RetCode.DATA_ERROR, | |||
| def get_data_error_result(code=settings.RetCode.DATA_ERROR, | |||
| message='Sorry! Data missing!'): | |||
| import re | |||
| result_dict = { | |||
| @@ -126,8 +124,8 @@ def server_error_response(e): | |||
| pass | |||
| if len(e.args) > 1: | |||
| return get_json_result( | |||
| code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||
| return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) | |||
| code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||
| return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) | |||
| def error_response(response_code, message=None): | |||
| @@ -168,7 +166,7 @@ def validate_request(*args, **kwargs): | |||
| error_string += "required argument values: {}".format( | |||
| ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) | |||
| return get_json_result( | |||
| code=RetCode.ARGUMENT_ERROR, message=error_string) | |||
| code=settings.RetCode.ARGUMENT_ERROR, message=error_string) | |||
| return func(*_args, **_kwargs) | |||
| return decorated_function | |||
| @@ -193,7 +191,7 @@ def send_file_in_mem(data, filename): | |||
| return send_file(f, as_attachment=True, attachment_filename=filename) | |||
| def get_json_result(code=RetCode.SUCCESS, message='success', data=None): | |||
| def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): | |||
| response = {"code": code, "message": message, "data": data} | |||
| return jsonify(response) | |||
| @@ -204,7 +202,7 @@ def apikey_required(func): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return build_error_result( | |||
| message='API-KEY is invalid!', code=RetCode.FORBIDDEN | |||
| message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN | |||
| ) | |||
| kwargs['tenant_id'] = objs[0].tenant_id | |||
| return func(*args, **kwargs) | |||
| @@ -212,14 +210,14 @@ def apikey_required(func): | |||
| return decorated_function | |||
| def build_error_result(code=RetCode.FORBIDDEN, message='success'): | |||
| def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'): | |||
| response = {"code": code, "message": message} | |||
| response = jsonify(response) | |||
| response.status_code = code | |||
| return response | |||
| def construct_response(code=RetCode.SUCCESS, | |||
| def construct_response(code=settings.RetCode.SUCCESS, | |||
| message='success', data=None, auth=None): | |||
| result_dict = {"code": code, "message": message, "data": data} | |||
| response_dict = {} | |||
| @@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS, | |||
| return response | |||
| def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): | |||
| def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'): | |||
| import re | |||
| result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} | |||
| response = {} | |||
| @@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): | |||
| return jsonify(response) | |||
| def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): | |||
| def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): | |||
| if data is None: | |||
| return jsonify({"code": code, "message": message}) | |||
| else: | |||
| @@ -262,12 +260,12 @@ def construct_error_response(e): | |||
| logging.exception(e) | |||
| try: | |||
| if e.code == 401: | |||
| return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) | |||
| return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e)) | |||
| except BaseException: | |||
| pass | |||
| if len(e.args) > 1: | |||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) | |||
| return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||
| return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) | |||
| def token_required(func): | |||
| @@ -280,7 +278,7 @@ def token_required(func): | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR | |||
| data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR | |||
| ) | |||
| kwargs['tenant_id'] = objs[0].tenant_id | |||
| return func(*args, **kwargs) | |||
| @@ -288,7 +286,7 @@ def token_required(func): | |||
| return decorated_function | |||
| def get_result(code=RetCode.SUCCESS, message="", data=None): | |||
| def get_result(code=settings.RetCode.SUCCESS, message="", data=None): | |||
| if code == 0: | |||
| if data is not None: | |||
| response = {"code": code, "data": data} | |||
| @@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None): | |||
| return jsonify(response) | |||
| def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR, | |||
| def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR, | |||
| ): | |||
| import re | |||
| result_dict = { | |||
| @@ -24,7 +24,7 @@ import numpy as np | |||
| from timeit import default_timer as timer | |||
| from pypdf import PdfReader as pdf2_read | |||
| from api.settings import LIGHTEN | |||
| from api import settings | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer | |||
| from rag.nlp import rag_tokenizer | |||
| @@ -41,7 +41,7 @@ class RAGFlowPdfParser: | |||
| self.tbl_det = TableStructureRecognizer() | |||
| self.updown_cnt_mdl = xgb.Booster() | |||
| if not LIGHTEN: | |||
| if not settings.LIGHTEN: | |||
| try: | |||
| import torch | |||
| if torch.cuda.is_available(): | |||
| @@ -252,13 +252,13 @@ if __name__ == "__main__": | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.settings import retrievaler | |||
| from api import settings | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) | |||
| ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) | |||
| docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] | |||
| docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] | |||
| info = { | |||
| "input_text": docs, | |||
| "entity_specs": "organization, person", | |||
| @@ -30,14 +30,14 @@ if __name__ == "__main__": | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.settings import retrievaler | |||
| from api import settings | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) | |||
| ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) | |||
| docs = [d["content_with_weight"] for d in | |||
| retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])] | |||
| settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])] | |||
| graph = ex(docs) | |||
| er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) | |||
| @@ -23,7 +23,7 @@ from collections import defaultdict | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.settings import retrievaler, docStoreConn | |||
| from api import settings | |||
| from api.utils import get_uuid | |||
| from rag.nlp import tokenize, search | |||
| from ranx import evaluate | |||
| @@ -52,7 +52,7 @@ class Benchmark: | |||
| run = defaultdict(dict) | |||
| query_list = list(qrels.keys()) | |||
| for query in query_list: | |||
| ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, | |||
| ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, | |||
| 0.0, self.vector_similarity_weight) | |||
| if len(ranks["chunks"]) == 0: | |||
| print(f"deleted query: {query}") | |||
| @@ -81,9 +81,9 @@ class Benchmark: | |||
| def init_index(self, vector_size: int): | |||
| if self.initialized_index: | |||
| return | |||
| if docStoreConn.indexExist(self.index_name, self.kb_id): | |||
| docStoreConn.deleteIdx(self.index_name, self.kb_id) | |||
| docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) | |||
| if settings.docStoreConn.indexExist(self.index_name, self.kb_id): | |||
| settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) | |||
| settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) | |||
| self.initialized_index = True | |||
| def ms_marco_index(self, file_path, index_name): | |||
| @@ -118,13 +118,13 @@ class Benchmark: | |||
| docs_count += len(docs) | |||
| docs, vector_size = self.embedding(docs) | |||
| self.init_index(vector_size) | |||
| docStoreConn.insert(docs, self.index_name, self.kb_id) | |||
| settings.docStoreConn.insert(docs, self.index_name, self.kb_id) | |||
| docs = [] | |||
| if docs: | |||
| docs, vector_size = self.embedding(docs) | |||
| self.init_index(vector_size) | |||
| docStoreConn.insert(docs, self.index_name, self.kb_id) | |||
| settings.docStoreConn.insert(docs, self.index_name, self.kb_id) | |||
| return qrels, texts | |||
| def trivia_qa_index(self, file_path, index_name): | |||
| @@ -159,12 +159,12 @@ class Benchmark: | |||
| docs_count += len(docs) | |||
| docs, vector_size = self.embedding(docs) | |||
| self.init_index(vector_size) | |||
| docStoreConn.insert(docs,self.index_name) | |||
| settings.docStoreConn.insert(docs,self.index_name) | |||
| docs = [] | |||
| docs, vector_size = self.embedding(docs) | |||
| self.init_index(vector_size) | |||
| docStoreConn.insert(docs, self.index_name) | |||
| settings.docStoreConn.insert(docs, self.index_name) | |||
| return qrels, texts | |||
| def miracl_index(self, file_path, corpus_path, index_name): | |||
| @@ -214,12 +214,12 @@ class Benchmark: | |||
| docs_count += len(docs) | |||
| docs, vector_size = self.embedding(docs) | |||
| self.init_index(vector_size) | |||
| docStoreConn.insert(docs, self.index_name) | |||
| settings.docStoreConn.insert(docs, self.index_name) | |||
| docs = [] | |||
| docs, vector_size = self.embedding(docs) | |||
| self.init_index(vector_size) | |||
| docStoreConn.insert(docs, self.index_name) | |||
| settings.docStoreConn.insert(docs, self.index_name) | |||
| return qrels, texts | |||
| def save_results(self, qrels, run, texts, dataset, file_path): | |||
| @@ -28,7 +28,7 @@ from openai import OpenAI | |||
| import numpy as np | |||
| import asyncio | |||
| from api.settings import LIGHTEN | |||
| from api import settings | |||
| from api.utils.file_utils import get_home_cache_dir | |||
| from rag.utils import num_tokens_from_string, truncate | |||
| import google.generativeai as genai | |||
| @@ -60,7 +60,7 @@ class DefaultEmbedding(Base): | |||
| ^_- | |||
| """ | |||
| if not LIGHTEN and not DefaultEmbedding._model: | |||
| if not settings.LIGHTEN and not DefaultEmbedding._model: | |||
| with DefaultEmbedding._model_lock: | |||
| from FlagEmbedding import FlagModel | |||
| import torch | |||
| @@ -248,7 +248,7 @@ class FastEmbed(Base): | |||
| threads: Optional[int] = None, | |||
| **kwargs, | |||
| ): | |||
| if not LIGHTEN and not FastEmbed._model: | |||
| if not settings.LIGHTEN and not FastEmbed._model: | |||
| from fastembed import TextEmbedding | |||
| self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |||
| @@ -294,7 +294,7 @@ class YoudaoEmbed(Base): | |||
| _client = None | |||
| def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | |||
| if not LIGHTEN and not YoudaoEmbed._client: | |||
| if not settings.LIGHTEN and not YoudaoEmbed._client: | |||
| from BCEmbedding import EmbeddingModel as qanthing | |||
| try: | |||
| logging.info("LOADING BCE...") | |||
| @@ -23,7 +23,7 @@ import os | |||
| from abc import ABC | |||
| import numpy as np | |||
| from api.settings import LIGHTEN | |||
| from api import settings | |||
| from api.utils.file_utils import get_home_cache_dir | |||
| from rag.utils import num_tokens_from_string, truncate | |||
| import json | |||
| @@ -57,7 +57,7 @@ class DefaultRerank(Base): | |||
| ^_- | |||
| """ | |||
| if not LIGHTEN and not DefaultRerank._model: | |||
| if not settings.LIGHTEN and not DefaultRerank._model: | |||
| import torch | |||
| from FlagEmbedding import FlagReranker | |||
| with DefaultRerank._model_lock: | |||
| @@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank): | |||
| _model_lock = threading.Lock() | |||
| def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): | |||
| if not LIGHTEN and not YoudaoRerank._model: | |||
| if not settings.LIGHTEN and not YoudaoRerank._model: | |||
| from BCEmbedding import RerankerModel | |||
| with YoudaoRerank._model_lock: | |||
| if not YoudaoRerank._model: | |||
| @@ -16,6 +16,7 @@ | |||
| import logging | |||
| import sys | |||
| from api.utils.log_utils import initRootLogger | |||
| CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] | |||
| initRootLogger(f"task_executor_{CONSUMER_NO}") | |||
| for module in ["pdfminer"]: | |||
| @@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.task_service import TaskService | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.settings import retrievaler, docStoreConn | |||
| from api import settings | |||
| from api.db.db_models import close_connection | |||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email | |||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ | |||
| knowledge_graph, email | |||
| from rag.nlp import search, rag_tokenizer | |||
| from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | |||
| from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME | |||
| @@ -88,6 +90,7 @@ PENDING_TASKS = 0 | |||
| HEAD_CREATED_AT = "" | |||
| HEAD_DETAIL = "" | |||
| def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): | |||
| global PAYLOAD | |||
| if prog is not None and prog < 0: | |||
| @@ -171,7 +174,8 @@ def build(row): | |||
| "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) | |||
| except TimeoutError: | |||
| callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") | |||
| logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) | |||
| logging.exception( | |||
| "Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) | |||
| return | |||
| except Exception as e: | |||
| if re.search("(No such file|not found)", str(e)): | |||
| @@ -188,7 +192,7 @@ def build(row): | |||
| logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"])) | |||
| except Exception as e: | |||
| callback(-1, "Internal server error while chunking: %s" % | |||
| str(e).replace("'", "")) | |||
| str(e).replace("'", "")) | |||
| logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"])) | |||
| return | |||
| @@ -226,7 +230,8 @@ def build(row): | |||
| STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue()) | |||
| el += timer() - st | |||
| except Exception: | |||
| logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) | |||
| logging.exception( | |||
| "Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) | |||
| d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) | |||
| del d["image"] | |||
| @@ -241,7 +246,7 @@ def build(row): | |||
| d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"], | |||
| row["parser_config"]["auto_keywords"]).split(",") | |||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) | |||
| callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st)) | |||
| callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st)) | |||
| if row["parser_config"].get("auto_questions", 0): | |||
| st = timer() | |||
| @@ -255,14 +260,14 @@ def build(row): | |||
| d["content_ltks"] += " " + qst | |||
| if "content_sm_ltks" in d: | |||
| d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst) | |||
| callback(msg="Question generation completed in {:.2f}s".format(timer()-st)) | |||
| callback(msg="Question generation completed in {:.2f}s".format(timer() - st)) | |||
| return docs | |||
| def init_kb(row, vector_size: int): | |||
| idxnm = search.index_name(row["tenant_id"]) | |||
| return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size) | |||
| return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size) | |||
| def embedding(docs, mdl, parser_config=None, callback=None): | |||
| @@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): | |||
| vector_size = len(vts[0]) | |||
| vctr_nm = "q_%d_vec" % vector_size | |||
| chunks = [] | |||
| for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]): | |||
| for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | |||
| fields=["content_with_weight", vctr_nm]): | |||
| chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) | |||
| raptor = Raptor( | |||
| @@ -384,7 +390,8 @@ def main(): | |||
| # TODO: exception handler | |||
| ## set_progress(r["did"], -1, "ERROR: ") | |||
| callback( | |||
| msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st) | |||
| msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), | |||
| timer() - st) | |||
| ) | |||
| st = timer() | |||
| try: | |||
| @@ -403,18 +410,18 @@ def main(): | |||
| es_r = "" | |||
| es_bulk_size = 4 | |||
| for b in range(0, len(cks), es_bulk_size): | |||
| es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) | |||
| es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) | |||
| if b % 128 == 0: | |||
| callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") | |||
| logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) | |||
| if es_r: | |||
| callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") | |||
| docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||
| settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||
| logging.error('Insert chunk error: ' + str(es_r)) | |||
| else: | |||
| if TaskService.do_cancel(r["id"]): | |||
| docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||
| settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) | |||
| continue | |||
| callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) | |||
| callback(1., "Done!") | |||
| @@ -435,7 +442,7 @@ def report_status(): | |||
| if PENDING_TASKS > 0: | |||
| head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) | |||
| if head_info is not None: | |||
| seconds = int(head_info[0].split("-")[0])/1000 | |||
| seconds = int(head_info[0].split("-")[0]) / 1000 | |||
| HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() | |||
| HEAD_DETAIL = head_info[1] | |||
| @@ -452,7 +459,7 @@ def report_status(): | |||
| REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) | |||
| logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") | |||
| expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30) | |||
| expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30) | |||
| if expired > 0: | |||
| REDIS_CONN.zpopmin(CONSUMER_NAME, expired) | |||
| except Exception: | |||