Browse Source

Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

tags/2.0.0-beta.2
-LAN- 1 month ago
parent
commit
23cd615489
No account linked to committer's email address
100 changed files with 399 additions and 368 deletions
  1. 4
    0
      .github/workflows/style.yml
  2. 8
    0
      .gitignore
  3. 3
    0
      api/app_factory.py
  4. 1
    2
      api/commands.py
  5. 1
    2
      api/configs/middleware/__init__.py
  6. 3
    2
      api/configs/middleware/vdb/clickzetta_config.py
  7. 3
    2
      api/configs/middleware/vdb/matrixone_config.py
  8. 1
    1
      api/configs/packaging/__init__.py
  9. 32
    30
      api/configs/remote_settings_sources/apollo/client.py
  10. 6
    4
      api/configs/remote_settings_sources/apollo/python_3x.py
  11. 6
    5
      api/configs/remote_settings_sources/apollo/utils.py
  12. 1
    1
      api/configs/remote_settings_sources/base.py
  13. 5
    8
      api/configs/remote_settings_sources/nacos/__init__.py
  14. 15
    7
      api/configs/remote_settings_sources/nacos/http_request.py
  15. 1
    1
      api/configs/remote_settings_sources/nacos/utils.py
  16. 6
    2
      api/controllers/console/admin.py
  17. 1
    1
      api/controllers/console/apikey.py
  18. 1
    1
      api/controllers/console/app/generator.py
  19. 3
    4
      api/controllers/console/app/workflow_draft_variable.py
  20. 13
    13
      api/controllers/console/auth/oauth_server.py
  21. 6
    3
      api/controllers/console/billing/billing.py
  22. 6
    0
      api/controllers/console/datasets/datasets_document.py
  23. 2
    0
      api/controllers/console/explore/parameter.py
  24. 4
    0
      api/controllers/console/explore/workflow.py
  25. 12
    14
      api/controllers/console/explore/wraps.py
  26. 7
    2
      api/controllers/console/workspace/__init__.py
  27. 33
    28
      api/controllers/console/wraps.py
  28. 1
    1
      api/controllers/mcp/mcp.py
  29. 4
    4
      api/controllers/service_api/dataset/segment.py
  30. 10
    28
      api/controllers/service_api/wraps.py
  31. 4
    0
      api/controllers/web/wraps.py
  32. 1
    1
      api/core/agent/base_agent_runner.py
  33. 1
    1
      api/core/agent/cot_agent_runner.py
  34. 1
    1
      api/core/agent/entities.py
  35. 1
    1
      api/core/app/app_config/easy_ui_based_app/dataset/manager.py
  36. 1
    1
      api/core/app/app_config/easy_ui_based_app/model_config/manager.py
  37. 1
    1
      api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py
  38. 1
    1
      api/core/app/apps/advanced_chat/app_config_manager.py
  39. 1
    1
      api/core/app/apps/advanced_chat/app_generator.py
  40. 4
    3
      api/core/app/apps/advanced_chat/app_runner.py
  41. 3
    3
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  42. 1
    1
      api/core/app/apps/agent_chat/app_config_manager.py
  43. 1
    1
      api/core/app/apps/agent_chat/app_generator.py
  44. 1
    1
      api/core/app/apps/agent_chat/app_runner.py
  45. 2
    2
      api/core/app/apps/agent_chat/generate_response_converter.py
  46. 1
    1
      api/core/app/apps/base_app_generate_response_converter.py
  47. 1
    1
      api/core/app/apps/base_app_generator.py
  48. 6
    6
      api/core/app/apps/base_app_queue_manager.py
  49. 4
    6
      api/core/app/apps/base_app_runner.py
  50. 1
    1
      api/core/app/apps/chat/app_config_manager.py
  51. 1
    1
      api/core/app/apps/chat/app_generator.py
  52. 1
    1
      api/core/app/apps/chat/app_runner.py
  53. 2
    2
      api/core/app/apps/chat/generate_response_converter.py
  54. 1
    1
      api/core/app/apps/common/workflow_response_converter.py
  55. 1
    1
      api/core/app/apps/completion/app_config_manager.py
  56. 4
    1
      api/core/app/apps/completion/app_generator.py
  57. 1
    1
      api/core/app/apps/completion/app_runner.py
  58. 2
    2
      api/core/app/apps/completion/generate_response_converter.py
  59. 2
    2
      api/core/app/apps/message_based_app_queue_manager.py
  60. 1
    1
      api/core/app/apps/workflow/app_config_manager.py
  61. 2
    2
      api/core/app/apps/workflow/app_queue_manager.py
  62. 3
    2
      api/core/app/apps/workflow/app_runner.py
  63. 2
    2
      api/core/app/apps/workflow/generate_response_converter.py
  64. 3
    3
      api/core/app/apps/workflow/generate_task_pipeline.py
  65. 3
    3
      api/core/app/apps/workflow_app_runner.py
  66. 1
    1
      api/core/app/task_pipeline/based_generate_task_pipeline.py
  67. 3
    3
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  68. 2
    2
      api/core/app/task_pipeline/message_cycle_manager.py
  69. 7
    7
      api/core/callback_handler/agent_tool_callback_handler.py
  70. 3
    3
      api/core/callback_handler/index_tool_callback_handler.py
  71. 0
    2
      api/core/datasource/__base/datasource_runtime.py
  72. 2
    2
      api/core/entities/model_entities.py
  73. 16
    18
      api/core/entities/provider_configuration.py
  74. 1
    1
      api/core/errors/error.py
  75. 2
    2
      api/core/extension/api_based_extension_requestor.py
  76. 1
    1
      api/core/extension/extensible.py
  77. 1
    1
      api/core/external_data_tool/api/api.py
  78. 2
    2
      api/core/external_data_tool/base.py
  79. 2
    2
      api/core/external_data_tool/factory.py
  80. 1
    1
      api/core/file/tool_file_parser.py
  81. 1
    1
      api/core/helper/code_executor/code_node_provider.py
  82. 1
    1
      api/core/helper/code_executor/jinja2/jinja2_transformer.py
  83. 1
    1
      api/core/helper/code_executor/python3/python3_code_provider.py
  84. 2
    2
      api/core/helper/model_provider_cache.py
  85. 4
    4
      api/core/helper/provider_cache.py
  86. 2
    2
      api/core/helper/tool_parameter_cache.py
  87. 1
    1
      api/core/helper/trace_id_helper.py
  88. 2
    2
      api/core/hosting_configuration.py
  89. 14
    11
      api/core/indexing_runner.py
  90. 6
    8
      api/core/llm_generator/llm_generator.py
  91. 1
    3
      api/core/llm_generator/output_parser/rule_config_generator.py
  92. 5
    5
      api/core/llm_generator/output_parser/structured_output.py
  93. 1
    2
      api/core/llm_generator/output_parser/suggested_questions_after_answer.py
  94. 3
    3
      api/core/mcp/auth/auth_provider.py
  95. 8
    8
      api/core/mcp/client/sse_client.py
  96. 12
    12
      api/core/mcp/client/streamable_client.py
  97. 14
    16
      api/core/mcp/session/base_session.py
  98. 10
    12
      api/core/mcp/session/client_session.py
  99. 2
    1
      api/core/memory/token_buffer_memory.py
  100. 0
    0
      api/core/model_manager.py

+ 4
- 0
.github/workflows/style.yml View File

@@ -47,6 +47,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check

- name: Run Mypy Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .

- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example

+ 8
- 0
.gitignore View File

@@ -198,6 +198,7 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template
!.vscode/README.md
api/.vscode
web/.vscode
# vscode Code History Extension
.history

@@ -215,6 +216,13 @@ mise.toml
# Next.js build output
.next/

# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js

# AI Assistant
.roo/
api/.env.backup

+ 3
- 0
api/app_factory.py View File

@@ -25,6 +25,9 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()

# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request

return dify_app



+ 1
- 2
api/commands.py View File

@@ -14,7 +14,6 @@ from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
from core.helper import encrypter
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
@@ -1494,7 +1493,7 @@ def transform_datasource_credentials():
for credential in credentials:
auth_count += 1
# get credential api key
credentials_json =json.loads(credential.credentials)
credentials_json = json.loads(credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {

+ 1
- 2
api/configs/middleware/__init__.py View File

@@ -300,8 +300,7 @@ class DatasetQueueMonitorConfig(BaseSettings):

class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
DatabaseConfig,
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
KeywordStoreConfig,
RedisConfig,
# configs of storage and storage providers

+ 3
- 2
api/configs/middleware/vdb/clickzetta_config.py View File

@@ -1,9 +1,10 @@
from typing import Optional

from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings


class ClickzettaConfig(BaseModel):
class ClickzettaConfig(BaseSettings):
"""
Clickzetta Lakehouse vector database configuration
"""

+ 3
- 2
api/configs/middleware/vdb/matrixone_config.py View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings


class MatrixoneConfig(BaseModel):
class MatrixoneConfig(BaseSettings):
"""Matrixone vector database configuration."""

MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")

+ 1
- 1
api/configs/packaging/__init__.py View File

@@ -1,6 +1,6 @@
from pydantic import Field

from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
from configs.packaging.pyproject import PyProjectTomlConfig


class PackagingInfo(PyProjectTomlConfig):

+ 32
- 30
api/configs/remote_settings_sources/apollo/client.py View File

@@ -4,8 +4,9 @@ import logging
import os
import threading
import time
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import Any

from .python_3x import http_request, makedirs_wrapper
from .utils import (
@@ -25,13 +26,13 @@ logger = logging.getLogger(__name__)
class ApolloClient:
def __init__(
self,
config_url,
app_id,
cluster="default",
secret="",
start_hot_update=True,
change_listener=None,
_notification_map=None,
config_url: str,
app_id: str,
cluster: str = "default",
secret: str = "",
start_hot_update: bool = True,
change_listener: Callable[[str, str, str, Any], None] | None = None,
_notification_map: dict[str, int] | None = None,
):
# Core routing parameters
self.config_url = config_url
@@ -47,17 +48,17 @@ class ApolloClient:
# Private control variables
self._cycle_time = 5
self._stopping = False
self._cache = {}
self._no_key = {}
self._hash = {}
self._cache: dict[str, dict[str, Any]] = {}
self._no_key: dict[str, str] = {}
self._hash: dict[str, str] = {}
self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None
self._long_poll_thread: threading.Thread | None = None
self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None:
_notification_map = {"application": -1}
self._notification_map = _notification_map
self.last_release_key = None
self.last_release_key: str | None = None
# Private startup method
self._path_checker()
if start_hot_update:
@@ -68,7 +69,7 @@ class ApolloClient:
heartbeat.daemon = True
heartbeat.start()

def get_json_from_net(self, namespace="application"):
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
)
@@ -88,7 +89,7 @@ class ApolloClient:
logger.exception("an error occurred in get_json_from_net")
return None

def get_value(self, key, default_val=None, namespace="application"):
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
try:
# read memory configuration
namespace_cache = self._cache.get(namespace)
@@ -104,7 +105,8 @@ class ApolloClient:
namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key)
if val is not None:
self._update_cache_and_file(namespace_data, namespace)
if namespace_data is not None:
self._update_cache_and_file(namespace_data, namespace)
return val

# read the file configuration
@@ -126,23 +128,23 @@ class ApolloClient:
# to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice
# and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key):
def _set_local_cache_none(self, namespace: str, key: str) -> None:
no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key

def _start_hot_update(self):
def _start_hot_update(self) -> None:
self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched.
self._long_poll_thread.daemon = True
self._long_poll_thread.start()

def stop(self):
def stop(self) -> None:
self._stopping = True
logger.info("Stopping listener...")

# Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv):
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
if self._change_listener is None:
return
if old_kv is None:
@@ -168,12 +170,12 @@ class ApolloClient:
except BaseException as e:
logger.warning(str(e))

def _path_checker(self):
def _path_checker(self) -> None:
if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path)

# update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"):
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
# update the local cache
self._cache[namespace] = namespace_data
# update the file cache
@@ -187,7 +189,7 @@ class ApolloClient:
self._hash[namespace] = new_hash

# get the configuration from the local file
def _get_local_cache(self, namespace="application"):
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path):
with open(cache_file_path) as f:
@@ -195,8 +197,8 @@ class ApolloClient:
return result
return {}

def _long_poll(self):
notifications = []
def _long_poll(self) -> None:
notifications: list[dict[str, Any]] = []
for key in self._cache:
namespace_data = self._cache[key]
notification_id = -1
@@ -236,7 +238,7 @@ class ApolloClient:
except Exception as e:
logger.warning(str(e))

def _get_net_and_set_local(self, namespace, n_id, call_change=False):
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
namespace_data = self.get_json_from_net(namespace)
if not namespace_data:
return
@@ -248,7 +250,7 @@ class ApolloClient:
new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv)

def _listener(self):
def _listener(self) -> None:
logger.info("start long_poll")
while not self._stopping:
self._long_poll()
@@ -266,13 +268,13 @@ class ApolloClient:
headers["Timestamp"] = time_unix_now
return headers

def _heart_beat(self):
def _heart_beat(self) -> None:
while not self._stopping:
for namespace in self._notification_map:
self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10 minutes

def _do_heart_beat(self, namespace):
def _do_heart_beat(self, namespace: str) -> None:
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
@@ -292,7 +294,7 @@ class ApolloClient:
logger.exception("an error occurred in _do_heart_beat")
return None

def get_all_dicts(self, namespace):
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
namespace_data = self._cache.get(namespace)
if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace)

+ 6
- 4
api/configs/remote_settings_sources/apollo/python_3x.py View File

@@ -2,6 +2,8 @@ import logging
import os
import ssl
import urllib.request
from collections.abc import Mapping
from typing import Any
from urllib import parse
from urllib.error import HTTPError

@@ -19,9 +21,9 @@ urllib.request.install_opener(opener)
logger = logging.getLogger(__name__)


def http_request(url, timeout, headers={}):
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
try:
request = urllib.request.Request(url, headers=headers)
request = urllib.request.Request(url, headers=dict(headers))
res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8")
return res.code, body
@@ -33,9 +35,9 @@ def http_request(url, timeout, headers={}):
raise e


def url_encode(params):
def url_encode(params: dict[str, Any]) -> str:
return parse.urlencode(params)


def makedirs_wrapper(path):
def makedirs_wrapper(path: str) -> None:
os.makedirs(path, exist_ok=True)

+ 6
- 5
api/configs/remote_settings_sources/apollo/utils.py View File

@@ -1,5 +1,6 @@
import hashlib
import socket
from typing import Any

from .python_3x import url_encode

@@ -10,7 +11,7 @@ NAMESPACE_NAME = "namespaceName"


# add timestamps uris and keys
def signature(timestamp, uri, secret):
def signature(timestamp: str, uri: str, secret: str) -> str:
import base64
import hmac

@@ -19,16 +20,16 @@ def signature(timestamp, uri, secret):
return base64.b64encode(hmac_code).decode()


def url_encode_wrapper(params):
def url_encode_wrapper(params: dict[str, Any]) -> str:
return url_encode(params)


def no_key_cache_key(namespace, key):
def no_key_cache_key(namespace: str, key: str) -> str:
return f"{namespace}{len(namespace)}{key}"


# Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key):
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None:
@@ -38,7 +39,7 @@ def get_value_from_dict(namespace_cache, key):
return None


def init_ip():
def init_ip() -> str:
ip = ""
s = None
try:

+ 1
- 1
api/configs/remote_settings_sources/base.py View File

@@ -11,5 +11,5 @@ class RemoteSettingsSource:
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError

def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
return value

+ 5
- 8
api/configs/remote_settings_sources/nacos/__init__.py View File

@@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)

from configs.remote_settings_sources.base import RemoteSettingsSource

from .utils import _parse_config
from .utils import parse_config


class NacosSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]):
self.configs = configs
self.remote_configs: dict[str, Any] = {}
self.remote_configs: dict[str, str] = {}
self.async_init()

def async_init(self):
def async_init(self) -> None:
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
@@ -33,18 +33,15 @@ class NacosSettingsSource(RemoteSettingsSource):
logger.exception("[get-access-token] exception occurred")
raise

def _parse_config(self, content: str) -> dict:
def _parse_config(self, content: str) -> dict[str, str]:
if not content:
return {}
try:
return _parse_config(self, content)
return parse_config(content)
except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}")

def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")

field_value = self.remote_configs.get(field_name)
if field_value is None:
return None, field_name, False

+ 15
- 7
api/configs/remote_settings_sources/nacos/http_request.py View File

@@ -17,11 +17,17 @@ class NacosHttpClient:
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
self.token = None
self.token: str | None = None
self.token_ttl = 18000
self.token_expire_time: float = 0

def http_request(self, url, method="GET", headers=None, params=None):
def http_request(
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
) -> str:
if headers is None:
headers = {}
if params is None:
params = {}
try:
self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
@@ -30,7 +36,7 @@ class NacosHttpClient:
except requests.RequestException as e:
return f"Request to Nacos failed: {e}"

def _inject_auth_info(self, headers, params, module="config"):
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})

if module == "login":
@@ -45,16 +51,17 @@ class NacosHttpClient:
headers["timeStamp"] = ts
if self.username and self.password:
self.get_access_token(force_refresh=False)
params["accessToken"] = self.token
if self.token is not None:
params["accessToken"] = self.token

def __do_sign(self, sign_str, sk):
def __do_sign(self, sign_str: str, sk: str) -> str:
return (
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
.decode()
.strip()
)

def get_sign_str(self, group, tenant, ts):
def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
sign_str = ""
if tenant:
sign_str = tenant + "+"
@@ -63,7 +70,7 @@ class NacosHttpClient:
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
return sign_str

def get_access_token(self, force_refresh=False):
def get_access_token(self, force_refresh: bool = False) -> str | None:
current_time = time.time()
if self.token and not force_refresh and self.token_expire_time > current_time:
return self.token
@@ -77,6 +84,7 @@ class NacosHttpClient:
self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10
return self.token
except Exception:
logger.exception("[get-access-token] exception occur")
raise

+ 1
- 1
api/configs/remote_settings_sources/nacos/utils.py View File

@@ -1,4 +1,4 @@
def _parse_config(self, content: str) -> dict[str, str]:
def parse_config(content: str) -> dict[str, str]:
config: dict[str, str] = {}
if not content:
return config

+ 6
- 2
api/controllers/console/admin.py View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar

from flask import request
from flask_restx import Resource, reqparse
@@ -6,6 +8,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized

P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
@@ -14,9 +18,9 @@ from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp


def admin_required(view):
def admin_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")


+ 1
- 1
api/controllers/console/apikey.py View File

@@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource):
custom="max_keys_exceeded",
)

key = ApiToken.generate_api_key(self.token_prefix, 24)
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id

+ 1
- 1
api/controllers/console/app/generator.py View File

@@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self) -> dict:
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()

+ 3
- 4
api/controllers/console/app/workflow_draft_variable.py View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, NoReturn
from typing import NoReturn

from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
@@ -31,7 +31,7 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)


def _convert_values_to_json_serializable_object(value: Segment) -> Any:
def _convert_values_to_json_serializable_object(value: Segment):
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
@@ -42,8 +42,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any:
return value.value


def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
"""Serialize variable value. If variable is truncated, return the truncated value."""
def _serialize_var_value(variable: WorkflowDraftVariable):
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)

+ 13
- 13
api/controllers/console/auth/oauth_server.py View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import cast
from typing import Concatenate, ParamSpec, TypeVar, cast

import flask_login
from flask import jsonify, request
@@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,

from .. import api

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")

def oauth_server_client_id_required(view):

def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
@@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
if not oauth_provider_app:
raise NotFound("client_id is invalid")

kwargs["oauth_provider_app"] = oauth_provider_app

return view(*args, **kwargs)
return view(self, oauth_provider_app, *args, **kwargs)

return decorated


def oauth_server_access_token_required(view):
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app")
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")

authorization_header = request.headers.get("Authorization")
@@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
response.headers["WWW-Authenticate"] = "Bearer"
return response

kwargs["account"] = account

return view(*args, **kwargs)
return view(self, oauth_provider_app, account, *args, **kwargs)

return decorated


+ 6
- 3
api/controllers/console/billing/billing.py View File

@@ -1,9 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse

from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import login_required
from libs.login import current_user, login_required
from models.model import Account
from services.billing_service import BillingService


@@ -17,9 +17,10 @@ class Subscription(Resource):
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
assert isinstance(current_user, Account)

BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
)
@@ -31,7 +32,9 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)



+ 6
- 0
api/controllers/console/datasets/datasets_document.py View File

@@ -477,6 +477,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_source_info = document.data_source_info_dict

if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
@@ -493,6 +495,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting)

elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
@@ -506,6 +510,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={

+ 2
- 0
api/controllers/console/explore/parameter.py View File

@@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp):
"""Get app meta"""
app_model = installed_app.app
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model)



+ 4
- 0
api/controllers/console/explore/workflow.py View File

@@ -36,6 +36,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
Run workflow
"""
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
@@ -74,6 +76,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
Stop workflow task
"""
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()

+ 12
- 14
api/controllers/console/explore/wraps.py View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar

from flask_login import current_user
from flask_restx import Resource
@@ -13,19 +15,15 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")

def installed_app_required(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")

installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)

del kwargs["installed_app_id"]

def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
installed_app = (
db.session.query(InstalledApp)
.where(
@@ -52,10 +50,10 @@ def installed_app_required(view=None):
return decorator


def user_allowed_to_access_app(view=None):
def decorator(view):
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs):
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id

+ 7
- 2
api/controllers/console/workspace/__init__.py View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar

from flask_login import current_user
from sqlalchemy.orm import Session
@@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission

P = ParamSpec("P")
R = TypeVar("R")


def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
user = current_user
tenant_id = user.current_tenant_id


+ 33
- 28
api/controllers/console/wraps.py View File

@@ -2,7 +2,9 @@ import contextlib
import json
import os
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar

from flask import abort, request
from flask_login import current_user
@@ -19,10 +21,13 @@ from services.operation_service import OperationService

from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout

P = ParamSpec("P")
R = TypeVar("R")

def account_initialization_required(view):

def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = current_user

@@ -34,9 +39,9 @@ def account_initialization_required(view):
return decorated


def only_edition_cloud(view):
def only_edition_cloud(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD":
abort(404)

@@ -45,9 +50,9 @@ def only_edition_cloud(view):
return decorated


def only_edition_enterprise(view):
def only_edition_enterprise(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED:
abort(404)

@@ -56,9 +61,9 @@ def only_edition_enterprise(view):
return decorated


def only_edition_self_hosted(view):
def only_edition_self_hosted(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED":
abort(404)

@@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
return decorated


def cloud_edition_billing_enabled(view):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
@@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):


def cloud_edition_billing_resource_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
members = features.members
@@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):


def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
@@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):


def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
@@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor


def cloud_utm_record(view):
def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)

@@ -194,9 +199,9 @@ def cloud_utm_record(view):
return decorated


def setup_required(view):
def setup_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"
@@ -212,9 +217,9 @@ def setup_required(view):
return decorated


def enterprise_license_required(view):
def enterprise_license_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
@@ -224,9 +229,9 @@ def enterprise_license_required(view):
return decorated


def email_password_login_enabled(view):
def email_password_login_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.enable_email_password_login:
return view(*args, **kwargs)
@@ -237,9 +242,9 @@ def email_password_login_enabled(view):
return decorated


def enable_change_email(view):
def enable_change_email(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.enable_change_email:
return view(*args, **kwargs)
@@ -250,9 +255,9 @@ def enable_change_email(view):
return decorated


def is_allow_transfer_owner(view):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)

+ 1
- 1
api/controllers/mcp/mcp.py View File

@@ -99,7 +99,7 @@ class MCPAppApi(Resource):

return mcp_server, app

def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
def _validate_server_status(self, mcp_server: AppMCPServer):
"""Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")

+ 4
- 4
api/controllers/service_api/dataset/segment.py View File

@@ -440,7 +440,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")

# validate segment belongs to the specified document
if segment.document_id != document_id:
if str(segment.document_id) != str(document_id):
raise NotFound("Document not found.")

# check child chunk
@@ -451,7 +451,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")

# validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id:
if str(child_chunk.segment_id) != str(segment.id):
raise NotFound("Child chunk not found.")

try:
@@ -500,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")

# validate segment belongs to the specified document
if segment.document_id != document_id:
if str(segment.document_id) != str(document_id):
raise NotFound("Segment not found.")

# get child chunk
@@ -511,7 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")

# validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id:
if str(child_chunk.segment_id) != str(segment.id):
raise NotFound("Child chunk not found.")

# validate args

+ 10
- 28
api/controllers/service_api/wraps.py View File

@@ -3,7 +3,7 @@ from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Optional
from typing import Optional, ParamSpec, TypeVar

from flask import current_app, request
from flask_login import user_logged_in
@@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService

P = ParamSpec("P")
R = TypeVar("R")


class WhereisUserArg(StrEnum):
"""
@@ -60,27 +63,6 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")

tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")

kwargs["app_model"] = app_model

if fetch_user_arg:
@@ -118,8 +100,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio


def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view):
def decorated(*args, **kwargs):
def interceptor(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)

@@ -148,9 +130,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):


def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
@@ -170,9 +152,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s


def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)

if resource == "knowledge":

+ 4
- 0
api/controllers/web/wraps.py View File

@@ -1,5 +1,6 @@
from datetime import UTC, datetime
from functools import wraps
from typing import ParamSpec, TypeVar

from flask import request
from flask_restx import Resource
@@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService

P = ParamSpec("P")
R = TypeVar("R")


def validate_jwt_token(view=None):
def decorator(view):

+ 1
- 1
api/core/agent/base_agent_runner.py View File

@@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner):
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
) -> None:
):
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.conversation = conversation

+ 1
- 1
api/core/agent/cot_agent_runner.py View File

@@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):

return instruction

def _init_react_state(self, query) -> None:
def _init_react_state(self, query):
"""
init agent scratchpad
"""

+ 1
- 1
api/core/agent/entities.py View File

@@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel):
action_name: str
action_input: Union[dict, str]

def to_dict(self) -> dict:
def to_dict(self):
"""
Convert to dictionary.
"""

+ 1
- 1
api/core/app/app_config/easy_ui_based_app/dataset/manager.py View File

@@ -158,7 +158,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]

@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict:
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
"""
Extract dataset config for legacy compatibility


+ 1
- 1
api/core/app/app_config/easy_ui_based_app/model_config/manager.py View File

@@ -105,7 +105,7 @@ class ModelConfigManager:
return dict(config), ["model"]

@classmethod
def validate_model_completion_params(cls, cp: dict) -> dict:
def validate_model_completion_params(cls, cp: dict):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

+ 1
- 1
api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py View File

@@ -122,7 +122,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]

@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict:
def validate_post_prompt_and_set_defaults(cls, config: dict):
"""
Validate post_prompt and set defaults for prompt feature


+ 1
- 1
api/core/app/apps/advanced_chat/app_config_manager.py View File

@@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
return app_config

@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for advanced chat app model config


+ 1
- 1
api/core/app/apps/advanced_chat/app_generator.py View File

@@ -481,7 +481,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

+ 4
- 3
api/core/app/apps/advanced_chat/app_runner.py View File

@@ -55,7 +55,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow,
system_user_id: str,
app: App,
) -> None:
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
@@ -69,7 +69,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.system_user_id = system_user_id
self._app = app

def run(self) -> None:
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)

@@ -184,6 +184,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
@@ -238,7 +239,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):

return False

def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
"""
Direct output
"""

+ 3
- 3
api/core/app/apps/advanced_chat/generate_task_pipeline.py View File

@@ -96,7 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@@ -284,7 +284,7 @@ class AdvancedChatAppGenerateTaskPipeline:
session.rollback()
raise

def _ensure_workflow_initialized(self) -> None:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
@@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()

def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None):
message = self._get_message(session=session)

# If there are assistant files, remove markdown image links from answer

+ 1
- 1
api/core/app/apps/agent_chat/app_config_manager.py View File

@@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config

@classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
"""
Validate for agent chat app model config


+ 1
- 1
api/core/app/apps/agent_chat/app_generator.py View File

@@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

+ 1
- 1
api/core/app/apps/agent_chat/app_runner.py View File

@@ -35,7 +35,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
):
"""
Run assistant application
:param application_generate_entity: application generate entity

+ 2
- 2
api/core/app/apps/agent_chat/generate_response_converter.py View File

@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse

@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response

@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

+ 1
- 1
api/core/app/apps/base_app_generate_response_converter.py View File

@@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC):
return metadata

@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict:
def _error_to_stream_response(cls, e: Exception):
"""
Error to stream response.
:param e: exception

+ 1
- 1
api/core/app/apps/base_app_generator.py View File

@@ -158,7 +158,7 @@ class BaseAppGenerator:

return value

def _sanitize_value(self, value: Any) -> Any:
def _sanitize_value(self, value: Any):
if isinstance(value, str):
return value.replace("\x00", "")
return value

+ 6
- 6
api/core/app/apps/base_app_queue_manager.py View File

@@ -25,7 +25,7 @@ class PublishFrom(IntEnum):


class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
if not user_id:
raise ValueError("user is required")

@@ -73,14 +73,14 @@ class AppQueueManager:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10

def stop_listen(self) -> None:
def stop_listen(self):
"""
Stop listen to queue
:return:
"""
self._q.put(None)

def publish_error(self, e, pub_from: PublishFrom) -> None:
def publish_error(self, e, pub_from: PublishFrom):
"""
Publish error
:param e: error
@@ -89,7 +89,7 @@ class AppQueueManager:
"""
self.publish(QueueErrorEvent(error=e), pub_from)

def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
@@ -100,7 +100,7 @@ class AppQueueManager:
self._publish(event, pub_from)

@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
@@ -110,7 +110,7 @@ class AppQueueManager:
raise NotImplementedError

@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
"""
Set task stop flag
:return:

+ 4
- 6
api/core/app/apps/base_app_runner.py View File

@@ -162,7 +162,7 @@ class AppRunner:
text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
):
"""
Direct output
:param queue_manager: application queue manager
@@ -204,7 +204,7 @@ class AppRunner:
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
):
"""
Handle invoke result
:param invoke_result: invoke result
@@ -220,9 +220,7 @@ class AppRunner:
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")

def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
"""
Handle invoke result direct
:param invoke_result: invoke result
@@ -239,7 +237,7 @@ class AppRunner:

def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
) -> None:
):
"""
Handle invoke result
:param invoke_result: invoke result

+ 1
- 1
api/core/app/apps/chat/app_config_manager.py View File

@@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config

@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for chat app model config


+ 1
- 1
api/core/app/apps/chat/app_generator.py View File

@@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

+ 1
- 1
api/core/app/apps/chat/app_runner.py View File

@@ -33,7 +33,7 @@ class ChatAppRunner(AppRunner):
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
):
"""
Run application
:param application_generate_entity: application generate entity

+ 2
- 2
api/core/app/apps/chat/generate_response_converter.py View File

@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse

@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response

@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

+ 1
- 1
api/core/app/apps/common/workflow_response_converter.py View File

@@ -56,7 +56,7 @@ class WorkflowResponseConverter:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
) -> None:
):
self._application_generate_entity = application_generate_entity
self._user = user
self._truncator = VariableTruncator.default()

+ 1
- 1
api/core/app/apps/completion/app_config_manager.py View File

@@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config

@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for completion app model config


+ 4
- 1
api/core/app/apps/completion/app_generator.py View File

@@ -192,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
@@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise MessageNotExistsError()

current_app_model_config = app_model.app_model_config
if not current_app_model_config:
raise MoreLikeThisDisabledError()

more_like_this = current_app_model_config.more_like_this_dict

if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:

+ 1
- 1
api/core/app/apps/completion/app_runner.py View File

@@ -27,7 +27,7 @@ class CompletionAppRunner(AppRunner):

def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
) -> None:
):
"""
Run application
:param application_generate_entity: application generate entity

+ 2
- 2
api/core/app/apps/completion/generate_response_converter.py View File

@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse

@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response

@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

+ 2
- 2
api/core/app/apps/message_based_app_queue_manager.py View File

@@ -14,14 +14,14 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager):
def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
) -> None:
):
super().__init__(task_id, user_id, invoke_from)

self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)

def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

+ 1
- 1
api/core/app/apps/workflow/app_config_manager.py View File

@@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
return app_config

@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for workflow app model config


+ 2
- 2
api/core/app/apps/workflow/app_queue_manager.py View File

@@ -14,12 +14,12 @@ from core.app.entities.queue_entities import (


class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
super().__init__(task_id, user_id, invoke_from)

self._app_mode = app_mode

def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

+ 3
- 2
api/core/app/apps/workflow/app_runner.py View File

@@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
) -> None:
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
@@ -44,7 +44,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._workflow = workflow
self._sys_user_id = system_user_id

def run(self) -> None:
def run(self):
"""
Run application
"""
@@ -127,6 +127,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)

+ 2
- 2
api/core/app/apps/workflow/generate_response_converter.py View File

@@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse

@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.to_dict())

@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

+ 3
- 3
api/core/app/apps/workflow/generate_task_pipeline.py View File

@@ -88,7 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@@ -259,7 +259,7 @@ class WorkflowAppGenerateTaskPipeline:
session.rollback()
raise

def _ensure_workflow_initialized(self) -> None:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
@@ -697,7 +697,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)

def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API

+ 3
- 3
api/core/app/apps/workflow_app_runner.py View File

@@ -67,7 +67,7 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
) -> None:
):
self._queue_manager = queue_manager
self._variable_loader = variable_loader
self._app_id = app_id
@@ -348,7 +348,7 @@ class WorkflowBasedAppRunner:

return graph, variable_pool

def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
:param workflow_entry: workflow entry
@@ -580,5 +580,5 @@ class WorkflowBasedAppRunner:
)
)

def _publish_event(self, event: AppQueueEvent) -> None:
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

+ 1
- 1
api/core/app/task_pipeline/based_generate_task_pipeline.py View File

@@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline:
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
):
self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self._start_at = time.perf_counter()

+ 3
- 3
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py View File

@@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
conversation: Conversation,
message: Message,
stream: bool,
) -> None:
):
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()

def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None):
"""
Save message.
:return:
@@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
application_generate_entity=self._application_generate_entity,
)

def _handle_stop(self, event: QueueStopEvent) -> None:
def _handle_stop(self, event: QueueStopEvent):
"""
Handle stop.
:return:

+ 2
- 2
api/core/app/task_pipeline/message_cycle_manager.py View File

@@ -48,7 +48,7 @@ class MessageCycleManager:
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state

@@ -132,7 +132,7 @@ class MessageCycleManager:

return None

def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
"""
Handle retriever resources.
:param event: event

+ 7
- 7
api/core/callback_handler/agent_tool_callback_handler.py View File

@@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str:
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"


def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None):
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
@@ -37,7 +37,7 @@ class DifyAgentCallbackHandler(BaseModel):
color: Optional[str] = ""
current_loop: int = 1

def __init__(self, color: Optional[str] = None) -> None:
def __init__(self, color: Optional[str] = None):
super().__init__()
"""Initialize callback handler."""
# use a specific color is not specified
@@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel):
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
) -> None:
):
"""Do nothing."""
if dify_config.DEBUG:
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
@@ -61,7 +61,7 @@ class DifyAgentCallbackHandler(BaseModel):
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> None:
):
"""If not the final action, print out observation."""
if dify_config.DEBUG:
print_text("\n[on_tool_end]\n", color=self.color)
@@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel):
)
)

def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
"""Do nothing."""
if dify_config.DEBUG:
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")

def on_agent_start(self, thought: str) -> None:
def on_agent_start(self, thought: str):
"""Run on agent start."""
if dify_config.DEBUG:
if thought:
@@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel):
else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)

def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any):
"""Run on agent end."""
if dify_config.DEBUG:
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)

+ 3
- 3
api/core/callback_handler/index_tool_callback_handler.py View File

@@ -21,14 +21,14 @@ class DatasetIndexToolCallbackHandler:

def __init__(
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
) -> None:
):
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
self._user_id = user_id
self._invoke_from = invoke_from

def on_query(self, query: str, dataset_id: str) -> None:
def on_query(self, query: str, dataset_id: str):
"""
Handle query.
"""
@@ -46,7 +46,7 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_query)
db.session.commit()

def on_tool_end(self, documents: list[Document]) -> None:
def on_tool_end(self, documents: list[Document]):
"""Handle tool end."""
for document in documents:
if document.metadata is not None:

+ 0
- 2
api/core/datasource/__base/datasource_runtime.py View File

@@ -30,8 +30,6 @@ class FakeDatasourceRuntime(DatasourceRuntime):
"""

def __init__(self):


super().__init__(
tenant_id="fake_tenant_id",
datasource_id="fake_datasource_id",

+ 2
- 2
api/core/entities/model_entities.py View File

@@ -33,7 +33,7 @@ class SimpleModelProviderEntity(BaseModel):
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]

def __init__(self, provider_entity: ProviderEntity) -> None:
def __init__(self, provider_entity: ProviderEntity):
"""
Init simple provider.

@@ -57,7 +57,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
load_balancing_enabled: bool = False
has_invalid_load_balancing_configs: bool = False

def raise_for_status(self) -> None:
def raise_for_status(self):
"""
Check model status and raise ValueError if not active.


+ 16
- 18
api/core/entities/provider_configuration.py View File

@@ -280,9 +280,7 @@ class ProviderConfiguration(BaseModel):
else [],
)

def validate_provider_credentials(
self, credentials: dict, credential_id: str = "", session: Session | None = None
) -> dict:
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -291,7 +289,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""

def _validate(s: Session) -> dict:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
@@ -402,7 +400,7 @@ class ProviderConfiguration(BaseModel):
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"

def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
def create_provider_credential(self, credentials: dict, credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@@ -458,7 +456,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict,
credential_id: str,
credential_name: str | None,
) -> None:
):
"""
update a saved provider credential (by credential_id).

@@ -519,7 +517,7 @@ class ProviderConfiguration(BaseModel):
credential_record: ProviderCredential | ProviderModelCredential,
credential_source: str,
session: Session,
) -> None:
):
"""
Update load balancing configurations that reference the given credential_id.

@@ -559,7 +557,7 @@ class ProviderConfiguration(BaseModel):

session.commit()

def delete_provider_credential(self, credential_id: str) -> None:
def delete_provider_credential(self, credential_id: str):
"""
Delete a saved provider credential (by credential_id).

@@ -636,7 +634,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise

def switch_active_provider_credential(self, credential_id: str) -> None:
def switch_active_provider_credential(self, credential_id: str):
"""
Switch active provider credential (copy the selected one into current active snapshot).

@@ -815,7 +813,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict,
credential_id: str = "",
session: Session | None = None,
) -> dict:
):
"""
Validate custom model credentials.

@@ -826,7 +824,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""

def _validate(s: Session) -> dict:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
@@ -1010,7 +1008,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise

def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
Delete a saved provider credential (by credential_id).

@@ -1080,7 +1078,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise

def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
"""
if model list exist this custom model, switch the custom model credential.
if model list not exist this custom model, use the credential to add a new custom model record.
@@ -1123,7 +1121,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()

def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
switch the custom model credential.

@@ -1153,7 +1151,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()

def delete_custom_model(self, model_type: ModelType, model: str) -> None:
def delete_custom_model(self, model_type: ModelType, model: str):
"""
Delete custom model.
:param model_type: model type
@@ -1350,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)

def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None:
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
"""
Switch preferred provider type.
:param provider_type:
@@ -1362,7 +1360,7 @@ class ProviderConfiguration(BaseModel):
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return

def _switch(s: Session) -> None:
def _switch(s: Session):
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
@@ -1406,7 +1404,7 @@ class ProviderConfiguration(BaseModel):

return secret_input_form_variables

def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.


+ 1
- 1
api/core/errors/error.py View File

@@ -6,7 +6,7 @@ class LLMError(ValueError):

description: Optional[str] = None

def __init__(self, description: Optional[str] = None) -> None:
def __init__(self, description: Optional[str] = None):
self.description = description



+ 2
- 2
api/core/extension/api_based_extension_requestor.py View File

@@ -10,11 +10,11 @@ class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60)
"""timeout for request connect and read"""

def __init__(self, api_endpoint: str, api_key: str) -> None:
def __init__(self, api_endpoint: str, api_key: str):
self.api_endpoint = api_endpoint
self.api_key = api_key

def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
def request(self, point: APIBasedExtensionPoint, params: dict):
"""
Request the api.


+ 1
- 1
api/core/extension/extensible.py View File

@@ -34,7 +34,7 @@ class Extensible:
tenant_id: str
config: Optional[dict] = None

def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
def __init__(self, tenant_id: str, config: Optional[dict] = None):
self.tenant_id = tenant_id
self.config = config


+ 1
- 1
api/core/external_data_tool/api/api.py View File

@@ -18,7 +18,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""

@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict):
"""
Validate the incoming form config data.


+ 2
- 2
api/core/external_data_tool/base.py View File

@@ -16,14 +16,14 @@ class ExternalDataTool(Extensible, ABC):
variable: str
"""the tool variable name of app tool"""

def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None):
super().__init__(tenant_id, config)
self.app_id = app_id
self.variable = variable

@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict):
"""
Validate the incoming form config data.


+ 2
- 2
api/core/external_data_tool/factory.py View File

@@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension


class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
)

@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
def validate_config(cls, name: str, tenant_id: str, config: dict):
"""
Validate the incoming form config data.


+ 1
- 1
api/core/file/tool_file_parser.py View File

@@ -7,6 +7,6 @@ if TYPE_CHECKING:
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None


def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]):
global _tool_file_manager_factory
_tool_file_manager_factory = factory

+ 1
- 1
api/core/helper/code_executor/code_node_provider.py View File

@@ -22,7 +22,7 @@ class CodeNodeProvider(BaseModel):
pass

@classmethod
def get_default_config(cls) -> dict:
def get_default_config(cls):
return {
"type": "code",
"config": {

+ 1
- 1
api/core/helper/code_executor/jinja2/jinja2_transformer.py View File

@@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer

class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod
def transform_response(cls, response: str) -> dict:
def transform_response(cls, response: str):
"""
Transform response to dict
:param response: response

+ 1
- 1
api/core/helper/code_executor/python3/python3_code_provider.py View File

@@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
def get_default_code(cls) -> str:
return dedent(
"""
def main(arg1: str, arg2: str) -> dict:
def main(arg1: str, arg2: str):
return {
"result": arg1 + arg2,
}

+ 2
- 2
api/core/helper/model_provider_cache.py View File

@@ -34,7 +34,7 @@ class ProviderCredentialsCache:
else:
return None

def set(self, credentials: dict) -> None:
def set(self, credentials: dict):
"""
Cache model provider credentials.

@@ -43,7 +43,7 @@ class ProviderCredentialsCache:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))

def delete(self) -> None:
def delete(self):
"""
Delete cached model provider credentials.


+ 4
- 4
api/core/helper/provider_cache.py View File

@@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC):
return None
return None

def set(self, config: dict[str, Any]) -> None:
def set(self, config: dict[str, Any]):
"""Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config))

def delete(self) -> None:
def delete(self):
"""Delete cached provider credentials"""
redis_client.delete(self.cache_key)

@@ -75,10 +75,10 @@ class NoOpProviderCredentialCache:
"""Get cached provider credentials"""
return None

def set(self, config: dict[str, Any]) -> None:
def set(self, config: dict[str, Any]):
"""Cache provider credentials"""
pass

def delete(self) -> None:
def delete(self):
"""Delete cached provider credentials"""
pass

+ 2
- 2
api/core/helper/tool_parameter_cache.py View File

@@ -37,11 +37,11 @@ class ToolParameterCache:
else:
return None

def set(self, parameters: dict) -> None:
def set(self, parameters: dict):
"""Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))

def delete(self) -> None:
def delete(self):
"""
Delete cached model provider credentials.


+ 1
- 1
api/core/helper/trace_id_helper.py View File

@@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]:
return None


def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
def extract_external_trace_id_from_args(args: Mapping[str, Any]):
"""
Extract 'external_trace_id' from args.


+ 2
- 2
api/core/hosting_configuration.py View File

@@ -44,11 +44,11 @@ class HostingConfiguration:
provider_map: dict[str, HostingProvider]
moderation_config: Optional[HostedModerationConfig] = None

def __init__(self) -> None:
def __init__(self):
self.provider_map = {}
self.moderation_config = None

def init_app(self, app: Flask) -> None:
def init_app(self, app: Flask):
if dify_config.EDITION != "CLOUD":
return


+ 14
- 11
api/core/indexing_runner.py View File

@@ -270,7 +270,9 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = [] # type: ignore
# keep separate, avoid union-list ambiguity
preview_texts: list[PreviewDetail] = []
qa_preview_texts: list[QAPreviewDetail] = []

total_segments = 0
index_type = doc_form
@@ -293,14 +295,14 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
qa_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer") or ""
)
preview_texts.append(preview_detail)
qa_preview_texts.append(qa_detail)
else:
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
preview_detail = PreviewDetail(content=document.page_content)
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail)

# delete image files and related db records
@@ -321,8 +323,8 @@ class IndexingRunner:
db.session.delete(image_file)

if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)

def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@@ -425,6 +427,7 @@ class IndexingRunner:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
@@ -451,7 +454,7 @@ class IndexingRunner:
embedding_model_instance=embedding_model_instance,
)

return character_splitter # type: ignore
return character_splitter

def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
@@ -510,7 +513,7 @@ class IndexingRunner:
dataset: Dataset,
dataset_document: DatasetDocument,
documents: list[Document],
) -> None:
):
"""
insert index and update document/segment status to completed
"""
@@ -649,7 +652,7 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
) -> None:
):
"""
Update the document indexing status.
"""
@@ -668,7 +671,7 @@ class IndexingRunner:
db.session.commit()

@staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
"""
Update the document segment by document id.
"""

+ 6
- 8
api/core/llm_generator/llm_generator.py View File

@@ -129,7 +129,7 @@ class LLMGenerator:
return questions

@classmethod
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict:
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool):
output_parser = RuleConfigGeneratorOutputParser()

error = ""
@@ -264,9 +264,7 @@ class LLMGenerator:
return rule_config

@classmethod
def generate_code(
cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"
) -> dict:
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
@@ -375,7 +373,7 @@ class LLMGenerator:
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict:
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
@@ -415,7 +413,7 @@ class LLMGenerator:
instruction: str,
model_config: dict,
ideal_output: str | None,
) -> dict:
):
from services.workflow_service import WorkflowService

session = db.session()
@@ -455,7 +453,7 @@ class LLMGenerator:
return []
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)

def dict_of_event(event: AgentLogEvent) -> dict:
def dict_of_event(event: AgentLogEvent):
return {
"status": event.status,
"error": event.error,
@@ -493,7 +491,7 @@ class LLMGenerator:
instruction: str,
node_type: str,
ideal_output: str | None,
) -> dict:
):
LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}"
ERROR_MESSAGE = "{{#error_message#}}"

+ 1
- 3
api/core/llm_generator/output_parser/rule_config_generator.py View File

@@ -1,5 +1,3 @@
from typing import Any

from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
@@ -17,7 +15,7 @@ class RuleConfigGeneratorOutputParser:
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
)

def parse(self, text: str) -> Any:
def parse(self, text: str):
try:
expected_keys = ["prompt", "variables", "opening_statement"]
parsed = parse_and_check_json_markdown(text, expected_keys)

+ 5
- 5
api/core/llm_generator/output_parser/structured_output.py View File

@@ -210,7 +210,7 @@ def _handle_native_json_schema(
structured_output_schema: Mapping,
model_parameters: dict,
rules: list[ParameterRule],
) -> dict:
):
"""
Handle structured output for models with native JSON schema support.

@@ -232,7 +232,7 @@ def _handle_native_json_schema(
return model_parameters


def _set_response_format(model_parameters: dict, rules: list) -> None:
def _set_response_format(model_parameters: dict, rules: list):
"""
Set the appropriate response format parameter based on model rules.

@@ -306,7 +306,7 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
return structured_output


def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
"""
Prepare JSON schema based on model requirements.

@@ -334,7 +334,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"}


def remove_additional_properties(schema: dict) -> None:
def remove_additional_properties(schema: dict):
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
@@ -357,7 +357,7 @@ def remove_additional_properties(schema: dict) -> None:
remove_additional_properties(item)


def convert_boolean_to_string(schema: dict) -> None:
def convert_boolean_to_string(schema: dict):
"""
Convert boolean type specifications to string in JSON schema.


+ 1
- 2
api/core/llm_generator/output_parser/suggested_questions_after_answer.py View File

@@ -1,6 +1,5 @@
import json
import re
from typing import Any

from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT

@@ -9,7 +8,7 @@ class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT

def parse(self, text: str) -> Any:
def parse(self, text: str):
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
if action_match is not None:
json_obj = json.loads(action_match.group(0).strip())

+ 3
- 3
api/core/mcp/auth/auth_provider.py View File

@@ -44,7 +44,7 @@ class OAuthClientProvider:
return None
return OAuthClientInformation.model_validate(client_information)

def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider,
@@ -63,13 +63,13 @@ class OAuthClientProvider:
refresh_token=credentials.get("refresh_token", ""),
)

def save_tokens(self, tokens: OAuthTokens) -> None:
def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)

def save_code_verifier(self, code_verifier: str) -> None:
def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})


+ 8
- 8
api/core/mcp/client/sse_client.py View File

@@ -47,7 +47,7 @@ class SSETransport:
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> None:
):
"""Initialize the SSE transport.

Args:
@@ -76,7 +76,7 @@ class SSETransport:

return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme

def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue):
"""Handle an 'endpoint' SSE event.

Args:
@@ -94,7 +94,7 @@ class SSETransport:

status_queue.put(_StatusReady(endpoint_url))

def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue):
"""Handle a 'message' SSE event.

Args:
@@ -110,7 +110,7 @@ class SSETransport:
logger.exception("Error parsing server message")
read_queue.put(exc)

def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue):
"""Handle a single SSE event.

Args:
@@ -126,7 +126,7 @@ class SSETransport:
case _:
logger.warning("Unknown SSE event: %s", sse.event)

def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue):
"""Read and process SSE events.

Args:
@@ -144,7 +144,7 @@ class SSETransport:
finally:
read_queue.put(None)

def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage):
"""Send a single message to the server.

Args:
@@ -163,7 +163,7 @@ class SSETransport:
response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code)

def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue):
"""Handle writing messages to the server.

Args:
@@ -303,7 +303,7 @@ def sse_client(
write_queue.put(None)


def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
"""
Send a message to the server using the provided HTTP client.


+ 12
- 12
api/core/mcp/client/streamable_client.py View File

@@ -82,7 +82,7 @@ class StreamableHTTPTransport:
headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
) -> None:
):
"""Initialize the StreamableHTTP transport.

Args:
@@ -122,7 +122,7 @@ class StreamableHTTPTransport:
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
):
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
@@ -173,7 +173,7 @@ class StreamableHTTPTransport:
self,
client: httpx.Client,
server_to_client_queue: ServerToClientQueue,
) -> None:
):
"""Handle GET stream for server-initiated messages."""
try:
if not self.session_id:
@@ -197,7 +197,7 @@ class StreamableHTTPTransport:
except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc)

def _handle_resumption_request(self, ctx: RequestContext) -> None:
def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
@@ -230,7 +230,7 @@ class StreamableHTTPTransport:
if is_complete:
break

def _handle_post_request(self, ctx: RequestContext) -> None:
def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
@@ -278,7 +278,7 @@ class StreamableHTTPTransport:
self,
response: httpx.Response,
server_to_client_queue: ServerToClientQueue,
) -> None:
):
"""Handle JSON response from the server."""
try:
content = response.read()
@@ -288,7 +288,7 @@ class StreamableHTTPTransport:
except Exception as exc:
server_to_client_queue.put(exc)

def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
@@ -307,7 +307,7 @@ class StreamableHTTPTransport:
self,
content_type: str,
server_to_client_queue: ServerToClientQueue,
) -> None:
):
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg)
@@ -317,7 +317,7 @@ class StreamableHTTPTransport:
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
) -> None:
):
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
@@ -333,7 +333,7 @@ class StreamableHTTPTransport:
client_to_server_queue: ClientToServerQueue,
server_to_client_queue: ServerToClientQueue,
start_get_stream: Callable[[], None],
) -> None:
):
"""Handle writing requests to the server.

This method processes messages from the client_to_server_queue and sends them to the server.
@@ -379,7 +379,7 @@ class StreamableHTTPTransport:
except Exception as exc:
server_to_client_queue.put(exc)

def terminate_session(self, client: httpx.Client) -> None:
def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
@@ -441,7 +441,7 @@ def streamablehttp_client(
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream() -> None:
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)


+ 14
- 16
api/core/mcp/session/base_session.py View File

@@ -76,7 +76,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
):
self.request_id = request_id
self.request_meta = request_meta
self.request = request
@@ -95,7 +95,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
):
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
@@ -103,7 +103,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
finally:
self._entered = False

def respond(self, response: SendResultT | ErrorData) -> None:
def respond(self, response: SendResultT | ErrorData):
"""Send a response for this request.

Must be called within a context manager block.
@@ -119,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):

self._session._send_response(request_id=self.request_id, response=response)

def cancel(self) -> None:
def cancel(self):
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
@@ -163,7 +163,7 @@ class BaseSession(
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None:
):
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
@@ -183,7 +183,7 @@ class BaseSession(
self._receiver_future = self._executor.submit(self._receive_loop)
return self

def check_receiver_status(self) -> None:
def check_receiver_status(self):
"""`check_receiver_status` ensures that any exceptions raised during the
execution of `_receive_loop` are retrieved and propagated."""
if self._receiver_future and self._receiver_future.done():
@@ -191,7 +191,7 @@ class BaseSession(

def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
):
self._read_stream.put(None)
self._write_stream.put(None)

@@ -277,7 +277,7 @@ class BaseSession(
self,
notification: SendNotificationT,
related_request_id: RequestId | None = None,
) -> None:
):
"""
Emits a notification, which is a one-way message that does not expect
a response.
@@ -296,7 +296,7 @@ class BaseSession(
)
self._write_stream.put(session_message)

def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@@ -310,7 +310,7 @@ class BaseSession(
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
self._write_stream.put(session_message)

def _receive_loop(self) -> None:
def _receive_loop(self):
"""
Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread.
@@ -382,7 +382,7 @@ class BaseSession(
logger.exception("Error in message processing loop")
raise

def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]):
"""
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
@@ -391,15 +391,13 @@ class BaseSession(
forwarded on to the message stream.
"""

def _received_notification(self, notification: ReceiveNotificationT) -> None:
def _received_notification(self, notification: ReceiveNotificationT):
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""

def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
"""
Sends a progress notification for a request that is currently being
processed.
@@ -408,5 +406,5 @@ class BaseSession(
def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
):
"""A generic handler for incoming messages. Overwritten by subclasses."""

+ 10
- 12
api/core/mcp/session/client_session.py View File

@@ -28,19 +28,19 @@ class LoggingFnT(Protocol):
def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ...
): ...


class MessageHandlerFnT(Protocol):
def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ...
): ...


def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
):
if isinstance(message, Exception):
raise ValueError(str(message))
elif isinstance(message, (types.ServerNotification | RequestResponder)):
@@ -68,7 +68,7 @@ def _default_list_roots_callback(

def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
):
pass


@@ -94,7 +94,7 @@ class ClientSession(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
) -> None:
):
super().__init__(
read_stream,
write_stream,
@@ -155,9 +155,7 @@ class ClientSession(
types.EmptyResult,
)

def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
"""Send a progress notification."""
self.send_notification(
types.ClientNotification(
@@ -314,7 +312,7 @@ class ClientSession(
types.ListToolsResult,
)

def send_roots_list_changed(self) -> None:
def send_roots_list_changed(self):
"""Send a roots/list_changed notification."""
self.send_notification(
types.ClientNotification(
@@ -324,7 +322,7 @@ class ClientSession(
)
)

def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
@@ -352,11 +350,11 @@ class ClientSession(
def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
):
"""Handle incoming messages by forwarding to the message handler."""
self._message_handler(req)

def _received_notification(self, notification: types.ServerNotification) -> None:
def _received_notification(self, notification: types.ServerNotification):
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:

+ 2
- 1
api/core/memory/token_buffer_memory.py View File

@@ -27,7 +27,7 @@ class TokenBufferMemory:
self,
conversation: Conversation,
model_instance: ModelInstance,
) -> None:
):
self.conversation = conversation
self.model_instance = model_instance

@@ -124,6 +124,7 @@ class TokenBufferMemory:

messages = list(reversed(thread_messages))

curr_message_tokens = 0
prompt_messages: list[PromptMessage] = []
for message in messages:
# Process user message with files

+ 0
- 0
api/core/model_manager.py View File


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save