| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,17 +21,17 @@ from flask import Blueprint, Flask, request | |||
| from werkzeug.wrappers.request import Request | |||
| from flask_cors import CORS | |||
| from web_server.db import StatusEnum | |||
| from web_server.db.services import UserService | |||
| from web_server.utils import CustomJSONEncoder | |||
| from api.db import StatusEnum | |||
| from api.db.services import UserService | |||
| from api.utils import CustomJSONEncoder | |||
| from flask_session import Session | |||
| from flask_login import LoginManager | |||
| from web_server.settings import RetCode, SECRET_KEY, stat_logger | |||
| from web_server.hook import HookManager | |||
| from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters | |||
| from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger | |||
| from web_server.utils.api_utils import get_json_result, server_error_response | |||
| from api.settings import RetCode, SECRET_KEY, stat_logger | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters | |||
| from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger | |||
| from api.utils.api_utils import get_json_result, server_error_response | |||
| from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | |||
| __all__ = ['app'] | |||
| @@ -68,7 +68,7 @@ def search_pages_path(pages_dir): | |||
| def register_page(page_path): | |||
| page_name = page_path.stem.rstrip('_app') | |||
| module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, )) | |||
| module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name, )) | |||
| spec = spec_from_file_location(module_name, page_path) | |||
| page = module_from_spec(spec) | |||
| @@ -86,7 +86,7 @@ def register_page(page_path): | |||
| pages_dir = [ | |||
| Path(__file__).parent, | |||
| Path(__file__).parent.parent / 'web_server' / 'apps', | |||
| Path(__file__).parent.parent / 'api' / 'apps', | |||
| ] | |||
| client_urls_prefix = [ | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,31 +13,26 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import base64 | |||
| import hashlib | |||
| import pathlib | |||
| import re | |||
| from elasticsearch_dsl import Q | |||
| import numpy as np | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from rag.nlp import search, huqie | |||
| from rag.utils import ELASTICSEARCH, rmSpace | |||
| from web_server.db import LLMType | |||
| from web_server.db.services import duplicate_name | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.db.services.llm_service import TenantLLMService | |||
| from web_server.db.services.user_service import UserTenantService | |||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from web_server.utils import get_uuid | |||
| from web_server.db.services.document_service import DocumentService | |||
| from web_server.settings import RetCode | |||
| from web_server.utils.api_utils import get_json_result | |||
| from rag.utils.minio_conn import MINIO | |||
| from web_server.utils.file_utils import filename_type | |||
| retrival = search.Dealer(ELASTICSEARCH, None) | |||
| from api.db import LLMType | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| retrival = search.Dealer(ELASTICSEARCH) | |||
| @manager.route('/list', methods=['POST']) | |||
| @login_required | |||
| @@ -45,16 +40,29 @@ retrival = search.Dealer(ELASTICSEARCH, None) | |||
| def list(): | |||
| req = request.json | |||
| doc_id = req["doc_id"] | |||
| page = req.get("page", 1) | |||
| size = req.get("size", 30) | |||
| page = int(req.get("page", 1)) | |||
| size = int(req.get("size", 30)) | |||
| question = req.get("keywords", "") | |||
| try: | |||
| tenants = UserTenantService.query(user_id=current_user.id) | |||
| if not tenants: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| res = retrival.search({ | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") | |||
| query = { | |||
| "doc_ids": [doc_id], "page": page, "size": size, "question": question | |||
| }, search.index_name(tenants[0].tenant_id)) | |||
| } | |||
| if "available_int" in req: query["available_int"] = int(req["available_int"]) | |||
| sres = retrival.search(query, search.index_name(tenant_id)) | |||
| res = {"total": sres.total, "chunks": []} | |||
| for id in sres.ids: | |||
| d = { | |||
| "chunk_id": id, | |||
| "content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"], | |||
| "doc_id": sres.field[id]["doc_id"], | |||
| "docnm_kwd": sres.field[id]["docnm_kwd"], | |||
| "important_kwd": sres.field[id].get("important_kwd", []), | |||
| "img_id": sres.field[id].get("img_id", ""), | |||
| "available_int": sres.field[id].get("available_int", 1), | |||
| } | |||
| res["chunks"].append(d) | |||
| return get_json_result(data=res) | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| @@ -102,6 +110,7 @@ def set(): | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| d["important_kwd"] = req["important_kwd"] | |||
| d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) | |||
| if "available_int" in req: d["available_int"] = req["available_int"] | |||
| try: | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| @@ -116,10 +125,27 @@ def set(): | |||
| return server_error_response(e) | |||
| @manager.route('/switch', methods=['POST']) | |||
| @login_required | |||
| @validate_request("chunk_ids", "available_int", "doc_id") | |||
| def switch(): | |||
| req = request.json | |||
| try: | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") | |||
| if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]], | |||
| search.index_name(tenant_id)): | |||
| return get_data_error_result(retmsg="Index updating failure") | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/create', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id", "content_ltks", "important_kwd") | |||
| def set(): | |||
| def create(): | |||
| req = request.json | |||
| md5 = hashlib.md5() | |||
| md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) | |||
| @@ -148,3 +174,64 @@ def set(): | |||
| return get_json_result(data={"chunk_id": chunck_id}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/retrieval_test', methods=['POST']) | |||
| @login_required | |||
| @validate_request("kb_id", "question") | |||
| def retrieval_test(): | |||
| req = request.json | |||
| page = int(req.get("page", 1)) | |||
| size = int(req.get("size", 30)) | |||
| question = req["question"] | |||
| kb_id = req["kb_id"] | |||
| doc_ids = req.get("doc_ids", []) | |||
| similarity_threshold = float(req.get("similarity_threshold", 0.4)) | |||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | |||
| top = int(req.get("top", 1024)) | |||
| try: | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Knowledgebase not found!") | |||
| embd_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.EMBEDDING.value) | |||
| sres = retrival.search({"kb_ids": [kb_id], "doc_ids": doc_ids, "size": top, | |||
| "question": question, "vector": True, | |||
| "similarity": similarity_threshold}, | |||
| search.index_name(kb.tenant_id), | |||
| embd_mdl) | |||
| sim, tsim, vsim = retrival.rerank(sres, question, 1-vector_similarity_weight, vector_similarity_weight) | |||
| idx = np.argsort(sim*-1) | |||
| ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | |||
| start_idx = (page-1)*size | |||
| for i in idx: | |||
| ranks["total"] += 1 | |||
| if sim[i] < similarity_threshold: break | |||
| start_idx -= 1 | |||
| if start_idx >= 0:continue | |||
| if len(ranks["chunks"]) == size:continue | |||
| id = sres.ids[i] | |||
| dnm = sres.field[id]["docnm_kwd"] | |||
| d = { | |||
| "chunk_id": id, | |||
| "content_ltks": sres.field[id]["content_ltks"], | |||
| "doc_id": sres.field[id]["doc_id"], | |||
| "docnm_kwd": dnm, | |||
| "kb_id": sres.field[id]["kb_id"], | |||
| "important_kwd": sres.field[id].get("important_kwd", []), | |||
| "img_id": sres.field[id].get("img_id", ""), | |||
| "similarity": sim[i], | |||
| "vector_similarity": vsim[i], | |||
| "term_similarity": tsim[i] | |||
| } | |||
| ranks["chunks"].append(d) | |||
| if dnm not in ranks["doc_aggs"]:ranks["doc_aggs"][dnm] = 0 | |||
| ranks["doc_aggs"][dnm] += 1 | |||
| return get_json_result(data=ranks) | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| return get_json_result(data=False, retmsg=f'Index not found!', | |||
| retcode=RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @@ -0,0 +1,163 @@ | |||
| # | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import hashlib | |||
| import re | |||
| import numpy as np | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from api.db.services.dialog_service import DialogService | |||
| from rag.nlp import search, huqie | |||
| from rag.utils import ELASTICSEARCH, rmSpace | |||
| from api.db import LLMType, StatusEnum | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.user_service import UserTenantService, TenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import RetCode, stat_logger | |||
| from api.utils.api_utils import get_json_result | |||
| from rag.utils.minio_conn import MINIO | |||
| from api.utils.file_utils import filename_type | |||
| @manager.route('/set', methods=['POST']) | |||
| @login_required | |||
| def set(): | |||
| req = request.json | |||
| dialog_id = req.get("dialog_id") | |||
| name = req.get("name", "New Dialog") | |||
| description = req.get("description", "A helpful Dialog") | |||
| language = req.get("language", "Chinese") | |||
| llm_setting_type = req.get("llm_setting_type", "Precise") | |||
| llm_setting = req.get("llm_setting", { | |||
| "Creative": { | |||
| "temperature": 0.9, | |||
| "top_p": 0.9, | |||
| "frequency_penalty": 0.2, | |||
| "presence_penalty": 0.4, | |||
| "max_tokens": 512 | |||
| }, | |||
| "Precise": { | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, | |||
| "max_tokens": 215 | |||
| }, | |||
| "Evenly": { | |||
| "temperature": 0.5, | |||
| "top_p": 0.5, | |||
| "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, | |||
| "max_tokens": 215 | |||
| }, | |||
| "Custom": { | |||
| "temperature": 0.2, | |||
| "top_p": 0.3, | |||
| "frequency_penalty": 0.6, | |||
| "presence_penalty": 0.3, | |||
| "max_tokens": 215 | |||
| }, | |||
| }) | |||
| prompt_config = req.get("prompt_config", { | |||
| "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 | |||
| 以下是知识库: | |||
| {knowledge} | |||
| 以上是知识库。""", | |||
| "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", | |||
| "parameters": [ | |||
| {"key": "knowledge", "optional": False} | |||
| ], | |||
| "empty_response": "Sorry! 知识库中未找到相关内容!" | |||
| }) | |||
| if len(prompt_config["parameters"]) < 1: | |||
| return get_data_error_result(retmsg="'knowledge' should be in parameters") | |||
| for p in prompt_config["parameters"]: | |||
| if prompt_config["system"].find("{%s}"%p["key"]) < 0: | |||
| return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) | |||
| try: | |||
| e, tenant = TenantService.get_by_id(current_user.id) | |||
| if not e:return get_data_error_result(retmsg="Tenant not found!") | |||
| llm_id = req.get("llm_id", tenant.llm_id) | |||
| if not dialog_id: | |||
| dia = { | |||
| "id": get_uuid(), | |||
| "tenant_id": current_user.id, | |||
| "name": name, | |||
| "description": description, | |||
| "language": language, | |||
| "llm_id": llm_id, | |||
| "llm_setting_type": llm_setting_type, | |||
| "llm_setting": llm_setting, | |||
| "prompt_config": prompt_config | |||
| } | |||
| if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") | |||
| e, dia = DialogService.get_by_id(dia["id"]) | |||
| if not e: return get_data_error_result(retmsg="Fail to new a dialog!") | |||
| return get_json_result(data=dia.to_json()) | |||
| else: | |||
| del req["dialog_id"] | |||
| if "kb_names" in req: del req["kb_names"] | |||
| if not DialogService.update_by_id(dialog_id, req): | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| e, dia = DialogService.get_by_id(dialog_id) | |||
| if not e: return get_data_error_result(retmsg="Fail to update a dialog!") | |||
| dia = dia.to_dict() | |||
| dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | |||
| return get_json_result(data=dia) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/get', methods=['GET']) | |||
| @login_required | |||
| def get(): | |||
| dialog_id = request.args["dialog_id"] | |||
| try: | |||
| e,dia = DialogService.get_by_id(dialog_id) | |||
| if not e: return get_data_error_result(retmsg="Dialog not found!") | |||
| dia = dia.to_dict() | |||
| dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | |||
| return get_json_result(data=dia) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| def get_kb_names(kb_ids): | |||
| ids, nms = [], [] | |||
| for kid in kb_ids: | |||
| e, kb = KnowledgebaseService.get_by_id(kid) | |||
| if not e or kb.status != StatusEnum.VALID.value:continue | |||
| ids.append(kid) | |||
| nms.append(kb.name) | |||
| return ids, nms | |||
| @manager.route('/list', methods=['GET']) | |||
| @login_required | |||
| def list(): | |||
| try: | |||
| diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value) | |||
| diags = [d.to_dict() for d in diags] | |||
| for d in diags: | |||
| d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) | |||
| return get_json_result(data=diags) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,22 +16,23 @@ | |||
| import base64 | |||
| import pathlib | |||
| import flask | |||
| from elasticsearch_dsl import Q | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from rag.nlp import search | |||
| from rag.utils import ELASTICSEARCH | |||
| from web_server.db.services import duplicate_name | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from web_server.utils import get_uuid | |||
| from web_server.db import FileType | |||
| from web_server.db.services.document_service import DocumentService | |||
| from web_server.settings import RetCode | |||
| from web_server.utils.api_utils import get_json_result | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.db import FileType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| from rag.utils.minio_conn import MINIO | |||
| from web_server.utils.file_utils import filename_type | |||
| from api.utils.file_utils import filename_type | |||
| @manager.route('/upload', methods=['POST']) | |||
| @@ -163,21 +164,13 @@ def change_status(): | |||
| if str(req["status"]) == "0": | |||
| ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), | |||
| scripts=""" | |||
| if(ctx._source.kb_id.contains('%s')) | |||
| ctx._source.kb_id.remove( | |||
| ctx._source.kb_id.indexOf('%s') | |||
| ); | |||
| """ % (doc.kb_id, doc.kb_id), | |||
| scripts="ctx._source.available_int=0;", | |||
| idxnm=search.index_name( | |||
| kb.tenant_id) | |||
| ) | |||
| else: | |||
| ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), | |||
| scripts=""" | |||
| if(!ctx._source.kb_id.contains('%s')) | |||
| ctx._source.kb_id.add('%s'); | |||
| """ % (doc.kb_id, doc.kb_id), | |||
| scripts="ctx._source.available_int=1;", | |||
| idxnm=search.index_name( | |||
| kb.tenant_id) | |||
| ) | |||
| @@ -195,8 +188,7 @@ def rm(): | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): | |||
| return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR) | |||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)) | |||
| DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) | |||
| if not DocumentService.delete_by_id(req["doc_id"]): | |||
| @@ -277,3 +269,15 @@ def change_parser(): | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/image/<image_id>', methods=['GET']) | |||
| @login_required | |||
| def get_image(image_id): | |||
| try: | |||
| bkt, nm = image_id.split("-") | |||
| response = flask.make_response(MINIO.get(bkt, nm)) | |||
| response.headers.set('Content-Type', 'image/JPEG') | |||
| return response | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,15 +16,15 @@ | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from web_server.db.services import duplicate_name | |||
| from web_server.db.services.user_service import TenantService, UserTenantService | |||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db import StatusEnum, UserTenantRole | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.db.db_models import Knowledgebase | |||
| from web_server.settings import stat_logger, RetCode | |||
| from web_server.utils.api_utils import get_json_result | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid, get_format_time | |||
| from api.db import StatusEnum, UserTenantRole | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.db_models import Knowledgebase | |||
| from api.settings import stat_logger, RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| @manager.route('/create', methods=['post']) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,16 +16,16 @@ | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from web_server.db.services import duplicate_name | |||
| from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | |||
| from web_server.db.services.user_service import TenantService, UserTenantService | |||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db import StatusEnum, UserTenantRole | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.db.db_models import Knowledgebase, TenantLLM | |||
| from web_server.settings import stat_logger, RetCode | |||
| from web_server.utils.api_utils import get_json_result | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid, get_format_time | |||
| from api.db import StatusEnum, UserTenantRole | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.db_models import Knowledgebase, TenantLLM | |||
| from api.settings import stat_logger, RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| @manager.route('/factories', methods=['GET']) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,15 +17,15 @@ from flask import request, session, redirect, url_for | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from flask_login import login_required, current_user, login_user, logout_user | |||
| from web_server.db.db_models import TenantLLM | |||
| from web_server.db.services.llm_service import TenantLLMService | |||
| from web_server.utils.api_utils import server_error_response, validate_request | |||
| from web_server.utils import get_uuid, get_format_time, decrypt, download_img | |||
| from web_server.db import UserTenantRole, LLMType | |||
| from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS | |||
| from web_server.db.services.user_service import UserService, TenantService, UserTenantService | |||
| from web_server.settings import stat_logger | |||
| from web_server.utils.api_utils import get_json_result, cors_reponse | |||
| from api.db.db_models import TenantLLM | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.utils.api_utils import server_error_response, validate_request | |||
| from api.utils import get_uuid, get_format_time, decrypt, download_img | |||
| from api.db import UserTenantRole, LLMType | |||
| from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS | |||
| from api.db.services.user_service import UserService, TenantService, UserTenantService | |||
| from api.settings import stat_logger | |||
| from api.utils.api_utils import get_json_result, cors_reponse | |||
| @manager.route('/login', methods=['POST', 'GET']) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -51,4 +51,11 @@ class LLMType(StrEnum): | |||
| CHAT = 'chat' | |||
| EMBEDDING = 'embedding' | |||
| SPEECH2TEXT = 'speech2text' | |||
| IMAGE2TEXT = 'image2text' | |||
| IMAGE2TEXT = 'image2text' | |||
| class ChatStyle(StrEnum): | |||
| CREATIVE = 'Creative' | |||
| PRECISE = 'Precise' | |||
| EVENLY = 'Evenly' | |||
| CUSTOM = 'Custom' | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -29,10 +29,10 @@ from peewee import ( | |||
| ) | |||
| from playhouse.pool import PooledMySQLDatabase | |||
| from web_server.db import SerializedType | |||
| from web_server.settings import DATABASE, stat_logger, SECRET_KEY | |||
| from web_server.utils.log_utils import getLogger | |||
| from web_server import utils | |||
| from api.db import SerializedType | |||
| from api.settings import DATABASE, stat_logger, SECRET_KEY | |||
| from api.utils.log_utils import getLogger | |||
| from api import utils | |||
| LOGGER = getLogger() | |||
| @@ -467,6 +467,8 @@ class Knowledgebase(DataBaseModel): | |||
| doc_num = IntegerField(default=0) | |||
| token_num = IntegerField(default=0) | |||
| chunk_num = IntegerField(default=0) | |||
| similarity_threshold = FloatField(default=0.4) | |||
| vector_similarity_weight = FloatField(default=0.3) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| @@ -516,19 +518,20 @@ class Dialog(DataBaseModel): | |||
| prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") | |||
| prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", | |||
| "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) | |||
| kb_ids = JSONField(null=False, default=[]) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| class Meta: | |||
| db_table = "dialog" | |||
| class DialogKb(DataBaseModel): | |||
| dialog_id = CharField(max_length=32, null=False, index=True) | |||
| kb_id = CharField(max_length=32, null=False) | |||
| class Meta: | |||
| db_table = "dialog_kb" | |||
| primary_key = CompositeKey('dialog_id', 'kb_id') | |||
| # class DialogKb(DataBaseModel): | |||
| # dialog_id = CharField(max_length=32, null=False, index=True) | |||
| # kb_id = CharField(max_length=32, null=False) | |||
| # | |||
| # class Meta: | |||
| # db_table = "dialog_kb" | |||
| # primary_key = CompositeKey('dialog_id', 'kb_id') | |||
| class Conversation(DataBaseModel): | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2021 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2021 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -19,10 +19,10 @@ import time | |||
| from functools import wraps | |||
| from shortuuid import ShortUUID | |||
| from web_server.versions import get_rag_version | |||
| from api.versions import get_rag_version | |||
| from web_server.errors.error_services import * | |||
| from web_server.settings import ( | |||
| from api.errors.error_services import * | |||
| from api.settings import ( | |||
| GRPC_PORT, HOST, HTTP_PORT, | |||
| RANDOM_INSTANCE_ID, stat_logger, | |||
| ) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,11 +17,11 @@ import operator | |||
| from functools import reduce | |||
| from typing import Dict, Type, Union | |||
| from web_server.utils import current_timestamp, timestamp_to_date | |||
| from api.utils import current_timestamp, timestamp_to_date | |||
| from web_server.db.db_models import DB, DataBaseModel | |||
| from web_server.db.runtime_config import RuntimeConfig | |||
| from web_server.utils.log_utils import getLogger | |||
| from api.db.db_models import DB, DataBaseModel | |||
| from api.db.runtime_config import RuntimeConfig | |||
| from api.utils.log_utils import getLogger | |||
| from enum import Enum | |||
| @@ -123,9 +123,3 @@ def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0, | |||
| data = data.offset(offset) | |||
| return list(data), count | |||
| class StatusEnum(Enum): | |||
| # 样本可用状态 | |||
| VALID = "1" | |||
| IN_VALID = "0" | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,10 +16,10 @@ | |||
| import time | |||
| import uuid | |||
| from web_server.db import LLMType | |||
| from web_server.db.db_models import init_database_tables as init_web_db | |||
| from web_server.db.services import UserService | |||
| from web_server.db.services.llm_service import LLMFactoriesService, LLMService | |||
| from api.db import LLMType | |||
| from api.db.db_models import init_database_tables as init_web_db | |||
| from api.db.services import UserService | |||
| from api.db.services.llm_service import LLMFactoriesService, LLMService | |||
| def init_superuser(): | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,5 +17,5 @@ | |||
| import operator | |||
| import time | |||
| import typing | |||
| from web_server.utils.log_utils import sql_logger | |||
| from api.utils.log_utils import sql_logger | |||
| import peewee | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,7 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from web_server.versions import get_versions | |||
| from api.versions import get_versions | |||
| from .reload_config_base import ReloadConfigBase | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,8 +17,8 @@ from datetime import datetime | |||
| import peewee | |||
| from web_server.db.db_models import DB | |||
| from web_server.utils import datetime_format | |||
| from api.db.db_models import DB | |||
| from api.utils import datetime_format | |||
| class CommonService: | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,14 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from web_server.db.db_models import DB, UserTenant | |||
| from web_server.db.db_models import Dialog, Conversation, DialogKb | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db.db_utils import StatusEnum | |||
| from api.db.db_models import Dialog, Conversation | |||
| from api.db.services.common_service import CommonService | |||
| class DialogService(CommonService): | |||
| @@ -29,7 +23,3 @@ class DialogService(CommonService): | |||
| class ConversationService(CommonService): | |||
| model = Conversation | |||
| class DialogKbService(CommonService): | |||
| model = DialogKb | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,12 +15,12 @@ | |||
| # | |||
| from peewee import Expression | |||
| from web_server.db import TenantPermission, FileType | |||
| from web_server.db.db_models import DB, Knowledgebase, Tenant | |||
| from web_server.db.db_models import Document | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.db.db_utils import StatusEnum | |||
| from api.db import TenantPermission, FileType | |||
| from api.db.db_models import DB, Knowledgebase, Tenant | |||
| from api.db.db_models import Document | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db import StatusEnum | |||
| class DocumentService(CommonService): | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,15 +13,12 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from web_server.db import TenantPermission | |||
| from web_server.db.db_models import DB, UserTenant, Tenant | |||
| from web_server.db.db_models import Knowledgebase | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db.db_utils import StatusEnum | |||
| from api.db import TenantPermission | |||
| from api.db.db_models import DB, Tenant | |||
| from api.db.db_models import Knowledgebase | |||
| from api.db.services.common_service import CommonService | |||
| from api.db import StatusEnum | |||
| class KnowledgebaseService(CommonService): | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,14 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from web_server.db.db_models import DB, UserTenant | |||
| from web_server.db.db_models import Knowledgebase, Document | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db.db_utils import StatusEnum | |||
| from api.db.db_models import Knowledgebase, Document | |||
| from api.db.services.common_service import CommonService | |||
| class KnowledgebaseService(CommonService): | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,15 +13,12 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from rag.llm import EmbeddingModel, CvModel | |||
| from web_server.db import LLMType | |||
| from web_server.db.db_models import DB, UserTenant | |||
| from web_server.db.db_models import LLMFactories, LLM, TenantLLM | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.db.db_utils import StatusEnum | |||
| from api.db import LLMType | |||
| from api.db.db_models import DB, UserTenant | |||
| from api.db.db_models import LLMFactories, LLM, TenantLLM | |||
| from api.db.services.common_service import CommonService | |||
| from api.db import StatusEnum | |||
| class LLMFactoriesService(CommonService): | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,12 +16,12 @@ | |||
| import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from web_server.db import UserTenantRole | |||
| from web_server.db.db_models import DB, UserTenant | |||
| from web_server.db.db_models import User, Tenant | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db.db_utils import StatusEnum | |||
| from api.db import UserTenantRole | |||
| from api.db.db_models import DB, UserTenant | |||
| from api.db.db_models import User, Tenant | |||
| from api.db.services.common_service import CommonService | |||
| from api.utils import get_uuid, get_format_time | |||
| from api.db import StatusEnum | |||
| class UserService(CommonService): | |||
| @@ -1,4 +1,4 @@ | |||
| from web_server.errors import RagFlowError | |||
| from api.errors import RagFlowError | |||
| __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', | |||
| 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,8 +1,8 @@ | |||
| import importlib | |||
| from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ | |||
| from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ | |||
| SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters | |||
| from web_server.settings import HOOK_MODULE, stat_logger,RetCode | |||
| from api.settings import HOOK_MODULE, stat_logger,RetCode | |||
| class HookManager: | |||
| @@ -1,10 +1,10 @@ | |||
| import requests | |||
| from web_server.db.service_registry import ServiceRegistry | |||
| from web_server.settings import RegistryServiceName | |||
| from web_server.hook import HookManager | |||
| from web_server.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn | |||
| from web_server.settings import HOOK_SERVER_NAME | |||
| from api.db.service_registry import ServiceRegistry | |||
| from api.settings import RegistryServiceName | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn | |||
| from api.settings import HOOK_SERVER_NAME | |||
| @HookManager.register_client_authentication_hook | |||
| @@ -1,10 +1,10 @@ | |||
| import requests | |||
| from web_server.db.service_registry import ServiceRegistry | |||
| from web_server.settings import RegistryServiceName | |||
| from web_server.hook import HookManager | |||
| from web_server.hook.common.parameters import PermissionCheckParameters, PermissionReturn | |||
| from web_server.settings import HOOK_SERVER_NAME | |||
| from api.db.service_registry import ServiceRegistry | |||
| from api.settings import RegistryServiceName | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn | |||
| from api.settings import HOOK_SERVER_NAME | |||
| @HookManager.register_permission_check_hook | |||
| @@ -1,11 +1,11 @@ | |||
| import requests | |||
| from web_server.db.service_registry import ServiceRegistry | |||
| from web_server.settings import RegistryServiceName | |||
| from web_server.hook import HookManager | |||
| from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ | |||
| from api.db.service_registry import ServiceRegistry | |||
| from api.settings import RegistryServiceName | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ | |||
| SignatureReturn | |||
| from web_server.settings import HOOK_SERVER_NAME, PARTY_ID | |||
| from api.settings import HOOK_SERVER_NAME, PARTY_ID | |||
| @HookManager.register_site_signature_hook | |||
| @@ -1,4 +1,4 @@ | |||
| from web_server.settings import RetCode | |||
| from api.settings import RetCode | |||
| class ParametersBase: | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -23,17 +23,17 @@ import traceback | |||
| from werkzeug.serving import run_simple | |||
| from web_server.apps import app | |||
| from web_server.db.runtime_config import RuntimeConfig | |||
| from web_server.hook import HookManager | |||
| from web_server.settings import ( | |||
| from api.apps import app | |||
| from api.db.runtime_config import RuntimeConfig | |||
| from api.hook import HookManager | |||
| from api.settings import ( | |||
| HOST, HTTP_PORT, access_logger, database_logger, stat_logger, | |||
| ) | |||
| from web_server import utils | |||
| from api import utils | |||
| from web_server.db.db_models import init_database_tables as init_web_db | |||
| from web_server.db.init_data import init_web_data | |||
| from web_server.versions import get_versions | |||
| from api.db.db_models import init_database_tables as init_web_db | |||
| from api.db.init_data import init_web_data | |||
| from api.versions import get_versions | |||
| if __name__ == '__main__': | |||
| stat_logger.info( | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,9 +17,9 @@ import os | |||
| from enum import IntEnum, Enum | |||
| from web_server.utils import get_base_config,decrypt_database_config | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| from web_server.utils.log_utils import LoggerFactory, getLogger | |||
| from api.utils import get_base_config,decrypt_database_config | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from api.utils.log_utils import LoggerFactory, getLogger | |||
| # Server | |||
| @@ -71,7 +71,7 @@ PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") | |||
| DATABASE = decrypt_database_config() | |||
| # Logger | |||
| LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "web_server")) | |||
| LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api")) | |||
| # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} | |||
| LoggerFactory.LEVEL = 10 | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,16 +24,16 @@ from flask import ( | |||
| ) | |||
| from werkzeug.http import HTTP_STATUS_CODES | |||
| from web_server.utils import json_dumps | |||
| from web_server.versions import get_rag_version | |||
| from web_server.settings import RetCode | |||
| from web_server.settings import ( | |||
| from api.utils import json_dumps | |||
| from api.versions import get_rag_version | |||
| from api.settings import RetCode | |||
| from api.settings import ( | |||
| REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | |||
| stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY | |||
| ) | |||
| import requests | |||
| import functools | |||
| from web_server.utils import CustomJSONEncoder | |||
| from api.utils import CustomJSONEncoder | |||
| from uuid import uuid1 | |||
| from base64 import b64encode | |||
| from hmac import HMAC | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,7 +21,7 @@ import re | |||
| from cachetools import LRUCache, cached | |||
| from ruamel.yaml import YAML | |||
| from web_server.db import FileType | |||
| from api.db import FileType | |||
| PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") | |||
| RAG_BASE = os.getenv("RAG_BASE") | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,7 +21,7 @@ import inspect | |||
| from logging.handlers import TimedRotatingFileHandler | |||
| from threading import RLock | |||
| from web_server.utils import file_utils | |||
| from api.utils import file_utils | |||
| class LoggerFactory(object): | |||
| TYPE = "FILE" | |||
| @@ -1,7 +1,7 @@ | |||
| import base64, os, sys | |||
| from Cryptodome.PublicKey import RSA | |||
| from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 | |||
| from web_server.utils import decrypt, file_utils | |||
| from api.utils import decrypt, file_utils | |||
| def crypt(line): | |||
| file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,7 +18,7 @@ import os | |||
| import dotenv | |||
| import typing | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| from api.utils.file_utils import get_project_base_directory | |||
| def get_versions() -> typing.Mapping[str, typing.Any]: | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -60,6 +60,10 @@ class HuEmbedding(Base): | |||
| res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) | |||
| return np.array(res), token_count | |||
| def encode_queries(self, text: str): | |||
| token_count = num_tokens_from_string(text) | |||
| return self.model.encode_queries([text]).tolist()[0], token_count | |||
| class OpenAIEmbed(Base): | |||
| def __init__(self, key, model_name="text-embedding-ada-002"): | |||
| @@ -9,7 +9,7 @@ import string | |||
| import sys | |||
| from hanziconv import HanziConv | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| from api.utils.file_utils import get_project_base_directory | |||
| class Huqie: | |||
| @@ -147,7 +147,7 @@ class EsQueryer: | |||
| atks = toDict(atks) | |||
| btkss = [toDict(tks) for tks in btkss] | |||
| tksim = [self.similarity(atks, btks) for btks in btkss] | |||
| return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight | |||
| return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, sims[0], tksim | |||
| def similarity(self, qtwt, dtwt): | |||
| if isinstance(dtwt, type("")): | |||
| @@ -15,7 +15,7 @@ def index_name(uid): return f"ragflow_{uid}" | |||
| class Dealer: | |||
| def __init__(self, es, emb_mdl): | |||
| def __init__(self, es): | |||
| self.qryr = query.EsQueryer(es) | |||
| self.qryr.flds = [ | |||
| "title_tks^10", | |||
| @@ -23,7 +23,6 @@ class Dealer: | |||
| "content_ltks^2", | |||
| "content_sm_ltks"] | |||
| self.es = es | |||
| self.emb_mdl = emb_mdl | |||
| @dataclass | |||
| class SearchResult: | |||
| @@ -36,23 +35,26 @@ class Dealer: | |||
| keywords: Optional[List[str]] = None | |||
| group_docs: List[List] = None | |||
| def _vector(self, txt, sim=0.8, topk=10): | |||
| qv = self.emb_mdl.encode_queries(txt) | |||
| def _vector(self, txt, emb_mdl, sim=0.8, topk=10): | |||
| qv, c = emb_mdl.encode_queries(txt) | |||
| return { | |||
| "field": "q_%d_vec"%len(qv), | |||
| "k": topk, | |||
| "similarity": sim, | |||
| "num_candidates": 1000, | |||
| "num_candidates": topk*2, | |||
| "query_vector": qv | |||
| } | |||
| def search(self, req, idxnm, tks_num=3): | |||
| def search(self, req, idxnm, emb_mdl=None): | |||
| qst = req.get("question", "") | |||
| bqry, keywords = self.qryr.question(qst) | |||
| if req.get("kb_ids"): | |||
| bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) | |||
| if req.get("doc_ids"): | |||
| bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) | |||
| if "available_int" in req: | |||
| if req["available_int"] == 0: bqry.filter.append(Q("range", available_int={"lt": 1})) | |||
| else: bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1}))) | |||
| bqry.boost = 0.05 | |||
| s = Search() | |||
| @@ -60,7 +62,7 @@ class Dealer: | |||
| ps = int(req.get("size", 1000)) | |||
| src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id","img_id", | |||
| "image_id", "doc_id", "q_512_vec", "q_768_vec", | |||
| "q_1024_vec", "q_1536_vec"]) | |||
| "q_1024_vec", "q_1536_vec", "available_int"]) | |||
| s = s.query(bqry)[pg * ps:(pg + 1) * ps] | |||
| s = s.highlight("content_ltks") | |||
| @@ -80,7 +82,8 @@ class Dealer: | |||
| s = s.to_dict() | |||
| q_vec = [] | |||
| if req.get("vector"): | |||
| s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps) | |||
| assert emb_mdl, "No embedding model selected" | |||
| s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps) | |||
| s["knn"]["filter"] = bqry.to_dict() | |||
| if "highlight" in s: del s["highlight"] | |||
| q_vec = s["knn"]["query_vector"] | |||
| @@ -168,7 +171,7 @@ class Dealer: | |||
| def trans2floats(txt): | |||
| return [float(t) for t in txt.split("\t")] | |||
| def insert_citations(self, ans, top_idx, sres, | |||
| def insert_citations(self, ans, top_idx, sres, emb_mdl, | |||
| vfield="q_vec", cfield="content_ltks"): | |||
| ins_embd = [Dealer.trans2floats( | |||
| @@ -179,15 +182,14 @@ class Dealer: | |||
| res = "" | |||
| def citeit(): | |||
| nonlocal s, e, ans, res | |||
| nonlocal s, e, ans, res, emb_mdl | |||
| if not ins_embd: | |||
| return | |||
| embd = self.emb_mdl.encode(ans[s: e]) | |||
| embd = emb_mdl.encode(ans[s: e]) | |||
| sim = self.qryr.hybrid_similarity(embd, | |||
| ins_embd, | |||
| huqie.qie(ans[s:e]).split(" "), | |||
| ins_tw) | |||
| print(ans[s: e], sim) | |||
| mx = np.max(sim) * 0.99 | |||
| if mx < 0.55: | |||
| return | |||
| @@ -225,20 +227,18 @@ class Dealer: | |||
| return res | |||
| def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, | |||
| vfield="q_vec", cfield="content_ltks"): | |||
| def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"): | |||
| ins_embd = [ | |||
| Dealer.trans2floats( | |||
| sres.field[i]["q_vec"]) for i in sres.ids] | |||
| sres.field[i]["q_%d_vec"%len(sres.query_vector)]) for i in sres.ids] | |||
| if not ins_embd: | |||
| return [] | |||
| ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids] | |||
| # return CosineSimilarity([sres.query_vector], ins_embd)[0] | |||
| sim = self.qryr.hybrid_similarity(sres.query_vector, | |||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | |||
| ins_embd, | |||
| huqie.qie(query).split(" "), | |||
| ins_tw, tkweight, vtweight) | |||
| return sim | |||
| return sim, tksim, vtsim | |||
| @@ -4,7 +4,7 @@ import time | |||
| import logging | |||
| import re | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| from api.utils.file_utils import get_project_base_directory | |||
| class Dealer: | |||
| @@ -5,7 +5,7 @@ import re | |||
| import os | |||
| import numpy as np | |||
| from rag.nlp import huqie | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| from api.utils.file_utils import get_project_base_directory | |||
| class Dealer: | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,9 +14,9 @@ | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from web_server.utils import get_base_config,decrypt_database_config | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| from web_server.utils.log_utils import LoggerFactory, getLogger | |||
| from api.utils import get_base_config,decrypt_database_config | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from api.utils.log_utils import LoggerFactory, getLogger | |||
| # Server | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # Copyright 2019 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -47,12 +47,12 @@ from rag.nlp.huchunk import ( | |||
| PptChunker, | |||
| TextChunker | |||
| ) | |||
| from web_server.db import LLMType | |||
| from web_server.db.services.document_service import DocumentService | |||
| from web_server.db.services.llm_service import TenantLLMService | |||
| from web_server.settings import database_logger | |||
| from web_server.utils import get_format_time | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| from api.db import LLMType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.settings import database_logger | |||
| from api.utils import get_format_time | |||
| from api.utils.file_utils import get_project_base_directory | |||
| BATCH_SIZE = 64 | |||
| @@ -257,7 +257,6 @@ def main(comm, mod): | |||
| cron_logger.error(str(e)) | |||
| continue | |||
| set_progress(r["id"], random.randint(70, 95) / 100., | |||
| "Finished embedding! Start to build index!") | |||
| init_kb(r) | |||
| @@ -66,7 +66,6 @@ class HuEs: | |||
| body=d, | |||
| id=id, | |||
| refresh=False, | |||
| doc_type="_doc", | |||
| retry_on_conflict=100) | |||
| es_logger.info("Successfully upsert: %s" % id) | |||
| T = True | |||