浏览代码

Fix basedpyright type errors (#25435)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.9.0
-LAN- 1 个月前
父节点
当前提交
08dd3f7b50
没有帐户链接到提交者的电子邮件
共有 100 个文件被更改,包括 839 次插入489 次删除
  1. 16
    2
      api/commands.py
  2. 6
    6
      api/constants/__init__.py
  3. 0
    1
      api/contexts/__init__.py
  4. 54
    46
      api/controllers/console/__init__.py
  5. 7
    6
      api/controllers/console/apikey.py
  6. 23
    7
      api/controllers/console/app/app.py
  7. 2
    2
      api/controllers/console/app/audio.py
  8. 14
    14
      api/controllers/console/app/completion.py
  9. 5
    1
      api/controllers/console/app/conversation.py
  10. 9
    4
      api/controllers/console/app/message.py
  11. 5
    1
      api/controllers/console/app/site.py
  12. 6
    6
      api/controllers/console/app/statistic.py
  13. 3
    3
      api/controllers/console/app/workflow_statistic.py
  14. 4
    1
      api/controllers/console/auth/oauth.py
  15. 10
    1
      api/controllers/console/explore/completion.py
  16. 12
    1
      api/controllers/console/explore/conversation.py
  17. 10
    3
      api/controllers/console/explore/installed_app.py
  18. 10
    1
      api/controllers/console/explore/message.py
  19. 4
    4
      api/controllers/console/explore/recommended_app.py
  20. 8
    1
      api/controllers/console/explore/saved_message.py
  21. 3
    0
      api/controllers/console/files.py
  22. 3
    3
      api/controllers/console/version.py
  23. 32
    0
      api/controllers/console/workspace/account.py
  24. 49
    10
      api/controllers/console/workspace/members.py
  25. 37
    0
      api/controllers/console/workspace/model_providers.py
  26. 22
    2
      api/controllers/console/workspace/workspace.py
  27. 1
    1
      api/controllers/files/__init__.py
  28. 3
    3
      api/controllers/inner_api/__init__.py
  29. 15
    15
      api/controllers/inner_api/plugin/plugin.py
  30. 5
    5
      api/controllers/inner_api/plugin/wraps.py
  31. 1
    1
      api/controllers/mcp/__init__.py
  32. 22
    4
      api/controllers/service_api/__init__.py
  33. 2
    1
      api/controllers/service_api/app/conversation.py
  34. 6
    0
      api/controllers/service_api/dataset/document.py
  35. 2
    2
      api/controllers/service_api/wraps.py
  36. 14
    14
      api/controllers/web/__init__.py
  37. 0
    1
      api/core/__init__.py
  38. 2
    0
      api/core/agent/cot_agent_runner.py
  39. 1
    0
      api/core/agent/fc_agent_runner.py
  40. 9
    2
      api/core/app/app_config/common/sensitive_word_avoidance/manager.py
  41. 7
    3
      api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py
  42. 6
    6
      api/core/app/apps/advanced_chat/generate_response_converter.py
  43. 12
    12
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  44. 19
    15
      api/core/app/apps/agent_chat/app_config_manager.py
  45. 7
    4
      api/core/app/apps/agent_chat/generate_response_converter.py
  46. 1
    0
      api/core/app/apps/base_app_queue_manager.py
  47. 7
    4
      api/core/app/apps/chat/generate_response_converter.py
  48. 2
    0
      api/core/app/apps/completion/app_generator.py
  49. 9
    4
      api/core/app/apps/completion/generate_response_converter.py
  50. 5
    5
      api/core/app/apps/workflow/generate_response_converter.py
  51. 5
    5
      api/core/app/apps/workflow/generate_task_pipeline.py
  52. 3
    3
      api/core/app/entities/app_invoke_entities.py
  53. 0
    7
      api/core/app/entities/task_entities.py
  54. 3
    0
      api/core/app/features/annotation_reply/annotation_reply.py
  55. 2
    0
      api/core/app/features/rate_limiting/__init__.py
  56. 1
    1
      api/core/app/features/rate_limiting/rate_limit.py
  57. 11
    11
      api/core/app/task_pipeline/based_generate_task_pipeline.py
  58. 11
    11
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  59. 3
    3
      api/core/base/tts/app_generator_tts_publisher.py
  60. 7
    1
      api/core/entities/provider_configuration.py
  61. 3
    3
      api/core/file/file_manager.py
  62. 8
    0
      api/core/file/models.py
  63. 7
    7
      api/core/helper/ssrf_proxy.py
  64. 6
    1
      api/core/indexing_runner.py
  65. 9
    3
      api/core/llm_generator/llm_generator.py
  66. 6
    8
      api/core/llm_generator/output_parser/structured_output.py
  67. 4
    4
      api/core/mcp/client/sse_client.py
  68. 14
    14
      api/core/mcp/server/streamable_http.py
  69. 6
    6
      api/core/mcp/session/base_session.py
  70. 1
    1
      api/core/model_runtime/model_providers/__base/large_language_model.py
  71. 1
    4
      api/core/plugin/entities/parameters.py
  72. 3
    1
      api/core/plugin/utils/chunk_merger.py
  73. 26
    6
      api/core/prompt/simple_prompt_transform.py
  74. 24
    11
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  75. 2
    2
      api/core/repositories/celery_workflow_node_execution_repository.py
  76. 1
    1
      api/core/variables/segment_group.py
  77. 12
    12
      api/core/variables/segments.py
  78. 2
    2
      api/core/workflow/errors.py
  79. 2
    2
      api/core/workflow/nodes/list_operator/node.py
  80. 2
    1
      api/core/workflow/nodes/llm/node.py
  81. 2
    2
      api/factories/file_factory.py
  82. 4
    1
      api/fields/_value_type_serializer.py
  83. 11
    3
      api/libs/external_api.py
  84. 0
    7
      api/libs/helper.py
  85. 37
    17
      api/pyrightconfig.json
  86. 2
    2
      api/services/account_service.py
  87. 35
    19
      api/services/annotation_service.py
  88. 1
    0
      api/services/clear_free_plan_tenant_expired_logs.py
  89. 10
    56
      api/services/dataset_service.py
  90. 1
    1
      api/services/external_knowledge_service.py
  91. 2
    2
      api/services/file_service.py
  92. 10
    7
      api/services/model_load_balancing_service.py
  93. 1
    0
      api/services/plugin/plugin_migration.py
  94. 5
    5
      api/services/tools/builtin_tools_manage_service.py
  95. 14
    2
      api/services/workflow/workflow_converter.py
  96. 2
    2
      api/services/workflow_service.py
  97. 1
    1
      api/services/workspace_service.py
  98. 2
    2
      api/tests/test_containers_integration_tests/services/test_account_service.py
  99. 2
    1
      api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py
  100. 0
    0
      api/tests/unit_tests/services/test_account_service.py

+ 16
- 2
api/commands.py 查看文件

@@ -511,7 +511,7 @@ def add_qdrant_index(field: str):
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType

from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig

for binding in bindings:
if dify_config.QDRANT_URL is None:
@@ -525,7 +525,21 @@ def add_qdrant_index(field: str):
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
)
try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
params = qdrant_config.to_qdrant_params()
# Check the type before using
if isinstance(params, PathQdrantParams):
# PathQdrantParams case
client = qdrant_client.QdrantClient(path=params.path)
else:
# UrlQdrantParams case - params is UrlQdrantParams
client = qdrant_client.QdrantClient(
url=params.url,
api_key=params.api_key,
timeout=int(params.timeout),
verify=params.verify,
grpc_port=params.grpc_port,
prefer_grpc=params.prefer_grpc,
)
# create payload index
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
create_count += 1

+ 6
- 6
api/constants/__init__.py 查看文件

@@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])


_doc_extensions: list[str]
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
_doc_extensions.append("ppt")
else:
DOCUMENT_EXTENSIONS = [
_doc_extensions = [
"txt",
"markdown",
"md",
@@ -38,4 +38,4 @@ else:
"vtt",
"properties",
]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]

+ 0
- 1
api/contexts/__init__.py 查看文件

@@ -8,7 +8,6 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool


"""

+ 54
- 46
api/controllers/console/__init__.py 查看文件

@@ -43,56 +43,64 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")

# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version
from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport]

# Import app controllers
from .app import (
advanced_prompt_template,
agent,
annotation,
app,
audio,
completion,
conversation,
conversation_variables,
generator,
mcp_server,
message,
model_config,
ops_trace,
site,
statistic,
workflow,
workflow_app_log,
workflow_draft_variable,
workflow_run,
workflow_statistic,
advanced_prompt_template, # pyright: ignore[reportUnusedImport]
agent, # pyright: ignore[reportUnusedImport]
annotation, # pyright: ignore[reportUnusedImport]
app, # pyright: ignore[reportUnusedImport]
audio, # pyright: ignore[reportUnusedImport]
completion, # pyright: ignore[reportUnusedImport]
conversation, # pyright: ignore[reportUnusedImport]
conversation_variables, # pyright: ignore[reportUnusedImport]
generator, # pyright: ignore[reportUnusedImport]
mcp_server, # pyright: ignore[reportUnusedImport]
message, # pyright: ignore[reportUnusedImport]
model_config, # pyright: ignore[reportUnusedImport]
ops_trace, # pyright: ignore[reportUnusedImport]
site, # pyright: ignore[reportUnusedImport]
statistic, # pyright: ignore[reportUnusedImport]
workflow, # pyright: ignore[reportUnusedImport]
workflow_app_log, # pyright: ignore[reportUnusedImport]
workflow_draft_variable, # pyright: ignore[reportUnusedImport]
workflow_run, # pyright: ignore[reportUnusedImport]
workflow_statistic, # pyright: ignore[reportUnusedImport]
)

# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
from .auth import (
activate, # pyright: ignore[reportUnusedImport]
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
data_source_oauth, # pyright: ignore[reportUnusedImport]
forgot_password, # pyright: ignore[reportUnusedImport]
login, # pyright: ignore[reportUnusedImport]
oauth, # pyright: ignore[reportUnusedImport]
oauth_server, # pyright: ignore[reportUnusedImport]
)

# Import billing controllers
from .billing import billing, compliance
from .billing import billing, compliance # pyright: ignore[reportUnusedImport]

# Import datasets controllers
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
external,
hit_testing,
metadata,
website,
data_source, # pyright: ignore[reportUnusedImport]
datasets, # pyright: ignore[reportUnusedImport]
datasets_document, # pyright: ignore[reportUnusedImport]
datasets_segments, # pyright: ignore[reportUnusedImport]
external, # pyright: ignore[reportUnusedImport]
hit_testing, # pyright: ignore[reportUnusedImport]
metadata, # pyright: ignore[reportUnusedImport]
website, # pyright: ignore[reportUnusedImport]
)

# Import explore controllers
from .explore import (
installed_app,
parameter,
recommended_app,
saved_message,
installed_app, # pyright: ignore[reportUnusedImport]
parameter, # pyright: ignore[reportUnusedImport]
recommended_app, # pyright: ignore[reportUnusedImport]
saved_message, # pyright: ignore[reportUnusedImport]
)

# Explore Audio
@@ -167,18 +175,18 @@ api.add_resource(
)

# Import tag controllers
from .tag import tags
from .tag import tags # pyright: ignore[reportUnusedImport]

# Import workspace controllers
from .workspace import (
account,
agent_providers,
endpoint,
load_balancing_config,
members,
model_providers,
models,
plugin,
tool_providers,
workspace,
account, # pyright: ignore[reportUnusedImport]
agent_providers, # pyright: ignore[reportUnusedImport]
endpoint, # pyright: ignore[reportUnusedImport]
load_balancing_config, # pyright: ignore[reportUnusedImport]
members, # pyright: ignore[reportUnusedImport]
model_providers, # pyright: ignore[reportUnusedImport]
models, # pyright: ignore[reportUnusedImport]
plugin, # pyright: ignore[reportUnusedImport]
tool_providers, # pyright: ignore[reportUnusedImport]
workspace, # pyright: ignore[reportUnusedImport]
)

+ 7
- 6
api/controllers/console/apikey.py 查看文件

@@ -1,8 +1,9 @@
from typing import Any, Optional
from typing import Optional

import flask_restx
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
).scalar_one_or_none()

if resource is None:
flask_restx.abort(404, message=f"{resource_model.__name__} not found.")
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")

return resource

@@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]

resource_type: str | None = None
resource_model: Optional[Any] = None
resource_model: Optional[type] = None
resource_id_field: str | None = None
token_prefix: str | None = None
max_keys = 10
@@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource):

if current_key_count >= self.max_keys:
flask_restx.abort(
400,
HTTPStatus.BAD_REQUEST,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
custom="max_keys_exceeded",
)
@@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]

resource_type: str | None = None
resource_model: Optional[Any] = None
resource_model: Optional[type] = None
resource_id_field: str | None = None

def delete(self, resource_id, api_key_id):
@@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource):
)

if key is None:
flask_restx.abort(404, message="API key not found")
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")

db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

+ 23
- 7
api/controllers/console/app/app.py 查看文件

@@ -115,6 +115,10 @@ class AppListApi(Resource):
raise BadRequest("mode is required")

app_service = AppService()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)

return app, 201
@@ -161,14 +165,26 @@ class AppApi(Resource):
args = parser.parse_args()

app_service = AppService()
app_model = app_service.update_app(app_model, args)
# Construct ArgsDict from parsed arguments
from services.app_service import AppService as AppServiceType

args_dict: AppServiceType.ArgsDict = {
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
"icon": args.get("icon", ""),
"icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.get("max_active_requests", 0),
}
app_model = app_service.update_app(app_model, args_dict)

return app_model

@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def delete(self, app_model):
"""Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -224,10 +240,10 @@ class AppCopyApi(Resource):


class AppExportApi(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
"""Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -263,7 +279,7 @@ class AppNameApi(Resource):
args = parser.parse_args()

app_service = AppService()
app_model = app_service.update_app_name(app_model, args.get("name"))
app_model = app_service.update_app_name(app_model, args["name"])

return app_model

@@ -285,7 +301,7 @@ class AppIconApi(Resource):
args = parser.parse_args()

app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")

return app_model

@@ -306,7 +322,7 @@ class AppSiteStatus(Resource):
args = parser.parse_args()

app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
app_model = app_service.update_app_site_status(app_model, args["enable_site"])

return app_model

@@ -327,7 +343,7 @@ class AppApiStatus(Resource):
args = parser.parse_args()

app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
app_model = app_service.update_app_api_status(app_model, args["enable_api"])

return app_model


+ 2
- 2
api/controllers/console/app/audio.py 查看文件

@@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource):


class ChatMessageTextApi(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def post(self, app_model: App):
try:
parser = reqparse.RequestParser()
@@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource):


class TextModesApi(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
try:
parser = reqparse.RequestParser()

+ 14
- 14
api/controllers/console/app/completion.py 查看文件

@@ -1,6 +1,5 @@
import logging

import flask_login
from flask import request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
@@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from libs.login import login_required
from libs.login import current_user, login_required
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@@ -56,11 +56,11 @@ class CompletionMessageApi(Resource):
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False

account = flask_login.current_user

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account or EndUser instance")
response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
)

return helper.compact_generate_response(response)
@@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id):
account = flask_login.current_user
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)

return {"result": "success"}, 200

@@ -123,11 +123,11 @@ class ChatMessageApi(Resource):
if external_trace_id:
args["external_trace_id"] = external_trace_id

account = flask_login.current_user

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account or EndUser instance")
response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
)

return helper.compact_generate_response(response)
@@ -161,9 +161,9 @@ class ChatMessageStopApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id):
account = flask_login.current_user
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)

return {"result": "success"}, 200


+ 5
- 1
api/controllers/console/app/conversation.py 查看文件

@@ -22,7 +22,7 @@ from fields.conversation_fields import (
from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models import Account, Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@@ -124,6 +124,8 @@ class CompletionConversationDetailApi(Resource):
conversation_id = str(conversation_id)

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -282,6 +284,8 @@ class ChatConversationDetailApi(Resource):
conversation_id = str(conversation_id)

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

+ 9
- 4
api/controllers/console/app/message.py 查看文件

@@ -1,6 +1,5 @@
import logging

from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
@@ -27,7 +26,8 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
@@ -118,11 +118,14 @@ class ChatMessageListApi(Resource):


class MessageFeedbackApi(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def post(self, app_model):
if current_user is None:
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
@@ -167,6 +170,8 @@ class MessageAnnotationApi(Resource):
@get_app_model
@marshal_with(annotation_fields)
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor:
raise Forbidden()

@@ -182,10 +187,10 @@ class MessageAnnotationApi(Resource):


class MessageAnnotationCountApi(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()


+ 5
- 1
api/controllers/console/app/site.py 查看文件

@@ -10,7 +10,7 @@ from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Site
from models import Account, Site


def parse_app_site_args():
@@ -75,6 +75,8 @@ class AppSite(Resource):
if value is not None:
setattr(site, attr_name, value)

if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()
@@ -99,6 +101,8 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound

site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()

+ 6
- 6
api/controllers/console/app/statistic.py 查看文件

@@ -18,10 +18,10 @@ from models import AppMode, Message


class DailyMessageStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user

@@ -75,10 +75,10 @@ WHERE


class DailyConversationStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user

@@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource):


class DailyTerminalsStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user

@@ -184,10 +184,10 @@ WHERE


class DailyTokenCostStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user

@@ -320,10 +320,10 @@ ORDER BY


class UserSatisfactionRateStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user

@@ -443,10 +443,10 @@ WHERE


class TokensPerSecondStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user


+ 3
- 3
api/controllers/console/app/workflow_statistic.py 查看文件

@@ -18,10 +18,10 @@ from models.model import AppMode


class WorkflowDailyRunsStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user

@@ -80,10 +80,10 @@ WHERE


class WorkflowDailyTerminalsStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user

@@ -142,10 +142,10 @@ WHERE


class WorkflowDailyTokenCostStatistic(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user


+ 4
- 1
api/controllers/console/auth/oauth.py 查看文件

@@ -77,6 +77,9 @@ class OAuthCallback(Resource):
if state:
invite_token = state

if not code:
return {"error": "Authorization code is required"}, 400

try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
@@ -86,7 +89,7 @@ class OAuthCallback(Resource):
return {"error": "OAuth process failed"}, 400

if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService._get_invitation_by_token(token=invite_token)
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:

+ 10
- 1
api/controllers/console/explore/completion.py 查看文件

@@ -1,6 +1,5 @@
import logging

from flask_login import current_user
from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError, NotFound

@@ -28,6 +27,8 @@ from extensions.ext_database import db
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource):
db.session.commit()

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
@@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != "completion":
raise NotCompletionAppError()

if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

return {"result": "success"}, 200
@@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource):
db.session.commit()

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
@@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()

if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

return {"result": "success"}, 200

+ 12
- 1
api/controllers/console/explore/conversation.py 查看文件

@@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
@@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
@@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource):
pinned = args["pinned"] == "true"

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
session=session,
@@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource):

conversation_id = str(c_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource):
args = parser.parse_args()

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
)
@@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource):
conversation_id = str(c_id)

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.pin(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource):
raise NotChatAppError()

conversation_id = str(c_id)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user)

return {"result": "success"}

+ 10
- 3
api/controllers/console/explore/installed_app.py 查看文件

@@ -2,7 +2,6 @@ import logging
from typing import Any

from flask import request
from flask_login import current_user
from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import App, InstalledApp, RecommendedApp
from libs.login import current_user, login_required
from models import Account, App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@@ -29,6 +28,8 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id

if app_id:
@@ -40,6 +41,8 @@ class InstalledAppsListApi(Resource):
else:
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()

if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None")
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [
{
@@ -115,6 +118,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("App not found")

if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).where(App.id == args["app_id"]).first()

@@ -154,6 +159,8 @@ class InstalledAppApi(InstalledAppResource):
"""

def delete(self, installed_app):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")


+ 10
- 1
api/controllers/console/explore/message.py 查看文件

@@ -1,6 +1,5 @@
import logging

from flask_login import current_user
from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
@@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource):
args = parser.parse_args()

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
@@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args()

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
@@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming"

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this(
app_model=app_model,
user=current_user,
@@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id)

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)

+ 4
- 4
api/controllers/console/explore/recommended_app.py 查看文件

@@ -1,11 +1,10 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse

from constants.languages import languages
from controllers.console import api
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import login_required
from libs.login import current_user, login_required
from services.recommended_app_service import RecommendedAppService

app_fields = {
@@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource):
parser.add_argument("language", type=str, location="args")
args = parser.parse_args()

if args.get("language") and args.get("language") in languages:
language_prefix = args.get("language")
language = args.get("language")
if language and language in languages:
language_prefix = language
elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language
else:

+ 8
- 1
api/controllers/console/explore/saved_message.py 查看文件

@@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound
@@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from libs.login import current_user
from models import Account
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService

@@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource):
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()

if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])

def post(self, installed_app):
@@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource):
args = parser.parse_args()

try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion":
raise NotCompletionAppError()

if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id)

return {"result": "success"}, 204

+ 3
- 0
api/controllers/console/files.py 查看文件

@@ -22,6 +22,7 @@ from controllers.console.wraps import (
)
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from models import Account
from services.file_service import FileService

PREVIEW_WORDS_LIMIT = 3000
@@ -68,6 +69,8 @@ class FileApi(Resource):
source = None

try:
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),

+ 3
- 3
api/controllers/console/version.py 查看文件

@@ -34,14 +34,14 @@ class VersionApi(Resource):
return result

try:
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
result["version"] = args.get("current_version")
result["version"] = args["current_version"]
return result

content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]

+ 32
- 0
api/controllers/console/workspace/account.py 查看文件

@@ -49,6 +49,8 @@ class AccountInitApi(Resource):
@setup_required
@login_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user

if account.status == "active":
@@ -102,6 +104,8 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields)
@enterprise_license_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
return current_user


@@ -111,6 +115,8 @@ class AccountNameApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
@@ -130,6 +136,8 @@ class AccountAvatarApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
@@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
args = parser.parse_args()
@@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
args = parser.parse_args()
@@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
@@ -194,6 +208,8 @@ class AccountPasswordApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json")
@@ -228,6 +244,8 @@ class AccountIntegrateApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user

account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
@@ -268,6 +286,8 @@ class AccountDeleteVerifyApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user

token, code = AccountService.generate_account_deletion_verification_code(account)
@@ -281,6 +301,8 @@ class AccountDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user

parser = reqparse.RequestParser()
@@ -321,6 +343,8 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user

return BillingService.EducationIdentity.verify(account.id, account.email)
@@ -340,6 +364,8 @@ class EducationApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user

parser = reqparse.RequestParser()
@@ -357,6 +383,8 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(status_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user

res = BillingService.EducationIdentity.status(account.id)
@@ -421,6 +449,8 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")

if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email:
raise InvalidEmailError()
else:
@@ -501,6 +531,8 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"])

old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if current_user.email != old_email:
raise AccountNotFound()


+ 49
- 10
api/controllers/console/workspace/members.py 查看文件

@@ -1,8 +1,8 @@
from urllib import parse

from flask import request
from flask import abort, request
from flask_login import current_user
from flask_restx import Resource, abort, marshal_with, reqparse
from flask_restx import Resource, marshal_with, reqparse

import services
from configs import dify_config
@@ -41,6 +41,10 @@ class MemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200

@@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource):
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400

if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL

@@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource):

for invitee_email in invitee_emails:
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
)
@@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource):
return {
"result": "success",
"invitation_results": invitation_results,
"tenant_id": str(current_user.current_tenant.id),
"tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "",
}, 201


@@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None:
abort(404)
@@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource):
except Exception as e:
raise ValueError(str(e))

return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
return {
"result": "success",
"tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "",
}, 200


class MemberUpdateRoleApi(Resource):
@@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400

if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
if not member:
abort(404)
@@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200

@@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource):
raise EmailSendIpLimitError()

# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError()

@@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource):
account=current_user,
email=email,
language=language,
workspace_name=current_user.current_tenant.name,
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
)

return {"result": "success", "data": token}
@@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError()

@@ -256,6 +289,10 @@ class OwnerTransfer(Resource):
args = parser.parse_args()

# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError()

@@ -274,9 +311,11 @@ class OwnerTransfer(Resource):
member = db.session.get(Account, str(member_id))
if not member:
abort(404)
else:
member_account = member
if not TenantService.is_member(member_account, current_user.current_tenant):
return # Never reached, but helps type checker

if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_member(member, current_user.current_tenant):
raise MemberNotInTenantError()

try:
@@ -286,13 +325,13 @@ class OwnerTransfer(Resource):
AccountService.send_new_owner_transfer_notify_email(
account=member,
email=member.email,
workspace_name=current_user.current_tenant.name,
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
)

AccountService.send_old_owner_transfer_notify_email(
account=current_user,
email=current_user.email,
workspace_name=current_user.current_tenant.name,
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
new_owner_email=member.email,
)


+ 37
- 0
api/controllers/console/workspace/model_providers.py 查看文件

@@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from models.account import Account
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService

@@ -21,6 +22,10 @@ class ModelProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
@@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser()
@@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner:
raise Forbidden()

@@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource):

model_provider_service = ModelProviderService()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id,
@@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner:
raise Forbidden()

@@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource):

model_provider_service = ModelProviderService()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id,
@@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
@@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id,
@@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id

model_provider_service = ModelProviderService()
@@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner:
raise Forbidden()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
@@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link(
provider_name=provider,
tenant_id=current_user.current_tenant_id,

+ 22
- 2
api/controllers/console/workspace/workspace.py 查看文件

@@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from models.account import Tenant, TenantStatus
from models.account import Account, Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService
from services.file_service import FileService
@@ -70,6 +70,8 @@ class TenantListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []

@@ -83,7 +85,7 @@ class TenantListApi(Resource):
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id,
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
}

tenant_dicts.append(tenant_dict)
@@ -125,7 +127,11 @@ class TenantApi(Resource):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")

if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")

if tenant.status == TenantStatus.ARCHIVE:
tenants = TenantService.get_join_tenants(current_user)
@@ -137,6 +143,8 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")

if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200


@@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
@@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)

custom_config_dict = {
@@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required
# Change workspace name
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()

if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant.name = args["name"]
db.session.commit()

+ 1
- 1
api/controllers/files/__init__.py 查看文件

@@ -15,6 +15,6 @@ api = ExternalApi(

files_ns = Namespace("files", description="File operations", path="/")

from . import image_preview, tool_files, upload
from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport]

api.add_namespace(files_ns)

+ 3
- 3
api/controllers/inner_api/__init__.py 查看文件

@@ -16,8 +16,8 @@ api = ExternalApi(
# Create namespace
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")

from . import mail
from .plugin import plugin
from .workspace import workspace
from . import mail as _mail # pyright: ignore[reportUnusedImport]
from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport]
from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport]

api.add_namespace(inner_api_ns)

+ 15
- 15
api/controllers/inner_api/plugin/plugin.py 查看文件

@@ -37,9 +37,9 @@ from models.model import EndUser

@inner_api_ns.route("/invoke/llm")
class PluginInvokeLLMApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLM)
@inner_api_ns.doc("plugin_invoke_llm")
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
@@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource):

@inner_api_ns.route("/invoke/llm/structured-output")
class PluginInvokeLLMWithStructuredOutputApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
@inner_api_ns.doc("plugin_invoke_llm_structured")
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
@@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):

@inner_api_ns.route("/invoke/text-embedding")
class PluginInvokeTextEmbeddingApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding)
@inner_api_ns.doc("plugin_invoke_text_embedding")
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
@@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource):

@inner_api_ns.route("/invoke/rerank")
class PluginInvokeRerankApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeRerank)
@inner_api_ns.doc("plugin_invoke_rerank")
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
@@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource):

@inner_api_ns.route("/invoke/tts")
class PluginInvokeTTSApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTTS)
@inner_api_ns.doc("plugin_invoke_tts")
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
@@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource):

@inner_api_ns.route("/invoke/speech2text")
class PluginInvokeSpeech2TextApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSpeech2Text)
@inner_api_ns.doc("plugin_invoke_speech2text")
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
@@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource):

@inner_api_ns.route("/invoke/moderation")
class PluginInvokeModerationApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeModeration)
@inner_api_ns.doc("plugin_invoke_moderation")
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
@@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource):

@inner_api_ns.route("/invoke/tool")
class PluginInvokeToolApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTool)
@inner_api_ns.doc("plugin_invoke_tool")
@inner_api_ns.doc(description="Invoke tools through plugin interface")
@@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource):

@inner_api_ns.route("/invoke/parameter-extractor")
class PluginInvokeParameterExtractorNodeApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
@@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource):

@inner_api_ns.route("/invoke/question-classifier")
class PluginInvokeQuestionClassifierNodeApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
@inner_api_ns.doc("plugin_invoke_question_classifier")
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
@@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):

@inner_api_ns.route("/invoke/app")
class PluginInvokeAppApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeApp)
@inner_api_ns.doc("plugin_invoke_app")
@inner_api_ns.doc(description="Invoke application through plugin interface")
@@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource):

@inner_api_ns.route("/invoke/encrypt")
class PluginInvokeEncryptApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeEncrypt)
@inner_api_ns.doc("plugin_invoke_encrypt")
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
@@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource):

@inner_api_ns.route("/invoke/summary")
class PluginInvokeSummaryApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSummary)
@inner_api_ns.doc("plugin_invoke_summary")
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
@@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource):

@inner_api_ns.route("/upload/file/request")
class PluginUploadFileRequestApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestRequestUploadFile)
@inner_api_ns.doc("plugin_upload_file_request")
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
@@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource):

@inner_api_ns.route("/fetch/app/info")
class PluginFetchAppInfoApi(Resource):
@get_user_tenant
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo)
@inner_api_ns.doc("plugin_fetch_app_info")
@inner_api_ns.doc(description="Fetch application information through plugin interface")

+ 5
- 5
api/controllers/inner_api/plugin/wraps.py 查看文件

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional, ParamSpec, TypeVar
from typing import Optional, ParamSpec, TypeVar, cast

from flask import current_app, request
from flask_login import user_logged_in
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session

from core.file.constants import DEFAULT_SERVICE_API_USER_ID
from extensions.ext_database import db
from libs.login import _get_user
from libs.login import current_user
from models.account import Tenant
from models.model import EndUser

@@ -66,8 +66,8 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):

p = parser.parse_args()

user_id: Optional[str] = p.get("user_id")
tenant_id: str = p.get("tenant_id")
user_id = cast(str, p.get("user_id"))
tenant_id = cast(str, p.get("tenant_id"))

if not tenant_id:
raise ValueError("tenant_id is required")
@@ -98,7 +98,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
kwargs["user_model"] = user

current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore

return view_func(*args, **kwargs)


+ 1
- 1
api/controllers/mcp/__init__.py 查看文件

@@ -15,6 +15,6 @@ api = ExternalApi(

mcp_ns = Namespace("mcp", description="MCP operations", path="/")

from . import mcp
from . import mcp # pyright: ignore[reportUnusedImport]

api.add_namespace(mcp_ns)

+ 22
- 4
api/controllers/service_api/__init__.py 查看文件

@@ -15,9 +15,27 @@ api = ExternalApi(

service_api_ns = Namespace("service_api", description="Service operations", path="/")

from . import index
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models
from . import index # pyright: ignore[reportUnusedImport]
from .app import (
annotation, # pyright: ignore[reportUnusedImport]
app, # pyright: ignore[reportUnusedImport]
audio, # pyright: ignore[reportUnusedImport]
completion, # pyright: ignore[reportUnusedImport]
conversation, # pyright: ignore[reportUnusedImport]
file, # pyright: ignore[reportUnusedImport]
file_preview, # pyright: ignore[reportUnusedImport]
message, # pyright: ignore[reportUnusedImport]
site, # pyright: ignore[reportUnusedImport]
workflow, # pyright: ignore[reportUnusedImport]
)
from .dataset import (
dataset, # pyright: ignore[reportUnusedImport]
document, # pyright: ignore[reportUnusedImport]
hit_testing, # pyright: ignore[reportUnusedImport]
metadata, # pyright: ignore[reportUnusedImport]
segment, # pyright: ignore[reportUnusedImport]
upload_file, # pyright: ignore[reportUnusedImport]
)
from .workspace import models # pyright: ignore[reportUnusedImport]

api.add_namespace(service_api_ns)

+ 2
- 1
api/controllers/service_api/app/conversation.py 查看文件

@@ -1,4 +1,5 @@
from flask_restx import Resource, reqparse
from flask_restx._http import HTTPStatus
from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@@ -121,7 +122,7 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204)
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)

+ 6
- 0
api/controllers/service_api/dataset/document.py 查看文件

@@ -30,6 +30,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.model import EndUser
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService
@@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError

if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")

upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
@@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FilenameNotExistsError

try:
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),

+ 2
- 2
api/controllers/service_api/wraps.py 查看文件

@@ -17,7 +17,7 @@ from core.file.constants import DEFAULT_SERVICE_API_USER_ID
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import _get_user
from libs.login import current_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser
@@ -210,7 +210,7 @@ def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:

+ 14
- 14
api/controllers/web/__init__.py 查看文件

@@ -17,20 +17,20 @@ api = ExternalApi(
web_ns = Namespace("web", description="Web application API operations", path="/")

from . import (
app,
audio,
completion,
conversation,
feature,
files,
forgot_password,
login,
message,
passport,
remote_files,
saved_message,
site,
workflow,
app, # pyright: ignore[reportUnusedImport]
audio, # pyright: ignore[reportUnusedImport]
completion, # pyright: ignore[reportUnusedImport]
conversation, # pyright: ignore[reportUnusedImport]
feature, # pyright: ignore[reportUnusedImport]
files, # pyright: ignore[reportUnusedImport]
forgot_password, # pyright: ignore[reportUnusedImport]
login, # pyright: ignore[reportUnusedImport]
message, # pyright: ignore[reportUnusedImport]
passport, # pyright: ignore[reportUnusedImport]
remote_files, # pyright: ignore[reportUnusedImport]
saved_message, # pyright: ignore[reportUnusedImport]
site, # pyright: ignore[reportUnusedImport]
workflow, # pyright: ignore[reportUnusedImport]
)

api.add_namespace(web_ns)

+ 0
- 1
api/core/__init__.py 查看文件

@@ -1 +0,0 @@
import core.moderation.base

+ 2
- 0
api/core/agent/cot_agent_runner.py 查看文件

@@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
agent_thought_id = "" # Initialize agent_thought_id

def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:

+ 1
- 0
api/core/agent/fc_agent_runner.py 查看文件

@@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages

# get tracing instance
trace_manager = app_generate_entity.trace_manager

+ 9
- 2
api/core/app/app_config/common/sensitive_word_avoidance/manager.py 查看文件

@@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager:

@classmethod
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}
@@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager:

if not only_structure_validate:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
if not isinstance(typ, str):
raise ValueError("sensitive_word_avoidance.type must be a string")

sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
if sensitive_word_avoidance_config is None:
sensitive_word_avoidance_config = {}
if not isinstance(sensitive_word_avoidance_config, dict):
raise ValueError("sensitive_word_avoidance.config must be a dict")

ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)


+ 7
- 3
api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py 查看文件

@@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
text = message.get("text")
if not isinstance(text, str):
raise ValueError("message text must be a string")
role = message.get("role")
if not isinstance(role, str):
raise ValueError("message role must be a string")
chat_prompt_messages.append(
AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
)

advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)

+ 6
- 6
api/core/app/apps/advanced_chat/generate_response_converter.py 查看文件

@@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue

response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

@classmethod
@@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue

response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}

if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))

yield response_chunk

+ 12
- 12
api/core/app/apps/advanced_chat/generate_task_pipeline.py 查看文件

@@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline:

generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)

if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@@ -302,13 +302,13 @@ class AdvancedChatAppGenerateTaskPipeline:

def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()

def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
with self._database_session() as session:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)

def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
@@ -627,10 +627,10 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)

yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
yield self._base_task_pipeline.error_to_stream_response(err)

def _handle_stop_event(
self,
@@ -683,7 +683,7 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)

output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
@@ -899,7 +899,7 @@ class AdvancedChatAppGenerateTaskPipeline:

message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
@@ -955,9 +955,9 @@ class AdvancedChatAppGenerateTaskPipeline:
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
if self._base_task_pipeline.output_moderation_handler:
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
@@ -967,7 +967,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return True
else:
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
self._base_task_pipeline.output_moderation_handler.append_new_token(text)

return False


+ 19
- 15
api/core/app/apps/agent_chat/app_config_manager.py 查看文件

@@ -1,6 +1,6 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, Optional, cast

from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return filtered_config

@classmethod
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_agent_mode_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate agent_mode and set defaults for agent feature

@@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}

if not isinstance(config["agent_mode"], dict):
agent_mode = config["agent_mode"]
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type")

if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
config["agent_mode"]["enabled"] = False
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
agent_mode = cast(dict[str, Any], agent_mode)

if not isinstance(config["agent_mode"]["enabled"], bool):
if "enabled" not in agent_mode or not agent_mode["enabled"]:
agent_mode["enabled"] = False

if not isinstance(agent_mode["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")

if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value

if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")

if not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not agent_mode.get("tools"):
agent_mode["tools"] = []

if not isinstance(config["agent_mode"]["tools"], list):
if not isinstance(agent_mode["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")

for tool in config["agent_mode"]["tools"]:
for tool in agent_mode["tools"]:
key = list(tool.keys())[0]
if key in OLD_TOOLS:
# old style, use tool name as key

+ 7
- 4
api/core/app/apps/agent_chat/generate_response_converter.py 查看文件

@@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)

metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}

return response

@@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

@classmethod
@@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}

if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))

yield response_chunk

+ 1
- 0
api/core/app/apps/base_app_queue_manager.py 查看文件

@@ -32,6 +32,7 @@ class AppQueueManager:
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self.invoke_from = invoke_from # Public accessor for invoke_from

user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(

+ 7
- 4
api/core/app/apps/chat/generate_response_converter.py 查看文件

@@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)

metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}

return response

@@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

@classmethod
@@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}

if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))

yield response_chunk

+ 2
- 0
api/core/app/apps/completion/app_generator.py 查看文件

@@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise MoreLikeThisDisabledError()

app_model_config = message.app_model_config
if not app_model_config:
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")

+ 9
- 4
api/core/app/apps/completion/generate_response_converter.py 查看文件

@@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)

metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}

return response

@@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

@classmethod
@@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
}

if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))

yield response_chunk

+ 5
- 5
api/core/app/apps/workflow/generate_response_converter.py 查看文件

@@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
return blocking_response.model_dump()

@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
@@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue

response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

@classmethod
@@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue

response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

+ 5
- 5
api/core/app/apps/workflow/generate_task_pipeline.py 查看文件

@@ -137,7 +137,7 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory

def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -146,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline:
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@@ -276,12 +276,12 @@ class WorkflowAppGenerateTaskPipeline:

def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()

def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event)
yield self._base_task_pipeline.error_to_stream_response(err)

def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs

+ 3
- 3
api/core/app/entities/app_invoke_entities.py 查看文件

@@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""

# app config
app_config: EasyUIBasedAppConfig
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity

query: Optional[str] = None
@@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""

# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore

workflow_run_id: Optional[str] = None
query: str
@@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""

# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_execution_id: str

class SingleIterationRunEntity(BaseModel):

+ 0
- 7
api/core/app/entities/task_entities.py 查看文件

@@ -5,7 +5,6 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field

from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -92,9 +91,6 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str

def to_dict(self):
return jsonable_encoder(self)


class ErrorStreamResponse(StreamResponse):
"""
@@ -745,9 +741,6 @@ class AppBlockingResponse(BaseModel):

task_id: str

def to_dict(self):
return jsonable_encoder(self)


class ChatbotAppBlockingResponse(AppBlockingResponse):
"""

+ 3
- 0
api/core/app/features/annotation_reply/annotation_reply.py 查看文件

@@ -35,6 +35,9 @@ class AnnotationReplyFeature:

collection_binding_detail = annotation_setting.collection_binding_detail

if not collection_binding_detail:
return None

try:
score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name

+ 2
- 0
api/core/app/features/rate_limiting/__init__.py 查看文件

@@ -1 +1,3 @@
from .rate_limit import RateLimit

__all__ = ["RateLimit"]

+ 1
- 1
api/core/app/features/rate_limiting/rate_limit.py 查看文件

@@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}

def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance

+ 11
- 11
api/core/app/task_pipeline/based_generate_task_pipeline.py 查看文件

@@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline:
):
self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
self.start_at = time.perf_counter()
self.output_moderation_handler = self._init_output_moderation()
self.stream = stream

def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
@@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline:

return message

def _error_to_stream_response(self, e: Exception):
def error_to_stream_response(self, e: Exception):
"""
Error to stream response.
:param e: exception
@@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline:
"""
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)

def _ping_stream_response(self) -> PingStreamResponse:
def ping_stream_response(self) -> PingStreamResponse:
"""
Ping stream response.
:return:
@@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline:
)
return None

def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
"""
Handle output moderation when task finished.
:param completion: completion
:return:
"""
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()

completion, flagged = self._output_moderation_handler.moderation_completion(
completion, flagged = self.output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)

self._output_moderation_handler = None
self.output_moderation_handler = None
if flagged:
return completion


+ 11
- 11
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py 查看文件

@@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)

generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):

if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err)
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
@@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._handle_stop(event)

# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
@@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self.ping_stream_response()
else:
continue
if publisher:
@@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self.start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
@@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)

@@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self.output_moderation_handler:
if self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
self.queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
@@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self.output_moderation_handler.append_new_token(text)

return False

+ 3
- 3
api/core/base/tts/app_generator_tts_publisher.py 查看文件

@@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher:
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self.max_sentence = 2
self._last_audio_event: Optional[AudioTrunk] = None
# FIXME better way to handle this threading.start
threading.Thread(target=self._runtime).start()
@@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher:
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
if len(sentence_arr) >= min(self.max_sentence, 7):
self.max_sentence += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice

+ 7
- 1
api/core/entities/provider_configuration.py 查看文件

@@ -1840,8 +1840,14 @@ class ProviderConfigurations(BaseModel):
def __setitem__(self, key, value):
self.configurations[key] = value

def __contains__(self, key):
if "/" not in key:
key = str(ModelProviderID(key))
return key in self.configurations

def __iter__(self):
return iter(self.configurations)
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
yield from self.configurations.items()

def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values())

+ 3
- 3
api/core/file/file_manager.py 查看文件

@@ -98,7 +98,7 @@ def to_prompt_message_content(

def download(f: File, /):
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
return _download_file_content(f._storage_key)
return _download_file_content(f.storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
@@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /):
response.raise_for_status()
data = response.content
case FileTransferMethod.LOCAL_FILE:
data = _download_file_content(f._storage_key)
data = _download_file_content(f.storage_key)
case FileTransferMethod.TOOL_FILE:
data = _download_file_content(f._storage_key)
data = _download_file_content(f.storage_key)

encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string

+ 8
- 0
api/core/file/models.py 查看文件

@@ -146,3 +146,11 @@ class File(BaseModel):
if not self.related_id:
raise ValueError("Missing file related_id")
return self

@property
def storage_key(self) -> str:
return self._storage_key

@storage_key.setter
def storage_key(self, value: str):
self._storage_key = value

+ 7
- 7
api/core/helper/ssrf_proxy.py 查看文件

@@ -13,18 +13,18 @@ logger = logging.getLogger(__name__)

SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES

HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
try:
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(config_value).lower()
if http_request_node_ssl_verify_lower == "true":
HTTP_REQUEST_NODE_SSL_VERIFY = True
http_request_node_ssl_verify = True
elif http_request_node_ssl_verify_lower == "false":
HTTP_REQUEST_NODE_SSL_VERIFY = False
http_request_node_ssl_verify = False
else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError:
HTTP_REQUEST_NODE_SSL_VERIFY = True
http_request_node_ssl_verify = True

BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
@@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)

if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
kwargs["ssl_verify"] = http_request_node_ssl_verify

ssl_verify = kwargs.pop("ssl_verify")


+ 6
- 1
api/core/indexing_runner.py 查看文件

@@ -529,6 +529,7 @@ class IndexingRunner:
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
create_keyword_thread = None
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
# create keyword index
create_keyword_thread = threading.Thread(
@@ -567,7 +568,11 @@ class IndexingRunner:

for future in futures:
tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
if (
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and create_keyword_thread is not None
):
create_keyword_thread.join()
indexing_end_at = time.perf_counter()


+ 9
- 3
api/core/llm_generator/llm_generator.py 查看文件

@@ -20,7 +20,7 @@ from core.llm_generator.prompts import (
)
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.entities.trace_entity import TraceTaskName
@@ -313,14 +313,20 @@ class LLMGenerator:
model_type=ModelType.LLM,
)

prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]

response: LLMResult = model_instance.invoke_llm(
# Explicitly use the non-streaming overload
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
)

# Runtime type check since pyright has issues with the overload
if not isinstance(result, LLMResult):
raise TypeError("Expected LLMResult when stream=False")
response = result

answer = cast(str, response.message.content)
return answer.strip()


+ 6
- 8
api/core/llm_generator/output_parser/structured_output.py 查看文件

@@ -45,6 +45,7 @@ class SpecialModelType(StrEnum):

@overload
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
@@ -53,14 +54,13 @@ def invoke_llm_with_structured_output(
model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[True] = True,
stream: Literal[True],
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...


@overload
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
@@ -69,14 +69,13 @@ def invoke_llm_with_structured_output(
model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[False] = False,
stream: Literal[False],
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput: ...


@overload
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
@@ -89,9 +88,8 @@ def invoke_llm_with_structured_output(
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...


def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,

+ 4
- 4
api/core/mcp/client/sse_client.py 查看文件

@@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
@final
class _StatusReady:
def __init__(self, endpoint_url: str):
self._endpoint_url = endpoint_url
self.endpoint_url = endpoint_url


@final
class _StatusError:
def __init__(self, exc: Exception):
self._exc = exc
self.exc = exc


# Type aliases for better readability
@@ -211,9 +211,9 @@ class SSETransport:
raise ValueError("failed to get endpoint URL")

if isinstance(status, _StatusReady):
return status._endpoint_url
return status.endpoint_url
elif isinstance(status, _StatusError):
raise status._exc
raise status.exc
else:
raise ValueError("failed to get endpoint URL")


+ 14
- 14
api/core/mcp/server/streamable_http.py 查看文件

@@ -38,6 +38,7 @@ def handle_mcp_request(
"""

request_type = type(request.root)
request_root = request.root

def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
"""Create success response with business result data"""
@@ -58,21 +59,20 @@ def handle_mcp_request(
error=error_data,
)

# Request handler mapping using functional approach
request_handlers = {
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
mcp_types.ListToolsRequest: lambda: handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
),
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
mcp_types.PingRequest: lambda: handle_ping(),
}

try:
# Dispatch request to appropriate handler
handler = request_handlers.get(request_type)
if handler:
return create_success_response(handler())
# Dispatch request to appropriate handler based on instance type
if isinstance(request_root, mcp_types.InitializeRequest):
return create_success_response(handle_initialize(mcp_server.description))
elif isinstance(request_root, mcp_types.ListToolsRequest):
return create_success_response(
handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
)
)
elif isinstance(request_root, mcp_types.CallToolRequest):
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
elif isinstance(request_root, mcp_types.PingRequest):
return create_success_response(handle_ping())
else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")


+ 6
- 6
api/core/mcp/session/base_session.py 查看文件

@@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.request_meta = request_meta
self.request = request
self._session = session
self._completed = False
self.completed = False
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager

@@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
):
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
if self.completed:
self._on_complete(self)
finally:
self._entered = False
@@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
assert not self.completed, "Request already responded to"

self._completed = True
self.completed = True

self._session._send_response(request_id=self.request_id, response=response)

@@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")

self._completed = True # Mark as completed so it's removed from in_flight
self.completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
request_id=self.request_id,
@@ -351,7 +351,7 @@ class BaseSession(
self._in_flight[responder.request_id] = responder
self._received_request(responder)

if not responder._completed:
if not responder.completed:
self._handle_incoming(responder)

elif isinstance(message.message.root, JSONRPCNotification):

+ 1
- 1
api/core/model_runtime/model_providers/__base/large_language_model.py 查看文件

@@ -354,7 +354,7 @@ class LargeLanguageModel(AIModel):
)
return 0

def _calc_response_usage(
def calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
) -> LLMUsage:
"""

+ 1
- 4
api/core/plugin/entities/parameters.py 查看文件

@@ -1,4 +1,5 @@
import enum
import json
from typing import Any, Optional, Union

from pydantic import BaseModel, Field, field_validator
@@ -162,8 +163,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
# Try to parse JSON string for arrays
if isinstance(value, str):
try:
import json

parsed_value = json.loads(value)
if isinstance(parsed_value, list):
return parsed_value
@@ -176,8 +175,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
# Try to parse JSON string for objects
if isinstance(value, str):
try:
import json

parsed_value = json.loads(value)
if isinstance(parsed_value, dict):
return parsed_value

+ 3
- 1
api/core/plugin/utils/chunk_merger.py 查看文件

@@ -82,7 +82,9 @@ def merge_blob_chunks(
message_class = type(resp)
merged_message = message_class(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
message=ToolInvokeMessage.BlobMessage(
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
),
meta=resp.meta,
)
yield cast(MessageType, merged_message)

+ 26
- 6
api/core/prompt/simple_prompt_transform.py 查看文件

@@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform):
with_memory_prompt=histories is not None,
)

variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
special_variable_keys_obj = prompt_template_config["special_variable_keys"]

for v in prompt_template_config["special_variable_keys"]:
# Type check for custom_variable_keys
if not isinstance(custom_variable_keys_obj, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
custom_variable_keys = cast(list[str], custom_variable_keys_obj)

# Type check for special_variable_keys
if not isinstance(special_variable_keys_obj, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
special_variable_keys = cast(list[str], special_variable_keys_obj)

variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}

for v in special_variable_keys:
# support #context#, #query# and #histories#
if v == "#context#":
variables["#context#"] = context or ""
@@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform):
variables["#histories#"] = histories or ""

prompt_template = prompt_template_config["prompt_template"]
if not isinstance(prompt_template, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}")

prompt = prompt_template.format(variables)

return prompt, prompt_template_config["prompt_rules"]
prompt_rules = prompt_template_config["prompt_rules"]
if not isinstance(prompt_rules, dict):
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")

return prompt, prompt_rules

def get_prompt_template(
self,
@@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform):
has_context: bool,
query_in_prompt: bool,
with_memory_prompt: bool = False,
):
) -> dict[str, object]:
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)

custom_variable_keys = []
special_variable_keys = []
custom_variable_keys: list[str] = []
special_variable_keys: list[str] = []

prompt = ""
for order in prompt_rules["system_prompt_orders"]:

+ 24
- 11
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py 查看文件

@@ -40,6 +40,19 @@ if TYPE_CHECKING:
MetadataFilter = Union[DictFilter, common_types.Filter]


class PathQdrantParams(BaseModel):
path: str


class UrlQdrantParams(BaseModel):
url: str
api_key: Optional[str]
timeout: float
verify: bool
grpc_port: int
prefer_grpc: bool


class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str] = None
@@ -50,7 +63,7 @@ class QdrantConfig(BaseModel):
replication_factor: int = 1
write_consistency_factor: int = 1

def to_qdrant_params(self):
def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams:
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
@@ -58,23 +71,23 @@ class QdrantConfig(BaseModel):
raise ValueError("Root path is not set")
path = os.path.join(self.root_path, path)

return {"path": path}
return PathQdrantParams(path=path)
else:
return {
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": self.endpoint.startswith("https"),
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
return UrlQdrantParams(
url=self.endpoint,
api_key=self.api_key,
timeout=self.timeout,
verify=self.endpoint.startswith("https"),
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
)


class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump())
self._distance_func = distance_func.upper()
self._group_id = group_id


+ 2
- 2
api/core/repositories/celery_workflow_node_execution_repository.py 查看文件

@@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER

# In-memory cache for workflow node executions
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
self._execution_cache = {}

# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
self._workflow_execution_mapping: dict[str, list[str]] = {}
self._workflow_execution_mapping = {}

logger.info(
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",

+ 1
- 1
api/core/variables/segment_group.py 查看文件

@@ -4,7 +4,7 @@ from .types import SegmentType

class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP
value: list[Segment]
value: list[Segment] = None # type: ignore

@property
def text(self):

+ 12
- 12
api/core/variables/segments.py 查看文件

@@ -74,12 +74,12 @@ class NoneSegment(Segment):

class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str
value: str = None # type: ignore


class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT
value: float
value: float = None # type: ignore
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
@@ -98,12 +98,12 @@ class FloatSegment(Segment):

class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER
value: int
value: int = None # type: ignore


class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any]
value: Mapping[str, Any] = None # type: ignore

@property
def text(self) -> str:
@@ -136,7 +136,7 @@ class ArraySegment(Segment):

class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
value: File
value: File = None # type: ignore

@property
def markdown(self) -> str:
@@ -153,17 +153,17 @@ class FileSegment(Segment):

class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool
value: bool = None # type: ignore


class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
value: Sequence[Any] = None # type: ignore


class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str]
value: Sequence[str] = None # type: ignore

@property
def text(self) -> str:
@@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):

class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int]
value: Sequence[float | int] = None # type: ignore


class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]
value: Sequence[Mapping[str, Any]] = None # type: ignore


class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File]
value: Sequence[File] = None # type: ignore

@property
def markdown(self) -> str:
@@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):

class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool]
value: Sequence[bool] = None # type: ignore


def get_segment_discriminator(v: Any) -> SegmentType | None:

+ 2
- 2
api/core/workflow/errors.py 查看文件

@@ -3,6 +3,6 @@ from core.workflow.nodes.base import BaseNode

class WorkflowNodeRunFailedError(Exception):
def __init__(self, node: BaseNode, err_msg: str):
self._node = node
self._error = err_msg
self.node = node
self.error = err_msg
super().__init__(f"Node {node.title} run failed: {err_msg}")

+ 2
- 2
api/core/workflow/nodes/list_operator/node.py 查看文件

@@ -67,8 +67,8 @@ class ListOperatorNode(BaseNode):
return "1"

def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
inputs: dict[str, Sequence[object]] = {}
process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {}

variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)

+ 2
- 1
api/core/workflow/nodes/llm/node.py 查看文件

@@ -1183,7 +1183,8 @@ def _combine_message_content_with_role(
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
raise NotImplementedError(f"Role {role} is not supported")
case _:
raise NotImplementedError(f"Role {role} is not supported")


def _render_jinja2_message(

+ 2
- 2
api/factories/file_factory.py 查看文件

@@ -462,9 +462,9 @@ class StorageKeyLoader:
upload_file_row = upload_files.get(model_id)
if upload_file_row is None:
raise ValueError(f"Upload file not found for id: {model_id}")
file._storage_key = upload_file_row.key
file.storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(model_id)
if tool_file_row is None:
raise ValueError(f"Tool file not found for id: {model_id}")
file._storage_key = tool_file_row.file_key
file.storage_key = tool_file_row.file_key

+ 4
- 1
api/fields/_value_type_serializer.py 查看文件

@@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return v.value_type.exposed_type().value
else:
return v["value_type"].exposed_type().value
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
return value_type.exposed_type().value

+ 11
- 3
api/libs/external_api.py 查看文件

@@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api):
headers["WWW-Authenticate"] = 'Bearer realm="api"'
return data, status_code, headers

_ = handle_http_exception

@api.errorhandler(ValueError)
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
@@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api):
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code

_ = handle_value_error

@api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
got_request_exception.send(current_app, exception=e)
@@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api):
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code

_ = handle_quota_exceeded

@api.errorhandler(Exception)
def handle_general_exception(e: Exception):
got_request_exception.send(current_app, exception=e)

status_code = 500
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
data = getattr(e, "data", {"message": http_status_message(status_code)})

# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
if not isinstance(data, Mapping):
if not isinstance(data, dict):
data = {"message": str(e)}

data.setdefault("code", "unknown")
@@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api):
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type]
current_app.log_exception(exc_info)

return data, status_code

_ = handle_general_exception


class ExternalApi(Api):
_authorizations = {

+ 0
- 7
api/libs/helper.py 查看文件

@@ -167,13 +167,6 @@ class DatetimeString:
return value


def _get_float(value):
try:
return float(value)
except (TypeError, ValueError):
raise ValueError(f"{value} is not a valid float")


def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones():
return timezone_string

+ 37
- 17
api/pyrightconfig.json 查看文件

@@ -1,24 +1,44 @@
{
"include": ["."],
"exclude": [".venv", "tests/", "migrations/"],
"ignore": [
"core/",
"controllers/",
"tasks/",
"services/",
"schedule/",
"extensions/",
"utils/",
"repositories/",
"libs/",
"fields/",
"factories/",
"events/",
"contexts/",
"constants/",
"commands.py"
"exclude": [
".venv",
"tests/",
"migrations/",
"core/rag",
"extensions",
"libs",
"controllers/console/datasets",
"controllers/service_api/dataset",
"core/ops",
"core/tools",
"core/model_runtime",
"core/workflow",
"core/app/app_config/easy_ui_based_app/dataset"
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [
"flask_restx",
"flask_login",
"opentelemetry.instrumentation.celery",
"opentelemetry.instrumentation.flask",
"opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis"
],
"reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint",
"reportUnknownArgumentType": "hint",
"reportUnknownVariableType": "hint",
"reportUnknownLambdaType": "hint",
"reportMissingParameterType": "hint",
"reportMissingTypeArgument": "hint",
"reportUnnecessaryContains": "hint",
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryCast": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",

"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonPlatform": "All"
}

+ 2
- 2
api/services/account_service.py 查看文件

@@ -1318,7 +1318,7 @@ class RegisterService:
def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None

@@ -1355,7 +1355,7 @@ class RegisterService:
}

@classmethod
def _get_invitation_by_token(
def get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
if workspace_id is not None and email is not None:

+ 35
- 19
api/services/annotation_service.py 查看文件

@@ -349,7 +349,7 @@ class AppAnnotationService:

try:
# Skip the first row
df = pd.read_csv(file, dtype=str)
df = pd.read_csv(file.stream, dtype=str)
result = []
for _, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]}
@@ -463,15 +463,23 @@ class AppAnnotationService:
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
return {"enabled": False}

@classmethod
@@ -506,15 +514,23 @@ class AppAnnotationService:

collection_binding_detail = annotation_setting.collection_binding_detail

return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}

@classmethod
def clear_all_annotations(cls, app_id: str):

+ 1
- 0
api/services/clear_free_plan_tenant_expired_logs.py 查看文件

@@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs:
datetime.timedelta(hours=1),
]

tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)

+ 10
- 56
api/services/dataset_service.py 查看文件

@@ -134,11 +134,14 @@ class DatasetService:

# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids(
"knowledge",
tenant_id, # ty: ignore [invalid-argument-type]
tag_ids,
)
if tenant_id is not None:
target_ids = TagService.get_target_ids_by_tag_ids(
"knowledge",
tenant_id,
tag_ids,
)
else:
target_ids = []
if target_ids and len(target_ids) > 0:
query = query.where(Dataset.id.in_(target_ids))
else:
@@ -987,7 +990,8 @@ class DocumentService:
for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict
]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
if dataset.doc_form is not None:
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)

for document in documents:
db.session.delete(document)
@@ -2688,56 +2692,6 @@ class SegmentService:

return paginated_segments.items, paginated_segments.total

@classmethod
def update_segment_by_id(
cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")

# check user's model setting
DatasetService.check_dataset_model_setting(dataset)

# check document
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")

# check embedding model setting if high quality
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=user_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)

# check segment
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")

# validate and update segment
cls.segment_create_args_validate(segment_data, document)
updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)

return updated_segment, document

@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID."""

+ 1
- 1
api/services/external_knowledge_service.py 查看文件

@@ -181,7 +181,7 @@ class ExternalDatasetService:
do http request depending on api bundle
"""

kwargs = {
kwargs: dict[str, Any] = {
"url": settings.url,
"headers": settings.headers,
"follow_redirects": True,

+ 2
- 2
api/services/file_service.py 查看文件

@@ -1,7 +1,7 @@
import hashlib
import os
import uuid
from typing import Any, Literal, Union
from typing import Literal, Union

from werkzeug.exceptions import NotFound

@@ -35,7 +35,7 @@ class FileService:
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser, Any],
user: Union[Account, EndUser],
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:

+ 10
- 7
api/services/model_load_balancing_service.py 查看文件

@@ -165,7 +165,7 @@ class ModelLoadBalancingService:

try:
if load_balancing_config.encrypted_config:
credentials = json.loads(load_balancing_config.encrypted_config)
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
@@ -180,11 +180,13 @@ class ModelLoadBalancingService:
for variable in credential_secret_variables:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable), # ty: ignore [invalid-argument-type]
decoding_rsa_key,
decoding_cipher_rsa,
)
token_value = credentials.get(variable)
if isinstance(token_value, str):
credentials[variable] = encrypter.decrypt_token_with_decoding(
token_value,
decoding_rsa_key,
decoding_cipher_rsa,
)
except ValueError:
pass

@@ -345,8 +347,9 @@ class ModelLoadBalancingService:
credential_id = config.get("credential_id")
enabled = config.get("enabled")

credential_record: ProviderCredential | ProviderModelCredential | None = None

if credential_id:
credential_record: ProviderCredential | ProviderModelCredential | None = None
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)

+ 1
- 0
api/services/plugin/plugin_migration.py 查看文件

@@ -99,6 +99,7 @@ class PluginMigration:
datetime.timedelta(hours=1),
]

tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)

+ 5
- 5
api/services/tools/builtin_tools_manage_service.py 查看文件

@@ -223,8 +223,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
with Session(db.engine) as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@@ -285,9 +285,9 @@ class BuiltinToolManageService:

session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}

@staticmethod

+ 14
- 2
api/services/workflow/workflow_converter.py 查看文件

@@ -18,6 +18,7 @@ from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes import NodeType
from events.app_event import app_was_created
from extensions.ext_database import db
@@ -420,7 +421,11 @@ class WorkflowConverter:
query_in_prompt=False,
)

template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")

template = prompt_template_obj.template
if not template:
prompts = []
else:
@@ -457,7 +462,11 @@ class WorkflowConverter:
query_in_prompt=False,
)

template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")

template = prompt_template_obj.template
template = self._replace_template_variables(
template=template,
variables=start_node["data"]["variables"],
@@ -467,6 +476,9 @@ class WorkflowConverter:
prompts = {"text": template}

prompt_rules = prompt_template_config["prompt_rules"]
if not isinstance(prompt_rules, dict):
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")

role_prefix = {
"user": prompt_rules.get("human_prefix", "Human"),
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),

+ 2
- 2
api/services/workflow_service.py 查看文件

@@ -769,10 +769,10 @@ class WorkflowService:
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node = e._node
node = e.node
run_succeeded = False
node_run_result = None
error = e._error
error = e.error

# Create a NodeExecution domain model
node_execution = WorkflowNodeExecution(

+ 1
- 1
api/services/workspace_service.py 查看文件

@@ -12,7 +12,7 @@ class WorkspaceService:
def get_tenant_info(cls, tenant: Tenant):
if not tenant:
return None
tenant_info = {
tenant_info: dict[str, object] = {
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,

+ 2
- 2
api/tests/test_containers_integration_tests/services/test_account_service.py 查看文件

@@ -3278,7 +3278,7 @@ class TestRegisterService:
redis_client.setex(cache_key, 24 * 60 * 60, account_id)

# Execute invitation retrieval
result = RegisterService._get_invitation_by_token(
result = RegisterService.get_invitation_by_token(
token=token,
workspace_id=workspace_id,
email=email,
@@ -3316,7 +3316,7 @@ class TestRegisterService:
redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data))

# Execute invitation retrieval
result = RegisterService._get_invitation_by_token(token=token)
result = RegisterService.get_invitation_by_token(token=token)

# Verify result contains expected data
assert result is not None

+ 2
- 1
api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py 查看文件

@@ -14,6 +14,7 @@ from core.app.app_config.entities import (
VariableEntityType,
)
from core.model_runtime.entities.llm_entities import LLMMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.account import Account, Tenant
from models.api_based_extension import APIBasedExtension
from models.model import App, AppMode, AppModelConfig
@@ -37,7 +38,7 @@ class TestWorkflowConverter:
# Setup default mock returns
mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
mock_prompt_transform.return_value.get_prompt_template.return_value = {
"prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(),
"prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"),
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
}
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()

+ 0
- 0
api/tests/unit_tests/services/test_account_service.py 查看文件


部分文件因为文件数量过多而无法显示

正在加载...
取消
保存