瀏覽代碼

chore: improve typing

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/2.0.0-beta.2^2
-LAN- 1 月之前
父節點
當前提交
7ba1f0a046
沒有連結到貢獻者的電子郵件帳戶。

+ 1
- 1
api/core/workflow/entities/graph_runtime_state.py 查看文件

@@ -25,7 +25,7 @@ class GraphRuntimeState(BaseModel):
llm_usage: LLMUsage | None = None,
outputs: dict[str, Any] | None = None,
node_run_steps: int = 0,
**kwargs,
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)

+ 4
- 4
api/core/workflow/entities/variable_pool.py 查看文件

@@ -14,7 +14,7 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_V
from core.workflow.system_variable import SystemVariable
from factories import variable_factory

VariableValue = Union[str, int, float, dict, list, File]
VariableValue = Union[str, int, float, dict[str, object], list[object], File]

VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")

@@ -40,11 +40,11 @@ class VariablePool(BaseModel):
)
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
default_factory=list[VariableUnion],
)
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
default_factory=list[VariableUnion],
)

def model_post_init(self, context: Any, /) -> None:
@@ -191,7 +191,7 @@ class VariablePool(BaseModel):

def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments = []
segments: list[Segment] = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)

+ 4
- 2
api/core/workflow/graph/graph_runtime_state_protocol.py 查看文件

@@ -1,16 +1,18 @@
from collections.abc import Mapping
from typing import Any, Protocol

from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment


class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""

def get(self, node_id: str, variable_key: str) -> Any:
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (read-only)."""
...

def get_all_by_node(self, node_id: str) -> dict[str, Any]:
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (read-only)."""
...


+ 8
- 7
api/core/workflow/graph/read_only_state_wrapper.py 查看文件

@@ -1,7 +1,9 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any

from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool

@@ -12,19 +14,18 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool):
self._variable_pool = variable_pool

def get(self, node_id: str, variable_key: str) -> Any:
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
value = self._variable_pool.get(node_id, variable_key)
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None

def get_all_by_node(self, node_id: str) -> dict[str, Any]:
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
variables = {}
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# FIXME(-LAN-): Handle the actual Variable object structure
value = var.value if hasattr(var, "value") else var
variables[key] = deepcopy(value)
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
return variables



+ 6
- 1
api/core/workflow/node_events/base.py 查看文件

@@ -13,6 +13,11 @@ class NodeEventBase(BaseModel):
pass


def _default_metadata():
v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
return v


class NodeRunResult(BaseModel):
"""
Node Run Result.
@@ -23,7 +28,7 @@ class NodeRunResult(BaseModel):
inputs: Mapping[str, Any] = Field(default_factory=dict)
process_data: Mapping[str, Any] = Field(default_factory=dict)
outputs: Mapping[str, Any] = Field(default_factory=dict)
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=dict)
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata)
llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)

edge_source_handle: str = "source" # source handle id of node with multiple branches

+ 106
- 37
api/core/workflow/utils/condition/processor.py 查看文件

@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import Any, Literal, NamedTuple, Union
from typing import Literal, NamedTuple

from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
@@ -10,7 +10,7 @@ from core.workflow.entities import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator


def _convert_to_bool(value: Any) -> bool:
def _convert_to_bool(value: object) -> bool:
if isinstance(value, int):
return bool(value)

@@ -23,7 +23,7 @@ def _convert_to_bool(value: Any) -> bool:


class ConditionCheckResult(NamedTuple):
inputs: Sequence[Mapping[str, Any]]
inputs: Sequence[Mapping[str, object]]
group_results: Sequence[bool]
final_result: bool

@@ -36,7 +36,7 @@ class ConditionProcessor:
conditions: Sequence[Condition],
operator: Literal["and", "or"],
) -> ConditionCheckResult:
input_conditions: list[Mapping[str, Any]] = []
input_conditions: list[Mapping[str, object]] = []
group_results: list[bool] = []

for condition in conditions:
@@ -103,8 +103,8 @@ class ConditionProcessor:
def _evaluate_condition(
*,
operator: SupportedComparisonOperator,
value: Any,
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
value: object,
expected: str | Sequence[str] | bool | Sequence[bool] | None,
) -> bool:
match operator:
case "contains":
@@ -144,7 +144,17 @@ def _evaluate_condition(
case "not in":
return _assert_not_in(value=value, expected=expected)
case "all of" if isinstance(expected, list):
return _assert_all_of(value=value, expected=expected)
# Type narrowing: at this point expected is a list, could be list[str] or list[bool]
if all(isinstance(item, str) for item in expected):
# Create a new typed list to satisfy type checker
str_list: list[str] = [item for item in expected if isinstance(item, str)]
return _assert_all_of(value=value, expected=str_list)
elif all(isinstance(item, bool) for item in expected):
# Create a new typed list to satisfy type checker
bool_list: list[bool] = [item for item in expected if isinstance(item, bool)]
return _assert_all_of_bool(value=value, expected=bool_list)
else:
raise ValueError("all of operator expects homogeneous list of strings or booleans")
case "exists":
return _assert_exists(value=value)
case "not exists":
@@ -153,55 +163,73 @@ def _evaluate_condition(
raise ValueError(f"Unsupported operator: {operator}")


def _assert_contains(*, value: Any, expected: Any) -> bool:
def _assert_contains(*, value: object, expected: object) -> bool:
if not value:
return False

if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")

if expected not in value:
return False
# Type checking ensures value is str or list at this point
if isinstance(value, str):
if not isinstance(expected, str):
expected = str(expected)
if expected not in value:
return False
else: # value is list
if expected not in value:
return False
return True


def _assert_not_contains(*, value: Any, expected: Any) -> bool:
def _assert_not_contains(*, value: object, expected: object) -> bool:
if not value:
return True

if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")

if expected in value:
return False
# Type checking ensures value is str or list at this point
if isinstance(value, str):
if not isinstance(expected, str):
expected = str(expected)
if expected in value:
return False
else: # value is list
if expected in value:
return False
return True


def _assert_start_with(*, value: Any, expected: Any) -> bool:
def _assert_start_with(*, value: object, expected: object) -> bool:
if not value:
return False

if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")

if not isinstance(expected, str):
raise ValueError("Expected value must be a string for startswith")
if not value.startswith(expected):
return False
return True


def _assert_end_with(*, value: Any, expected: Any) -> bool:
def _assert_end_with(*, value: object, expected: object) -> bool:
if not value:
return False

if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")

if not isinstance(expected, str):
raise ValueError("Expected value must be a string for endswith")
if not value.endswith(expected):
return False
return True


def _assert_is(*, value: Any, expected: Any) -> bool:
def _assert_is(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -213,7 +241,7 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
return True


def _assert_is_not(*, value: Any, expected: Any) -> bool:
def _assert_is_not(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -225,19 +253,19 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
return True


def _assert_empty(*, value: Any) -> bool:
def _assert_empty(*, value: object) -> bool:
if not value:
return True
return False


def _assert_not_empty(*, value: Any) -> bool:
def _assert_not_empty(*, value: object) -> bool:
if value:
return True
return False


def _assert_equal(*, value: Any, expected: Any) -> bool:
def _assert_equal(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -246,10 +274,16 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:

# Handle boolean comparison
if isinstance(value, bool):
if not isinstance(expected, (bool, int, str)):
raise ValueError(f"Cannot convert {type(expected)} to bool")
expected = bool(expected)
elif isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)

if value != expected:
@@ -257,7 +291,7 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
return True


def _assert_not_equal(*, value: Any, expected: Any) -> bool:
def _assert_not_equal(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -266,10 +300,16 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:

# Handle boolean comparison
if isinstance(value, bool):
if not isinstance(expected, (bool, int, str)):
raise ValueError(f"Cannot convert {type(expected)} to bool")
expected = bool(expected)
elif isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)

if value == expected:
@@ -277,7 +317,7 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
return True


def _assert_greater_than(*, value: Any, expected: Any) -> bool:
def _assert_greater_than(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -285,8 +325,12 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")

if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)

if value <= expected:
@@ -294,7 +338,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
return True


def _assert_less_than(*, value: Any, expected: Any) -> bool:
def _assert_less_than(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -302,8 +346,12 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")

if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)

if value >= expected:
@@ -311,7 +359,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
return True


def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -319,8 +367,12 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")

if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)

if value < expected:
@@ -328,7 +380,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
return True


def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
if value is None:
return False

@@ -336,8 +388,12 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")

if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)

if value > expected:
@@ -345,19 +401,19 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
return True


def _assert_null(*, value: Any) -> bool:
def _assert_null(*, value: object) -> bool:
if value is None:
return True
return False


def _assert_not_null(*, value: Any) -> bool:
def _assert_not_null(*, value: object) -> bool:
if value is not None:
return True
return False


def _assert_in(*, value: Any, expected: Any) -> bool:
def _assert_in(*, value: object, expected: object) -> bool:
if not value:
return False

@@ -369,7 +425,7 @@ def _assert_in(*, value: Any, expected: Any) -> bool:
return True


def _assert_not_in(*, value: Any, expected: Any) -> bool:
def _assert_not_in(*, value: object, expected: object) -> bool:
if not value:
return True

@@ -381,20 +437,33 @@ def _assert_not_in(*, value: Any, expected: Any) -> bool:
return True


def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool:
def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool:
if not value:
return False

if not all(item in value for item in expected):
# Ensure value is a container that supports 'in' operator
if not isinstance(value, (list, tuple, set, str)):
return False
return True

return all(item in value for item in expected)


def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool:
if not value:
return False

# Ensure value is a container that supports 'in' operator
if not isinstance(value, (list, tuple, set)):
return False

return all(item in value for item in expected)


def _assert_exists(*, value: Any) -> bool:
def _assert_exists(*, value: object) -> bool:
return value is not None


def _assert_not_exists(*, value: Any) -> bool:
def _assert_not_exists(*, value: object) -> bool:
return value is None


@@ -404,7 +473,7 @@ def _process_sub_conditions(
operator: Literal["and", "or"],
) -> bool:
files = variable.value
group_results = []
group_results: list[bool] = []
for condition in sub_conditions:
key = FileAttribute(condition.key)
values = [file_manager.get_attr(file=file, attr=key) for file in files]
@@ -415,14 +484,14 @@ def _process_sub_conditions(
if expected_value and not expected_value.startswith("."):
expected_value = "." + expected_value

normalized_values = []
normalized_values: list[object] = []
for value in values:
if value and isinstance(value, str):
if not value.startswith("."):
value = "." + value
normalized_values.append(value)
values = normalized_values
sub_group_results = [
sub_group_results: list[bool] = [
_evaluate_condition(
value=value,
operator=condition.comparison_operator,

Loading…
取消
儲存