| @@ -20,7 +20,7 @@ from fields.conversation_fields import ( | |||
| conversation_pagination_fields, | |||
| conversation_with_summary_pagination_fields, | |||
| ) | |||
| from libs.helper import datetime_string | |||
| from libs.helper import DatetimeString | |||
| from libs.login import login_required | |||
| from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation | |||
| @@ -36,8 +36,8 @@ class CompletionConversationApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("keyword", type=str, location="args") | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument( | |||
| "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" | |||
| ) | |||
| @@ -143,8 +143,8 @@ class ChatConversationApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("keyword", type=str, location="args") | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument( | |||
| "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" | |||
| ) | |||
| @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from extensions.ext_database import db | |||
| from libs.helper import datetime_string | |||
| from libs.helper import DatetimeString | |||
| from libs.login import login_required | |||
| from models.model import AppMode | |||
| @@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| @@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from extensions.ext_database import db | |||
| from libs.helper import datetime_string | |||
| from libs.helper import DatetimeString | |||
| from libs.login import login_required | |||
| from models.model import AppMode | |||
| from models.workflow import WorkflowRunTriggeredFrom | |||
| @@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -8,7 +8,7 @@ from constants.languages import supported_language | |||
| from controllers.console import api | |||
| from controllers.console.error import AlreadyActivateError | |||
| from extensions.ext_database import db | |||
| from libs.helper import email, str_len, timezone | |||
| from libs.helper import StrLen, email, timezone | |||
| from libs.password import hash_password, valid_password | |||
| from models.account import AccountStatus | |||
| from services.account_service import RegisterService | |||
| @@ -37,7 +37,7 @@ class ActivateApi(Resource): | |||
| parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("email", type=email, required=False, nullable=True, location="json") | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "interface_language", type=supported_language, required=True, nullable=False, location="json" | |||
| @@ -4,7 +4,7 @@ from flask import session | |||
| from flask_restful import Resource, reqparse | |||
| from configs import dify_config | |||
| from libs.helper import str_len | |||
| from libs.helper import StrLen | |||
| from models.model import DifySetup | |||
| from services.account_service import TenantService | |||
| @@ -28,7 +28,7 @@ class InitValidateAPI(Resource): | |||
| raise AlreadySetupError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("password", type=str_len(30), required=True, location="json") | |||
| parser.add_argument("password", type=StrLen(30), required=True, location="json") | |||
| input_password = parser.parse_args()["password"] | |||
| if input_password != os.environ.get("INIT_PASSWORD"): | |||
| @@ -4,7 +4,7 @@ from flask import request | |||
| from flask_restful import Resource, reqparse | |||
| from configs import dify_config | |||
| from libs.helper import email, get_remote_ip, str_len | |||
| from libs.helper import StrLen, email, get_remote_ip | |||
| from libs.password import valid_password | |||
| from models.model import DifySetup | |||
| from services.account_service import RegisterService, TenantService | |||
| @@ -40,7 +40,7 @@ class SetupApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| parser.add_argument("name", type=str_len(30), required=True, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, location="json") | |||
| parser.add_argument("password", type=valid_password, required=True, location="json") | |||
| args = parser.parse_args() | |||
| @@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig | |||
| from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner | |||
| from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter | |||
| from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| @@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| ) | |||
| runner.run() | |||
| except GenerateTaskStoppedException: | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| @@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| return generate_task_pipeline.process() | |||
| except ValueError as e: | |||
| if e.args[0] == "I/O operation on closed file.": # ignore this error | |||
| raise GenerateTaskStoppedException() | |||
| raise GenerateTaskStoppedError() | |||
| else: | |||
| logger.exception(e) | |||
| raise e | |||
| @@ -21,7 +21,7 @@ class AudioTrunk: | |||
| self.status = status | |||
| def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): | |||
| def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): | |||
| if not text_content or text_content.isspace(): | |||
| return | |||
| return model_instance.invoke_tts( | |||
| @@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher: | |||
| if message is None: | |||
| if self.msg_text and len(self.msg_text.strip()) > 0: | |||
| futures_result = self.executor.submit( | |||
| _invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice | |||
| _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice | |||
| ) | |||
| future_queue.put(futures_result) | |||
| break | |||
| @@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher: | |||
| self.MAX_SENTENCE += 1 | |||
| text_content = "".join(sentence_arr) | |||
| futures_result = self.executor.submit( | |||
| _invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice | |||
| _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice | |||
| ) | |||
| future_queue.put(futures_result) | |||
| if text_tmp: | |||
| @@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher: | |||
| break | |||
| future_queue.put(None) | |||
| def checkAndGetAudio(self) -> AudioTrunk | None: | |||
| def check_and_get_audio(self) -> AudioTrunk | None: | |||
| try: | |||
| if self._last_audio_event and self._last_audio_event.status == "finish": | |||
| if self.executor: | |||
| @@ -19,7 +19,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| ) | |||
| from core.moderation.base import ModerationException | |||
| from core.moderation.base import ModerationError | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| query=query, | |||
| message_id=message_id, | |||
| ) | |||
| except ModerationException as e: | |||
| except ModerationError as e: | |||
| self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) | |||
| return True | |||
| @@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| stream_response=stream_response, | |||
| ) | |||
| def _listenAudioMsg(self, publisher, task_id: str): | |||
| def _listen_audio_msg(self, publisher, task_id: str): | |||
| if not publisher: | |||
| return None | |||
| audio_msg: AudioTrunk = publisher.checkAndGetAudio() | |||
| audio_msg: AudioTrunk = publisher.check_and_get_audio() | |||
| if audio_msg and audio_msg.status != "finish": | |||
| return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) | |||
| return None | |||
| @@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | |||
| while True: | |||
| audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) | |||
| audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) | |||
| if audio_response: | |||
| yield audio_response | |||
| else: | |||
| @@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| try: | |||
| if not tts_publisher: | |||
| break | |||
| audio_trunk = tts_publisher.checkAndGetAudio() | |||
| audio_trunk = tts_publisher.check_and_get_audio() | |||
| if audio_trunk is None: | |||
| # release cpu | |||
| # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) | |||
| @@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan | |||
| from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager | |||
| from core.app.apps.agent_chat.app_runner import AgentChatAppRunner | |||
| from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom | |||
| @@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| message=message, | |||
| ) | |||
| except GenerateTaskStoppedException: | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| @@ -15,7 +15,7 @@ from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.moderation.base import ModerationException | |||
| from core.moderation.base import ModerationError | |||
| from core.tools.entities.tool_entities import ToolRuntimeVariablePool | |||
| from extensions.ext_database import db | |||
| from models.model import App, Conversation, Message, MessageAgentThought | |||
| @@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner): | |||
| query=query, | |||
| message_id=message.id, | |||
| ) | |||
| except ModerationException as e: | |||
| except ModerationError as e: | |||
| self.direct_output( | |||
| queue_manager=queue_manager, | |||
| app_generate_entity=application_generate_entity, | |||
| @@ -171,5 +171,5 @@ class AppQueueManager: | |||
| ) | |||
| class GenerateTaskStoppedException(Exception): | |||
| class GenerateTaskStoppedError(Exception): | |||
| pass | |||
| @@ -10,7 +10,7 @@ from pydantic import ValidationError | |||
| from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.apps.chat.app_config_manager import ChatAppConfigManager | |||
| from core.app.apps.chat.app_runner import ChatAppRunner | |||
| from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter | |||
| @@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| message=message, | |||
| ) | |||
| except GenerateTaskStoppedException: | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| @@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.moderation.base import ModerationException | |||
| from core.moderation.base import ModerationError | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from extensions.ext_database import db | |||
| from models.model import App, Conversation, Message | |||
| @@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner): | |||
| query=query, | |||
| message_id=message.id, | |||
| ) | |||
| except ModerationException as e: | |||
| except ModerationError as e: | |||
| self.direct_output( | |||
| queue_manager=queue_manager, | |||
| app_generate_entity=application_generate_entity, | |||
| @@ -10,7 +10,7 @@ from pydantic import ValidationError | |||
| from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.apps.completion.app_config_manager import CompletionAppConfigManager | |||
| from core.app.apps.completion.app_runner import CompletionAppRunner | |||
| from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter | |||
| @@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| queue_manager=queue_manager, | |||
| message=message, | |||
| ) | |||
| except GenerateTaskStoppedException: | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| @@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import ( | |||
| ) | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.model_manager import ModelInstance | |||
| from core.moderation.base import ModerationException | |||
| from core.moderation.base import ModerationError | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from extensions.ext_database import db | |||
| from models.model import App, Message | |||
| @@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner): | |||
| query=query, | |||
| message_id=message.id, | |||
| ) | |||
| except ModerationException as e: | |||
| except ModerationError as e: | |||
| self.direct_output( | |||
| queue_manager=queue_manager, | |||
| app_generate_entity=application_generate_entity, | |||
| @@ -8,7 +8,7 @@ from sqlalchemy import and_ | |||
| from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom | |||
| from core.app.apps.base_app_generator import BaseAppGenerator | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| AgentChatAppGenerateEntity, | |||
| @@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| return generate_task_pipeline.process() | |||
| except ValueError as e: | |||
| if e.args[0] == "I/O operation on closed file.": # ignore this error | |||
| raise GenerateTaskStoppedException() | |||
| raise GenerateTaskStoppedError() | |||
| else: | |||
| logger.exception(e) | |||
| raise e | |||
| @@ -1,4 +1,4 @@ | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| @@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager): | |||
| self.stop_listen() | |||
| if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): | |||
| raise GenerateTaskStoppedException() | |||
| raise GenerateTaskStoppedError() | |||
| @@ -12,7 +12,7 @@ from pydantic import ValidationError | |||
| import contexts | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.apps.base_app_generator import BaseAppGenerator | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager | |||
| from core.app.apps.workflow.app_runner import WorkflowAppRunner | |||
| @@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| ) | |||
| runner.run() | |||
| except GenerateTaskStoppedException: | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| @@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| return generate_task_pipeline.process() | |||
| except ValueError as e: | |||
| if e.args[0] == "I/O operation on closed file.": # ignore this error | |||
| raise GenerateTaskStoppedException() | |||
| raise GenerateTaskStoppedError() | |||
| else: | |||
| logger.exception(e) | |||
| raise e | |||
| @@ -1,4 +1,4 @@ | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| @@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager): | |||
| self.stop_listen() | |||
| if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): | |||
| raise GenerateTaskStoppedException() | |||
| raise GenerateTaskStoppedError() | |||
| @@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) | |||
| def _listenAudioMsg(self, publisher, task_id: str): | |||
| def _listen_audio_msg(self, publisher, task_id: str): | |||
| if not publisher: | |||
| return None | |||
| audio_msg: AudioTrunk = publisher.checkAndGetAudio() | |||
| audio_msg: AudioTrunk = publisher.check_and_get_audio() | |||
| if audio_msg and audio_msg.status != "finish": | |||
| return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) | |||
| return None | |||
| @@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | |||
| while True: | |||
| audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) | |||
| audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) | |||
| if audio_response: | |||
| yield audio_response | |||
| else: | |||
| @@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| try: | |||
| if not tts_publisher: | |||
| break | |||
| audio_trunk = tts_publisher.checkAndGetAudio() | |||
| audio_trunk = tts_publisher.check_and_get_audio() | |||
| if audio_trunk is None: | |||
| # release cpu | |||
| # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) | |||
| @@ -15,6 +15,7 @@ class Segment(BaseModel): | |||
| value: Any | |||
| @field_validator("value_type") | |||
| @classmethod | |||
| def validate_value_type(cls, value): | |||
| """ | |||
| This validator checks if the provided value is equal to the default value of the 'value_type' field. | |||
| @@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| stream_response=stream_response, | |||
| ) | |||
| def _listenAudioMsg(self, publisher, task_id: str): | |||
| def _listen_audio_msg(self, publisher, task_id: str): | |||
| if publisher is None: | |||
| return None | |||
| audio_msg: AudioTrunk = publisher.checkAndGetAudio() | |||
| audio_msg: AudioTrunk = publisher.check_and_get_audio() | |||
| if audio_msg and audio_msg.status != "finish": | |||
| # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') | |||
| return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) | |||
| @@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) | |||
| for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): | |||
| while True: | |||
| audio_response = self._listenAudioMsg(publisher, task_id) | |||
| audio_response = self._listen_audio_msg(publisher, task_id) | |||
| if audio_response: | |||
| yield audio_response | |||
| else: | |||
| @@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: | |||
| if publisher is None: | |||
| break | |||
| audio = publisher.checkAndGetAudio() | |||
| audio = publisher.check_and_get_audio() | |||
| if audio is None: | |||
| # release cpu | |||
| # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) | |||
| @@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer | |||
| logger = logging.getLogger(__name__) | |||
| class CodeExecutionException(Exception): | |||
| class CodeExecutionError(Exception): | |||
| pass | |||
| @@ -86,15 +86,15 @@ class CodeExecutor: | |||
| ), | |||
| ) | |||
| if response.status_code == 503: | |||
| raise CodeExecutionException("Code execution service is unavailable") | |||
| raise CodeExecutionError("Code execution service is unavailable") | |||
| elif response.status_code != 200: | |||
| raise Exception( | |||
| f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running" | |||
| ) | |||
| except CodeExecutionException as e: | |||
| except CodeExecutionError as e: | |||
| raise e | |||
| except Exception as e: | |||
| raise CodeExecutionException( | |||
| raise CodeExecutionError( | |||
| "Failed to execute code, which is likely a network issue," | |||
| " please check if the sandbox service is running." | |||
| f" ( Error: {str(e)} )" | |||
| @@ -103,15 +103,15 @@ class CodeExecutor: | |||
| try: | |||
| response = response.json() | |||
| except: | |||
| raise CodeExecutionException("Failed to parse response") | |||
| raise CodeExecutionError("Failed to parse response") | |||
| if (code := response.get("code")) != 0: | |||
| raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") | |||
| raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") | |||
| response = CodeExecutionResponse(**response) | |||
| if response.data.error: | |||
| raise CodeExecutionException(response.data.error) | |||
| raise CodeExecutionError(response.data.error) | |||
| return response.data.stdout or "" | |||
| @@ -126,13 +126,13 @@ class CodeExecutor: | |||
| """ | |||
| template_transformer = cls.code_template_transformers.get(language) | |||
| if not template_transformer: | |||
| raise CodeExecutionException(f"Unsupported language {language}") | |||
| raise CodeExecutionError(f"Unsupported language {language}") | |||
| runner, preload = template_transformer.transform_caller(code, inputs) | |||
| try: | |||
| response = cls.execute_code(language, preload, runner) | |||
| except CodeExecutionException as e: | |||
| except CodeExecutionError as e: | |||
| raise e | |||
| return template_transformer.transform_response(response) | |||
| @@ -78,8 +78,8 @@ class IndexingRunner: | |||
| dataset_document=dataset_document, | |||
| documents=documents, | |||
| ) | |||
| except DocumentIsPausedException: | |||
| raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) | |||
| except DocumentIsPausedError: | |||
| raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) | |||
| except ProviderTokenNotInitError as e: | |||
| dataset_document.indexing_status = "error" | |||
| dataset_document.error = str(e.description) | |||
| @@ -134,8 +134,8 @@ class IndexingRunner: | |||
| self._load( | |||
| index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents | |||
| ) | |||
| except DocumentIsPausedException: | |||
| raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) | |||
| except DocumentIsPausedError: | |||
| raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) | |||
| except ProviderTokenNotInitError as e: | |||
| dataset_document.indexing_status = "error" | |||
| dataset_document.error = str(e.description) | |||
| @@ -192,8 +192,8 @@ class IndexingRunner: | |||
| self._load( | |||
| index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents | |||
| ) | |||
| except DocumentIsPausedException: | |||
| raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) | |||
| except DocumentIsPausedError: | |||
| raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) | |||
| except ProviderTokenNotInitError as e: | |||
| dataset_document.indexing_status = "error" | |||
| dataset_document.error = str(e.description) | |||
| @@ -756,7 +756,7 @@ class IndexingRunner: | |||
| indexing_cache_key = "document_{}_is_paused".format(document_id) | |||
| result = redis_client.get(indexing_cache_key) | |||
| if result: | |||
| raise DocumentIsPausedException() | |||
| raise DocumentIsPausedError() | |||
| @staticmethod | |||
| def _update_document_index_status( | |||
| @@ -767,10 +767,10 @@ class IndexingRunner: | |||
| """ | |||
| count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() | |||
| if count > 0: | |||
| raise DocumentIsPausedException() | |||
| raise DocumentIsPausedError() | |||
| document = DatasetDocument.query.filter_by(id=document_id).first() | |||
| if not document: | |||
| raise DocumentIsDeletedPausedException() | |||
| raise DocumentIsDeletedPausedError() | |||
| update_params = {DatasetDocument.indexing_status: after_indexing_status} | |||
| @@ -875,9 +875,9 @@ class IndexingRunner: | |||
| pass | |||
| class DocumentIsPausedException(Exception): | |||
| class DocumentIsPausedError(Exception): | |||
| pass | |||
| class DocumentIsDeletedPausedException(Exception): | |||
| class DocumentIsDeletedPausedError(Exception): | |||
| pass | |||
| @@ -1,2 +1,2 @@ | |||
| class OutputParserException(Exception): | |||
| class OutputParserError(Exception): | |||
| pass | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Any | |||
| from core.llm_generator.output_parser.errors import OutputParserException | |||
| from core.llm_generator.output_parser.errors import OutputParserError | |||
| from core.llm_generator.prompts import ( | |||
| RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, | |||
| RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, | |||
| @@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser: | |||
| raise ValueError("Expected 'opening_statement' to be a str.") | |||
| return parsed | |||
| except Exception as e: | |||
| raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") | |||
| raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") | |||
| @@ -7,7 +7,7 @@ from requests import post | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( | |||
| BadRequestError, | |||
| InsufficientAccountBalance, | |||
| InsufficientAccountBalanceError, | |||
| InternalServerError, | |||
| InvalidAPIKeyError, | |||
| InvalidAuthenticationError, | |||
| @@ -124,7 +124,7 @@ class BaichuanModel: | |||
| if err == "invalid_api_key": | |||
| raise InvalidAPIKeyError(msg) | |||
| elif err == "insufficient_quota": | |||
| raise InsufficientAccountBalance(msg) | |||
| raise InsufficientAccountBalanceError(msg) | |||
| elif err == "invalid_authentication": | |||
| raise InvalidAuthenticationError(msg) | |||
| elif err == "invalid_request_error": | |||
| @@ -10,7 +10,7 @@ class RateLimitReachedError(Exception): | |||
| pass | |||
| class InsufficientAccountBalance(Exception): | |||
| class InsufficientAccountBalanceError(Exception): | |||
| pass | |||
| @@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B | |||
| from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel | |||
| from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( | |||
| BadRequestError, | |||
| InsufficientAccountBalance, | |||
| InsufficientAccountBalanceError, | |||
| InternalServerError, | |||
| InvalidAPIKeyError, | |||
| InvalidAuthenticationError, | |||
| @@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel): | |||
| InvokeRateLimitError: [RateLimitReachedError], | |||
| InvokeAuthorizationError: [ | |||
| InvalidAuthenticationError, | |||
| InsufficientAccountBalance, | |||
| InsufficientAccountBalanceError, | |||
| InvalidAPIKeyError, | |||
| ], | |||
| InvokeBadRequestError: [BadRequestError, KeyError], | |||
| @@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE | |||
| from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer | |||
| from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( | |||
| BadRequestError, | |||
| InsufficientAccountBalance, | |||
| InsufficientAccountBalanceError, | |||
| InternalServerError, | |||
| InvalidAPIKeyError, | |||
| InvalidAuthenticationError, | |||
| @@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): | |||
| if err == "invalid_api_key": | |||
| raise InvalidAPIKeyError(msg) | |||
| elif err == "insufficient_quota": | |||
| raise InsufficientAccountBalance(msg) | |||
| raise InsufficientAccountBalanceError(msg) | |||
| elif err == "invalid_authentication": | |||
| raise InvalidAuthenticationError(msg) | |||
| elif err and "rate" in err: | |||
| @@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): | |||
| InvokeRateLimitError: [RateLimitReachedError], | |||
| InvokeAuthorizationError: [ | |||
| InvalidAuthenticationError, | |||
| InsufficientAccountBalance, | |||
| InsufficientAccountBalanceError, | |||
| InvalidAPIKeyError, | |||
| ], | |||
| InvokeBadRequestError: [BadRequestError, KeyError], | |||
| @@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import ( | |||
| ) | |||
| class _CommonOAI_API_Compat: | |||
| class _CommonOaiApiCompat: | |||
| @property | |||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||
| """ | |||
| @@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import ( | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat | |||
| from core.model_runtime.utils import helper | |||
| logger = logging.getLogger(__name__) | |||
| class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): | |||
| class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): | |||
| """ | |||
| Model class for OpenAI large language model. | |||
| """ | |||
| @@ -6,10 +6,10 @@ import requests | |||
| from core.model_runtime.errors.invoke import InvokeBadRequestError | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat | |||
| class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): | |||
| class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel): | |||
| """ | |||
| Model class for OpenAI Compatible Speech to text model. | |||
| """ | |||
| @@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import ( | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat | |||
| class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): | |||
| class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): | |||
| """ | |||
| Model class for an OpenAI API-compatible text embedding model. | |||
| """ | |||
| @@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import ( | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat | |||
| from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat | |||
| class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): | |||
| class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): | |||
| """ | |||
| Model class for an OpenAI API-compatible text embedding model. | |||
| """ | |||
| @@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import ( | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error | |||
| from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService | |||
| from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService | |||
| class MaaSClient(MaasService): | |||
| @@ -106,7 +106,7 @@ class MaaSClient(MaasService): | |||
| def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: | |||
| try: | |||
| resp = fn() | |||
| except MaasException as e: | |||
| except MaasError as e: | |||
| raise wrap_error(e) | |||
| return resp | |||
| @@ -1,144 +1,144 @@ | |||
| from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException | |||
| from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError | |||
| class ClientSDKRequestError(MaasException): | |||
| class ClientSDKRequestError(MaasError): | |||
| pass | |||
| class SignatureDoesNotMatch(MaasException): | |||
| class SignatureDoesNotMatchError(MaasError): | |||
| pass | |||
| class RequestTimeout(MaasException): | |||
| class RequestTimeoutError(MaasError): | |||
| pass | |||
| class ServiceConnectionTimeout(MaasException): | |||
| class ServiceConnectionTimeoutError(MaasError): | |||
| pass | |||
| class MissingAuthenticationHeader(MaasException): | |||
| class MissingAuthenticationHeaderError(MaasError): | |||
| pass | |||
| class AuthenticationHeaderIsInvalid(MaasException): | |||
| class AuthenticationHeaderIsInvalidError(MaasError): | |||
| pass | |||
| class InternalServiceError(MaasException): | |||
| class InternalServiceError(MaasError): | |||
| pass | |||
| class MissingParameter(MaasException): | |||
| class MissingParameterError(MaasError): | |||
| pass | |||
| class InvalidParameter(MaasException): | |||
| class InvalidParameterError(MaasError): | |||
| pass | |||
| class AuthenticationExpire(MaasException): | |||
| class AuthenticationExpireError(MaasError): | |||
| pass | |||
| class EndpointIsInvalid(MaasException): | |||
| class EndpointIsInvalidError(MaasError): | |||
| pass | |||
| class EndpointIsNotEnable(MaasException): | |||
| class EndpointIsNotEnableError(MaasError): | |||
| pass | |||
| class ModelNotSupportStreamMode(MaasException): | |||
| class ModelNotSupportStreamModeError(MaasError): | |||
| pass | |||
| class ReqTextExistRisk(MaasException): | |||
| class ReqTextExistRiskError(MaasError): | |||
| pass | |||
| class RespTextExistRisk(MaasException): | |||
| class RespTextExistRiskError(MaasError): | |||
| pass | |||
| class EndpointRateLimitExceeded(MaasException): | |||
| class EndpointRateLimitExceededError(MaasError): | |||
| pass | |||
| class ServiceConnectionRefused(MaasException): | |||
| class ServiceConnectionRefusedError(MaasError): | |||
| pass | |||
| class ServiceConnectionClosed(MaasException): | |||
| class ServiceConnectionClosedError(MaasError): | |||
| pass | |||
| class UnauthorizedUserForEndpoint(MaasException): | |||
| class UnauthorizedUserForEndpointError(MaasError): | |||
| pass | |||
| class InvalidEndpointWithNoURL(MaasException): | |||
| class InvalidEndpointWithNoURLError(MaasError): | |||
| pass | |||
| class EndpointAccountRpmRateLimitExceeded(MaasException): | |||
| class EndpointAccountRpmRateLimitExceededError(MaasError): | |||
| pass | |||
| class EndpointAccountTpmRateLimitExceeded(MaasException): | |||
| class EndpointAccountTpmRateLimitExceededError(MaasError): | |||
| pass | |||
| class ServiceResourceWaitQueueFull(MaasException): | |||
| class ServiceResourceWaitQueueFullError(MaasError): | |||
| pass | |||
| class EndpointIsPending(MaasException): | |||
| class EndpointIsPendingError(MaasError): | |||
| pass | |||
| class ServiceNotOpen(MaasException): | |||
| class ServiceNotOpenError(MaasError): | |||
| pass | |||
| AuthErrors = { | |||
| "SignatureDoesNotMatch": SignatureDoesNotMatch, | |||
| "MissingAuthenticationHeader": MissingAuthenticationHeader, | |||
| "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid, | |||
| "AuthenticationExpire": AuthenticationExpire, | |||
| "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint, | |||
| "SignatureDoesNotMatch": SignatureDoesNotMatchError, | |||
| "MissingAuthenticationHeader": MissingAuthenticationHeaderError, | |||
| "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError, | |||
| "AuthenticationExpire": AuthenticationExpireError, | |||
| "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError, | |||
| } | |||
| BadRequestErrors = { | |||
| "MissingParameter": MissingParameter, | |||
| "InvalidParameter": InvalidParameter, | |||
| "EndpointIsInvalid": EndpointIsInvalid, | |||
| "EndpointIsNotEnable": EndpointIsNotEnable, | |||
| "ModelNotSupportStreamMode": ModelNotSupportStreamMode, | |||
| "ReqTextExistRisk": ReqTextExistRisk, | |||
| "RespTextExistRisk": RespTextExistRisk, | |||
| "InvalidEndpointWithNoURL": InvalidEndpointWithNoURL, | |||
| "ServiceNotOpen": ServiceNotOpen, | |||
| "MissingParameter": MissingParameterError, | |||
| "InvalidParameter": InvalidParameterError, | |||
| "EndpointIsInvalid": EndpointIsInvalidError, | |||
| "EndpointIsNotEnable": EndpointIsNotEnableError, | |||
| "ModelNotSupportStreamMode": ModelNotSupportStreamModeError, | |||
| "ReqTextExistRisk": ReqTextExistRiskError, | |||
| "RespTextExistRisk": RespTextExistRiskError, | |||
| "InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError, | |||
| "ServiceNotOpen": ServiceNotOpenError, | |||
| } | |||
| RateLimitErrors = { | |||
| "EndpointRateLimitExceeded": EndpointRateLimitExceeded, | |||
| "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded, | |||
| "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded, | |||
| "EndpointRateLimitExceeded": EndpointRateLimitExceededError, | |||
| "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError, | |||
| "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError, | |||
| } | |||
| ServerUnavailableErrors = { | |||
| "InternalServiceError": InternalServiceError, | |||
| "EndpointIsPending": EndpointIsPending, | |||
| "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull, | |||
| "EndpointIsPending": EndpointIsPendingError, | |||
| "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError, | |||
| } | |||
| ConnectionErrors = { | |||
| "ClientSDKRequestError": ClientSDKRequestError, | |||
| "RequestTimeout": RequestTimeout, | |||
| "ServiceConnectionTimeout": ServiceConnectionTimeout, | |||
| "ServiceConnectionRefused": ServiceConnectionRefused, | |||
| "ServiceConnectionClosed": ServiceConnectionClosed, | |||
| "RequestTimeout": RequestTimeoutError, | |||
| "ServiceConnectionTimeout": ServiceConnectionTimeoutError, | |||
| "ServiceConnectionRefused": ServiceConnectionRefusedError, | |||
| "ServiceConnectionClosed": ServiceConnectionClosedError, | |||
| } | |||
| ErrorCodeMap = { | |||
| @@ -150,7 +150,7 @@ ErrorCodeMap = { | |||
| } | |||
| def wrap_error(e: MaasException) -> Exception: | |||
| def wrap_error(e: MaasError) -> Exception: | |||
| if ErrorCodeMap.get(e.code): | |||
| return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) | |||
| return e | |||
| @@ -1,4 +1,4 @@ | |||
| from .common import ChatRole | |||
| from .maas import MaasException, MaasService | |||
| from .maas import MaasError, MaasService | |||
| __all__ = ["MaasService", "ChatRole", "MaasException"] | |||
| __all__ = ["MaasService", "ChatRole", "MaasError"] | |||
| @@ -63,7 +63,7 @@ class MaasService(Service): | |||
| raise | |||
| if res.error is not None and res.error.code_n != 0: | |||
| raise MaasException( | |||
| raise MaasError( | |||
| res.error.code_n, | |||
| res.error.code, | |||
| res.error.message, | |||
| @@ -72,7 +72,7 @@ class MaasService(Service): | |||
| yield res | |||
| return iter_fn() | |||
| except MaasException: | |||
| except MaasError: | |||
| raise | |||
| except Exception as e: | |||
| raise new_client_sdk_request_error(str(e)) | |||
| @@ -94,7 +94,7 @@ class MaasService(Service): | |||
| resp["req_id"] = req_id | |||
| return resp | |||
| except MaasException as e: | |||
| except MaasError as e: | |||
| raise e | |||
| except Exception as e: | |||
| raise new_client_sdk_request_error(str(e), req_id) | |||
| @@ -147,14 +147,14 @@ class MaasService(Service): | |||
| raise new_client_sdk_request_error(raw, req_id) | |||
| if resp.error: | |||
| raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id) | |||
| raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id) | |||
| else: | |||
| raise new_client_sdk_request_error(resp, req_id) | |||
| return res | |||
| class MaasException(Exception): | |||
| class MaasError(Exception): | |||
| def __init__(self, code_n, code, message, req_id): | |||
| self.code_n = code_n | |||
| self.code = code | |||
| @@ -172,7 +172,7 @@ class MaasException(Exception): | |||
| def new_client_sdk_request_error(raw, req_id=""): | |||
| return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) | |||
| return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) | |||
| class BinaryResponseContent: | |||
| @@ -192,7 +192,7 @@ class BinaryResponseContent: | |||
| if len(error_bytes) > 0: | |||
| resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) | |||
| raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) | |||
| raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) | |||
| def iter_bytes(self) -> Iterator[bytes]: | |||
| yield from self.response | |||
| @@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( | |||
| AuthErrors, | |||
| BadRequestErrors, | |||
| ConnectionErrors, | |||
| MaasException, | |||
| MaasError, | |||
| RateLimitErrors, | |||
| ServerUnavailableErrors, | |||
| ) | |||
| @@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): | |||
| }, | |||
| [UserPromptMessage(content="ping\nAnswer: ")], | |||
| ) | |||
| except MaasException as e: | |||
| except MaasError as e: | |||
| raise CredentialsValidateFailedError(e.message) | |||
| @staticmethod | |||
| @@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( | |||
| AuthErrors, | |||
| BadRequestErrors, | |||
| ConnectionErrors, | |||
| MaasException, | |||
| MaasError, | |||
| RateLimitErrors, | |||
| ServerUnavailableErrors, | |||
| ) | |||
| @@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): | |||
| def _validate_credentials_v2(self, model: str, credentials: dict) -> None: | |||
| try: | |||
| self._invoke(model=model, credentials=credentials, texts=["ping"]) | |||
| except MaasException as e: | |||
| except MaasError as e: | |||
| raise CredentialsValidateFailedError(e.message) | |||
| def _validate_credentials_v3(self, model: str, credentials: dict) -> None: | |||
| @@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: | |||
| InvokeRateLimitError: [RateLimitReachedError], | |||
| InvokeAuthorizationError: [ | |||
| InvalidAuthenticationError, | |||
| InsufficientAccountBalance, | |||
| InsufficientAccountBalanceError, | |||
| InvalidAPIKeyError, | |||
| ], | |||
| InvokeBadRequestError: [BadRequestError, KeyError], | |||
| @@ -42,7 +42,7 @@ class RateLimitReachedError(Exception): | |||
| pass | |||
| class InsufficientAccountBalance(Exception): | |||
| class InsufficientAccountBalanceError(Exception): | |||
| pass | |||
| @@ -76,7 +76,7 @@ class Moderation(Extensible, ABC): | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: | |||
| def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: | |||
| # inputs_config | |||
| inputs_config = config.get("inputs_config") | |||
| if not isinstance(inputs_config, dict): | |||
| @@ -111,5 +111,5 @@ class Moderation(Extensible, ABC): | |||
| raise ValueError("outputs_config.preset_response must be less than 100 characters") | |||
| class ModerationException(Exception): | |||
| class ModerationError(Exception): | |||
| pass | |||
| @@ -2,7 +2,7 @@ import logging | |||
| from typing import Optional | |||
| from core.app.app_config.entities import AppConfig | |||
| from core.moderation.base import ModerationAction, ModerationException | |||
| from core.moderation.base import ModerationAction, ModerationError | |||
| from core.moderation.factory import ModerationFactory | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| @@ -61,7 +61,7 @@ class InputModeration: | |||
| return False, inputs, query | |||
| if moderation_result.action == ModerationAction.DIRECT_OUTPUT: | |||
| raise ModerationException(moderation_result.preset_response) | |||
| raise ModerationError(moderation_result.preset_response) | |||
| elif moderation_result.action == ModerationAction.OVERRIDDEN: | |||
| inputs = moderation_result.inputs | |||
| query = moderation_result.query | |||
| @@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig): | |||
| host: str = "https://api.langfuse.com" | |||
| @field_validator("host") | |||
| @classmethod | |||
| def set_value(cls, v, info: ValidationInfo): | |||
| if v is None or v == "": | |||
| v = "https://api.langfuse.com" | |||
| @@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig): | |||
| endpoint: str = "https://api.smith.langchain.com" | |||
| @field_validator("endpoint") | |||
| @classmethod | |||
| def set_value(cls, v, info: ValidationInfo): | |||
| if v is None or v == "": | |||
| v = "https://api.smith.langchain.com" | |||
| @@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel): | |||
| metadata: dict[str, Any] | |||
| @field_validator("inputs", "outputs") | |||
| @classmethod | |||
| def ensure_type(cls, v): | |||
| if v is None: | |||
| return None | |||
| @@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel): | |||
| ) | |||
| @field_validator("input", "output") | |||
| @classmethod | |||
| def ensure_dict(cls, v, info: ValidationInfo): | |||
| field_name = info.field_name | |||
| return validate_input_output(v, field_name) | |||
| @@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel): | |||
| ) | |||
| @field_validator("input", "output") | |||
| @classmethod | |||
| def ensure_dict(cls, v, info: ValidationInfo): | |||
| field_name = info.field_name | |||
| return validate_input_output(v, field_name) | |||
| @@ -196,6 +198,7 @@ class GenerationUsage(BaseModel): | |||
| totalCost: Optional[float] = None | |||
| @field_validator("input", "output") | |||
| @classmethod | |||
| def ensure_dict(cls, v, info: ValidationInfo): | |||
| field_name = info.field_name | |||
| return validate_input_output(v, field_name) | |||
| @@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel): | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @field_validator("input", "output") | |||
| @classmethod | |||
| def ensure_dict(cls, v, info: ValidationInfo): | |||
| field_name = info.field_name | |||
| return validate_input_output(v, field_name) | |||
| @@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): | |||
| output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") | |||
| @field_validator("inputs", "outputs") | |||
| @classmethod | |||
| def ensure_dict(cls, v, info: ValidationInfo): | |||
| field_name = info.field_name | |||
| values = info.data | |||
| @@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): | |||
| return v | |||
| return v | |||
| @classmethod | |||
| @field_validator("start_time", "end_time") | |||
| def format_time(cls, v, info: ValidationInfo): | |||
| if not isinstance(v, datetime): | |||
| @@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel): | |||
| password: str | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["host"]: | |||
| raise ValueError("config HOST is required") | |||
| @@ -28,6 +28,7 @@ class MilvusConfig(BaseModel): | |||
| database: str = "default" | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values.get("uri"): | |||
| raise ValueError("config MILVUS_URI is required") | |||
| @@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel): | |||
| secure: bool = False | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values.get("host"): | |||
| raise ValueError("config OPENSEARCH_HOST is required") | |||
| @@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel): | |||
| database: str | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["host"]: | |||
| raise ValueError("config ORACLE_HOST is required") | |||
| @@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel): | |||
| database: str | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["host"]: | |||
| raise ValueError("config PGVECTO_RS_HOST is required") | |||
| @@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel): | |||
| database: str | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["host"]: | |||
| raise ValueError("config PGVECTOR_HOST is required") | |||
| @@ -34,6 +34,7 @@ class RelytConfig(BaseModel): | |||
| database: str | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["host"]: | |||
| raise ValueError("config RELYT_HOST is required") | |||
| @@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel): | |||
| program_name: str | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["host"]: | |||
| raise ValueError("config TIDB_VECTOR_HOST is required") | |||
| @@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel): | |||
| batch_size: int = 100 | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["endpoint"]: | |||
| raise ValueError("config WEAVIATE_ENDPOINT is required") | |||
| @@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool): | |||
| self.create_blob_message( | |||
| blob=b64decode(image.b64_json), | |||
| meta={"mime_type": "image/png"}, | |||
| save_as=self.VARIABLE_KEY.IMAGE.value, | |||
| save_as=self.VariableKey.IMAGE.value, | |||
| ) | |||
| ) | |||
| result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) | |||
| @@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool): | |||
| self.create_blob_message( | |||
| blob=b64decode(image.b64_json), | |||
| meta={"mime_type": "image/png"}, | |||
| save_as=self.VARIABLE_KEY.IMAGE.value, | |||
| save_as=self.VariableKey.IMAGE.value, | |||
| ) | |||
| ) | |||
| @@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool): | |||
| for image in response.data: | |||
| mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) | |||
| blob_message = self.create_blob_message( | |||
| blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value | |||
| blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value | |||
| ) | |||
| result.append(blob_message) | |||
| return result | |||
| @@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool): | |||
| self.create_blob_message( | |||
| blob=b64decode(client_result.image_file), | |||
| meta={"mime_type": f"image/{client_result.image_type}"}, | |||
| save_as=self.VARIABLE_KEY.IMAGE.value, | |||
| save_as=self.VariableKey.IMAGE.value, | |||
| ) | |||
| ) | |||
| @@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): | |||
| self.create_blob_message( | |||
| blob=b64decode(image_encoded), | |||
| meta={"mime_type": f"image/{image.image_type}"}, | |||
| save_as=self.VARIABLE_KEY.IMAGE.value, | |||
| save_as=self.VariableKey.IMAGE.value, | |||
| ) | |||
| ) | |||
| @@ -46,7 +46,7 @@ class QRCodeGeneratorTool(BuiltinTool): | |||
| image = self._generate_qrcode(content, border, error_correction) | |||
| image_bytes = self._image_to_byte_array(image) | |||
| return self.create_blob_message( | |||
| blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value | |||
| blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value | |||
| ) | |||
| except Exception: | |||
| logging.exception(f"Failed to generate QR code for content: {content}") | |||
| @@ -32,5 +32,5 @@ class FluxTool(BuiltinTool): | |||
| res = response.json() | |||
| result = [self.create_json_message(res)] | |||
| for image in res.get("images", []): | |||
| result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) | |||
| result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) | |||
| return result | |||
| @@ -41,5 +41,5 @@ class StableDiffusionTool(BuiltinTool): | |||
| res = response.json() | |||
| result = [self.create_json_message(res)] | |||
| for image in res.get("images", []): | |||
| result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) | |||
| result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) | |||
| return result | |||
| @@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| class AssembleHeaderException(Exception): | |||
| class AssembleHeaderError(Exception): | |||
| def __init__(self, msg): | |||
| self.message = msg | |||
| class Url: | |||
| def __init__(this, host, path, schema): | |||
| this.host = host | |||
| this.path = path | |||
| this.schema = schema | |||
| def __init__(self, host, path, schema): | |||
| self.host = host | |||
| self.path = path | |||
| self.schema = schema | |||
| # calculate sha256 and encode to base64 | |||
| @@ -41,7 +41,7 @@ def parse_url(request_url): | |||
| schema = request_url[: stidx + 3] | |||
| edidx = host.index("/") | |||
| if edidx <= 0: | |||
| raise AssembleHeaderException("invalid request url:" + request_url) | |||
| raise AssembleHeaderError("invalid request url:" + request_url) | |||
| path = host[edidx:] | |||
| host = host[:edidx] | |||
| u = Url(host, path, schema) | |||
| @@ -115,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool): | |||
| self.create_blob_message( | |||
| blob=b64decode(image["base64_image"]), | |||
| meta={"mime_type": "image/png"}, | |||
| save_as=self.VARIABLE_KEY.IMAGE.value, | |||
| save_as=self.VariableKey.IMAGE.value, | |||
| ) | |||
| ) | |||
| return result | |||
| @@ -52,5 +52,5 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): | |||
| raise Exception(response.text) | |||
| return self.create_blob_message( | |||
| blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value | |||
| blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value | |||
| ) | |||
| @@ -260,7 +260,7 @@ class StableDiffusionTool(BuiltinTool): | |||
| image = response.json()["images"][0] | |||
| return self.create_blob_message( | |||
| blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value | |||
| blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value | |||
| ) | |||
| except Exception as e: | |||
| @@ -294,7 +294,7 @@ class StableDiffusionTool(BuiltinTool): | |||
| image = response.json()["images"][0] | |||
| return self.create_blob_message( | |||
| blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value | |||
| blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value | |||
| ) | |||
| except Exception as e: | |||
| @@ -45,5 +45,5 @@ class PoiSearchTool(BuiltinTool): | |||
| ).content | |||
| return self.create_blob_message( | |||
| blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value | |||
| blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value | |||
| ) | |||
| @@ -32,7 +32,7 @@ class VectorizerTool(BuiltinTool): | |||
| if image_id.startswith("__test_"): | |||
| image_binary = b64decode(VECTORIZER_ICON_PNG) | |||
| else: | |||
| image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) | |||
| image_binary = self.get_variable_file(self.VariableKey.IMAGE) | |||
| if not image_binary: | |||
| return self.create_text_message("Image not found, please request user to generate image firstly.") | |||
| @@ -63,7 +63,7 @@ class Tool(BaseModel, ABC): | |||
| def __init__(self, **data: Any): | |||
| super().__init__(**data) | |||
| class VARIABLE_KEY(Enum): | |||
| class VariableKey(Enum): | |||
| IMAGE = "image" | |||
| def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": | |||
| @@ -142,7 +142,7 @@ class Tool(BaseModel, ABC): | |||
| if not self.variables: | |||
| return None | |||
| return self.get_variable(self.VARIABLE_KEY.IMAGE) | |||
| return self.get_variable(self.VariableKey.IMAGE) | |||
| def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: | |||
| """ | |||
| @@ -189,7 +189,7 @@ class Tool(BaseModel, ABC): | |||
| result = [] | |||
| for variable in self.variables.pool: | |||
| if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): | |||
| if variable.name.startswith(self.VariableKey.IMAGE.value): | |||
| result.append(variable) | |||
| return result | |||
| @@ -8,7 +8,7 @@ from typing import Any, Optional | |||
| from flask import Flask, current_app | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import ( | |||
| NodeRunMetadataKey, | |||
| @@ -669,7 +669,7 @@ class GraphEngine: | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| except GenerateTaskStoppedException: | |||
| except GenerateTaskStoppedError: | |||
| # trigger node run failed event | |||
| route_node_state.status = RouteNodeState.Status.FAILED | |||
| route_node_state.failed_reason = "Workflow stopped." | |||
| @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from configs import dify_config | |||
| from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage | |||
| from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage | |||
| from core.helper.code_executor.code_node_provider import CodeNodeProvider | |||
| from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider | |||
| from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider | |||
| @@ -61,7 +61,7 @@ class CodeNode(BaseNode): | |||
| # Transform result | |||
| result = self._transform_result(result, node_data.outputs) | |||
| except (CodeExecutionException, ValueError) as e: | |||
| except (CodeExecutionError, ValueError) as e: | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) | |||
| @@ -2,7 +2,7 @@ import os | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage | |||
| from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData | |||
| @@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode): | |||
| result = CodeExecutor.execute_workflow_code_template( | |||
| language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables | |||
| ) | |||
| except CodeExecutionException as e: | |||
| except CodeExecutionError as e: | |||
| return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) | |||
| if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: | |||
| @@ -6,7 +6,7 @@ from typing import Any, Optional, cast | |||
| from configs import dify_config | |||
| from core.app.app_config.entities import FileExtraConfig | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file.file_obj import FileTransferMethod, FileType, FileVar | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| @@ -103,7 +103,7 @@ class WorkflowEntry: | |||
| for callback in callbacks: | |||
| callback.on_event(event=event) | |||
| yield event | |||
| except GenerateTaskStoppedException: | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when workflow entry running") | |||
| @@ -5,7 +5,7 @@ import time | |||
| import click | |||
| from werkzeug.exceptions import NotFound | |||
| from core.indexing_runner import DocumentIsPausedException, IndexingRunner | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from events.event_handlers.document_index_event import document_index_created | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document | |||
| @@ -43,7 +43,7 @@ def handle(sender, **kwargs): | |||
| indexing_runner.run(documents) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) | |||
| except DocumentIsPausedException as ex: | |||
| except DocumentIsPausedError as ex: | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| except Exception: | |||
| pass | |||
| @@ -5,7 +5,7 @@ from collections.abc import Generator | |||
| from contextlib import closing | |||
| from flask import Flask | |||
| from google.cloud import storage as GoogleCloudStorage | |||
| from google.cloud import storage as google_cloud_storage | |||
| from extensions.storage.base_storage import BaseStorage | |||
| @@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage): | |||
| service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") | |||
| # convert str to object | |||
| service_account_obj = json.loads(service_account_json) | |||
| self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) | |||
| self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) | |||
| else: | |||
| self.client = GoogleCloudStorage.Client() | |||
| self.client = google_cloud_storage.Client() | |||
| def save(self, filename, data): | |||
| bucket = self.client.get_bucket(self.bucket_name) | |||
| @@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord | |||
| from Crypto.Util.strxor import strxor | |||
| class PKCS1OAEP_Cipher: | |||
| class PKCS1OAepCipher: | |||
| """Cipher object for PKCS#1 v1.5 OAEP. | |||
| Do not create directly: use :func:`new` instead.""" | |||
| @@ -237,4 +237,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): | |||
| if randfunc is None: | |||
| randfunc = Random.get_random_bytes | |||
| return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc) | |||
| return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc) | |||
| @@ -84,7 +84,7 @@ def timestamp_value(timestamp): | |||
| raise ValueError(error) | |||
| class str_len: | |||
| class StrLen: | |||
| """Restrict input to an integer in a range (inclusive)""" | |||
| def __init__(self, max_length, argument="argument"): | |||
| @@ -102,7 +102,7 @@ class str_len: | |||
| return value | |||
| class float_range: | |||
| class FloatRange: | |||
| """Restrict input to an float in a range (inclusive)""" | |||
| def __init__(self, low, high, argument="argument"): | |||
| @@ -121,7 +121,7 @@ class float_range: | |||
| return value | |||
| class datetime_string: | |||
| class DatetimeString: | |||
| def __init__(self, format, argument="argument"): | |||
| self.format = format | |||
| self.argument = argument | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| from core.llm_generator.output_parser.errors import OutputParserException | |||
| from core.llm_generator.output_parser.errors import OutputParserError | |||
| def parse_json_markdown(json_string: str) -> dict: | |||
| @@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: | |||
| try: | |||
| json_obj = parse_json_markdown(text) | |||
| except json.JSONDecodeError as e: | |||
| raise OutputParserException(f"Got invalid JSON object. Error: {e}") | |||
| raise OutputParserError(f"Got invalid JSON object. Error: {e}") | |||
| for key in expected_keys: | |||
| if key not in json_obj: | |||
| raise OutputParserException( | |||
| raise OutputParserError( | |||
| f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" | |||
| ) | |||
| return json_obj | |||
| @@ -15,8 +15,8 @@ select = [ | |||
| "C4", # flake8-comprehensions | |||
| "F", # pyflakes rules | |||
| "I", # isort rules | |||
| "N", # pep8-naming | |||
| "UP", # pyupgrade rules | |||
| "B035", # static-key-dict-comprehension | |||
| "E101", # mixed-spaces-and-tabs | |||
| "E111", # indentation-with-invalid-multiple | |||
| "E112", # no-indented-block | |||
| @@ -47,9 +47,10 @@ ignore = [ | |||
| "B006", # mutable-argument-default | |||
| "B007", # unused-loop-control-variable | |||
| "B026", # star-arg-unpacking-after-keyword-arg | |||
| # "B901", # return-in-generator | |||
| "B904", # raise-without-from-inside-except | |||
| "B905", # zip-without-explicit-strict | |||
| "N806", # non-lowercase-variable-in-function | |||
| "N815", # mixed-case-variable-in-class-scope | |||
| ] | |||
| [tool.ruff.lint.per-file-ignores] | |||
| @@ -65,6 +66,12 @@ ignore = [ | |||
| "F401", # unused-import | |||
| "F811", # redefined-while-unused | |||
| ] | |||
| "configs/*" = [ | |||
| "N802", # invalid-function-name | |||
| ] | |||
| "libs/gmpy2_pkcs10aep_cipher.py" = [ | |||
| "N803", # invalid-argument-name | |||
| ] | |||
| [tool.ruff.format] | |||
| exclude = [ | |||
| @@ -32,7 +32,7 @@ from services.errors.account import ( | |||
| NoPermissionError, | |||
| RateLimitExceededError, | |||
| RoleAlreadyAssignedError, | |||
| TenantNotFound, | |||
| TenantNotFoundError, | |||
| ) | |||
| from tasks.mail_invite_member_task import send_invite_member_mail_task | |||
| from tasks.mail_reset_password_task import send_reset_password_mail_task | |||
| @@ -311,13 +311,13 @@ class TenantService: | |||
| """Get tenant by account and add the role""" | |||
| tenant = account.current_tenant | |||
| if not tenant: | |||
| raise TenantNotFound("Tenant not found.") | |||
| raise TenantNotFoundError("Tenant not found.") | |||
| ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() | |||
| if ta: | |||
| tenant.role = ta.role | |||
| else: | |||
| raise TenantNotFound("Tenant not found for the account.") | |||
| raise TenantNotFoundError("Tenant not found for the account.") | |||
| return tenant | |||
| @staticmethod | |||
| @@ -614,8 +614,8 @@ class RegisterService: | |||
| "email": account.email, | |||
| "workspace_id": tenant.id, | |||
| } | |||
| expiryHours = dify_config.INVITE_EXPIRY_HOURS | |||
| redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) | |||
| expiry_hours = dify_config.INVITE_EXPIRY_HOURS | |||
| redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) | |||
| return token | |||
| @classmethod | |||
| @@ -1,7 +1,7 @@ | |||
| from services.errors.base import BaseServiceError | |||
| class AccountNotFound(BaseServiceError): | |||
| class AccountNotFoundError(BaseServiceError): | |||
| pass | |||
| @@ -25,7 +25,7 @@ class LinkAccountIntegrateError(BaseServiceError): | |||
| pass | |||
| class TenantNotFound(BaseServiceError): | |||
| class TenantNotFoundError(BaseServiceError): | |||
| pass | |||
| @@ -6,7 +6,7 @@ import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.indexing_runner import DocumentIsPausedException, IndexingRunner | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| @@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| logging.info( | |||
| click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") | |||
| ) | |||
| except DocumentIsPausedException as ex: | |||
| except DocumentIsPausedError as ex: | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| except Exception: | |||
| pass | |||
| @@ -6,7 +6,7 @@ import click | |||
| from celery import shared_task | |||
| from configs import dify_config | |||
| from core.indexing_runner import DocumentIsPausedException, IndexingRunner | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document | |||
| from services.feature_service import FeatureService | |||
| @@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): | |||
| indexing_runner.run(documents) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) | |||
| except DocumentIsPausedException as ex: | |||
| except DocumentIsPausedError as ex: | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| except Exception: | |||
| pass | |||
| @@ -6,7 +6,7 @@ import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.indexing_runner import DocumentIsPausedException, IndexingRunner | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| @@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): | |||
| indexing_runner.run([document]) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) | |||
| except DocumentIsPausedException as ex: | |||
| except DocumentIsPausedError as ex: | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| except Exception: | |||
| pass | |||
| @@ -6,7 +6,7 @@ import click | |||
| from celery import shared_task | |||
| from configs import dify_config | |||
| from core.indexing_runner import DocumentIsPausedException, IndexingRunner | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| @@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): | |||
| indexing_runner.run(documents) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) | |||
| except DocumentIsPausedException as ex: | |||
| except DocumentIsPausedError as ex: | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| except Exception: | |||
| pass | |||
| @@ -5,7 +5,7 @@ import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.indexing_runner import DocumentIsPausedException, IndexingRunner | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document | |||
| @@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): | |||
| logging.info( | |||
| click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") | |||
| ) | |||
| except DocumentIsPausedException as ex: | |||
| except DocumentIsPausedError as ex: | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| except Exception: | |||
| pass | |||
| @@ -70,6 +70,7 @@ class MockTEIClass: | |||
| }, | |||
| } | |||
| @staticmethod | |||
| def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: | |||
| # Example response: | |||
| # [ | |||
| @@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch | |||
| class MockedHttp: | |||
| @staticmethod | |||
| def httpx_request( | |||
| method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs | |||
| ) -> httpx.Response: | |||
| @@ -13,7 +13,7 @@ from xinference_client.types import Embedding | |||
| class MockTcvectordbClass: | |||
| def VectorDBClient( | |||
| def mock_vector_db_client( | |||
| self, | |||
| url=None, | |||
| username="", | |||
| @@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" | |||
| @pytest.fixture | |||
| def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): | |||
| if MOCK: | |||
| monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) | |||
| monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) | |||
| monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) | |||
| monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) | |||
| monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) | |||
| @@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true" | |||
| class MockedHttp: | |||
| @staticmethod | |||
| def httpx_request( | |||
| method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs | |||
| ) -> httpx.Response: | |||
| @@ -1,11 +1,11 @@ | |||
| import pytest | |||
| from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor | |||
| from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor | |||
| CODE_LANGUAGE = "unsupported_language" | |||
| def test_unsupported_with_code_template(): | |||
| with pytest.raises(CodeExecutionException) as e: | |||
| with pytest.raises(CodeExecutionError) as e: | |||
| CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) | |||
| assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" | |||