* add front end code * change licence * rename web_server to API * change name to web_servertags/v0.1.0
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import base64 | |||||
| import hashlib | |||||
| import pathlib | |||||
| import re | |||||
| from elasticsearch_dsl import Q | |||||
| from flask import request | |||||
| from flask_login import login_required, current_user | |||||
| from rag.nlp import search, huqie | |||||
| from rag.utils import ELASTICSEARCH, rmSpace | |||||
| from web_server.db import LLMType | |||||
| from web_server.db.services import duplicate_name | |||||
| from web_server.db.services.kb_service import KnowledgebaseService | |||||
| from web_server.db.services.llm_service import TenantLLMService | |||||
| from web_server.db.services.user_service import UserTenantService | |||||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||||
| from web_server.utils import get_uuid | |||||
| from web_server.db.services.document_service import DocumentService | |||||
| from web_server.settings import RetCode | |||||
| from web_server.utils.api_utils import get_json_result | |||||
| from rag.utils.minio_conn import MINIO | |||||
| from web_server.utils.file_utils import filename_type | |||||
| retrival = search.Dealer(ELASTICSEARCH, None) | |||||
| @manager.route('/list', methods=['POST']) | |||||
| @login_required | |||||
| @validate_request("doc_id") | |||||
| def list(): | |||||
| req = request.json | |||||
| doc_id = req["doc_id"] | |||||
| page = req.get("page", 1) | |||||
| size = req.get("size", 30) | |||||
| question = req.get("keywords", "") | |||||
| try: | |||||
| tenants = UserTenantService.query(user_id=current_user.id) | |||||
| if not tenants: | |||||
| return get_data_error_result(retmsg="Tenant not found!") | |||||
| res = retrival.search({ | |||||
| "doc_ids": [doc_id], "page": page, "size": size, "question": question | |||||
| }, search.index_name(tenants[0].tenant_id)) | |||||
| return get_json_result(data=res) | |||||
| except Exception as e: | |||||
| if str(e).find("not_found") > 0: | |||||
| return get_json_result(data=False, retmsg=f'Index not found!', | |||||
| retcode=RetCode.DATA_ERROR) | |||||
| return server_error_response(e) | |||||
| @manager.route('/get', methods=['GET']) | |||||
| @login_required | |||||
| def get(): | |||||
| chunk_id = request.args["chunk_id"] | |||||
| try: | |||||
| tenants = UserTenantService.query(user_id=current_user.id) | |||||
| if not tenants: | |||||
| return get_data_error_result(retmsg="Tenant not found!") | |||||
| res = ELASTICSEARCH.get(chunk_id, search.index_name(tenants[0].tenant_id)) | |||||
| if not res.get("found"):return server_error_response("Chunk not found") | |||||
| id = res["_id"] | |||||
| res = res["_source"] | |||||
| res["chunk_id"] = id | |||||
| k = [] | |||||
| for n in res.keys(): | |||||
| if re.search(r"(_vec$|_sm_)", n): | |||||
| k.append(n) | |||||
| if re.search(r"(_tks|_ltks)", n): | |||||
| res[n] = rmSpace(res[n]) | |||||
| for n in k: del res[n] | |||||
| return get_json_result(data=res) | |||||
| except Exception as e: | |||||
| if str(e).find("NotFoundError") >= 0: | |||||
| return get_json_result(data=False, retmsg=f'Chunk not found!', | |||||
| retcode=RetCode.DATA_ERROR) | |||||
| return server_error_response(e) | |||||
| @manager.route('/set', methods=['POST']) | |||||
| @login_required | |||||
| @validate_request("doc_id", "chunk_id", "content_ltks", "important_kwd", "docnm_kwd") | |||||
| def set(): | |||||
| req = request.json | |||||
| d = {"id": req["chunk_id"]} | |||||
| d["content_ltks"] = huqie.qie(req["content_ltks"]) | |||||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||||
| d["important_kwd"] = req["important_kwd"] | |||||
| d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) | |||||
| try: | |||||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||||
| if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") | |||||
| embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value) | |||||
| v, c = embd_mdl.encode([req["docnm_kwd"], req["content_ltks"]]) | |||||
| v = 0.1 * v[0] + 0.9 * v[1] | |||||
| d["q_%d_vec"%len(v)] = v.tolist() | |||||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||||
| return get_json_result(data=True) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/create', methods=['POST']) | |||||
| @login_required | |||||
| @validate_request("doc_id", "content_ltks", "important_kwd") | |||||
| def set(): | |||||
| req = request.json | |||||
| md5 = hashlib.md5() | |||||
| md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) | |||||
| chunck_id = md5.hexdigest() | |||||
| d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])} | |||||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||||
| d["important_kwd"] = req["important_kwd"] | |||||
| d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) | |||||
| try: | |||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||||
| if not e: return get_data_error_result(retmsg="Document not found!") | |||||
| d["kb_id"] = [doc.kb_id] | |||||
| d["docnm_kwd"] = doc.name | |||||
| d["doc_id"] = doc.id | |||||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||||
| if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") | |||||
| embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value) | |||||
| v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) | |||||
| DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) | |||||
| v = 0.1 * v[0] + 0.9 * v[1] | |||||
| d["q_%d_vec"%len(v)] = v.tolist() | |||||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||||
| return get_json_result(data={"chunk_id": chunck_id}) | |||||
| except Exception as e: | |||||
| return server_error_response(e) |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| from rag.utils import ELASTICSEARCH | from rag.utils import ELASTICSEARCH | ||||
| from web_server.db.services import duplicate_name | from web_server.db.services import duplicate_name | ||||
| from web_server.db.services.kb_service import KnowledgebaseService | from web_server.db.services.kb_service import KnowledgebaseService | ||||
| from web_server.db.services.user_service import TenantService | |||||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from web_server.utils import get_uuid, get_format_time | |||||
| from web_server.db import StatusEnum, FileType | |||||
| from web_server.utils import get_uuid | |||||
| from web_server.db import FileType | |||||
| from web_server.db.services.document_service import DocumentService | from web_server.db.services.document_service import DocumentService | ||||
| from web_server.settings import RetCode | from web_server.settings import RetCode | ||||
| from web_server.utils.api_utils import get_json_result | from web_server.utils.api_utils import get_json_result |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| def factories(): | def factories(): | ||||
| try: | try: | ||||
| fac = LLMFactoriesService.get_all() | fac = LLMFactoriesService.get_all() | ||||
| return get_json_result(data=fac.to_json()) | |||||
| return get_json_result(data=[f.to_dict() for f in fac]) | |||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| @login_required | @login_required | ||||
| def my_llms(): | def my_llms(): | ||||
| try: | try: | ||||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||||
| objs = [o.to_dict() for o in objs] | |||||
| for o in objs: del o["api_key"] | |||||
| objs = TenantLLMService.get_my_llms(current_user.id) | |||||
| return get_json_result(data=objs) | return get_json_result(data=objs) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| return get_json_result(data=tenants) | return get_json_result(data=tenants) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| @manager.route("/set_tenant_info", methods=["POST"]) | |||||
| @login_required | |||||
| @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id") | |||||
| def set_tenant_info(): | |||||
| req = request.json | |||||
| try: | |||||
| tid = req["tenant_id"] | |||||
| del req["tenant_id"] | |||||
| TenantService.update_by_id(tid, req) | |||||
| return get_json_result(data=True) | |||||
| except Exception as e: | |||||
| return server_error_response(e) |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2021 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2021 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| from functools import wraps | from functools import wraps | ||||
| from shortuuid import ShortUUID | from shortuuid import ShortUUID | ||||
| from web_server.versions import get_fate_version | |||||
| from web_server.versions import get_rag_version | |||||
| from web_server.errors.error_services import * | from web_server.errors.error_services import * | ||||
| from web_server.settings import ( | from web_server.settings import ( | ||||
| json.dumps({ | json.dumps({ | ||||
| 'instance_id': instance_id, | 'instance_id': instance_id, | ||||
| 'timestamp': round(time.time() * 1000), | 'timestamp': round(time.time() * 1000), | ||||
| 'version': get_fate_version() or '', | |||||
| 'version': get_rag_version() or '', | |||||
| 'host': HOST, | 'host': HOST, | ||||
| 'grpc_port': GRPC_PORT, | 'grpc_port': GRPC_PORT, | ||||
| 'http_port': HTTP_PORT, | 'http_port': HTTP_PORT, | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def supported_services(self): | def supported_services(self): | ||||
| """The names of supported services. | """The names of supported services. | ||||
| The returned list SHOULD contain `fateflow` (model download) and `servings` (FATE-Serving). | |||||
| The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving). | |||||
| :return: The service names. | :return: The service names. | ||||
| :rtype: list | :rtype: list | ||||
| @check_service_supported | @check_service_supported | ||||
| def get_urls(self, service_name, with_values=False): | def get_urls(self, service_name, with_values=False): | ||||
| """Query service urls from database. The urls may belong to other nodes. | """Query service urls from database. The urls may belong to other nodes. | ||||
| Currently, only `fateflow` (model download) urls and `servings` (FATE-Serving) urls are supported. | |||||
| `fateflow` is a url containing scheme, host, port and path, | |||||
| Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported. | |||||
| `ragflow` is a url containing scheme, host, port and path, | |||||
| while `servings` only contains host and port. | while `servings` only contains host and port. | ||||
| :param str service_name: The service name. | :param str service_name: The service name. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() | num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() | ||||
| return num | return num | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_tenant_id(cls, doc_id): | |||||
| docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value) | |||||
| docs = docs.dicts() | |||||
| if not docs:return | |||||
| return docs[0]["tenant_id"] |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| import peewee | import peewee | ||||
| from werkzeug.security import generate_password_hash, check_password_hash | from werkzeug.security import generate_password_hash, check_password_hash | ||||
| from rag.llm import EmbeddingModel, CvModel | |||||
| from web_server.db import LLMType | |||||
| from web_server.db.db_models import DB, UserTenant | from web_server.db.db_models import DB, UserTenant | ||||
| from web_server.db.db_models import LLMFactories, LLM, TenantLLM | from web_server.db.db_models import LLMFactories, LLM, TenantLLM | ||||
| from web_server.db.services.common_service import CommonService | from web_server.db.services.common_service import CommonService | ||||
| from web_server.utils import get_uuid, get_format_time | |||||
| from web_server.db.db_utils import StatusEnum | from web_server.db.db_utils import StatusEnum | ||||
| if not objs:return | if not objs:return | ||||
| return objs[0] | return objs[0] | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_my_llms(cls, tenant_id): | |||||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name] | |||||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts() | |||||
| return list(objs) | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def model_instance(cls, tenant_id, llm_type): | |||||
| model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING) | |||||
| if not model_config: | |||||
| model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""} | |||||
| else: | |||||
| model_config = model_config[0].to_dict() | |||||
| if llm_type == LLMType.EMBEDDING: | |||||
| if model_config["llm_factory"] not in EmbeddingModel: return | |||||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||||
| if llm_type == LLMType.IMAGE2TEXT: | |||||
| if model_config["llm_factory"] not in CvModel: return | |||||
| return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"]) |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| from .general_error import * | from .general_error import * | ||||
| class FateFlowError(Exception): | |||||
| message = 'Unknown Fate Flow Error' | |||||
| class RagFlowError(Exception): | |||||
| message = 'Unknown Rag Flow Error' | |||||
| def __init__(self, message=None, *args, **kwargs): | def __init__(self, message=None, *args, **kwargs): | ||||
| message = str(message) if message is not None else self.message | message = str(message) if message is not None else self.message |
| from web_server.errors import FateFlowError | |||||
| from web_server.errors import RagFlowError | |||||
| __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', | __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', | ||||
| 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] | 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] | ||||
| class ServicesError(FateFlowError): | |||||
| class ServicesError(RagFlowError): | |||||
| message = 'Unknown services error' | message = 'Unknown services error' | ||||
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| # init runtime config | # init runtime config | ||||
| import argparse | import argparse | ||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument('--version', default=False, help="fate flow version", action='store_true') | |||||
| parser.add_argument('--version', default=False, help="rag flow version", action='store_true') | |||||
| parser.add_argument('--debug', default=False, help="debug mode", action='store_true') | parser.add_argument('--debug', default=False, help="debug mode", action='store_true') | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if args.version: | if args.version: | ||||
| peewee_logger = logging.getLogger('peewee') | peewee_logger = logging.getLogger('peewee') | ||||
| peewee_logger.propagate = False | peewee_logger.propagate = False | ||||
| # fate_arch.common.log.ROpenHandler | |||||
| # rag_arch.common.log.ROpenHandler | |||||
| peewee_logger.addHandler(database_logger.handlers[0]) | peewee_logger.addHandler(database_logger.handlers[0]) | ||||
| peewee_logger.setLevel(database_logger.level) | peewee_logger.setLevel(database_logger.level) | ||||
| # start http server | # start http server | ||||
| try: | try: | ||||
| stat_logger.info("FATE Flow http server start...") | |||||
| stat_logger.info("RAG Flow http server start...") | |||||
| werkzeug_logger = logging.getLogger("werkzeug") | werkzeug_logger = logging.getLogger("werkzeug") | ||||
| for h in access_logger.handlers: | for h in access_logger.handlers: | ||||
| werkzeug_logger.addHandler(h) | werkzeug_logger.addHandler(h) |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| # Server | # Server | ||||
| API_VERSION = "v1" | API_VERSION = "v1" | ||||
| FATE_FLOW_SERVICE_NAME = "ragflow" | |||||
| RAG_FLOW_SERVICE_NAME = "ragflow" | |||||
| SERVER_MODULE = "rag_flow_server.py" | SERVER_MODULE = "rag_flow_server.py" | ||||
| TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp") | TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp") | ||||
| FATE_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf") | |||||
| RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf") | |||||
| SUBPROCESS_STD_LOG_NAME = "std.log" | SUBPROCESS_STD_LOG_NAME = "std.log" | ||||
| # distribution | # distribution | ||||
| DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) | DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) | ||||
| FATE_FLOW_UPDATE_CHECK = False | |||||
| RAG_FLOW_UPDATE_CHECK = False | |||||
| HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||||
| HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("http_port") | |||||
| HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||||
| HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||||
| SECRET_KEY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") | |||||
| TOKEN_EXPIRE_IN = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) | |||||
| SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") | |||||
| TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) | |||||
| NGINX_HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST | |||||
| NGINX_HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT | |||||
| NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST | |||||
| NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT | |||||
| RANDOM_INSTANCE_ID = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) | |||||
| RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) | |||||
| PROXY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("proxy") | |||||
| PROXY_PROTOCOL = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("protocol") | |||||
| PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy") | |||||
| PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") | |||||
| DATABASE = decrypt_database_config() | DATABASE = decrypt_database_config() | ||||
| class PythonDependenceName(CustomEnum): | class PythonDependenceName(CustomEnum): | ||||
| Fate_Source_Code = "python" | |||||
| Rag_Source_Code = "python" | |||||
| Python_Env = "miniconda" | Python_Env = "miniconda" | ||||
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import base64 | import base64 | ||||
| from datetime import datetime | |||||
| import datetime | |||||
| import io | import io | ||||
| import json | import json | ||||
| import os | import os | ||||
| safe_module = { | safe_module = { | ||||
| 'numpy', | 'numpy', | ||||
| 'fate_flow' | |||||
| 'rag_flow' | |||||
| } | } | ||||
| return uuid.uuid1().hex | return uuid.uuid1().hex | ||||
| def datetime_format(date_time: datetime) -> datetime: | |||||
| return datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) | |||||
| def datetime_format(date_time: datetime.datetime) -> datetime.datetime: | |||||
| return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) | |||||
| def get_format_time() -> datetime: | |||||
| return datetime_format(datetime.now()) | |||||
| def get_format_time() -> datetime.datetime: | |||||
| return datetime_format(datetime.datetime.now()) | |||||
| def str2date(date_time: str): | def str2date(date_time: str): | ||||
| return datetime.strptime(date_time, '%Y-%m-%d') | |||||
| return datetime.datetime.strptime(date_time, '%Y-%m-%d') | |||||
| def elapsed2time(elapsed): | def elapsed2time(elapsed): |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| from werkzeug.http import HTTP_STATUS_CODES | from werkzeug.http import HTTP_STATUS_CODES | ||||
| from web_server.utils import json_dumps | from web_server.utils import json_dumps | ||||
| from web_server.versions import get_fate_version | |||||
| from web_server.versions import get_rag_version | |||||
| from web_server.settings import RetCode | from web_server.settings import RetCode | ||||
| from web_server.settings import ( | from web_server.settings import ( | ||||
| REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | ||||
| return sess.send(prepped, stream=stream, timeout=timeout) | return sess.send(prepped, stream=stream, timeout=timeout) | ||||
| fate_version = get_fate_version() or '' | |||||
| rag_version = get_rag_version() or '' | |||||
| def get_exponential_backoff_interval(retries, full_jitter=False): | def get_exponential_backoff_interval(retries, full_jitter=False): | ||||
| result_dict = { | result_dict = { | ||||
| "retcode": retcode, | "retcode": retcode, | ||||
| "retmsg":retmsg, | "retmsg":retmsg, | ||||
| # "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE), | |||||
| # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE), | |||||
| "data": data, | "data": data, | ||||
| "jobId": job_id, | "jobId": job_id, | ||||
| "meta": meta, | "meta": meta, | ||||
| def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'): | def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'): | ||||
| import re | import re | ||||
| result_dict = {"retcode": retcode, "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE)} | |||||
| result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)} | |||||
| response = {} | response = {} | ||||
| for key, value in result_dict.items(): | for key, value in result_dict.items(): | ||||
| if value is None and key != "retcode": | if value is None and key != "retcode": |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| from web_server.db import FileType | from web_server.db import FileType | ||||
| PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") | PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") | ||||
| FATE_BASE = os.getenv("RAG_BASE") | |||||
| RAG_BASE = os.getenv("RAG_BASE") | |||||
| def get_project_base_directory(*args): | def get_project_base_directory(*args): | ||||
| global PROJECT_BASE | global PROJECT_BASE | ||||
| return PROJECT_BASE | return PROJECT_BASE | ||||
| def get_fate_directory(*args): | |||||
| global FATE_BASE | |||||
| if FATE_BASE is None: | |||||
| FATE_BASE = os.path.abspath( | |||||
| def get_rag_directory(*args): | |||||
| global RAG_BASE | |||||
| if RAG_BASE is None: | |||||
| RAG_BASE = os.path.abspath( | |||||
| os.path.join( | os.path.join( | ||||
| os.path.dirname(os.path.realpath(__file__)), | os.path.dirname(os.path.realpath(__file__)), | ||||
| os.pardir, | os.pardir, | ||||
| ) | ) | ||||
| ) | ) | ||||
| if args: | if args: | ||||
| return os.path.join(FATE_BASE, *args) | |||||
| return FATE_BASE | |||||
| return os.path.join(RAG_BASE, *args) | |||||
| return RAG_BASE | |||||
| def get_fate_python_directory(*args): | |||||
| return get_fate_directory("python", *args) | |||||
| def get_rag_python_directory(*args): | |||||
| return get_rag_directory("python", *args) | |||||
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| else: | else: | ||||
| log_file = os.path.join(log_dir, "{}.log".format(class_name)) | log_file = os.path.join(log_dir, "{}.log".format(class_name)) | ||||
| else: | else: | ||||
| log_file = os.path.join(log_dir, "fate_flow_{}.log".format( | |||||
| log_type) if level == LoggerFactory.LEVEL else 'fate_flow_{}_error.log'.format(log_type)) | |||||
| job_id = job_id or os.getenv("FATE_JOB_ID") | |||||
| if job_id: | |||||
| formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", job_id)) | |||||
| else: | |||||
| formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", "Server")) | |||||
| log_file = os.path.join(log_dir, "rag_flow_{}.log".format( | |||||
| log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type)) | |||||
| os.makedirs(os.path.dirname(log_file), exist_ok=True) | os.makedirs(os.path.dirname(log_file), exist_ok=True) | ||||
| if LoggerFactory.log_share: | if LoggerFactory.log_share: | ||||
| handler = ROpenHandler(log_file, | handler = ROpenHandler(log_file, | ||||
| if level: | if level: | ||||
| handler.level = level | handler.level = level | ||||
| handler.setFormatter(formatter) | |||||
| return handler | return handler | ||||
| @staticmethod | @staticmethod | ||||
| def get_logger_base_dir(): | def get_logger_base_dir(): | ||||
| job_log_dir = file_utils.get_fate_flow_directory('logs') | |||||
| job_log_dir = file_utils.get_rag_flow_directory('logs') | |||||
| return job_log_dir | return job_log_dir | ||||
| def get_job_logger(job_id, log_type): | def get_job_logger(job_id, log_type): | ||||
| fate_flow_log_dir = file_utils.get_fate_flow_directory('logs', 'fate_flow') | |||||
| job_log_dir = file_utils.get_fate_flow_directory('logs', job_id) | |||||
| rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow') | |||||
| job_log_dir = file_utils.get_rag_flow_directory('logs', job_id) | |||||
| if not job_id: | if not job_id: | ||||
| log_dirs = [fate_flow_log_dir] | |||||
| log_dirs = [rag_flow_log_dir] | |||||
| else: | else: | ||||
| if log_type == 'audit': | if log_type == 'audit': | ||||
| log_dirs = [job_log_dir, fate_flow_log_dir] | |||||
| log_dirs = [job_log_dir, rag_flow_log_dir] | |||||
| else: | else: | ||||
| log_dirs = [job_log_dir] | log_dirs = [job_log_dir] | ||||
| if LoggerFactory.log_share: | if LoggerFactory.log_share: | ||||
| oldmask = os.umask(000) | oldmask = os.umask(000) | ||||
| os.makedirs(job_log_dir, exist_ok=True) | os.makedirs(job_log_dir, exist_ok=True) | ||||
| os.makedirs(fate_flow_log_dir, exist_ok=True) | |||||
| os.makedirs(rag_flow_log_dir, exist_ok=True) | |||||
| os.umask(oldmask) | os.umask(oldmask) | ||||
| else: | else: | ||||
| os.makedirs(job_log_dir, exist_ok=True) | os.makedirs(job_log_dir, exist_ok=True) | ||||
| os.makedirs(fate_flow_log_dir, exist_ok=True) | |||||
| os.makedirs(rag_flow_log_dir, exist_ok=True) | |||||
| logger = LoggerFactory.new_logger(f"{job_id}_{log_type}") | logger = LoggerFactory.new_logger(f"{job_id}_{log_type}") | ||||
| for job_log_dir in log_dirs: | for job_log_dir in log_dirs: | ||||
| handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, | handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, |
| # | # | ||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # Copyright 2019 The RAG Flow Authors. All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| dotenv_path=os.path.join(get_project_base_directory(), "rag.env") | dotenv_path=os.path.join(get_project_base_directory(), "rag.env") | ||||
| ) | ) | ||||
| def get_fate_version() -> typing.Optional[str]: | |||||
| def get_rag_version() -> typing.Optional[str]: | |||||
| return get_versions().get("RAG") | return get_versions().get("RAG") |
| # | |||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import socket | |||||
| from pathlib import Path | |||||
| from web_server import utils | |||||
| from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo | |||||
| from .reload_config_base import ReloadConfigBase | |||||
| class ServiceRegistry(ReloadConfigBase): | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def load_service(cls, **kwargs) -> [ServiceRegistryInfo]: | |||||
| service_registry_list = ServiceRegistryInfo.query(**kwargs) | |||||
| return [service for service in service_registry_list] | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"): | |||||
| if not server_info: | |||||
| server_list = ServerRegistry.query_server_info_from_db(server_name=server_name) | |||||
| if not server_list: | |||||
| raise Exception(f"no found server {server_name}") | |||||
| server_info = server_list[0] | |||||
| url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}" | |||||
| else: | |||||
| url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}" | |||||
| service_info = { | |||||
| "f_server_name": server_name, | |||||
| "f_service_name": service_name, | |||||
| "f_url": url, | |||||
| "f_method": method, | |||||
| "f_params": params if params else {}, | |||||
| "f_data": data if data else {}, | |||||
| "f_headers": headers if headers else {} | |||||
| } | |||||
| entity_model, status = ServiceRegistryInfo.get_or_create( | |||||
| f_server_name=server_name, | |||||
| f_service_name=service_name, | |||||
| defaults=service_info) | |||||
| if status is False: | |||||
| for key in service_info: | |||||
| setattr(entity_model, key, service_info[key]) | |||||
| entity_model.save(force_insert=False) | |||||
| class ServerRegistry(ReloadConfigBase): | |||||
| FATEBOARD = None | |||||
| FATE_ON_STANDALONE = None | |||||
| FATE_ON_EGGROLL = None | |||||
| FATE_ON_SPARK = None | |||||
| MODEL_STORE_ADDRESS = None | |||||
| SERVINGS = None | |||||
| FATEMANAGER = None | |||||
| STUDIO = None | |||||
| @classmethod | |||||
| def load(cls): | |||||
| cls.load_server_info_from_conf() | |||||
| cls.load_server_info_from_db() | |||||
| @classmethod | |||||
| def load_server_info_from_conf(cls): | |||||
| path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF | |||||
| conf = utils.file_utils.load_yaml_conf(path) | |||||
| if not isinstance(conf, dict): | |||||
| raise ValueError('invalid config file') | |||||
| local_path = path.with_name(f'local.{utils.SERVICE_CONF}') | |||||
| if local_path.exists(): | |||||
| local_conf = utils.file_utils.load_yaml_conf(local_path) | |||||
| if not isinstance(local_conf, dict): | |||||
| raise ValueError('invalid local config file') | |||||
| conf.update(local_conf) | |||||
| for k, v in conf.items(): | |||||
| if isinstance(v, dict): | |||||
| setattr(cls, k.upper(), v) | |||||
| @classmethod | |||||
| def register(cls, server_name, server_info): | |||||
| cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http")) | |||||
| setattr(cls, server_name, server_info) | |||||
| @classmethod | |||||
| def save(cls, service_config): | |||||
| update_server = {} | |||||
| for server_name, server_info in service_config.items(): | |||||
| cls.parameter_check(server_info) | |||||
| api_info = server_info.pop("api", {}) | |||||
| for service_name, info in api_info.items(): | |||||
| ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info) | |||||
| cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http") | |||||
| setattr(cls, server_name.upper(), server_info) | |||||
| return update_server | |||||
| @classmethod | |||||
| def parameter_check(cls, service_info): | |||||
| if "host" in service_info and "port" in service_info: | |||||
| cls.connection_test(service_info.get("host"), service_info.get("port")) | |||||
| @classmethod | |||||
| def connection_test(cls, ip, port): | |||||
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||||
| result = s.connect_ex((ip, port)) | |||||
| if result != 0: | |||||
| raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}") | |||||
| @classmethod | |||||
| def query(cls, service_name, default=None): | |||||
| service_info = getattr(cls, service_name, default) | |||||
| if not service_info: | |||||
| service_info = utils.get_base_config(service_name, default) | |||||
| return service_info | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]: | |||||
| if server_name: | |||||
| server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper()) | |||||
| else: | |||||
| server_list = ServerRegistryInfo.select() | |||||
| return [server for server in server_list] | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def load_server_info_from_db(cls): | |||||
| for server in cls.query_server_info_from_db(): | |||||
| server_info = { | |||||
| "host": server.f_host, | |||||
| "port": server.f_port, | |||||
| "protocol": server.f_protocol | |||||
| } | |||||
| setattr(cls, server.f_server_name.upper(), server_info) | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def save_server_info_to_db(cls, server_name, host, port, protocol="http"): | |||||
| server_info = { | |||||
| "f_server_name": server_name, | |||||
| "f_host": host, | |||||
| "f_port": port, | |||||
| "f_protocol": protocol | |||||
| } | |||||
| entity_model, status = ServerRegistryInfo.get_or_create( | |||||
| f_server_name=server_name, | |||||
| defaults=server_info) | |||||
| if status is False: | |||||
| for key in server_info: | |||||
| setattr(entity_model, key, server_info[key]) | |||||
| entity_model.save(force_insert=False) |