Selaa lähdekoodia

[CHORE]: remove redundant-cast (#24807)

tags/1.8.1
willzhao 2 kuukautta sitten
vanhempi
commit
ffba341258
No account linked to committer's email address

+ 1
- 1
api/core/app/apps/advanced_chat/app_runner.py Näytä tiedosto

@@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
)

# init graph

+ 1
- 1
api/core/helper/encrypter.py Näytä tiedosto

@@ -3,7 +3,7 @@ import base64
from libs import rsa


def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:

+ 0
- 18
api/core/model_manager.py Näytä tiedosto

@@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")

self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")

self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")

self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")

self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")

self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")

self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")

self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")

self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")

self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

+ 0
- 1
api/core/prompt/utils/prompt_message_util.py Näytä tiedosto

@@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

+ 3
- 3
api/core/provider_manager.py Näytä tiedosto

@@ -2,7 +2,7 @@ import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, Optional

from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -154,8 +154,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):

+ 1
- 2
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py Näytä tiedosto

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union

import qdrant_client
from flask import current_app
@@ -426,7 +426,6 @@ class QdrantVector(BaseVector):

def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()

@classmethod

+ 2
- 2
api/core/rag/extractor/markdown_extractor.py Näytä tiedosto

@@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import Optional, cast
from typing import Optional

from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
@@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
markdown_tups.append((current_header, current_text))

markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]


+ 1
- 1
api/core/rag/extractor/notion_extractor.py Näytä tiedosto

@@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor):
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)

return cast(str, data_source_binding.access_token)
return data_source_binding.access_token

+ 2
- 2
api/core/rag/extractor/pdf_extractor.py Näytä tiedosto

@@ -2,7 +2,7 @@

import contextlib
from collections.abc import Iterator
from typing import Optional, cast
from typing import Optional

from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
documents = list(self.load())

+ 9
- 12
api/core/tools/tool_manager.py Näytä tiedosto

@@ -331,16 +331,13 @@ class ToolManager:
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
@@ -648,8 +645,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name,
):

+ 2
- 3
api/core/tools/utils/message_transformer.py Näytä tiedosto

@@ -3,7 +3,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional, cast
from typing import Optional
from uuid import UUID

import numpy as np
@@ -159,8 +159,7 @@ class ToolFileMessageTransformer:

elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message

+ 8
- 11
api/core/tools/utils/model_invocation_utils.py Näytä tiedosto

@@ -129,17 +129,14 @@ class ModelInvocationUtils:
db.session.commit()

try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

+ 3
- 3
api/core/tools/workflow_as_tool/tool.py Näytä tiedosto

@@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, cast
from typing import Any, Optional

from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@@ -204,14 +204,14 @@ class WorkflowTool(Tool):
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)


+ 2
- 2
api/core/variables/variables.py Näytä tiedosto

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Annotated, TypeAlias, cast
from typing import Annotated, TypeAlias
from uuid import uuid4

from pydantic import Discriminator, Field, Tag
@@ -86,7 +86,7 @@ class SecretVariable(StringVariable):

@property
def log(self) -> str:
return cast(str, encrypter.obfuscated_token(self.value))
return encrypter.obfuscated_token(self.value)


class NoneVariable(NoneSegment, Variable):

+ 1
- 1
api/core/workflow/graph_engine/graph_engine.py Näytä tiedosto

@@ -374,7 +374,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue

edge = cast(GraphEdge, sub_edge_mappings[0])
edge = sub_edge_mappings[0]
if edge.run_condition is None:
logger.warning("Edge %s run condition is None", edge.target_node_id)
continue

+ 2
- 3
api/core/workflow/nodes/agent/agent_node.py Näytä tiedosto

@@ -153,7 +153,7 @@ class AgentNode(BaseNode):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@@ -394,8 +394,7 @@ class AgentNode(BaseNode):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

+ 2
- 2
api/core/workflow/nodes/document_extractor/node.py Näytä tiedosto

@@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
encoding = "utf-8"

yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e


+ 1
- 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py Näytä tiedosto

@@ -139,7 +139,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self._node_data)
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""


+ 2
- 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py Näytä tiedosto

@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode):
return "1"

def _run(self):
node_data = cast(QuestionClassifierNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool

# extract variables

+ 2
- 2
api/core/workflow/nodes/tool/tool_node.py Näytä tiedosto

@@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional

from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -57,7 +57,7 @@ class ToolNode(BaseNode):
Run the tool node
"""

node_data = cast(ToolNodeData, self._node_data)
node_data = self._node_data

# fetch tool icon
tool_info = {

+ 1
- 2
api/core/workflow/workflow_entry.py Näytä tiedosto

@@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional

from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -261,7 +261,6 @@ class WorkflowEntry:
environment_variables=[],
)

node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node: BaseNode = node_cls(
id=str(uuid.uuid4()),

+ 1
- 2
api/factories/file_factory.py Näytä tiedosto

@@ -3,7 +3,7 @@ import os
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
from typing import Any

import httpx
from sqlalchemy import select
@@ -258,7 +258,6 @@ def _get_remote_file_info(url: str):
mime_type = ""

resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"'))

+ 1
- 1
api/models/tools.py Näytä tiedosto

@@ -308,7 +308,7 @@ class MCPToolProvider(Base):

@property
def decrypted_server_url(self) -> str:
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
return encrypter.decrypt_token(self.tenant_id, self.server_url)

@property
def masked_server_url(self) -> str:

+ 3
- 3
api/services/account_service.py Näytä tiedosto

@@ -146,7 +146,7 @@ class AccountService:
account.last_active_at = naive_utc_now()
db.session.commit()

return cast(Account, account)
return account

@staticmethod
def get_account_jwt_token(account: Account) -> str:
@@ -191,7 +191,7 @@ class AccountService:

db.session.commit()

return cast(Account, account)
return account

@staticmethod
def update_account_password(account, password, new_password):
@@ -1127,7 +1127,7 @@ class TenantService:
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)

return cast(dict, tenant.custom_config_dict)
return tenant.custom_config_dict

@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:

+ 3
- 3
api/services/annotation_service.py Näytä tiedosto

@@ -1,5 +1,5 @@
import uuid
from typing import cast
from typing import Optional

import pandas as pd
from flask_login import current_user
@@ -40,7 +40,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")

annotation = message.annotation
annotation: Optional[MessageAnnotation] = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
@@ -70,7 +70,7 @@ class AppAnnotationService:
app_id,
annotation_setting.collection_binding_id,
)
return cast(MessageAnnotation, annotation)
return annotation

@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:

+ 0
- 6
api/tests/integration_tests/workflow/nodes/test_code.py Näytä tiedosto

@@ -1,7 +1,6 @@
import time
import uuid
from os import getenv
from typing import cast

import pytest

@@ -13,7 +12,6 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -238,8 +236,6 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}

node._node_data = cast(CodeNodeData, node._node_data)

# validate
node._transform_result(result, node._node_data.outputs)

@@ -334,8 +330,6 @@ def test_execute_code_output_object_list():
]
}

node._node_data = cast(CodeNodeData, node._node_data)

# validate
node._transform_result(result, node._node_data.outputs)


Loading…
Peruuta
Tallenna