浏览代码

[CHORE]: remove redundant-cast (#24807)

tags/1.8.1
willzhao 2 个月前
父节点
当前提交
ffba341258
没有帐户链接到提交者的电子邮件

+ 1
- 1
api/core/app/apps/advanced_chat/app_runner.py 查看文件

environment_variables=self._workflow.environment_variables, environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
) )


# init graph # init graph

+ 1
- 1
api/core/helper/encrypter.py 查看文件

from libs import rsa from libs import rsa




def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token: if not token:
return token return token
if len(token) <= 8: if len(token) <= 8:

+ 0
- 18
api/core/model_manager.py 查看文件

""" """
if not isinstance(self.model_type_instance, LargeLanguageModel): if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel") raise Exception("Model type instance is not LargeLanguageModel")

self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
Union[LLMResult, Generator], Union[LLMResult, Generator],
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, LargeLanguageModel): if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel") raise Exception("Model type instance is not LargeLanguageModel")

self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
int, int,
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, TextEmbeddingModel): if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")

self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
TextEmbeddingResult, TextEmbeddingResult,
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, TextEmbeddingModel): if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")

self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
list[int], list[int],
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, RerankModel): if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel") raise Exception("Model type instance is not RerankModel")

self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast( return cast(
RerankResult, RerankResult,
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, ModerationModel): if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel") raise Exception("Model type instance is not ModerationModel")

self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast( return cast(
bool, bool,
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, Speech2TextModel): if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel") raise Exception("Model type instance is not Speech2TextModel")

self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast( return cast(
str, str,
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, TTSModel): if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel") raise Exception("Model type instance is not TTSModel")

self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast( return cast(
Iterable[bytes], Iterable[bytes],
self._round_robin_invoke( self._round_robin_invoke(
""" """
if not isinstance(self.model_type_instance, TTSModel): if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel") raise Exception("Model type instance is not TTSModel")

self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices( return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language model=self.model, credentials=self.credentials, language=language
) )

+ 0
- 1
api/core/prompt/utils/prompt_message_util.py 查看文件

if isinstance(prompt_message.content, list): if isinstance(prompt_message.content, list):
for content in prompt_message.content: for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT: if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data text += content.data
else: else:
content = cast(ImagePromptMessageContent, content) content = cast(ImagePromptMessageContent, content)

+ 3
- 3
api/core/provider_manager.py 查看文件

import json import json
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, Optional


from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
for provider_entity in provider_entities: for provider_entity in provider_entities:
# handle include, exclude # handle include, exclude
if is_filtered( if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity, data=provider_entity,
name_func=lambda x: x.provider, name_func=lambda x: x.provider,
): ):

+ 1
- 2
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py 查看文件

import uuid import uuid
from collections.abc import Generator, Iterable, Sequence from collections.abc import Generator, Iterable, Sequence
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union


import qdrant_client import qdrant_client
from flask import current_app from flask import current_app


def _reload_if_needed(self): def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal): if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load() self._client._load()


@classmethod @classmethod

+ 2
- 2
api/core/rag/extractor/markdown_extractor.py 查看文件



import re import re
from pathlib import Path from pathlib import Path
from typing import Optional, cast
from typing import Optional


from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings from core.rag.extractor.helpers import detect_file_encodings
markdown_tups.append((current_header, current_text)) markdown_tups.append((current_header, current_text))


markdown_tups = [ markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups for key, value in markdown_tups
] ]



+ 1
- 1
api/core/rag/extractor/notion_extractor.py 查看文件

f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}" f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
) )


return cast(str, data_source_binding.access_token)
return data_source_binding.access_token

+ 2
- 2
api/core/rag/extractor/pdf_extractor.py 查看文件



import contextlib import contextlib
from collections.abc import Iterator from collections.abc import Iterator
from typing import Optional, cast
from typing import Optional


from core.rag.extractor.blob.blob import Blob from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
plaintext_file_exists = False plaintext_file_exists = False
if self._file_cache_key: if self._file_cache_key:
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True plaintext_file_exists = True
return [Document(page_content=text)] return [Document(page_content=text)]
documents = list(self.load()) documents = list(self.load())

+ 9
- 12
api/core/tools/tool_manager.py 查看文件

if controller_tools is None or len(controller_tools) == 0: if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")


return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
) )
elif provider_type == ToolProviderType.APP: elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented") raise NotImplementedError("app provider not implemented")
for provider in builtin_providers: for provider in builtin_providers:
# handle include, exclude # handle include, exclude
if is_filtered( if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider, data=provider,
name_func=lambda x: x.identity.name, name_func=lambda x: x.identity.name,
): ):

+ 2
- 3
api/core/tools/utils/message_transformer.py 查看文件

from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Optional, cast
from typing import Optional
from uuid import UUID from uuid import UUID


import numpy as np import numpy as np


elif message.type == ToolInvokeMessage.MessageType.JSON: elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage): if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
message.message.json_object = safe_json_value(message.message.json_object)
yield message yield message
else: else:
yield message yield message

+ 8
- 11
api/core/tools/utils/model_invocation_utils.py 查看文件

db.session.commit() db.session.commit()


try: try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
) )
except InvokeRateLimitError as e: except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}") raise InvokeModelError(f"Invoke rate limit error: {e}")

+ 3
- 3
api/core/tools/workflow_as_tool/tool.py 查看文件

import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, cast
from typing import Any, Optional


from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
item = self._update_file_mapping(item) item = self._update_file_mapping(item)
file = build_from_mapping( file = build_from_mapping(
mapping=item, mapping=item,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
) )
files.append(file) files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value) value = self._update_file_mapping(value)
file = build_from_mapping( file = build_from_mapping(
mapping=value, mapping=value,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
) )
files.append(file) files.append(file)



+ 2
- 2
api/core/variables/variables.py 查看文件

from collections.abc import Sequence from collections.abc import Sequence
from typing import Annotated, TypeAlias, cast
from typing import Annotated, TypeAlias
from uuid import uuid4 from uuid import uuid4


from pydantic import Discriminator, Field, Tag from pydantic import Discriminator, Field, Tag


@property @property
def log(self) -> str: def log(self) -> str:
return cast(str, encrypter.obfuscated_token(self.value))
return encrypter.obfuscated_token(self.value)




class NoneVariable(NoneSegment, Variable): class NoneVariable(NoneSegment, Variable):

+ 1
- 1
api/core/workflow/graph_engine/graph_engine.py 查看文件

if len(sub_edge_mappings) == 0: if len(sub_edge_mappings) == 0:
continue continue


edge = cast(GraphEdge, sub_edge_mappings[0])
edge = sub_edge_mappings[0]
if edge.run_condition is None: if edge.run_condition is None:
logger.warning("Edge %s run condition is None", edge.target_node_id) logger.warning("Edge %s run condition is None", edge.target_node_id)
continue continue

+ 2
- 3
api/core/workflow/nodes/agent/agent_node.py 查看文件

messages=message_stream, messages=message_stream,
tool_info={ tool_info={
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"agent_strategy": self._node_data.agent_strategy_name,
}, },
parameters_for_log=parameters_for_log, parameters_for_log=parameters_for_log,
user_id=self.user_id, user_id=self.user_id,
current_plugin = next( current_plugin = next(
plugin plugin
for plugin in plugins for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
) )
icon = current_plugin.declaration.icon icon = current_plugin.declaration.icon
except StopIteration: except StopIteration:

+ 2
- 2
api/core/workflow/nodes/document_extractor/node.py 查看文件

encoding = "utf-8" encoding = "utf-8"


yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort # If decoding fails, try with utf-8 as last resort
try: try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError): except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e



+ 1
- 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py 查看文件

""" """
Run the node. Run the node.
""" """
node_data = cast(ParameterExtractorNodeData, self._node_data)
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query) variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else "" query = variable.text if variable else ""



+ 2
- 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py 查看文件

import json import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional


from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
return "1" return "1"


def _run(self): def _run(self):
node_data = cast(QuestionClassifierNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool


# extract variables # extract variables

+ 2
- 2
api/core/workflow/nodes/tool/tool_node.py 查看文件

from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional


from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
Run the tool node Run the tool node
""" """


node_data = cast(ToolNodeData, self._node_data)
node_data = self._node_data


# fetch tool icon # fetch tool icon
tool_info = { tool_info = {

+ 1
- 2
api/core/workflow/workflow_entry.py 查看文件

import time import time
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional


from configs import dify_config from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.exc import GenerateTaskStoppedError
environment_variables=[], environment_variables=[],
) )


node_cls = cast(type[BaseNode], node_cls)
# init workflow run state # init workflow run state
node: BaseNode = node_cls( node: BaseNode = node_cls(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),

+ 1
- 2
api/factories/file_factory.py 查看文件

import urllib.parse import urllib.parse
import uuid import uuid
from collections.abc import Callable, Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
from typing import Any


import httpx import httpx
from sqlalchemy import select from sqlalchemy import select
mime_type = "" mime_type = ""


resp = ssrf_proxy.head(url, follow_redirects=True) resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK: if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"): if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"')) filename = str(content_disposition.split("filename=")[-1].strip('"'))

+ 1
- 1
api/models/tools.py 查看文件



@property @property
def decrypted_server_url(self) -> str: def decrypted_server_url(self) -> str:
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
return encrypter.decrypt_token(self.tenant_id, self.server_url)


@property @property
def masked_server_url(self) -> str: def masked_server_url(self) -> str:

+ 3
- 3
api/services/account_service.py 查看文件

account.last_active_at = naive_utc_now() account.last_active_at = naive_utc_now()
db.session.commit() db.session.commit()


return cast(Account, account)
return account


@staticmethod @staticmethod
def get_account_jwt_token(account: Account) -> str: def get_account_jwt_token(account: Account) -> str:


db.session.commit() db.session.commit()


return cast(Account, account)
return account


@staticmethod @staticmethod
def update_account_password(account, password, new_password): def update_account_password(account, password, new_password):
def get_custom_config(tenant_id: str) -> dict: def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id) tenant = db.get_or_404(Tenant, tenant_id)


return cast(dict, tenant.custom_config_dict)
return tenant.custom_config_dict


@staticmethod @staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool: def is_owner(account: Account, tenant: Tenant) -> bool:

+ 3
- 3
api/services/annotation_service.py 查看文件

import uuid import uuid
from typing import cast
from typing import Optional


import pandas as pd import pandas as pd
from flask_login import current_user from flask_login import current_user
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")


annotation = message.annotation
annotation: Optional[MessageAnnotation] = message.annotation
# save the message annotation # save the message annotation
if annotation: if annotation:
annotation.content = args["answer"] annotation.content = args["answer"]
app_id, app_id,
annotation_setting.collection_binding_id, annotation_setting.collection_binding_id,
) )
return cast(MessageAnnotation, annotation)
return annotation


@classmethod @classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict: def enable_app_annotation(cls, args: dict, app_id: str) -> dict:

+ 0
- 6
api/tests/integration_tests/workflow/nodes/test_code.py 查看文件

import time import time
import uuid import uuid
from os import getenv from os import getenv
from typing import cast


import pytest import pytest


from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
} }


node._node_data = cast(CodeNodeData, node._node_data)

# validate # validate
node._transform_result(result, node._node_data.outputs) node._transform_result(result, node._node_data.outputs)


] ]
} }


node._node_data = cast(CodeNodeData, node._node_data)

# validate # validate
node._transform_result(result, node._node_data.outputs) node._transform_result(result, node._node_data.outputs)



正在加载...
取消
保存