* add front end code * change licence * rename web_server to API * change name to web_servertags/v0.1.0
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -0,0 +1,150 @@ | |||
| # | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import base64 | |||
| import hashlib | |||
| import pathlib | |||
| import re | |||
| from elasticsearch_dsl import Q | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from rag.nlp import search, huqie | |||
| from rag.utils import ELASTICSEARCH, rmSpace | |||
| from web_server.db import LLMType | |||
| from web_server.db.services import duplicate_name | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.db.services.llm_service import TenantLLMService | |||
| from web_server.db.services.user_service import UserTenantService | |||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from web_server.utils import get_uuid | |||
| from web_server.db.services.document_service import DocumentService | |||
| from web_server.settings import RetCode | |||
| from web_server.utils.api_utils import get_json_result | |||
| from rag.utils.minio_conn import MINIO | |||
| from web_server.utils.file_utils import filename_type | |||
| retrival = search.Dealer(ELASTICSEARCH, None) | |||
| @manager.route('/list', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id") | |||
| def list(): | |||
| req = request.json | |||
| doc_id = req["doc_id"] | |||
| page = req.get("page", 1) | |||
| size = req.get("size", 30) | |||
| question = req.get("keywords", "") | |||
| try: | |||
| tenants = UserTenantService.query(user_id=current_user.id) | |||
| if not tenants: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| res = retrival.search({ | |||
| "doc_ids": [doc_id], "page": page, "size": size, "question": question | |||
| }, search.index_name(tenants[0].tenant_id)) | |||
| return get_json_result(data=res) | |||
| except Exception as e: | |||
| if str(e).find("not_found") > 0: | |||
| return get_json_result(data=False, retmsg=f'Index not found!', | |||
| retcode=RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @manager.route('/get', methods=['GET']) | |||
| @login_required | |||
| def get(): | |||
| chunk_id = request.args["chunk_id"] | |||
| try: | |||
| tenants = UserTenantService.query(user_id=current_user.id) | |||
| if not tenants: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| res = ELASTICSEARCH.get(chunk_id, search.index_name(tenants[0].tenant_id)) | |||
| if not res.get("found"):return server_error_response("Chunk not found") | |||
| id = res["_id"] | |||
| res = res["_source"] | |||
| res["chunk_id"] = id | |||
| k = [] | |||
| for n in res.keys(): | |||
| if re.search(r"(_vec$|_sm_)", n): | |||
| k.append(n) | |||
| if re.search(r"(_tks|_ltks)", n): | |||
| res[n] = rmSpace(res[n]) | |||
| for n in k: del res[n] | |||
| return get_json_result(data=res) | |||
| except Exception as e: | |||
| if str(e).find("NotFoundError") >= 0: | |||
| return get_json_result(data=False, retmsg=f'Chunk not found!', | |||
| retcode=RetCode.DATA_ERROR) | |||
| return server_error_response(e) | |||
| @manager.route('/set', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id", "chunk_id", "content_ltks", "important_kwd", "docnm_kwd") | |||
| def set(): | |||
| req = request.json | |||
| d = {"id": req["chunk_id"]} | |||
| d["content_ltks"] = huqie.qie(req["content_ltks"]) | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| d["important_kwd"] = req["important_kwd"] | |||
| d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) | |||
| try: | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") | |||
| embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value) | |||
| v, c = embd_mdl.encode([req["docnm_kwd"], req["content_ltks"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] | |||
| d["q_%d_vec"%len(v)] = v.tolist() | |||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/create', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id", "content_ltks", "important_kwd") | |||
| def set(): | |||
| req = request.json | |||
| md5 = hashlib.md5() | |||
| md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) | |||
| chunck_id = md5.hexdigest() | |||
| d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])} | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| d["important_kwd"] = req["important_kwd"] | |||
| d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) | |||
| try: | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: return get_data_error_result(retmsg="Document not found!") | |||
| d["kb_id"] = [doc.kb_id] | |||
| d["docnm_kwd"] = doc.name | |||
| d["doc_id"] = doc.id | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") | |||
| embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value) | |||
| v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) | |||
| DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) | |||
| v = 0.1 * v[0] + 0.9 * v[1] | |||
| d["q_%d_vec"%len(v)] = v.tolist() | |||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||
| return get_json_result(data={"chunk_id": chunck_id}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,10 +24,9 @@ from rag.nlp import search | |||
| from rag.utils import ELASTICSEARCH | |||
| from web_server.db.services import duplicate_name | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.db.services.user_service import TenantService | |||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db import StatusEnum, FileType | |||
| from web_server.utils import get_uuid | |||
| from web_server.db import FileType | |||
| from web_server.db.services.document_service import DocumentService | |||
| from web_server.settings import RetCode | |||
| from web_server.utils.api_utils import get_json_result | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -33,7 +33,7 @@ from web_server.utils.api_utils import get_json_result | |||
| def factories(): | |||
| try: | |||
| fac = LLMFactoriesService.get_all() | |||
| return get_json_result(data=fac.to_json()) | |||
| return get_json_result(data=[f.to_dict() for f in fac]) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -60,9 +60,7 @@ def set_api_key(): | |||
| @login_required | |||
| def my_llms(): | |||
| try: | |||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||
| objs = [o.to_dict() for o in objs] | |||
| for o in objs: del o["api_key"] | |||
| objs = TenantLLMService.get_my_llms(current_user.id) | |||
| return get_json_result(data=objs) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -252,3 +252,17 @@ def tenant_info(): | |||
| return get_json_result(data=tenants) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/set_tenant_info", methods=["POST"]) | |||
| @login_required | |||
| @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id") | |||
| def set_tenant_info(): | |||
| req = request.json | |||
| try: | |||
| tid = req["tenant_id"] | |||
| del req["tenant_id"] | |||
| TenantService.update_by_id(tid, req) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2021 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2021 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -19,7 +19,7 @@ import time | |||
| from functools import wraps | |||
| from shortuuid import ShortUUID | |||
| from web_server.versions import get_fate_version | |||
| from web_server.versions import get_rag_version | |||
| from web_server.errors.error_services import * | |||
| from web_server.settings import ( | |||
| @@ -34,7 +34,7 @@ server_instance = ( | |||
| json.dumps({ | |||
| 'instance_id': instance_id, | |||
| 'timestamp': round(time.time() * 1000), | |||
| 'version': get_fate_version() or '', | |||
| 'version': get_rag_version() or '', | |||
| 'host': HOST, | |||
| 'grpc_port': GRPC_PORT, | |||
| 'http_port': HTTP_PORT, | |||
| @@ -68,7 +68,7 @@ class ServicesDB(abc.ABC): | |||
| @abc.abstractmethod | |||
| def supported_services(self): | |||
| """The names of supported services. | |||
| The returned list SHOULD contain `fateflow` (model download) and `servings` (FATE-Serving). | |||
| The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving). | |||
| :return: The service names. | |||
| :rtype: list | |||
| @@ -142,8 +142,8 @@ class ServicesDB(abc.ABC): | |||
| @check_service_supported | |||
| def get_urls(self, service_name, with_values=False): | |||
| """Query service urls from database. The urls may belong to other nodes. | |||
| Currently, only `fateflow` (model download) urls and `servings` (FATE-Serving) urls are supported. | |||
| `fateflow` is a url containing scheme, host, port and path, | |||
| Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported. | |||
| `ragflow` is a url containing scheme, host, port and path, | |||
| while `servings` only contains host and port. | |||
| :param str service_name: The service name. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -87,3 +87,10 @@ class DocumentService(CommonService): | |||
| num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() | |||
| return num | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_tenant_id(cls, doc_id): | |||
| docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value) | |||
| docs = docs.dicts() | |||
| if not docs:return | |||
| return docs[0]["tenant_id"] | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,10 +16,11 @@ | |||
| import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from rag.llm import EmbeddingModel, CvModel | |||
| from web_server.db import LLMType | |||
| from web_server.db.db_models import DB, UserTenant | |||
| from web_server.db.db_models import LLMFactories, LLM, TenantLLM | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db.db_utils import StatusEnum | |||
| @@ -51,3 +52,25 @@ class TenantLLMService(CommonService): | |||
| if not objs:return | |||
| return objs[0] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_my_llms(cls, tenant_id): | |||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name] | |||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts() | |||
| return list(objs) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def model_instance(cls, tenant_id, llm_type): | |||
| model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING) | |||
| if not model_config: | |||
| model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""} | |||
| else: | |||
| model_config = model_config[0].to_dict() | |||
| if llm_type == LLMType.EMBEDDING: | |||
| if model_config["llm_factory"] not in EmbeddingModel: return | |||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||
| if llm_type == LLMType.IMAGE2TEXT: | |||
| if model_config["llm_factory"] not in CvModel: return | |||
| return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"]) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,8 +1,8 @@ | |||
| from .general_error import * | |||
| class FateFlowError(Exception): | |||
| message = 'Unknown Fate Flow Error' | |||
| class RagFlowError(Exception): | |||
| message = 'Unknown Rag Flow Error' | |||
| def __init__(self, message=None, *args, **kwargs): | |||
| message = str(message) if message is not None else self.message | |||
| @@ -1,10 +1,10 @@ | |||
| from web_server.errors import FateFlowError | |||
| from web_server.errors import RagFlowError | |||
| __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', | |||
| 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] | |||
| class ServicesError(FateFlowError): | |||
| class ServicesError(RagFlowError): | |||
| message = 'Unknown services error' | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -46,7 +46,7 @@ if __name__ == '__main__': | |||
| # init runtime config | |||
| import argparse | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--version', default=False, help="fate flow version", action='store_true') | |||
| parser.add_argument('--version', default=False, help="rag flow version", action='store_true') | |||
| parser.add_argument('--debug', default=False, help="debug mode", action='store_true') | |||
| args = parser.parse_args() | |||
| if args.version: | |||
| @@ -64,13 +64,13 @@ if __name__ == '__main__': | |||
| peewee_logger = logging.getLogger('peewee') | |||
| peewee_logger.propagate = False | |||
| # fate_arch.common.log.ROpenHandler | |||
| # rag_arch.common.log.ROpenHandler | |||
| peewee_logger.addHandler(database_logger.handlers[0]) | |||
| peewee_logger.setLevel(database_logger.level) | |||
| # start http server | |||
| try: | |||
| stat_logger.info("FATE Flow http server start...") | |||
| stat_logger.info("RAG Flow http server start...") | |||
| werkzeug_logger = logging.getLogger("werkzeug") | |||
| for h in access_logger.handlers: | |||
| werkzeug_logger.addHandler(h) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,10 +24,10 @@ from web_server.utils.log_utils import LoggerFactory, getLogger | |||
| # Server | |||
| API_VERSION = "v1" | |||
| FATE_FLOW_SERVICE_NAME = "ragflow" | |||
| RAG_FLOW_SERVICE_NAME = "ragflow" | |||
| SERVER_MODULE = "rag_flow_server.py" | |||
| TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp") | |||
| FATE_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf") | |||
| RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf") | |||
| SUBPROCESS_STD_LOG_NAME = "std.log" | |||
| @@ -52,21 +52,21 @@ IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") | |||
| # distribution | |||
| DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) | |||
| FATE_FLOW_UPDATE_CHECK = False | |||
| RAG_FLOW_UPDATE_CHECK = False | |||
| HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||
| HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("http_port") | |||
| HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||
| HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||
| SECRET_KEY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") | |||
| TOKEN_EXPIRE_IN = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) | |||
| SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") | |||
| TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) | |||
| NGINX_HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST | |||
| NGINX_HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT | |||
| NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST | |||
| NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT | |||
| RANDOM_INSTANCE_ID = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) | |||
| RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) | |||
| PROXY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("proxy") | |||
| PROXY_PROTOCOL = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("protocol") | |||
| PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy") | |||
| PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") | |||
| DATABASE = decrypt_database_config() | |||
| @@ -133,7 +133,7 @@ class CustomEnum(Enum): | |||
| class PythonDependenceName(CustomEnum): | |||
| Fate_Source_Code = "python" | |||
| Rag_Source_Code = "python" | |||
| Python_Env = "miniconda" | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import base64 | |||
| from datetime import datetime | |||
| import datetime | |||
| import io | |||
| import json | |||
| import os | |||
| @@ -185,7 +185,7 @@ def deserialize_b64(src): | |||
| safe_module = { | |||
| 'numpy', | |||
| 'fate_flow' | |||
| 'rag_flow' | |||
| } | |||
| @@ -287,16 +287,16 @@ def get_uuid(): | |||
| return uuid.uuid1().hex | |||
| def datetime_format(date_time: datetime) -> datetime: | |||
| return datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) | |||
| def datetime_format(date_time: datetime.datetime) -> datetime.datetime: | |||
| return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) | |||
| def get_format_time() -> datetime: | |||
| return datetime_format(datetime.now()) | |||
| def get_format_time() -> datetime.datetime: | |||
| return datetime_format(datetime.datetime.now()) | |||
| def str2date(date_time: str): | |||
| return datetime.strptime(date_time, '%Y-%m-%d') | |||
| return datetime.datetime.strptime(date_time, '%Y-%m-%d') | |||
| def elapsed2time(elapsed): | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -25,7 +25,7 @@ from flask import ( | |||
| from werkzeug.http import HTTP_STATUS_CODES | |||
| from web_server.utils import json_dumps | |||
| from web_server.versions import get_fate_version | |||
| from web_server.versions import get_rag_version | |||
| from web_server.settings import RetCode | |||
| from web_server.settings import ( | |||
| REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | |||
| @@ -73,7 +73,7 @@ def request(**kwargs): | |||
| return sess.send(prepped, stream=stream, timeout=timeout) | |||
| fate_version = get_fate_version() or '' | |||
| rag_version = get_rag_version() or '' | |||
| def get_exponential_backoff_interval(retries, full_jitter=False): | |||
| @@ -93,7 +93,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id | |||
| result_dict = { | |||
| "retcode": retcode, | |||
| "retmsg":retmsg, | |||
| # "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE), | |||
| # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE), | |||
| "data": data, | |||
| "jobId": job_id, | |||
| "meta": meta, | |||
| @@ -109,7 +109,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id | |||
| def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'): | |||
| import re | |||
| result_dict = {"retcode": retcode, "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE)} | |||
| result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)} | |||
| response = {} | |||
| for key, value in result_dict.items(): | |||
| if value is None and key != "retcode": | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,7 +24,7 @@ from ruamel.yaml import YAML | |||
| from web_server.db import FileType | |||
| PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") | |||
| FATE_BASE = os.getenv("RAG_BASE") | |||
| RAG_BASE = os.getenv("RAG_BASE") | |||
| def get_project_base_directory(*args): | |||
| global PROJECT_BASE | |||
| @@ -42,10 +42,10 @@ def get_project_base_directory(*args): | |||
| return PROJECT_BASE | |||
| def get_fate_directory(*args): | |||
| global FATE_BASE | |||
| if FATE_BASE is None: | |||
| FATE_BASE = os.path.abspath( | |||
| def get_rag_directory(*args): | |||
| global RAG_BASE | |||
| if RAG_BASE is None: | |||
| RAG_BASE = os.path.abspath( | |||
| os.path.join( | |||
| os.path.dirname(os.path.realpath(__file__)), | |||
| os.pardir, | |||
| @@ -54,12 +54,12 @@ def get_fate_directory(*args): | |||
| ) | |||
| ) | |||
| if args: | |||
| return os.path.join(FATE_BASE, *args) | |||
| return FATE_BASE | |||
| return os.path.join(RAG_BASE, *args) | |||
| return RAG_BASE | |||
| def get_fate_python_directory(*args): | |||
| return get_fate_directory("python", *args) | |||
| def get_rag_python_directory(*args): | |||
| return get_rag_directory("python", *args) | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -127,13 +127,9 @@ class LoggerFactory(object): | |||
| else: | |||
| log_file = os.path.join(log_dir, "{}.log".format(class_name)) | |||
| else: | |||
| log_file = os.path.join(log_dir, "fate_flow_{}.log".format( | |||
| log_type) if level == LoggerFactory.LEVEL else 'fate_flow_{}_error.log'.format(log_type)) | |||
| job_id = job_id or os.getenv("FATE_JOB_ID") | |||
| if job_id: | |||
| formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", job_id)) | |||
| else: | |||
| formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", "Server")) | |||
| log_file = os.path.join(log_dir, "rag_flow_{}.log".format( | |||
| log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type)) | |||
| os.makedirs(os.path.dirname(log_file), exist_ok=True) | |||
| if LoggerFactory.log_share: | |||
| handler = ROpenHandler(log_file, | |||
| @@ -150,7 +146,6 @@ class LoggerFactory(object): | |||
| if level: | |||
| handler.level = level | |||
| handler.setFormatter(formatter) | |||
| return handler | |||
| @staticmethod | |||
| @@ -264,28 +259,28 @@ def exception_to_trace_string(ex): | |||
| def get_logger_base_dir(): | |||
| job_log_dir = file_utils.get_fate_flow_directory('logs') | |||
| job_log_dir = file_utils.get_rag_flow_directory('logs') | |||
| return job_log_dir | |||
| def get_job_logger(job_id, log_type): | |||
| fate_flow_log_dir = file_utils.get_fate_flow_directory('logs', 'fate_flow') | |||
| job_log_dir = file_utils.get_fate_flow_directory('logs', job_id) | |||
| rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow') | |||
| job_log_dir = file_utils.get_rag_flow_directory('logs', job_id) | |||
| if not job_id: | |||
| log_dirs = [fate_flow_log_dir] | |||
| log_dirs = [rag_flow_log_dir] | |||
| else: | |||
| if log_type == 'audit': | |||
| log_dirs = [job_log_dir, fate_flow_log_dir] | |||
| log_dirs = [job_log_dir, rag_flow_log_dir] | |||
| else: | |||
| log_dirs = [job_log_dir] | |||
| if LoggerFactory.log_share: | |||
| oldmask = os.umask(000) | |||
| os.makedirs(job_log_dir, exist_ok=True) | |||
| os.makedirs(fate_flow_log_dir, exist_ok=True) | |||
| os.makedirs(rag_flow_log_dir, exist_ok=True) | |||
| os.umask(oldmask) | |||
| else: | |||
| os.makedirs(job_log_dir, exist_ok=True) | |||
| os.makedirs(fate_flow_log_dir, exist_ok=True) | |||
| os.makedirs(rag_flow_log_dir, exist_ok=True) | |||
| logger = LoggerFactory.new_logger(f"{job_id}_{log_type}") | |||
| for job_log_dir in log_dirs: | |||
| handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, | |||
| @@ -1,5 +1,5 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -26,5 +26,5 @@ def get_versions() -> typing.Mapping[str, typing.Any]: | |||
| dotenv_path=os.path.join(get_project_base_directory(), "rag.env") | |||
| ) | |||
| def get_fate_version() -> typing.Optional[str]: | |||
| def get_rag_version() -> typing.Optional[str]: | |||
| return get_versions().get("RAG") | |||
| @@ -1,164 +0,0 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import socket | |||
| from pathlib import Path | |||
| from web_server import utils | |||
| from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo | |||
| from .reload_config_base import ReloadConfigBase | |||
| class ServiceRegistry(ReloadConfigBase): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def load_service(cls, **kwargs) -> [ServiceRegistryInfo]: | |||
| service_registry_list = ServiceRegistryInfo.query(**kwargs) | |||
| return [service for service in service_registry_list] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"): | |||
| if not server_info: | |||
| server_list = ServerRegistry.query_server_info_from_db(server_name=server_name) | |||
| if not server_list: | |||
| raise Exception(f"no found server {server_name}") | |||
| server_info = server_list[0] | |||
| url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}" | |||
| else: | |||
| url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}" | |||
| service_info = { | |||
| "f_server_name": server_name, | |||
| "f_service_name": service_name, | |||
| "f_url": url, | |||
| "f_method": method, | |||
| "f_params": params if params else {}, | |||
| "f_data": data if data else {}, | |||
| "f_headers": headers if headers else {} | |||
| } | |||
| entity_model, status = ServiceRegistryInfo.get_or_create( | |||
| f_server_name=server_name, | |||
| f_service_name=service_name, | |||
| defaults=service_info) | |||
| if status is False: | |||
| for key in service_info: | |||
| setattr(entity_model, key, service_info[key]) | |||
| entity_model.save(force_insert=False) | |||
| class ServerRegistry(ReloadConfigBase): | |||
| FATEBOARD = None | |||
| FATE_ON_STANDALONE = None | |||
| FATE_ON_EGGROLL = None | |||
| FATE_ON_SPARK = None | |||
| MODEL_STORE_ADDRESS = None | |||
| SERVINGS = None | |||
| FATEMANAGER = None | |||
| STUDIO = None | |||
| @classmethod | |||
| def load(cls): | |||
| cls.load_server_info_from_conf() | |||
| cls.load_server_info_from_db() | |||
| @classmethod | |||
| def load_server_info_from_conf(cls): | |||
| path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF | |||
| conf = utils.file_utils.load_yaml_conf(path) | |||
| if not isinstance(conf, dict): | |||
| raise ValueError('invalid config file') | |||
| local_path = path.with_name(f'local.{utils.SERVICE_CONF}') | |||
| if local_path.exists(): | |||
| local_conf = utils.file_utils.load_yaml_conf(local_path) | |||
| if not isinstance(local_conf, dict): | |||
| raise ValueError('invalid local config file') | |||
| conf.update(local_conf) | |||
| for k, v in conf.items(): | |||
| if isinstance(v, dict): | |||
| setattr(cls, k.upper(), v) | |||
| @classmethod | |||
| def register(cls, server_name, server_info): | |||
| cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http")) | |||
| setattr(cls, server_name, server_info) | |||
| @classmethod | |||
| def save(cls, service_config): | |||
| update_server = {} | |||
| for server_name, server_info in service_config.items(): | |||
| cls.parameter_check(server_info) | |||
| api_info = server_info.pop("api", {}) | |||
| for service_name, info in api_info.items(): | |||
| ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info) | |||
| cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http") | |||
| setattr(cls, server_name.upper(), server_info) | |||
| return update_server | |||
| @classmethod | |||
| def parameter_check(cls, service_info): | |||
| if "host" in service_info and "port" in service_info: | |||
| cls.connection_test(service_info.get("host"), service_info.get("port")) | |||
| @classmethod | |||
| def connection_test(cls, ip, port): | |||
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
| result = s.connect_ex((ip, port)) | |||
| if result != 0: | |||
| raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}") | |||
| @classmethod | |||
| def query(cls, service_name, default=None): | |||
| service_info = getattr(cls, service_name, default) | |||
| if not service_info: | |||
| service_info = utils.get_base_config(service_name, default) | |||
| return service_info | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]: | |||
| if server_name: | |||
| server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper()) | |||
| else: | |||
| server_list = ServerRegistryInfo.select() | |||
| return [server for server in server_list] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def load_server_info_from_db(cls): | |||
| for server in cls.query_server_info_from_db(): | |||
| server_info = { | |||
| "host": server.f_host, | |||
| "port": server.f_port, | |||
| "protocol": server.f_protocol | |||
| } | |||
| setattr(cls, server.f_server_name.upper(), server_info) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def save_server_info_to_db(cls, server_name, host, port, protocol="http"): | |||
| server_info = { | |||
| "f_server_name": server_name, | |||
| "f_host": host, | |||
| "f_port": port, | |||
| "f_protocol": protocol | |||
| } | |||
| entity_model, status = ServerRegistryInfo.get_or_create( | |||
| f_server_name=server_name, | |||
| defaults=server_info) | |||
| if status is False: | |||
| for key in server_info: | |||
| setattr(entity_model, key, server_info[key]) | |||
| entity_model.save(force_insert=False) | |||