| def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: | def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: | |||||
| def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool): | |||||
| return value | return value |
| logger.exception("[get-access-token] exception occurred") | logger.exception("[get-access-token] exception occurred") | ||||
| raise | raise | ||||
| def _parse_config(self, content: str) -> dict: | |||||
| def _parse_config(self, content: str): | |||||
| if not content: | if not content: | ||||
| return {} | return {} | ||||
| try: | try: |
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self) -> dict: | |||||
| def post(self): | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("type", type=str, required=True, default=False, location="json") | parser.add_argument("type", type=str, required=True, default=False, location="json") | ||||
| args = parser.parse_args() | args = parser.parse_args() |
| import logging | import logging | ||||
| from typing import Any, NoReturn | |||||
| from typing import NoReturn | |||||
| from flask import Response | from flask import Response | ||||
| from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse | from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| def _convert_values_to_json_serializable_object(value: Segment) -> Any: | |||||
| def _convert_values_to_json_serializable_object(value: Segment): | |||||
| if isinstance(value, FileSegment): | if isinstance(value, FileSegment): | ||||
| return value.value.model_dump() | return value.value.model_dump() | ||||
| elif isinstance(value, ArrayFileSegment): | elif isinstance(value, ArrayFileSegment): | ||||
| return value.value | return value.value | ||||
| def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: | |||||
| def _serialize_var_value(variable: WorkflowDraftVariable): | |||||
| value = variable.get_value() | value = variable.get_value() | ||||
| # create a copy of the value to avoid affecting the model cache. | # create a copy of the value to avoid affecting the model cache. | ||||
| value = value.model_copy(deep=True) | value = value.model_copy(deep=True) |
| return mcp_server, app | return mcp_server, app | ||||
| def _validate_server_status(self, mcp_server: AppMCPServer) -> None: | |||||
| def _validate_server_status(self, mcp_server: AppMCPServer): | |||||
| """Validate MCP server status""" | """Validate MCP server status""" | ||||
| if mcp_server.status != AppMCPServerStatus.ACTIVE: | if mcp_server.status != AppMCPServerStatus.ACTIVE: | ||||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") | raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") |
| model_instance: ModelInstance, | model_instance: ModelInstance, | ||||
| memory: Optional[TokenBufferMemory] = None, | memory: Optional[TokenBufferMemory] = None, | ||||
| prompt_messages: Optional[list[PromptMessage]] = None, | prompt_messages: Optional[list[PromptMessage]] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.application_generate_entity = application_generate_entity | self.application_generate_entity = application_generate_entity | ||||
| self.conversation = conversation | self.conversation = conversation |
| return instruction | return instruction | ||||
| def _init_react_state(self, query) -> None: | |||||
| def _init_react_state(self, query): | |||||
| """ | """ | ||||
| init agent scratchpad | init agent scratchpad | ||||
| """ | """ |
| action_name: str | action_name: str | ||||
| action_input: Union[dict, str] | action_input: Union[dict, str] | ||||
| def to_dict(self) -> dict: | |||||
| def to_dict(self): | |||||
| """ | """ | ||||
| Convert to dictionary. | Convert to dictionary. | ||||
| """ | """ |
| return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] | return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] | ||||
| @classmethod | @classmethod | ||||
| def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: | |||||
| def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict): | |||||
| """ | """ | ||||
| Extract dataset config for legacy compatibility | Extract dataset config for legacy compatibility | ||||
| return dict(config), ["model"] | return dict(config), ["model"] | ||||
| @classmethod | @classmethod | ||||
| def validate_model_completion_params(cls, cp: dict) -> dict: | |||||
| def validate_model_completion_params(cls, cp: dict): | |||||
| # model.completion_params | # model.completion_params | ||||
| if not isinstance(cp, dict): | if not isinstance(cp, dict): | ||||
| raise ValueError("model.completion_params must be of object type") | raise ValueError("model.completion_params must be of object type") |
| return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] | return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] | ||||
| @classmethod | @classmethod | ||||
| def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: | |||||
| def validate_post_prompt_and_set_defaults(cls, config: dict): | |||||
| """ | """ | ||||
| Validate post_prompt and set defaults for prompt feature | Validate post_prompt and set defaults for prompt feature | ||||
| return app_config | return app_config | ||||
| @classmethod | @classmethod | ||||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): | |||||
| """ | """ | ||||
| Validate for advanced chat app model config | Validate for advanced chat app model config | ||||
| message_id: str, | message_id: str, | ||||
| context: contextvars.Context, | context: contextvars.Context, | ||||
| variable_loader: VariableLoader, | variable_loader: VariableLoader, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Generate worker in a new thread. | Generate worker in a new thread. | ||||
| :param flask_app: Flask app | :param flask_app: Flask app |
| workflow: Workflow, | workflow: Workflow, | ||||
| system_user_id: str, | system_user_id: str, | ||||
| app: App, | app: App, | ||||
| ) -> None: | |||||
| ): | |||||
| super().__init__( | super().__init__( | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| variable_loader=variable_loader, | variable_loader=variable_loader, | ||||
| self.system_user_id = system_user_id | self.system_user_id = system_user_id | ||||
| self._app = app | self._app = app | ||||
| def run(self) -> None: | |||||
| def run(self): | |||||
| app_config = self.application_generate_entity.app_config | app_config = self.application_generate_entity.app_config | ||||
| app_config = cast(AdvancedChatAppConfig, app_config) | app_config = cast(AdvancedChatAppConfig, app_config) | ||||
| return False | return False | ||||
| def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: | |||||
| def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy): | |||||
| """ | """ | ||||
| Direct output | Direct output | ||||
| """ | """ |
| workflow_execution_repository: WorkflowExecutionRepository, | workflow_execution_repository: WorkflowExecutionRepository, | ||||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | workflow_node_execution_repository: WorkflowNodeExecutionRepository, | ||||
| draft_var_saver_factory: DraftVariableSaverFactory, | draft_var_saver_factory: DraftVariableSaverFactory, | ||||
| ) -> None: | |||||
| ): | |||||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | self._base_task_pipeline = BasedGenerateTaskPipeline( | ||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| session.rollback() | session.rollback() | ||||
| raise | raise | ||||
| def _ensure_workflow_initialized(self) -> None: | |||||
| def _ensure_workflow_initialized(self): | |||||
| """Fluent validation for workflow state.""" | """Fluent validation for workflow state.""" | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| if self._conversation_name_generate_thread: | if self._conversation_name_generate_thread: | ||||
| self._conversation_name_generate_thread.join() | self._conversation_name_generate_thread.join() | ||||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None): | |||||
| message = self._get_message(session=session) | message = self._get_message(session=session) | ||||
| # If there are assistant files, remove markdown image links from answer | # If there are assistant files, remove markdown image links from answer |
| return app_config | return app_config | ||||
| @classmethod | @classmethod | ||||
| def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: | |||||
| def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): | |||||
| """ | """ | ||||
| Validate for agent chat app model config | Validate for agent chat app model config | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| conversation_id: str, | conversation_id: str, | ||||
| message_id: str, | message_id: str, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Generate worker in a new thread. | Generate worker in a new thread. | ||||
| :param flask_app: Flask app | :param flask_app: Flask app |
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| conversation: Conversation, | conversation: Conversation, | ||||
| message: Message, | message: Message, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Run assistant application | Run assistant application | ||||
| :param application_generate_entity: application generate entity | :param application_generate_entity: application generate entity |
| _blocking_response_type = ChatbotAppBlockingResponse | _blocking_response_type = ChatbotAppBlockingResponse | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking full response. | Convert blocking full response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response | ||||
| return response | return response | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking simple response. | Convert blocking simple response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response |
| return metadata | return metadata | ||||
| @classmethod | @classmethod | ||||
| def _error_to_stream_response(cls, e: Exception) -> dict: | |||||
| def _error_to_stream_response(cls, e: Exception): | |||||
| """ | """ | ||||
| Error to stream response. | Error to stream response. | ||||
| :param e: exception | :param e: exception |
| return value | return value | ||||
| def _sanitize_value(self, value: Any) -> Any: | |||||
| def _sanitize_value(self, value: Any): | |||||
| if isinstance(value, str): | if isinstance(value, str): | ||||
| return value.replace("\x00", "") | return value.replace("\x00", "") | ||||
| return value | return value |
| class AppQueueManager: | class AppQueueManager: | ||||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: | |||||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom): | |||||
| if not user_id: | if not user_id: | ||||
| raise ValueError("user is required") | raise ValueError("user is required") | ||||
| self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) | self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) | ||||
| last_ping_time = elapsed_time // 10 | last_ping_time = elapsed_time // 10 | ||||
| def stop_listen(self) -> None: | |||||
| def stop_listen(self): | |||||
| """ | """ | ||||
| Stop listen to queue | Stop listen to queue | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| self._q.put(None) | self._q.put(None) | ||||
| def publish_error(self, e, pub_from: PublishFrom) -> None: | |||||
| def publish_error(self, e, pub_from: PublishFrom): | |||||
| """ | """ | ||||
| Publish error | Publish error | ||||
| :param e: error | :param e: error | ||||
| """ | """ | ||||
| self.publish(QueueErrorEvent(error=e), pub_from) | self.publish(QueueErrorEvent(error=e), pub_from) | ||||
| def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||||
| def publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||||
| """ | """ | ||||
| Publish event to queue | Publish event to queue | ||||
| :param event: | :param event: | ||||
| self._publish(event, pub_from) | self._publish(event, pub_from) | ||||
| @abstractmethod | @abstractmethod | ||||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||||
| """ | """ | ||||
| Publish event to queue | Publish event to queue | ||||
| :param event: | :param event: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @classmethod | @classmethod | ||||
| def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: | |||||
| def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str): | |||||
| """ | """ | ||||
| Set task stop flag | Set task stop flag | ||||
| :return: | :return: |
| text: str, | text: str, | ||||
| stream: bool, | stream: bool, | ||||
| usage: Optional[LLMUsage] = None, | usage: Optional[LLMUsage] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Direct output | Direct output | ||||
| :param queue_manager: application queue manager | :param queue_manager: application queue manager | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| stream: bool, | stream: bool, | ||||
| agent: bool = False, | agent: bool = False, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Handle invoke result | Handle invoke result | ||||
| :param invoke_result: invoke result | :param invoke_result: invoke result | ||||
| else: | else: | ||||
| raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") | raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") | ||||
| def _handle_invoke_result_direct( | |||||
| self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool | |||||
| ) -> None: | |||||
| def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): | |||||
| """ | """ | ||||
| Handle invoke result direct | Handle invoke result direct | ||||
| :param invoke_result: invoke result | :param invoke_result: invoke result | ||||
| def _handle_invoke_result_stream( | def _handle_invoke_result_stream( | ||||
| self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool | self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Handle invoke result | Handle invoke result | ||||
| :param invoke_result: invoke result | :param invoke_result: invoke result |
| return app_config | return app_config | ||||
| @classmethod | @classmethod | ||||
| def config_validate(cls, tenant_id: str, config: dict) -> dict: | |||||
| def config_validate(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate for chat app model config | Validate for chat app model config | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| conversation_id: str, | conversation_id: str, | ||||
| message_id: str, | message_id: str, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Generate worker in a new thread. | Generate worker in a new thread. | ||||
| :param flask_app: Flask app | :param flask_app: Flask app |
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| conversation: Conversation, | conversation: Conversation, | ||||
| message: Message, | message: Message, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Run application | Run application | ||||
| :param application_generate_entity: application generate entity | :param application_generate_entity: application generate entity |
| _blocking_response_type = ChatbotAppBlockingResponse | _blocking_response_type = ChatbotAppBlockingResponse | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking full response. | Convert blocking full response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response | ||||
| return response | return response | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking simple response. | Convert blocking simple response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response |
| *, | *, | ||||
| application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| ) -> None: | |||||
| ): | |||||
| self._application_generate_entity = application_generate_entity | self._application_generate_entity = application_generate_entity | ||||
| self._user = user | self._user = user | ||||
| return app_config | return app_config | ||||
| @classmethod | @classmethod | ||||
| def config_validate(cls, tenant_id: str, config: dict) -> dict: | |||||
| def config_validate(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate for completion app model config | Validate for completion app model config | ||||
| application_generate_entity: CompletionAppGenerateEntity, | application_generate_entity: CompletionAppGenerateEntity, | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| message_id: str, | message_id: str, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Generate worker in a new thread. | Generate worker in a new thread. | ||||
| :param flask_app: Flask app | :param flask_app: Flask app |
| def run( | def run( | ||||
| self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message | self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Run application | Run application | ||||
| :param application_generate_entity: application generate entity | :param application_generate_entity: application generate entity |
| _blocking_response_type = CompletionAppBlockingResponse | _blocking_response_type = CompletionAppBlockingResponse | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking full response. | Convert blocking full response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response | ||||
| return response | return response | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking simple response. | Convert blocking simple response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response |
| class MessageBasedAppQueueManager(AppQueueManager): | class MessageBasedAppQueueManager(AppQueueManager): | ||||
| def __init__( | def __init__( | ||||
| self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str | self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str | ||||
| ) -> None: | |||||
| ): | |||||
| super().__init__(task_id, user_id, invoke_from) | super().__init__(task_id, user_id, invoke_from) | ||||
| self._conversation_id = str(conversation_id) | self._conversation_id = str(conversation_id) | ||||
| self._app_mode = app_mode | self._app_mode = app_mode | ||||
| self._message_id = str(message_id) | self._message_id = str(message_id) | ||||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||||
| """ | """ | ||||
| Publish event to queue | Publish event to queue | ||||
| :param event: | :param event: |
| return app_config | return app_config | ||||
| @classmethod | @classmethod | ||||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): | |||||
| """ | """ | ||||
| Validate for workflow app model config | Validate for workflow app model config | ||||
| class WorkflowAppQueueManager(AppQueueManager): | class WorkflowAppQueueManager(AppQueueManager): | ||||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: | |||||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str): | |||||
| super().__init__(task_id, user_id, invoke_from) | super().__init__(task_id, user_id, invoke_from) | ||||
| self._app_mode = app_mode | self._app_mode = app_mode | ||||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||||
| """ | """ | ||||
| Publish event to queue | Publish event to queue | ||||
| :param event: | :param event: |
| variable_loader: VariableLoader, | variable_loader: VariableLoader, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| system_user_id: str, | system_user_id: str, | ||||
| ) -> None: | |||||
| ): | |||||
| super().__init__( | super().__init__( | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| variable_loader=variable_loader, | variable_loader=variable_loader, | ||||
| self._workflow = workflow | self._workflow = workflow | ||||
| self._sys_user_id = system_user_id | self._sys_user_id = system_user_id | ||||
| def run(self) -> None: | |||||
| def run(self): | |||||
| """ | """ | ||||
| Run application | Run application | ||||
| """ | """ |
| _blocking_response_type = WorkflowAppBlockingResponse | _blocking_response_type = WorkflowAppBlockingResponse | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking full response. | Convert blocking full response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response | ||||
| return dict(blocking_response.to_dict()) | return dict(blocking_response.to_dict()) | ||||
| @classmethod | @classmethod | ||||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] | |||||
| """ | """ | ||||
| Convert blocking simple response. | Convert blocking simple response. | ||||
| :param blocking_response: blocking response | :param blocking_response: blocking response |
| workflow_execution_repository: WorkflowExecutionRepository, | workflow_execution_repository: WorkflowExecutionRepository, | ||||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | workflow_node_execution_repository: WorkflowNodeExecutionRepository, | ||||
| draft_var_saver_factory: DraftVariableSaverFactory, | draft_var_saver_factory: DraftVariableSaverFactory, | ||||
| ) -> None: | |||||
| ): | |||||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | self._base_task_pipeline = BasedGenerateTaskPipeline( | ||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| session.rollback() | session.rollback() | ||||
| raise | raise | ||||
| def _ensure_workflow_initialized(self) -> None: | |||||
| def _ensure_workflow_initialized(self): | |||||
| """Fluent validation for workflow state.""" | """Fluent validation for workflow state.""" | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| if tts_publisher: | if tts_publisher: | ||||
| tts_publisher.publish(None) | tts_publisher.publish(None) | ||||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: | |||||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution): | |||||
| invoke_from = self._application_generate_entity.invoke_from | invoke_from = self._application_generate_entity.invoke_from | ||||
| if invoke_from == InvokeFrom.SERVICE_API: | if invoke_from == InvokeFrom.SERVICE_API: | ||||
| created_from = WorkflowAppLogCreatedFrom.SERVICE_API | created_from = WorkflowAppLogCreatedFrom.SERVICE_API |
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | ||||
| app_id: str, | app_id: str, | ||||
| ) -> None: | |||||
| ): | |||||
| self._queue_manager = queue_manager | self._queue_manager = queue_manager | ||||
| self._variable_loader = variable_loader | self._variable_loader = variable_loader | ||||
| self._app_id = app_id | self._app_id = app_id | ||||
| return graph, variable_pool | return graph, variable_pool | ||||
| def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: | |||||
| def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): | |||||
| """ | """ | ||||
| Handle event | Handle event | ||||
| :param workflow_entry: workflow entry | :param workflow_entry: workflow entry | ||||
| ) | ) | ||||
| ) | ) | ||||
| def _publish_event(self, event: AppQueueEvent) -> None: | |||||
| def _publish_event(self, event: AppQueueEvent): | |||||
| self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) | self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) |
| application_generate_entity: AppGenerateEntity, | application_generate_entity: AppGenerateEntity, | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| stream: bool, | stream: bool, | ||||
| ) -> None: | |||||
| ): | |||||
| self._application_generate_entity = application_generate_entity | self._application_generate_entity = application_generate_entity | ||||
| self.queue_manager = queue_manager | self.queue_manager = queue_manager | ||||
| self._start_at = time.perf_counter() | self._start_at = time.perf_counter() |
| conversation: Conversation, | conversation: Conversation, | ||||
| message: Message, | message: Message, | ||||
| stream: bool, | stream: bool, | ||||
| ) -> None: | |||||
| ): | |||||
| super().__init__( | super().__init__( | ||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| if self._conversation_name_generate_thread: | if self._conversation_name_generate_thread: | ||||
| self._conversation_name_generate_thread.join() | self._conversation_name_generate_thread.join() | ||||
| def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: | |||||
| def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None): | |||||
| """ | """ | ||||
| Save message. | Save message. | ||||
| :return: | :return: | ||||
| application_generate_entity=self._application_generate_entity, | application_generate_entity=self._application_generate_entity, | ||||
| ) | ) | ||||
| def _handle_stop(self, event: QueueStopEvent) -> None: | |||||
| def _handle_stop(self, event: QueueStopEvent): | |||||
| """ | """ | ||||
| Handle stop. | Handle stop. | ||||
| :return: | :return: |
| AdvancedChatAppGenerateEntity, | AdvancedChatAppGenerateEntity, | ||||
| ], | ], | ||||
| task_state: Union[EasyUITaskState, WorkflowTaskState], | task_state: Union[EasyUITaskState, WorkflowTaskState], | ||||
| ) -> None: | |||||
| ): | |||||
| self._application_generate_entity = application_generate_entity | self._application_generate_entity = application_generate_entity | ||||
| self._task_state = task_state | self._task_state = task_state | ||||
| return None | return None | ||||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent): | |||||
| """ | """ | ||||
| Handle retriever resources. | Handle retriever resources. | ||||
| :param event: event | :param event: event |
| return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" | return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" | ||||
| def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: | |||||
| def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None): | |||||
| """Print text with highlighting and no end characters.""" | """Print text with highlighting and no end characters.""" | ||||
| text_to_print = get_colored_text(text, color) if color else text | text_to_print = get_colored_text(text, color) if color else text | ||||
| print(text_to_print, end=end, file=file) | print(text_to_print, end=end, file=file) | ||||
| color: Optional[str] = "" | color: Optional[str] = "" | ||||
| current_loop: int = 1 | current_loop: int = 1 | ||||
| def __init__(self, color: Optional[str] = None) -> None: | |||||
| def __init__(self, color: Optional[str] = None): | |||||
| super().__init__() | super().__init__() | ||||
| """Initialize callback handler.""" | """Initialize callback handler.""" | ||||
| # use a specific color is not specified | # use a specific color is not specified | ||||
| self, | self, | ||||
| tool_name: str, | tool_name: str, | ||||
| tool_inputs: Mapping[str, Any], | tool_inputs: Mapping[str, Any], | ||||
| ) -> None: | |||||
| ): | |||||
| """Do nothing.""" | """Do nothing.""" | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) | print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) | ||||
| message_id: Optional[str] = None, | message_id: Optional[str] = None, | ||||
| timer: Optional[Any] = None, | timer: Optional[Any] = None, | ||||
| trace_manager: Optional[TraceQueueManager] = None, | trace_manager: Optional[TraceQueueManager] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """If not the final action, print out observation.""" | """If not the final action, print out observation.""" | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| print_text("\n[on_tool_end]\n", color=self.color) | print_text("\n[on_tool_end]\n", color=self.color) | ||||
| ) | ) | ||||
| ) | ) | ||||
| def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: | |||||
| def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any): | |||||
| """Do nothing.""" | """Do nothing.""" | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") | print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") | ||||
| def on_agent_start(self, thought: str) -> None: | |||||
| def on_agent_start(self, thought: str): | |||||
| """Run on agent start.""" | """Run on agent start.""" | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| if thought: | if thought: | ||||
| else: | else: | ||||
| print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) | print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) | ||||
| def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: | |||||
| def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any): | |||||
| """Run on agent end.""" | """Run on agent end.""" | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) | print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) |
| def __init__( | def __init__( | ||||
| self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom | self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom | ||||
| ) -> None: | |||||
| ): | |||||
| self._queue_manager = queue_manager | self._queue_manager = queue_manager | ||||
| self._app_id = app_id | self._app_id = app_id | ||||
| self._message_id = message_id | self._message_id = message_id | ||||
| self._user_id = user_id | self._user_id = user_id | ||||
| self._invoke_from = invoke_from | self._invoke_from = invoke_from | ||||
| def on_query(self, query: str, dataset_id: str) -> None: | |||||
| def on_query(self, query: str, dataset_id: str): | |||||
| """ | """ | ||||
| Handle query. | Handle query. | ||||
| """ | """ | ||||
| db.session.add(dataset_query) | db.session.add(dataset_query) | ||||
| db.session.commit() | db.session.commit() | ||||
| def on_tool_end(self, documents: list[Document]) -> None: | |||||
| def on_tool_end(self, documents: list[Document]): | |||||
| """Handle tool end.""" | """Handle tool end.""" | ||||
| for document in documents: | for document in documents: | ||||
| if document.metadata is not None: | if document.metadata is not None: |
| icon_large: Optional[I18nObject] = None | icon_large: Optional[I18nObject] = None | ||||
| supported_model_types: list[ModelType] | supported_model_types: list[ModelType] | ||||
| def __init__(self, provider_entity: ProviderEntity) -> None: | |||||
| def __init__(self, provider_entity: ProviderEntity): | |||||
| """ | """ | ||||
| Init simple provider. | Init simple provider. | ||||
| load_balancing_enabled: bool = False | load_balancing_enabled: bool = False | ||||
| has_invalid_load_balancing_configs: bool = False | has_invalid_load_balancing_configs: bool = False | ||||
| def raise_for_status(self) -> None: | |||||
| def raise_for_status(self): | |||||
| """ | """ | ||||
| Check model status and raise ValueError if not active. | Check model status and raise ValueError if not active. | ||||
| else [], | else [], | ||||
| ) | ) | ||||
| def validate_provider_credentials( | |||||
| self, credentials: dict, credential_id: str = "", session: Session | None = None | |||||
| ) -> dict: | |||||
| def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): | |||||
| """ | """ | ||||
| Validate custom credentials. | Validate custom credentials. | ||||
| :param credentials: provider credentials | :param credentials: provider credentials | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| def _validate(s: Session) -> dict: | |||||
| def _validate(s: Session): | |||||
| # Get provider credential secret variables | # Get provider credential secret variables | ||||
| provider_credential_secret_variables = self.extract_secret_variables( | provider_credential_secret_variables = self.extract_secret_variables( | ||||
| self.provider.provider_credential_schema.credential_form_schemas | self.provider.provider_credential_schema.credential_form_schemas | ||||
| logger.warning("Error generating next credential name: %s", str(e)) | logger.warning("Error generating next credential name: %s", str(e)) | ||||
| return "API KEY 1" | return "API KEY 1" | ||||
| def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: | |||||
| def create_provider_credential(self, credentials: dict, credential_name: str | None): | |||||
| """ | """ | ||||
| Add custom provider credentials. | Add custom provider credentials. | ||||
| :param credentials: provider credentials | :param credentials: provider credentials | ||||
| credentials: dict, | credentials: dict, | ||||
| credential_id: str, | credential_id: str, | ||||
| credential_name: str | None, | credential_name: str | None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| update a saved provider credential (by credential_id). | update a saved provider credential (by credential_id). | ||||
| credential_record: ProviderCredential | ProviderModelCredential, | credential_record: ProviderCredential | ProviderModelCredential, | ||||
| credential_source: str, | credential_source: str, | ||||
| session: Session, | session: Session, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Update load balancing configurations that reference the given credential_id. | Update load balancing configurations that reference the given credential_id. | ||||
| session.commit() | session.commit() | ||||
| def delete_provider_credential(self, credential_id: str) -> None: | |||||
| def delete_provider_credential(self, credential_id: str): | |||||
| """ | """ | ||||
| Delete a saved provider credential (by credential_id). | Delete a saved provider credential (by credential_id). | ||||
| session.rollback() | session.rollback() | ||||
| raise | raise | ||||
| def switch_active_provider_credential(self, credential_id: str) -> None: | |||||
| def switch_active_provider_credential(self, credential_id: str): | |||||
| """ | """ | ||||
| Switch active provider credential (copy the selected one into current active snapshot). | Switch active provider credential (copy the selected one into current active snapshot). | ||||
| credentials: dict, | credentials: dict, | ||||
| credential_id: str = "", | credential_id: str = "", | ||||
| session: Session | None = None, | session: Session | None = None, | ||||
| ) -> dict: | |||||
| ): | |||||
| """ | """ | ||||
| Validate custom model credentials. | Validate custom model credentials. | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| def _validate(s: Session) -> dict: | |||||
| def _validate(s: Session): | |||||
| # Get provider credential secret variables | # Get provider credential secret variables | ||||
| provider_credential_secret_variables = self.extract_secret_variables( | provider_credential_secret_variables = self.extract_secret_variables( | ||||
| self.provider.model_credential_schema.credential_form_schemas | self.provider.model_credential_schema.credential_form_schemas | ||||
| session.rollback() | session.rollback() | ||||
| raise | raise | ||||
| def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||||
| def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): | |||||
| """ | """ | ||||
| Delete a saved provider credential (by credential_id). | Delete a saved provider credential (by credential_id). | ||||
| session.rollback() | session.rollback() | ||||
| raise | raise | ||||
| def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||||
| def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str): | |||||
| """ | """ | ||||
| if model list exist this custom model, switch the custom model credential. | if model list exist this custom model, switch the custom model credential. | ||||
| if model list not exist this custom model, use the credential to add a new custom model record. | if model list not exist this custom model, use the credential to add a new custom model record. | ||||
| session.add(provider_model_record) | session.add(provider_model_record) | ||||
| session.commit() | session.commit() | ||||
| def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||||
| def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): | |||||
| """ | """ | ||||
| switch the custom model credential. | switch the custom model credential. | ||||
| session.add(provider_model_record) | session.add(provider_model_record) | ||||
| session.commit() | session.commit() | ||||
| def delete_custom_model(self, model_type: ModelType, model: str) -> None: | |||||
| def delete_custom_model(self, model_type: ModelType, model: str): | |||||
| """ | """ | ||||
| Delete custom model. | Delete custom model. | ||||
| :param model_type: model type | :param model_type: model type | ||||
| provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | ||||
| ) | ) | ||||
| def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: | |||||
| def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None): | |||||
| """ | """ | ||||
| Switch preferred provider type. | Switch preferred provider type. | ||||
| :param provider_type: | :param provider_type: | ||||
| if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: | if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: | ||||
| return | return | ||||
| def _switch(s: Session) -> None: | |||||
| def _switch(s: Session): | |||||
| # get preferred provider | # get preferred provider | ||||
| model_provider_id = ModelProviderID(self.provider.provider) | model_provider_id = ModelProviderID(self.provider.provider) | ||||
| provider_names = [self.provider.provider] | provider_names = [self.provider.provider] | ||||
| return secret_input_form_variables | return secret_input_form_variables | ||||
| def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: | |||||
| def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): | |||||
| """ | """ | ||||
| Obfuscated credentials. | Obfuscated credentials. | ||||
| description: Optional[str] = None | description: Optional[str] = None | ||||
| def __init__(self, description: Optional[str] = None) -> None: | |||||
| def __init__(self, description: Optional[str] = None): | |||||
| self.description = description | self.description = description | ||||
| timeout: tuple[int, int] = (5, 60) | timeout: tuple[int, int] = (5, 60) | ||||
| """timeout for request connect and read""" | """timeout for request connect and read""" | ||||
| def __init__(self, api_endpoint: str, api_key: str) -> None: | |||||
| def __init__(self, api_endpoint: str, api_key: str): | |||||
| self.api_endpoint = api_endpoint | self.api_endpoint = api_endpoint | ||||
| self.api_key = api_key | self.api_key = api_key | ||||
| def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: | |||||
| def request(self, point: APIBasedExtensionPoint, params: dict): | |||||
| """ | """ | ||||
| Request the api. | Request the api. | ||||
| tenant_id: str | tenant_id: str | ||||
| config: Optional[dict] = None | config: Optional[dict] = None | ||||
| def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: | |||||
| def __init__(self, tenant_id: str, config: Optional[dict] = None): | |||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.config = config | self.config = config | ||||
| """the unique name of external data tool""" | """the unique name of external data tool""" | ||||
| @classmethod | @classmethod | ||||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| variable: str | variable: str | ||||
| """the tool variable name of app tool""" | """the tool variable name of app tool""" | ||||
| def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: | |||||
| def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None): | |||||
| super().__init__(tenant_id, config) | super().__init__(tenant_id, config) | ||||
| self.app_id = app_id | self.app_id = app_id | ||||
| self.variable = variable | self.variable = variable | ||||
| @classmethod | @classmethod | ||||
| @abstractmethod | @abstractmethod | ||||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| class ExternalDataToolFactory: | class ExternalDataToolFactory: | ||||
| def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: | |||||
| def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict): | |||||
| extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) | extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) | ||||
| self.__extension_instance = extension_class( | self.__extension_instance = extension_class( | ||||
| tenant_id=tenant_id, app_id=app_id, variable=variable, config=config | tenant_id=tenant_id, app_id=app_id, variable=variable, config=config | ||||
| ) | ) | ||||
| @classmethod | @classmethod | ||||
| def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, name: str, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None | _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None | ||||
| def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: | |||||
| def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]): | |||||
| global _tool_file_manager_factory | global _tool_file_manager_factory | ||||
| _tool_file_manager_factory = factory | _tool_file_manager_factory = factory |
| pass | pass | ||||
| @classmethod | @classmethod | ||||
| def get_default_config(cls) -> dict: | |||||
| def get_default_config(cls): | |||||
| return { | return { | ||||
| "type": "code", | "type": "code", | ||||
| "config": { | "config": { |
| class Jinja2TemplateTransformer(TemplateTransformer): | class Jinja2TemplateTransformer(TemplateTransformer): | ||||
| @classmethod | @classmethod | ||||
| def transform_response(cls, response: str) -> dict: | |||||
| def transform_response(cls, response: str): | |||||
| """ | """ | ||||
| Transform response to dict | Transform response to dict | ||||
| :param response: response | :param response: response |
| def get_default_code(cls) -> str: | def get_default_code(cls) -> str: | ||||
| return dedent( | return dedent( | ||||
| """ | """ | ||||
| def main(arg1: str, arg2: str) -> dict: | |||||
| def main(arg1: str, arg2: str): | |||||
| return { | return { | ||||
| "result": arg1 + arg2, | "result": arg1 + arg2, | ||||
| } | } |
| else: | else: | ||||
| return None | return None | ||||
| def set(self, credentials: dict) -> None: | |||||
| def set(self, credentials: dict): | |||||
| """ | """ | ||||
| Cache model provider credentials. | Cache model provider credentials. | ||||
| """ | """ | ||||
| redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) | redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) | ||||
| def delete(self) -> None: | |||||
| def delete(self): | |||||
| """ | """ | ||||
| Delete cached model provider credentials. | Delete cached model provider credentials. | ||||
| return None | return None | ||||
| return None | return None | ||||
| def set(self, config: dict[str, Any]) -> None: | |||||
| def set(self, config: dict[str, Any]): | |||||
| """Cache provider credentials""" | """Cache provider credentials""" | ||||
| redis_client.setex(self.cache_key, 86400, json.dumps(config)) | redis_client.setex(self.cache_key, 86400, json.dumps(config)) | ||||
| def delete(self) -> None: | |||||
| def delete(self): | |||||
| """Delete cached provider credentials""" | """Delete cached provider credentials""" | ||||
| redis_client.delete(self.cache_key) | redis_client.delete(self.cache_key) | ||||
| """Get cached provider credentials""" | """Get cached provider credentials""" | ||||
| return None | return None | ||||
| def set(self, config: dict[str, Any]) -> None: | |||||
| def set(self, config: dict[str, Any]): | |||||
| """Cache provider credentials""" | """Cache provider credentials""" | ||||
| pass | pass | ||||
| def delete(self) -> None: | |||||
| def delete(self): | |||||
| """Delete cached provider credentials""" | """Delete cached provider credentials""" | ||||
| pass | pass |
| else: | else: | ||||
| return None | return None | ||||
| def set(self, parameters: dict) -> None: | |||||
| def set(self, parameters: dict): | |||||
| """Cache model provider credentials.""" | """Cache model provider credentials.""" | ||||
| redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) | redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) | ||||
| def delete(self) -> None: | |||||
| def delete(self): | |||||
| """ | """ | ||||
| Delete cached model provider credentials. | Delete cached model provider credentials. | ||||
| return None | return None | ||||
| def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: | |||||
| def extract_external_trace_id_from_args(args: Mapping[str, Any]): | |||||
| """ | """ | ||||
| Extract 'external_trace_id' from args. | Extract 'external_trace_id' from args. | ||||
| provider_map: dict[str, HostingProvider] | provider_map: dict[str, HostingProvider] | ||||
| moderation_config: Optional[HostedModerationConfig] = None | moderation_config: Optional[HostedModerationConfig] = None | ||||
| def __init__(self) -> None: | |||||
| def __init__(self): | |||||
| self.provider_map = {} | self.provider_map = {} | ||||
| self.moderation_config = None | self.moderation_config = None | ||||
| def init_app(self, app: Flask) -> None: | |||||
| def init_app(self, app: Flask): | |||||
| if dify_config.EDITION != "CLOUD": | if dify_config.EDITION != "CLOUD": | ||||
| return | return | ||||
| dataset: Dataset, | dataset: Dataset, | ||||
| dataset_document: DatasetDocument, | dataset_document: DatasetDocument, | ||||
| documents: list[Document], | documents: list[Document], | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| insert index and update document/segment status to completed | insert index and update document/segment status to completed | ||||
| """ | """ | ||||
| @staticmethod | @staticmethod | ||||
| def _update_document_index_status( | def _update_document_index_status( | ||||
| document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None | document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Update the document indexing status. | Update the document indexing status. | ||||
| """ | """ | ||||
| db.session.commit() | db.session.commit() | ||||
| @staticmethod | @staticmethod | ||||
| def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: | |||||
| def _update_segments_by_document(dataset_document_id: str, update_params: dict): | |||||
| """ | """ | ||||
| Update the document segment by document id. | Update the document segment by document id. | ||||
| """ | """ |
| return questions | return questions | ||||
| @classmethod | @classmethod | ||||
| def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: | |||||
| def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): | |||||
| output_parser = RuleConfigGeneratorOutputParser() | output_parser = RuleConfigGeneratorOutputParser() | ||||
| error = "" | error = "" | ||||
| return rule_config | return rule_config | ||||
| @classmethod | @classmethod | ||||
| def generate_code( | |||||
| cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript" | |||||
| ) -> dict: | |||||
| def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): | |||||
| if code_language == "python": | if code_language == "python": | ||||
| prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) | prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) | ||||
| else: | else: | ||||
| @staticmethod | @staticmethod | ||||
| def instruction_modify_legacy( | def instruction_modify_legacy( | ||||
| tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None | tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None | ||||
| ) -> dict: | |||||
| ): | |||||
| last_run: Message | None = ( | last_run: Message | None = ( | ||||
| db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() | db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() | ||||
| ) | ) | ||||
| instruction: str, | instruction: str, | ||||
| model_config: dict, | model_config: dict, | ||||
| ideal_output: str | None, | ideal_output: str | None, | ||||
| ) -> dict: | |||||
| ): | |||||
| from services.workflow_service import WorkflowService | from services.workflow_service import WorkflowService | ||||
| app: App | None = db.session.query(App).where(App.id == flow_id).first() | app: App | None = db.session.query(App).where(App.id == flow_id).first() | ||||
| return [] | return [] | ||||
| parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) | parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) | ||||
| def dict_of_event(event: AgentLogEvent) -> dict: | |||||
| def dict_of_event(event: AgentLogEvent): | |||||
| return { | return { | ||||
| "status": event.status, | "status": event.status, | ||||
| "error": event.error, | "error": event.error, | ||||
| instruction: str, | instruction: str, | ||||
| node_type: str, | node_type: str, | ||||
| ideal_output: str | None, | ideal_output: str | None, | ||||
| ) -> dict: | |||||
| ): | |||||
| LAST_RUN = "{{#last_run#}}" | LAST_RUN = "{{#last_run#}}" | ||||
| CURRENT = "{{#current#}}" | CURRENT = "{{#current#}}" | ||||
| ERROR_MESSAGE = "{{#error_message#}}" | ERROR_MESSAGE = "{{#error_message#}}" |
| from typing import Any | |||||
| from core.llm_generator.output_parser.errors import OutputParserError | from core.llm_generator.output_parser.errors import OutputParserError | ||||
| from core.llm_generator.prompts import ( | from core.llm_generator.prompts import ( | ||||
| RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, | RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, | ||||
| RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, | RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, | ||||
| ) | ) | ||||
| def parse(self, text: str) -> Any: | |||||
| def parse(self, text: str): | |||||
| try: | try: | ||||
| expected_keys = ["prompt", "variables", "opening_statement"] | expected_keys = ["prompt", "variables", "opening_statement"] | ||||
| parsed = parse_and_check_json_markdown(text, expected_keys) | parsed = parse_and_check_json_markdown(text, expected_keys) |
| structured_output_schema: Mapping, | structured_output_schema: Mapping, | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| rules: list[ParameterRule], | rules: list[ParameterRule], | ||||
| ) -> dict: | |||||
| ): | |||||
| """ | """ | ||||
| Handle structured output for models with native JSON schema support. | Handle structured output for models with native JSON schema support. | ||||
| return model_parameters | return model_parameters | ||||
| def _set_response_format(model_parameters: dict, rules: list) -> None: | |||||
| def _set_response_format(model_parameters: dict, rules: list): | |||||
| """ | """ | ||||
| Set the appropriate response format parameter based on model rules. | Set the appropriate response format parameter based on model rules. | ||||
| return structured_output | return structured_output | ||||
| def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: | |||||
| def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping): | |||||
| """ | """ | ||||
| Prepare JSON schema based on model requirements. | Prepare JSON schema based on model requirements. | ||||
| return {"schema": processed_schema, "name": "llm_response"} | return {"schema": processed_schema, "name": "llm_response"} | ||||
| def remove_additional_properties(schema: dict) -> None: | |||||
| def remove_additional_properties(schema: dict): | |||||
| """ | """ | ||||
| Remove additionalProperties fields from JSON schema. | Remove additionalProperties fields from JSON schema. | ||||
| Used for models like Gemini that don't support this property. | Used for models like Gemini that don't support this property. | ||||
| remove_additional_properties(item) | remove_additional_properties(item) | ||||
| def convert_boolean_to_string(schema: dict) -> None: | |||||
| def convert_boolean_to_string(schema: dict): | |||||
| """ | """ | ||||
| Convert boolean type specifications to string in JSON schema. | Convert boolean type specifications to string in JSON schema. | ||||
| import json | import json | ||||
| import re | import re | ||||
| from typing import Any | |||||
| from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | ||||
| def get_format_instructions(self) -> str: | def get_format_instructions(self) -> str: | ||||
| return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | ||||
| def parse(self, text: str) -> Any: | |||||
| def parse(self, text: str): | |||||
| action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) | action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) | ||||
| if action_match is not None: | if action_match is not None: | ||||
| json_obj = json.loads(action_match.group(0).strip()) | json_obj = json.loads(action_match.group(0).strip()) |
| return None | return None | ||||
| return OAuthClientInformation.model_validate(client_information) | return OAuthClientInformation.model_validate(client_information) | ||||
| def save_client_information(self, client_information: OAuthClientInformationFull) -> None: | |||||
| def save_client_information(self, client_information: OAuthClientInformationFull): | |||||
| """Saves client information after dynamic registration.""" | """Saves client information after dynamic registration.""" | ||||
| MCPToolManageService.update_mcp_provider_credentials( | MCPToolManageService.update_mcp_provider_credentials( | ||||
| self.mcp_provider, | self.mcp_provider, | ||||
| refresh_token=credentials.get("refresh_token", ""), | refresh_token=credentials.get("refresh_token", ""), | ||||
| ) | ) | ||||
| def save_tokens(self, tokens: OAuthTokens) -> None: | |||||
| def save_tokens(self, tokens: OAuthTokens): | |||||
| """Stores new OAuth tokens for the current session.""" | """Stores new OAuth tokens for the current session.""" | ||||
| # update mcp provider credentials | # update mcp provider credentials | ||||
| token_dict = tokens.model_dump() | token_dict = tokens.model_dump() | ||||
| MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) | MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) | ||||
| def save_code_verifier(self, code_verifier: str) -> None: | |||||
| def save_code_verifier(self, code_verifier: str): | |||||
| """Saves a PKCE code verifier for the current session.""" | """Saves a PKCE code verifier for the current session.""" | ||||
| MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) | MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) | ||||
| headers: dict[str, Any] | None = None, | headers: dict[str, Any] | None = None, | ||||
| timeout: float = 5.0, | timeout: float = 5.0, | ||||
| sse_read_timeout: float = 5 * 60, | sse_read_timeout: float = 5 * 60, | ||||
| ) -> None: | |||||
| ): | |||||
| """Initialize the SSE transport. | """Initialize the SSE transport. | ||||
| Args: | Args: | ||||
| return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme | return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme | ||||
| def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: | |||||
| def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue): | |||||
| """Handle an 'endpoint' SSE event. | """Handle an 'endpoint' SSE event. | ||||
| Args: | Args: | ||||
| status_queue.put(_StatusReady(endpoint_url)) | status_queue.put(_StatusReady(endpoint_url)) | ||||
| def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: | |||||
| def _handle_message_event(self, sse_data: str, read_queue: ReadQueue): | |||||
| """Handle a 'message' SSE event. | """Handle a 'message' SSE event. | ||||
| Args: | Args: | ||||
| logger.exception("Error parsing server message") | logger.exception("Error parsing server message") | ||||
| read_queue.put(exc) | read_queue.put(exc) | ||||
| def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | |||||
| def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue): | |||||
| """Handle a single SSE event. | """Handle a single SSE event. | ||||
| Args: | Args: | ||||
| case _: | case _: | ||||
| logger.warning("Unknown SSE event: %s", sse.event) | logger.warning("Unknown SSE event: %s", sse.event) | ||||
| def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | |||||
| def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue): | |||||
| """Read and process SSE events. | """Read and process SSE events. | ||||
| Args: | Args: | ||||
| finally: | finally: | ||||
| read_queue.put(None) | read_queue.put(None) | ||||
| def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: | |||||
| def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage): | |||||
| """Send a single message to the server. | """Send a single message to the server. | ||||
| Args: | Args: | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| logger.debug("Client message sent successfully: %s", response.status_code) | logger.debug("Client message sent successfully: %s", response.status_code) | ||||
| def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: | |||||
| def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue): | |||||
| """Handle writing messages to the server. | """Handle writing messages to the server. | ||||
| Args: | Args: | ||||
| write_queue.put(None) | write_queue.put(None) | ||||
| def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: | |||||
| def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage): | |||||
| """ | """ | ||||
| Send a message to the server using the provided HTTP client. | Send a message to the server using the provided HTTP client. | ||||
| headers: dict[str, Any] | None = None, | headers: dict[str, Any] | None = None, | ||||
| timeout: float | timedelta = 30, | timeout: float | timedelta = 30, | ||||
| sse_read_timeout: float | timedelta = 60 * 5, | sse_read_timeout: float | timedelta = 60 * 5, | ||||
| ) -> None: | |||||
| ): | |||||
| """Initialize the StreamableHTTP transport. | """Initialize the StreamableHTTP transport. | ||||
| Args: | Args: | ||||
| def _maybe_extract_session_id_from_response( | def _maybe_extract_session_id_from_response( | ||||
| self, | self, | ||||
| response: httpx.Response, | response: httpx.Response, | ||||
| ) -> None: | |||||
| ): | |||||
| """Extract and store session ID from response headers.""" | """Extract and store session ID from response headers.""" | ||||
| new_session_id = response.headers.get(MCP_SESSION_ID) | new_session_id = response.headers.get(MCP_SESSION_ID) | ||||
| if new_session_id: | if new_session_id: | ||||
| self, | self, | ||||
| client: httpx.Client, | client: httpx.Client, | ||||
| server_to_client_queue: ServerToClientQueue, | server_to_client_queue: ServerToClientQueue, | ||||
| ) -> None: | |||||
| ): | |||||
| """Handle GET stream for server-initiated messages.""" | """Handle GET stream for server-initiated messages.""" | ||||
| try: | try: | ||||
| if not self.session_id: | if not self.session_id: | ||||
| except Exception as exc: | except Exception as exc: | ||||
| logger.debug("GET stream error (non-fatal): %s", exc) | logger.debug("GET stream error (non-fatal): %s", exc) | ||||
| def _handle_resumption_request(self, ctx: RequestContext) -> None: | |||||
| def _handle_resumption_request(self, ctx: RequestContext): | |||||
| """Handle a resumption request using GET with SSE.""" | """Handle a resumption request using GET with SSE.""" | ||||
| headers = self._update_headers_with_session(ctx.headers) | headers = self._update_headers_with_session(ctx.headers) | ||||
| if ctx.metadata and ctx.metadata.resumption_token: | if ctx.metadata and ctx.metadata.resumption_token: | ||||
| if is_complete: | if is_complete: | ||||
| break | break | ||||
| def _handle_post_request(self, ctx: RequestContext) -> None: | |||||
| def _handle_post_request(self, ctx: RequestContext): | |||||
| """Handle a POST request with response processing.""" | """Handle a POST request with response processing.""" | ||||
| headers = self._update_headers_with_session(ctx.headers) | headers = self._update_headers_with_session(ctx.headers) | ||||
| message = ctx.session_message.message | message = ctx.session_message.message | ||||
| self, | self, | ||||
| response: httpx.Response, | response: httpx.Response, | ||||
| server_to_client_queue: ServerToClientQueue, | server_to_client_queue: ServerToClientQueue, | ||||
| ) -> None: | |||||
| ): | |||||
| """Handle JSON response from the server.""" | """Handle JSON response from the server.""" | ||||
| try: | try: | ||||
| content = response.read() | content = response.read() | ||||
| except Exception as exc: | except Exception as exc: | ||||
| server_to_client_queue.put(exc) | server_to_client_queue.put(exc) | ||||
| def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: | |||||
| def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext): | |||||
| """Handle SSE response from the server.""" | """Handle SSE response from the server.""" | ||||
| try: | try: | ||||
| event_source = EventSource(response) | event_source = EventSource(response) | ||||
| self, | self, | ||||
| content_type: str, | content_type: str, | ||||
| server_to_client_queue: ServerToClientQueue, | server_to_client_queue: ServerToClientQueue, | ||||
| ) -> None: | |||||
| ): | |||||
| """Handle unexpected content type in response.""" | """Handle unexpected content type in response.""" | ||||
| error_msg = f"Unexpected content type: {content_type}" | error_msg = f"Unexpected content type: {content_type}" | ||||
| logger.error(error_msg) | logger.error(error_msg) | ||||
| self, | self, | ||||
| server_to_client_queue: ServerToClientQueue, | server_to_client_queue: ServerToClientQueue, | ||||
| request_id: RequestId, | request_id: RequestId, | ||||
| ) -> None: | |||||
| ): | |||||
| """Send a session terminated error response.""" | """Send a session terminated error response.""" | ||||
| jsonrpc_error = JSONRPCError( | jsonrpc_error = JSONRPCError( | ||||
| jsonrpc="2.0", | jsonrpc="2.0", | ||||
| client_to_server_queue: ClientToServerQueue, | client_to_server_queue: ClientToServerQueue, | ||||
| server_to_client_queue: ServerToClientQueue, | server_to_client_queue: ServerToClientQueue, | ||||
| start_get_stream: Callable[[], None], | start_get_stream: Callable[[], None], | ||||
| ) -> None: | |||||
| ): | |||||
| """Handle writing requests to the server. | """Handle writing requests to the server. | ||||
| This method processes messages from the client_to_server_queue and sends them to the server. | This method processes messages from the client_to_server_queue and sends them to the server. | ||||
| except Exception as exc: | except Exception as exc: | ||||
| server_to_client_queue.put(exc) | server_to_client_queue.put(exc) | ||||
| def terminate_session(self, client: httpx.Client) -> None: | |||||
| def terminate_session(self, client: httpx.Client): | |||||
| """Terminate the session by sending a DELETE request.""" | """Terminate the session by sending a DELETE request.""" | ||||
| if not self.session_id: | if not self.session_id: | ||||
| return | return | ||||
| timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), | timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), | ||||
| ) as client: | ) as client: | ||||
| # Define callbacks that need access to thread pool | # Define callbacks that need access to thread pool | ||||
| def start_get_stream() -> None: | |||||
| def start_get_stream(): | |||||
| """Start a worker thread to handle server-initiated messages.""" | """Start a worker thread to handle server-initiated messages.""" | ||||
| executor.submit(transport.handle_get_stream, client, server_to_client_queue) | executor.submit(transport.handle_get_stream, client, server_to_client_queue) | ||||
| ReceiveNotificationT | ReceiveNotificationT | ||||
| ]""", | ]""", | ||||
| on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], | on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], | ||||
| ) -> None: | |||||
| ): | |||||
| self.request_id = request_id | self.request_id = request_id | ||||
| self.request_meta = request_meta | self.request_meta = request_meta | ||||
| self.request = request | self.request = request | ||||
| exc_type: type[BaseException] | None, | exc_type: type[BaseException] | None, | ||||
| exc_val: BaseException | None, | exc_val: BaseException | None, | ||||
| exc_tb: TracebackType | None, | exc_tb: TracebackType | None, | ||||
| ) -> None: | |||||
| ): | |||||
| """Exit the context manager, performing cleanup and notifying completion.""" | """Exit the context manager, performing cleanup and notifying completion.""" | ||||
| try: | try: | ||||
| if self._completed: | if self._completed: | ||||
| finally: | finally: | ||||
| self._entered = False | self._entered = False | ||||
| def respond(self, response: SendResultT | ErrorData) -> None: | |||||
| def respond(self, response: SendResultT | ErrorData): | |||||
| """Send a response for this request. | """Send a response for this request. | ||||
| Must be called within a context manager block. | Must be called within a context manager block. | ||||
| self._session._send_response(request_id=self.request_id, response=response) | self._session._send_response(request_id=self.request_id, response=response) | ||||
| def cancel(self) -> None: | |||||
| def cancel(self): | |||||
| """Cancel this request and mark it as completed.""" | """Cancel this request and mark it as completed.""" | ||||
| if not self._entered: | if not self._entered: | ||||
| raise RuntimeError("RequestResponder must be used as a context manager") | raise RuntimeError("RequestResponder must be used as a context manager") | ||||
| receive_notification_type: type[ReceiveNotificationT], | receive_notification_type: type[ReceiveNotificationT], | ||||
| # If none, reading will never time out | # If none, reading will never time out | ||||
| read_timeout_seconds: timedelta | None = None, | read_timeout_seconds: timedelta | None = None, | ||||
| ) -> None: | |||||
| ): | |||||
| self._read_stream = read_stream | self._read_stream = read_stream | ||||
| self._write_stream = write_stream | self._write_stream = write_stream | ||||
| self._response_streams = {} | self._response_streams = {} | ||||
| self._receiver_future = self._executor.submit(self._receive_loop) | self._receiver_future = self._executor.submit(self._receive_loop) | ||||
| return self | return self | ||||
| def check_receiver_status(self) -> None: | |||||
| def check_receiver_status(self): | |||||
| """`check_receiver_status` ensures that any exceptions raised during the | """`check_receiver_status` ensures that any exceptions raised during the | ||||
| execution of `_receive_loop` are retrieved and propagated.""" | execution of `_receive_loop` are retrieved and propagated.""" | ||||
| if self._receiver_future and self._receiver_future.done(): | if self._receiver_future and self._receiver_future.done(): | ||||
| def __exit__( | def __exit__( | ||||
| self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | ||||
| ) -> None: | |||||
| ): | |||||
| self._read_stream.put(None) | self._read_stream.put(None) | ||||
| self._write_stream.put(None) | self._write_stream.put(None) | ||||
| self, | self, | ||||
| notification: SendNotificationT, | notification: SendNotificationT, | ||||
| related_request_id: RequestId | None = None, | related_request_id: RequestId | None = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Emits a notification, which is a one-way message that does not expect | Emits a notification, which is a one-way message that does not expect | ||||
| a response. | a response. | ||||
| ) | ) | ||||
| self._write_stream.put(session_message) | self._write_stream.put(session_message) | ||||
| def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: | |||||
| def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData): | |||||
| if isinstance(response, ErrorData): | if isinstance(response, ErrorData): | ||||
| jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) | jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) | ||||
| session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) | session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) | ||||
| session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) | session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) | ||||
| self._write_stream.put(session_message) | self._write_stream.put(session_message) | ||||
| def _receive_loop(self) -> None: | |||||
| def _receive_loop(self): | |||||
| """ | """ | ||||
| Main message processing loop. | Main message processing loop. | ||||
| In a real synchronous implementation, this would likely run in a separate thread. | In a real synchronous implementation, this would likely run in a separate thread. | ||||
| logger.exception("Error in message processing loop") | logger.exception("Error in message processing loop") | ||||
| raise | raise | ||||
| def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: | |||||
| def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]): | |||||
| """ | """ | ||||
| Can be overridden by subclasses to handle a request without needing to | Can be overridden by subclasses to handle a request without needing to | ||||
| listen on the message stream. | listen on the message stream. | ||||
| forwarded on to the message stream. | forwarded on to the message stream. | ||||
| """ | """ | ||||
| def _received_notification(self, notification: ReceiveNotificationT) -> None: | |||||
| def _received_notification(self, notification: ReceiveNotificationT): | |||||
| """ | """ | ||||
| Can be overridden by subclasses to handle a notification without needing | Can be overridden by subclasses to handle a notification without needing | ||||
| to listen on the message stream. | to listen on the message stream. | ||||
| """ | """ | ||||
| def send_progress_notification( | |||||
| self, progress_token: str | int, progress: float, total: float | None = None | |||||
| ) -> None: | |||||
| def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): | |||||
| """ | """ | ||||
| Sends a progress notification for a request that is currently being | Sends a progress notification for a request that is currently being | ||||
| processed. | processed. | ||||
| def _handle_incoming( | def _handle_incoming( | ||||
| self, | self, | ||||
| req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, | req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, | ||||
| ) -> None: | |||||
| ): | |||||
| """A generic handler for incoming messages. Overwritten by subclasses.""" | """A generic handler for incoming messages. Overwritten by subclasses.""" |
| def __call__( | def __call__( | ||||
| self, | self, | ||||
| params: types.LoggingMessageNotificationParams, | params: types.LoggingMessageNotificationParams, | ||||
| ) -> None: ... | |||||
| ): ... | |||||
| class MessageHandlerFnT(Protocol): | class MessageHandlerFnT(Protocol): | ||||
| def __call__( | def __call__( | ||||
| self, | self, | ||||
| message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | ||||
| ) -> None: ... | |||||
| ): ... | |||||
| def _default_message_handler( | def _default_message_handler( | ||||
| message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | ||||
| ) -> None: | |||||
| ): | |||||
| if isinstance(message, Exception): | if isinstance(message, Exception): | ||||
| raise ValueError(str(message)) | raise ValueError(str(message)) | ||||
| elif isinstance(message, (types.ServerNotification | RequestResponder)): | elif isinstance(message, (types.ServerNotification | RequestResponder)): | ||||
| def _default_logging_callback( | def _default_logging_callback( | ||||
| params: types.LoggingMessageNotificationParams, | params: types.LoggingMessageNotificationParams, | ||||
| ) -> None: | |||||
| ): | |||||
| pass | pass | ||||
| logging_callback: LoggingFnT | None = None, | logging_callback: LoggingFnT | None = None, | ||||
| message_handler: MessageHandlerFnT | None = None, | message_handler: MessageHandlerFnT | None = None, | ||||
| client_info: types.Implementation | None = None, | client_info: types.Implementation | None = None, | ||||
| ) -> None: | |||||
| ): | |||||
| super().__init__( | super().__init__( | ||||
| read_stream, | read_stream, | ||||
| write_stream, | write_stream, | ||||
| types.EmptyResult, | types.EmptyResult, | ||||
| ) | ) | ||||
| def send_progress_notification( | |||||
| self, progress_token: str | int, progress: float, total: float | None = None | |||||
| ) -> None: | |||||
| def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): | |||||
| """Send a progress notification.""" | """Send a progress notification.""" | ||||
| self.send_notification( | self.send_notification( | ||||
| types.ClientNotification( | types.ClientNotification( | ||||
| types.ListToolsResult, | types.ListToolsResult, | ||||
| ) | ) | ||||
| def send_roots_list_changed(self) -> None: | |||||
| def send_roots_list_changed(self): | |||||
| """Send a roots/list_changed notification.""" | """Send a roots/list_changed notification.""" | ||||
| self.send_notification( | self.send_notification( | ||||
| types.ClientNotification( | types.ClientNotification( | ||||
| ) | ) | ||||
| ) | ) | ||||
| def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: | |||||
| def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]): | |||||
| ctx = RequestContext[ClientSession, Any]( | ctx = RequestContext[ClientSession, Any]( | ||||
| request_id=responder.request_id, | request_id=responder.request_id, | ||||
| meta=responder.request_meta, | meta=responder.request_meta, | ||||
| def _handle_incoming( | def _handle_incoming( | ||||
| self, | self, | ||||
| req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | ||||
| ) -> None: | |||||
| ): | |||||
| """Handle incoming messages by forwarding to the message handler.""" | """Handle incoming messages by forwarding to the message handler.""" | ||||
| self._message_handler(req) | self._message_handler(req) | ||||
| def _received_notification(self, notification: types.ServerNotification) -> None: | |||||
| def _received_notification(self, notification: types.ServerNotification): | |||||
| """Handle notifications from the server.""" | """Handle notifications from the server.""" | ||||
| # Process specific notification types | # Process specific notification types | ||||
| match notification.root: | match notification.root: |
| self, | self, | ||||
| conversation: Conversation, | conversation: Conversation, | ||||
| model_instance: ModelInstance, | model_instance: ModelInstance, | ||||
| ) -> None: | |||||
| ): | |||||
| self.conversation = conversation | self.conversation = conversation | ||||
| self.model_instance = model_instance | self.model_instance = model_instance | ||||
| Model instance class | Model instance class | ||||
| """ | """ | ||||
| def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: | |||||
| def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): | |||||
| self.provider_model_bundle = provider_model_bundle | self.provider_model_bundle = provider_model_bundle | ||||
| self.model = model | self.model = model | ||||
| self.provider = provider_model_bundle.configuration.provider.provider | self.provider = provider_model_bundle.configuration.provider.provider | ||||
| ) | ) | ||||
| @staticmethod | @staticmethod | ||||
| def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: | |||||
| def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): | |||||
| """ | """ | ||||
| Fetch credentials from provider model bundle | Fetch credentials from provider model bundle | ||||
| :param provider_model_bundle: provider model bundle | :param provider_model_bundle: provider model bundle | ||||
| ), | ), | ||||
| ) | ) | ||||
| def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: | |||||
| def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): | |||||
| """ | """ | ||||
| Round-robin invoke | Round-robin invoke | ||||
| :param function: function to invoke | :param function: function to invoke | ||||
| except Exception as e: | except Exception as e: | ||||
| raise e | raise e | ||||
| def get_tts_voices(self, language: Optional[str] = None) -> list: | |||||
| def get_tts_voices(self, language: Optional[str] = None): | |||||
| """ | """ | ||||
| Invoke large language tts model voices | Invoke large language tts model voices | ||||
| class ModelManager: | class ModelManager: | ||||
| def __init__(self) -> None: | |||||
| def __init__(self): | |||||
| self._provider_manager = ProviderManager() | self._provider_manager = ProviderManager() | ||||
| def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: | def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: | ||||
| model: str, | model: str, | ||||
| load_balancing_configs: list[ModelLoadBalancingConfiguration], | load_balancing_configs: list[ModelLoadBalancingConfiguration], | ||||
| managed_credentials: Optional[dict] = None, | managed_credentials: Optional[dict] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Load balancing model manager | Load balancing model manager | ||||
| :param tenant_id: tenant_id | :param tenant_id: tenant_id | ||||
| return config | return config | ||||
| def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: | |||||
| def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60): | |||||
| """ | """ | ||||
| Cooldown model load balancing config | Cooldown model load balancing config | ||||
| :param config: model load balancing config | :param config: model load balancing config |
| stop: Optional[Sequence[str]] = None, | stop: Optional[Sequence[str]] = None, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Before invoke callback | Before invoke callback | ||||
| stop: Optional[Sequence[str]] = None, | stop: Optional[Sequence[str]] = None, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| After invoke callback | After invoke callback | ||||
| stop: Optional[Sequence[str]] = None, | stop: Optional[Sequence[str]] = None, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Invoke error callback | Invoke error callback | ||||
| """ | """ | ||||
| raise NotImplementedError() | raise NotImplementedError() | ||||
| def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: | |||||
| def print_text(self, text: str, color: Optional[str] = None, end: str = ""): | |||||
| """Print text with highlighting and no end characters.""" | """Print text with highlighting and no end characters.""" | ||||
| text_to_print = self._get_colored_text(text, color) if color else text | text_to_print = self._get_colored_text(text, color) if color else text | ||||
| print(text_to_print, end=end) | print(text_to_print, end=end) |
| stop: Optional[Sequence[str]] = None, | stop: Optional[Sequence[str]] = None, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Before invoke callback | Before invoke callback | ||||
| stop: Optional[Sequence[str]] = None, | stop: Optional[Sequence[str]] = None, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| After invoke callback | After invoke callback | ||||
| stop: Optional[Sequence[str]] = None, | stop: Optional[Sequence[str]] = None, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Invoke error callback | Invoke error callback | ||||
| description: Optional[str] = None | description: Optional[str] = None | ||||
| def __init__(self, description: Optional[str] = None) -> None: | |||||
| def __init__(self, description: Optional[str] = None): | |||||
| self.description = description | self.description = description | ||||
| def __str__(self): | def __str__(self): |
| """ | """ | ||||
| return None | return None | ||||
| def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict: | |||||
| def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): | |||||
| """ | """ | ||||
| Get default parameter rule for given name | Get default parameter rule for given name | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Trigger before invoke callbacks | Trigger before invoke callbacks | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Trigger new chunk callbacks | Trigger new chunk callbacks | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Trigger after invoke callbacks | Trigger after invoke callbacks | ||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| ) -> None: | |||||
| ): | |||||
| """ | """ | ||||
| Trigger invoke error callbacks | Trigger invoke error callbacks | ||||
| return GPT2Tokenizer._get_num_tokens_by_gpt2(text) | return GPT2Tokenizer._get_num_tokens_by_gpt2(text) | ||||
| @staticmethod | @staticmethod | ||||
| def get_encoder() -> Any: | |||||
| def get_encoder(): | |||||
| global _tokenizer, _lock | global _tokenizer, _lock | ||||
| if _tokenizer is not None: | if _tokenizer is not None: | ||||
| return _tokenizer | return _tokenizer |
| except Exception as e: | except Exception as e: | ||||
| raise self._transform_invoke_error(e) | raise self._transform_invoke_error(e) | ||||
| def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]: | |||||
| def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None): | |||||
| """ | """ | ||||
| Retrieves the list of voices supported by a given text-to-speech (TTS) model. | Retrieves the list of voices supported by a given text-to-speech (TTS) model. | ||||
| return plugin_model_provider_entity | return plugin_model_provider_entity | ||||
| def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: | |||||
| def provider_credentials_validate(self, *, provider: str, credentials: dict): | |||||
| """ | """ | ||||
| Validate provider credentials | Validate provider credentials | ||||
| return filtered_credentials | return filtered_credentials | ||||
| def model_credentials_validate( | |||||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | |||||
| ) -> dict: | |||||
| def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): | |||||
| """ | """ | ||||
| Validate model credentials | Validate model credentials | ||||
| class CommonValidator: | class CommonValidator: | ||||
| def _validate_and_filter_credential_form_schemas( | def _validate_and_filter_credential_form_schemas( | ||||
| self, credential_form_schemas: list[CredentialFormSchema], credentials: dict | self, credential_form_schemas: list[CredentialFormSchema], credentials: dict | ||||
| ) -> dict: | |||||
| ): | |||||
| need_validate_credential_form_schema_map = {} | need_validate_credential_form_schema_map = {} | ||||
| for credential_form_schema in credential_form_schemas: | for credential_form_schema in credential_form_schemas: | ||||
| if not credential_form_schema.show_on: | if not credential_form_schema.show_on: |
| self.model_type = model_type | self.model_type = model_type | ||||
| self.model_credential_schema = model_credential_schema | self.model_credential_schema = model_credential_schema | ||||
| def validate_and_filter(self, credentials: dict) -> dict: | |||||
| def validate_and_filter(self, credentials: dict): | |||||
| """ | """ | ||||
| Validate model credentials | Validate model credentials | ||||
| def __init__(self, provider_credential_schema: ProviderCredentialSchema): | def __init__(self, provider_credential_schema: ProviderCredentialSchema): | ||||
| self.provider_credential_schema = provider_credential_schema | self.provider_credential_schema = provider_credential_schema | ||||
| def validate_and_filter(self, credentials: dict) -> dict: | |||||
| def validate_and_filter(self, credentials: dict): | |||||
| """ | """ | ||||
| Validate provider credentials | Validate provider credentials | ||||
| from pydantic_extra_types.color import Color | from pydantic_extra_types.color import Color | ||||
| def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: | |||||
| def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any): | |||||
| return model.model_dump(mode=mode, **kwargs) | return model.model_dump(mode=mode, **kwargs) | ||||
| exclude_none: bool = False, | exclude_none: bool = False, | ||||
| custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, | custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, | ||||
| sqlalchemy_safe: bool = True, | sqlalchemy_safe: bool = True, | ||||
| ) -> Any: | |||||
| ): | |||||
| custom_encoder = custom_encoder or {} | custom_encoder = custom_encoder or {} | ||||
| if custom_encoder: | if custom_encoder: | ||||
| if type(obj) in custom_encoder: | if type(obj) in custom_encoder: |
| name: str = "api" | name: str = "api" | ||||
| @classmethod | @classmethod | ||||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response | flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response | ||||
| ) | ) | ||||
| def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: | |||||
| def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict): | |||||
| if self.config is None: | if self.config is None: | ||||
| raise ValueError("The config is not set.") | raise ValueError("The config is not set.") | ||||
| extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) | extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) |
| module: ExtensionModule = ExtensionModule.MODERATION | module: ExtensionModule = ExtensionModule.MODERATION | ||||
| def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: | |||||
| def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None): | |||||
| super().__init__(tenant_id, config) | super().__init__(tenant_id, config) | ||||
| self.app_id = app_id | self.app_id = app_id | ||||
| @classmethod | @classmethod | ||||
| @abstractmethod | @abstractmethod | ||||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @classmethod | @classmethod | ||||
| def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: | |||||
| def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool): | |||||
| # inputs_config | # inputs_config | ||||
| inputs_config = config.get("inputs_config") | inputs_config = config.get("inputs_config") | ||||
| if not isinstance(inputs_config, dict): | if not isinstance(inputs_config, dict): |
| class ModerationFactory: | class ModerationFactory: | ||||
| __extension_instance: Moderation | __extension_instance: Moderation | ||||
| def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None: | |||||
| def __init__(self, name: str, app_id: str, tenant_id: str, config: dict): | |||||
| extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) | extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) | ||||
| self.__extension_instance = extension_class(app_id, tenant_id, config) | self.__extension_instance = extension_class(app_id, tenant_id, config) | ||||
| @classmethod | @classmethod | ||||
| def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, name: str, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| name: str = "keywords" | name: str = "keywords" | ||||
| @classmethod | @classmethod | ||||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| name: str = "openai_moderation" | name: str = "openai_moderation" | ||||
| @classmethod | @classmethod | ||||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||||
| def validate_config(cls, tenant_id: str, config: dict): | |||||
| """ | """ | ||||
| Validate the incoming form config data. | Validate the incoming form config data. | ||||
| def get_final_output(self) -> str: | def get_final_output(self) -> str: | ||||
| return self.final_output or "" | return self.final_output or "" | ||||
| def append_new_token(self, token: str) -> None: | |||||
| def append_new_token(self, token: str): | |||||
| self.buffer += token | self.buffer += token | ||||
| if not self.thread: | if not self.thread: |
| class PluginEncrypter: | class PluginEncrypter: | ||||
| @classmethod | @classmethod | ||||
| def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: | |||||
| def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt): | |||||
| encrypter, cache = create_provider_encrypter( | encrypter, cache = create_provider_encrypter( | ||||
| tenant_id=tenant.id, | tenant_id=tenant.id, | ||||
| config=payload.config, | config=payload.config, |
| model_config: ParameterExtractorModelConfig, | model_config: ParameterExtractorModelConfig, | ||||
| instruction: str, | instruction: str, | ||||
| query: str, | query: str, | ||||
| ) -> dict: | |||||
| ): | |||||
| """ | """ | ||||
| Invoke parameter extractor node. | Invoke parameter extractor node. | ||||
| classes: list[ClassConfig], | classes: list[ClassConfig], | ||||
| instruction: str, | instruction: str, | ||||
| query: str, | query: str, | ||||
| ) -> dict: | |||||
| ): | |||||
| """ | """ | ||||
| Invoke question classifier node. | Invoke question classifier node. | ||||
| @model_validator(mode="before") | @model_validator(mode="before") | ||||
| @classmethod | @classmethod | ||||
| def validate_category(cls, values: dict) -> dict: | |||||
| def validate_category(cls, values: dict): | |||||
| # auto detect category | # auto detect category | ||||
| if values.get("tool"): | if values.get("tool"): | ||||
| values["category"] = PluginCategory.Tool | values["category"] = PluginCategory.Tool |
| Fetch agent providers for the given tenant. | Fetch agent providers for the given tenant. | ||||
| """ | """ | ||||
| def transformer(json_response: dict[str, Any]) -> dict: | |||||
| def transformer(json_response: dict[str, Any]): | |||||
| for provider in json_response.get("data", []): | for provider in json_response.get("data", []): | ||||
| declaration = provider.get("declaration", {}) or {} | declaration = provider.get("declaration", {}) or {} | ||||
| provider_name = declaration.get("identity", {}).get("name") | provider_name = declaration.get("identity", {}).get("name") | ||||
| """ | """ | ||||
| agent_provider_id = GenericProviderID(provider) | agent_provider_id = GenericProviderID(provider) | ||||
| def transformer(json_response: dict[str, Any]) -> dict: | |||||
| def transformer(json_response: dict[str, Any]): | |||||
| # skip if error occurs | # skip if error occurs | ||||
| if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None: | if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None: | ||||
| return json_response | return json_response |
| class PluginDaemonError(Exception): | class PluginDaemonError(Exception): | ||||
| """Base class for all plugin daemon errors.""" | """Base class for all plugin daemon errors.""" | ||||
| def __init__(self, description: str) -> None: | |||||
| def __init__(self, description: str): | |||||
| self.description = description | self.description = description | ||||
| def __str__(self) -> str: | def __str__(self) -> str: |
| model: str, | model: str, | ||||
| credentials: dict, | credentials: dict, | ||||
| language: Optional[str] = None, | language: Optional[str] = None, | ||||
| ) -> list[dict]: | |||||
| ): | |||||
| """ | """ | ||||
| Get tts model voices | Get tts model voices | ||||
| """ | """ |
| Fetch tool providers for the given tenant. | Fetch tool providers for the given tenant. | ||||
| """ | """ | ||||
| def transformer(json_response: dict[str, Any]) -> dict: | |||||
| def transformer(json_response: dict[str, Any]): | |||||
| for provider in json_response.get("data", []): | for provider in json_response.get("data", []): | ||||
| declaration = provider.get("declaration", {}) or {} | declaration = provider.get("declaration", {}) or {} | ||||
| provider_name = declaration.get("identity", {}).get("name") | provider_name = declaration.get("identity", {}).get("name") | ||||
| """ | """ | ||||
| tool_provider_id = ToolProviderID(provider) | tool_provider_id = ToolProviderID(provider) | ||||
| def transformer(json_response: dict[str, Any]) -> dict: | |||||
| def transformer(json_response: dict[str, Any]): | |||||
| data = json_response.get("data") | data = json_response.get("data") | ||||
| if data: | if data: | ||||
| for tool in data.get("declaration", {}).get("tools", []): | for tool in data.get("declaration", {}).get("tools", []): |
| bytes_written: int = field(default=0, init=False) | bytes_written: int = field(default=0, init=False) | ||||
| data: bytearray = field(init=False) | data: bytearray = field(init=False) | ||||
| def __post_init__(self) -> None: | |||||
| def __post_init__(self): | |||||
| self.data = bytearray(self.total_length) | self.data = bytearray(self.total_length) | ||||