瀏覽代碼

Move settings initialization after module init phase (#3438)

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
tags/v0.14.0
Jin Hai 11 月之前
父節點
當前提交
1e90a1bf36
沒有連結到貢獻者的電子郵件帳戶。

+ 13
- 9
agent/component/generate.py 查看文件

from api.db import LLMType from api.db import LLMType
from api.db.services.dialog_service import message_fit_in from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase




component_name = "Generate" component_name = "Generate"


def get_dependent_components(self): def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
cpnts = [para["component_id"] for para in self._param.parameters if
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts return cpnts


def set_cite(self, retrieval_res, answer): def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
if "empty_response" in retrieval_res.columns: if "empty_response" in retrieval_res.columns:
retrieval_res["empty_response"].fillna("", inplace=True) retrieval_res["empty_response"].fillna("", inplace=True)
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
vtweight=0.3)
answer, idx = settings.retrievaler.insert_citations(answer,
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
vtweight=0.3)
doc_ids = set([]) doc_ids = set([])
recall_docs = [] recall_docs = []
for i in idx: for i in idx:
else: else:
if cpn.component_name.lower() == "retrieval": if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out) retrieval_res.append(out)
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
kwargs[para["key"]] = " - " + "\n - ".join(
[o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})


if retrieval_res: if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True) retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])


for n, v in kwargs.items(): for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)

+ 2
- 2
agent/component/retrieval.py 查看文件

from api.db import LLMType from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase




if self._param.rerank_id: if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)


kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
1, self._param.top_n, 1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl) aggs=False, rerank_mdl=rerank_mdl)

+ 4
- 6
api/apps/__init__.py 查看文件



from flask_session import Session from flask_session import Session
from flask_login import LoginManager from flask_login import LoginManager
from api.settings import SECRET_KEY
from api.settings import API_VERSION
from api import settings
from api.utils.api_utils import server_error_response from api.utils.api_utils import server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer


app.json_encoder = CustomJSONEncoder app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response) app.errorhandler(Exception)(server_error_response)



## convince for dev and debug ## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True # app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False app.config["SESSION_PERMANENT"] = False


page_name = page_path.stem.rstrip("_app") page_name = page_path.stem.rstrip("_app")
module_name = ".".join( module_name = ".".join(
page_path.parts[page_path.parts.index("api") : -1] + (page_name,)
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
) )


spec = spec_from_file_location(module_name, page_path) spec = spec_from_file_location(module_name, page_path)
spec.loader.exec_module(page) spec.loader.exec_module(page)
page_name = getattr(page, "page_name", page_name) page_name = getattr(page, "page_name", page_name)
url_prefix = ( url_prefix = (
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
) )


app.register_blueprint(page.manager, url_prefix=url_prefix) app.register_blueprint(page.manager, url_prefix=url_prefix)


@login_manager.request_loader @login_manager.request_loader
def load_user(web_request): def load_user(web_request):
jwt = Serializer(secret_key=SECRET_KEY)
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = web_request.headers.get("Authorization") authorization = web_request.headers.get("Authorization")
if authorization: if authorization:
try: try:

+ 33
- 31
api/apps/api_app.py 查看文件

from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils import get_uuid, current_timestamp, datetime_format from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token generate_confirmation_token
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
try: try:
if objs[0].source == "agent": if objs[0].source == "agent":
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"]) e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
API4ConversationService.append_message(conv.id, conv.to_dict()) API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result) rename_field(result)
return get_json_result(data=result) return get_json_result(data=result)
#******************For dialog******************
# ******************For dialog******************
conv.message.append(msg[-1]) conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id) e, dia = DialogService.get_by_id(conv.dialog_id)
if not e: if not e:
resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp return resp
answer = None answer = None
for ans in chat(dia, msg, **req): for ans in chat(dia, msg, **req):
answer = ans answer = ans
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
e, conv = API4ConversationService.get_by_id(conversation_id) e, conv = API4ConversationService.get_by_id(conversation_id)
if not e: if not e:
conv = conv.to_dict() conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Token is not valid for this conversation_id!"', return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
for referenct_i in conv['reference']: for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0: if referenct_i is None or len(referenct_i) == 0:
continue continue
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)


kb_name = request.form.get("kb_name").strip() kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id


if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)


file = request.files['file'] file = request.files['file']
if file.filename == '': if file.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)


root_folder = FileService.get_root_folder(tenant_id) root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)


if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)


file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)


doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids) return get_json_result(data=doc_ids)
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)


req = request.json req = request.json


) )
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)


res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = [ res = [
{ {
"content": res_item["content_with_weight"], "content": res_item["content_with_weight"],
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)


req = request.json req = request.json
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)



@manager.route('/document/infos', methods=['POST']) @manager.route('/document/infos', methods=['POST'])
@validate_request("doc_ids") @validate_request("doc_ids")
def docinfos(): def docinfos():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
doc_ids = req["doc_ids"] doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids) docs = DocumentService.get_by_ids(doc_ids)
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)


tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
req = request.json req = request.json
errors += str(e) errors += str(e)


if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)


return get_json_result(data=True) return get_json_result(data=True)


objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)


e, conv = API4ConversationService.get_by_id(req["conversation_id"]) e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)


req = request.json req = request.json
kb_ids = req.get("kb_id",[])
kb_ids = req.get("kb_id", [])
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
question = req.get("question") question = req.get("question")
page = int(req.get("page", 1)) page = int(req.get("page", 1))
embd_nms = list(set([kb.embd_id for kb in kbs])) embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1: if len(embd_nms) != 1:
return get_json_result( return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Knowledge bases use different embedding models or does not exist."',
code=settings.RetCode.AUTHENTICATION_ERROR)


embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
rerank_mdl = None rerank_mdl = None
if req.get("rerank_id"): if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance( rerank_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]: for c in ranks["chunks"]:
if "vector" in c: if "vector" in c:
del c["vector"] del c["vector"]
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)

+ 12
- 10
api/apps/canvas_app.py 查看文件

from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode
from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
@login_required @login_required
def canvas_list(): def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \ return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
UserCanvasService.query(user_id=current_user.id)],
key=lambda x: x["update_time"] * -1)
) )




@login_required @login_required
def rm(): def rm():
for i in request.json["canvas_ids"]: for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id,id=i):
if not UserCanvasService.query(user_id=current_user.id, id=i):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i) UserCanvasService.delete_by_id(i)
return get_json_result(data=True) return get_json_result(data=True)


if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req) UserCanvasService.update_by_id(req["id"], req)
return get_json_result(data=req) return get_json_result(data=req)


if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)


if not isinstance(cvs.dsl, str): if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if "message" in req: if "message" in req:
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
if len([m for m in canvas.messages if m["role"] == "user"]) > 1: if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
#ten = TenantService.get_info_by(current_user.id)[0]
#req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
# ten = TenantService.get_info_by(current_user.id)[0]
# req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
pass pass
canvas.add_user_input(req["message"]) canvas.add_user_input(req["message"])
answer = canvas.run(stream=stream) answer = canvas.run(stream=stream)
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."


if stream: if stream:
assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
assert isinstance(answer,
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."


def sse(): def sse():
nonlocal answer, cvs nonlocal answer, cvs
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)


canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset() canvas.reset()

+ 17
- 17
api/apps/chunk_app.py 查看文件

from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import hashlib import hashlib
import re import re



@manager.route('/list', methods=['POST']) @manager.route('/list', methods=['POST'])
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
"positions": json.loads(sres.field[id].get("position_list", "[]")), "positions": json.loads(sres.field[id].get("position_list", "[]")),
} }
assert isinstance(d["positions"], list) assert isinstance(d["positions"], list)
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
res["chunks"].append(d) res["chunks"].append(d)
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found!', return get_json_result(data=False, message='No chunk found!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)




tenant_id = tenants[0].tenant_id tenant_id = tenants[0].tenant_id


kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None: if chunk is None:
return server_error_response("Chunk not found") return server_error_response("Chunk not found")
k = [] k = []
except Exception as e: except Exception as e:
if str(e).find("NotFoundError") >= 0: if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, message='Chunk not found!', return get_json_result(data=False, message='Chunk not found!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)




v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
search.index_name(doc.tenant_id), doc.kb_id):
if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
search.index_name(doc.tenant_id), doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
deleted_chunk_ids = req["chunk_ids"] deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids) chunk_number = len(deleted_chunk_ids)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)


DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)


e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e: if not e:
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)


retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size, ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)




tenant_id = DocumentService.get_tenant_id(doc_id) tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = { req = {
"doc_ids":[doc_id],
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]
obj[ty] = content_json obj[ty] = content_json


return get_json_result(data=obj) return get_json_result(data=obj)


+ 7
- 6
api/apps/conversation_app.py 查看文件

from api.db.services.dialog_service import DialogService, ConversationService, chat, ask from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of conversation authorized for this operation.', data=False, message='Only owner of conversation authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
conv = conv.to_dict() conv = conv.to_dict()
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of conversation authorized for this operation.', data=False, message='Only owner of conversation authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid) ConversationService.delete_by_id(cid)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id): if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result( return get_json_result(
data=False, message='Only owner of dialog authorized for this operation.', data=False, message='Only owner of dialog authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
convs = ConversationService.query( convs = ConversationService.query(
dialog_id=dialog_id, dialog_id=dialog_id,
order_by=ConversationService.model.create_time, order_by=ConversationService.model.create_time,
def ask_about(): def ask_about():
req = request.json req = request.json
uid = current_user.id uid = current_user.id

def stream(): def stream():
nonlocal req, uid nonlocal req, uid
try: try:
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False)
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False)
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
if "error" in mind_map: if "error" in mind_map:

+ 2
- 2
api/apps/dialog_app.py 查看文件

from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of dialog authorized for this operation.', data=False, message='Only owner of dialog authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list) DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True) return get_json_result(data=True)

+ 31
- 30
api/apps/document_app.py 查看文件

from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.settings import RetCode, docStoreConn
from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)


file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)


e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
err, _ = FileService.upload_document(kb, file_objs, current_user.id) err, _ = FileService.upload_document(kb, file_objs, current_user.id)
if err: if err:
return get_json_result( return get_json_result(
data=False, message="\n".join(err), code=RetCode.SERVER_ERROR)
data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True) return get_json_result(data=True)




kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
name = request.form.get("name") name = request.form.get("name")
url = request.form.get("url") url = request.form.get("url")
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result( return get_json_result(
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
kb_id = req["kb_id"] kb_id = req["kb_id"]
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)


try: try:
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
kb_id = request.args.get("kb_id") kb_id = request.args.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")


page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 1))
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
) )
docs = DocumentService.get_by_ids(doc_ids) docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts())) return get_json_result(data=list(docs.dicts()))




@manager.route('/thumbnails', methods=['GET']) @manager.route('/thumbnails', methods=['GET'])
#@login_required
# @login_required
def thumbnails(): def thumbnails():
doc_ids = request.args.get("doc_ids").split(",") doc_ids = request.args.get("doc_ids").split(",")
if not doc_ids: if not doc_ids:
return get_json_result( return get_json_result(
data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)


try: try:
docs = DocumentService.get_thumbnails(doc_ids) docs = DocumentService.get_thumbnails(doc_ids)
return get_json_result( return get_json_result(
data=False, data=False,
message='"Status" must be either 0 or 1!', message='"Status" must be either 0 or 1!',
code=RetCode.ARGUMENT_ERROR)
code=settings.RetCode.ARGUMENT_ERROR)


if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)


try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
message="Database error (Document update)!") message="Database error (Document update)!")


status = int(req["status"]) status = int(req["status"])
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
search.index_name(kb.tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
) )


root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
errors += str(e) errors += str(e)


if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)


return get_json_result(data=True) return get_json_result(data=True)


return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
for id in req["doc_ids"]: for id in req["doc_ids"]:
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)


if str(req["run"]) == TaskStatus.RUNNING.value: if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
return get_json_result( return get_json_result(
data=False, data=False,
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR)
code=settings.RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]: if d.name == req["name"]:
return get_data_error_result( return get_data_error_result(
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)


return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
def upload_and_parse(): def upload_and_parse():
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)


file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)


doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)


if url: if url:
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result( return get_json_result(
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
download_path = os.path.join(get_project_base_directory(), "logs/downloads") download_path = os.path.join(get_project_base_directory(), "logs/downloads")
os.makedirs(download_path, exist_ok=True) os.makedirs(download_path, exist_ok=True)
from selenium.webdriver import Chrome, ChromeOptions from selenium.webdriver import Chrome, ChromeOptions


if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)


file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
txt = FileService.parse_docs(file_objs, current_user.id) txt = FileService.parse_docs(file_objs, current_user.id)

+ 2
- 2
api/apps/file2document_app.py 查看文件

from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType from api.db import FileType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result




file_ids = req["file_ids"] file_ids = req["file_ids"]
if not file_ids: if not file_ids:
return get_json_result( return get_json_result(
data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
try: try:
for file_id in file_ids: for file_id in file_ids:
informs = File2DocumentService.get_by_file_id(file_id) informs = File2DocumentService.get_by_file_id(file_id)

+ 5
- 5
api/apps/file_app.py 查看文件

from api.db import FileType, FileSource from api.db import FileType, FileSource
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.settings import RetCode
from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL


if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')


for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
file_res = [] file_res = []
try: try:
for file_obj in file_objs: for file_obj in file_objs:
try: try:
if not FileService.is_parent_folder_exist(pf_id): if not FileService.is_parent_folder_exist(pf_id):
return get_json_result( return get_json_result(
data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
if FileService.query(name=req["name"], parent_id=pf_id): if FileService.query(name=req["name"], parent_id=pf_id):
return get_data_error_result( return get_data_error_result(
message="Duplicated folder name in the same folder.") message="Duplicated folder name in the same folder.")
return get_json_result( return get_json_result(
data=False, data=False,
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR)
code=settings.RetCode.ARGUMENT_ERROR)
for file in FileService.query(name=req["name"], pf_id=file.parent_id): for file in FileService.query(name=req["name"], pf_id=file.parent_id):
if file.name == req["name"]: if file.name == req["name"]:
return get_data_error_result( return get_data_error_result(

+ 7
- 8
api/apps/kb_app.py 查看文件

from api.db import StatusEnum, FileSource from api.db import StatusEnum, FileSource
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
from api.settings import RetCode
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.settings import docStoreConn
from api import settings
from rag.nlp import search from rag.nlp import search




return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
if not KnowledgebaseService.query( if not KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]): created_by=current_user.id, id=req["kb_id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)


e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e: if not e:
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
kb = KnowledgebaseService.get_detail(kb_id) kb = KnowledgebaseService.get_detail(kb_id)
if not kb: if not kb:
return get_data_error_result( return get_data_error_result(
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
kbs = KnowledgebaseService.query( kbs = KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]) created_by=current_user.id, id=req["kb_id"])
if not kbs: if not kbs:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)


for doc in DocumentService.query(kb_id=req["kb_id"]): for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id): if not DocumentService.remove_document(doc, kbs[0].tenant_id):
message="Database error (Knowledgebase removal)!") message="Database error (Knowledgebase removal)!")
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

+ 2
- 2
api/apps/llm_app.py 查看文件

from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from api.settings import LIGHTEN
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db import StatusEnum, LLMType from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
@login_required @login_required
def list_app(): def list_app():
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else []
weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type") model_type = request.args.get("model_type")
try: try:
objs = TenantLLMService.query(tenant_id=current_user.id) objs = TenantLLMService.query(tenant_id=current_user.id)

+ 3
- 3
api/apps/sdk/chat.py 查看文件

# limitations under the License. # limitations under the License.
# #
from flask import request from flask import request
from api.settings import RetCode
from api import settings
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
kbs = KnowledgebaseService.get_by_ids(ids) kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs])) embd_count = list(set([kb.embd_id for kb in kbs]))
if len(embd_count) != 1: if len(embd_count) != 1:
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids req["kb_ids"] = ids
# llm # llm
llm = req.get("llm") llm = req.get("llm")
if len(embd_count) != 1 : if len(embd_count) != 1 :
return get_result( return get_result(
message='Datasets use different embedding models."', message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids req["kb_ids"] = ids
llm = req.get("llm") llm = req.get("llm")
if llm: if llm:

+ 3
- 3
api/apps/sdk/dataset.py 查看文件

from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService, LLMService from api.db.services.llm_service import TenantLLMService, LLMService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import RetCode
from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
get_result, get_result,
File2DocumentService.delete_by_document_id(doc.id) File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(id): if not KnowledgebaseService.delete_by_id(id):
return get_error_data_result(message="Delete dataset error.(Database error)") return get_error_data_result(message="Delete dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS)
return get_result(code=settings.RetCode.SUCCESS)




@manager.route("/datasets/<dataset_id>", methods=["PUT"]) @manager.route("/datasets/<dataset_id>", methods=["PUT"])
) )
if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)") return get_error_data_result(message="Update dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS)
return get_result(code=settings.RetCode.SUCCESS)




@manager.route("/datasets", methods=["GET"]) @manager.route("/datasets", methods=["GET"])

+ 6
- 6
api/apps/sdk/dify_retrieval.py 查看文件

from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler, kg_retrievaler, RetCode
from api import settings
from api.utils.api_utils import validate_request, build_error_result, apikey_required from api.utils.api_utils import validate_request, build_error_result, apikey_required






e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)


if kb.tenant_id != tenant_id: if kb.tenant_id != tenant_id:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)


embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)


retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval( ranks = retr.retrieval(
question, question,
embd_mdl, embd_mdl,
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return build_error_result( return build_error_result(
message='No chunk found! Check the chunk status please!', message='No chunk found! Check the chunk status please!',
code=RetCode.NOT_FOUND
code=settings.RetCode.NOT_FOUND
) )
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)

+ 25
- 25
api/apps/sdk/doc.py 查看文件

from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from api.settings import kg_retrievaler
from api import settings
import hashlib import hashlib
import re import re
from api.utils.api_utils import token_required from api.utils.api_utils import token_required
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils.api_utils import construct_json_result, get_parser_config from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search from rag.nlp import search
from rag.utils import rmSpace from rag.utils import rmSpace
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
import os import os


""" """
if "file" not in request.files: if "file" not in request.files:
return get_error_data_result( return get_error_data_result(
message="No file part!", code=RetCode.ARGUMENT_ERROR
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
) )
file_objs = request.files.getlist("file") file_objs = request.files.getlist("file")
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == "": if file_obj.filename == "":
return get_result( return get_result(
message="No file selected!", code=RetCode.ARGUMENT_ERROR
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
) )
# total size # total size
total_size = 0 total_size = 0
if total_size > MAX_TOTAL_FILE_SIZE: if total_size > MAX_TOTAL_FILE_SIZE:
return get_result( return get_result(
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
code=RetCode.ARGUMENT_ERROR,
code=settings.RetCode.ARGUMENT_ERROR,
) )
e, kb = KnowledgebaseService.get_by_id(dataset_id) e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e: if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!") raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id) err, files = FileService.upload_document(kb, file_objs, tenant_id)
if err: if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
# rename key's name # rename key's name
renamed_doc_list = [] renamed_doc_list = []
for file in files: for file in files:


if "name" in req and req["name"] != doc.name: if "name" in req and req["name"] != doc.name:
if ( if (
pathlib.Path(req["name"].lower()).suffix
!= pathlib.Path(doc.name.lower()).suffix
pathlib.Path(req["name"].lower()).suffix
!= pathlib.Path(doc.name.lower()).suffix
): ):
return get_result( return get_result(
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR,
code=settings.RetCode.ARGUMENT_ERROR,
) )
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]: if d.name == req["name"]:
) )
if not e: if not e:
return get_error_data_result(message="Document not found!") return get_error_data_result(message="Document not found!")
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)


return get_result() return get_result()


file_stream = STORAGE_IMPL.get(doc_id, doc_location) file_stream = STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream: if not file_stream:
return construct_json_result( return construct_json_result(
message="This file is empty.", code=RetCode.DATA_ERROR
message="This file is empty.", code=settings.RetCode.DATA_ERROR
) )
file = BytesIO(file_stream) file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type # Use send_file with a proper filename and MIME type
errors += str(e) errors += str(e)


if errors: if errors:
return get_result(message=errors, code=RetCode.SERVER_ERROR)
return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)


return get_result() return get_result()


info["chunk_num"] = 0 info["chunk_num"] = 0
info["token_num"] = 0 info["token_num"] = 0
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict() doc = doc.to_dict()
) )
info = {"run": "2", "progress": 0, "chunk_num": 0} info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()






res = {"total": 0, "chunks": [], "doc": renamed_doc} res = {"total": 0, "chunks": [], "doc": renamed_doc}
origin_chunks = [] origin_chunks = []
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
highlight=True)
res["total"] = sres.total res["total"] = sres.total
sign = 0 sign = 0
for id in sres.ids: for id in sres.ids:
v, c = embd_mdl.encode([doc.name, req["content"]]) v, c = embd_mdl.encode([doc.name, req["content"]])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)


DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys # rename keys
condition = {"doc_id": document_id} condition = {"doc_id": document_id}
if "chunk_ids" in req: if "chunk_ids" in req:
condition["id"] = req["chunk_ids"] condition["id"] = req["chunk_ids"]
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0: if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]): if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
schema: schema:
type: object type: object
""" """
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None: if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}") return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()




if len(embd_nms) != 1: if len(embd_nms) != 1:
return get_result( return get_result(
message='Datasets use different embedding models."', message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
) )
if "question" not in req: if "question" not in req:
return get_error_data_result("`question` is required.") return get_error_data_result("`question` is required.")
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)


retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval( ranks = retr.retrieval(
question, question,
embd_mdl, embd_mdl,
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_result( return get_result(
message="No chunk found! Check the chunk status please!", message="No chunk found! Check the chunk status please!",
code=RetCode.DATA_ERROR,
code=settings.RetCode.DATA_ERROR,
) )
return server_error_response(e)
return server_error_response(e)

+ 4
- 5
api/apps/system_app.py 查看文件

from api.db.services.api_service import APITokenService from api.db.services.api_service import APITokenService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.settings import DATABASE_TYPE
from api import settings
from api.utils import current_timestamp, datetime_format from api.utils import current_timestamp, datetime_format
from api.utils.api_utils import ( from api.utils.api_utils import (
get_json_result, get_json_result,
generate_confirmation_token, generate_confirmation_token,
) )
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
from timeit import default_timer as timer from timeit import default_timer as timer


res = {} res = {}
st = timer() st = timer()
try: try:
res["doc_store"] = docStoreConn.health()
res["doc_store"] = settings.docStoreConn.health()
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e: except Exception as e:
res["doc_store"] = { res["doc_store"] = {
try: try:
KnowledgebaseService.get_by_id("x") KnowledgebaseService.get_by_id("x")
res["database"] = { res["database"] = {
"database": DATABASE_TYPE.lower(),
"database": settings.DATABASE_TYPE.lower(),
"status": "green", "status": "green",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0), "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
} }
except Exception as e: except Exception as e:
res["database"] = { res["database"] = {
"database": DATABASE_TYPE.lower(),
"database": settings.DATABASE_TYPE.lower(),
"status": "red", "status": "red",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0), "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
"error": str(e), "error": str(e),

+ 29
- 42
api/apps/user_app.py 查看文件

datetime_format, datetime_format,
) )
from api.db import UserTenantRole, FileType from api.db import UserTenantRole, FileType
from api.settings import (
RetCode,
GITHUB_OAUTH,
FEISHU_OAUTH,
CHAT_MDL,
EMBEDDING_MDL,
ASR_MDL,
IMAGE2TEXT_MDL,
PARSERS,
API_KEY,
LLM_FACTORY,
LLM_BASE_URL,
RERANK_MDL,
)
from api import settings
from api.db.services.user_service import UserService, TenantService, UserTenantService from api.db.services.user_service import UserService, TenantService, UserTenantService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result, construct_response from api.utils.api_utils import get_json_result, construct_response
""" """
if not request.json: if not request.json:
return get_json_result( return get_json_result(
data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
) )


email = request.json.get("email", "") email = request.json.get("email", "")
if not users: if not users:
return get_json_result( return get_json_result(
data=False, data=False,
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
message=f"Email: {email} is not registered!", message=f"Email: {email} is not registered!",
) )


password = decrypt(password) password = decrypt(password)
except BaseException: except BaseException:
return get_json_result( return get_json_result(
data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
) )


user = UserService.query_user(email, password) user = UserService.query_user(email, password)
else: else:
return get_json_result( return get_json_result(
data=False, data=False,
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
message="Email and password do not match!", message="Email and password do not match!",
) )


import requests import requests


res = requests.post( res = requests.post(
GITHUB_OAUTH.get("url"),
settings.GITHUB_OAUTH.get("url"),
data={ data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"client_id": settings.GITHUB_OAUTH.get("client_id"),
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
"code": request.args.get("code"), "code": request.args.get("code"),
}, },
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
import requests import requests


app_access_token_res = requests.post( app_access_token_res = requests.post(
FEISHU_OAUTH.get("app_access_token_url"),
settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
"app_id": FEISHU_OAUTH.get("app_id"),
"app_secret": FEISHU_OAUTH.get("app_secret"),
"app_id": settings.FEISHU_OAUTH.get("app_id"),
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
} }
), ),
headers={"Content-Type": "application/json; charset=utf-8"}, headers={"Content-Type": "application/json; charset=utf-8"},
return redirect("/?error=%s" % app_access_token_res) return redirect("/?error=%s" % app_access_token_res)


res = requests.post( res = requests.post(
FEISHU_OAUTH.get("user_access_token_url"),
settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
"grant_type": FEISHU_OAUTH.get("grant_type"),
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
"code": request.args.get("code"), "code": request.args.get("code"),
} }
), ),
if request_data.get("password"): if request_data.get("password"):
new_password = request_data.get("new_password") new_password = request_data.get("new_password")
if not check_password_hash( if not check_password_hash(
current_user.password, decrypt(request_data["password"])
current_user.password, decrypt(request_data["password"])
): ):
return get_json_result( return get_json_result(
data=False, data=False,
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
message="Password error!", message="Password error!",
) )


except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
return get_json_result( return get_json_result(
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
) )




tenant = { tenant = {
"id": user_id, "id": user_id,
"name": user["nickname"] + "‘s Kingdom", "name": user["nickname"] + "‘s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL,
"rerank_id": RERANK_MDL,
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
"rerank_id": settings.RERANK_MDL,
} }
usr_tenant = { usr_tenant = {
"tenant_id": user_id, "tenant_id": user_id,
"location": "", "location": "",
} }
tenant_llm = [] tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append( tenant_llm.append(
{ {
"tenant_id": user_id, "tenant_id": user_id,
"llm_factory": LLM_FACTORY,
"llm_factory": settings.LLM_FACTORY,
"llm_name": llm.llm_name, "llm_name": llm.llm_name,
"model_type": llm.model_type, "model_type": llm.model_type,
"api_key": API_KEY,
"api_base": LLM_BASE_URL,
"api_key": settings.API_KEY,
"api_base": settings.LLM_BASE_URL,
} }
) )


return get_json_result( return get_json_result(
data=False, data=False,
message=f"Invalid email address: {email_address}!", message=f"Invalid email address: {email_address}!",
code=RetCode.OPERATING_ERROR,
code=settings.RetCode.OPERATING_ERROR,
) )


# Check if the email address is already used # Check if the email address is already used
return get_json_result( return get_json_result(
data=False, data=False,
message=f"Email: {email_address} has already registered!", message=f"Email: {email_address} has already registered!",
code=RetCode.OPERATING_ERROR,
code=settings.RetCode.OPERATING_ERROR,
) )


# Construct user info data # Construct user info data
return get_json_result( return get_json_result(
data=False, data=False,
message=f"User registration failure, error: {str(e)}", message=f"User registration failure, error: {str(e)}",
code=RetCode.EXCEPTION_ERROR,
code=settings.RetCode.EXCEPTION_ERROR,
) )





+ 7
- 7
api/db/db_models.py 查看文件

) )
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType from api.db import SerializedType, ParserType
from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE
from api import settings
from api import utils from api import utils


def singleton(cls, *args, **kw): def singleton(cls, *args, **kw):




class LongTextField(TextField): class LongTextField(TextField):
field_type = TextFieldType[DATABASE_TYPE.upper()].value
field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value




class JSONField(LongTextField): class JSONField(LongTextField):
@singleton @singleton
class BaseDataBase: class BaseDataBase:
def __init__(self): def __init__(self):
database_config = DATABASE.copy()
database_config = settings.DATABASE.copy()
db_name = database_config.pop("name") db_name = database_config.pop("name")
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info('init database on cluster mode successfully') logging.info('init database on cluster mode successfully')


class PostgresDatabaseLock: class PostgresDatabaseLock:




DB = BaseDataBase().database_connection DB = BaseDataBase().database_connection
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value




def close_connection(): def close_connection():
return self.email return self.email


def get_id(self): def get_id(self):
jwt = Serializer(secret_key=SECRET_KEY)
jwt = Serializer(secret_key=settings.SECRET_KEY)
return jwt.dumps(str(self.access_token)) return jwt.dumps(str(self.access_token))


class Meta: class Meta:


def migrate_db(): def migrate_db():
with DB.transaction(): with DB.transaction():
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
try: try:
migrate( migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",

+ 12
- 11
api/db/init_data.py 查看文件

from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
from api import settings
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory




tenant = { tenant = {
"id": user_info["id"], "id": user_info["id"],
"name": user_info["nickname"] + "‘s Kingdom", "name": user_info["nickname"] + "‘s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL
} }
usr_tenant = { usr_tenant = {
"tenant_id": user_info["id"], "tenant_id": user_info["id"],
"role": UserTenantRole.OWNER "role": UserTenantRole.OWNER
} }
tenant_llm = [] tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append( tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY, "api_base": LLM_BASE_URL})
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})


if not UserService.save(**user_info): if not UserService.save(**user_info):
logging.error("can't init admin.") logging.error("can't init admin.")


chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[ msg = chat_mdl.chat(system="", history=[
{"role": "user", "content": "Hello!"}], gen_conf={})
{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0: if msg.find("ERROR: ") == 0:
logging.error( logging.error(
"'{}' dosen't work. {}".format( "'{}' dosen't work. {}".format(
start_time = time.time() start_time = time.time()


init_llm_factory() init_llm_factory()
#if not UserService.get_all().count():
# if not UserService.get_all().count():
# init_superuser() # init_superuser()


add_graph_templates() add_graph_templates()

+ 4
- 4
api/db/services/dialog_service.py 查看文件

from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import retrievaler, kg_retrievaler
from api import settings
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder from rag.utils import rmSpace, num_tokens_from_string, encoder
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}


is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler


questions = [m["content"] for m in messages if m["role"] == "user"][-3:] questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None


logging.debug(f"{question} get SQL(refined): {sql}") logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1 tried_times += 1
return retrievaler.sql_retrieval(sql, format="json"), sql
return settings.retrievaler.sql_retrieval(sql, format="json"), sql


tbl, sql = get_table() tbl, sql = get_table()
if tbl is None: if tbl is None:
embd_nms = list(set([kb.embd_id for kb in kbs])) embd_nms = list(set([kb.embd_id for kb in kbs]))


is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler


embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)

+ 5
- 5
api/db/services/document_service.py 查看文件

from peewee import fn from peewee import fn


from api.db.db_utils import bulk_insert_into_db from api.db.db_utils import bulk_insert_into_db
from api.settings import docStoreConn
from api import settings
from api.utils import current_timestamp, get_format_time, get_uuid from api.utils import current_timestamp, get_format_time, get_uuid
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME from rag.settings import SVR_QUEUE_NAME
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def remove_document(cls, doc, tenant_id): def remove_document(cls, doc, tenant_id):
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id) cls.clear_chunk_num(doc.id)
return cls.delete_by_id(doc.id) return cls.delete_by_id(doc.id)


d["q_%d_vec" % len(v)] = v d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
if try_create_idx: if try_create_idx:
if not docStoreConn.indexExist(idxnm, kb_id):
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False try_create_idx = False
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)


DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

+ 5
- 6
api/ragflow_server.py 查看文件

from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor


from werkzeug.serving import run_simple from werkzeug.serving import run_simple
from api import settings
from api.apps import app from api.apps import app
from api.db.runtime_config import RuntimeConfig from api.db.runtime_config import RuntimeConfig
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import (
HOST, HTTP_PORT
)
from api import utils from api import utils


from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
f'project base: {utils.file_utils.get_project_base_directory()}' f'project base: {utils.file_utils.get_project_base_directory()}'
) )
show_configs() show_configs()
settings.init_settings()


# init db # init db
init_web_db() init_web_db()
logging.info("run on debug mode") logging.info("run on debug mode")


RuntimeConfig.init_env() RuntimeConfig.init_env()
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)


thread = ThreadPoolExecutor(max_workers=1) thread = ThreadPoolExecutor(max_workers=1)
thread.submit(update_progress) thread.submit(update_progress)
try: try:
logging.info("RAGFlow HTTP server start...") logging.info("RAGFlow HTTP server start...")
run_simple( run_simple(
hostname=HOST,
port=HTTP_PORT,
hostname=settings.HOST_IP,
port=settings.HOST_PORT,
application=app, application=app,
threaded=True, threaded=True,
use_reloader=RuntimeConfig.DEBUG, use_reloader=RuntimeConfig.DEBUG,

+ 144
- 101
api/settings.py 查看文件



REQUEST_WAIT_SEC = 2 REQUEST_WAIT_SEC = 2
REQUEST_MAX_WAIT_SEC = 300 REQUEST_MAX_WAIT_SEC = 300

LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url")

CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
if not LIGHTEN:
default_llm = {
"Tongyi-Qianwen": {
"chat_model": "qwen-plus",
"embedding_model": "text-embedding-v2",
"image2text_model": "qwen-vl-max",
"asr_model": "paraformer-realtime-8k-v1",
},
"OpenAI": {
"chat_model": "gpt-3.5-turbo",
"embedding_model": "text-embedding-ada-002",
"image2text_model": "gpt-4-vision-preview",
"asr_model": "whisper-1",
},
"Azure-OpenAI": {
"chat_model": "gpt-35-turbo",
"embedding_model": "text-embedding-ada-002",
"image2text_model": "gpt-4-vision-preview",
"asr_model": "whisper-1",
},
"ZHIPU-AI": {
"chat_model": "glm-3-turbo",
"embedding_model": "embedding-2",
"image2text_model": "glm-4v",
"asr_model": "",
},
"Ollama": {
"chat_model": "qwen-14B-chat",
"embedding_model": "flag-embedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"DeepSeek": {
"chat_model": "deepseek-chat",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"VolcEngine": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"BAAI": {
"chat_model": "",
"embedding_model": "BAAI/bge-large-zh-v1.5",
"image2text_model": "",
"asr_model": "",
"rerank_model": "BAAI/bge-reranker-v2-m3",
}
}

if LLM_FACTORY:
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"

API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get(
"parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")

HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")

SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME,
{}).get("secret_key", str(date.today()))
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
CHAT_MDL = ""
EMBEDDING_MDL = ""
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
API_KEY = None
PARSERS = None
HOST_IP = None
HOST_PORT = None
SECRET_KEY = None


DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)


# authentication # authentication
AUTHENTICATION_CONF = get_base_config("authentication", {})
AUTHENTICATION_CONF = None


# client # client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
"client", {}).get(
"switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")

DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
if DOC_ENGINE == "elasticsearch":
docStoreConn = rag.utils.es_conn.ESConnection()
elif DOC_ENGINE == "infinity":
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")

retrievaler = search.Dealer(docStoreConn)
kg_retrievaler = kg_search.KGSearch(docStoreConn)
CLIENT_AUTHENTICATION = None
HTTP_APP_KEY = None
GITHUB_OAUTH = None
FEISHU_OAUTH = None

DOC_ENGINE = None
docStoreConn = None

retrievaler = None
kg_retrievaler = None


def init_settings():
global LLM, LLM_FACTORY, LLM_BASE_URL
LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url")

global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
if not LIGHTEN:
default_llm = {
"Tongyi-Qianwen": {
"chat_model": "qwen-plus",
"embedding_model": "text-embedding-v2",
"image2text_model": "qwen-vl-max",
"asr_model": "paraformer-realtime-8k-v1",
},
"OpenAI": {
"chat_model": "gpt-3.5-turbo",
"embedding_model": "text-embedding-ada-002",
"image2text_model": "gpt-4-vision-preview",
"asr_model": "whisper-1",
},
"Azure-OpenAI": {
"chat_model": "gpt-35-turbo",
"embedding_model": "text-embedding-ada-002",
"image2text_model": "gpt-4-vision-preview",
"asr_model": "whisper-1",
},
"ZHIPU-AI": {
"chat_model": "glm-3-turbo",
"embedding_model": "embedding-2",
"image2text_model": "glm-4v",
"asr_model": "",
},
"Ollama": {
"chat_model": "qwen-14B-chat",
"embedding_model": "flag-embedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"DeepSeek": {
"chat_model": "deepseek-chat",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"VolcEngine": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"BAAI": {
"chat_model": "",
"embedding_model": "BAAI/bge-large-zh-v1.5",
"image2text_model": "",
"asr_model": "",
"rerank_model": "BAAI/bge-reranker-v2-m3",
}
}

if LLM_FACTORY:
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"

global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get(
"parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")

HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")

SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME,
{}).get("secret_key", str(date.today()))

global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
# authentication
AUTHENTICATION_CONF = get_base_config("authentication", {})

# client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
"client", {}).get(
"switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")

global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
if DOC_ENGINE == "elasticsearch":
docStoreConn = rag.utils.es_conn.ESConnection()
elif DOC_ENGINE == "infinity":
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")

retrievaler = search.Dealer(docStoreConn)
kg_retrievaler = kg_search.KGSearch(docStoreConn)

def get_host_ip():
global HOST_IP
return HOST_IP


def get_host_port():
global HOST_PORT
return HOST_PORT




class CustomEnum(Enum): class CustomEnum(Enum):

+ 24
- 26
api/utils/api_utils.py 查看文件

from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES


from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
)
from api.settings import RetCode
from api import settings

from api import settings
from api.utils import CustomJSONEncoder, get_uuid from api.utils import CustomJSONEncoder, get_uuid
from api.utils import json_dumps from api.utils import json_dumps


{}).items()} {}).items()}
prepped = requests.Request(**kwargs).prepare() prepped = requests.Request(**kwargs).prepare()


if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
timestamp = str(round(time() * 1000)) timestamp = str(round(time() * 1000))
nonce = str(uuid1()) nonce = str(uuid1())
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
timestamp.encode('ascii'), timestamp.encode('ascii'),
nonce.encode('ascii'), nonce.encode('ascii'),
HTTP_APP_KEY.encode('ascii'),
settings.HTTP_APP_KEY.encode('ascii'),
prepped.path_url.encode('ascii'), prepped.path_url.encode('ascii'),
prepped.body if kwargs.get('json') else b'', prepped.body if kwargs.get('json') else b'',
urlencode( urlencode(
prepped.headers.update({ prepped.headers.update({
'TIMESTAMP': timestamp, 'TIMESTAMP': timestamp,
'NONCE': nonce, 'NONCE': nonce,
'APP-KEY': HTTP_APP_KEY,
'APP-KEY': settings.HTTP_APP_KEY,
'SIGNATURE': signature, 'SIGNATURE': signature,
}) })


def get_exponential_backoff_interval(retries, full_jitter=False): def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time.""" """Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0 # Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
# Full jitter according to # Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter: if full_jitter:
return max(0, countdown) return max(0, countdown)




def get_data_error_result(code=RetCode.DATA_ERROR,
def get_data_error_result(code=settings.RetCode.DATA_ERROR,
message='Sorry! Data missing!'): message='Sorry! Data missing!'):
import re import re
result_dict = { result_dict = {
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return get_json_result( return get_json_result(
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))




def error_response(response_code, message=None): def error_response(response_code, message=None):
error_string += "required argument values: {}".format( error_string += "required argument values: {}".format(
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result( return get_json_result(
code=RetCode.ARGUMENT_ERROR, message=error_string)
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
return func(*_args, **_kwargs) return func(*_args, **_kwargs)


return decorated_function return decorated_function
return send_file(f, as_attachment=True, attachment_filename=filename) return send_file(f, as_attachment=True, attachment_filename=filename)




def get_json_result(code=RetCode.SUCCESS, message='success', data=None):
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
response = {"code": code, "message": message, "data": data} response = {"code": code, "message": message, "data": data}
return jsonify(response) return jsonify(response)


objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return build_error_result( return build_error_result(
message='API-KEY is invalid!', code=RetCode.FORBIDDEN
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
) )
kwargs['tenant_id'] = objs[0].tenant_id kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
return decorated_function return decorated_function




def build_error_result(code=RetCode.FORBIDDEN, message='success'):
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
response = {"code": code, "message": message} response = {"code": code, "message": message}
response = jsonify(response) response = jsonify(response)
response.status_code = code response.status_code = code
return response return response




def construct_response(code=RetCode.SUCCESS,
def construct_response(code=settings.RetCode.SUCCESS,
message='success', data=None, auth=None): message='success', data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data} result_dict = {"code": code, "message": message, "data": data}
response_dict = {} response_dict = {}
return response return response




def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
import re import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
response = {} response = {}
return jsonify(response) return jsonify(response)




def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
if data is None: if data is None:
return jsonify({"code": code, "message": message}) return jsonify({"code": code, "message": message})
else: else:
logging.exception(e) logging.exception(e)
try: try:
if e.code == 401: if e.code == 401:
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
except BaseException: except BaseException:
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))




def token_required(func): def token_required(func):
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR
data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
) )
kwargs['tenant_id'] = objs[0].tenant_id kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
return decorated_function return decorated_function




def get_result(code=RetCode.SUCCESS, message="", data=None):
def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
if code == 0: if code == 0:
if data is not None: if data is not None:
response = {"code": code, "data": data} response = {"code": code, "data": data}
return jsonify(response) return jsonify(response)




def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR,
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
): ):
import re import re
result_dict = { result_dict = {

+ 2
- 2
deepdoc/parser/pdf_parser.py 查看文件

from timeit import default_timer as timer from timeit import default_timer as timer
from pypdf import PdfReader as pdf2_read from pypdf import PdfReader as pdf2_read


from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
self.tbl_det = TableStructureRecognizer() self.tbl_det = TableStructureRecognizer()


self.updown_cnt_mdl = xgb.Booster() self.updown_cnt_mdl = xgb.Booster()
if not LIGHTEN:
if not settings.LIGHTEN:
try: try:
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():

+ 2
- 2
graphrag/claim_extractor.py 查看文件



from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService


kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)


ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = { info = {
"input_text": docs, "input_text": docs,
"entity_specs": "organization, person", "entity_specs": "organization, person",

+ 2
- 2
graphrag/smoke.py 查看文件



from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService


kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)


ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in docs = [d["content_with_weight"] for d in
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
graph = ex(docs) graph = ex(docs)


er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))

+ 11
- 11
rag/benchmark.py 查看文件

from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler, docStoreConn
from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from rag.nlp import tokenize, search from rag.nlp import tokenize, search
from ranx import evaluate from ranx import evaluate
run = defaultdict(dict) run = defaultdict(dict)
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight) 0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0: if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}") print(f"deleted query: {query}")
def init_index(self, vector_size: int): def init_index(self, vector_size: int):
if self.initialized_index: if self.initialized_index:
return return
if docStoreConn.indexExist(self.index_name, self.kb_id):
docStoreConn.deleteIdx(self.index_name, self.kb_id)
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True self.initialized_index = True


def ms_marco_index(self, file_path, index_name): def ms_marco_index(self, file_path, index_name):
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = [] docs = []


if docs: if docs:
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts return qrels, texts


def trivia_qa_index(self, file_path, index_name): def trivia_qa_index(self, file_path, index_name):
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs,self.index_name)
settings.docStoreConn.insert(docs,self.index_name)
docs = [] docs = []


docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts


def miracl_index(self, file_path, corpus_path, index_name): def miracl_index(self, file_path, corpus_path, index_name):
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
docs = [] docs = []


docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts


def save_results(self, qrels, run, texts, dataset, file_path): def save_results(self, qrels, run, texts, dataset, file_path):

+ 4
- 4
rag/llm/embedding_model.py 查看文件

import numpy as np import numpy as np
import asyncio import asyncio


from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai import google.generativeai as genai
^_- ^_-


""" """
if not LIGHTEN and not DefaultEmbedding._model:
if not settings.LIGHTEN and not DefaultEmbedding._model:
with DefaultEmbedding._model_lock: with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
threads: Optional[int] = None, threads: Optional[int] = None,
**kwargs, **kwargs,
): ):
if not LIGHTEN and not FastEmbed._model:
if not settings.LIGHTEN and not FastEmbed._model:
from fastembed import TextEmbedding from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)


_client = None _client = None


def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not LIGHTEN and not YoudaoEmbed._client:
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing from BCEmbedding import EmbeddingModel as qanthing
try: try:
logging.info("LOADING BCE...") logging.info("LOADING BCE...")

+ 3
- 3
rag/llm/rerank_model.py 查看文件

from abc import ABC from abc import ABC
import numpy as np import numpy as np


from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import json import json
^_- ^_-


""" """
if not LIGHTEN and not DefaultRerank._model:
if not settings.LIGHTEN and not DefaultRerank._model:
import torch import torch
from FlagEmbedding import FlagReranker from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock: with DefaultRerank._model_lock:
_model_lock = threading.Lock() _model_lock = threading.Lock()


def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not LIGHTEN and not YoudaoRerank._model:
if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel from BCEmbedding import RerankerModel
with YoudaoRerank._model_lock: with YoudaoRerank._model_lock:
if not YoudaoRerank._model: if not YoudaoRerank._model:

+ 22
- 15
rag/svr/task_executor.py 查看文件

import logging import logging
import sys import sys
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger

CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
initRootLogger(f"task_executor_{CONSUMER_NO}") initRootLogger(f"task_executor_{CONSUMER_NO}")
for module in ["pdfminer"]: for module in ["pdfminer"]:
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler, docStoreConn
from api import settings
from api.db.db_models import close_connection from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
HEAD_CREATED_AT = "" HEAD_CREATED_AT = ""
HEAD_DETAIL = "" HEAD_DETAIL = ""



def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD global PAYLOAD
if prog is not None and prog < 0: if prog is not None and prog < 0:
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError: except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
return return
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"])) logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
except Exception as e: except Exception as e:
callback(-1, "Internal server error while chunking: %s" % callback(-1, "Internal server error while chunking: %s" %
str(e).replace("'", ""))
str(e).replace("'", ""))
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"])) logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
return return


STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue()) STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st el += timer() - st
except Exception: except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))


d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"] del d["image"]
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"], d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
row["parser_config"]["auto_keywords"]).split(",") row["parser_config"]["auto_keywords"]).split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))


if row["parser_config"].get("auto_questions", 0): if row["parser_config"].get("auto_questions", 0):
st = timer() st = timer()
d["content_ltks"] += " " + qst d["content_ltks"] += " " + qst
if "content_sm_ltks" in d: if "content_sm_ltks" in d:
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst) d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
callback(msg="Question generation completed in {:.2f}s".format(timer() - st))


return docs return docs




def init_kb(row, vector_size: int): def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"]) idxnm = search.index_name(row["tenant_id"])
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)




def embedding(docs, mdl, parser_config=None, callback=None): def embedding(docs, mdl, parser_config=None, callback=None):
vector_size = len(vts[0]) vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size vctr_nm = "q_%d_vec" % vector_size
chunks = [] chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))


raptor = Raptor( raptor = Raptor(
# TODO: exception handler # TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ") ## set_progress(r["did"], -1, "ERROR: ")
callback( callback(
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
timer() - st)
) )
st = timer() st = timer()
try: try:
es_r = "" es_r = ""
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
if b % 128 == 0: if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")


logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: if es_r:
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
logging.error('Insert chunk error: ' + str(es_r)) logging.error('Insert chunk error: ' + str(es_r))
else: else:
if TaskService.do_cancel(r["id"]): if TaskService.do_cancel(r["id"]):
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
continue continue
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
callback(1., "Done!") callback(1., "Done!")
if PENDING_TASKS > 0: if PENDING_TASKS > 0:
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
if head_info is not None: if head_info is not None:
seconds = int(head_info[0].split("-")[0])/1000
seconds = int(head_info[0].split("-")[0]) / 1000
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
HEAD_DETAIL = head_info[1] HEAD_DETAIL = head_info[1]


REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")


expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
if expired > 0: if expired > 0:
REDIS_CONN.zpopmin(CONSUMER_NAME, expired) REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception: except Exception:

Loading…
取消
儲存