Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.9.0
| @@ -511,7 +511,7 @@ def add_qdrant_index(field: str): | |||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||
| from qdrant_client.http.models import PayloadSchemaType | |||
| from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig | |||
| from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig | |||
| for binding in bindings: | |||
| if dify_config.QDRANT_URL is None: | |||
| @@ -525,7 +525,21 @@ def add_qdrant_index(field: str): | |||
| prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, | |||
| ) | |||
| try: | |||
| client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) | |||
| params = qdrant_config.to_qdrant_params() | |||
| # Check the type before using | |||
| if isinstance(params, PathQdrantParams): | |||
| # PathQdrantParams case | |||
| client = qdrant_client.QdrantClient(path=params.path) | |||
| else: | |||
| # UrlQdrantParams case - params is UrlQdrantParams | |||
| client = qdrant_client.QdrantClient( | |||
| url=params.url, | |||
| api_key=params.api_key, | |||
| timeout=int(params.timeout), | |||
| verify=params.verify, | |||
| grpc_port=params.grpc_port, | |||
| prefer_grpc=params.prefer_grpc, | |||
| ) | |||
| # create payload index | |||
| client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) | |||
| create_count += 1 | |||
| @@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] | |||
| AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) | |||
| _doc_extensions: list[str] | |||
| if dify_config.ETL_TYPE == "Unstructured": | |||
| DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] | |||
| DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) | |||
| _doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] | |||
| _doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) | |||
| if dify_config.UNSTRUCTURED_API_URL: | |||
| DOCUMENT_EXTENSIONS.append("ppt") | |||
| DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) | |||
| _doc_extensions.append("ppt") | |||
| else: | |||
| DOCUMENT_EXTENSIONS = [ | |||
| _doc_extensions = [ | |||
| "txt", | |||
| "markdown", | |||
| "md", | |||
| @@ -38,4 +38,4 @@ else: | |||
| "vtt", | |||
| "properties", | |||
| ] | |||
| DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) | |||
| DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] | |||
| @@ -8,7 +8,6 @@ if TYPE_CHECKING: | |||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | |||
| from core.tools.plugin_tool.provider import PluginToolProviderController | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| """ | |||
| @@ -43,56 +43,64 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm" | |||
| api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies") | |||
| # Import other controllers | |||
| from . import admin, apikey, extension, feature, ping, setup, version | |||
| from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport] | |||
| # Import app controllers | |||
| from .app import ( | |||
| advanced_prompt_template, | |||
| agent, | |||
| annotation, | |||
| app, | |||
| audio, | |||
| completion, | |||
| conversation, | |||
| conversation_variables, | |||
| generator, | |||
| mcp_server, | |||
| message, | |||
| model_config, | |||
| ops_trace, | |||
| site, | |||
| statistic, | |||
| workflow, | |||
| workflow_app_log, | |||
| workflow_draft_variable, | |||
| workflow_run, | |||
| workflow_statistic, | |||
| advanced_prompt_template, # pyright: ignore[reportUnusedImport] | |||
| agent, # pyright: ignore[reportUnusedImport] | |||
| annotation, # pyright: ignore[reportUnusedImport] | |||
| app, # pyright: ignore[reportUnusedImport] | |||
| audio, # pyright: ignore[reportUnusedImport] | |||
| completion, # pyright: ignore[reportUnusedImport] | |||
| conversation, # pyright: ignore[reportUnusedImport] | |||
| conversation_variables, # pyright: ignore[reportUnusedImport] | |||
| generator, # pyright: ignore[reportUnusedImport] | |||
| mcp_server, # pyright: ignore[reportUnusedImport] | |||
| message, # pyright: ignore[reportUnusedImport] | |||
| model_config, # pyright: ignore[reportUnusedImport] | |||
| ops_trace, # pyright: ignore[reportUnusedImport] | |||
| site, # pyright: ignore[reportUnusedImport] | |||
| statistic, # pyright: ignore[reportUnusedImport] | |||
| workflow, # pyright: ignore[reportUnusedImport] | |||
| workflow_app_log, # pyright: ignore[reportUnusedImport] | |||
| workflow_draft_variable, # pyright: ignore[reportUnusedImport] | |||
| workflow_run, # pyright: ignore[reportUnusedImport] | |||
| workflow_statistic, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| # Import auth controllers | |||
| from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server | |||
| from .auth import ( | |||
| activate, # pyright: ignore[reportUnusedImport] | |||
| data_source_bearer_auth, # pyright: ignore[reportUnusedImport] | |||
| data_source_oauth, # pyright: ignore[reportUnusedImport] | |||
| forgot_password, # pyright: ignore[reportUnusedImport] | |||
| login, # pyright: ignore[reportUnusedImport] | |||
| oauth, # pyright: ignore[reportUnusedImport] | |||
| oauth_server, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| # Import billing controllers | |||
| from .billing import billing, compliance | |||
| from .billing import billing, compliance # pyright: ignore[reportUnusedImport] | |||
| # Import datasets controllers | |||
| from .datasets import ( | |||
| data_source, | |||
| datasets, | |||
| datasets_document, | |||
| datasets_segments, | |||
| external, | |||
| hit_testing, | |||
| metadata, | |||
| website, | |||
| data_source, # pyright: ignore[reportUnusedImport] | |||
| datasets, # pyright: ignore[reportUnusedImport] | |||
| datasets_document, # pyright: ignore[reportUnusedImport] | |||
| datasets_segments, # pyright: ignore[reportUnusedImport] | |||
| external, # pyright: ignore[reportUnusedImport] | |||
| hit_testing, # pyright: ignore[reportUnusedImport] | |||
| metadata, # pyright: ignore[reportUnusedImport] | |||
| website, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| # Import explore controllers | |||
| from .explore import ( | |||
| installed_app, | |||
| parameter, | |||
| recommended_app, | |||
| saved_message, | |||
| installed_app, # pyright: ignore[reportUnusedImport] | |||
| parameter, # pyright: ignore[reportUnusedImport] | |||
| recommended_app, # pyright: ignore[reportUnusedImport] | |||
| saved_message, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| # Explore Audio | |||
| @@ -167,18 +175,18 @@ api.add_resource( | |||
| ) | |||
| # Import tag controllers | |||
| from .tag import tags | |||
| from .tag import tags # pyright: ignore[reportUnusedImport] | |||
| # Import workspace controllers | |||
| from .workspace import ( | |||
| account, | |||
| agent_providers, | |||
| endpoint, | |||
| load_balancing_config, | |||
| members, | |||
| model_providers, | |||
| models, | |||
| plugin, | |||
| tool_providers, | |||
| workspace, | |||
| account, # pyright: ignore[reportUnusedImport] | |||
| agent_providers, # pyright: ignore[reportUnusedImport] | |||
| endpoint, # pyright: ignore[reportUnusedImport] | |||
| load_balancing_config, # pyright: ignore[reportUnusedImport] | |||
| members, # pyright: ignore[reportUnusedImport] | |||
| model_providers, # pyright: ignore[reportUnusedImport] | |||
| models, # pyright: ignore[reportUnusedImport] | |||
| plugin, # pyright: ignore[reportUnusedImport] | |||
| tool_providers, # pyright: ignore[reportUnusedImport] | |||
| workspace, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| @@ -1,8 +1,9 @@ | |||
| from typing import Any, Optional | |||
| from typing import Optional | |||
| import flask_restx | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, fields, marshal_with | |||
| from flask_restx._http import HTTPStatus | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden | |||
| @@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model): | |||
| ).scalar_one_or_none() | |||
| if resource is None: | |||
| flask_restx.abort(404, message=f"{resource_model.__name__} not found.") | |||
| flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") | |||
| return resource | |||
| @@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource): | |||
| method_decorators = [account_initialization_required, login_required, setup_required] | |||
| resource_type: str | None = None | |||
| resource_model: Optional[Any] = None | |||
| resource_model: Optional[type] = None | |||
| resource_id_field: str | None = None | |||
| token_prefix: str | None = None | |||
| max_keys = 10 | |||
| @@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource): | |||
| if current_key_count >= self.max_keys: | |||
| flask_restx.abort( | |||
| 400, | |||
| HTTPStatus.BAD_REQUEST, | |||
| message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |||
| custom="max_keys_exceeded", | |||
| ) | |||
| @@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource): | |||
| method_decorators = [account_initialization_required, login_required, setup_required] | |||
| resource_type: str | None = None | |||
| resource_model: Optional[Any] = None | |||
| resource_model: Optional[type] = None | |||
| resource_id_field: str | None = None | |||
| def delete(self, resource_id, api_key_id): | |||
| @@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource): | |||
| ) | |||
| if key is None: | |||
| flask_restx.abort(404, message="API key not found") | |||
| flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") | |||
| db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() | |||
| db.session.commit() | |||
| @@ -115,6 +115,10 @@ class AppListApi(Resource): | |||
| raise BadRequest("mode is required") | |||
| app_service = AppService() | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| if current_user.current_tenant_id is None: | |||
| raise ValueError("current_user.current_tenant_id cannot be None") | |||
| app = app_service.create_app(current_user.current_tenant_id, args, current_user) | |||
| return app, 201 | |||
| @@ -161,14 +165,26 @@ class AppApi(Resource): | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app(app_model, args) | |||
| # Construct ArgsDict from parsed arguments | |||
| from services.app_service import AppService as AppServiceType | |||
| args_dict: AppServiceType.ArgsDict = { | |||
| "name": args["name"], | |||
| "description": args.get("description", ""), | |||
| "icon_type": args.get("icon_type", ""), | |||
| "icon": args.get("icon", ""), | |||
| "icon_background": args.get("icon_background", ""), | |||
| "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), | |||
| "max_active_requests": args.get("max_active_requests", 0), | |||
| } | |||
| app_model = app_service.update_app(app_model, args_dict) | |||
| return app_model | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def delete(self, app_model): | |||
| """Delete app""" | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| @@ -224,10 +240,10 @@ class AppCopyApi(Resource): | |||
| class AppExportApi(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| """Export app""" | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| @@ -263,7 +279,7 @@ class AppNameApi(Resource): | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_name(app_model, args.get("name")) | |||
| app_model = app_service.update_app_name(app_model, args["name"]) | |||
| return app_model | |||
| @@ -285,7 +301,7 @@ class AppIconApi(Resource): | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) | |||
| app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") | |||
| return app_model | |||
| @@ -306,7 +322,7 @@ class AppSiteStatus(Resource): | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) | |||
| app_model = app_service.update_app_site_status(app_model, args["enable_site"]) | |||
| return app_model | |||
| @@ -327,7 +343,7 @@ class AppApiStatus(Resource): | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) | |||
| app_model = app_service.update_app_api_status(app_model, args["enable_api"]) | |||
| return app_model | |||
| @@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource): | |||
| class ChatMessageTextApi(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def post(self, app_model: App): | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| @@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource): | |||
| class TextModesApi(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| @@ -1,6 +1,5 @@ | |||
| import logging | |||
| import flask_login | |||
| from flask import request | |||
| from flask_restx import Resource, reqparse | |||
| from werkzeug.exceptions import InternalServerError, NotFound | |||
| @@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs import helper | |||
| from libs.helper import uuid_value | |||
| from libs.login import login_required | |||
| from libs.login import current_user, login_required | |||
| from models import Account | |||
| from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.llm import InvokeRateLimitError | |||
| @@ -56,11 +56,11 @@ class CompletionMessageApi(Resource): | |||
| streaming = args["response_mode"] != "blocking" | |||
| args["auto_generate_name"] = False | |||
| account = flask_login.current_user | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account or EndUser instance") | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource): | |||
| @account_initialization_required | |||
| @get_app_model(mode=AppMode.COMPLETION) | |||
| def post(self, app_model, task_id): | |||
| account = flask_login.current_user | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) | |||
| return {"result": "success"}, 200 | |||
| @@ -123,11 +123,11 @@ class ChatMessageApi(Resource): | |||
| if external_trace_id: | |||
| args["external_trace_id"] = external_trace_id | |||
| account = flask_login.current_user | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account or EndUser instance") | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -161,9 +161,9 @@ class ChatMessageStopApi(Resource): | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) | |||
| def post(self, app_model, task_id): | |||
| account = flask_login.current_user | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) | |||
| return {"result": "success"}, 200 | |||
| @@ -22,7 +22,7 @@ from fields.conversation_fields import ( | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.helper import DatetimeString | |||
| from libs.login import login_required | |||
| from models import Conversation, EndUser, Message, MessageAnnotation | |||
| from models import Account, Conversation, EndUser, Message, MessageAnnotation | |||
| from models.model import AppMode | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| @@ -124,6 +124,8 @@ class CompletionConversationDetailApi(Resource): | |||
| conversation_id = str(conversation_id) | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| ConversationService.delete(app_model, conversation_id, current_user) | |||
| except ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| @@ -282,6 +284,8 @@ class ChatConversationDetailApi(Resource): | |||
| conversation_id = str(conversation_id) | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| ConversationService.delete(app_model, conversation_id, current_user) | |||
| except ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| @@ -1,6 +1,5 @@ | |||
| import logging | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, fields, marshal_with, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from sqlalchemy import exists, select | |||
| @@ -27,7 +26,8 @@ from extensions.ext_database import db | |||
| from fields.conversation_fields import annotation_fields, message_detail_fields | |||
| from libs.helper import uuid_value | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| from libs.login import login_required | |||
| from libs.login import current_user, login_required | |||
| from models.account import Account | |||
| from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback | |||
| from services.annotation_service import AppAnnotationService | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| @@ -118,11 +118,14 @@ class ChatMessageListApi(Resource): | |||
| class MessageFeedbackApi(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def post(self, app_model): | |||
| if current_user is None: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("message_id", required=True, type=uuid_value, location="json") | |||
| parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") | |||
| @@ -167,6 +170,8 @@ class MessageAnnotationApi(Resource): | |||
| @get_app_model | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_model): | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -182,10 +187,10 @@ class MessageAnnotationApi(Resource): | |||
| class MessageAnnotationCountApi(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() | |||
| @@ -10,7 +10,7 @@ from extensions.ext_database import db | |||
| from fields.app_fields import app_site_fields | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import login_required | |||
| from models import Site | |||
| from models import Account, Site | |||
| def parse_app_site_args(): | |||
| @@ -75,6 +75,8 @@ class AppSite(Resource): | |||
| if value is not None: | |||
| setattr(site, attr_name, value) | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| site.updated_by = current_user.id | |||
| site.updated_at = naive_utc_now() | |||
| db.session.commit() | |||
| @@ -99,6 +101,8 @@ class AppSiteAccessTokenReset(Resource): | |||
| raise NotFound | |||
| site.code = Site.generate_code(16) | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| site.updated_by = current_user.id | |||
| site.updated_at = naive_utc_now() | |||
| db.session.commit() | |||
| @@ -18,10 +18,10 @@ from models import AppMode, Message | |||
| class DailyMessageStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -75,10 +75,10 @@ WHERE | |||
| class DailyConversationStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource): | |||
| class DailyTerminalsStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -184,10 +184,10 @@ WHERE | |||
| class DailyTokenCostStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -320,10 +320,10 @@ ORDER BY | |||
| class UserSatisfactionRateStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -443,10 +443,10 @@ WHERE | |||
| class TokensPerSecondStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -18,10 +18,10 @@ from models.model import AppMode | |||
| class WorkflowDailyRunsStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -80,10 +80,10 @@ WHERE | |||
| class WorkflowDailyTerminalsStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -142,10 +142,10 @@ WHERE | |||
| class WorkflowDailyTokenCostStatistic(Resource): | |||
| @get_app_model | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| account = current_user | |||
| @@ -77,6 +77,9 @@ class OAuthCallback(Resource): | |||
| if state: | |||
| invite_token = state | |||
| if not code: | |||
| return {"error": "Authorization code is required"}, 400 | |||
| try: | |||
| token = oauth_provider.get_access_token(code) | |||
| user_info = oauth_provider.get_user_info(token) | |||
| @@ -86,7 +89,7 @@ class OAuthCallback(Resource): | |||
| return {"error": "OAuth process failed"}, 400 | |||
| if invite_token and RegisterService.is_valid_invite_token(invite_token): | |||
| invitation = RegisterService._get_invitation_by_token(token=invite_token) | |||
| invitation = RegisterService.get_invitation_by_token(token=invite_token) | |||
| if invitation: | |||
| invitation_email = invitation.get("email", None) | |||
| if invitation_email != user_info.email: | |||
| @@ -1,6 +1,5 @@ | |||
| import logging | |||
| from flask_login import current_user | |||
| from flask_restx import reqparse | |||
| from werkzeug.exceptions import InternalServerError, NotFound | |||
| @@ -28,6 +27,8 @@ from extensions.ext_database import db | |||
| from libs import helper | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.helper import uuid_value | |||
| from libs.login import current_user | |||
| from models import Account | |||
| from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.llm import InvokeRateLimitError | |||
| @@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource): | |||
| db.session.commit() | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming | |||
| ) | |||
| @@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource): | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return {"result": "success"}, 200 | |||
| @@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource): | |||
| db.session.commit() | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True | |||
| ) | |||
| @@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource): | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return {"result": "success"}, 200 | |||
| @@ -1,4 +1,3 @@ | |||
| from flask_login import current_user | |||
| from flask_restx import marshal_with, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from sqlalchemy.orm import Session | |||
| @@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields | |||
| from libs.helper import uuid_value | |||
| from libs.login import current_user | |||
| from models import Account | |||
| from models.model import AppMode | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError | |||
| @@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource): | |||
| pinned = args["pinned"] == "true" | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| with Session(db.engine) as session: | |||
| return WebConversationService.pagination_by_last_id( | |||
| session=session, | |||
| @@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource): | |||
| conversation_id = str(c_id) | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| ConversationService.delete(app_model, conversation_id, current_user) | |||
| except ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| @@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource): | |||
| args = parser.parse_args() | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| return ConversationService.rename( | |||
| app_model, conversation_id, current_user, args["name"], args["auto_generate"] | |||
| ) | |||
| @@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource): | |||
| conversation_id = str(c_id) | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| WebConversationService.pin(app_model, conversation_id, current_user) | |||
| except ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| @@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource): | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| WebConversationService.unpin(app_model, conversation_id, current_user) | |||
| return {"result": "success"} | |||
| @@ -2,7 +2,6 @@ import logging | |||
| from typing import Any | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, inputs, marshal_with, reqparse | |||
| from sqlalchemy import and_ | |||
| from werkzeug.exceptions import BadRequest, Forbidden, NotFound | |||
| @@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi | |||
| from extensions.ext_database import db | |||
| from fields.installed_app_fields import installed_app_list_fields | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import login_required | |||
| from models import App, InstalledApp, RecommendedApp | |||
| from libs.login import current_user, login_required | |||
| from models import Account, App, InstalledApp, RecommendedApp | |||
| from services.account_service import TenantService | |||
| from services.app_service import AppService | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| @@ -29,6 +28,8 @@ class InstalledAppsListApi(Resource): | |||
| @marshal_with(installed_app_list_fields) | |||
| def get(self): | |||
| app_id = request.args.get("app_id", default=None, type=str) | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| current_tenant_id = current_user.current_tenant_id | |||
| if app_id: | |||
| @@ -40,6 +41,8 @@ class InstalledAppsListApi(Resource): | |||
| else: | |||
| installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() | |||
| if current_user.current_tenant is None: | |||
| raise ValueError("current_user.current_tenant must not be None") | |||
| current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) | |||
| installed_app_list: list[dict[str, Any]] = [ | |||
| { | |||
| @@ -115,6 +118,8 @@ class InstalledAppsListApi(Resource): | |||
| if recommended_app is None: | |||
| raise NotFound("App not found") | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| current_tenant_id = current_user.current_tenant_id | |||
| app = db.session.query(App).where(App.id == args["app_id"]).first() | |||
| @@ -154,6 +159,8 @@ class InstalledAppApi(InstalledAppResource): | |||
| """ | |||
| def delete(self, installed_app): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| if installed_app.app_owner_tenant_id == current_user.current_tenant_id: | |||
| raise BadRequest("You can't uninstall an app owned by the current tenant") | |||
| @@ -1,6 +1,5 @@ | |||
| import logging | |||
| from flask_login import current_user | |||
| from flask_restx import marshal_with, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from werkzeug.exceptions import InternalServerError, NotFound | |||
| @@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError | |||
| from fields.message_fields import message_infinite_scroll_pagination_fields | |||
| from libs import helper | |||
| from libs.helper import uuid_value | |||
| from libs.login import current_user | |||
| from models import Account | |||
| from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import MoreLikeThisDisabledError | |||
| @@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource): | |||
| args = parser.parse_args() | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| return MessageService.pagination_by_first_id( | |||
| app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] | |||
| ) | |||
| @@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource): | |||
| args = parser.parse_args() | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| MessageService.create_feedback( | |||
| app_model=app_model, | |||
| message_id=message_id, | |||
| @@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): | |||
| streaming = args["response_mode"] == "streaming" | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| response = AppGenerateService.generate_more_like_this( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| @@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): | |||
| message_id = str(message_id) | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| questions = MessageService.get_suggested_questions_after_answer( | |||
| app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE | |||
| ) | |||
| @@ -1,11 +1,10 @@ | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, fields, marshal_with, reqparse | |||
| from constants.languages import languages | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required | |||
| from libs.helper import AppIconUrlField | |||
| from libs.login import login_required | |||
| from libs.login import current_user, login_required | |||
| from services.recommended_app_service import RecommendedAppService | |||
| app_fields = { | |||
| @@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource): | |||
| parser.add_argument("language", type=str, location="args") | |||
| args = parser.parse_args() | |||
| if args.get("language") and args.get("language") in languages: | |||
| language_prefix = args.get("language") | |||
| language = args.get("language") | |||
| if language and language in languages: | |||
| language_prefix = language | |||
| elif current_user and current_user.interface_language: | |||
| language_prefix = current_user.interface_language | |||
| else: | |||
| @@ -1,4 +1,3 @@ | |||
| from flask_login import current_user | |||
| from flask_restx import fields, marshal_with, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from fields.conversation_fields import message_file_fields | |||
| from libs.helper import TimestampField, uuid_value | |||
| from libs.login import current_user | |||
| from models import Account | |||
| from services.errors.message import MessageNotExistsError | |||
| from services.saved_message_service import SavedMessageService | |||
| @@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource): | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) | |||
| def post(self, installed_app): | |||
| @@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource): | |||
| args = parser.parse_args() | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| SavedMessageService.save(app_model, current_user, args["message_id"]) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource): | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user must be an Account instance") | |||
| SavedMessageService.delete(app_model, current_user, message_id) | |||
| return {"result": "success"}, 204 | |||
| @@ -22,6 +22,7 @@ from controllers.console.wraps import ( | |||
| ) | |||
| from fields.file_fields import file_fields, upload_config_fields | |||
| from libs.login import login_required | |||
| from models import Account | |||
| from services.file_service import FileService | |||
| PREVIEW_WORDS_LIMIT = 3000 | |||
| @@ -68,6 +69,8 @@ class FileApi(Resource): | |||
| source = None | |||
| try: | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| upload_file = FileService.upload_file( | |||
| filename=file.filename, | |||
| content=file.read(), | |||
| @@ -34,14 +34,14 @@ class VersionApi(Resource): | |||
| return result | |||
| try: | |||
| response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) | |||
| response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) | |||
| except Exception as error: | |||
| logger.warning("Check update version error: %s.", str(error)) | |||
| result["version"] = args.get("current_version") | |||
| result["version"] = args["current_version"] | |||
| return result | |||
| content = json.loads(response.content) | |||
| if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): | |||
| if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): | |||
| result["version"] = content["version"] | |||
| result["release_date"] = content["releaseDate"] | |||
| result["release_notes"] = content["releaseNotes"] | |||
| @@ -49,6 +49,8 @@ class AccountInitApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| if account.status == "active": | |||
| @@ -102,6 +104,8 @@ class AccountProfileApi(Resource): | |||
| @marshal_with(account_fields) | |||
| @enterprise_license_required | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| return current_user | |||
| @@ -111,6 +115,8 @@ class AccountNameApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| @@ -130,6 +136,8 @@ class AccountAvatarApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("avatar", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| @@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("interface_language", type=supported_language, required=True, location="json") | |||
| args = parser.parse_args() | |||
| @@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") | |||
| args = parser.parse_args() | |||
| @@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("timezone", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| @@ -194,6 +208,8 @@ class AccountPasswordApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("password", type=str, required=False, location="json") | |||
| parser.add_argument("new_password", type=str, required=True, location="json") | |||
| @@ -228,6 +244,8 @@ class AccountIntegrateApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(integrate_list_fields) | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() | |||
| @@ -268,6 +286,8 @@ class AccountDeleteVerifyApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| token, code = AccountService.generate_account_deletion_verification_code(account) | |||
| @@ -281,6 +301,8 @@ class AccountDeleteApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| @@ -321,6 +343,8 @@ class EducationVerifyApi(Resource): | |||
| @cloud_edition_billing_enabled | |||
| @marshal_with(verify_fields) | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| return BillingService.EducationIdentity.verify(account.id, account.email) | |||
| @@ -340,6 +364,8 @@ class EducationApi(Resource): | |||
| @only_edition_cloud | |||
| @cloud_edition_billing_enabled | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| @@ -357,6 +383,8 @@ class EducationApi(Resource): | |||
| @cloud_edition_billing_enabled | |||
| @marshal_with(status_fields) | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| res = BillingService.EducationIdentity.status(account.id) | |||
| @@ -421,6 +449,8 @@ class ChangeEmailSendEmailApi(Resource): | |||
| raise InvalidTokenError() | |||
| user_email = reset_data.get("email", "") | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if user_email != current_user.email: | |||
| raise InvalidEmailError() | |||
| else: | |||
| @@ -501,6 +531,8 @@ class ChangeEmailResetApi(Resource): | |||
| AccountService.revoke_change_email_token(args["token"]) | |||
| old_email = reset_data.get("old_email", "") | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if current_user.email != old_email: | |||
| raise AccountNotFound() | |||
| @@ -1,8 +1,8 @@ | |||
| from urllib import parse | |||
| from flask import request | |||
| from flask import abort, request | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, abort, marshal_with, reqparse | |||
| from flask_restx import Resource, marshal_with, reqparse | |||
| import services | |||
| from configs import dify_config | |||
| @@ -41,6 +41,10 @@ class MemberListApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_with_role_list_fields) | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| members = TenantService.get_tenant_members(current_user.current_tenant) | |||
| return {"result": "success", "accounts": members}, 200 | |||
| @@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource): | |||
| if not TenantAccountRole.is_non_owner_role(invitee_role): | |||
| return {"code": "invalid-role", "message": "Invalid role"}, 400 | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| inviter = current_user | |||
| if not inviter.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| invitation_results = [] | |||
| console_web_url = dify_config.CONSOLE_WEB_URL | |||
| @@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource): | |||
| for invitee_email in invitee_emails: | |||
| try: | |||
| if not inviter.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| token = RegisterService.invite_new_member( | |||
| inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter | |||
| ) | |||
| @@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource): | |||
| return { | |||
| "result": "success", | |||
| "invitation_results": invitation_results, | |||
| "tenant_id": str(current_user.current_tenant.id), | |||
| "tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "", | |||
| }, 201 | |||
| @@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, member_id): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| member = db.session.query(Account).where(Account.id == str(member_id)).first() | |||
| if member is None: | |||
| abort(404) | |||
| @@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource): | |||
| except Exception as e: | |||
| raise ValueError(str(e)) | |||
| return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 | |||
| return { | |||
| "result": "success", | |||
| "tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "", | |||
| }, 200 | |||
| class MemberUpdateRoleApi(Resource): | |||
| @@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource): | |||
| if not TenantAccountRole.is_valid_role(new_role): | |||
| return {"code": "invalid-role", "message": "Invalid role"}, 400 | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| member = db.session.get(Account, str(member_id)) | |||
| if not member: | |||
| abort(404) | |||
| @@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(account_with_role_list_fields) | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| members = TenantService.get_dataset_operator_members(current_user.current_tenant) | |||
| return {"result": "success", "accounts": members}, 200 | |||
| @@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource): | |||
| raise EmailSendIpLimitError() | |||
| # check if the current user is the owner of the workspace | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| if not TenantService.is_owner(current_user, current_user.current_tenant): | |||
| raise NotOwnerError() | |||
| @@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource): | |||
| account=current_user, | |||
| email=email, | |||
| language=language, | |||
| workspace_name=current_user.current_tenant.name, | |||
| workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", | |||
| ) | |||
| return {"result": "success", "data": token} | |||
| @@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource): | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| # check if the current user is the owner of the workspace | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| if not TenantService.is_owner(current_user, current_user.current_tenant): | |||
| raise NotOwnerError() | |||
| @@ -256,6 +289,10 @@ class OwnerTransfer(Resource): | |||
| args = parser.parse_args() | |||
| # check if the current user is the owner of the workspace | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| if not TenantService.is_owner(current_user, current_user.current_tenant): | |||
| raise NotOwnerError() | |||
| @@ -274,9 +311,11 @@ class OwnerTransfer(Resource): | |||
| member = db.session.get(Account, str(member_id)) | |||
| if not member: | |||
| abort(404) | |||
| else: | |||
| member_account = member | |||
| if not TenantService.is_member(member_account, current_user.current_tenant): | |||
| return # Never reached, but helps type checker | |||
| if not current_user.current_tenant: | |||
| raise ValueError("No current tenant") | |||
| if not TenantService.is_member(member, current_user.current_tenant): | |||
| raise MemberNotInTenantError() | |||
| try: | |||
| @@ -286,13 +325,13 @@ class OwnerTransfer(Resource): | |||
| AccountService.send_new_owner_transfer_notify_email( | |||
| account=member, | |||
| email=member.email, | |||
| workspace_name=current_user.current_tenant.name, | |||
| workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", | |||
| ) | |||
| AccountService.send_old_owner_transfer_notify_email( | |||
| account=current_user, | |||
| email=current_user.email, | |||
| workspace_name=current_user.current_tenant.name, | |||
| workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", | |||
| new_owner_email=member.email, | |||
| ) | |||
| @@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.helper import StrLen, uuid_value | |||
| from libs.login import login_required | |||
| from models.account import Account | |||
| from services.billing_service import BillingService | |||
| from services.model_provider_service import ModelProviderService | |||
| @@ -21,6 +22,10 @@ class ModelProviderListApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| @@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| tenant_id = current_user.current_tenant_id | |||
| # if credential_id is not provided, return current used credential | |||
| parser = reqparse.RequestParser() | |||
| @@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| @@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource): | |||
| model_provider_service = ModelProviderService() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| try: | |||
| model_provider_service.create_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| @@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def put(self, provider: str): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| @@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource): | |||
| model_provider_service = ModelProviderService() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| try: | |||
| model_provider_service.update_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| @@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider: str): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] | |||
| @@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| service = ModelProviderService() | |||
| service.switch_active_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| @@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| @@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| @@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| def get(self, provider: str): | |||
| if provider != "anthropic": | |||
| raise ValueError(f"provider name {provider} is invalid") | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| data = BillingService.get_model_provider_payment_link( | |||
| provider_name=provider, | |||
| tenant_id=current_user.current_tenant_id, | |||
| @@ -25,7 +25,7 @@ from controllers.console.wraps import ( | |||
| from extensions.ext_database import db | |||
| from libs.helper import TimestampField | |||
| from libs.login import login_required | |||
| from models.account import Tenant, TenantStatus | |||
| from models.account import Account, Tenant, TenantStatus | |||
| from services.account_service import TenantService | |||
| from services.feature_service import FeatureService | |||
| from services.file_service import FileService | |||
| @@ -70,6 +70,8 @@ class TenantListApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| tenants = TenantService.get_join_tenants(current_user) | |||
| tenant_dicts = [] | |||
| @@ -83,7 +85,7 @@ class TenantListApi(Resource): | |||
| "status": tenant.status, | |||
| "created_at": tenant.created_at, | |||
| "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", | |||
| "current": tenant.id == current_user.current_tenant_id, | |||
| "current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, | |||
| } | |||
| tenant_dicts.append(tenant_dict) | |||
| @@ -125,7 +127,11 @@ class TenantApi(Resource): | |||
| if request.path == "/info": | |||
| logger.warning("Deprecated URL /info was used.") | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| tenant = current_user.current_tenant | |||
| if not tenant: | |||
| raise ValueError("No current tenant") | |||
| if tenant.status == TenantStatus.ARCHIVE: | |||
| tenants = TenantService.get_join_tenants(current_user) | |||
| @@ -137,6 +143,8 @@ class TenantApi(Resource): | |||
| else: | |||
| raise Unauthorized("workspace is archived") | |||
| if not tenant: | |||
| raise ValueError("No tenant available") | |||
| return WorkspaceService.get_tenant_info(tenant), 200 | |||
| @@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("tenant_id", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| @@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource): | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("workspace_custom") | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("remove_webapp_brand", type=bool, location="json") | |||
| parser.add_argument("replace_webapp_logo", type=str, location="json") | |||
| args = parser.parse_args() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| tenant = db.get_or_404(Tenant, current_user.current_tenant_id) | |||
| custom_config_dict = { | |||
| @@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource): | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("workspace_custom") | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| # check file | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| @@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource): | |||
| @account_initialization_required | |||
| # Change workspace name | |||
| def post(self): | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("Invalid user account") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| if not current_user.current_tenant_id: | |||
| raise ValueError("No current tenant") | |||
| tenant = db.get_or_404(Tenant, current_user.current_tenant_id) | |||
| tenant.name = args["name"] | |||
| db.session.commit() | |||
| @@ -15,6 +15,6 @@ api = ExternalApi( | |||
| files_ns = Namespace("files", description="File operations", path="/") | |||
| from . import image_preview, tool_files, upload | |||
| from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport] | |||
| api.add_namespace(files_ns) | |||
| @@ -16,8 +16,8 @@ api = ExternalApi( | |||
| # Create namespace | |||
| inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") | |||
| from . import mail | |||
| from .plugin import plugin | |||
| from .workspace import workspace | |||
| from . import mail as _mail # pyright: ignore[reportUnusedImport] | |||
| from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport] | |||
| from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport] | |||
| api.add_namespace(inner_api_ns) | |||
| @@ -37,9 +37,9 @@ from models.model import EndUser | |||
| @inner_api_ns.route("/invoke/llm") | |||
| class PluginInvokeLLMApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeLLM) | |||
| @inner_api_ns.doc("plugin_invoke_llm") | |||
| @inner_api_ns.doc(description="Invoke LLM models through plugin interface") | |||
| @@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource): | |||
| @inner_api_ns.route("/invoke/llm/structured-output") | |||
| class PluginInvokeLLMWithStructuredOutputApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) | |||
| @inner_api_ns.doc("plugin_invoke_llm_structured") | |||
| @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") | |||
| @@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource): | |||
| @inner_api_ns.route("/invoke/text-embedding") | |||
| class PluginInvokeTextEmbeddingApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTextEmbedding) | |||
| @inner_api_ns.doc("plugin_invoke_text_embedding") | |||
| @inner_api_ns.doc(description="Invoke text embedding models through plugin interface") | |||
| @@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource): | |||
| @inner_api_ns.route("/invoke/rerank") | |||
| class PluginInvokeRerankApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeRerank) | |||
| @inner_api_ns.doc("plugin_invoke_rerank") | |||
| @inner_api_ns.doc(description="Invoke rerank models through plugin interface") | |||
| @@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource): | |||
| @inner_api_ns.route("/invoke/tts") | |||
| class PluginInvokeTTSApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTTS) | |||
| @inner_api_ns.doc("plugin_invoke_tts") | |||
| @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") | |||
| @@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource): | |||
| @inner_api_ns.route("/invoke/speech2text") | |||
| class PluginInvokeSpeech2TextApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeSpeech2Text) | |||
| @inner_api_ns.doc("plugin_invoke_speech2text") | |||
| @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") | |||
| @@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource): | |||
| @inner_api_ns.route("/invoke/moderation") | |||
| class PluginInvokeModerationApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeModeration) | |||
| @inner_api_ns.doc("plugin_invoke_moderation") | |||
| @inner_api_ns.doc(description="Invoke moderation models through plugin interface") | |||
| @@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource): | |||
| @inner_api_ns.route("/invoke/tool") | |||
| class PluginInvokeToolApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTool) | |||
| @inner_api_ns.doc("plugin_invoke_tool") | |||
| @inner_api_ns.doc(description="Invoke tools through plugin interface") | |||
| @@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource): | |||
| @inner_api_ns.route("/invoke/parameter-extractor") | |||
| class PluginInvokeParameterExtractorNodeApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeParameterExtractorNode) | |||
| @inner_api_ns.doc("plugin_invoke_parameter_extractor") | |||
| @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") | |||
| @@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource): | |||
| @inner_api_ns.route("/invoke/question-classifier") | |||
| class PluginInvokeQuestionClassifierNodeApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) | |||
| @inner_api_ns.doc("plugin_invoke_question_classifier") | |||
| @inner_api_ns.doc(description="Invoke question classifier node through plugin interface") | |||
| @@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): | |||
| @inner_api_ns.route("/invoke/app") | |||
| class PluginInvokeAppApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeApp) | |||
| @inner_api_ns.doc("plugin_invoke_app") | |||
| @inner_api_ns.doc(description="Invoke application through plugin interface") | |||
| @@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource): | |||
| @inner_api_ns.route("/invoke/encrypt") | |||
| class PluginInvokeEncryptApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeEncrypt) | |||
| @inner_api_ns.doc("plugin_invoke_encrypt") | |||
| @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") | |||
| @@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource): | |||
| @inner_api_ns.route("/invoke/summary") | |||
| class PluginInvokeSummaryApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeSummary) | |||
| @inner_api_ns.doc("plugin_invoke_summary") | |||
| @inner_api_ns.doc(description="Invoke summary functionality through plugin interface") | |||
| @@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource): | |||
| @inner_api_ns.route("/upload/file/request") | |||
| class PluginUploadFileRequestApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestRequestUploadFile) | |||
| @inner_api_ns.doc("plugin_upload_file_request") | |||
| @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") | |||
| @@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource): | |||
| @inner_api_ns.route("/fetch/app/info") | |||
| class PluginFetchAppInfoApi(Resource): | |||
| @get_user_tenant | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestFetchAppInfo) | |||
| @inner_api_ns.doc("plugin_fetch_app_info") | |||
| @inner_api_ns.doc(description="Fetch application information through plugin interface") | |||
| @@ -1,6 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Optional, ParamSpec, TypeVar | |||
| from typing import Optional, ParamSpec, TypeVar, cast | |||
| from flask import current_app, request | |||
| from flask_login import user_logged_in | |||
| @@ -10,7 +10,7 @@ from sqlalchemy.orm import Session | |||
| from core.file.constants import DEFAULT_SERVICE_API_USER_ID | |||
| from extensions.ext_database import db | |||
| from libs.login import _get_user | |||
| from libs.login import current_user | |||
| from models.account import Tenant | |||
| from models.model import EndUser | |||
| @@ -66,8 +66,8 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): | |||
| p = parser.parse_args() | |||
| user_id: Optional[str] = p.get("user_id") | |||
| tenant_id: str = p.get("tenant_id") | |||
| user_id = cast(str, p.get("user_id")) | |||
| tenant_id = cast(str, p.get("tenant_id")) | |||
| if not tenant_id: | |||
| raise ValueError("tenant_id is required") | |||
| @@ -98,7 +98,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): | |||
| kwargs["user_model"] = user | |||
| current_app.login_manager._update_request_context_with_user(user) # type: ignore | |||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore | |||
| user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore | |||
| return view_func(*args, **kwargs) | |||
| @@ -15,6 +15,6 @@ api = ExternalApi( | |||
| mcp_ns = Namespace("mcp", description="MCP operations", path="/") | |||
| from . import mcp | |||
| from . import mcp # pyright: ignore[reportUnusedImport] | |||
| api.add_namespace(mcp_ns) | |||
| @@ -15,9 +15,27 @@ api = ExternalApi( | |||
| service_api_ns = Namespace("service_api", description="Service operations", path="/") | |||
| from . import index | |||
| from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow | |||
| from .dataset import dataset, document, hit_testing, metadata, segment, upload_file | |||
| from .workspace import models | |||
| from . import index # pyright: ignore[reportUnusedImport] | |||
| from .app import ( | |||
| annotation, # pyright: ignore[reportUnusedImport] | |||
| app, # pyright: ignore[reportUnusedImport] | |||
| audio, # pyright: ignore[reportUnusedImport] | |||
| completion, # pyright: ignore[reportUnusedImport] | |||
| conversation, # pyright: ignore[reportUnusedImport] | |||
| file, # pyright: ignore[reportUnusedImport] | |||
| file_preview, # pyright: ignore[reportUnusedImport] | |||
| message, # pyright: ignore[reportUnusedImport] | |||
| site, # pyright: ignore[reportUnusedImport] | |||
| workflow, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| from .dataset import ( | |||
| dataset, # pyright: ignore[reportUnusedImport] | |||
| document, # pyright: ignore[reportUnusedImport] | |||
| hit_testing, # pyright: ignore[reportUnusedImport] | |||
| metadata, # pyright: ignore[reportUnusedImport] | |||
| segment, # pyright: ignore[reportUnusedImport] | |||
| upload_file, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| from .workspace import models # pyright: ignore[reportUnusedImport] | |||
| api.add_namespace(service_api_ns) | |||
| @@ -1,4 +1,5 @@ | |||
| from flask_restx import Resource, reqparse | |||
| from flask_restx._http import HTTPStatus | |||
| from flask_restx.inputs import int_range | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import BadRequest, NotFound | |||
| @@ -121,7 +122,7 @@ class ConversationDetailApi(Resource): | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) | |||
| @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) | |||
| def delete(self, app_model: App, end_user: EndUser, c_id): | |||
| """Delete a specific conversation.""" | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| @@ -30,6 +30,7 @@ from extensions.ext_database import db | |||
| from fields.document_fields import document_fields, document_status_fields | |||
| from libs.login import current_user | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.model import EndUser | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||
| from services.file_service import FileService | |||
| @@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource): | |||
| if not file.filename: | |||
| raise FilenameNotExistsError | |||
| if not isinstance(current_user, EndUser): | |||
| raise ValueError("Invalid user account") | |||
| upload_file = FileService.upload_file( | |||
| filename=file.filename, | |||
| content=file.read(), | |||
| @@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource): | |||
| raise FilenameNotExistsError | |||
| try: | |||
| if not isinstance(current_user, EndUser): | |||
| raise ValueError("Invalid user account") | |||
| upload_file = FileService.upload_file( | |||
| filename=file.filename, | |||
| content=file.read(), | |||
| @@ -17,7 +17,7 @@ from core.file.constants import DEFAULT_SERVICE_API_USER_ID | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import _get_user | |||
| from libs.login import current_user | |||
| from models.account import Account, Tenant, TenantAccountJoin, TenantStatus | |||
| from models.dataset import Dataset, RateLimitLog | |||
| from models.model import ApiToken, App, EndUser | |||
| @@ -210,7 +210,7 @@ def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None | |||
| if account: | |||
| account.current_tenant = tenant | |||
| current_app.login_manager._update_request_context_with_user(account) # type: ignore | |||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore | |||
| user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore | |||
| else: | |||
| raise Unauthorized("Tenant owner account does not exist.") | |||
| else: | |||
| @@ -17,20 +17,20 @@ api = ExternalApi( | |||
| web_ns = Namespace("web", description="Web application API operations", path="/") | |||
| from . import ( | |||
| app, | |||
| audio, | |||
| completion, | |||
| conversation, | |||
| feature, | |||
| files, | |||
| forgot_password, | |||
| login, | |||
| message, | |||
| passport, | |||
| remote_files, | |||
| saved_message, | |||
| site, | |||
| workflow, | |||
| app, # pyright: ignore[reportUnusedImport] | |||
| audio, # pyright: ignore[reportUnusedImport] | |||
| completion, # pyright: ignore[reportUnusedImport] | |||
| conversation, # pyright: ignore[reportUnusedImport] | |||
| feature, # pyright: ignore[reportUnusedImport] | |||
| files, # pyright: ignore[reportUnusedImport] | |||
| forgot_password, # pyright: ignore[reportUnusedImport] | |||
| login, # pyright: ignore[reportUnusedImport] | |||
| message, # pyright: ignore[reportUnusedImport] | |||
| passport, # pyright: ignore[reportUnusedImport] | |||
| remote_files, # pyright: ignore[reportUnusedImport] | |||
| saved_message, # pyright: ignore[reportUnusedImport] | |||
| site, # pyright: ignore[reportUnusedImport] | |||
| workflow, # pyright: ignore[reportUnusedImport] | |||
| ) | |||
| api.add_namespace(web_ns) | |||
| @@ -1 +0,0 @@ | |||
| import core.moderation.base | |||
| @@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| function_call_state = True | |||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | |||
| final_answer = "" | |||
| prompt_messages: list = [] # Initialize prompt_messages | |||
| agent_thought_id = "" # Initialize agent_thought_id | |||
| def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): | |||
| if not final_llm_usage_dict["usage"]: | |||
| @@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| function_call_state = True | |||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | |||
| final_answer = "" | |||
| prompt_messages: list = [] # Initialize prompt_messages | |||
| # get tracing instance | |||
| trace_manager = app_generate_entity.trace_manager | |||
| @@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager: | |||
| @classmethod | |||
| def validate_and_set_defaults( | |||
| cls, tenant_id, config: dict, only_structure_validate: bool = False | |||
| cls, tenant_id: str, config: dict, only_structure_validate: bool = False | |||
| ) -> tuple[dict, list[str]]: | |||
| if not config.get("sensitive_word_avoidance"): | |||
| config["sensitive_word_avoidance"] = {"enabled": False} | |||
| @@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager: | |||
| if not only_structure_validate: | |||
| typ = config["sensitive_word_avoidance"]["type"] | |||
| sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] | |||
| if not isinstance(typ, str): | |||
| raise ValueError("sensitive_word_avoidance.type must be a string") | |||
| sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config") | |||
| if sensitive_word_avoidance_config is None: | |||
| sensitive_word_avoidance_config = {} | |||
| if not isinstance(sensitive_word_avoidance_config, dict): | |||
| raise ValueError("sensitive_word_avoidance.config must be a dict") | |||
| ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) | |||
| @@ -25,10 +25,14 @@ class PromptTemplateConfigManager: | |||
| if chat_prompt_config: | |||
| chat_prompt_messages = [] | |||
| for message in chat_prompt_config.get("prompt", []): | |||
| text = message.get("text") | |||
| if not isinstance(text, str): | |||
| raise ValueError("message text must be a string") | |||
| role = message.get("role") | |||
| if not isinstance(role, str): | |||
| raise ValueError("message role must be a string") | |||
| chat_prompt_messages.append( | |||
| AdvancedChatMessageEntity( | |||
| **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} | |||
| ) | |||
| AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role)) | |||
| ) | |||
| advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) | |||
| @@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| response_chunk: dict[str, Any] = { | |||
| "event": sub_stream_response.event.value, | |||
| "conversation_id": chunk.conversation_id, | |||
| "message_id": chunk.message_id, | |||
| @@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @classmethod | |||
| @@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| response_chunk: dict[str, Any] = { | |||
| "event": sub_stream_response.event.value, | |||
| "conversation_id": chunk.conversation_id, | |||
| "message_id": chunk.message_id, | |||
| @@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| } | |||
| if isinstance(sub_stream_response, MessageEndStreamResponse): | |||
| sub_stream_response_dict = sub_stream_response.to_dict() | |||
| sub_stream_response_dict = sub_stream_response.model_dump(mode="json") | |||
| metadata = sub_stream_response_dict.get("metadata", {}) | |||
| sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) | |||
| response_chunk.update(sub_stream_response_dict) | |||
| @@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| if self._base_task_pipeline._stream: | |||
| if self._base_task_pipeline.stream: | |||
| return self._to_stream_response(generator) | |||
| else: | |||
| return self._to_blocking_response(generator) | |||
| @@ -302,13 +302,13 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: | |||
| """Handle ping events.""" | |||
| yield self._base_task_pipeline._ping_stream_response() | |||
| yield self._base_task_pipeline.ping_stream_response() | |||
| def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: | |||
| """Handle error events.""" | |||
| with self._database_session() as session: | |||
| err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id) | |||
| yield self._base_task_pipeline.error_to_stream_response(err) | |||
| def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]: | |||
| """Handle workflow started events.""" | |||
| @@ -627,10 +627,10 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) | |||
| err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) | |||
| err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id) | |||
| yield workflow_finish_resp | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| yield self._base_task_pipeline.error_to_stream_response(err) | |||
| def _handle_stop_event( | |||
| self, | |||
| @@ -683,7 +683,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| """Handle advanced chat message end events.""" | |||
| self._ensure_graph_runtime_initialized(graph_runtime_state) | |||
| output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( | |||
| output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished( | |||
| self._task_state.answer | |||
| ) | |||
| if output_moderation_answer: | |||
| @@ -899,7 +899,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| message.answer = answer_text | |||
| message.updated_at = naive_utc_now() | |||
| message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at | |||
| message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| message_files = [ | |||
| MessageFile( | |||
| @@ -955,9 +955,9 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| :param text: text | |||
| :return: True if output moderation should direct output, otherwise False | |||
| """ | |||
| if self._base_task_pipeline._output_moderation_handler: | |||
| if self._base_task_pipeline._output_moderation_handler.should_direct_output(): | |||
| self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() | |||
| if self._base_task_pipeline.output_moderation_handler: | |||
| if self._base_task_pipeline.output_moderation_handler.should_direct_output(): | |||
| self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output() | |||
| self._base_task_pipeline.queue_manager.publish( | |||
| QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE | |||
| ) | |||
| @@ -967,7 +967,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| ) | |||
| return True | |||
| else: | |||
| self._base_task_pipeline._output_moderation_handler.append_new_token(text) | |||
| self._base_task_pipeline.output_moderation_handler.append_new_token(text) | |||
| return False | |||
| @@ -1,6 +1,6 @@ | |||
| import uuid | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| from typing import Any, Optional, cast | |||
| from core.agent.entities import AgentEntity | |||
| from core.app.app_config.base_app_config_manager import BaseAppConfigManager | |||
| @@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): | |||
| return filtered_config | |||
| @classmethod | |||
| def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: | |||
| def validate_agent_mode_and_set_defaults( | |||
| cls, tenant_id: str, config: dict[str, Any] | |||
| ) -> tuple[dict[str, Any], list[str]]: | |||
| """ | |||
| Validate agent_mode and set defaults for agent feature | |||
| @@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager): | |||
| if not config.get("agent_mode"): | |||
| config["agent_mode"] = {"enabled": False, "tools": []} | |||
| if not isinstance(config["agent_mode"], dict): | |||
| agent_mode = config["agent_mode"] | |||
| if not isinstance(agent_mode, dict): | |||
| raise ValueError("agent_mode must be of object type") | |||
| if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: | |||
| config["agent_mode"]["enabled"] = False | |||
| # FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing | |||
| agent_mode = cast(dict[str, Any], agent_mode) | |||
| if not isinstance(config["agent_mode"]["enabled"], bool): | |||
| if "enabled" not in agent_mode or not agent_mode["enabled"]: | |||
| agent_mode["enabled"] = False | |||
| if not isinstance(agent_mode["enabled"], bool): | |||
| raise ValueError("enabled in agent_mode must be of boolean type") | |||
| if not config["agent_mode"].get("strategy"): | |||
| config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value | |||
| if not agent_mode.get("strategy"): | |||
| agent_mode["strategy"] = PlanningStrategy.ROUTER.value | |||
| if config["agent_mode"]["strategy"] not in [ | |||
| member.value for member in list(PlanningStrategy.__members__.values()) | |||
| ]: | |||
| if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: | |||
| raise ValueError("strategy in agent_mode must be in the specified strategy list") | |||
| if not config["agent_mode"].get("tools"): | |||
| config["agent_mode"]["tools"] = [] | |||
| if not agent_mode.get("tools"): | |||
| agent_mode["tools"] = [] | |||
| if not isinstance(config["agent_mode"]["tools"], list): | |||
| if not isinstance(agent_mode["tools"], list): | |||
| raise ValueError("tools in agent_mode must be a list of objects") | |||
| for tool in config["agent_mode"]["tools"]: | |||
| for tool in agent_mode["tools"]: | |||
| key = list(tool.keys())[0] | |||
| if key in OLD_TOOLS: | |||
| # old style, use tool name as key | |||
| @@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response = cls.convert_blocking_full_response(blocking_response) | |||
| metadata = response.get("metadata", {}) | |||
| response["metadata"] = cls._get_simple_metadata(metadata) | |||
| if isinstance(metadata, dict): | |||
| response["metadata"] = cls._get_simple_metadata(metadata) | |||
| else: | |||
| response["metadata"] = {} | |||
| return response | |||
| @@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @classmethod | |||
| @@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| } | |||
| if isinstance(sub_stream_response, MessageEndStreamResponse): | |||
| sub_stream_response_dict = sub_stream_response.to_dict() | |||
| sub_stream_response_dict = sub_stream_response.model_dump(mode="json") | |||
| metadata = sub_stream_response_dict.get("metadata", {}) | |||
| sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) | |||
| response_chunk.update(sub_stream_response_dict) | |||
| @@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @@ -32,6 +32,7 @@ class AppQueueManager: | |||
| self._task_id = task_id | |||
| self._user_id = user_id | |||
| self._invoke_from = invoke_from | |||
| self.invoke_from = invoke_from # Public accessor for invoke_from | |||
| user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" | |||
| redis_client.setex( | |||
| @@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response = cls.convert_blocking_full_response(blocking_response) | |||
| metadata = response.get("metadata", {}) | |||
| response["metadata"] = cls._get_simple_metadata(metadata) | |||
| if isinstance(metadata, dict): | |||
| response["metadata"] = cls._get_simple_metadata(metadata) | |||
| else: | |||
| response["metadata"] = {} | |||
| return response | |||
| @@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @classmethod | |||
| @@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| } | |||
| if isinstance(sub_stream_response, MessageEndStreamResponse): | |||
| sub_stream_response_dict = sub_stream_response.to_dict() | |||
| sub_stream_response_dict = sub_stream_response.model_dump(mode="json") | |||
| metadata = sub_stream_response_dict.get("metadata", {}) | |||
| sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) | |||
| response_chunk.update(sub_stream_response_dict) | |||
| @@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| raise MoreLikeThisDisabledError() | |||
| app_model_config = message.app_model_config | |||
| if not app_model_config: | |||
| raise ValueError("Message app_model_config is None") | |||
| override_model_config_dict = app_model_config.to_dict() | |||
| model_dict = override_model_config_dict["model"] | |||
| completion_params = model_dict.get("completion_params") | |||
| @@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response = cls.convert_blocking_full_response(blocking_response) | |||
| metadata = response.get("metadata", {}) | |||
| response["metadata"] = cls._get_simple_metadata(metadata) | |||
| if isinstance(metadata, dict): | |||
| response["metadata"] = cls._get_simple_metadata(metadata) | |||
| else: | |||
| response["metadata"] = {} | |||
| return response | |||
| @@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @classmethod | |||
| @@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| } | |||
| if isinstance(sub_stream_response, MessageEndStreamResponse): | |||
| sub_stream_response_dict = sub_stream_response.to_dict() | |||
| sub_stream_response_dict = sub_stream_response.model_dump(mode="json") | |||
| metadata = sub_stream_response_dict.get("metadata", {}) | |||
| if not isinstance(metadata, dict): | |||
| metadata = {} | |||
| sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) | |||
| response_chunk.update(sub_stream_response_dict) | |||
| if isinstance(sub_stream_response, ErrorStreamResponse): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| :param blocking_response: blocking response | |||
| :return: | |||
| """ | |||
| return dict(blocking_response.to_dict()) | |||
| return blocking_response.model_dump() | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] | |||
| @@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| response_chunk: dict[str, object] = { | |||
| "event": sub_stream_response.event.value, | |||
| "workflow_run_id": chunk.workflow_run_id, | |||
| } | |||
| @@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @classmethod | |||
| @@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| response_chunk: dict[str, object] = { | |||
| "event": sub_stream_response.event.value, | |||
| "workflow_run_id": chunk.workflow_run_id, | |||
| } | |||
| @@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| response_chunk.update(sub_stream_response.model_dump(mode="json")) | |||
| yield response_chunk | |||
| @@ -137,7 +137,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._workflow_run_id = "" | |||
| self._invoke_from = queue_manager._invoke_from | |||
| self._invoke_from = queue_manager.invoke_from | |||
| self._draft_var_saver_factory = draft_var_saver_factory | |||
| def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| @@ -146,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| :return: | |||
| """ | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| if self._base_task_pipeline._stream: | |||
| if self._base_task_pipeline.stream: | |||
| return self._to_stream_response(generator) | |||
| else: | |||
| return self._to_blocking_response(generator) | |||
| @@ -276,12 +276,12 @@ class WorkflowAppGenerateTaskPipeline: | |||
| def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: | |||
| """Handle ping events.""" | |||
| yield self._base_task_pipeline._ping_stream_response() | |||
| yield self._base_task_pipeline.ping_stream_response() | |||
| def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: | |||
| """Handle error events.""" | |||
| err = self._base_task_pipeline._handle_error(event=event) | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| err = self._base_task_pipeline.handle_error(event=event) | |||
| yield self._base_task_pipeline.error_to_stream_response(err) | |||
| def _handle_workflow_started_event( | |||
| self, event: QueueWorkflowStartedEvent, **kwargs | |||
| @@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): | |||
| """ | |||
| # app config | |||
| app_config: EasyUIBasedAppConfig | |||
| app_config: EasyUIBasedAppConfig = None # type: ignore | |||
| model_conf: ModelConfigWithCredentialsEntity | |||
| query: Optional[str] = None | |||
| @@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): | |||
| """ | |||
| # app config | |||
| app_config: WorkflowUIBasedAppConfig | |||
| app_config: WorkflowUIBasedAppConfig = None # type: ignore | |||
| workflow_run_id: Optional[str] = None | |||
| query: str | |||
| @@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): | |||
| """ | |||
| # app config | |||
| app_config: WorkflowUIBasedAppConfig | |||
| app_config: WorkflowUIBasedAppConfig = None # type: ignore | |||
| workflow_execution_id: str | |||
| class SingleIterationRunEntity(BaseModel): | |||
| @@ -5,7 +5,6 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel, ConfigDict, Field | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| @@ -92,9 +91,6 @@ class StreamResponse(BaseModel): | |||
| event: StreamEvent | |||
| task_id: str | |||
| def to_dict(self): | |||
| return jsonable_encoder(self) | |||
| class ErrorStreamResponse(StreamResponse): | |||
| """ | |||
| @@ -745,9 +741,6 @@ class AppBlockingResponse(BaseModel): | |||
| task_id: str | |||
| def to_dict(self): | |||
| return jsonable_encoder(self) | |||
| class ChatbotAppBlockingResponse(AppBlockingResponse): | |||
| """ | |||
| @@ -35,6 +35,9 @@ class AnnotationReplyFeature: | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| if not collection_binding_detail: | |||
| return None | |||
| try: | |||
| score_threshold = annotation_setting.score_threshold or 1 | |||
| embedding_provider_name = collection_binding_detail.provider_name | |||
| @@ -1 +1,3 @@ | |||
| from .rate_limit import RateLimit | |||
| __all__ = ["RateLimit"] | |||
| @@ -19,7 +19,7 @@ class RateLimit: | |||
| _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes | |||
| _instance_dict: dict[str, "RateLimit"] = {} | |||
| def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): | |||
| def __new__(cls, client_id: str, max_active_requests: int): | |||
| if client_id not in cls._instance_dict: | |||
| instance = super().__new__(cls) | |||
| cls._instance_dict[client_id] = instance | |||
| @@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline: | |||
| ): | |||
| self._application_generate_entity = application_generate_entity | |||
| self.queue_manager = queue_manager | |||
| self._start_at = time.perf_counter() | |||
| self._output_moderation_handler = self._init_output_moderation() | |||
| self._stream = stream | |||
| self.start_at = time.perf_counter() | |||
| self.output_moderation_handler = self._init_output_moderation() | |||
| self.stream = stream | |||
| def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): | |||
| def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): | |||
| logger.debug("error: %s", event.error) | |||
| e = event.error | |||
| err: Exception | |||
| @@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline: | |||
| return message | |||
| def _error_to_stream_response(self, e: Exception): | |||
| def error_to_stream_response(self, e: Exception): | |||
| """ | |||
| Error to stream response. | |||
| :param e: exception | |||
| @@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline: | |||
| """ | |||
| return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) | |||
| def _ping_stream_response(self) -> PingStreamResponse: | |||
| def ping_stream_response(self) -> PingStreamResponse: | |||
| """ | |||
| Ping stream response. | |||
| :return: | |||
| @@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline: | |||
| ) | |||
| return None | |||
| def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: | |||
| def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: | |||
| """ | |||
| Handle output moderation when task finished. | |||
| :param completion: completion | |||
| :return: | |||
| """ | |||
| # response moderation | |||
| if self._output_moderation_handler: | |||
| self._output_moderation_handler.stop_thread() | |||
| if self.output_moderation_handler: | |||
| self.output_moderation_handler.stop_thread() | |||
| completion, flagged = self._output_moderation_handler.moderation_completion( | |||
| completion, flagged = self.output_moderation_handler.moderation_completion( | |||
| completion=completion, public_event=False | |||
| ) | |||
| self._output_moderation_handler = None | |||
| self.output_moderation_handler = None | |||
| if flagged: | |||
| return completion | |||
| @@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| ) | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| if self._stream: | |||
| if self.stream: | |||
| return self._to_stream_response(generator) | |||
| else: | |||
| return self._to_blocking_response(generator) | |||
| @@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| if isinstance(event, QueueErrorEvent): | |||
| with Session(db.engine) as session: | |||
| err = self._handle_error(event=event, session=session, message_id=self._message_id) | |||
| err = self.handle_error(event=event, session=session, message_id=self._message_id) | |||
| session.commit() | |||
| yield self._error_to_stream_response(err) | |||
| yield self.error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): | |||
| if isinstance(event, QueueMessageEndEvent): | |||
| @@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| self._handle_stop(event) | |||
| # handle output moderation | |||
| output_moderation_answer = self._handle_output_moderation_when_task_finished( | |||
| output_moderation_answer = self.handle_output_moderation_when_task_finished( | |||
| cast(str, self._task_state.llm_result.message.content) | |||
| ) | |||
| if output_moderation_answer: | |||
| @@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) | |||
| elif isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| yield self.ping_stream_response() | |||
| else: | |||
| continue | |||
| if publisher: | |||
| @@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| message.answer_tokens = usage.completion_tokens | |||
| message.answer_unit_price = usage.completion_unit_price | |||
| message.answer_price_unit = usage.completion_price_unit | |||
| message.provider_response_latency = time.perf_counter() - self._start_at | |||
| message.provider_response_latency = time.perf_counter() - self.start_at | |||
| message.total_price = usage.total_price | |||
| message.currency = usage.currency | |||
| self._task_state.llm_result.usage.latency = message.provider_response_latency | |||
| @@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| # transform usage | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| self._task_state.llm_result.usage = model_type_instance._calc_response_usage( | |||
| self._task_state.llm_result.usage = model_type_instance.calc_response_usage( | |||
| model, credentials, prompt_tokens, completion_tokens | |||
| ) | |||
| @@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| :param text: text | |||
| :return: True if output moderation should direct output, otherwise False | |||
| """ | |||
| if self._output_moderation_handler: | |||
| if self._output_moderation_handler.should_direct_output(): | |||
| if self.output_moderation_handler: | |||
| if self.output_moderation_handler.should_direct_output(): | |||
| # stop subscribe new token when output moderation should direct output | |||
| self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() | |||
| self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output() | |||
| self.queue_manager.publish( | |||
| QueueLLMChunkEvent( | |||
| chunk=LLMResultChunk( | |||
| @@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| ) | |||
| return True | |||
| else: | |||
| self._output_moderation_handler.append_new_token(text) | |||
| self.output_moderation_handler.append_new_token(text) | |||
| return False | |||
| @@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher: | |||
| self.voice = voice | |||
| if not voice or voice not in values: | |||
| self.voice = self.voices[0].get("value") | |||
| self.MAX_SENTENCE = 2 | |||
| self.max_sentence = 2 | |||
| self._last_audio_event: Optional[AudioTrunk] = None | |||
| # FIXME better way to handle this threading.start | |||
| threading.Thread(target=self._runtime).start() | |||
| @@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher: | |||
| self.msg_text += message.event.outputs.get("output", "") | |||
| self.last_message = message | |||
| sentence_arr, text_tmp = self._extract_sentence(self.msg_text) | |||
| if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): | |||
| self.MAX_SENTENCE += 1 | |||
| if len(sentence_arr) >= min(self.max_sentence, 7): | |||
| self.max_sentence += 1 | |||
| text_content = "".join(sentence_arr) | |||
| futures_result = self.executor.submit( | |||
| _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice | |||
| @@ -1840,8 +1840,14 @@ class ProviderConfigurations(BaseModel): | |||
| def __setitem__(self, key, value): | |||
| self.configurations[key] = value | |||
| def __contains__(self, key): | |||
| if "/" not in key: | |||
| key = str(ModelProviderID(key)) | |||
| return key in self.configurations | |||
| def __iter__(self): | |||
| return iter(self.configurations) | |||
| # Return an iterator of (key, value) tuples to match BaseModel's __iter__ | |||
| yield from self.configurations.items() | |||
| def values(self) -> Iterator[ProviderConfiguration]: | |||
| return iter(self.configurations.values()) | |||
| @@ -98,7 +98,7 @@ def to_prompt_message_content( | |||
| def download(f: File, /): | |||
| if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): | |||
| return _download_file_content(f._storage_key) | |||
| return _download_file_content(f.storage_key) | |||
| elif f.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| response = ssrf_proxy.get(f.remote_url, follow_redirects=True) | |||
| response.raise_for_status() | |||
| @@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /): | |||
| response.raise_for_status() | |||
| data = response.content | |||
| case FileTransferMethod.LOCAL_FILE: | |||
| data = _download_file_content(f._storage_key) | |||
| data = _download_file_content(f.storage_key) | |||
| case FileTransferMethod.TOOL_FILE: | |||
| data = _download_file_content(f._storage_key) | |||
| data = _download_file_content(f.storage_key) | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return encoded_string | |||
| @@ -146,3 +146,11 @@ class File(BaseModel): | |||
| if not self.related_id: | |||
| raise ValueError("Missing file related_id") | |||
| return self | |||
| @property | |||
| def storage_key(self) -> str: | |||
| return self._storage_key | |||
| @storage_key.setter | |||
| def storage_key(self, value: str): | |||
| self._storage_key = value | |||
| @@ -13,18 +13,18 @@ logger = logging.getLogger(__name__) | |||
| SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES | |||
| HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True | |||
| http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True | |||
| try: | |||
| HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY | |||
| http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() | |||
| config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY | |||
| http_request_node_ssl_verify_lower = str(config_value).lower() | |||
| if http_request_node_ssl_verify_lower == "true": | |||
| HTTP_REQUEST_NODE_SSL_VERIFY = True | |||
| http_request_node_ssl_verify = True | |||
| elif http_request_node_ssl_verify_lower == "false": | |||
| HTTP_REQUEST_NODE_SSL_VERIFY = False | |||
| http_request_node_ssl_verify = False | |||
| else: | |||
| raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") | |||
| except NameError: | |||
| HTTP_REQUEST_NODE_SSL_VERIFY = True | |||
| http_request_node_ssl_verify = True | |||
| BACKOFF_FACTOR = 0.5 | |||
| STATUS_FORCELIST = [429, 500, 502, 503, 504] | |||
| @@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| ) | |||
| if "ssl_verify" not in kwargs: | |||
| kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY | |||
| kwargs["ssl_verify"] = http_request_node_ssl_verify | |||
| ssl_verify = kwargs.pop("ssl_verify") | |||
| @@ -529,6 +529,7 @@ class IndexingRunner: | |||
| # chunk nodes by chunk size | |||
| indexing_start_at = time.perf_counter() | |||
| tokens = 0 | |||
| create_keyword_thread = None | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": | |||
| # create keyword index | |||
| create_keyword_thread = threading.Thread( | |||
| @@ -567,7 +568,11 @@ class IndexingRunner: | |||
| for future in futures: | |||
| tokens += future.result() | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": | |||
| if ( | |||
| dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX | |||
| and dataset.indexing_technique == "economy" | |||
| and create_keyword_thread is not None | |||
| ): | |||
| create_keyword_thread.join() | |||
| indexing_end_at = time.perf_counter() | |||
| @@ -20,7 +20,7 @@ from core.llm_generator.prompts import ( | |||
| ) | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage | |||
| from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| @@ -313,14 +313,20 @@ class LLMGenerator: | |||
| model_type=ModelType.LLM, | |||
| ) | |||
| prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] | |||
| prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| # Explicitly use the non-streaming overload | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters={"temperature": 0.01, "max_tokens": 2000}, | |||
| stream=False, | |||
| ) | |||
| # Runtime type check since pyright has issues with the overload | |||
| if not isinstance(result, LLMResult): | |||
| raise TypeError("Expected LLMResult when stream=False") | |||
| response = result | |||
| answer = cast(str, response.message.content) | |||
| return answer.strip() | |||
| @@ -45,6 +45,7 @@ class SpecialModelType(StrEnum): | |||
| @overload | |||
| def invoke_llm_with_structured_output( | |||
| *, | |||
| provider: str, | |||
| model_schema: AIModelEntity, | |||
| model_instance: ModelInstance, | |||
| @@ -53,14 +54,13 @@ def invoke_llm_with_structured_output( | |||
| model_parameters: Optional[Mapping] = None, | |||
| tools: Sequence[PromptMessageTool] | None = None, | |||
| stop: Optional[list[str]] = None, | |||
| stream: Literal[True] = True, | |||
| stream: Literal[True], | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... | |||
| @overload | |||
| def invoke_llm_with_structured_output( | |||
| *, | |||
| provider: str, | |||
| model_schema: AIModelEntity, | |||
| model_instance: ModelInstance, | |||
| @@ -69,14 +69,13 @@ def invoke_llm_with_structured_output( | |||
| model_parameters: Optional[Mapping] = None, | |||
| tools: Sequence[PromptMessageTool] | None = None, | |||
| stop: Optional[list[str]] = None, | |||
| stream: Literal[False] = False, | |||
| stream: Literal[False], | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> LLMResultWithStructuredOutput: ... | |||
| @overload | |||
| def invoke_llm_with_structured_output( | |||
| *, | |||
| provider: str, | |||
| model_schema: AIModelEntity, | |||
| model_instance: ModelInstance, | |||
| @@ -89,9 +88,8 @@ def invoke_llm_with_structured_output( | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... | |||
| def invoke_llm_with_structured_output( | |||
| *, | |||
| provider: str, | |||
| model_schema: AIModelEntity, | |||
| model_instance: ModelInstance, | |||
| @@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 | |||
| @final | |||
| class _StatusReady: | |||
| def __init__(self, endpoint_url: str): | |||
| self._endpoint_url = endpoint_url | |||
| self.endpoint_url = endpoint_url | |||
| @final | |||
| class _StatusError: | |||
| def __init__(self, exc: Exception): | |||
| self._exc = exc | |||
| self.exc = exc | |||
| # Type aliases for better readability | |||
| @@ -211,9 +211,9 @@ class SSETransport: | |||
| raise ValueError("failed to get endpoint URL") | |||
| if isinstance(status, _StatusReady): | |||
| return status._endpoint_url | |||
| return status.endpoint_url | |||
| elif isinstance(status, _StatusError): | |||
| raise status._exc | |||
| raise status.exc | |||
| else: | |||
| raise ValueError("failed to get endpoint URL") | |||
| @@ -38,6 +38,7 @@ def handle_mcp_request( | |||
| """ | |||
| request_type = type(request.root) | |||
| request_root = request.root | |||
| def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: | |||
| """Create success response with business result data""" | |||
| @@ -58,21 +59,20 @@ def handle_mcp_request( | |||
| error=error_data, | |||
| ) | |||
| # Request handler mapping using functional approach | |||
| request_handlers = { | |||
| mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description), | |||
| mcp_types.ListToolsRequest: lambda: handle_list_tools( | |||
| app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict | |||
| ), | |||
| mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user), | |||
| mcp_types.PingRequest: lambda: handle_ping(), | |||
| } | |||
| try: | |||
| # Dispatch request to appropriate handler | |||
| handler = request_handlers.get(request_type) | |||
| if handler: | |||
| return create_success_response(handler()) | |||
| # Dispatch request to appropriate handler based on instance type | |||
| if isinstance(request_root, mcp_types.InitializeRequest): | |||
| return create_success_response(handle_initialize(mcp_server.description)) | |||
| elif isinstance(request_root, mcp_types.ListToolsRequest): | |||
| return create_success_response( | |||
| handle_list_tools( | |||
| app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict | |||
| ) | |||
| ) | |||
| elif isinstance(request_root, mcp_types.CallToolRequest): | |||
| return create_success_response(handle_call_tool(app, request, user_input_form, end_user)) | |||
| elif isinstance(request_root, mcp_types.PingRequest): | |||
| return create_success_response(handle_ping()) | |||
| else: | |||
| return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") | |||
| @@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| self.request_meta = request_meta | |||
| self.request = request | |||
| self._session = session | |||
| self._completed = False | |||
| self.completed = False | |||
| self._on_complete = on_complete | |||
| self._entered = False # Track if we're in a context manager | |||
| @@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| ): | |||
| """Exit the context manager, performing cleanup and notifying completion.""" | |||
| try: | |||
| if self._completed: | |||
| if self.completed: | |||
| self._on_complete(self) | |||
| finally: | |||
| self._entered = False | |||
| @@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| """ | |||
| if not self._entered: | |||
| raise RuntimeError("RequestResponder must be used as a context manager") | |||
| assert not self._completed, "Request already responded to" | |||
| assert not self.completed, "Request already responded to" | |||
| self._completed = True | |||
| self.completed = True | |||
| self._session._send_response(request_id=self.request_id, response=response) | |||
| @@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| if not self._entered: | |||
| raise RuntimeError("RequestResponder must be used as a context manager") | |||
| self._completed = True # Mark as completed so it's removed from in_flight | |||
| self.completed = True # Mark as completed so it's removed from in_flight | |||
| # Send an error response to indicate cancellation | |||
| self._session._send_response( | |||
| request_id=self.request_id, | |||
| @@ -351,7 +351,7 @@ class BaseSession( | |||
| self._in_flight[responder.request_id] = responder | |||
| self._received_request(responder) | |||
| if not responder._completed: | |||
| if not responder.completed: | |||
| self._handle_incoming(responder) | |||
| elif isinstance(message.message.root, JSONRPCNotification): | |||
| @@ -354,7 +354,7 @@ class LargeLanguageModel(AIModel): | |||
| ) | |||
| return 0 | |||
| def _calc_response_usage( | |||
| def calc_response_usage( | |||
| self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int | |||
| ) -> LLMUsage: | |||
| """ | |||
| @@ -1,4 +1,5 @@ | |||
| import enum | |||
| import json | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel, Field, field_validator | |||
| @@ -162,8 +163,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): | |||
| # Try to parse JSON string for arrays | |||
| if isinstance(value, str): | |||
| try: | |||
| import json | |||
| parsed_value = json.loads(value) | |||
| if isinstance(parsed_value, list): | |||
| return parsed_value | |||
| @@ -176,8 +175,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): | |||
| # Try to parse JSON string for objects | |||
| if isinstance(value, str): | |||
| try: | |||
| import json | |||
| parsed_value = json.loads(value) | |||
| if isinstance(parsed_value, dict): | |||
| return parsed_value | |||
| @@ -82,7 +82,9 @@ def merge_blob_chunks( | |||
| message_class = type(resp) | |||
| merged_message = message_class( | |||
| type=ToolInvokeMessage.MessageType.BLOB, | |||
| message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), | |||
| message=ToolInvokeMessage.BlobMessage( | |||
| blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written]) | |||
| ), | |||
| meta=resp.meta, | |||
| ) | |||
| yield cast(MessageType, merged_message) | |||
| @@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform): | |||
| with_memory_prompt=histories is not None, | |||
| ) | |||
| variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} | |||
| custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] | |||
| special_variable_keys_obj = prompt_template_config["special_variable_keys"] | |||
| for v in prompt_template_config["special_variable_keys"]: | |||
| # Type check for custom_variable_keys | |||
| if not isinstance(custom_variable_keys_obj, list): | |||
| raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") | |||
| custom_variable_keys = cast(list[str], custom_variable_keys_obj) | |||
| # Type check for special_variable_keys | |||
| if not isinstance(special_variable_keys_obj, list): | |||
| raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") | |||
| special_variable_keys = cast(list[str], special_variable_keys_obj) | |||
| variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} | |||
| for v in special_variable_keys: | |||
| # support #context#, #query# and #histories# | |||
| if v == "#context#": | |||
| variables["#context#"] = context or "" | |||
| @@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform): | |||
| variables["#histories#"] = histories or "" | |||
| prompt_template = prompt_template_config["prompt_template"] | |||
| if not isinstance(prompt_template, PromptTemplateParser): | |||
| raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}") | |||
| prompt = prompt_template.format(variables) | |||
| return prompt, prompt_template_config["prompt_rules"] | |||
| prompt_rules = prompt_template_config["prompt_rules"] | |||
| if not isinstance(prompt_rules, dict): | |||
| raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") | |||
| return prompt, prompt_rules | |||
| def get_prompt_template( | |||
| self, | |||
| @@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform): | |||
| has_context: bool, | |||
| query_in_prompt: bool, | |||
| with_memory_prompt: bool = False, | |||
| ): | |||
| ) -> dict[str, object]: | |||
| prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) | |||
| custom_variable_keys = [] | |||
| special_variable_keys = [] | |||
| custom_variable_keys: list[str] = [] | |||
| special_variable_keys: list[str] = [] | |||
| prompt = "" | |||
| for order in prompt_rules["system_prompt_orders"]: | |||
| @@ -40,6 +40,19 @@ if TYPE_CHECKING: | |||
| MetadataFilter = Union[DictFilter, common_types.Filter] | |||
| class PathQdrantParams(BaseModel): | |||
| path: str | |||
| class UrlQdrantParams(BaseModel): | |||
| url: str | |||
| api_key: Optional[str] | |||
| timeout: float | |||
| verify: bool | |||
| grpc_port: int | |||
| prefer_grpc: bool | |||
| class QdrantConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] = None | |||
| @@ -50,7 +63,7 @@ class QdrantConfig(BaseModel): | |||
| replication_factor: int = 1 | |||
| write_consistency_factor: int = 1 | |||
| def to_qdrant_params(self): | |||
| def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams: | |||
| if self.endpoint and self.endpoint.startswith("path:"): | |||
| path = self.endpoint.replace("path:", "") | |||
| if not os.path.isabs(path): | |||
| @@ -58,23 +71,23 @@ class QdrantConfig(BaseModel): | |||
| raise ValueError("Root path is not set") | |||
| path = os.path.join(self.root_path, path) | |||
| return {"path": path} | |||
| return PathQdrantParams(path=path) | |||
| else: | |||
| return { | |||
| "url": self.endpoint, | |||
| "api_key": self.api_key, | |||
| "timeout": self.timeout, | |||
| "verify": self.endpoint.startswith("https"), | |||
| "grpc_port": self.grpc_port, | |||
| "prefer_grpc": self.prefer_grpc, | |||
| } | |||
| return UrlQdrantParams( | |||
| url=self.endpoint, | |||
| api_key=self.api_key, | |||
| timeout=self.timeout, | |||
| verify=self.endpoint.startswith("https"), | |||
| grpc_port=self.grpc_port, | |||
| prefer_grpc=self.prefer_grpc, | |||
| ) | |||
| class QdrantVector(BaseVector): | |||
| def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): | |||
| super().__init__(collection_name) | |||
| self._client_config = config | |||
| self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) | |||
| self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump()) | |||
| self._distance_func = distance_func.upper() | |||
| self._group_id = group_id | |||
| @@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): | |||
| self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER | |||
| # In-memory cache for workflow node executions | |||
| self._execution_cache: dict[str, WorkflowNodeExecution] = {} | |||
| self._execution_cache = {} | |||
| # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval | |||
| self._workflow_execution_mapping: dict[str, list[str]] = {} | |||
| self._workflow_execution_mapping = {} | |||
| logger.info( | |||
| "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", | |||
| @@ -4,7 +4,7 @@ from .types import SegmentType | |||
| class SegmentGroup(Segment): | |||
| value_type: SegmentType = SegmentType.GROUP | |||
| value: list[Segment] | |||
| value: list[Segment] = None # type: ignore | |||
| @property | |||
| def text(self): | |||
| @@ -74,12 +74,12 @@ class NoneSegment(Segment): | |||
| class StringSegment(Segment): | |||
| value_type: SegmentType = SegmentType.STRING | |||
| value: str | |||
| value: str = None # type: ignore | |||
| class FloatSegment(Segment): | |||
| value_type: SegmentType = SegmentType.FLOAT | |||
| value: float | |||
| value: float = None # type: ignore | |||
| # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. | |||
| # The following tests cannot pass. | |||
| # | |||
| @@ -98,12 +98,12 @@ class FloatSegment(Segment): | |||
| class IntegerSegment(Segment): | |||
| value_type: SegmentType = SegmentType.INTEGER | |||
| value: int | |||
| value: int = None # type: ignore | |||
| class ObjectSegment(Segment): | |||
| value_type: SegmentType = SegmentType.OBJECT | |||
| value: Mapping[str, Any] | |||
| value: Mapping[str, Any] = None # type: ignore | |||
| @property | |||
| def text(self) -> str: | |||
| @@ -136,7 +136,7 @@ class ArraySegment(Segment): | |||
| class FileSegment(Segment): | |||
| value_type: SegmentType = SegmentType.FILE | |||
| value: File | |||
| value: File = None # type: ignore | |||
| @property | |||
| def markdown(self) -> str: | |||
| @@ -153,17 +153,17 @@ class FileSegment(Segment): | |||
| class BooleanSegment(Segment): | |||
| value_type: SegmentType = SegmentType.BOOLEAN | |||
| value: bool | |||
| value: bool = None # type: ignore | |||
| class ArrayAnySegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_ANY | |||
| value: Sequence[Any] | |||
| value: Sequence[Any] = None # type: ignore | |||
| class ArrayStringSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_STRING | |||
| value: Sequence[str] | |||
| value: Sequence[str] = None # type: ignore | |||
| @property | |||
| def text(self) -> str: | |||
| @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): | |||
| class ArrayNumberSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_NUMBER | |||
| value: Sequence[float | int] | |||
| value: Sequence[float | int] = None # type: ignore | |||
| class ArrayObjectSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_OBJECT | |||
| value: Sequence[Mapping[str, Any]] | |||
| value: Sequence[Mapping[str, Any]] = None # type: ignore | |||
| class ArrayFileSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_FILE | |||
| value: Sequence[File] | |||
| value: Sequence[File] = None # type: ignore | |||
| @property | |||
| def markdown(self) -> str: | |||
| @@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment): | |||
| class ArrayBooleanSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_BOOLEAN | |||
| value: Sequence[bool] | |||
| value: Sequence[bool] = None # type: ignore | |||
| def get_segment_discriminator(v: Any) -> SegmentType | None: | |||
| @@ -3,6 +3,6 @@ from core.workflow.nodes.base import BaseNode | |||
| class WorkflowNodeRunFailedError(Exception): | |||
| def __init__(self, node: BaseNode, err_msg: str): | |||
| self._node = node | |||
| self._error = err_msg | |||
| self.node = node | |||
| self.error = err_msg | |||
| super().__init__(f"Node {node.title} run failed: {err_msg}") | |||
| @@ -67,8 +67,8 @@ class ListOperatorNode(BaseNode): | |||
| return "1" | |||
| def _run(self): | |||
| inputs: dict[str, list] = {} | |||
| process_data: dict[str, list] = {} | |||
| inputs: dict[str, Sequence[object]] = {} | |||
| process_data: dict[str, Sequence[object]] = {} | |||
| outputs: dict[str, Any] = {} | |||
| variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) | |||
| @@ -1183,7 +1183,8 @@ def _combine_message_content_with_role( | |||
| return AssistantPromptMessage(content=contents) | |||
| case PromptMessageRole.SYSTEM: | |||
| return SystemPromptMessage(content=contents) | |||
| raise NotImplementedError(f"Role {role} is not supported") | |||
| case _: | |||
| raise NotImplementedError(f"Role {role} is not supported") | |||
| def _render_jinja2_message( | |||
| @@ -462,9 +462,9 @@ class StorageKeyLoader: | |||
| upload_file_row = upload_files.get(model_id) | |||
| if upload_file_row is None: | |||
| raise ValueError(f"Upload file not found for id: {model_id}") | |||
| file._storage_key = upload_file_row.key | |||
| file.storage_key = upload_file_row.key | |||
| elif file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| tool_file_row = tool_files.get(model_id) | |||
| if tool_file_row is None: | |||
| raise ValueError(f"Tool file not found for id: {model_id}") | |||
| file._storage_key = tool_file_row.file_key | |||
| file.storage_key = tool_file_row.file_key | |||
| @@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str: | |||
| if isinstance(v, Segment): | |||
| return v.value_type.exposed_type().value | |||
| else: | |||
| return v["value_type"].exposed_type().value | |||
| value_type = v.get("value_type") | |||
| if value_type is None: | |||
| raise ValueError("value_type is required but not provided") | |||
| return value_type.exposed_type().value | |||
| @@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api): | |||
| headers["WWW-Authenticate"] = 'Bearer realm="api"' | |||
| return data, status_code, headers | |||
| _ = handle_http_exception | |||
| @api.errorhandler(ValueError) | |||
| def handle_value_error(e: ValueError): | |||
| got_request_exception.send(current_app, exception=e) | |||
| @@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api): | |||
| data = {"code": "invalid_param", "message": str(e), "status": status_code} | |||
| return data, status_code | |||
| _ = handle_value_error | |||
| @api.errorhandler(AppInvokeQuotaExceededError) | |||
| def handle_quota_exceeded(e: AppInvokeQuotaExceededError): | |||
| got_request_exception.send(current_app, exception=e) | |||
| @@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api): | |||
| data = {"code": "too_many_requests", "message": str(e), "status": status_code} | |||
| return data, status_code | |||
| _ = handle_quota_exceeded | |||
| @api.errorhandler(Exception) | |||
| def handle_general_exception(e: Exception): | |||
| got_request_exception.send(current_app, exception=e) | |||
| status_code = 500 | |||
| data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) | |||
| data = getattr(e, "data", {"message": http_status_message(status_code)}) | |||
| # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) | |||
| if not isinstance(data, Mapping): | |||
| if not isinstance(data, dict): | |||
| data = {"message": str(e)} | |||
| data.setdefault("code", "unknown") | |||
| @@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api): | |||
| exc_info: Any = sys.exc_info() | |||
| if exc_info[1] is None: | |||
| exc_info = None | |||
| current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type] | |||
| current_app.log_exception(exc_info) | |||
| return data, status_code | |||
| _ = handle_general_exception | |||
| class ExternalApi(Api): | |||
| _authorizations = { | |||
| @@ -167,13 +167,6 @@ class DatetimeString: | |||
| return value | |||
| def _get_float(value): | |||
| try: | |||
| return float(value) | |||
| except (TypeError, ValueError): | |||
| raise ValueError(f"{value} is not a valid float") | |||
| def timezone(timezone_string): | |||
| if timezone_string and timezone_string in available_timezones(): | |||
| return timezone_string | |||
| @@ -1,24 +1,44 @@ | |||
| { | |||
| "include": ["."], | |||
| "exclude": [".venv", "tests/", "migrations/"], | |||
| "ignore": [ | |||
| "core/", | |||
| "controllers/", | |||
| "tasks/", | |||
| "services/", | |||
| "schedule/", | |||
| "extensions/", | |||
| "utils/", | |||
| "repositories/", | |||
| "libs/", | |||
| "fields/", | |||
| "factories/", | |||
| "events/", | |||
| "contexts/", | |||
| "constants/", | |||
| "commands.py" | |||
| "exclude": [ | |||
| ".venv", | |||
| "tests/", | |||
| "migrations/", | |||
| "core/rag", | |||
| "extensions", | |||
| "libs", | |||
| "controllers/console/datasets", | |||
| "controllers/service_api/dataset", | |||
| "core/ops", | |||
| "core/tools", | |||
| "core/model_runtime", | |||
| "core/workflow", | |||
| "core/app/app_config/easy_ui_based_app/dataset" | |||
| ], | |||
| "typeCheckingMode": "strict", | |||
| "allowedUntypedLibraries": [ | |||
| "flask_restx", | |||
| "flask_login", | |||
| "opentelemetry.instrumentation.celery", | |||
| "opentelemetry.instrumentation.flask", | |||
| "opentelemetry.instrumentation.requests", | |||
| "opentelemetry.instrumentation.sqlalchemy", | |||
| "opentelemetry.instrumentation.redis" | |||
| ], | |||
| "reportUnknownMemberType": "hint", | |||
| "reportUnknownParameterType": "hint", | |||
| "reportUnknownArgumentType": "hint", | |||
| "reportUnknownVariableType": "hint", | |||
| "reportUnknownLambdaType": "hint", | |||
| "reportMissingParameterType": "hint", | |||
| "reportMissingTypeArgument": "hint", | |||
| "reportUnnecessaryContains": "hint", | |||
| "reportUnnecessaryComparison": "hint", | |||
| "reportUnnecessaryCast": "hint", | |||
| "reportUnnecessaryIsInstance": "hint", | |||
| "reportUntypedFunctionDecorator": "hint", | |||
| "reportAttributeAccessIssue": "hint", | |||
| "pythonVersion": "3.11", | |||
| "pythonPlatform": "All" | |||
| } | |||
| @@ -1318,7 +1318,7 @@ class RegisterService: | |||
| def get_invitation_if_token_valid( | |||
| cls, workspace_id: Optional[str], email: str, token: str | |||
| ) -> Optional[dict[str, Any]]: | |||
| invitation_data = cls._get_invitation_by_token(token, workspace_id, email) | |||
| invitation_data = cls.get_invitation_by_token(token, workspace_id, email) | |||
| if not invitation_data: | |||
| return None | |||
| @@ -1355,7 +1355,7 @@ class RegisterService: | |||
| } | |||
| @classmethod | |||
| def _get_invitation_by_token( | |||
| def get_invitation_by_token( | |||
| cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None | |||
| ) -> Optional[dict[str, str]]: | |||
| if workspace_id is not None and email is not None: | |||
| @@ -349,7 +349,7 @@ class AppAnnotationService: | |||
| try: | |||
| # Skip the first row | |||
| df = pd.read_csv(file, dtype=str) | |||
| df = pd.read_csv(file.stream, dtype=str) | |||
| result = [] | |||
| for _, row in df.iterrows(): | |||
| content = {"question": row.iloc[0], "answer": row.iloc[1]} | |||
| @@ -463,15 +463,23 @@ class AppAnnotationService: | |||
| annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() | |||
| if annotation_setting: | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name, | |||
| }, | |||
| } | |||
| if collection_binding_detail: | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name, | |||
| }, | |||
| } | |||
| else: | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": {}, | |||
| } | |||
| return {"enabled": False} | |||
| @classmethod | |||
| @@ -506,15 +514,23 @@ class AppAnnotationService: | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name, | |||
| }, | |||
| } | |||
| if collection_binding_detail: | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name, | |||
| }, | |||
| } | |||
| else: | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": {}, | |||
| } | |||
| @classmethod | |||
| def clear_all_annotations(cls, app_id: str): | |||
| @@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs: | |||
| datetime.timedelta(hours=1), | |||
| ] | |||
| tenant_count = 0 | |||
| for test_interval in test_intervals: | |||
| tenant_count = ( | |||
| session.query(Tenant.id) | |||
| @@ -134,11 +134,14 @@ class DatasetService: | |||
| # Check if tag_ids is not empty to avoid WHERE false condition | |||
| if tag_ids and len(tag_ids) > 0: | |||
| target_ids = TagService.get_target_ids_by_tag_ids( | |||
| "knowledge", | |||
| tenant_id, # ty: ignore [invalid-argument-type] | |||
| tag_ids, | |||
| ) | |||
| if tenant_id is not None: | |||
| target_ids = TagService.get_target_ids_by_tag_ids( | |||
| "knowledge", | |||
| tenant_id, | |||
| tag_ids, | |||
| ) | |||
| else: | |||
| target_ids = [] | |||
| if target_ids and len(target_ids) > 0: | |||
| query = query.where(Dataset.id.in_(target_ids)) | |||
| else: | |||
| @@ -987,7 +990,8 @@ class DocumentService: | |||
| for document in documents | |||
| if document.data_source_type == "upload_file" and document.data_source_info_dict | |||
| ] | |||
| batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | |||
| if dataset.doc_form is not None: | |||
| batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | |||
| for document in documents: | |||
| db.session.delete(document) | |||
| @@ -2688,56 +2692,6 @@ class SegmentService: | |||
| return paginated_segments.items, paginated_segments.total | |||
| @classmethod | |||
| def update_segment_by_id( | |||
| cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str | |||
| ) -> tuple[DocumentSegment, Document]: | |||
| """Update a segment by its ID with validation and checks.""" | |||
| # check dataset | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check embedding model setting if high quality | |||
| if dataset.indexing_technique == "high_quality": | |||
| try: | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=user_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ValueError( | |||
| "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ValueError(ex.description) | |||
| # check segment | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # validate and update segment | |||
| cls.segment_create_args_validate(segment_data, document) | |||
| updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset) | |||
| return updated_segment, document | |||
| @classmethod | |||
| def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: | |||
| """Get a segment by its ID.""" | |||
| @@ -181,7 +181,7 @@ class ExternalDatasetService: | |||
| do http request depending on api bundle | |||
| """ | |||
| kwargs = { | |||
| kwargs: dict[str, Any] = { | |||
| "url": settings.url, | |||
| "headers": settings.headers, | |||
| "follow_redirects": True, | |||
| @@ -1,7 +1,7 @@ | |||
| import hashlib | |||
| import os | |||
| import uuid | |||
| from typing import Any, Literal, Union | |||
| from typing import Literal, Union | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -35,7 +35,7 @@ class FileService: | |||
| filename: str, | |||
| content: bytes, | |||
| mimetype: str, | |||
| user: Union[Account, EndUser, Any], | |||
| user: Union[Account, EndUser], | |||
| source: Literal["datasets"] | None = None, | |||
| source_url: str = "", | |||
| ) -> UploadFile: | |||
| @@ -165,7 +165,7 @@ class ModelLoadBalancingService: | |||
| try: | |||
| if load_balancing_config.encrypted_config: | |||
| credentials = json.loads(load_balancing_config.encrypted_config) | |||
| credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config) | |||
| else: | |||
| credentials = {} | |||
| except JSONDecodeError: | |||
| @@ -180,11 +180,13 @@ class ModelLoadBalancingService: | |||
| for variable in credential_secret_variables: | |||
| if variable in credentials: | |||
| try: | |||
| credentials[variable] = encrypter.decrypt_token_with_decoding( | |||
| credentials.get(variable), # ty: ignore [invalid-argument-type] | |||
| decoding_rsa_key, | |||
| decoding_cipher_rsa, | |||
| ) | |||
| token_value = credentials.get(variable) | |||
| if isinstance(token_value, str): | |||
| credentials[variable] = encrypter.decrypt_token_with_decoding( | |||
| token_value, | |||
| decoding_rsa_key, | |||
| decoding_cipher_rsa, | |||
| ) | |||
| except ValueError: | |||
| pass | |||
| @@ -345,8 +347,9 @@ class ModelLoadBalancingService: | |||
| credential_id = config.get("credential_id") | |||
| enabled = config.get("enabled") | |||
| credential_record: ProviderCredential | ProviderModelCredential | None = None | |||
| if credential_id: | |||
| credential_record: ProviderCredential | ProviderModelCredential | None = None | |||
| if config_from == "predefined-model": | |||
| credential_record = ( | |||
| db.session.query(ProviderCredential) | |||
| @@ -99,6 +99,7 @@ class PluginMigration: | |||
| datetime.timedelta(hours=1), | |||
| ] | |||
| tenant_count = 0 | |||
| for test_interval in test_intervals: | |||
| tenant_count = ( | |||
| session.query(Tenant.id) | |||
| @@ -223,8 +223,8 @@ class BuiltinToolManageService: | |||
| """ | |||
| add builtin tool provider | |||
| """ | |||
| try: | |||
| with Session(db.engine) as session: | |||
| with Session(db.engine) as session: | |||
| try: | |||
| lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" | |||
| with redis_client.lock(lock, timeout=20): | |||
| provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) | |||
| @@ -285,9 +285,9 @@ class BuiltinToolManageService: | |||
| session.add(db_provider) | |||
| session.commit() | |||
| except Exception as e: | |||
| session.rollback() | |||
| raise ValueError(str(e)) | |||
| except Exception as e: | |||
| session.rollback() | |||
| raise ValueError(str(e)) | |||
| return {"result": "success"} | |||
| @staticmethod | |||
| @@ -18,6 +18,7 @@ from core.helper import encrypter | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.prompt.simple_prompt_transform import SimplePromptTransform | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from core.workflow.nodes import NodeType | |||
| from events.app_event import app_was_created | |||
| from extensions.ext_database import db | |||
| @@ -420,7 +421,11 @@ class WorkflowConverter: | |||
| query_in_prompt=False, | |||
| ) | |||
| template = prompt_template_config["prompt_template"].template | |||
| prompt_template_obj = prompt_template_config["prompt_template"] | |||
| if not isinstance(prompt_template_obj, PromptTemplateParser): | |||
| raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") | |||
| template = prompt_template_obj.template | |||
| if not template: | |||
| prompts = [] | |||
| else: | |||
| @@ -457,7 +462,11 @@ class WorkflowConverter: | |||
| query_in_prompt=False, | |||
| ) | |||
| template = prompt_template_config["prompt_template"].template | |||
| prompt_template_obj = prompt_template_config["prompt_template"] | |||
| if not isinstance(prompt_template_obj, PromptTemplateParser): | |||
| raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") | |||
| template = prompt_template_obj.template | |||
| template = self._replace_template_variables( | |||
| template=template, | |||
| variables=start_node["data"]["variables"], | |||
| @@ -467,6 +476,9 @@ class WorkflowConverter: | |||
| prompts = {"text": template} | |||
| prompt_rules = prompt_template_config["prompt_rules"] | |||
| if not isinstance(prompt_rules, dict): | |||
| raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") | |||
| role_prefix = { | |||
| "user": prompt_rules.get("human_prefix", "Human"), | |||
| "assistant": prompt_rules.get("assistant_prefix", "Assistant"), | |||
| @@ -769,10 +769,10 @@ class WorkflowService: | |||
| ) | |||
| error = node_run_result.error if not run_succeeded else None | |||
| except WorkflowNodeRunFailedError as e: | |||
| node = e._node | |||
| node = e.node | |||
| run_succeeded = False | |||
| node_run_result = None | |||
| error = e._error | |||
| error = e.error | |||
| # Create a NodeExecution domain model | |||
| node_execution = WorkflowNodeExecution( | |||
| @@ -12,7 +12,7 @@ class WorkspaceService: | |||
| def get_tenant_info(cls, tenant: Tenant): | |||
| if not tenant: | |||
| return None | |||
| tenant_info = { | |||
| tenant_info: dict[str, object] = { | |||
| "id": tenant.id, | |||
| "name": tenant.name, | |||
| "plan": tenant.plan, | |||
| @@ -3278,7 +3278,7 @@ class TestRegisterService: | |||
| redis_client.setex(cache_key, 24 * 60 * 60, account_id) | |||
| # Execute invitation retrieval | |||
| result = RegisterService._get_invitation_by_token( | |||
| result = RegisterService.get_invitation_by_token( | |||
| token=token, | |||
| workspace_id=workspace_id, | |||
| email=email, | |||
| @@ -3316,7 +3316,7 @@ class TestRegisterService: | |||
| redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) | |||
| # Execute invitation retrieval | |||
| result = RegisterService._get_invitation_by_token(token=token) | |||
| result = RegisterService.get_invitation_by_token(token=token) | |||
| # Verify result contains expected data | |||
| assert result is not None | |||
| @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( | |||
| VariableEntityType, | |||
| ) | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from models.account import Account, Tenant | |||
| from models.api_based_extension import APIBasedExtension | |||
| from models.model import App, AppMode, AppModelConfig | |||
| @@ -37,7 +38,7 @@ class TestWorkflowConverter: | |||
| # Setup default mock returns | |||
| mock_encrypter.decrypt_token.return_value = "decrypted_api_key" | |||
| mock_prompt_transform.return_value.get_prompt_template.return_value = { | |||
| "prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(), | |||
| "prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"), | |||
| "prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"}, | |||
| } | |||
| mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config() | |||