浏览代码

refactor(api/core): Improve type hints and apply ruff formatter in agent runner and model manager. (#8166)

tags/0.8.0
-LAN- 1年前
父节点
当前提交
ed37439ef7
没有帐户链接到提交者的电子邮件
共有 2 个文件被更改,包括 199 次插入197 次删除
  1. 129
    110
      api/core/agent/base_agent_runner.py
  2. 70
    87
      api/core/model_manager.py

+ 129
- 110
api/core/agent/base_agent_runner.py 查看文件

import json import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, Union, cast from typing import Optional, Union, cast




logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)



class BaseAgentRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None:
def __init__(
self,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
""" """
Agent runner Agent runner
:param tenant_id: tenant id :param tenant_id: tenant id
self.message = message self.message = message
self.user_id = user_id self.user_id = user_id
self.memory = memory self.memory = memory
self.history_prompt_messages = self.organize_agent_history(
prompt_messages=prompt_messages or []
)
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool self.variables_pool = variables_pool
self.db_variables_pool = db_variables self.db_variables_pool = db_variables
self.model_instance = model_instance self.model_instance = model_instance
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source, return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback
hit_callback=hit_callback,
) )
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.filter(
MessageAgentThought.message_id == self.message.id,
)
.count()
)
db.session.close() db.session.close()


# check if model supports stream tool call # check if model supports stream tool call
self.query = None self.query = None
self._current_thoughts: list[PromptMessage] = [] self._current_thoughts: list[PromptMessage] = []


def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
""" """
Repack app generate entity Repack app generate entity
""" """
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""


return app_generate_entity return app_generate_entity
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
""" """
convert tool to prompt message tool
convert tool to prompt message tool
""" """
tool_entity = ToolManager.get_agent_tool_runtime( tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_config.app_id, app_id=self.app_config.app_id,
agent_tool=tool, agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from
invoke_from=self.application_generate_entity.invoke_from,
) )
tool_entity.load_variables(self.variables_pool) tool_entity.load_variables(self.variables_pool)


"type": "object", "type": "object",
"properties": {}, "properties": {},
"required": [], "required": [],
}
},
) )


parameters = tool_entity.get_all_runtime_parameters() parameters = tool_entity.get_all_runtime_parameters()
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]


message_tool.parameters['properties'][parameter.name] = {
message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '',
"description": parameter.llm_description or "",
} }


if len(enum) > 0: if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
message_tool.parameters["properties"][parameter.name]["enum"] = enum


if parameter.required: if parameter.required:
message_tool.parameters['required'].append(parameter.name)
message_tool.parameters["required"].append(parameter.name)


return message_tool, tool_entity return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
""" """
convert dataset retriever tool to prompt message tool convert dataset retriever tool to prompt message tool
"type": "object", "type": "object",
"properties": {}, "properties": {},
"required": [], "required": [],
}
},
) )


for parameter in tool.get_runtime_parameters(): for parameter in tool.get_runtime_parameters():
parameter_type = 'string'
prompt_tool.parameters['properties'][parameter.name] = {
parameter_type = "string"
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '',
"description": parameter.llm_description or "",
} }


if parameter.required: if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)


return prompt_tool return prompt_tool
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
""" """
Init tools Init tools
""" """
enum = [] enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]
prompt_tool.parameters['properties'][parameter.name] = {
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '',
"description": parameter.llm_description or "",
} }


if len(enum) > 0: if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum


if parameter.required: if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)


return prompt_tool return prompt_tool
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
def create_agent_thought(
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
""" """
Create agent thought Create agent thought
""" """
thought = MessageAgentThought( thought = MessageAgentThought(
message_id=message_id, message_id=message_id,
message_chain_id=None, message_chain_id=None,
thought='',
thought="",
tool=tool_name, tool=tool_name,
tool_labels_str='{}',
tool_meta_str='{}',
tool_labels_str="{}",
tool_meta_str="{}",
tool_input=tool_input, tool_input=tool_input,
message=message, message=message,
message_token=0, message_token=0,
message_unit_price=0, message_unit_price=0,
message_price_unit=0, message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else '',
answer='',
observation='',
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0, answer_token=0,
answer_unit_price=0, answer_unit_price=0,
answer_price_unit=0, answer_price_unit=0,
tokens=0, tokens=0,
total_price=0, total_price=0,
position=self.agent_thought_count + 1, position=self.agent_thought_count + 1,
currency='USD',
currency="USD",
latency=0, latency=0,
created_by_role='account',
created_by_role="account",
created_by=self.user_id, created_by=self.user_id,
) )




return thought return thought


def save_agent_thought(self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()


if thought is not None: if thought is not None:
agent_thought.thought = thought agent_thought.thought = thought
observation = json.dumps(observation, ensure_ascii=False) observation = json.dumps(observation, ensure_ascii=False)
except Exception as e: except Exception as e:
observation = json.dumps(observation) observation = json.dumps(observation)
agent_thought.observation = observation agent_thought.observation = observation


if answer is not None: if answer is not None:


if messages_ids is not None and len(messages_ids) > 0: if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids) agent_thought.message_files = json.dumps(messages_ids)
if llm_usage: if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit agent_thought.message_price_unit = llm_usage.prompt_price_unit


# check if tool labels is not empty # check if tool labels is not empty
labels = agent_thought.tool_labels or {} labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(';') if agent_thought.tool else []
tools = agent_thought.tool.split(";") if agent_thought.tool else []
for tool in tools: for tool in tools:
if not tool: if not tool:
continue continue
if tool_label: if tool_label:
labels[tool] = tool_label.to_dict() labels[tool] = tool_label.to_dict()
else: else:
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
labels[tool] = {"en_US": tool, "zh_Hans": tool}


agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)




db.session.commit() db.session.commit()
db.session.close() db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
""" """
convert tool variables to db variables convert tool variables to db variables
""" """
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)


db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
if isinstance(prompt_message, SystemPromptMessage): if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message) result.append(prompt_message)


messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all()
messages: list[Message] = (
db.session.query(Message)
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.asc())
.all()
)


for message in messages: for message in messages:
if message.id == self.message.id: if message.id == self.message.id:
for agent_thought in agent_thoughts: for agent_thought in agent_thoughts:
tools = agent_thought.tool tools = agent_thought.tool
if tools: if tools:
tools = tools.split(';')
tools = tools.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = [] tool_call_response: list[ToolPromptMessage] = []
try: try:
tool_inputs = json.loads(agent_thought.tool_input) tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e: except Exception as e:
tool_inputs = { tool: {} for tool in tools }
tool_inputs = {tool: {} for tool in tools}
try: try:
tool_responses = json.loads(agent_thought.observation) tool_responses = json.loads(agent_thought.observation)
except Exception as e: except Exception as e:
for tool in tools: for tool in tools:
# generate a uuid for tool call # generate a uuid for tool call
tool_call_id = str(uuid.uuid4()) tool_call_id = str(uuid.uuid4())
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
),
)
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool, name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
tool_call_id=tool_call_id,
) )
))
tool_call_response.append(ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
))

result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
)

result.extend(
[
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response,
]
)
if not tools: if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought)) result.append(AssistantPromptMessage(content=agent_thought.thought))
else: else:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())


if file_extra_config: if file_extra_config:
file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
else: else:
file_objs = [] file_objs = []



+ 70
- 87
api/core/model_manager.py 查看文件

import logging import logging
import os import os
from collections.abc import Callable, Generator
from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast from typing import IO, Optional, Union, cast


from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
configuration=provider_model_bundle.configuration, configuration=provider_model_bundle.configuration,
model_type=provider_model_bundle.model_type_instance.model_type, model_type=provider_model_bundle.model_type_instance.model_type,
model=model, model=model,
credentials=self.credentials
credentials=self.credentials,
) )


@staticmethod @staticmethod
""" """
configuration = provider_model_bundle.configuration configuration = provider_model_bundle.configuration
model_type = provider_model_bundle.model_type_instance.model_type model_type = provider_model_bundle.model_type_instance.model_type
credentials = configuration.get_current_credentials(
model_type=model_type,
model=model
)
credentials = configuration.get_current_credentials(model_type=model_type, model=model)


if credentials is None: if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
return credentials return credentials


@staticmethod @staticmethod
def _get_load_balancing_manager(configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict) -> Optional["LBModelManager"]:
def _get_load_balancing_manager(
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
) -> Optional["LBModelManager"]:
""" """
Get load balancing model credentials Get load balancing model credentials
:param configuration: provider configuration :param configuration: provider configuration
current_model_setting = None current_model_setting = None
# check if model is disabled by admin # check if model is disabled by admin
for model_setting in configuration.model_settings: for model_setting in configuration.model_settings:
if (model_setting.model_type == model_type
and model_setting.model == model):
if model_setting.model_type == model_type and model_setting.model == model:
current_model_setting = model_setting current_model_setting = model_setting
break break


model_type=model_type, model_type=model_type,
model=model, model=model,
load_balancing_configs=current_model_setting.load_balancing_configs, load_balancing_configs=current_model_setting.load_balancing_configs,
managed_credentials=credentials if configuration.custom_configuration.provider else None
managed_credentials=credentials if configuration.custom_configuration.provider else None,
) )


return lb_model_manager return lb_model_manager


return None return None


def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model


stop=stop, stop=stop,
stream=stream, stream=stream,
user=user, user=user,
callbacks=callbacks
callbacks=callbacks,
) )


def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm


model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=tools
tools=tools,
) )


def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
""" """
Invoke large language model Invoke large language model


model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
texts=texts, texts=texts,
user=user
user=user,
) )


def get_text_embedding_num_tokens(self, texts: list[str]) -> int: def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
texts=texts
texts=texts,
) )


def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def invoke_rerank(
self,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
""" """
Invoke rerank model Invoke rerank model


docs=docs, docs=docs,
score_threshold=score_threshold, score_threshold=score_threshold,
top_n=top_n, top_n=top_n,
user=user
user=user,
) )


def invoke_moderation(self, text: str, user: Optional[str] = None) \
-> bool:
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
""" """
Invoke moderation model Invoke moderation model


model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
text=text, text=text,
user=user
user=user,
) )


def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
""" """
Invoke large language model Invoke large language model


model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
file=file, file=file,
user=user
user=user,
) )


def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
-> str:
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
""" """
Invoke large language tts model Invoke large language tts model


content_text=content_text, content_text=content_text,
user=user, user=user,
tenant_id=tenant_id, tenant_id=tenant_id,
voice=voice
voice=voice,
) )


def _round_robin_invoke(self, function: Callable, *args, **kwargs): def _round_robin_invoke(self, function: Callable, *args, **kwargs):
raise last_exception raise last_exception


try: try:
if 'credentials' in kwargs:
del kwargs['credentials']
if "credentials" in kwargs:
del kwargs["credentials"]
return function(*args, **kwargs, credentials=lb_config.credentials) return function(*args, **kwargs, credentials=lb_config.credentials)
except InvokeRateLimitError as e: except InvokeRateLimitError as e:
# expire in 60 seconds # expire in 60 seconds


self.model_type_instance = cast(TTSModel, self.model_type_instance) self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices( return self.model_type_instance.get_tts_model_voices(
model=self.model,
credentials=self.credentials,
language=language
model=self.model, credentials=self.credentials, language=language
) )




return self.get_default_model_instance(tenant_id, model_type) return self.get_default_model_instance(tenant_id, model_type)


provider_model_bundle = self._provider_manager.get_provider_model_bundle( provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=provider,
model_type=model_type
tenant_id=tenant_id, provider=provider, model_type=model_type
) )


return ModelInstance(provider_model_bundle, model) return ModelInstance(provider_model_bundle, model)
:param model_type: model type :param model_type: model type
:return: :return:
""" """
default_model_entity = self._provider_manager.get_default_model(
tenant_id=tenant_id,
model_type=model_type
)
default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type)


if not default_model_entity: if not default_model_entity:
raise ProviderTokenNotInitError(f"Default model not found for {model_type}") raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
tenant_id=tenant_id, tenant_id=tenant_id,
provider=default_model_entity.provider.provider, provider=default_model_entity.provider.provider,
model_type=model_type, model_type=model_type,
model=default_model_entity.model
model=default_model_entity.model,
) )




class LBModelManager: class LBModelManager:
def __init__(self, tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: Optional[dict] = None) -> None:
def __init__(
self,
tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: Optional[dict] = None,
) -> None:
""" """
Load balancing model manager Load balancing model manager
:param tenant_id: tenant_id :param tenant_id: tenant_id
:return: :return:
""" """
cache_key = "model_lb_index:{}:{}:{}:{}".format( cache_key = "model_lb_index:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model
self._tenant_id, self._provider, self._model_type.value, self._model
) )


cooldown_load_balancing_configs = [] cooldown_load_balancing_configs = []


continue continue


if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
f"model_type: {self._model_type.value}\nmodel: {self._model}")
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
logger.info(
f"Model LB\nid: {config.id}\nname:{config.name}\n"
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
f"model_type: {self._model_type.value}\nmodel: {self._model}"
)


return config return config


:return: :return:
""" """
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model,
config.id
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
) )


redis_client.setex(cooldown_cache_key, expire, 'true')
redis_client.setex(cooldown_cache_key, expire, "true")


def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
""" """
:return: :return:
""" """
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model,
config.id
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
) )


res = redis_client.exists(cooldown_cache_key) res = redis_client.exists(cooldown_cache_key)
return res return res


@staticmethod @staticmethod
def get_config_in_cooldown_and_ttl(tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
config_id: str) -> tuple[bool, int]:
def get_config_in_cooldown_and_ttl(
tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str
) -> tuple[bool, int]:
""" """
Get model load balancing config is in cooldown and ttl Get model load balancing config is in cooldown and ttl
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
tenant_id,
provider,
model_type.value,
model,
config_id
tenant_id, provider, model_type.value, model, config_id
) )


ttl = redis_client.ttl(cooldown_cache_key) ttl = redis_client.ttl(cooldown_cache_key)

正在加载...
取消
保存