| @@ -11,5 +11,5 @@ class RemoteSettingsSource: | |||
| def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: | |||
| raise NotImplementedError | |||
| def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: | |||
| def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool): | |||
| return value | |||
| @@ -33,7 +33,7 @@ class NacosSettingsSource(RemoteSettingsSource): | |||
| logger.exception("[get-access-token] exception occurred") | |||
| raise | |||
| def _parse_config(self, content: str) -> dict: | |||
| def _parse_config(self, content: str): | |||
| if not content: | |||
| return {} | |||
| try: | |||
| @@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self) -> dict: | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("type", type=str, required=True, default=False, location="json") | |||
| args = parser.parse_args() | |||
| @@ -1,5 +1,5 @@ | |||
| import logging | |||
| from typing import Any, NoReturn | |||
| from typing import NoReturn | |||
| from flask import Response | |||
| from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse | |||
| @@ -30,7 +30,7 @@ from services.workflow_service import WorkflowService | |||
| logger = logging.getLogger(__name__) | |||
| def _convert_values_to_json_serializable_object(value: Segment) -> Any: | |||
| def _convert_values_to_json_serializable_object(value: Segment): | |||
| if isinstance(value, FileSegment): | |||
| return value.value.model_dump() | |||
| elif isinstance(value, ArrayFileSegment): | |||
| @@ -41,7 +41,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any: | |||
| return value.value | |||
| def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: | |||
| def _serialize_var_value(variable: WorkflowDraftVariable): | |||
| value = variable.get_value() | |||
| # create a copy of the value to avoid affecting the model cache. | |||
| value = value.model_copy(deep=True) | |||
| @@ -99,7 +99,7 @@ class MCPAppApi(Resource): | |||
| return mcp_server, app | |||
| def _validate_server_status(self, mcp_server: AppMCPServer) -> None: | |||
| def _validate_server_status(self, mcp_server: AppMCPServer): | |||
| """Validate MCP server status""" | |||
| if mcp_server.status != AppMCPServerStatus.ACTIVE: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") | |||
| @@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner): | |||
| model_instance: ModelInstance, | |||
| memory: Optional[TokenBufferMemory] = None, | |||
| prompt_messages: Optional[list[PromptMessage]] = None, | |||
| ) -> None: | |||
| ): | |||
| self.tenant_id = tenant_id | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| @@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| return instruction | |||
| def _init_react_state(self, query) -> None: | |||
| def _init_react_state(self, query): | |||
| """ | |||
| init agent scratchpad | |||
| """ | |||
| @@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel): | |||
| action_name: str | |||
| action_input: Union[dict, str] | |||
| def to_dict(self) -> dict: | |||
| def to_dict(self): | |||
| """ | |||
| Convert to dictionary. | |||
| """ | |||
| @@ -158,7 +158,7 @@ class DatasetConfigManager: | |||
| return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] | |||
| @classmethod | |||
| def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: | |||
| def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict): | |||
| """ | |||
| Extract dataset config for legacy compatibility | |||
| @@ -105,7 +105,7 @@ class ModelConfigManager: | |||
| return dict(config), ["model"] | |||
| @classmethod | |||
| def validate_model_completion_params(cls, cp: dict) -> dict: | |||
| def validate_model_completion_params(cls, cp: dict): | |||
| # model.completion_params | |||
| if not isinstance(cp, dict): | |||
| raise ValueError("model.completion_params must be of object type") | |||
| @@ -122,7 +122,7 @@ class PromptTemplateConfigManager: | |||
| return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] | |||
| @classmethod | |||
| def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: | |||
| def validate_post_prompt_and_set_defaults(cls, config: dict): | |||
| """ | |||
| Validate post_prompt and set defaults for prompt feature | |||
| @@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): | |||
| """ | |||
| Validate for advanced chat app model config | |||
| @@ -481,7 +481,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| message_id: str, | |||
| context: contextvars.Context, | |||
| variable_loader: VariableLoader, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -55,7 +55,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| app: App, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| queue_manager=queue_manager, | |||
| variable_loader=variable_loader, | |||
| @@ -69,7 +69,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| self.system_user_id = system_user_id | |||
| self._app = app | |||
| def run(self) -> None: | |||
| def run(self): | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(AdvancedChatAppConfig, app_config) | |||
| @@ -238,7 +238,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| return False | |||
| def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: | |||
| def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy): | |||
| """ | |||
| Direct output | |||
| """ | |||
| @@ -96,7 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| draft_var_saver_factory: DraftVariableSaverFactory, | |||
| ) -> None: | |||
| ): | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| @@ -284,7 +284,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| session.rollback() | |||
| raise | |||
| def _ensure_workflow_initialized(self) -> None: | |||
| def _ensure_workflow_initialized(self): | |||
| """Fluent validation for workflow state.""" | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| @@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None): | |||
| message = self._get_message(session=session) | |||
| # If there are assistant files, remove markdown image links from answer | |||
| @@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): | |||
| """ | |||
| Validate for agent chat app model config | |||
| @@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| queue_manager: AppQueueManager, | |||
| conversation_id: str, | |||
| message_id: str, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -35,7 +35,7 @@ class AgentChatAppRunner(AppRunner): | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Run assistant application | |||
| :param application_generate_entity: application generate entity | |||
| @@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = ChatbotAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return response | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC): | |||
| return metadata | |||
| @classmethod | |||
| def _error_to_stream_response(cls, e: Exception) -> dict: | |||
| def _error_to_stream_response(cls, e: Exception): | |||
| """ | |||
| Error to stream response. | |||
| :param e: exception | |||
| @@ -157,7 +157,7 @@ class BaseAppGenerator: | |||
| return value | |||
| def _sanitize_value(self, value: Any) -> Any: | |||
| def _sanitize_value(self, value: Any): | |||
| if isinstance(value, str): | |||
| return value.replace("\x00", "") | |||
| return value | |||
| @@ -25,7 +25,7 @@ class PublishFrom(IntEnum): | |||
| class AppQueueManager: | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom): | |||
| if not user_id: | |||
| raise ValueError("user is required") | |||
| @@ -73,14 +73,14 @@ class AppQueueManager: | |||
| self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) | |||
| last_ping_time = elapsed_time // 10 | |||
| def stop_listen(self) -> None: | |||
| def stop_listen(self): | |||
| """ | |||
| Stop listen to queue | |||
| :return: | |||
| """ | |||
| self._q.put(None) | |||
| def publish_error(self, e, pub_from: PublishFrom) -> None: | |||
| def publish_error(self, e, pub_from: PublishFrom): | |||
| """ | |||
| Publish error | |||
| :param e: error | |||
| @@ -89,7 +89,7 @@ class AppQueueManager: | |||
| """ | |||
| self.publish(QueueErrorEvent(error=e), pub_from) | |||
| def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -100,7 +100,7 @@ class AppQueueManager: | |||
| self._publish(event, pub_from) | |||
| @abstractmethod | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -110,7 +110,7 @@ class AppQueueManager: | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: | |||
| def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str): | |||
| """ | |||
| Set task stop flag | |||
| :return: | |||
| @@ -162,7 +162,7 @@ class AppRunner: | |||
| text: str, | |||
| stream: bool, | |||
| usage: Optional[LLMUsage] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Direct output | |||
| :param queue_manager: application queue manager | |||
| @@ -204,7 +204,7 @@ class AppRunner: | |||
| queue_manager: AppQueueManager, | |||
| stream: bool, | |||
| agent: bool = False, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| @@ -220,9 +220,7 @@ class AppRunner: | |||
| else: | |||
| raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") | |||
| def _handle_invoke_result_direct( | |||
| self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool | |||
| ) -> None: | |||
| def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): | |||
| """ | |||
| Handle invoke result direct | |||
| :param invoke_result: invoke result | |||
| @@ -239,7 +237,7 @@ class AppRunner: | |||
| def _handle_invoke_result_stream( | |||
| self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| @@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate for chat app model config | |||
| @@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| queue_manager: AppQueueManager, | |||
| conversation_id: str, | |||
| message_id: str, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -33,7 +33,7 @@ class ChatAppRunner(AppRunner): | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Run application | |||
| :param application_generate_entity: application generate entity | |||
| @@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = ChatbotAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return response | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -55,7 +55,7 @@ class WorkflowResponseConverter: | |||
| *, | |||
| application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | |||
| user: Union[Account, EndUser], | |||
| ) -> None: | |||
| ): | |||
| self._application_generate_entity = application_generate_entity | |||
| self._user = user | |||
| @@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate for completion app model config | |||
| @@ -192,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| application_generate_entity: CompletionAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| message_id: str, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -27,7 +27,7 @@ class CompletionAppRunner(AppRunner): | |||
| def run( | |||
| self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Run application | |||
| :param application_generate_entity: application generate entity | |||
| @@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = CompletionAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return response | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -14,14 +14,14 @@ from core.app.entities.queue_entities import ( | |||
| class MessageBasedAppQueueManager(AppQueueManager): | |||
| def __init__( | |||
| self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str | |||
| ) -> None: | |||
| ): | |||
| super().__init__(task_id, user_id, invoke_from) | |||
| self._conversation_id = str(conversation_id) | |||
| self._app_mode = app_mode | |||
| self._message_id = str(message_id) | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): | |||
| """ | |||
| Validate for workflow app model config | |||
| @@ -14,12 +14,12 @@ from core.app.entities.queue_entities import ( | |||
| class WorkflowAppQueueManager(AppQueueManager): | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str): | |||
| super().__init__(task_id, user_id, invoke_from) | |||
| self._app_mode = app_mode | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| variable_loader: VariableLoader, | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| queue_manager=queue_manager, | |||
| variable_loader=variable_loader, | |||
| @@ -44,7 +44,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| self._workflow = workflow | |||
| self._sys_user_id = system_user_id | |||
| def run(self) -> None: | |||
| def run(self): | |||
| """ | |||
| Run application | |||
| """ | |||
| @@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = WorkflowAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return dict(blocking_response.to_dict()) | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -88,7 +88,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| draft_var_saver_factory: DraftVariableSaverFactory, | |||
| ) -> None: | |||
| ): | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| @@ -259,7 +259,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| session.rollback() | |||
| raise | |||
| def _ensure_workflow_initialized(self) -> None: | |||
| def _ensure_workflow_initialized(self): | |||
| """Fluent validation for workflow state.""" | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| @@ -697,7 +697,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if tts_publisher: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution): | |||
| invoke_from = self._application_generate_entity.invoke_from | |||
| if invoke_from == InvokeFrom.SERVICE_API: | |||
| created_from = WorkflowAppLogCreatedFrom.SERVICE_API | |||
| @@ -67,7 +67,7 @@ class WorkflowBasedAppRunner: | |||
| queue_manager: AppQueueManager, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| app_id: str, | |||
| ) -> None: | |||
| ): | |||
| self._queue_manager = queue_manager | |||
| self._variable_loader = variable_loader | |||
| self._app_id = app_id | |||
| @@ -348,7 +348,7 @@ class WorkflowBasedAppRunner: | |||
| return graph, variable_pool | |||
| def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: | |||
| def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): | |||
| """ | |||
| Handle event | |||
| :param workflow_entry: workflow entry | |||
| @@ -580,5 +580,5 @@ class WorkflowBasedAppRunner: | |||
| ) | |||
| ) | |||
| def _publish_event(self, event: AppQueueEvent) -> None: | |||
| def _publish_event(self, event: AppQueueEvent): | |||
| self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline: | |||
| application_generate_entity: AppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| stream: bool, | |||
| ) -> None: | |||
| ): | |||
| self._application_generate_entity = application_generate_entity | |||
| self.queue_manager = queue_manager | |||
| self._start_at = time.perf_counter() | |||
| @@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| conversation: Conversation, | |||
| message: Message, | |||
| stream: bool, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| @@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: | |||
| def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None): | |||
| """ | |||
| Save message. | |||
| :return: | |||
| @@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| application_generate_entity=self._application_generate_entity, | |||
| ) | |||
| def _handle_stop(self, event: QueueStopEvent) -> None: | |||
| def _handle_stop(self, event: QueueStopEvent): | |||
| """ | |||
| Handle stop. | |||
| :return: | |||
| @@ -48,7 +48,7 @@ class MessageCycleManager: | |||
| AdvancedChatAppGenerateEntity, | |||
| ], | |||
| task_state: Union[EasyUITaskState, WorkflowTaskState], | |||
| ) -> None: | |||
| ): | |||
| self._application_generate_entity = application_generate_entity | |||
| self._task_state = task_state | |||
| @@ -132,7 +132,7 @@ class MessageCycleManager: | |||
| return None | |||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent): | |||
| """ | |||
| Handle retriever resources. | |||
| :param event: event | |||
| @@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str: | |||
| return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" | |||
| def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: | |||
| def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None): | |||
| """Print text with highlighting and no end characters.""" | |||
| text_to_print = get_colored_text(text, color) if color else text | |||
| print(text_to_print, end=end, file=file) | |||
| @@ -37,7 +37,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| color: Optional[str] = "" | |||
| current_loop: int = 1 | |||
| def __init__(self, color: Optional[str] = None) -> None: | |||
| def __init__(self, color: Optional[str] = None): | |||
| super().__init__() | |||
| """Initialize callback handler.""" | |||
| # use a specific color is not specified | |||
| @@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| self, | |||
| tool_name: str, | |||
| tool_inputs: Mapping[str, Any], | |||
| ) -> None: | |||
| ): | |||
| """Do nothing.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) | |||
| @@ -61,7 +61,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| message_id: Optional[str] = None, | |||
| timer: Optional[Any] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> None: | |||
| ): | |||
| """If not the final action, print out observation.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_tool_end]\n", color=self.color) | |||
| @@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| ) | |||
| ) | |||
| def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: | |||
| def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any): | |||
| """Do nothing.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") | |||
| def on_agent_start(self, thought: str) -> None: | |||
| def on_agent_start(self, thought: str): | |||
| """Run on agent start.""" | |||
| if dify_config.DEBUG: | |||
| if thought: | |||
| @@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| else: | |||
| print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) | |||
| def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: | |||
| def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any): | |||
| """Run on agent end.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) | |||
| @@ -21,14 +21,14 @@ class DatasetIndexToolCallbackHandler: | |||
| def __init__( | |||
| self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom | |||
| ) -> None: | |||
| ): | |||
| self._queue_manager = queue_manager | |||
| self._app_id = app_id | |||
| self._message_id = message_id | |||
| self._user_id = user_id | |||
| self._invoke_from = invoke_from | |||
| def on_query(self, query: str, dataset_id: str) -> None: | |||
| def on_query(self, query: str, dataset_id: str): | |||
| """ | |||
| Handle query. | |||
| """ | |||
| @@ -46,7 +46,7 @@ class DatasetIndexToolCallbackHandler: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| def on_tool_end(self, documents: list[Document]) -> None: | |||
| def on_tool_end(self, documents: list[Document]): | |||
| """Handle tool end.""" | |||
| for document in documents: | |||
| if document.metadata is not None: | |||
| @@ -33,7 +33,7 @@ class SimpleModelProviderEntity(BaseModel): | |||
| icon_large: Optional[I18nObject] = None | |||
| supported_model_types: list[ModelType] | |||
| def __init__(self, provider_entity: ProviderEntity) -> None: | |||
| def __init__(self, provider_entity: ProviderEntity): | |||
| """ | |||
| Init simple provider. | |||
| @@ -57,7 +57,7 @@ class ProviderModelWithStatusEntity(ProviderModel): | |||
| load_balancing_enabled: bool = False | |||
| has_invalid_load_balancing_configs: bool = False | |||
| def raise_for_status(self) -> None: | |||
| def raise_for_status(self): | |||
| """ | |||
| Check model status and raise ValueError if not active. | |||
| @@ -280,9 +280,7 @@ class ProviderConfiguration(BaseModel): | |||
| else [], | |||
| ) | |||
| def validate_provider_credentials( | |||
| self, credentials: dict, credential_id: str = "", session: Session | None = None | |||
| ) -> dict: | |||
| def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): | |||
| """ | |||
| Validate custom credentials. | |||
| :param credentials: provider credentials | |||
| @@ -291,7 +289,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| def _validate(s: Session) -> dict: | |||
| def _validate(s: Session): | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self.extract_secret_variables( | |||
| self.provider.provider_credential_schema.credential_form_schemas | |||
| @@ -402,7 +400,7 @@ class ProviderConfiguration(BaseModel): | |||
| logger.warning("Error generating next credential name: %s", str(e)) | |||
| return "API KEY 1" | |||
| def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: | |||
| def create_provider_credential(self, credentials: dict, credential_name: str | None): | |||
| """ | |||
| Add custom provider credentials. | |||
| :param credentials: provider credentials | |||
| @@ -458,7 +456,7 @@ class ProviderConfiguration(BaseModel): | |||
| credentials: dict, | |||
| credential_id: str, | |||
| credential_name: str | None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| update a saved provider credential (by credential_id). | |||
| @@ -519,7 +517,7 @@ class ProviderConfiguration(BaseModel): | |||
| credential_record: ProviderCredential | ProviderModelCredential, | |||
| credential_source: str, | |||
| session: Session, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Update load balancing configurations that reference the given credential_id. | |||
| @@ -559,7 +557,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.commit() | |||
| def delete_provider_credential(self, credential_id: str) -> None: | |||
| def delete_provider_credential(self, credential_id: str): | |||
| """ | |||
| Delete a saved provider credential (by credential_id). | |||
| @@ -636,7 +634,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.rollback() | |||
| raise | |||
| def switch_active_provider_credential(self, credential_id: str) -> None: | |||
| def switch_active_provider_credential(self, credential_id: str): | |||
| """ | |||
| Switch active provider credential (copy the selected one into current active snapshot). | |||
| @@ -815,7 +813,7 @@ class ProviderConfiguration(BaseModel): | |||
| credentials: dict, | |||
| credential_id: str = "", | |||
| session: Session | None = None, | |||
| ) -> dict: | |||
| ): | |||
| """ | |||
| Validate custom model credentials. | |||
| @@ -826,7 +824,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| def _validate(s: Session) -> dict: | |||
| def _validate(s: Session): | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self.extract_secret_variables( | |||
| self.provider.model_credential_schema.credential_form_schemas | |||
| @@ -1010,7 +1008,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.rollback() | |||
| raise | |||
| def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||
| def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): | |||
| """ | |||
| Delete a saved provider credential (by credential_id). | |||
| @@ -1080,7 +1078,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.rollback() | |||
| raise | |||
| def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||
| def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str): | |||
| """ | |||
| if model list exist this custom model, switch the custom model credential. | |||
| if model list not exist this custom model, use the credential to add a new custom model record. | |||
| @@ -1123,7 +1121,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.add(provider_model_record) | |||
| session.commit() | |||
| def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||
| def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): | |||
| """ | |||
| switch the custom model credential. | |||
| @@ -1153,7 +1151,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.add(provider_model_record) | |||
| session.commit() | |||
| def delete_custom_model(self, model_type: ModelType, model: str) -> None: | |||
| def delete_custom_model(self, model_type: ModelType, model: str): | |||
| """ | |||
| Delete custom model. | |||
| :param model_type: model type | |||
| @@ -1350,7 +1348,7 @@ class ProviderConfiguration(BaseModel): | |||
| provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | |||
| ) | |||
| def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: | |||
| def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None): | |||
| """ | |||
| Switch preferred provider type. | |||
| :param provider_type: | |||
| @@ -1362,7 +1360,7 @@ class ProviderConfiguration(BaseModel): | |||
| if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: | |||
| return | |||
| def _switch(s: Session) -> None: | |||
| def _switch(s: Session): | |||
| # get preferred provider | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| @@ -1406,7 +1404,7 @@ class ProviderConfiguration(BaseModel): | |||
| return secret_input_form_variables | |||
| def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: | |||
| def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): | |||
| """ | |||
| Obfuscated credentials. | |||
| @@ -6,7 +6,7 @@ class LLMError(ValueError): | |||
| description: Optional[str] = None | |||
| def __init__(self, description: Optional[str] = None) -> None: | |||
| def __init__(self, description: Optional[str] = None): | |||
| self.description = description | |||
| @@ -10,11 +10,11 @@ class APIBasedExtensionRequestor: | |||
| timeout: tuple[int, int] = (5, 60) | |||
| """timeout for request connect and read""" | |||
| def __init__(self, api_endpoint: str, api_key: str) -> None: | |||
| def __init__(self, api_endpoint: str, api_key: str): | |||
| self.api_endpoint = api_endpoint | |||
| self.api_key = api_key | |||
| def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: | |||
| def request(self, point: APIBasedExtensionPoint, params: dict): | |||
| """ | |||
| Request the api. | |||
| @@ -34,7 +34,7 @@ class Extensible: | |||
| tenant_id: str | |||
| config: Optional[dict] = None | |||
| def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: | |||
| def __init__(self, tenant_id: str, config: Optional[dict] = None): | |||
| self.tenant_id = tenant_id | |||
| self.config = config | |||
| @@ -18,7 +18,7 @@ class ApiExternalDataTool(ExternalDataTool): | |||
| """the unique name of external data tool""" | |||
| @classmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -16,14 +16,14 @@ class ExternalDataTool(Extensible, ABC): | |||
| variable: str | |||
| """the tool variable name of app tool""" | |||
| def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: | |||
| def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None): | |||
| super().__init__(tenant_id, config) | |||
| self.app_id = app_id | |||
| self.variable = variable | |||
| @classmethod | |||
| @abstractmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension | |||
| class ExternalDataToolFactory: | |||
| def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: | |||
| def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict): | |||
| extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) | |||
| self.__extension_instance = extension_class( | |||
| tenant_id=tenant_id, app_id=app_id, variable=variable, config=config | |||
| ) | |||
| @classmethod | |||
| def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, name: str, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -7,6 +7,6 @@ if TYPE_CHECKING: | |||
| _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None | |||
| def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: | |||
| def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]): | |||
| global _tool_file_manager_factory | |||
| _tool_file_manager_factory = factory | |||
| @@ -22,7 +22,7 @@ class CodeNodeProvider(BaseModel): | |||
| pass | |||
| @classmethod | |||
| def get_default_config(cls) -> dict: | |||
| def get_default_config(cls): | |||
| return { | |||
| "type": "code", | |||
| "config": { | |||
| @@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer | |||
| class Jinja2TemplateTransformer(TemplateTransformer): | |||
| @classmethod | |||
| def transform_response(cls, response: str) -> dict: | |||
| def transform_response(cls, response: str): | |||
| """ | |||
| Transform response to dict | |||
| :param response: response | |||
| @@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider): | |||
| def get_default_code(cls) -> str: | |||
| return dedent( | |||
| """ | |||
| def main(arg1: str, arg2: str) -> dict: | |||
| def main(arg1: str, arg2: str): | |||
| return { | |||
| "result": arg1 + arg2, | |||
| } | |||
| @@ -34,7 +34,7 @@ class ProviderCredentialsCache: | |||
| else: | |||
| return None | |||
| def set(self, credentials: dict) -> None: | |||
| def set(self, credentials: dict): | |||
| """ | |||
| Cache model provider credentials. | |||
| @@ -43,7 +43,7 @@ class ProviderCredentialsCache: | |||
| """ | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """ | |||
| Delete cached model provider credentials. | |||
| @@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC): | |||
| return None | |||
| return None | |||
| def set(self, config: dict[str, Any]) -> None: | |||
| def set(self, config: dict[str, Any]): | |||
| """Cache provider credentials""" | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(config)) | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """Delete cached provider credentials""" | |||
| redis_client.delete(self.cache_key) | |||
| @@ -75,10 +75,10 @@ class NoOpProviderCredentialCache: | |||
| """Get cached provider credentials""" | |||
| return None | |||
| def set(self, config: dict[str, Any]) -> None: | |||
| def set(self, config: dict[str, Any]): | |||
| """Cache provider credentials""" | |||
| pass | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """Delete cached provider credentials""" | |||
| pass | |||
| @@ -37,11 +37,11 @@ class ToolParameterCache: | |||
| else: | |||
| return None | |||
| def set(self, parameters: dict) -> None: | |||
| def set(self, parameters: dict): | |||
| """Cache model provider credentials.""" | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """ | |||
| Delete cached model provider credentials. | |||
| @@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]: | |||
| return None | |||
| def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: | |||
| def extract_external_trace_id_from_args(args: Mapping[str, Any]): | |||
| """ | |||
| Extract 'external_trace_id' from args. | |||
| @@ -44,11 +44,11 @@ class HostingConfiguration: | |||
| provider_map: dict[str, HostingProvider] | |||
| moderation_config: Optional[HostedModerationConfig] = None | |||
| def __init__(self) -> None: | |||
| def __init__(self): | |||
| self.provider_map = {} | |||
| self.moderation_config = None | |||
| def init_app(self, app: Flask) -> None: | |||
| def init_app(self, app: Flask): | |||
| if dify_config.EDITION != "CLOUD": | |||
| return | |||
| @@ -512,7 +512,7 @@ class IndexingRunner: | |||
| dataset: Dataset, | |||
| dataset_document: DatasetDocument, | |||
| documents: list[Document], | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| insert index and update document/segment status to completed | |||
| """ | |||
| @@ -651,7 +651,7 @@ class IndexingRunner: | |||
| @staticmethod | |||
| def _update_document_index_status( | |||
| document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Update the document indexing status. | |||
| """ | |||
| @@ -670,7 +670,7 @@ class IndexingRunner: | |||
| db.session.commit() | |||
| @staticmethod | |||
| def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: | |||
| def _update_segments_by_document(dataset_document_id: str, update_params: dict): | |||
| """ | |||
| Update the document segment by document id. | |||
| """ | |||
| @@ -128,7 +128,7 @@ class LLMGenerator: | |||
| return questions | |||
| @classmethod | |||
| def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: | |||
| def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): | |||
| output_parser = RuleConfigGeneratorOutputParser() | |||
| error = "" | |||
| @@ -263,9 +263,7 @@ class LLMGenerator: | |||
| return rule_config | |||
| @classmethod | |||
| def generate_code( | |||
| cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript" | |||
| ) -> dict: | |||
| def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): | |||
| if code_language == "python": | |||
| prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) | |||
| else: | |||
| @@ -374,7 +372,7 @@ class LLMGenerator: | |||
| @staticmethod | |||
| def instruction_modify_legacy( | |||
| tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None | |||
| ) -> dict: | |||
| ): | |||
| last_run: Message | None = ( | |||
| db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() | |||
| ) | |||
| @@ -414,7 +412,7 @@ class LLMGenerator: | |||
| instruction: str, | |||
| model_config: dict, | |||
| ideal_output: str | None, | |||
| ) -> dict: | |||
| ): | |||
| from services.workflow_service import WorkflowService | |||
| app: App | None = db.session.query(App).where(App.id == flow_id).first() | |||
| @@ -452,7 +450,7 @@ class LLMGenerator: | |||
| return [] | |||
| parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) | |||
| def dict_of_event(event: AgentLogEvent) -> dict: | |||
| def dict_of_event(event: AgentLogEvent): | |||
| return { | |||
| "status": event.status, | |||
| "error": event.error, | |||
| @@ -489,7 +487,7 @@ class LLMGenerator: | |||
| instruction: str, | |||
| node_type: str, | |||
| ideal_output: str | None, | |||
| ) -> dict: | |||
| ): | |||
| LAST_RUN = "{{#last_run#}}" | |||
| CURRENT = "{{#current#}}" | |||
| ERROR_MESSAGE = "{{#error_message#}}" | |||
| @@ -1,5 +1,3 @@ | |||
| from typing import Any | |||
| from core.llm_generator.output_parser.errors import OutputParserError | |||
| from core.llm_generator.prompts import ( | |||
| RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, | |||
| @@ -17,7 +15,7 @@ class RuleConfigGeneratorOutputParser: | |||
| RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, | |||
| ) | |||
| def parse(self, text: str) -> Any: | |||
| def parse(self, text: str): | |||
| try: | |||
| expected_keys = ["prompt", "variables", "opening_statement"] | |||
| parsed = parse_and_check_json_markdown(text, expected_keys) | |||
| @@ -210,7 +210,7 @@ def _handle_native_json_schema( | |||
| structured_output_schema: Mapping, | |||
| model_parameters: dict, | |||
| rules: list[ParameterRule], | |||
| ) -> dict: | |||
| ): | |||
| """ | |||
| Handle structured output for models with native JSON schema support. | |||
| @@ -232,7 +232,7 @@ def _handle_native_json_schema( | |||
| return model_parameters | |||
| def _set_response_format(model_parameters: dict, rules: list) -> None: | |||
| def _set_response_format(model_parameters: dict, rules: list): | |||
| """ | |||
| Set the appropriate response format parameter based on model rules. | |||
| @@ -306,7 +306,7 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]: | |||
| return structured_output | |||
| def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: | |||
| def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping): | |||
| """ | |||
| Prepare JSON schema based on model requirements. | |||
| @@ -334,7 +334,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema | |||
| return {"schema": processed_schema, "name": "llm_response"} | |||
| def remove_additional_properties(schema: dict) -> None: | |||
| def remove_additional_properties(schema: dict): | |||
| """ | |||
| Remove additionalProperties fields from JSON schema. | |||
| Used for models like Gemini that don't support this property. | |||
| @@ -357,7 +357,7 @@ def remove_additional_properties(schema: dict) -> None: | |||
| remove_additional_properties(item) | |||
| def convert_boolean_to_string(schema: dict) -> None: | |||
| def convert_boolean_to_string(schema: dict): | |||
| """ | |||
| Convert boolean type specifications to string in JSON schema. | |||
| @@ -1,6 +1,5 @@ | |||
| import json | |||
| import re | |||
| from typing import Any | |||
| from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | |||
| @@ -9,7 +8,7 @@ class SuggestedQuestionsAfterAnswerOutputParser: | |||
| def get_format_instructions(self) -> str: | |||
| return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | |||
| def parse(self, text: str) -> Any: | |||
| def parse(self, text: str): | |||
| action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) | |||
| if action_match is not None: | |||
| json_obj = json.loads(action_match.group(0).strip()) | |||
| @@ -44,7 +44,7 @@ class OAuthClientProvider: | |||
| return None | |||
| return OAuthClientInformation.model_validate(client_information) | |||
| def save_client_information(self, client_information: OAuthClientInformationFull) -> None: | |||
| def save_client_information(self, client_information: OAuthClientInformationFull): | |||
| """Saves client information after dynamic registration.""" | |||
| MCPToolManageService.update_mcp_provider_credentials( | |||
| self.mcp_provider, | |||
| @@ -63,13 +63,13 @@ class OAuthClientProvider: | |||
| refresh_token=credentials.get("refresh_token", ""), | |||
| ) | |||
| def save_tokens(self, tokens: OAuthTokens) -> None: | |||
| def save_tokens(self, tokens: OAuthTokens): | |||
| """Stores new OAuth tokens for the current session.""" | |||
| # update mcp provider credentials | |||
| token_dict = tokens.model_dump() | |||
| MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) | |||
| def save_code_verifier(self, code_verifier: str) -> None: | |||
| def save_code_verifier(self, code_verifier: str): | |||
| """Saves a PKCE code verifier for the current session.""" | |||
| MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) | |||
| @@ -47,7 +47,7 @@ class SSETransport: | |||
| headers: dict[str, Any] | None = None, | |||
| timeout: float = 5.0, | |||
| sse_read_timeout: float = 5 * 60, | |||
| ) -> None: | |||
| ): | |||
| """Initialize the SSE transport. | |||
| Args: | |||
| @@ -76,7 +76,7 @@ class SSETransport: | |||
| return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme | |||
| def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: | |||
| def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue): | |||
| """Handle an 'endpoint' SSE event. | |||
| Args: | |||
| @@ -94,7 +94,7 @@ class SSETransport: | |||
| status_queue.put(_StatusReady(endpoint_url)) | |||
| def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: | |||
| def _handle_message_event(self, sse_data: str, read_queue: ReadQueue): | |||
| """Handle a 'message' SSE event. | |||
| Args: | |||
| @@ -110,7 +110,7 @@ class SSETransport: | |||
| logger.exception("Error parsing server message") | |||
| read_queue.put(exc) | |||
| def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | |||
| def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue): | |||
| """Handle a single SSE event. | |||
| Args: | |||
| @@ -126,7 +126,7 @@ class SSETransport: | |||
| case _: | |||
| logger.warning("Unknown SSE event: %s", sse.event) | |||
| def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | |||
| def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue): | |||
| """Read and process SSE events. | |||
| Args: | |||
| @@ -144,7 +144,7 @@ class SSETransport: | |||
| finally: | |||
| read_queue.put(None) | |||
| def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: | |||
| def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage): | |||
| """Send a single message to the server. | |||
| Args: | |||
| @@ -163,7 +163,7 @@ class SSETransport: | |||
| response.raise_for_status() | |||
| logger.debug("Client message sent successfully: %s", response.status_code) | |||
| def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: | |||
| def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue): | |||
| """Handle writing messages to the server. | |||
| Args: | |||
| @@ -303,7 +303,7 @@ def sse_client( | |||
| write_queue.put(None) | |||
| def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: | |||
| def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage): | |||
| """ | |||
| Send a message to the server using the provided HTTP client. | |||
| @@ -82,7 +82,7 @@ class StreamableHTTPTransport: | |||
| headers: dict[str, Any] | None = None, | |||
| timeout: float | timedelta = 30, | |||
| sse_read_timeout: float | timedelta = 60 * 5, | |||
| ) -> None: | |||
| ): | |||
| """Initialize the StreamableHTTP transport. | |||
| Args: | |||
| @@ -122,7 +122,7 @@ class StreamableHTTPTransport: | |||
| def _maybe_extract_session_id_from_response( | |||
| self, | |||
| response: httpx.Response, | |||
| ) -> None: | |||
| ): | |||
| """Extract and store session ID from response headers.""" | |||
| new_session_id = response.headers.get(MCP_SESSION_ID) | |||
| if new_session_id: | |||
| @@ -173,7 +173,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| client: httpx.Client, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| ) -> None: | |||
| ): | |||
| """Handle GET stream for server-initiated messages.""" | |||
| try: | |||
| if not self.session_id: | |||
| @@ -197,7 +197,7 @@ class StreamableHTTPTransport: | |||
| except Exception as exc: | |||
| logger.debug("GET stream error (non-fatal): %s", exc) | |||
| def _handle_resumption_request(self, ctx: RequestContext) -> None: | |||
| def _handle_resumption_request(self, ctx: RequestContext): | |||
| """Handle a resumption request using GET with SSE.""" | |||
| headers = self._update_headers_with_session(ctx.headers) | |||
| if ctx.metadata and ctx.metadata.resumption_token: | |||
| @@ -230,7 +230,7 @@ class StreamableHTTPTransport: | |||
| if is_complete: | |||
| break | |||
| def _handle_post_request(self, ctx: RequestContext) -> None: | |||
| def _handle_post_request(self, ctx: RequestContext): | |||
| """Handle a POST request with response processing.""" | |||
| headers = self._update_headers_with_session(ctx.headers) | |||
| message = ctx.session_message.message | |||
| @@ -278,7 +278,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| response: httpx.Response, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| ) -> None: | |||
| ): | |||
| """Handle JSON response from the server.""" | |||
| try: | |||
| content = response.read() | |||
| @@ -288,7 +288,7 @@ class StreamableHTTPTransport: | |||
| except Exception as exc: | |||
| server_to_client_queue.put(exc) | |||
| def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: | |||
| def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext): | |||
| """Handle SSE response from the server.""" | |||
| try: | |||
| event_source = EventSource(response) | |||
| @@ -307,7 +307,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| content_type: str, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| ) -> None: | |||
| ): | |||
| """Handle unexpected content type in response.""" | |||
| error_msg = f"Unexpected content type: {content_type}" | |||
| logger.error(error_msg) | |||
| @@ -317,7 +317,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| request_id: RequestId, | |||
| ) -> None: | |||
| ): | |||
| """Send a session terminated error response.""" | |||
| jsonrpc_error = JSONRPCError( | |||
| jsonrpc="2.0", | |||
| @@ -333,7 +333,7 @@ class StreamableHTTPTransport: | |||
| client_to_server_queue: ClientToServerQueue, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| start_get_stream: Callable[[], None], | |||
| ) -> None: | |||
| ): | |||
| """Handle writing requests to the server. | |||
| This method processes messages from the client_to_server_queue and sends them to the server. | |||
| @@ -379,7 +379,7 @@ class StreamableHTTPTransport: | |||
| except Exception as exc: | |||
| server_to_client_queue.put(exc) | |||
| def terminate_session(self, client: httpx.Client) -> None: | |||
| def terminate_session(self, client: httpx.Client): | |||
| """Terminate the session by sending a DELETE request.""" | |||
| if not self.session_id: | |||
| return | |||
| @@ -441,7 +441,7 @@ def streamablehttp_client( | |||
| timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), | |||
| ) as client: | |||
| # Define callbacks that need access to thread pool | |||
| def start_get_stream() -> None: | |||
| def start_get_stream(): | |||
| """Start a worker thread to handle server-initiated messages.""" | |||
| executor.submit(transport.handle_get_stream, client, server_to_client_queue) | |||
| @@ -76,7 +76,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| ReceiveNotificationT | |||
| ]""", | |||
| on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], | |||
| ) -> None: | |||
| ): | |||
| self.request_id = request_id | |||
| self.request_meta = request_meta | |||
| self.request = request | |||
| @@ -95,7 +95,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| exc_type: type[BaseException] | None, | |||
| exc_val: BaseException | None, | |||
| exc_tb: TracebackType | None, | |||
| ) -> None: | |||
| ): | |||
| """Exit the context manager, performing cleanup and notifying completion.""" | |||
| try: | |||
| if self._completed: | |||
| @@ -103,7 +103,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| finally: | |||
| self._entered = False | |||
| def respond(self, response: SendResultT | ErrorData) -> None: | |||
| def respond(self, response: SendResultT | ErrorData): | |||
| """Send a response for this request. | |||
| Must be called within a context manager block. | |||
| @@ -119,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| self._session._send_response(request_id=self.request_id, response=response) | |||
| def cancel(self) -> None: | |||
| def cancel(self): | |||
| """Cancel this request and mark it as completed.""" | |||
| if not self._entered: | |||
| raise RuntimeError("RequestResponder must be used as a context manager") | |||
| @@ -163,7 +163,7 @@ class BaseSession( | |||
| receive_notification_type: type[ReceiveNotificationT], | |||
| # If none, reading will never time out | |||
| read_timeout_seconds: timedelta | None = None, | |||
| ) -> None: | |||
| ): | |||
| self._read_stream = read_stream | |||
| self._write_stream = write_stream | |||
| self._response_streams = {} | |||
| @@ -183,7 +183,7 @@ class BaseSession( | |||
| self._receiver_future = self._executor.submit(self._receive_loop) | |||
| return self | |||
| def check_receiver_status(self) -> None: | |||
| def check_receiver_status(self): | |||
| """`check_receiver_status` ensures that any exceptions raised during the | |||
| execution of `_receive_loop` are retrieved and propagated.""" | |||
| if self._receiver_future and self._receiver_future.done(): | |||
| @@ -191,7 +191,7 @@ class BaseSession( | |||
| def __exit__( | |||
| self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | |||
| ) -> None: | |||
| ): | |||
| self._read_stream.put(None) | |||
| self._write_stream.put(None) | |||
| @@ -277,7 +277,7 @@ class BaseSession( | |||
| self, | |||
| notification: SendNotificationT, | |||
| related_request_id: RequestId | None = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Emits a notification, which is a one-way message that does not expect | |||
| a response. | |||
| @@ -296,7 +296,7 @@ class BaseSession( | |||
| ) | |||
| self._write_stream.put(session_message) | |||
| def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: | |||
| def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData): | |||
| if isinstance(response, ErrorData): | |||
| jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) | |||
| session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) | |||
| @@ -310,7 +310,7 @@ class BaseSession( | |||
| session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) | |||
| self._write_stream.put(session_message) | |||
| def _receive_loop(self) -> None: | |||
| def _receive_loop(self): | |||
| """ | |||
| Main message processing loop. | |||
| In a real synchronous implementation, this would likely run in a separate thread. | |||
| @@ -382,7 +382,7 @@ class BaseSession( | |||
| logger.exception("Error in message processing loop") | |||
| raise | |||
| def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: | |||
| def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]): | |||
| """ | |||
| Can be overridden by subclasses to handle a request without needing to | |||
| listen on the message stream. | |||
| @@ -391,15 +391,13 @@ class BaseSession( | |||
| forwarded on to the message stream. | |||
| """ | |||
| def _received_notification(self, notification: ReceiveNotificationT) -> None: | |||
| def _received_notification(self, notification: ReceiveNotificationT): | |||
| """ | |||
| Can be overridden by subclasses to handle a notification without needing | |||
| to listen on the message stream. | |||
| """ | |||
| def send_progress_notification( | |||
| self, progress_token: str | int, progress: float, total: float | None = None | |||
| ) -> None: | |||
| def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): | |||
| """ | |||
| Sends a progress notification for a request that is currently being | |||
| processed. | |||
| @@ -408,5 +406,5 @@ class BaseSession( | |||
| def _handle_incoming( | |||
| self, | |||
| req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, | |||
| ) -> None: | |||
| ): | |||
| """A generic handler for incoming messages. Overwritten by subclasses.""" | |||
| @@ -28,19 +28,19 @@ class LoggingFnT(Protocol): | |||
| def __call__( | |||
| self, | |||
| params: types.LoggingMessageNotificationParams, | |||
| ) -> None: ... | |||
| ): ... | |||
| class MessageHandlerFnT(Protocol): | |||
| def __call__( | |||
| self, | |||
| message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | |||
| ) -> None: ... | |||
| ): ... | |||
| def _default_message_handler( | |||
| message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | |||
| ) -> None: | |||
| ): | |||
| if isinstance(message, Exception): | |||
| raise ValueError(str(message)) | |||
| elif isinstance(message, (types.ServerNotification | RequestResponder)): | |||
| @@ -68,7 +68,7 @@ def _default_list_roots_callback( | |||
| def _default_logging_callback( | |||
| params: types.LoggingMessageNotificationParams, | |||
| ) -> None: | |||
| ): | |||
| pass | |||
| @@ -94,7 +94,7 @@ class ClientSession( | |||
| logging_callback: LoggingFnT | None = None, | |||
| message_handler: MessageHandlerFnT | None = None, | |||
| client_info: types.Implementation | None = None, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| read_stream, | |||
| write_stream, | |||
| @@ -155,9 +155,7 @@ class ClientSession( | |||
| types.EmptyResult, | |||
| ) | |||
| def send_progress_notification( | |||
| self, progress_token: str | int, progress: float, total: float | None = None | |||
| ) -> None: | |||
| def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): | |||
| """Send a progress notification.""" | |||
| self.send_notification( | |||
| types.ClientNotification( | |||
| @@ -314,7 +312,7 @@ class ClientSession( | |||
| types.ListToolsResult, | |||
| ) | |||
| def send_roots_list_changed(self) -> None: | |||
| def send_roots_list_changed(self): | |||
| """Send a roots/list_changed notification.""" | |||
| self.send_notification( | |||
| types.ClientNotification( | |||
| @@ -324,7 +322,7 @@ class ClientSession( | |||
| ) | |||
| ) | |||
| def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: | |||
| def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]): | |||
| ctx = RequestContext[ClientSession, Any]( | |||
| request_id=responder.request_id, | |||
| meta=responder.request_meta, | |||
| @@ -352,11 +350,11 @@ class ClientSession( | |||
| def _handle_incoming( | |||
| self, | |||
| req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | |||
| ) -> None: | |||
| ): | |||
| """Handle incoming messages by forwarding to the message handler.""" | |||
| self._message_handler(req) | |||
| def _received_notification(self, notification: types.ServerNotification) -> None: | |||
| def _received_notification(self, notification: types.ServerNotification): | |||
| """Handle notifications from the server.""" | |||
| # Process specific notification types | |||
| match notification.root: | |||
| @@ -27,7 +27,7 @@ class TokenBufferMemory: | |||
| self, | |||
| conversation: Conversation, | |||
| model_instance: ModelInstance, | |||
| ) -> None: | |||
| ): | |||
| self.conversation = conversation | |||
| self.model_instance = model_instance | |||
| @@ -32,7 +32,7 @@ class ModelInstance: | |||
| Model instance class | |||
| """ | |||
| def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: | |||
| def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): | |||
| self.provider_model_bundle = provider_model_bundle | |||
| self.model = model | |||
| self.provider = provider_model_bundle.configuration.provider.provider | |||
| @@ -46,7 +46,7 @@ class ModelInstance: | |||
| ) | |||
| @staticmethod | |||
| def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: | |||
| def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): | |||
| """ | |||
| Fetch credentials from provider model bundle | |||
| :param provider_model_bundle: provider model bundle | |||
| @@ -342,7 +342,7 @@ class ModelInstance: | |||
| ), | |||
| ) | |||
| def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: | |||
| def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): | |||
| """ | |||
| Round-robin invoke | |||
| :param function: function to invoke | |||
| @@ -379,7 +379,7 @@ class ModelInstance: | |||
| except Exception as e: | |||
| raise e | |||
| def get_tts_voices(self, language: Optional[str] = None) -> list: | |||
| def get_tts_voices(self, language: Optional[str] = None): | |||
| """ | |||
| Invoke large language tts model voices | |||
| @@ -394,7 +394,7 @@ class ModelInstance: | |||
| class ModelManager: | |||
| def __init__(self) -> None: | |||
| def __init__(self): | |||
| self._provider_manager = ProviderManager() | |||
| def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: | |||
| @@ -453,7 +453,7 @@ class LBModelManager: | |||
| model: str, | |||
| load_balancing_configs: list[ModelLoadBalancingConfiguration], | |||
| managed_credentials: Optional[dict] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Load balancing model manager | |||
| :param tenant_id: tenant_id | |||
| @@ -534,7 +534,7 @@ model: %s""", | |||
| return config | |||
| def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: | |||
| def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60): | |||
| """ | |||
| Cooldown model load balancing config | |||
| :param config: model load balancing config | |||
| @@ -35,7 +35,7 @@ class Callback(ABC): | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Before invoke callback | |||
| @@ -94,7 +94,7 @@ class Callback(ABC): | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| After invoke callback | |||
| @@ -124,7 +124,7 @@ class Callback(ABC): | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Invoke error callback | |||
| @@ -141,7 +141,7 @@ class Callback(ABC): | |||
| """ | |||
| raise NotImplementedError() | |||
| def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: | |||
| def print_text(self, text: str, color: Optional[str] = None, end: str = ""): | |||
| """Print text with highlighting and no end characters.""" | |||
| text_to_print = self._get_colored_text(text, color) if color else text | |||
| print(text_to_print, end=end) | |||
| @@ -24,7 +24,7 @@ class LoggingCallback(Callback): | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Before invoke callback | |||
| @@ -110,7 +110,7 @@ class LoggingCallback(Callback): | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| After invoke callback | |||
| @@ -151,7 +151,7 @@ class LoggingCallback(Callback): | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Invoke error callback | |||
| @@ -6,7 +6,7 @@ class InvokeError(ValueError): | |||
| description: Optional[str] = None | |||
| def __init__(self, description: Optional[str] = None) -> None: | |||
| def __init__(self, description: Optional[str] = None): | |||
| self.description = description | |||
| def __str__(self): | |||
| @@ -242,7 +242,7 @@ class AIModel(BaseModel): | |||
| """ | |||
| return None | |||
| def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict: | |||
| def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): | |||
| """ | |||
| Get default parameter rule for given name | |||
| @@ -411,7 +411,7 @@ class LargeLanguageModel(AIModel): | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Trigger before invoke callbacks | |||
| @@ -459,7 +459,7 @@ class LargeLanguageModel(AIModel): | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Trigger new chunk callbacks | |||
| @@ -506,7 +506,7 @@ class LargeLanguageModel(AIModel): | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Trigger after invoke callbacks | |||
| @@ -556,7 +556,7 @@ class LargeLanguageModel(AIModel): | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Trigger invoke error callbacks | |||
| @@ -28,7 +28,7 @@ class GPT2Tokenizer: | |||
| return GPT2Tokenizer._get_num_tokens_by_gpt2(text) | |||
| @staticmethod | |||
| def get_encoder() -> Any: | |||
| def get_encoder(): | |||
| global _tokenizer, _lock | |||
| if _tokenizer is not None: | |||
| return _tokenizer | |||
| @@ -57,7 +57,7 @@ class TTSModel(AIModel): | |||
| except Exception as e: | |||
| raise self._transform_invoke_error(e) | |||
| def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]: | |||
| def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None): | |||
| """ | |||
| Retrieves the list of voices supported by a given text-to-speech (TTS) model. | |||
| @@ -132,7 +132,7 @@ class ModelProviderFactory: | |||
| return plugin_model_provider_entity | |||
| def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: | |||
| def provider_credentials_validate(self, *, provider: str, credentials: dict): | |||
| """ | |||
| Validate provider credentials | |||
| @@ -163,9 +163,7 @@ class ModelProviderFactory: | |||
| return filtered_credentials | |||
| def model_credentials_validate( | |||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | |||
| ) -> dict: | |||
| def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): | |||
| """ | |||
| Validate model credentials | |||
| @@ -6,7 +6,7 @@ from core.model_runtime.entities.provider_entities import CredentialFormSchema, | |||
| class CommonValidator: | |||
| def _validate_and_filter_credential_form_schemas( | |||
| self, credential_form_schemas: list[CredentialFormSchema], credentials: dict | |||
| ) -> dict: | |||
| ): | |||
| need_validate_credential_form_schema_map = {} | |||
| for credential_form_schema in credential_form_schemas: | |||
| if not credential_form_schema.show_on: | |||
| @@ -8,7 +8,7 @@ class ModelCredentialSchemaValidator(CommonValidator): | |||
| self.model_type = model_type | |||
| self.model_credential_schema = model_credential_schema | |||
| def validate_and_filter(self, credentials: dict) -> dict: | |||
| def validate_and_filter(self, credentials: dict): | |||
| """ | |||
| Validate model credentials | |||
| @@ -6,7 +6,7 @@ class ProviderCredentialSchemaValidator(CommonValidator): | |||
| def __init__(self, provider_credential_schema: ProviderCredentialSchema): | |||
| self.provider_credential_schema = provider_credential_schema | |||
| def validate_and_filter(self, credentials: dict) -> dict: | |||
| def validate_and_filter(self, credentials: dict): | |||
| """ | |||
| Validate provider credentials | |||
| @@ -18,7 +18,7 @@ from pydantic_core import Url | |||
| from pydantic_extra_types.color import Color | |||
| def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: | |||
| def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any): | |||
| return model.model_dump(mode=mode, **kwargs) | |||
| @@ -100,7 +100,7 @@ def jsonable_encoder( | |||
| exclude_none: bool = False, | |||
| custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, | |||
| sqlalchemy_safe: bool = True, | |||
| ) -> Any: | |||
| ): | |||
| custom_encoder = custom_encoder or {} | |||
| if custom_encoder: | |||
| if type(obj) in custom_encoder: | |||
| @@ -25,7 +25,7 @@ class ApiModeration(Moderation): | |||
| name: str = "api" | |||
| @classmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -75,7 +75,7 @@ class ApiModeration(Moderation): | |||
| flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response | |||
| ) | |||
| def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: | |||
| def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict): | |||
| if self.config is None: | |||
| raise ValueError("The config is not set.") | |||
| extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) | |||
| @@ -34,13 +34,13 @@ class Moderation(Extensible, ABC): | |||
| module: ExtensionModule = ExtensionModule.MODERATION | |||
| def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: | |||
| def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None): | |||
| super().__init__(tenant_id, config) | |||
| self.app_id = app_id | |||
| @classmethod | |||
| @abstractmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -76,7 +76,7 @@ class Moderation(Extensible, ABC): | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: | |||
| def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool): | |||
| # inputs_config | |||
| inputs_config = config.get("inputs_config") | |||
| if not isinstance(inputs_config, dict): | |||
| @@ -6,12 +6,12 @@ from extensions.ext_code_based_extension import code_based_extension | |||
| class ModerationFactory: | |||
| __extension_instance: Moderation | |||
| def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None: | |||
| def __init__(self, name: str, app_id: str, tenant_id: str, config: dict): | |||
| extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) | |||
| self.__extension_instance = extension_class(app_id, tenant_id, config) | |||
| @classmethod | |||
| def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, name: str, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -8,7 +8,7 @@ class KeywordsModeration(Moderation): | |||
| name: str = "keywords" | |||
| @classmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -7,7 +7,7 @@ class OpenAIModeration(Moderation): | |||
| name: str = "openai_moderation" | |||
| @classmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -40,7 +40,7 @@ class OutputModeration(BaseModel): | |||
| def get_final_output(self) -> str: | |||
| return self.final_output or "" | |||
| def append_new_token(self, token: str) -> None: | |||
| def append_new_token(self, token: str): | |||
| self.buffer += token | |||
| if not self.thread: | |||
| @@ -6,7 +6,7 @@ from models.account import Tenant | |||
| class PluginEncrypter: | |||
| @classmethod | |||
| def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: | |||
| def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt): | |||
| encrypter, cache = create_provider_encrypter( | |||
| tenant_id=tenant.id, | |||
| config=payload.config, | |||
| @@ -27,7 +27,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): | |||
| model_config: ParameterExtractorModelConfig, | |||
| instruction: str, | |||
| query: str, | |||
| ) -> dict: | |||
| ): | |||
| """ | |||
| Invoke parameter extractor node. | |||
| @@ -78,7 +78,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): | |||
| classes: list[ClassConfig], | |||
| instruction: str, | |||
| query: str, | |||
| ) -> dict: | |||
| ): | |||
| """ | |||
| Invoke question classifier node. | |||
| @@ -115,7 +115,7 @@ class PluginDeclaration(BaseModel): | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_category(cls, values: dict) -> dict: | |||
| def validate_category(cls, values: dict): | |||
| # auto detect category | |||
| if values.get("tool"): | |||
| values["category"] = PluginCategory.Tool | |||
| @@ -17,7 +17,7 @@ class PluginAgentClient(BasePluginClient): | |||
| Fetch agent providers for the given tenant. | |||
| """ | |||
| def transformer(json_response: dict[str, Any]) -> dict: | |||
| def transformer(json_response: dict[str, Any]): | |||
| for provider in json_response.get("data", []): | |||
| declaration = provider.get("declaration", {}) or {} | |||
| provider_name = declaration.get("identity", {}).get("name") | |||
| @@ -49,7 +49,7 @@ class PluginAgentClient(BasePluginClient): | |||
| """ | |||
| agent_provider_id = GenericProviderID(provider) | |||
| def transformer(json_response: dict[str, Any]) -> dict: | |||
| def transformer(json_response: dict[str, Any]): | |||
| # skip if error occurs | |||
| if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None: | |||
| return json_response | |||
| @@ -8,7 +8,7 @@ from extensions.ext_logging import get_request_id | |||
| class PluginDaemonError(Exception): | |||
| """Base class for all plugin daemon errors.""" | |||
| def __init__(self, description: str) -> None: | |||
| def __init__(self, description: str): | |||
| self.description = description | |||
| def __str__(self) -> str: | |||
| @@ -415,7 +415,7 @@ class PluginModelClient(BasePluginClient): | |||
| model: str, | |||
| credentials: dict, | |||
| language: Optional[str] = None, | |||
| ) -> list[dict]: | |||
| ): | |||
| """ | |||
| Get tts model voices | |||
| """ | |||
| @@ -16,7 +16,7 @@ class PluginToolManager(BasePluginClient): | |||
| Fetch tool providers for the given tenant. | |||
| """ | |||
| def transformer(json_response: dict[str, Any]) -> dict: | |||
| def transformer(json_response: dict[str, Any]): | |||
| for provider in json_response.get("data", []): | |||
| declaration = provider.get("declaration", {}) or {} | |||
| provider_name = declaration.get("identity", {}).get("name") | |||
| @@ -48,7 +48,7 @@ class PluginToolManager(BasePluginClient): | |||
| """ | |||
| tool_provider_id = ToolProviderID(provider) | |||
| def transformer(json_response: dict[str, Any]) -> dict: | |||
| def transformer(json_response: dict[str, Any]): | |||
| data = json_response.get("data") | |||
| if data: | |||
| for tool in data.get("declaration", {}).get("tools", []): | |||
| @@ -18,7 +18,7 @@ class FileChunk: | |||
| bytes_written: int = field(default=0, init=False) | |||
| data: bytearray = field(init=False) | |||
| def __post_init__(self) -> None: | |||
| def __post_init__(self): | |||
| self.data = bytearray(self.total_length) | |||