| @@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| environment_variables=self._workflow.environment_variables, | |||
| # Based on the definition of `VariableUnion`, | |||
| # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | |||
| conversation_variables=cast(list[VariableUnion], conversation_variables), | |||
| conversation_variables=conversation_variables, | |||
| ) | |||
| # init graph | |||
| @@ -3,7 +3,7 @@ import base64 | |||
| from libs import rsa | |||
| def obfuscated_token(token: str): | |||
| def obfuscated_token(token: str) -> str: | |||
| if not token: | |||
| return token | |||
| if len(token) <= 8: | |||
| @@ -158,8 +158,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, LargeLanguageModel): | |||
| raise Exception("Model type instance is not LargeLanguageModel") | |||
| self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) | |||
| return cast( | |||
| Union[LLMResult, Generator], | |||
| self._round_robin_invoke( | |||
| @@ -188,8 +186,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, LargeLanguageModel): | |||
| raise Exception("Model type instance is not LargeLanguageModel") | |||
| self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) | |||
| return cast( | |||
| int, | |||
| self._round_robin_invoke( | |||
| @@ -214,8 +210,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, TextEmbeddingModel): | |||
| raise Exception("Model type instance is not TextEmbeddingModel") | |||
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |||
| return cast( | |||
| TextEmbeddingResult, | |||
| self._round_robin_invoke( | |||
| @@ -237,8 +231,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, TextEmbeddingModel): | |||
| raise Exception("Model type instance is not TextEmbeddingModel") | |||
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |||
| return cast( | |||
| list[int], | |||
| self._round_robin_invoke( | |||
| @@ -269,8 +261,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, RerankModel): | |||
| raise Exception("Model type instance is not RerankModel") | |||
| self.model_type_instance = cast(RerankModel, self.model_type_instance) | |||
| return cast( | |||
| RerankResult, | |||
| self._round_robin_invoke( | |||
| @@ -295,8 +285,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, ModerationModel): | |||
| raise Exception("Model type instance is not ModerationModel") | |||
| self.model_type_instance = cast(ModerationModel, self.model_type_instance) | |||
| return cast( | |||
| bool, | |||
| self._round_robin_invoke( | |||
| @@ -318,8 +306,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, Speech2TextModel): | |||
| raise Exception("Model type instance is not Speech2TextModel") | |||
| self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) | |||
| return cast( | |||
| str, | |||
| self._round_robin_invoke( | |||
| @@ -343,8 +329,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, TTSModel): | |||
| raise Exception("Model type instance is not TTSModel") | |||
| self.model_type_instance = cast(TTSModel, self.model_type_instance) | |||
| return cast( | |||
| Iterable[bytes], | |||
| self._round_robin_invoke( | |||
| @@ -404,8 +388,6 @@ class ModelInstance: | |||
| """ | |||
| if not isinstance(self.model_type_instance, TTSModel): | |||
| raise Exception("Model type instance is not TTSModel") | |||
| self.model_type_instance = cast(TTSModel, self.model_type_instance) | |||
| return self.model_type_instance.get_tts_model_voices( | |||
| model=self.model, credentials=self.credentials, language=language | |||
| ) | |||
| @@ -87,7 +87,6 @@ class PromptMessageUtil: | |||
| if isinstance(prompt_message.content, list): | |||
| for content in prompt_message.content: | |||
| if content.type == PromptMessageContentType.TEXT: | |||
| content = cast(TextPromptMessageContent, content) | |||
| text += content.data | |||
| else: | |||
| content = cast(ImagePromptMessageContent, content) | |||
| @@ -2,7 +2,7 @@ import contextlib | |||
| import json | |||
| from collections import defaultdict | |||
| from json import JSONDecodeError | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, Optional | |||
| from sqlalchemy import select | |||
| from sqlalchemy.exc import IntegrityError | |||
| @@ -154,8 +154,8 @@ class ProviderManager: | |||
| for provider_entity in provider_entities: | |||
| # handle include, exclude | |||
| if is_filtered( | |||
| include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), | |||
| exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), | |||
| include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, | |||
| exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, | |||
| data=provider_entity, | |||
| name_func=lambda x: x.provider, | |||
| ): | |||
| @@ -3,7 +3,7 @@ import os | |||
| import uuid | |||
| from collections.abc import Generator, Iterable, Sequence | |||
| from itertools import islice | |||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| import qdrant_client | |||
| from flask import current_app | |||
| @@ -426,7 +426,6 @@ class QdrantVector(BaseVector): | |||
| def _reload_if_needed(self): | |||
| if isinstance(self._client, QdrantLocal): | |||
| self._client = cast(QdrantLocal, self._client) | |||
| self._client._load() | |||
| @classmethod | |||
| @@ -2,7 +2,7 @@ | |||
| import re | |||
| from pathlib import Path | |||
| from typing import Optional, cast | |||
| from typing import Optional | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.extractor.helpers import detect_file_encodings | |||
| @@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor): | |||
| markdown_tups.append((current_header, current_text)) | |||
| markdown_tups = [ | |||
| (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value)) | |||
| (re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value)) | |||
| for key, value in markdown_tups | |||
| ] | |||
| @@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor): | |||
| f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}" | |||
| ) | |||
| return cast(str, data_source_binding.access_token) | |||
| return data_source_binding.access_token | |||
| @@ -2,7 +2,7 @@ | |||
| import contextlib | |||
| from collections.abc import Iterator | |||
| from typing import Optional, cast | |||
| from typing import Optional | |||
| from core.rag.extractor.blob.blob import Blob | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| @@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor): | |||
| plaintext_file_exists = False | |||
| if self._file_cache_key: | |||
| with contextlib.suppress(FileNotFoundError): | |||
| text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") | |||
| text = storage.load(self._file_cache_key).decode("utf-8") | |||
| plaintext_file_exists = True | |||
| return [Document(page_content=text)] | |||
| documents = list(self.load()) | |||
| @@ -331,16 +331,13 @@ class ToolManager: | |||
| if controller_tools is None or len(controller_tools) == 0: | |||
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | |||
| return cast( | |||
| WorkflowTool, | |||
| controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( | |||
| runtime=ToolRuntime( | |||
| tenant_id=tenant_id, | |||
| credentials={}, | |||
| invoke_from=invoke_from, | |||
| tool_invoke_from=tool_invoke_from, | |||
| ) | |||
| ), | |||
| return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( | |||
| runtime=ToolRuntime( | |||
| tenant_id=tenant_id, | |||
| credentials={}, | |||
| invoke_from=invoke_from, | |||
| tool_invoke_from=tool_invoke_from, | |||
| ) | |||
| ) | |||
| elif provider_type == ToolProviderType.APP: | |||
| raise NotImplementedError("app provider not implemented") | |||
| @@ -648,8 +645,8 @@ class ToolManager: | |||
| for provider in builtin_providers: | |||
| # handle include, exclude | |||
| if is_filtered( | |||
| include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), | |||
| exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), | |||
| include_set=dify_config.POSITION_TOOL_INCLUDES_SET, | |||
| exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, | |||
| data=provider, | |||
| name_func=lambda x: x.identity.name, | |||
| ): | |||
| @@ -3,7 +3,7 @@ from collections.abc import Generator | |||
| from datetime import date, datetime | |||
| from decimal import Decimal | |||
| from mimetypes import guess_extension | |||
| from typing import Optional, cast | |||
| from typing import Optional | |||
| from uuid import UUID | |||
| import numpy as np | |||
| @@ -159,8 +159,7 @@ class ToolFileMessageTransformer: | |||
| elif message.type == ToolInvokeMessage.MessageType.JSON: | |||
| if isinstance(message.message, ToolInvokeMessage.JsonMessage): | |||
| json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) | |||
| json_msg.json_object = safe_json_value(json_msg.json_object) | |||
| message.message.json_object = safe_json_value(message.message.json_object) | |||
| yield message | |||
| else: | |||
| yield message | |||
| @@ -129,17 +129,14 @@ class ModelInvocationUtils: | |||
| db.session.commit() | |||
| try: | |||
| response: LLMResult = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=model_parameters, | |||
| tools=[], | |||
| stop=[], | |||
| stream=False, | |||
| user=user_id, | |||
| callbacks=[], | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=model_parameters, | |||
| tools=[], | |||
| stop=[], | |||
| stream=False, | |||
| user=user_id, | |||
| callbacks=[], | |||
| ) | |||
| except InvokeRateLimitError as e: | |||
| raise InvokeModelError(f"Invoke rate limit error: {e}") | |||
| @@ -1,7 +1,7 @@ | |||
| import json | |||
| import logging | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, Optional | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | |||
| from core.tools.__base.tool import Tool | |||
| @@ -204,14 +204,14 @@ class WorkflowTool(Tool): | |||
| item = self._update_file_mapping(item) | |||
| file = build_from_mapping( | |||
| mapping=item, | |||
| tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), | |||
| tenant_id=str(self.runtime.tenant_id), | |||
| ) | |||
| files.append(file) | |||
| elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| value = self._update_file_mapping(value) | |||
| file = build_from_mapping( | |||
| mapping=value, | |||
| tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), | |||
| tenant_id=str(self.runtime.tenant_id), | |||
| ) | |||
| files.append(file) | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Annotated, TypeAlias, cast | |||
| from typing import Annotated, TypeAlias | |||
| from uuid import uuid4 | |||
| from pydantic import Discriminator, Field, Tag | |||
| @@ -86,7 +86,7 @@ class SecretVariable(StringVariable): | |||
| @property | |||
| def log(self) -> str: | |||
| return cast(str, encrypter.obfuscated_token(self.value)) | |||
| return encrypter.obfuscated_token(self.value) | |||
| class NoneVariable(NoneSegment, Variable): | |||
| @@ -374,7 +374,7 @@ class GraphEngine: | |||
| if len(sub_edge_mappings) == 0: | |||
| continue | |||
| edge = cast(GraphEdge, sub_edge_mappings[0]) | |||
| edge = sub_edge_mappings[0] | |||
| if edge.run_condition is None: | |||
| logger.warning("Edge %s run condition is None", edge.target_node_id) | |||
| continue | |||
| @@ -153,7 +153,7 @@ class AgentNode(BaseNode): | |||
| messages=message_stream, | |||
| tool_info={ | |||
| "icon": self.agent_strategy_icon, | |||
| "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, | |||
| "agent_strategy": self._node_data.agent_strategy_name, | |||
| }, | |||
| parameters_for_log=parameters_for_log, | |||
| user_id=self.user_id, | |||
| @@ -394,8 +394,7 @@ class AgentNode(BaseNode): | |||
| current_plugin = next( | |||
| plugin | |||
| for plugin in plugins | |||
| if f"{plugin.plugin_id}/{plugin.name}" | |||
| == cast(AgentNodeData, self._node_data).agent_strategy_provider_name | |||
| if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name | |||
| ) | |||
| icon = current_plugin.declaration.icon | |||
| except StopIteration: | |||
| @@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str: | |||
| encoding = "utf-8" | |||
| yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) | |||
| return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) | |||
| return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) | |||
| except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: | |||
| # If decoding fails, try with utf-8 as last resort | |||
| try: | |||
| yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) | |||
| return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) | |||
| return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) | |||
| except (UnicodeDecodeError, yaml.YAMLError): | |||
| raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e | |||
| @@ -139,7 +139,7 @@ class ParameterExtractorNode(BaseNode): | |||
| """ | |||
| Run the node. | |||
| """ | |||
| node_data = cast(ParameterExtractorNodeData, self._node_data) | |||
| node_data = self._node_data | |||
| variable = self.graph_runtime_state.variable_pool.get(node_data.query) | |||
| query = variable.text if variable else "" | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| from typing import TYPE_CHECKING, Any, Optional | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| @@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode): | |||
| return "1" | |||
| def _run(self): | |||
| node_data = cast(QuestionClassifierNodeData, self._node_data) | |||
| node_data = self._node_data | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| # extract variables | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, Optional | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| @@ -57,7 +57,7 @@ class ToolNode(BaseNode): | |||
| Run the tool node | |||
| """ | |||
| node_data = cast(ToolNodeData, self._node_data) | |||
| node_data = self._node_data | |||
| # fetch tool icon | |||
| tool_info = { | |||
| @@ -2,7 +2,7 @@ import logging | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, Optional | |||
| from configs import dify_config | |||
| from core.app.apps.exc import GenerateTaskStoppedError | |||
| @@ -261,7 +261,6 @@ class WorkflowEntry: | |||
| environment_variables=[], | |||
| ) | |||
| node_cls = cast(type[BaseNode], node_cls) | |||
| # init workflow run state | |||
| node: BaseNode = node_cls( | |||
| id=str(uuid.uuid4()), | |||
| @@ -3,7 +3,7 @@ import os | |||
| import urllib.parse | |||
| import uuid | |||
| from collections.abc import Callable, Mapping, Sequence | |||
| from typing import Any, cast | |||
| from typing import Any | |||
| import httpx | |||
| from sqlalchemy import select | |||
| @@ -258,7 +258,6 @@ def _get_remote_file_info(url: str): | |||
| mime_type = "" | |||
| resp = ssrf_proxy.head(url, follow_redirects=True) | |||
| resp = cast(httpx.Response, resp) | |||
| if resp.status_code == httpx.codes.OK: | |||
| if content_disposition := resp.headers.get("Content-Disposition"): | |||
| filename = str(content_disposition.split("filename=")[-1].strip('"')) | |||
| @@ -308,7 +308,7 @@ class MCPToolProvider(Base): | |||
| @property | |||
| def decrypted_server_url(self) -> str: | |||
| return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url)) | |||
| return encrypter.decrypt_token(self.tenant_id, self.server_url) | |||
| @property | |||
| def masked_server_url(self) -> str: | |||
| @@ -146,7 +146,7 @@ class AccountService: | |||
| account.last_active_at = naive_utc_now() | |||
| db.session.commit() | |||
| return cast(Account, account) | |||
| return account | |||
| @staticmethod | |||
| def get_account_jwt_token(account: Account) -> str: | |||
| @@ -191,7 +191,7 @@ class AccountService: | |||
| db.session.commit() | |||
| return cast(Account, account) | |||
| return account | |||
| @staticmethod | |||
| def update_account_password(account, password, new_password): | |||
| @@ -1127,7 +1127,7 @@ class TenantService: | |||
| def get_custom_config(tenant_id: str) -> dict: | |||
| tenant = db.get_or_404(Tenant, tenant_id) | |||
| return cast(dict, tenant.custom_config_dict) | |||
| return tenant.custom_config_dict | |||
| @staticmethod | |||
| def is_owner(account: Account, tenant: Tenant) -> bool: | |||
| @@ -1,5 +1,5 @@ | |||
| import uuid | |||
| from typing import cast | |||
| from typing import Optional | |||
| import pandas as pd | |||
| from flask_login import current_user | |||
| @@ -40,7 +40,7 @@ class AppAnnotationService: | |||
| if not message: | |||
| raise NotFound("Message Not Exists.") | |||
| annotation = message.annotation | |||
| annotation: Optional[MessageAnnotation] = message.annotation | |||
| # save the message annotation | |||
| if annotation: | |||
| annotation.content = args["answer"] | |||
| @@ -70,7 +70,7 @@ class AppAnnotationService: | |||
| app_id, | |||
| annotation_setting.collection_binding_id, | |||
| ) | |||
| return cast(MessageAnnotation, annotation) | |||
| return annotation | |||
| @classmethod | |||
| def enable_app_annotation(cls, args: dict, app_id: str) -> dict: | |||
| @@ -1,7 +1,6 @@ | |||
| import time | |||
| import uuid | |||
| from os import getenv | |||
| from typing import cast | |||
| import pytest | |||
| @@ -13,7 +12,6 @@ from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.code.code_node import CodeNode | |||
| from core.workflow.nodes.code.entities import CodeNodeData | |||
| from core.workflow.system_variable import SystemVariable | |||
| from models.enums import UserFrom | |||
| from models.workflow import WorkflowType | |||
| @@ -238,8 +236,6 @@ def test_execute_code_output_validator_depth(): | |||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||
| } | |||
| node._node_data = cast(CodeNodeData, node._node_data) | |||
| # validate | |||
| node._transform_result(result, node._node_data.outputs) | |||
| @@ -334,8 +330,6 @@ def test_execute_code_output_object_list(): | |||
| ] | |||
| } | |||
| node._node_data = cast(CodeNodeData, node._node_data) | |||
| # validate | |||
| node._transform_result(result, node._node_data.outputs) | |||