浏览代码

Fix: Reduce excessive IO operations by loading LLM factory configurations (#6047)

…ions

### What problem does this PR solve?

This PR fixes an issue where the application was repeatedly reading the
llm_factories.json file from disk in multiple places, which could lead
to "Too many open files" errors under high load conditions. The fix
centralizes the file reading operation in the settings.py module and
stores the data in a global variable that can be accessed by other
modules.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [x] Performance Improvement
- [ ] Other (please describe):
tags/v0.18.0
utopia2077 7 个月前
父节点
当前提交
2d4a60cae6
没有帐户链接到提交者的电子邮件
共有 4 个文件被更改,包括 18 次插入17 次删除
  1. 2
    7
      api/db/init_data.py
  2. 2
    4
      api/db/services/llm_service.py
  3. 10
    1
      api/settings.py
  4. 4
    5
      rag/prompts.py

+ 2
- 7
api/db/init_data.py 查看文件

except Exception: except Exception:
pass pass


factory_llm_infos = json.load(
open(
os.path.join(get_project_base_directory(), "conf", "llm_factories.json"),
"r",
)
)
for factory_llm_info in factory_llm_infos["factory_llm_infos"]:
factory_llm_infos = settings.FACTORY_LLM_INFOS
for factory_llm_info in factory_llm_infos:
llm_infos = factory_llm_info.pop("llm") llm_infos = factory_llm_info.pop("llm")
try: try:
LLMFactoriesService.save(**factory_llm_info) LLMFactoriesService.save(**factory_llm_info)

+ 2
- 4
api/db/services/llm_service.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
import logging import logging
import os


from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.utils.file_utils import get_project_base_directory
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api import settings
from api.db import LLMType from api.db import LLMType
from api.db.db_models import DB from api.db.db_models import DB
from api.db.db_models import LLMFactories, LLM, TenantLLM from api.db.db_models import LLMFactories, LLM, TenantLLM


# model name must be xxx@yyy # model name must be xxx@yyy
try: try:
model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
model_factories = settings.FACTORY_LLM_INFOS
model_providers = set([f["name"] for f in model_factories]) model_providers = set([f["name"] for f in model_factories])
if arr[-1] not in model_providers: if arr[-1] not in model_providers:
return model_name, None return model_name, None

+ 10
- 1
api/settings.py 查看文件

import os import os
from datetime import date from datetime import date
from enum import IntEnum, Enum from enum import IntEnum, Enum
import json
import rag.utils.es_conn import rag.utils.es_conn
import rag.utils.infinity_conn import rag.utils.infinity_conn


from graphrag import search as kg_search from graphrag import search as kg_search
from api.utils import get_base_config, decrypt_database_config from api.utils import get_base_config, decrypt_database_config
from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils.file_utils import get_project_base_directory


LIGHTEN = int(os.environ.get('LIGHTEN', "0")) LIGHTEN = int(os.environ.get('LIGHTEN', "0"))


HOST_IP = None HOST_IP = None
HOST_PORT = None HOST_PORT = None
SECRET_KEY = None SECRET_KEY = None
FACTORY_LLM_INFOS = None


DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)




def init_settings(): def init_settings():
global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE
global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS
LIGHTEN = int(os.environ.get('LIGHTEN', "0")) LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) LLM_DEFAULT_MODELS = LLM.get("default_models", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url") LLM_BASE_URL = LLM.get("base_url")
try:
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
except Exception:
FACTORY_LLM_INFOS = []


global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
if not LIGHTEN: if not LIGHTEN:

+ 4
- 5
rag/prompts.py 查看文件

import datetime import datetime
import json import json
import logging import logging
import os
import re import re
from collections import defaultdict from collections import defaultdict
import json_repair import json_repair
from api import settings
from api.db import LLMType from api.db import LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import TenantLLMService, LLMBundle from api.db.services.llm_service import TenantLLMService, LLMBundle
from api.utils.file_utils import get_project_base_directory
from rag.settings import TAG_FLD from rag.settings import TAG_FLD
from rag.utils import num_tokens_from_string, encoder from rag.utils import num_tokens_from_string, encoder




def llm_id2llm_type(llm_id): def llm_id2llm_type(llm_id):
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]: for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]: if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1] return llm["model_type"].strip(",")[-1]

正在加载...
取消
保存