瀏覽代碼

chore: adopt StrEnum and auto() for some string-typed enums (#25129)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
tags/1.9.0
Krito. 1 月之前
父節點
當前提交
a13d7987e0
No account linked to committer's email address
共有 68 個文件被更改,包括 562 次插入563 次删除
  1. 2
    2
      api/commands.py
  2. 2
    2
      api/configs/middleware/vdb/opensearch_config.py
  3. 5
    5
      api/constants/model_template.py
  4. 1
    1
      api/controllers/console/app/conversation.py
  5. 1
    1
      api/controllers/console/app/model_config.py
  6. 1
    1
      api/controllers/console/explore/parameter.py
  7. 1
    1
      api/controllers/mcp/mcp.py
  8. 1
    1
      api/controllers/service_api/app/app.py
  9. 1
    1
      api/controllers/web/app.py
  10. 15
    15
      api/core/agent/plugin_entities.py
  11. 2
    2
      api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py
  12. 9
    9
      api/core/app/app_config/entities.py
  13. 6
    6
      api/core/app/entities/queue_entities.py
  14. 28
    28
      api/core/app/entities/task_entities.py
  15. 1
    1
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  16. 1
    1
      api/core/app/task_pipeline/message_cycle_manager.py
  17. 6
    6
      api/core/entities/agent_entities.py
  18. 4
    4
      api/core/entities/embedding_type.py
  19. 4
    4
      api/core/entities/model_entities.py
  20. 24
    24
      api/core/entities/parameter_entities.py
  21. 20
    20
      api/core/entities/provider_entities.py
  22. 4
    4
      api/core/extension/extensible.py
  23. 3
    3
      api/core/helper/model_provider_cache.py
  24. 3
    3
      api/core/helper/tool_parameter_cache.py
  25. 9
    9
      api/core/mcp/server/streamable_http.py
  26. 13
    13
      api/core/model_runtime/entities/message_entities.py
  27. 48
    48
      api/core/model_runtime/entities/model_entities.py
  28. 5
    5
      api/core/model_runtime/entities/provider_entities.py
  29. 1
    1
      api/core/model_runtime/model_providers/__base/text_embedding_model.py
  30. 2
    2
      api/core/model_runtime/utils/encoders.py
  31. 4
    4
      api/core/moderation/base.py
  32. 2
    2
      api/core/ops/aliyun_trace/entities/semconv.py
  33. 5
    5
      api/core/plugin/backwards_invocation/app.py
  34. 25
    25
      api/core/plugin/entities/parameters.py
  35. 14
    14
      api/core/plugin/entities/plugin.py
  36. 4
    4
      api/core/prompt/simple_prompt_transform.py
  37. 4
    4
      api/core/rag/datasource/vdb/field.py
  38. 2
    2
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  39. 2
    2
      api/core/rag/extractor/entity/datasource_type.py
  40. 7
    7
      api/core/rag/index_processor/constant/built_in_field.py
  41. 77
    78
      api/core/tools/entities/tool_entities.py
  42. 7
    7
      api/core/workflow/graph_engine/entities/runtime_route_state.py
  43. 8
    8
      api/core/workflow/nodes/agent/entities.py
  44. 4
    4
      api/core/workflow/nodes/answer/entities.py
  45. 2
    2
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  46. 6
    6
      api/extensions/storage/clickzetta_volume/file_lifecycle.py
  47. 2
    2
      api/extensions/storage/clickzetta_volume/volume_permissions.py
  48. 23
    23
      api/libs/email_i18n.py
  49. 1
    1
      api/libs/helper.py
  50. 6
    6
      api/models/dataset.py
  51. 8
    8
      api/models/model.py
  52. 8
    8
      api/models/provider.py
  53. 5
    5
      api/models/workflow.py
  54. 4
    4
      api/services/advanced_prompt_template_service.py
  55. 9
    9
      api/services/app_generate_service.py
  56. 6
    6
      api/services/app_service.py
  57. 2
    2
      api/services/audio_service.py
  58. 1
    1
      api/services/dataset_service.py
  59. 1
    1
      api/services/message_service.py
  60. 20
    20
      api/services/metadata_service.py
  61. 1
    1
      api/services/plugin/plugin_migration.py
  62. 4
    4
      api/services/workflow/workflow_converter.py
  63. 3
    3
      api/services/workflow_service.py
  64. 30
    30
      api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py
  65. 12
    12
      api/tests/test_containers_integration_tests/services/test_metadata_service.py
  66. 8
    8
      api/tests/test_containers_integration_tests/services/test_workflow_service.py
  67. 10
    10
      api/tests/unit_tests/core/mcp/server/test_streamable_http.py
  68. 2
    2
      api/tests/unit_tests/services/workflow/test_workflow_converter.py

+ 2
- 2
api/commands.py 查看文件

@@ -477,12 +477,12 @@ def convert_to_agent_apps():
click.echo(f"Converting app: {app.id}")

try:
app.mode = AppMode.AGENT_CHAT.value
app.mode = AppMode.AGENT_CHAT
db.session.commit()

# update conversation mode to agent
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value}
{Conversation.mode: AppMode.AGENT_CHAT}
)

db.session.commit()

+ 2
- 2
api/configs/middleware/vdb/opensearch_config.py 查看文件

@@ -1,4 +1,4 @@
import enum
from enum import Enum
from typing import Literal, Optional

from pydantic import Field, PositiveInt
@@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch
"""

class AuthMethod(enum.StrEnum):
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""

+ 5
- 5
api/constants/model_template.py 查看文件

@@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# workflow default mode
AppMode.WORKFLOW: {
"app": {
"mode": AppMode.WORKFLOW.value,
"mode": AppMode.WORKFLOW,
"enable_site": True,
"enable_api": True,
}
@@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# completion default mode
AppMode.COMPLETION: {
"app": {
"mode": AppMode.COMPLETION.value,
"mode": AppMode.COMPLETION,
"enable_site": True,
"enable_api": True,
},
@@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# chat default mode
AppMode.CHAT: {
"app": {
"mode": AppMode.CHAT.value,
"mode": AppMode.CHAT,
"enable_site": True,
"enable_api": True,
},
@@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# advanced-chat default mode
AppMode.ADVANCED_CHAT: {
"app": {
"mode": AppMode.ADVANCED_CHAT.value,
"mode": AppMode.ADVANCED_CHAT,
"enable_site": True,
"enable_api": True,
},
@@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# agent-chat default mode
AppMode.AGENT_CHAT: {
"app": {
"mode": AppMode.AGENT_CHAT.value,
"mode": AppMode.AGENT_CHAT,
"enable_site": True,
"enable_api": True,
},

+ 1
- 1
api/controllers/console/app/conversation.py 查看文件

@@ -307,7 +307,7 @@ class ChatConversationApi(Resource):
.having(func.count(Message.id) >= args["message_count_gte"])
)

if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)

match args["sort_by"]:

+ 1
- 1
api/controllers/console/app/model_config.py 查看文件

@@ -74,7 +74,7 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)

if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()

+ 1
- 1
api/controllers/console/explore/parameter.py 查看文件

@@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource):
if app_model is None:
raise AppUnavailableError()

if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

+ 1
- 1
api/controllers/mcp/mcp.py 查看文件

@@ -150,7 +150,7 @@ class MCPAppApi(Resource):
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
"""Get and convert user input form"""
# Get raw user input form based on app mode
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if not app.workflow:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)

+ 1
- 1
api/controllers/service_api/app/app.py 查看文件

@@ -29,7 +29,7 @@ class AppParameterApi(Resource):

Returns the input form parameters and configuration for the application.
"""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

+ 1
- 1
api/controllers/web/app.py 查看文件

@@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

+ 15
- 15
api/core/agent/plugin_entities.py 查看文件

@@ -1,4 +1,4 @@
import enum
from enum import StrEnum
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity):


class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
class AgentStrategyParameterType(StrEnum):
"""
Keep all the types from PluginParameterType
"""

STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
ANY = CommonParameterType.ANY.value
STRING = CommonParameterType.STRING
NUMBER = CommonParameterType.NUMBER
BOOLEAN = CommonParameterType.BOOLEAN
SELECT = CommonParameterType.SELECT
SECRET_INPUT = CommonParameterType.SECRET_INPUT
FILE = CommonParameterType.FILE
FILES = CommonParameterType.FILES
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY

# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES

def as_normal_type(self):
return as_normal_type(self)
@@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity):
pass


class AgentFeature(enum.StrEnum):
class AgentFeature(StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""

+ 2
- 2
api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py 查看文件

@@ -70,7 +70,7 @@ class PromptTemplateConfigManager:
:param config: app model config args
"""
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE

prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
@@ -90,7 +90,7 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")

if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"

+ 9
- 9
api/core/app/app_config/entities.py 查看文件

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Literal, Optional

from pydantic import BaseModel, Field, field_validator
@@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
Prompt Template Entity.
"""

class PromptType(Enum):
class PromptType(StrEnum):
"""
Prompt Type.
'simple', 'advanced'
"""

SIMPLE = "simple"
ADVANCED = "advanced"
SIMPLE = auto()
ADVANCED = auto()

@classmethod
def value_of(cls, value: str):
@@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Config Entity.
"""

class RetrieveStrategy(Enum):
class RetrieveStrategy(StrEnum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""

SINGLE = "single"
MULTIPLE = "multiple"
SINGLE = auto()
MULTIPLE = auto()

@classmethod
def value_of(cls, value: str):
@@ -293,12 +293,12 @@ class AppConfig(BaseModel):
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None


class EasyUIBasedAppModelConfigFrom(Enum):
class EasyUIBasedAppModelConfigFrom(StrEnum):
"""
App Model Config From.
"""

ARGS = "args"
ARGS = auto()
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"


+ 6
- 6
api/core/app/entities/queue_entities.py 查看文件

@@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Optional

from pydantic import BaseModel
@@ -626,15 +626,15 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent entity
"""

class StopBy(Enum):
class StopBy(StrEnum):
"""
Stop by enum
"""

USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
USER_MANUAL = auto()
ANNOTATION_REPLY = auto()
OUTPUT_MODERATION = auto()
INPUT_MODERATION = auto()

event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy

+ 28
- 28
api/core/app/entities/task_entities.py 查看文件

@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, Field
@@ -50,37 +50,37 @@ class WorkflowTaskState(TaskState):
answer: str = ""


class StreamEvent(Enum):
class StreamEvent(StrEnum):
"""
Stream event
"""

PING = "ping"
ERROR = "error"
MESSAGE = "message"
MESSAGE_END = "message_end"
TTS_MESSAGE = "tts_message"
TTS_MESSAGE_END = "tts_message_end"
MESSAGE_FILE = "message_file"
MESSAGE_REPLACE = "message_replace"
AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
LOOP_STARTED = "loop_started"
LOOP_NEXT = "loop_next"
LOOP_COMPLETED = "loop_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
PING = auto()
ERROR = auto()
MESSAGE = auto()
MESSAGE_END = auto()
TTS_MESSAGE = auto()
TTS_MESSAGE_END = auto()
MESSAGE_FILE = auto()
MESSAGE_REPLACE = auto()
AGENT_THOUGHT = auto()
AGENT_MESSAGE = auto()
WORKFLOW_STARTED = auto()
WORKFLOW_FINISHED = auto()
NODE_STARTED = auto()
NODE_FINISHED = auto()
NODE_RETRY = auto()
PARALLEL_BRANCH_STARTED = auto()
PARALLEL_BRANCH_FINISHED = auto()
ITERATION_STARTED = auto()
ITERATION_NEXT = auto()
ITERATION_COMPLETED = auto()
LOOP_STARTED = auto()
LOOP_NEXT = auto()
LOOP_COMPLETED = auto()
TEXT_CHUNK = auto()
TEXT_REPLACE = auto()
AGENT_LOG = auto()


class StreamResponse(BaseModel):

+ 1
- 1
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py 查看文件

@@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(

+ 1
- 1
api/core/app/task_pipeline/message_cycle_manager.py 查看文件

@@ -92,7 +92,7 @@ class MessageCycleManager:
if not conversation:
return

if conversation.mode != AppMode.COMPLETION.value:
if conversation.mode != AppMode.COMPLETION:
app_model = conversation.app
if not app_model:
return

+ 6
- 6
api/core/entities/agent_entities.py 查看文件

@@ -1,8 +1,8 @@
from enum import Enum
from enum import StrEnum, auto


class PlanningStrategy(Enum):
ROUTER = "router"
REACT_ROUTER = "react_router"
REACT = "react"
FUNCTION_CALL = "function_call"
class PlanningStrategy(StrEnum):
ROUTER = auto()
REACT_ROUTER = auto()
REACT = auto()
FUNCTION_CALL = auto()

+ 4
- 4
api/core/entities/embedding_type.py 查看文件

@@ -1,10 +1,10 @@
from enum import Enum
from enum import StrEnum, auto


class EmbeddingInputType(Enum):
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""

DOCUMENT = "document"
QUERY = "query"
DOCUMENT = auto()
QUERY = auto()

+ 4
- 4
api/core/entities/model_entities.py 查看文件

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from typing import Optional

from pydantic import BaseModel, ConfigDict
@@ -9,16 +9,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ProviderEntity


class ModelStatus(Enum):
class ModelStatus(StrEnum):
"""
Enum class for model status.
"""

ACTIVE = "active"
ACTIVE = auto()
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
DISABLED = auto()
CREDENTIAL_REMOVED = "credential-removed"



+ 24
- 24
api/core/entities/parameter_entities.py 查看文件

@@ -1,20 +1,20 @@
from enum import StrEnum
from enum import StrEnum, auto


class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SELECT = auto()
STRING = auto()
NUMBER = auto()
FILE = auto()
FILES = auto()
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
BOOLEAN = auto()
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
ANY = "any"
ANY = auto()

# Dynamic select parameter
# Once you are not sure about the available options until authorization is done
@@ -23,29 +23,29 @@ class CommonParameterType(StrEnum):

# TOOL_SELECTOR = "tool-selector"
# MCP object and array type parameters
ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()


class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
ALL = auto()
CHAT = auto()
WORKFLOW = auto()
COMPLETION = auto()


class ModelSelectorScope(StrEnum):
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
RERANK = auto()
TTS = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
VISION = auto()


class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"
ALL = auto()
CUSTOM = auto()
BUILTIN = auto()
WORKFLOW = auto()

+ 20
- 20
api/core/entities/provider_entities.py 查看文件

@@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum, auto
from typing import Optional, Union

from pydantic import BaseModel, ConfigDict, Field
@@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject


class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""

FREE = "free"
FREE = auto()
"""third-party free quota"""

TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""

@staticmethod
@@ -31,20 +31,20 @@ class ProviderQuotaType(Enum):
raise ValueError(f"No matching enum found for value '{value}'")


class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class QuotaUnit(StrEnum):
TIMES = auto()
TOKENS = auto()
CREDITS = auto()


class SystemConfigurationStatus(Enum):
class SystemConfigurationStatus(StrEnum):
"""
Enum class for system configuration status.
"""

ACTIVE = "active"
ACTIVE = auto()
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
UNSUPPORTED = auto()


class RestrictModel(BaseModel):
@@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel):
Base model class for common provider settings like credentials
"""

class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
class Type(StrEnum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT
TEXT_INPUT = CommonParameterType.TEXT_INPUT
SELECT = CommonParameterType.SELECT
BOOLEAN = CommonParameterType.BOOLEAN
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR

@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":

+ 4
- 4
api/core/extension/extensible.py 查看文件

@@ -1,8 +1,8 @@
import enum
import importlib.util
import json
import logging
import os
from enum import StrEnum, auto
from pathlib import Path
from typing import Any, Optional

@@ -13,9 +13,9 @@ from core.helper.position_helper import sort_to_dict_by_position_map
logger = logging.getLogger(__name__)


class ExtensionModule(enum.Enum):
MODERATION = "moderation"
EXTERNAL_DATA_TOOL = "external_data_tool"
class ExtensionModule(StrEnum):
MODERATION = auto()
EXTERNAL_DATA_TOOL = auto()


class ModuleExtension(BaseModel):

+ 3
- 3
api/core/helper/model_provider_cache.py 查看文件

@@ -1,12 +1,12 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional

from extensions.ext_redis import redis_client


class ProviderCredentialsCacheType(Enum):
class ProviderCredentialsCacheType(StrEnum):
PROVIDER = "provider"
MODEL = "provider_model"
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
@@ -14,7 +14,7 @@ class ProviderCredentialsCacheType(Enum):

class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"

def get(self) -> Optional[dict]:
"""

+ 3
- 3
api/core/helper/tool_parameter_cache.py 查看文件

@@ -1,12 +1,12 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional

from extensions.ext_redis import redis_client


class ToolParameterCacheType(Enum):
class ToolParameterCacheType(StrEnum):
PARAMETER = "tool_parameter"


@@ -15,7 +15,7 @@ class ToolParameterCache:
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
):
self.cache_key = (
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f":identity_id:{identity_id}"
)


+ 9
- 9
api/core/mcp/server/streamable_http.py 查看文件

@@ -142,7 +142,7 @@ def handle_call_tool(
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT.value,
streaming=app.mode == AppMode.AGENT_CHAT,
)

answer = extract_answer_from_response(app, response)
@@ -157,7 +157,7 @@ def build_parameter_schema(
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)

if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}:
return {
"type": "object",
"properties": parameters,
@@ -175,9 +175,9 @@ def build_parameter_schema(

def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW.value:
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION.value:
elif app.mode == AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
@@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
AppMode.ADVANCED_CHAT,
AppMode.COMPLETION,
AppMode.CHAT,
AppMode.AGENT_CHAT,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW.value:
elif app.mode == AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))

+ 13
- 13
api/core/model_runtime/entities/message_entities.py 查看文件

@@ -1,20 +1,20 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Optional, Union

from pydantic import BaseModel, Field, field_serializer, field_validator


class PromptMessageRole(Enum):
class PromptMessageRole(StrEnum):
"""
Enum class for prompt message.
"""

SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()

@classmethod
def value_of(cls, value: str) -> "PromptMessageRole":
@@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum):
Enum class for prompt message content type.
"""

TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"
TEXT = auto()
IMAGE = auto()
AUDIO = auto()
VIDEO = auto()
DOCUMENT = auto()


class PromptMessageContent(ABC, BaseModel):
@@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""

class DETAIL(StrEnum):
LOW = "low"
HIGH = "high"
LOW = auto()
HIGH = auto()

type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW

+ 48
- 48
api/core/model_runtime/entities/model_entities.py 查看文件

@@ -1,5 +1,5 @@
from decimal import Decimal
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, model_validator
@@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator
from core.model_runtime.entities.common_entities import I18nObject


class ModelType(Enum):
class ModelType(StrEnum):
"""
Enum class for model type.
"""

LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
RERANK = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
TTS = auto()

@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
@@ -26,17 +26,17 @@ class ModelType(Enum):

:return: model type
"""
if origin_model_type in {"text-generation", cls.LLM.value}:
if origin_model_type in {"text-generation", cls.LLM}:
return cls.LLM
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
return cls.TEXT_EMBEDDING
elif origin_model_type in {"reranking", cls.RERANK.value}:
elif origin_model_type in {"reranking", cls.RERANK}:
return cls.RERANK
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
return cls.SPEECH2TEXT
elif origin_model_type in {"tts", cls.TTS.value}:
elif origin_model_type in {"tts", cls.TTS}:
return cls.TTS
elif origin_model_type == cls.MODERATION.value:
elif origin_model_type == cls.MODERATION:
return cls.MODERATION
else:
raise ValueError(f"invalid origin model type {origin_model_type}")
@@ -63,7 +63,7 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}")


class FetchFrom(Enum):
class FetchFrom(StrEnum):
"""
Enum class for fetch from.
"""
@@ -72,7 +72,7 @@ class FetchFrom(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"


class ModelFeature(Enum):
class ModelFeature(StrEnum):
"""
Enum class for llm feature.
"""
@@ -80,11 +80,11 @@ class ModelFeature(Enum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()
STRUCTURED_OUTPUT = "structured-output"


@@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum):
Enum class for parameter template variable.
"""

TEMPERATURE = "temperature"
TOP_P = "top_p"
TOP_K = "top_k"
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"
TEMPERATURE = auto()
TOP_P = auto()
TOP_K = auto()
PRESENCE_PENALTY = auto()
FREQUENCY_PENALTY = auto()
MAX_TOKENS = auto()
RESPONSE_FORMAT = auto()
JSON_SCHEMA = auto()

@classmethod
def value_of(cls, value: Any) -> "DefaultParameterName":
@@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum):
raise ValueError(f"invalid parameter name {value}")


class ParameterType(Enum):
class ParameterType(StrEnum):
"""
Enum class for parameter type.
"""

FLOAT = "float"
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
TEXT = "text"
FLOAT = auto()
INT = auto()
STRING = auto()
BOOLEAN = auto()
TEXT = auto()


class ModelPropertyKey(Enum):
class ModelPropertyKey(StrEnum):
"""
Enum class for model property key.
"""

MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
FILE_UPLOAD_LIMIT = "file_upload_limit"
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDIO_TYPE = "audio_type"
MAX_WORKERS = "max_workers"
MODE = auto()
CONTEXT_SIZE = auto()
MAX_CHUNKS = auto()
FILE_UPLOAD_LIMIT = auto()
SUPPORTED_FILE_EXTENSIONS = auto()
MAX_CHARACTERS_PER_CHUNK = auto()
DEFAULT_VOICE = auto()
VOICES = auto()
WORD_LIMIT = auto()
AUDIO_TYPE = auto()
MAX_WORKERS = auto()


class ProviderModel(BaseModel):
@@ -220,13 +220,13 @@ class ModelUsage(BaseModel):
pass


class PriceType(Enum):
class PriceType(StrEnum):
"""
Enum class for price type.
"""

INPUT = "input"
OUTPUT = "output"
INPUT = auto()
OUTPUT = auto()


class PriceInfo(BaseModel):

+ 5
- 5
api/core/model_runtime/entities/provider_entities.py 查看文件

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import Enum, StrEnum, auto
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -17,16 +17,16 @@ class ConfigurateMethod(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"


class FormType(Enum):
class FormType(StrEnum):
"""
Enum class for form type.
"""

TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = "select"
RADIO = "radio"
SWITCH = "switch"
SELECT = auto()
RADIO = auto()
SWITCH = auto()


class FormShowOnObject(BaseModel):

+ 1
- 1
api/core/model_runtime/model_providers/__base/text_embedding_model.py 查看文件

@@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
model=model,
credentials=credentials,
texts=texts,
input_type=input_type.value,
input_type=input_type,
)
except Exception as e:
raise self._transform_invoke_error(e)

+ 2
- 2
api/core/model_runtime/utils/encoders.py 查看文件

@@ -18,7 +18,7 @@ from pydantic_core import Url
from pydantic_extra_types.color import Color


def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.model_dump(mode=mode, **kwargs)


@@ -100,7 +100,7 @@ def jsonable_encoder(
exclude_none: bool = False,
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
):
) -> Any:
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:

+ 4
- 4
api/core/moderation/base.py 查看文件

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum
from enum import StrEnum, auto
from typing import Optional

from pydantic import BaseModel, Field
@@ -7,9 +7,9 @@ from pydantic import BaseModel, Field
from core.extension.extensible import Extensible, ExtensionModule


class ModerationAction(Enum):
DIRECT_OUTPUT = "direct_output"
OVERRIDDEN = "overridden"
class ModerationAction(StrEnum):
DIRECT_OUTPUT = auto()
OVERRIDDEN = auto()


class ModerationInputsResult(BaseModel):

+ 2
- 2
api/core/ops/aliyun_trace/entities/semconv.py 查看文件

@@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum

# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
@@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"


class GenAISpanKind(Enum):
class GenAISpanKind(StrEnum):
CHAIN = "CHAIN"
RETRIEVER = "RETRIEVER"
RERANKER = "RERANKER"

+ 5
- 5
api/core/plugin/backwards_invocation/app.py 查看文件

@@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
app = cls._get_app(app_id, tenant_id)

"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
@@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):

conversation_id = conversation_id or ""

if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
if not query:
raise ValueError("missing query")

@@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke chat app
"""
if app.mode == AppMode.ADVANCED_CHAT.value:
if app.mode == AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
@@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.AGENT_CHAT.value:
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
@@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.CHAT.value:
elif app.mode == AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,

+ 25
- 25
api/core/plugin/entities/parameters.py 查看文件

@@ -1,5 +1,5 @@
import enum
import json
from enum import StrEnum, auto
from typing import Any, Optional, Union

from pydantic import BaseModel, Field, field_validator
@@ -25,44 +25,44 @@ class PluginParameterOption(BaseModel):
return value


class PluginParameterType(enum.StrEnum):
class PluginParameterType(StrEnum):
"""
all available parameter types
"""

STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
ANY = CommonParameterType.ANY.value
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
STRING = CommonParameterType.STRING
NUMBER = CommonParameterType.NUMBER
BOOLEAN = CommonParameterType.BOOLEAN
SELECT = CommonParameterType.SELECT
SECRET_INPUT = CommonParameterType.SECRET_INPUT
FILE = CommonParameterType.FILE
FILES = CommonParameterType.FILES
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT

# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES

# MCP object and array type parameters
ARRAY = CommonParameterType.ARRAY.value
OBJECT = CommonParameterType.OBJECT.value
ARRAY = CommonParameterType.ARRAY
OBJECT = CommonParameterType.OBJECT


class MCPServerParameterType(enum.StrEnum):
class MCPServerParameterType(StrEnum):
"""
MCP server got complex parameter types
"""

ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()


class PluginParameterAutoGenerate(BaseModel):
class Type(enum.StrEnum):
PROMPT_INSTRUCTION = "prompt_instruction"
class Type(StrEnum):
PROMPT_INSTRUCTION = auto()

type: Type

@@ -93,7 +93,7 @@ class PluginParameter(BaseModel):
return v


def as_normal_type(typ: enum.StrEnum):
def as_normal_type(typ: StrEnum):
if typ.value in {
PluginParameterType.SECRET_INPUT,
PluginParameterType.SELECT,
@@ -102,7 +102,7 @@ def as_normal_type(typ: enum.StrEnum):
return typ.value


def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
def cast_parameter_value(typ: StrEnum, value: Any, /):
try:
match typ.value:
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
@@ -190,7 +190,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")


def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any):
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
"""
init frontend parameter by rule
"""

+ 14
- 14
api/core/plugin/entities/plugin.py 查看文件

@@ -1,7 +1,7 @@
import datetime
import enum
import re
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Optional

from packaging.version import InvalidVersion, Version
@@ -16,11 +16,11 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity


class PluginInstallationSource(enum.StrEnum):
Github = "github"
Marketplace = "marketplace"
Package = "package"
Remote = "remote"
class PluginInstallationSource(StrEnum):
Github = auto()
Marketplace = auto()
Package = auto()
Remote = auto()


class PluginResourceRequirements(BaseModel):
@@ -58,10 +58,10 @@ class PluginResourceRequirements(BaseModel):
permission: Optional[Permission] = Field(default=None)


class PluginCategory(enum.StrEnum):
Tool = "tool"
Model = "model"
Extension = "extension"
class PluginCategory(StrEnum):
Tool = auto()
Model = auto()
Extension = auto()
AgentStrategy = "agent-strategy"


@@ -206,10 +206,10 @@ class ToolProviderID(GenericProviderID):


class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value
Marketplace = PluginInstallationSource.Marketplace.value
Package = PluginInstallationSource.Package.value
class Type(StrEnum):
Github = PluginInstallationSource.Github
Marketplace = PluginInstallationSource.Marketplace
Package = PluginInstallationSource.Package

class Github(BaseModel):
repo: str

+ 4
- 4
api/core/prompt/simple_prompt_transform.py 查看文件

@@ -1,7 +1,7 @@
import enum
import json
import os
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Optional, cast

from core.app.app_config.entities import PromptTemplateEntity
@@ -25,9 +25,9 @@ if TYPE_CHECKING:
from core.file.models import File


class ModelMode(enum.StrEnum):
COMPLETION = "completion"
CHAT = "chat"
class ModelMode(StrEnum):
COMPLETION = auto()
CHAT = auto()


prompt_file_contents: dict[str, Any] = {}

+ 4
- 4
api/core/rag/datasource/vdb/field.py 查看文件

@@ -1,13 +1,13 @@
from enum import Enum
from enum import StrEnum, auto


class Field(Enum):
class Field(StrEnum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
VECTOR = auto()
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
SPARSE_VECTOR = auto()
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

+ 2
- 2
api/core/rag/datasource/vdb/myscale/myscale_vector.py 查看文件

@@ -1,7 +1,7 @@
import json
import logging
import uuid
from enum import Enum
from enum import StrEnum
from typing import Any

from clickhouse_connect import get_client
@@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel):
fts_params: str


class SortOrder(Enum):
class SortOrder(StrEnum):
ASC = "ASC"
DESC = "DESC"


+ 2
- 2
api/core/rag/extractor/entity/datasource_type.py 查看文件

@@ -1,7 +1,7 @@
from enum import Enum
from enum import StrEnum


class DatasourceType(Enum):
class DatasourceType(StrEnum):
FILE = "upload_file"
NOTION = "notion_import"
WEBSITE = "website_crawl"

+ 7
- 7
api/core/rag/index_processor/constant/built_in_field.py 查看文件

@@ -1,15 +1,15 @@
from enum import Enum, StrEnum
from enum import StrEnum, auto


class BuiltInField(StrEnum):
document_name = "document_name"
uploader = "uploader"
upload_date = "upload_date"
last_update_date = "last_update_date"
source = "source"
document_name = auto()
uploader = auto()
upload_date = auto()
last_update_date = auto()
source = auto()


class MetadataDataSource(Enum):
class MetadataDataSource(StrEnum):
upload_file = "file_upload"
website_crawl = "website"
notion_import = "notion"

+ 77
- 78
api/core/tools/entities/tool_entities.py 查看文件

@@ -1,8 +1,7 @@
import base64
import contextlib
import enum
from collections.abc import Mapping
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
@@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY


class ToolLabelEnum(Enum):
SEARCH = "search"
IMAGE = "image"
VIDEOS = "videos"
WEATHER = "weather"
FINANCE = "finance"
DESIGN = "design"
TRAVEL = "travel"
SOCIAL = "social"
NEWS = "news"
MEDICAL = "medical"
PRODUCTIVITY = "productivity"
EDUCATION = "education"
BUSINESS = "business"
ENTERTAINMENT = "entertainment"
UTILITIES = "utilities"
OTHER = "other"
class ToolProviderType(enum.StrEnum):
class ToolLabelEnum(StrEnum):
SEARCH = auto()
IMAGE = auto()
VIDEOS = auto()
WEATHER = auto()
FINANCE = auto()
DESIGN = auto()
TRAVEL = auto()
SOCIAL = auto()
NEWS = auto()
MEDICAL = auto()
PRODUCTIVITY = auto()
EDUCATION = auto()
BUSINESS = auto()
ENTERTAINMENT = auto()
UTILITIES = auto()
OTHER = auto()
class ToolProviderType(StrEnum):
"""
Enum class for tool provider
"""

PLUGIN = "plugin"
PLUGIN = auto()
BUILT_IN = "builtin"
WORKFLOW = "workflow"
API = "api"
APP = "app"
WORKFLOW = auto()
API = auto()
APP = auto()
DATASET_RETRIEVAL = "dataset-retrieval"
MCP = "mcp"
MCP = auto()

@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
@@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum):
raise ValueError(f"invalid mode value {value}")


class ApiProviderSchemaType(Enum):
class ApiProviderSchemaType(StrEnum):
"""
Enum class for api provider schema type.
"""

OPENAPI = "openapi"
SWAGGER = "swagger"
OPENAI_PLUGIN = "openai_plugin"
OPENAI_ACTIONS = "openai_actions"
OPENAPI = auto()
SWAGGER = auto()
OPENAI_PLUGIN = auto()
OPENAI_ACTIONS = auto()

@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
@@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum):
raise ValueError(f"invalid mode value {value}")


class ApiProviderAuthType(Enum):
class ApiProviderAuthType(StrEnum):
"""
Enum class for api provider auth type.
"""

NONE = "none"
API_KEY_HEADER = "api_key_header"
API_KEY_QUERY = "api_key_query"
NONE = auto()
API_KEY_HEADER = auto()
API_KEY_QUERY = auto()

@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
@@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel):
return value

class LogMessage(BaseModel):
class LogStatus(Enum):
START = "start"
ERROR = "error"
SUCCESS = "success"
class LogStatus(StrEnum):
START = auto()
ERROR = auto()
SUCCESS = auto()

id: str
label: str = Field(..., description="The label of the log")
@@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"
BINARY_LINK = "binary_link"
VARIABLE = "variable"
FILE = "file"
LOG = "log"
BLOB_CHUNK = "blob_chunk"
RETRIEVER_RESOURCES = "retriever_resources"
class MessageType(StrEnum):
TEXT = auto()
IMAGE = auto()
LINK = auto()
BLOB = auto()
JSON = auto()
IMAGE_LINK = auto()
BINARY_LINK = auto()
VARIABLE = auto()
FILE = auto()
LOG = auto()
BLOB_CHUNK = auto()
RETRIEVER_RESOURCES = auto()

type: MessageType = MessageType.TEXT
"""
@@ -250,29 +249,29 @@ class ToolParameter(PluginParameter):
Overrides type
"""

class ToolParameterType(enum.StrEnum):
class ToolParameterType(StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""

STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
ANY = PluginParameterType.ANY.value
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
APP_SELECTOR = PluginParameterType.APP_SELECTOR
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
ANY = PluginParameterType.ANY
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT

# MCP object and array type parameters
ARRAY = MCPServerParameterType.ARRAY.value
OBJECT = MCPServerParameterType.OBJECT.value
ARRAY = MCPServerParameterType.ARRAY
OBJECT = MCPServerParameterType.OBJECT

# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES

def as_normal_type(self):
return as_normal_type(self)
@@ -280,10 +279,10 @@ class ToolParameter(PluginParameter):
def cast_value(self, value: Any):
return cast_parameter_value(self, value)

class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
class ToolParameterForm(StrEnum):
SCHEMA = auto() # should be set while adding tool
FORM = auto() # should be set before invoking tool
LLM = auto() # will be set by LLM

type: ToolParameterType = Field(..., description="The type of the parameter")
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
@@ -446,14 +445,14 @@ class ToolLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool")


class ToolInvokeFrom(Enum):
class ToolInvokeFrom(StrEnum):
"""
Enum class for tool invoke
"""

WORKFLOW = "workflow"
AGENT = "agent"
PLUGIN = "plugin"
WORKFLOW = auto()
AGENT = auto()
PLUGIN = auto()


class ToolSelector(BaseModel):
@@ -478,9 +477,9 @@ class ToolSelector(BaseModel):
return self.model_dump()


class CredentialType(enum.StrEnum):
class CredentialType(StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
OAUTH2 = auto()

def get_name(self):
if self == CredentialType.API_KEY:

+ 7
- 7
api/core/workflow/graph_engine/entities/runtime_route_state.py 查看文件

@@ -1,6 +1,6 @@
import uuid
from datetime import datetime
from enum import Enum
from enum import StrEnum, auto
from typing import Optional

from pydantic import BaseModel, Field
@@ -11,12 +11,12 @@ from libs.datetime_utils import naive_utc_now


class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"
class Status(StrEnum):
RUNNING = auto()
SUCCESS = auto()
FAILED = auto()
PAUSED = auto()
EXCEPTION = auto()

id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""

+ 8
- 8
api/core/workflow/nodes/agent/entities.py 查看文件

@@ -1,4 +1,4 @@
from enum import Enum, StrEnum
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union

from pydantic import BaseModel
@@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
agent_parameters: dict[str, AgentInput]


class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()


class AgentOldVersionModelFeatures(StrEnum):
@@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

+ 4
- 4
api/core/workflow/nodes/answer/entities.py 查看文件

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto

from pydantic import BaseModel, Field

@@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
Generate Route Chunk.
"""

class ChunkType(Enum):
VAR = "var"
TEXT = "text"
class ChunkType(StrEnum):
VAR = auto()
TEXT = auto()

type: ChunkType = Field(..., description="generate route chunk type")


+ 2
- 2
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py 查看文件

@@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":

+ 6
- 6
api/extensions/storage/clickzetta_volume/file_lifecycle.py 查看文件

@@ -9,19 +9,19 @@ import json
import logging
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional

logger = logging.getLogger(__name__)


class FileStatus(Enum):
class FileStatus(StrEnum):
"""File status enumeration"""

ACTIVE = "active" # Active status
ARCHIVED = "archived" # Archived
DELETED = "deleted" # Deleted (soft delete)
BACKUP = "backup" # Backup file
ACTIVE = auto() # Active status
ARCHIVED = auto() # Archived
DELETED = auto() # Deleted (soft delete)
BACKUP = auto() # Backup file


@dataclass

+ 2
- 2
api/extensions/storage/clickzetta_volume/volume_permissions.py 查看文件

@@ -5,13 +5,13 @@ According to ClickZetta's permission model, different Volume types have differen
"""

import logging
from enum import Enum
from enum import StrEnum
from typing import Optional

logger = logging.getLogger(__name__)


class VolumePermission(Enum):
class VolumePermission(StrEnum):
"""Volume permission type enumeration"""

READ = "SELECT" # Corresponds to ClickZetta's SELECT permission

+ 23
- 23
api/libs/email_i18n.py 查看文件

@@ -7,7 +7,7 @@ eliminates the need for repetitive language switching logic.
"""

from dataclasses import dataclass
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional, Protocol

from flask import render_template
@@ -17,30 +17,30 @@ from extensions.ext_mail import mail
from services.feature_service import BrandingModel, FeatureService


class EmailType(Enum):
class EmailType(StrEnum):
"""Enumeration of supported email types."""

RESET_PASSWORD = "reset_password"
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = "reset_password_when_account_not_exist"
INVITE_MEMBER = "invite_member"
EMAIL_CODE_LOGIN = "email_code_login"
CHANGE_EMAIL_OLD = "change_email_old"
CHANGE_EMAIL_NEW = "change_email_new"
CHANGE_EMAIL_COMPLETED = "change_email_completed"
OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
ACCOUNT_DELETION_SUCCESS = "account_deletion_success"
ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification"
ENTERPRISE_CUSTOM = "enterprise_custom"
QUEUE_MONITOR_ALERT = "queue_monitor_alert"
DOCUMENT_CLEAN_NOTIFY = "document_clean_notify"
EMAIL_REGISTER = "email_register"
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = "email_register_when_account_exist"
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = "reset_password_when_account_not_exist_no_register"
class EmailLanguage(Enum):
RESET_PASSWORD = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto()
INVITE_MEMBER = auto()
EMAIL_CODE_LOGIN = auto()
CHANGE_EMAIL_OLD = auto()
CHANGE_EMAIL_NEW = auto()
CHANGE_EMAIL_COMPLETED = auto()
OWNER_TRANSFER_CONFIRM = auto()
OWNER_TRANSFER_OLD_NOTIFY = auto()
OWNER_TRANSFER_NEW_NOTIFY = auto()
ACCOUNT_DELETION_SUCCESS = auto()
ACCOUNT_DELETION_VERIFICATION = auto()
ENTERPRISE_CUSTOM = auto()
QUEUE_MONITOR_ALERT = auto()
DOCUMENT_CLEAN_NOTIFY = auto()
EMAIL_REGISTER = auto()
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
class EmailLanguage(StrEnum):
"""Supported email languages with fallback handling."""

EN_US = "en-US"

+ 1
- 1
api/libs/helper.py 查看文件

@@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw):
if isinstance(obj, dict) and "app" in obj:
obj = obj["app"]

if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value:
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(obj.icon)
return None


+ 6
- 6
api/models/dataset.py 查看文件

@@ -224,35 +224,35 @@ class Dataset(Base):
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.document_name.value,
"name": BuiltInField.document_name,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.uploader.value,
"name": BuiltInField.uploader,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.upload_date.value,
"name": BuiltInField.upload_date,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.last_update_date.value,
"name": BuiltInField.last_update_date,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.source.value,
"name": BuiltInField.source,
"type": "string",
}
)
@@ -544,7 +544,7 @@ class Document(Base):
"id": "built-in",
"name": BuiltInField.source,
"type": "string",
"value": MetadataDataSource[self.data_source_type].value,
"value": MetadataDataSource[self.data_source_type],
}
)
return built_in_fields

+ 8
- 8
api/models/model.py 查看文件

@@ -3,7 +3,7 @@ import re
import uuid
from collections.abc import Mapping
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, Optional, cast

from core.plugin.entities.plugin import GenericProviderID
@@ -62,9 +62,9 @@ class AppMode(StrEnum):
raise ValueError(f"invalid mode value {value}")


class IconType(Enum):
IMAGE = "image"
EMOJI = "emoji"
class IconType(StrEnum):
IMAGE = auto()
EMOJI = auto()


class App(Base):
@@ -149,15 +149,15 @@ class App(Base):
if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
"strategy", ""
) in {"function_call", "react"}:
self.mode = AppMode.AGENT_CHAT.value
self.mode = AppMode.AGENT_CHAT
db.session.commit()
return True
return False

@property
def mode_compatible_with_agent(self) -> str:
if self.mode == AppMode.CHAT.value and self.is_agent:
return AppMode.AGENT_CHAT.value
if self.mode == AppMode.CHAT and self.is_agent:
return AppMode.AGENT_CHAT

return str(self.mode)

@@ -713,7 +713,7 @@ class Conversation(Base):
model_config = {}
app_model_config: Optional[AppModelConfig] = None

if self.mode == AppMode.ADVANCED_CHAT.value:
if self.mode == AppMode.ADVANCED_CHAT:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
model_config = override_model_configs

+ 8
- 8
api/models/provider.py 查看文件

@@ -1,5 +1,5 @@
from datetime import datetime
from enum import Enum
from enum import StrEnum, auto
from functools import cached_property
from typing import Optional

@@ -12,9 +12,9 @@ from .engine import db
from .types import StringUUID


class ProviderType(Enum):
CUSTOM = "custom"
SYSTEM = "system"
class ProviderType(StrEnum):
CUSTOM = auto()
SYSTEM = auto()

@staticmethod
def value_of(value: str) -> "ProviderType":
@@ -24,14 +24,14 @@ class ProviderType(Enum):
raise ValueError(f"No matching enum found for value '{value}'")


class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""

FREE = "free"
FREE = auto()
"""third-party free quota"""

TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""

@staticmethod

+ 5
- 5
api/models/workflow.py 查看文件

@@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4

@@ -41,13 +41,13 @@ from .types import EnumText, StringUUID
logger = logging.getLogger(__name__)


class WorkflowType(Enum):
class WorkflowType(StrEnum):
"""
Workflow Type Enum
"""

WORKFLOW = "workflow"
CHAT = "chat"
WORKFLOW = auto()
CHAT = auto()

@classmethod
def value_of(cls, value: str) -> "WorkflowType":
@@ -777,7 +777,7 @@ class WorkflowNodeExecutionModel(Base):
return extras


class WorkflowAppLogCreatedFrom(Enum):
class WorkflowAppLogCreatedFrom(StrEnum):
"""
Workflow App Log Created From Enum
"""

+ 4
- 4
api/services/advanced_prompt_template_service.py 查看文件

@@ -32,14 +32,14 @@ class AdvancedPromptTemplateService:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)

if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
@@ -73,7 +73,7 @@ class AdvancedPromptTemplateService:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)

if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
@@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),

+ 9
- 9
api/services/app_generate_service.py 查看文件

@@ -60,7 +60,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
@@ -69,7 +69,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
@@ -78,7 +78,7 @@ class AppGenerateService:
),
request_id,
)
elif app_model.mode == AppMode.CHAT.value:
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
@@ -87,7 +87,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@@ -103,7 +103,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@@ -155,14 +155,14 @@ class AppGenerateService:

@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
@@ -174,14 +174,14 @@ class AppGenerateService:

@classmethod
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(

+ 6
- 6
api/services/app_service.py 查看文件

@@ -40,15 +40,15 @@ class AppService:
filters = [App.tenant_id == tenant_id, App.is_universal == False]

if args["mode"] == "workflow":
filters.append(App.mode == AppMode.WORKFLOW.value)
filters.append(App.mode == AppMode.WORKFLOW)
elif args["mode"] == "completion":
filters.append(App.mode == AppMode.COMPLETION.value)
filters.append(App.mode == AppMode.COMPLETION)
elif args["mode"] == "chat":
filters.append(App.mode == AppMode.CHAT.value)
filters.append(App.mode == AppMode.CHAT)
elif args["mode"] == "advanced-chat":
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
filters.append(App.mode == AppMode.ADVANCED_CHAT)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
filters.append(App.mode == AppMode.AGENT_CHAT)

if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
@@ -171,7 +171,7 @@ class AppService:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
model_config = app.app_model_config
if not model_config:
return app

+ 2
- 2
api/services/audio_service.py 查看文件

@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("Speech to text is not enabled")
@@ -88,7 +88,7 @@ class AudioService:
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:

+ 1
- 1
api/services/dataset_service.py 查看文件

@@ -1004,7 +1004,7 @@ class DocumentService:
if dataset.built_in_field_enabled:
if document.doc_metadata:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = name
doc_metadata[BuiltInField.document_name] = name
document.doc_metadata = doc_metadata

document.name = name

+ 1
- 1
api/services/message_service.py 查看文件

@@ -229,7 +229,7 @@ class MessageService:

model_manager = ModelManager()

if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)

+ 20
- 20
api/services/metadata_service.py 查看文件

@@ -131,11 +131,11 @@ class MetadataService:
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name.value, "type": "string"},
{"name": BuiltInField.uploader.value, "type": "string"},
{"name": BuiltInField.upload_date.value, "type": "time"},
{"name": BuiltInField.last_update_date.value, "type": "time"},
{"name": BuiltInField.source.value, "type": "string"},
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
]

@staticmethod
@@ -153,11 +153,11 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
dataset.built_in_field_enabled = True
@@ -183,11 +183,11 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name.value, None)
doc_metadata.pop(BuiltInField.uploader.value, None)
doc_metadata.pop(BuiltInField.upload_date.value, None)
doc_metadata.pop(BuiltInField.last_update_date.value, None)
doc_metadata.pop(BuiltInField.source.value, None)
doc_metadata.pop(BuiltInField.document_name, None)
doc_metadata.pop(BuiltInField.uploader, None)
doc_metadata.pop(BuiltInField.upload_date, None)
doc_metadata.pop(BuiltInField.last_update_date, None)
doc_metadata.pop(BuiltInField.source, None)
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
@@ -211,11 +211,11 @@ class MetadataService:
for metadata_value in operation.metadata_list:
doc_metadata[metadata_value.name] = metadata_value.value
if dataset.built_in_field_enabled:
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()

+ 1
- 1
api/services/plugin/plugin_migration.py 查看文件

@@ -256,7 +256,7 @@ class PluginMigration:
return []

agent_app_model_config_ids = [
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
]

rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()

+ 4
- 4
api/services/workflow/workflow_converter.py 查看文件

@@ -65,7 +65,7 @@ class WorkflowConverter:
new_app = App()
new_app.tenant_id = app_model.tenant_id
new_app.name = name or app_model.name + "(workflow)"
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
new_app.icon_type = icon_type or app_model.icon_type
new_app.icon = icon or app_model.icon
new_app.icon_background = icon_background or app_model.icon_background
@@ -203,7 +203,7 @@ class WorkflowConverter:
app_mode_enum = AppMode.value_of(app_model.mode)
app_config: EasyUIBasedAppConfig
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_model.mode = AppMode.AGENT_CHAT
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
@@ -279,7 +279,7 @@ class WorkflowConverter:
"app_id": app_model.id,
"tool_variable": tool_variable,
"inputs": inputs,
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "",
},
}

@@ -618,7 +618,7 @@ class WorkflowConverter:
:param app_model: App instance
:return: AppMode
"""
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return AppMode.WORKFLOW
else:
return AppMode.ADVANCED_CHAT

+ 3
- 3
api/services/workflow_service.py 查看文件

@@ -828,7 +828,7 @@ class WorkflowService:
# chatbot convert to workflow mode
workflow_converter = WorkflowConverter()

if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")

# convert to workflow
@@ -844,11 +844,11 @@ class WorkflowService:
return new_app

def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)

+ 30
- 30
api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py 查看文件

@@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService:

# Test data for Baichuan model
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
@@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService:

# Test data for common model
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
@@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService:

for model_name in test_cases:
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": model_name,
"has_context": "true",
@@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false")

# Assert: Verify the expected outcomes
assert result is not None
@@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true")

# Assert: Verify empty dict is returned
assert result == {}
@@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true")

# Assert: Verify the expected outcomes
assert result is not None
@@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")

# Assert: Verify the expected outcomes
assert result is not None
@@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true")

# Assert: Verify empty dict is returned
assert result == {}
@@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
model_modes = ["completion", "chat"]

for app_mode in app_modes:
@@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()

# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
model_modes = ["completion", "chat"]

for app_mode in app_modes:
@@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService:
# Test edge cases
edge_cases = [
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"},
{"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"},
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "",
@@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService:

# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
@@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService:

# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
@@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService:
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
@@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService:
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",

+ 12
- 12
api/tests/test_containers_integration_tests/services/test_metadata_service.py 查看文件

@@ -255,7 +255,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id

# Try to create metadata with built-in field name
built_in_field_name = BuiltInField.document_name.value
built_in_field_name = BuiltInField.document_name
metadata_args = MetadataArgs(type="string", name=built_in_field_name)

# Act & Assert: Verify proper error handling
@@ -375,7 +375,7 @@ class TestMetadataService:
metadata = MetadataService.create_metadata(dataset.id, metadata_args)

# Try to update with built-in field name
built_in_field_name = BuiltInField.document_name.value
built_in_field_name = BuiltInField.document_name

with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
@@ -540,11 +540,11 @@ class TestMetadataService:
field_names = [field["name"] for field in result]
field_types = [field["type"] for field in result]

assert BuiltInField.document_name.value in field_names
assert BuiltInField.uploader.value in field_names
assert BuiltInField.upload_date.value in field_names
assert BuiltInField.last_update_date.value in field_names
assert BuiltInField.source.value in field_names
assert BuiltInField.document_name in field_names
assert BuiltInField.uploader in field_names
assert BuiltInField.upload_date in field_names
assert BuiltInField.last_update_date in field_names
assert BuiltInField.source in field_names

# Verify field types
assert "string" in field_types
@@ -682,11 +682,11 @@ class TestMetadataService:

# Set document metadata with built-in fields
document.doc_metadata = {
BuiltInField.document_name.value: document.name,
BuiltInField.uploader.value: "test_uploader",
BuiltInField.upload_date.value: 1234567890.0,
BuiltInField.last_update_date.value: 1234567890.0,
BuiltInField.source.value: "test_source",
BuiltInField.document_name: document.name,
BuiltInField.uploader: "test_uploader",
BuiltInField.upload_date: 1234567890.0,
BuiltInField.last_update_date: 1234567890.0,
BuiltInField.source: "test_source",
}
db.session.add(document)
db.session.commit()

+ 8
- 8
api/tests/test_containers_integration_tests/services/test_workflow_service.py 查看文件

@@ -96,7 +96,7 @@ class TestWorkflowService:
app.tenant_id = fake.uuid4()
app.name = fake.company()
app.description = fake.text()
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#FFEAD5"
@@ -883,7 +883,7 @@ class TestWorkflowService:

# Create chat mode app
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT

# Create app model config (required for conversion)
from models.model import AppModelConfig
@@ -926,7 +926,7 @@ class TestWorkflowService:

# Assert
assert result is not None
assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
assert result.name == conversion_args["name"]
assert result.icon == conversion_args["icon"]
assert result.icon_type == conversion_args["icon_type"]
@@ -945,7 +945,7 @@ class TestWorkflowService:

# Create completion mode app
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.COMPLETION.value
app.mode = AppMode.COMPLETION

# Create app model config (required for conversion)
from models.model import AppModelConfig
@@ -988,7 +988,7 @@ class TestWorkflowService:

# Assert
assert result is not None
assert result.mode == AppMode.WORKFLOW.value
assert result.mode == AppMode.WORKFLOW
assert result.name == conversion_args["name"]
assert result.icon == conversion_args["icon"]
assert result.icon_type == conversion_args["icon_type"]
@@ -1007,7 +1007,7 @@ class TestWorkflowService:

# Create workflow mode app (already in workflow mode)
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW

from extensions.ext_database import db

@@ -1030,7 +1030,7 @@ class TestWorkflowService:
# Arrange
fake = Faker()
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.ADVANCED_CHAT.value
app.mode = AppMode.ADVANCED_CHAT

from extensions.ext_database import db

@@ -1061,7 +1061,7 @@ class TestWorkflowService:
# Arrange
fake = Faker()
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW

from extensions.ext_database import db


+ 10
- 10
api/tests/unit_tests/core/mcp/server/test_streamable_http.py 查看文件

@@ -29,7 +29,7 @@ class TestHandleMCPRequest:
"""Setup test fixtures"""
self.app = Mock(spec=App)
self.app.name = "test_app"
self.app.mode = AppMode.CHAT.value
self.app.mode = AppMode.CHAT

self.mcp_server = Mock(spec=AppMCPServer)
self.mcp_server.description = "Test server"
@@ -196,7 +196,7 @@ class TestIndividualHandlers:
def test_handle_list_tools(self):
"""Test list tools handler"""
app_name = "test_app"
app_mode = AppMode.CHAT.value
app_mode = AppMode.CHAT
description = "Test server"
parameters_dict: dict[str, str] = {}
user_input_form: list[VariableEntity] = []
@@ -212,7 +212,7 @@ class TestIndividualHandlers:
def test_handle_call_tool(self, mock_app_generate):
"""Test call tool handler"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT

# Create mock request
mock_request = Mock()
@@ -252,7 +252,7 @@ class TestUtilityFunctions:

def test_build_parameter_schema_chat_mode(self):
"""Test building parameter schema for chat mode"""
app_mode = AppMode.CHAT.value
app_mode = AppMode.CHAT
parameters_dict: dict[str, str] = {"name": "Enter your name"}

user_input_form = [
@@ -275,7 +275,7 @@ class TestUtilityFunctions:

def test_build_parameter_schema_workflow_mode(self):
"""Test building parameter schema for workflow mode"""
app_mode = AppMode.WORKFLOW.value
app_mode = AppMode.WORKFLOW
parameters_dict: dict[str, str] = {"input_text": "Enter text"}

user_input_form = [
@@ -298,7 +298,7 @@ class TestUtilityFunctions:
def test_prepare_tool_arguments_chat_mode(self):
"""Test preparing tool arguments for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT

arguments = {"query": "test question", "name": "John"}

@@ -312,7 +312,7 @@ class TestUtilityFunctions:
def test_prepare_tool_arguments_workflow_mode(self):
"""Test preparing tool arguments for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW

arguments = {"input_text": "test input"}

@@ -324,7 +324,7 @@ class TestUtilityFunctions:
def test_prepare_tool_arguments_completion_mode(self):
"""Test preparing tool arguments for completion mode"""
app = Mock(spec=App)
app.mode = AppMode.COMPLETION.value
app.mode = AppMode.COMPLETION

arguments = {"name": "John"}

@@ -336,7 +336,7 @@ class TestUtilityFunctions:
def test_extract_answer_from_mapping_response_chat(self):
"""Test extracting answer from mapping response for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT

response = {"answer": "test answer", "other": "data"}

@@ -347,7 +347,7 @@ class TestUtilityFunctions:
def test_extract_answer_from_mapping_response_workflow(self):
"""Test extracting answer from mapping response for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW

response = {"data": {"outputs": {"result": "test result"}}}


+ 2
- 2
api/tests/unit_tests/services/workflow/test_workflow_converter.py 查看文件

@@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.CHAT.value
app_model.mode = AppMode.CHAT

api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
@@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.WORKFLOW.value
app_model.mode = AppMode.WORKFLOW

api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(

Loading…
取消
儲存