Bladeren bron

[CHORE]: remove redundant-cast (#24807)

tags/1.8.1
willzhao 2 maanden geleden
bovenliggende
commit
ffba341258
No account linked to committer's email address

+ 1
- 1
api/core/app/apps/advanced_chat/app_runner.py Bestand weergeven

@@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
)

# init graph

+ 1
- 1
api/core/helper/encrypter.py Bestand weergeven

@@ -3,7 +3,7 @@ import base64
from libs import rsa


def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:

+ 0
- 18
api/core/model_manager.py Bestand weergeven

@@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")

self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")

self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")

self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")

self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")

self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")

self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")

self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")

self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")

self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

+ 0
- 1
api/core/prompt/utils/prompt_message_util.py Bestand weergeven

@@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

+ 3
- 3
api/core/provider_manager.py Bestand weergeven

@@ -2,7 +2,7 @@ import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, Optional

from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -154,8 +154,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):

+ 1
- 2
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py Bestand weergeven

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union

import qdrant_client
from flask import current_app
@@ -426,7 +426,6 @@ class QdrantVector(BaseVector):

def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()

@classmethod

+ 2
- 2
api/core/rag/extractor/markdown_extractor.py Bestand weergeven

@@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import Optional, cast
from typing import Optional

from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
@@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
markdown_tups.append((current_header, current_text))

markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]


+ 1
- 1
api/core/rag/extractor/notion_extractor.py Bestand weergeven

@@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor):
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)

return cast(str, data_source_binding.access_token)
return data_source_binding.access_token

+ 2
- 2
api/core/rag/extractor/pdf_extractor.py Bestand weergeven

@@ -2,7 +2,7 @@

import contextlib
from collections.abc import Iterator
from typing import Optional, cast
from typing import Optional

from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
documents = list(self.load())

+ 9
- 12
api/core/tools/tool_manager.py Bestand weergeven

@@ -331,16 +331,13 @@ class ToolManager:
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
@@ -648,8 +645,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name,
):

+ 2
- 3
api/core/tools/utils/message_transformer.py Bestand weergeven

@@ -3,7 +3,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional, cast
from typing import Optional
from uuid import UUID

import numpy as np
@@ -159,8 +159,7 @@ class ToolFileMessageTransformer:

elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message

+ 8
- 11
api/core/tools/utils/model_invocation_utils.py Bestand weergeven

@@ -129,17 +129,14 @@ class ModelInvocationUtils:
db.session.commit()

try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

+ 3
- 3
api/core/tools/workflow_as_tool/tool.py Bestand weergeven

@@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, cast
from typing import Any, Optional

from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@@ -204,14 +204,14 @@ class WorkflowTool(Tool):
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)


+ 2
- 2
api/core/variables/variables.py Bestand weergeven

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Annotated, TypeAlias, cast
from typing import Annotated, TypeAlias
from uuid import uuid4

from pydantic import Discriminator, Field, Tag
@@ -86,7 +86,7 @@ class SecretVariable(StringVariable):

@property
def log(self) -> str:
return cast(str, encrypter.obfuscated_token(self.value))
return encrypter.obfuscated_token(self.value)


class NoneVariable(NoneSegment, Variable):

+ 1
- 1
api/core/workflow/graph_engine/graph_engine.py Bestand weergeven

@@ -374,7 +374,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue

edge = cast(GraphEdge, sub_edge_mappings[0])
edge = sub_edge_mappings[0]
if edge.run_condition is None:
logger.warning("Edge %s run condition is None", edge.target_node_id)
continue

+ 2
- 3
api/core/workflow/nodes/agent/agent_node.py Bestand weergeven

@@ -153,7 +153,7 @@ class AgentNode(BaseNode):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@@ -394,8 +394,7 @@ class AgentNode(BaseNode):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

+ 2
- 2
api/core/workflow/nodes/document_extractor/node.py Bestand weergeven

@@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
encoding = "utf-8"

yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e


+ 1
- 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py Bestand weergeven

@@ -139,7 +139,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self._node_data)
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""


+ 2
- 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py Bestand weergeven

@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode):
return "1"

def _run(self):
node_data = cast(QuestionClassifierNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool

# extract variables

+ 2
- 2
api/core/workflow/nodes/tool/tool_node.py Bestand weergeven

@@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional

from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -57,7 +57,7 @@ class ToolNode(BaseNode):
Run the tool node
"""

node_data = cast(ToolNodeData, self._node_data)
node_data = self._node_data

# fetch tool icon
tool_info = {

+ 1
- 2
api/core/workflow/workflow_entry.py Bestand weergeven

@@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional

from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -261,7 +261,6 @@ class WorkflowEntry:
environment_variables=[],
)

node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node: BaseNode = node_cls(
id=str(uuid.uuid4()),

+ 1
- 2
api/factories/file_factory.py Bestand weergeven

@@ -3,7 +3,7 @@ import os
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
from typing import Any

import httpx
from sqlalchemy import select
@@ -258,7 +258,6 @@ def _get_remote_file_info(url: str):
mime_type = ""

resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"'))

+ 1
- 1
api/models/tools.py Bestand weergeven

@@ -308,7 +308,7 @@ class MCPToolProvider(Base):

@property
def decrypted_server_url(self) -> str:
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
return encrypter.decrypt_token(self.tenant_id, self.server_url)

@property
def masked_server_url(self) -> str:

+ 3
- 3
api/services/account_service.py Bestand weergeven

@@ -146,7 +146,7 @@ class AccountService:
account.last_active_at = naive_utc_now()
db.session.commit()

return cast(Account, account)
return account

@staticmethod
def get_account_jwt_token(account: Account) -> str:
@@ -191,7 +191,7 @@ class AccountService:

db.session.commit()

return cast(Account, account)
return account

@staticmethod
def update_account_password(account, password, new_password):
@@ -1127,7 +1127,7 @@ class TenantService:
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)

return cast(dict, tenant.custom_config_dict)
return tenant.custom_config_dict

@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:

+ 3
- 3
api/services/annotation_service.py Bestand weergeven

@@ -1,5 +1,5 @@
import uuid
from typing import cast
from typing import Optional

import pandas as pd
from flask_login import current_user
@@ -40,7 +40,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")

annotation = message.annotation
annotation: Optional[MessageAnnotation] = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
@@ -70,7 +70,7 @@ class AppAnnotationService:
app_id,
annotation_setting.collection_binding_id,
)
return cast(MessageAnnotation, annotation)
return annotation

@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:

+ 0
- 6
api/tests/integration_tests/workflow/nodes/test_code.py Bestand weergeven

@@ -1,7 +1,6 @@
import time
import uuid
from os import getenv
from typing import cast

import pytest

@@ -13,7 +12,6 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -238,8 +236,6 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}

node._node_data = cast(CodeNodeData, node._node_data)

# validate
node._transform_result(result, node._node_data.outputs)

@@ -334,8 +330,6 @@ def test_execute_code_output_object_list():
]
}

node._node_data = cast(CodeNodeData, node._node_data)

# validate
node._transform_result(result, node._node_data.outputs)


Laden…
Annuleren
Opslaan