| @@ -47,6 +47,10 @@ jobs: | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: dev/basedpyright-check | |||
| - name: Run Mypy Type Checks | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . | |||
| - name: Dotenv check | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example | |||
| @@ -198,6 +198,7 @@ sdks/python-client/dify_client.egg-info | |||
| !.vscode/launch.json.template | |||
| !.vscode/README.md | |||
| api/.vscode | |||
| web/.vscode | |||
| # vscode Code History Extension | |||
| .history | |||
| @@ -215,6 +216,13 @@ mise.toml | |||
| # Next.js build output | |||
| .next/ | |||
| # PWA generated files | |||
| web/public/sw.js | |||
| web/public/sw.js.map | |||
| web/public/workbox-*.js | |||
| web/public/workbox-*.js.map | |||
| web/public/fallback-*.js | |||
| # AI Assistant | |||
| .roo/ | |||
| api/.env.backup | |||
| @@ -25,6 +25,9 @@ def create_flask_app_with_configs() -> DifyApp: | |||
| # add an unique identifier to each request | |||
| RecyclableContextVar.increment_thread_recycles() | |||
| # Capture the decorator's return value to avoid pyright reportUnusedFunction | |||
| _ = before_request | |||
| return dify_app | |||
| @@ -14,7 +14,6 @@ from sqlalchemy.exc import SQLAlchemyError | |||
| from configs import dify_config | |||
| from constants.languages import languages | |||
| from core.helper import encrypter | |||
| from core.plugin.entities.plugin import PluginInstallationSource | |||
| from core.plugin.impl.plugin import PluginInstaller | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| @@ -1494,7 +1493,7 @@ def transform_datasource_credentials(): | |||
| for credential in credentials: | |||
| auth_count += 1 | |||
| # get credential api key | |||
| credentials_json =json.loads(credential.credentials) | |||
| credentials_json = json.loads(credential.credentials) | |||
| api_key = credentials_json.get("config", {}).get("api_key") | |||
| base_url = credentials_json.get("config", {}).get("base_url") | |||
| new_credentials = { | |||
| @@ -300,8 +300,7 @@ class DatasetQueueMonitorConfig(BaseSettings): | |||
| class MiddlewareConfig( | |||
| # place the configs in alphabet order | |||
| CeleryConfig, | |||
| DatabaseConfig, | |||
| CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig | |||
| KeywordStoreConfig, | |||
| RedisConfig, | |||
| # configs of storage and storage providers | |||
| @@ -1,9 +1,10 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel, Field | |||
| from pydantic import Field | |||
| from pydantic_settings import BaseSettings | |||
| class ClickzettaConfig(BaseModel): | |||
| class ClickzettaConfig(BaseSettings): | |||
| """ | |||
| Clickzetta Lakehouse vector database configuration | |||
| """ | |||
| @@ -1,7 +1,8 @@ | |||
| from pydantic import BaseModel, Field | |||
| from pydantic import Field | |||
| from pydantic_settings import BaseSettings | |||
| class MatrixoneConfig(BaseModel): | |||
| class MatrixoneConfig(BaseSettings): | |||
| """Matrixone vector database configuration.""" | |||
| MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server") | |||
| @@ -1,6 +1,6 @@ | |||
| from pydantic import Field | |||
| from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig | |||
| from configs.packaging.pyproject import PyProjectTomlConfig | |||
| class PackagingInfo(PyProjectTomlConfig): | |||
| @@ -4,8 +4,9 @@ import logging | |||
| import os | |||
| import threading | |||
| import time | |||
| from collections.abc import Mapping | |||
| from collections.abc import Callable, Mapping | |||
| from pathlib import Path | |||
| from typing import Any | |||
| from .python_3x import http_request, makedirs_wrapper | |||
| from .utils import ( | |||
| @@ -25,13 +26,13 @@ logger = logging.getLogger(__name__) | |||
| class ApolloClient: | |||
| def __init__( | |||
| self, | |||
| config_url, | |||
| app_id, | |||
| cluster="default", | |||
| secret="", | |||
| start_hot_update=True, | |||
| change_listener=None, | |||
| _notification_map=None, | |||
| config_url: str, | |||
| app_id: str, | |||
| cluster: str = "default", | |||
| secret: str = "", | |||
| start_hot_update: bool = True, | |||
| change_listener: Callable[[str, str, str, Any], None] | None = None, | |||
| _notification_map: dict[str, int] | None = None, | |||
| ): | |||
| # Core routing parameters | |||
| self.config_url = config_url | |||
| @@ -47,17 +48,17 @@ class ApolloClient: | |||
| # Private control variables | |||
| self._cycle_time = 5 | |||
| self._stopping = False | |||
| self._cache = {} | |||
| self._no_key = {} | |||
| self._hash = {} | |||
| self._cache: dict[str, dict[str, Any]] = {} | |||
| self._no_key: dict[str, str] = {} | |||
| self._hash: dict[str, str] = {} | |||
| self._pull_timeout = 75 | |||
| self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/" | |||
| self._long_poll_thread = None | |||
| self._long_poll_thread: threading.Thread | None = None | |||
| self._change_listener = change_listener # "add" "delete" "update" | |||
| if _notification_map is None: | |||
| _notification_map = {"application": -1} | |||
| self._notification_map = _notification_map | |||
| self.last_release_key = None | |||
| self.last_release_key: str | None = None | |||
| # Private startup method | |||
| self._path_checker() | |||
| if start_hot_update: | |||
| @@ -68,7 +69,7 @@ class ApolloClient: | |||
| heartbeat.daemon = True | |||
| heartbeat.start() | |||
| def get_json_from_net(self, namespace="application"): | |||
| def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None: | |||
| url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format( | |||
| self.config_url, self.app_id, self.cluster, namespace, "", self.ip | |||
| ) | |||
| @@ -88,7 +89,7 @@ class ApolloClient: | |||
| logger.exception("an error occurred in get_json_from_net") | |||
| return None | |||
| def get_value(self, key, default_val=None, namespace="application"): | |||
| def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any: | |||
| try: | |||
| # read memory configuration | |||
| namespace_cache = self._cache.get(namespace) | |||
| @@ -104,7 +105,8 @@ class ApolloClient: | |||
| namespace_data = self.get_json_from_net(namespace) | |||
| val = get_value_from_dict(namespace_data, key) | |||
| if val is not None: | |||
| self._update_cache_and_file(namespace_data, namespace) | |||
| if namespace_data is not None: | |||
| self._update_cache_and_file(namespace_data, namespace) | |||
| return val | |||
| # read the file configuration | |||
| @@ -126,23 +128,23 @@ class ApolloClient: | |||
| # to ensure the real-time correctness of the function call. | |||
| # If the user does not have the same default val twice | |||
| # and the default val is used here, there may be a problem. | |||
| def _set_local_cache_none(self, namespace, key): | |||
| def _set_local_cache_none(self, namespace: str, key: str) -> None: | |||
| no_key = no_key_cache_key(namespace, key) | |||
| self._no_key[no_key] = key | |||
| def _start_hot_update(self): | |||
| def _start_hot_update(self) -> None: | |||
| self._long_poll_thread = threading.Thread(target=self._listener) | |||
| # When the asynchronous thread is started, the daemon thread will automatically exit | |||
| # when the main thread is launched. | |||
| self._long_poll_thread.daemon = True | |||
| self._long_poll_thread.start() | |||
| def stop(self): | |||
| def stop(self) -> None: | |||
| self._stopping = True | |||
| logger.info("Stopping listener...") | |||
| # Call the set callback function, and if it is abnormal, try it out | |||
| def _call_listener(self, namespace, old_kv, new_kv): | |||
| def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None: | |||
| if self._change_listener is None: | |||
| return | |||
| if old_kv is None: | |||
| @@ -168,12 +170,12 @@ class ApolloClient: | |||
| except BaseException as e: | |||
| logger.warning(str(e)) | |||
| def _path_checker(self): | |||
| def _path_checker(self) -> None: | |||
| if not os.path.isdir(self._cache_file_path): | |||
| makedirs_wrapper(self._cache_file_path) | |||
| # update the local cache and file cache | |||
| def _update_cache_and_file(self, namespace_data, namespace="application"): | |||
| def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None: | |||
| # update the local cache | |||
| self._cache[namespace] = namespace_data | |||
| # update the file cache | |||
| @@ -187,7 +189,7 @@ class ApolloClient: | |||
| self._hash[namespace] = new_hash | |||
| # get the configuration from the local file | |||
| def _get_local_cache(self, namespace="application"): | |||
| def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]: | |||
| cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt") | |||
| if os.path.isfile(cache_file_path): | |||
| with open(cache_file_path) as f: | |||
| @@ -195,8 +197,8 @@ class ApolloClient: | |||
| return result | |||
| return {} | |||
| def _long_poll(self): | |||
| notifications = [] | |||
| def _long_poll(self) -> None: | |||
| notifications: list[dict[str, Any]] = [] | |||
| for key in self._cache: | |||
| namespace_data = self._cache[key] | |||
| notification_id = -1 | |||
| @@ -236,7 +238,7 @@ class ApolloClient: | |||
| except Exception as e: | |||
| logger.warning(str(e)) | |||
| def _get_net_and_set_local(self, namespace, n_id, call_change=False): | |||
| def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None: | |||
| namespace_data = self.get_json_from_net(namespace) | |||
| if not namespace_data: | |||
| return | |||
| @@ -248,7 +250,7 @@ class ApolloClient: | |||
| new_kv = namespace_data.get(CONFIGURATIONS) | |||
| self._call_listener(namespace, old_kv, new_kv) | |||
| def _listener(self): | |||
| def _listener(self) -> None: | |||
| logger.info("start long_poll") | |||
| while not self._stopping: | |||
| self._long_poll() | |||
| @@ -266,13 +268,13 @@ class ApolloClient: | |||
| headers["Timestamp"] = time_unix_now | |||
| return headers | |||
| def _heart_beat(self): | |||
| def _heart_beat(self) -> None: | |||
| while not self._stopping: | |||
| for namespace in self._notification_map: | |||
| self._do_heart_beat(namespace) | |||
| time.sleep(60 * 10) # 10 minutes | |||
| def _do_heart_beat(self, namespace): | |||
| def _do_heart_beat(self, namespace: str) -> None: | |||
| url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}" | |||
| try: | |||
| code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) | |||
| @@ -292,7 +294,7 @@ class ApolloClient: | |||
| logger.exception("an error occurred in _do_heart_beat") | |||
| return None | |||
| def get_all_dicts(self, namespace): | |||
| def get_all_dicts(self, namespace: str) -> dict[str, Any] | None: | |||
| namespace_data = self._cache.get(namespace) | |||
| if namespace_data is None: | |||
| net_namespace_data = self.get_json_from_net(namespace) | |||
| @@ -2,6 +2,8 @@ import logging | |||
| import os | |||
| import ssl | |||
| import urllib.request | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| from urllib import parse | |||
| from urllib.error import HTTPError | |||
| @@ -19,9 +21,9 @@ urllib.request.install_opener(opener) | |||
| logger = logging.getLogger(__name__) | |||
| def http_request(url, timeout, headers={}): | |||
| def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]: | |||
| try: | |||
| request = urllib.request.Request(url, headers=headers) | |||
| request = urllib.request.Request(url, headers=dict(headers)) | |||
| res = urllib.request.urlopen(request, timeout=timeout) | |||
| body = res.read().decode("utf-8") | |||
| return res.code, body | |||
| @@ -33,9 +35,9 @@ def http_request(url, timeout, headers={}): | |||
| raise e | |||
| def url_encode(params): | |||
| def url_encode(params: dict[str, Any]) -> str: | |||
| return parse.urlencode(params) | |||
| def makedirs_wrapper(path): | |||
| def makedirs_wrapper(path: str) -> None: | |||
| os.makedirs(path, exist_ok=True) | |||
| @@ -1,5 +1,6 @@ | |||
| import hashlib | |||
| import socket | |||
| from typing import Any | |||
| from .python_3x import url_encode | |||
| @@ -10,7 +11,7 @@ NAMESPACE_NAME = "namespaceName" | |||
| # add timestamps uris and keys | |||
| def signature(timestamp, uri, secret): | |||
| def signature(timestamp: str, uri: str, secret: str) -> str: | |||
| import base64 | |||
| import hmac | |||
| @@ -19,16 +20,16 @@ def signature(timestamp, uri, secret): | |||
| return base64.b64encode(hmac_code).decode() | |||
| def url_encode_wrapper(params): | |||
| def url_encode_wrapper(params: dict[str, Any]) -> str: | |||
| return url_encode(params) | |||
| def no_key_cache_key(namespace, key): | |||
| def no_key_cache_key(namespace: str, key: str) -> str: | |||
| return f"{namespace}{len(namespace)}{key}" | |||
| # Returns whether the obtained value is obtained, and None if it does not | |||
| def get_value_from_dict(namespace_cache, key): | |||
| def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None: | |||
| if namespace_cache: | |||
| kv_data = namespace_cache.get(CONFIGURATIONS) | |||
| if kv_data is None: | |||
| @@ -38,7 +39,7 @@ def get_value_from_dict(namespace_cache, key): | |||
| return None | |||
| def init_ip(): | |||
| def init_ip() -> str: | |||
| ip = "" | |||
| s = None | |||
| try: | |||
| @@ -11,5 +11,5 @@ class RemoteSettingsSource: | |||
| def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: | |||
| raise NotImplementedError | |||
| def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: | |||
| def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool): | |||
| return value | |||
| @@ -11,16 +11,16 @@ logger = logging.getLogger(__name__) | |||
| from configs.remote_settings_sources.base import RemoteSettingsSource | |||
| from .utils import _parse_config | |||
| from .utils import parse_config | |||
| class NacosSettingsSource(RemoteSettingsSource): | |||
| def __init__(self, configs: Mapping[str, Any]): | |||
| self.configs = configs | |||
| self.remote_configs: dict[str, Any] = {} | |||
| self.remote_configs: dict[str, str] = {} | |||
| self.async_init() | |||
| def async_init(self): | |||
| def async_init(self) -> None: | |||
| data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties") | |||
| group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify") | |||
| tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "") | |||
| @@ -33,18 +33,15 @@ class NacosSettingsSource(RemoteSettingsSource): | |||
| logger.exception("[get-access-token] exception occurred") | |||
| raise | |||
| def _parse_config(self, content: str) -> dict: | |||
| def _parse_config(self, content: str) -> dict[str, str]: | |||
| if not content: | |||
| return {} | |||
| try: | |||
| return _parse_config(self, content) | |||
| return parse_config(content) | |||
| except Exception as e: | |||
| raise RuntimeError(f"Failed to parse config: {e}") | |||
| def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: | |||
| if not isinstance(self.remote_configs, dict): | |||
| raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}") | |||
| field_value = self.remote_configs.get(field_name) | |||
| if field_value is None: | |||
| return None, field_name, False | |||
| @@ -17,11 +17,17 @@ class NacosHttpClient: | |||
| self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY") | |||
| self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY") | |||
| self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848") | |||
| self.token = None | |||
| self.token: str | None = None | |||
| self.token_ttl = 18000 | |||
| self.token_expire_time: float = 0 | |||
| def http_request(self, url, method="GET", headers=None, params=None): | |||
| def http_request( | |||
| self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None | |||
| ) -> str: | |||
| if headers is None: | |||
| headers = {} | |||
| if params is None: | |||
| params = {} | |||
| try: | |||
| self._inject_auth_info(headers, params) | |||
| response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) | |||
| @@ -30,7 +36,7 @@ class NacosHttpClient: | |||
| except requests.RequestException as e: | |||
| return f"Request to Nacos failed: {e}" | |||
| def _inject_auth_info(self, headers, params, module="config"): | |||
| def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: | |||
| headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"}) | |||
| if module == "login": | |||
| @@ -45,16 +51,17 @@ class NacosHttpClient: | |||
| headers["timeStamp"] = ts | |||
| if self.username and self.password: | |||
| self.get_access_token(force_refresh=False) | |||
| params["accessToken"] = self.token | |||
| if self.token is not None: | |||
| params["accessToken"] = self.token | |||
| def __do_sign(self, sign_str, sk): | |||
| def __do_sign(self, sign_str: str, sk: str) -> str: | |||
| return ( | |||
| base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest()) | |||
| .decode() | |||
| .strip() | |||
| ) | |||
| def get_sign_str(self, group, tenant, ts): | |||
| def get_sign_str(self, group: str, tenant: str, ts: str) -> str: | |||
| sign_str = "" | |||
| if tenant: | |||
| sign_str = tenant + "+" | |||
| @@ -63,7 +70,7 @@ class NacosHttpClient: | |||
| sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it. | |||
| return sign_str | |||
| def get_access_token(self, force_refresh=False): | |||
| def get_access_token(self, force_refresh: bool = False) -> str | None: | |||
| current_time = time.time() | |||
| if self.token and not force_refresh and self.token_expire_time > current_time: | |||
| return self.token | |||
| @@ -77,6 +84,7 @@ class NacosHttpClient: | |||
| self.token = response_data.get("accessToken") | |||
| self.token_ttl = response_data.get("tokenTtl", 18000) | |||
| self.token_expire_time = current_time + self.token_ttl - 10 | |||
| return self.token | |||
| except Exception: | |||
| logger.exception("[get-access-token] exception occur") | |||
| raise | |||
| @@ -1,4 +1,4 @@ | |||
| def _parse_config(self, content: str) -> dict[str, str]: | |||
| def parse_config(content: str) -> dict[str, str]: | |||
| config: dict[str, str] = {} | |||
| if not content: | |||
| return config | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import request | |||
| from flask_restx import Resource, reqparse | |||
| @@ -6,6 +8,8 @@ from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| from configs import dify_config | |||
| from constants.languages import supported_language | |||
| from controllers.console import api | |||
| @@ -14,9 +18,9 @@ from extensions.ext_database import db | |||
| from models.model import App, InstalledApp, RecommendedApp | |||
| def admin_required(view): | |||
| def admin_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if not dify_config.ADMIN_API_KEY: | |||
| raise Unauthorized("API key is invalid.") | |||
| @@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource): | |||
| custom="max_keys_exceeded", | |||
| ) | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| key = ApiToken.generate_api_key(self.token_prefix or "", 24) | |||
| api_token = ApiToken() | |||
| setattr(api_token, self.resource_id_field, resource_id) | |||
| api_token.tenant_id = current_user.current_tenant_id | |||
| @@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self) -> dict: | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("type", type=str, required=True, default=False, location="json") | |||
| args = parser.parse_args() | |||
| @@ -1,5 +1,5 @@ | |||
| import logging | |||
| from typing import Any, NoReturn | |||
| from typing import NoReturn | |||
| from flask import Response | |||
| from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse | |||
| @@ -31,7 +31,7 @@ from services.workflow_service import WorkflowService | |||
| logger = logging.getLogger(__name__) | |||
| def _convert_values_to_json_serializable_object(value: Segment) -> Any: | |||
| def _convert_values_to_json_serializable_object(value: Segment): | |||
| if isinstance(value, FileSegment): | |||
| return value.value.model_dump() | |||
| elif isinstance(value, ArrayFileSegment): | |||
| @@ -42,8 +42,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any: | |||
| return value.value | |||
| def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: | |||
| """Serialize variable value. If variable is truncated, return the truncated value.""" | |||
| def _serialize_var_value(variable: WorkflowDraftVariable): | |||
| value = variable.get_value() | |||
| # create a copy of the value to avoid affecting the model cache. | |||
| value = value.model_copy(deep=True) | |||
| @@ -1,5 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import cast | |||
| from typing import Concatenate, ParamSpec, TypeVar, cast | |||
| import flask_login | |||
| from flask import jsonify, request | |||
| @@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, | |||
| from .. import api | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| T = TypeVar("T") | |||
| def oauth_server_client_id_required(view): | |||
| def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(self: T, *args: P.args, **kwargs: P.kwargs): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("client_id", type=str, required=True, location="json") | |||
| parsed_args = parser.parse_args() | |||
| @@ -30,18 +35,15 @@ def oauth_server_client_id_required(view): | |||
| if not oauth_provider_app: | |||
| raise NotFound("client_id is invalid") | |||
| kwargs["oauth_provider_app"] = oauth_provider_app | |||
| return view(*args, **kwargs) | |||
| return view(self, oauth_provider_app, *args, **kwargs) | |||
| return decorated | |||
| def oauth_server_access_token_required(view): | |||
| def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| oauth_provider_app = kwargs.get("oauth_provider_app") | |||
| if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): | |||
| def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs): | |||
| if not isinstance(oauth_provider_app, OAuthProviderApp): | |||
| raise BadRequest("Invalid oauth_provider_app") | |||
| authorization_header = request.headers.get("Authorization") | |||
| @@ -79,9 +81,7 @@ def oauth_server_access_token_required(view): | |||
| response.headers["WWW-Authenticate"] = "Bearer" | |||
| return response | |||
| kwargs["account"] = account | |||
| return view(*args, **kwargs) | |||
| return view(self, oauth_provider_app, account, *args, **kwargs) | |||
| return decorated | |||
| @@ -1,9 +1,9 @@ | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required | |||
| from libs.login import login_required | |||
| from libs.login import current_user, login_required | |||
| from models.model import Account | |||
| from services.billing_service import BillingService | |||
| @@ -17,9 +17,10 @@ class Subscription(Resource): | |||
| parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) | |||
| parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) | |||
| args = parser.parse_args() | |||
| assert isinstance(current_user, Account) | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| assert current_user.current_tenant_id is not None | |||
| return BillingService.get_subscription( | |||
| args["plan"], args["interval"], current_user.email, current_user.current_tenant_id | |||
| ) | |||
| @@ -31,7 +32,9 @@ class Invoices(Resource): | |||
| @account_initialization_required | |||
| @only_edition_cloud | |||
| def get(self): | |||
| assert isinstance(current_user, Account) | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| assert current_user.current_tenant_id is not None | |||
| return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) | |||
| @@ -477,6 +477,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| data_source_info = document.data_source_info_dict | |||
| if document.data_source_type == "upload_file": | |||
| if not data_source_info: | |||
| continue | |||
| file_id = data_source_info["upload_file_id"] | |||
| file_detail = ( | |||
| db.session.query(UploadFile) | |||
| @@ -493,6 +495,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "notion_import": | |||
| if not data_source_info: | |||
| continue | |||
| extract_setting = ExtractSetting( | |||
| datasource_type=DatasourceType.NOTION.value, | |||
| notion_info={ | |||
| @@ -506,6 +510,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "website_crawl": | |||
| if not data_source_info: | |||
| continue | |||
| extract_setting = ExtractSetting( | |||
| datasource_type=DatasourceType.WEBSITE.value, | |||
| website_info={ | |||
| @@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource): | |||
| def get(self, installed_app: InstalledApp): | |||
| """Get app meta""" | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise ValueError("App not found") | |||
| return AppService().get_app_meta(app_model) | |||
| @@ -36,6 +36,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): | |||
| Run workflow | |||
| """ | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise NotWorkflowAppError() | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| @@ -74,6 +76,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): | |||
| Stop workflow task | |||
| """ | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise NotWorkflowAppError() | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Concatenate, Optional, ParamSpec, TypeVar | |||
| from flask_login import current_user | |||
| from flask_restx import Resource | |||
| @@ -13,19 +15,15 @@ from services.app_service import AppService | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.feature_service import FeatureService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| T = TypeVar("T") | |||
| def installed_app_required(view=None): | |||
| def decorator(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not kwargs.get("installed_app_id"): | |||
| raise ValueError("missing installed_app_id in path parameters") | |||
| installed_app_id = kwargs.get("installed_app_id") | |||
| installed_app_id = str(installed_app_id) | |||
| del kwargs["installed_app_id"] | |||
| def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): | |||
| installed_app = ( | |||
| db.session.query(InstalledApp) | |||
| .where( | |||
| @@ -52,10 +50,10 @@ def installed_app_required(view=None): | |||
| return decorator | |||
| def user_allowed_to_access_app(view=None): | |||
| def decorator(view): | |||
| def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(installed_app: InstalledApp, *args, **kwargs): | |||
| def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): | |||
| feature = FeatureService.get_system_features() | |||
| if feature.webapp_auth.enabled: | |||
| app_id = installed_app.app_id | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask_login import current_user | |||
| from sqlalchemy.orm import Session | |||
| @@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden | |||
| from extensions.ext_database import db | |||
| from models.account import TenantPluginPermission | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def plugin_permission_required( | |||
| install_required: bool = False, | |||
| debug_required: bool = False, | |||
| ): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| user = current_user | |||
| tenant_id = user.current_tenant_id | |||
| @@ -2,7 +2,9 @@ import contextlib | |||
| import json | |||
| import os | |||
| import time | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import abort, request | |||
| from flask_login import current_user | |||
| @@ -19,10 +21,13 @@ from services.operation_service import OperationService | |||
| from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def account_initialization_required(view): | |||
| def account_initialization_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| # check account initialization | |||
| account = current_user | |||
| @@ -34,9 +39,9 @@ def account_initialization_required(view): | |||
| return decorated | |||
| def only_edition_cloud(view): | |||
| def only_edition_cloud(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if dify_config.EDITION != "CLOUD": | |||
| abort(404) | |||
| @@ -45,9 +50,9 @@ def only_edition_cloud(view): | |||
| return decorated | |||
| def only_edition_enterprise(view): | |||
| def only_edition_enterprise(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if not dify_config.ENTERPRISE_ENABLED: | |||
| abort(404) | |||
| @@ -56,9 +61,9 @@ def only_edition_enterprise(view): | |||
| return decorated | |||
| def only_edition_self_hosted(view): | |||
| def only_edition_self_hosted(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if dify_config.EDITION != "SELF_HOSTED": | |||
| abort(404) | |||
| @@ -67,9 +72,9 @@ def only_edition_self_hosted(view): | |||
| return decorated | |||
| def cloud_edition_billing_enabled(view): | |||
| def cloud_edition_billing_enabled(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if not features.billing.enabled: | |||
| abort(403, "Billing feature is not enabled.") | |||
| @@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view): | |||
| def cloud_edition_billing_resource_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| members = features.members | |||
| @@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str): | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| if resource == "add_segment": | |||
| @@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): | |||
| def cloud_edition_billing_rate_limit_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if resource == "knowledge": | |||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) | |||
| if knowledge_rate_limit.enabled: | |||
| @@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str): | |||
| return interceptor | |||
| def cloud_utm_record(view): | |||
| def cloud_utm_record(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| with contextlib.suppress(Exception): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| @@ -194,9 +199,9 @@ def cloud_utm_record(view): | |||
| return decorated | |||
| def setup_required(view): | |||
| def setup_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| # check setup | |||
| if ( | |||
| dify_config.EDITION == "SELF_HOSTED" | |||
| @@ -212,9 +217,9 @@ def setup_required(view): | |||
| return decorated | |||
| def enterprise_license_required(view): | |||
| def enterprise_license_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| settings = FeatureService.get_system_features() | |||
| if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: | |||
| raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") | |||
| @@ -224,9 +229,9 @@ def enterprise_license_required(view): | |||
| return decorated | |||
| def email_password_login_enabled(view): | |||
| def email_password_login_enabled(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_system_features() | |||
| if features.enable_email_password_login: | |||
| return view(*args, **kwargs) | |||
| @@ -237,9 +242,9 @@ def email_password_login_enabled(view): | |||
| return decorated | |||
| def enable_change_email(view): | |||
| def enable_change_email(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_system_features() | |||
| if features.enable_change_email: | |||
| return view(*args, **kwargs) | |||
| @@ -250,9 +255,9 @@ def enable_change_email(view): | |||
| return decorated | |||
| def is_allow_transfer_owner(view): | |||
| def is_allow_transfer_owner(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.is_allow_transfer_workspace: | |||
| return view(*args, **kwargs) | |||
| @@ -99,7 +99,7 @@ class MCPAppApi(Resource): | |||
| return mcp_server, app | |||
| def _validate_server_status(self, mcp_server: AppMCPServer) -> None: | |||
| def _validate_server_status(self, mcp_server: AppMCPServer): | |||
| """Validate MCP server status""" | |||
| if mcp_server.status != AppMCPServerStatus.ACTIVE: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") | |||
| @@ -440,7 +440,7 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| raise NotFound("Segment not found.") | |||
| # validate segment belongs to the specified document | |||
| if segment.document_id != document_id: | |||
| if str(segment.document_id) != str(document_id): | |||
| raise NotFound("Document not found.") | |||
| # check child chunk | |||
| @@ -451,7 +451,7 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| raise NotFound("Child chunk not found.") | |||
| # validate child chunk belongs to the specified segment | |||
| if child_chunk.segment_id != segment.id: | |||
| if str(child_chunk.segment_id) != str(segment.id): | |||
| raise NotFound("Child chunk not found.") | |||
| try: | |||
| @@ -500,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| raise NotFound("Segment not found.") | |||
| # validate segment belongs to the specified document | |||
| if segment.document_id != document_id: | |||
| if str(segment.document_id) != str(document_id): | |||
| raise NotFound("Segment not found.") | |||
| # get child chunk | |||
| @@ -511,7 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| raise NotFound("Child chunk not found.") | |||
| # validate child chunk belongs to the specified segment | |||
| if child_chunk.segment_id != segment.id: | |||
| if str(child_chunk.segment_id) != str(segment.id): | |||
| raise NotFound("Child chunk not found.") | |||
| # validate args | |||
| @@ -3,7 +3,7 @@ from collections.abc import Callable | |||
| from datetime import timedelta | |||
| from enum import StrEnum, auto | |||
| from functools import wraps | |||
| from typing import Optional | |||
| from typing import Optional, ParamSpec, TypeVar | |||
| from flask import current_app, request | |||
| from flask_login import user_logged_in | |||
| @@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog | |||
| from models.model import ApiToken, App, EndUser | |||
| from services.feature_service import FeatureService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| class WhereisUserArg(StrEnum): | |||
| """ | |||
| @@ -60,27 +63,6 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| if tenant.status == TenantStatus.ARCHIVE: | |||
| raise Forbidden("The workspace's status is archived.") | |||
| tenant_account_join = ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .where(Tenant.id == api_token.tenant_id) | |||
| .where(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .where(TenantAccountJoin.role.in_(["owner"])) | |||
| .where(Tenant.status == TenantStatus.NORMAL) | |||
| .one_or_none() | |||
| ) # TODO: only owner information is required, so only one is returned. | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| account = db.session.query(Account).where(Account.id == ta.account_id).first() | |||
| # Login admin | |||
| if account: | |||
| account.current_tenant = tenant | |||
| current_app.login_manager._update_request_context_with_user(account) # type: ignore | |||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore | |||
| else: | |||
| raise Unauthorized("Tenant owner account does not exist.") | |||
| else: | |||
| raise Unauthorized("Tenant does not exist.") | |||
| kwargs["app_model"] = app_model | |||
| if fetch_user_arg: | |||
| @@ -118,8 +100,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def decorated(*args, **kwargs): | |||
| def interceptor(view: Callable[P, R]): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| features = FeatureService.get_features(api_token.tenant_id) | |||
| @@ -148,9 +130,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| features = FeatureService.get_features(api_token.tenant_id) | |||
| if features.billing.enabled: | |||
| @@ -170,9 +152,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s | |||
| def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| if resource == "knowledge": | |||
| @@ -1,5 +1,6 @@ | |||
| from datetime import UTC, datetime | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import request | |||
| from flask_restx import Resource | |||
| @@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett | |||
| from services.feature_service import FeatureService | |||
| from services.webapp_auth_service import WebAppAuthService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def validate_jwt_token(view=None): | |||
| def decorator(view): | |||
| @@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner): | |||
| model_instance: ModelInstance, | |||
| memory: Optional[TokenBufferMemory] = None, | |||
| prompt_messages: Optional[list[PromptMessage]] = None, | |||
| ) -> None: | |||
| ): | |||
| self.tenant_id = tenant_id | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| @@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| return instruction | |||
| def _init_react_state(self, query) -> None: | |||
| def _init_react_state(self, query): | |||
| """ | |||
| init agent scratchpad | |||
| """ | |||
| @@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel): | |||
| action_name: str | |||
| action_input: Union[dict, str] | |||
| def to_dict(self) -> dict: | |||
| def to_dict(self): | |||
| """ | |||
| Convert to dictionary. | |||
| """ | |||
| @@ -158,7 +158,7 @@ class DatasetConfigManager: | |||
| return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] | |||
| @classmethod | |||
| def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: | |||
| def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict): | |||
| """ | |||
| Extract dataset config for legacy compatibility | |||
| @@ -105,7 +105,7 @@ class ModelConfigManager: | |||
| return dict(config), ["model"] | |||
| @classmethod | |||
| def validate_model_completion_params(cls, cp: dict) -> dict: | |||
| def validate_model_completion_params(cls, cp: dict): | |||
| # model.completion_params | |||
| if not isinstance(cp, dict): | |||
| raise ValueError("model.completion_params must be of object type") | |||
| @@ -122,7 +122,7 @@ class PromptTemplateConfigManager: | |||
| return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] | |||
| @classmethod | |||
| def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: | |||
| def validate_post_prompt_and_set_defaults(cls, config: dict): | |||
| """ | |||
| Validate post_prompt and set defaults for prompt feature | |||
| @@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): | |||
| """ | |||
| Validate for advanced chat app model config | |||
| @@ -481,7 +481,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| message_id: str, | |||
| context: contextvars.Context, | |||
| variable_loader: VariableLoader, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -55,7 +55,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| app: App, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| queue_manager=queue_manager, | |||
| variable_loader=variable_loader, | |||
| @@ -69,7 +69,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| self.system_user_id = system_user_id | |||
| self._app = app | |||
| def run(self) -> None: | |||
| def run(self): | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(AdvancedChatAppConfig, app_config) | |||
| @@ -184,6 +184,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| graph_runtime_state=graph_runtime_state, | |||
| command_channel=command_channel, | |||
| ) | |||
| @@ -238,7 +239,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| return False | |||
| def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: | |||
| def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy): | |||
| """ | |||
| Direct output | |||
| """ | |||
| @@ -96,7 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| draft_var_saver_factory: DraftVariableSaverFactory, | |||
| ) -> None: | |||
| ): | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| @@ -284,7 +284,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| session.rollback() | |||
| raise | |||
| def _ensure_workflow_initialized(self) -> None: | |||
| def _ensure_workflow_initialized(self): | |||
| """Fluent validation for workflow state.""" | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| @@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None): | |||
| message = self._get_message(session=session) | |||
| # If there are assistant files, remove markdown image links from answer | |||
| @@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): | |||
| """ | |||
| Validate for agent chat app model config | |||
| @@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| queue_manager: AppQueueManager, | |||
| conversation_id: str, | |||
| message_id: str, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -35,7 +35,7 @@ class AgentChatAppRunner(AppRunner): | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Run assistant application | |||
| :param application_generate_entity: application generate entity | |||
| @@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = ChatbotAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return response | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC): | |||
| return metadata | |||
| @classmethod | |||
| def _error_to_stream_response(cls, e: Exception) -> dict: | |||
| def _error_to_stream_response(cls, e: Exception): | |||
| """ | |||
| Error to stream response. | |||
| :param e: exception | |||
| @@ -158,7 +158,7 @@ class BaseAppGenerator: | |||
| return value | |||
| def _sanitize_value(self, value: Any) -> Any: | |||
| def _sanitize_value(self, value: Any): | |||
| if isinstance(value, str): | |||
| return value.replace("\x00", "") | |||
| return value | |||
| @@ -25,7 +25,7 @@ class PublishFrom(IntEnum): | |||
| class AppQueueManager: | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom): | |||
| if not user_id: | |||
| raise ValueError("user is required") | |||
| @@ -73,14 +73,14 @@ class AppQueueManager: | |||
| self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) | |||
| last_ping_time = elapsed_time // 10 | |||
| def stop_listen(self) -> None: | |||
| def stop_listen(self): | |||
| """ | |||
| Stop listen to queue | |||
| :return: | |||
| """ | |||
| self._q.put(None) | |||
| def publish_error(self, e, pub_from: PublishFrom) -> None: | |||
| def publish_error(self, e, pub_from: PublishFrom): | |||
| """ | |||
| Publish error | |||
| :param e: error | |||
| @@ -89,7 +89,7 @@ class AppQueueManager: | |||
| """ | |||
| self.publish(QueueErrorEvent(error=e), pub_from) | |||
| def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -100,7 +100,7 @@ class AppQueueManager: | |||
| self._publish(event, pub_from) | |||
| @abstractmethod | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -110,7 +110,7 @@ class AppQueueManager: | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: | |||
| def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str): | |||
| """ | |||
| Set task stop flag | |||
| :return: | |||
| @@ -162,7 +162,7 @@ class AppRunner: | |||
| text: str, | |||
| stream: bool, | |||
| usage: Optional[LLMUsage] = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Direct output | |||
| :param queue_manager: application queue manager | |||
| @@ -204,7 +204,7 @@ class AppRunner: | |||
| queue_manager: AppQueueManager, | |||
| stream: bool, | |||
| agent: bool = False, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| @@ -220,9 +220,7 @@ class AppRunner: | |||
| else: | |||
| raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") | |||
| def _handle_invoke_result_direct( | |||
| self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool | |||
| ) -> None: | |||
| def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): | |||
| """ | |||
| Handle invoke result direct | |||
| :param invoke_result: invoke result | |||
| @@ -239,7 +237,7 @@ class AppRunner: | |||
| def _handle_invoke_result_stream( | |||
| self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| @@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate for chat app model config | |||
| @@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| queue_manager: AppQueueManager, | |||
| conversation_id: str, | |||
| message_id: str, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -33,7 +33,7 @@ class ChatAppRunner(AppRunner): | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Run application | |||
| :param application_generate_entity: application generate entity | |||
| @@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = ChatbotAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return response | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -56,7 +56,7 @@ class WorkflowResponseConverter: | |||
| *, | |||
| application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | |||
| user: Union[Account, EndUser], | |||
| ) -> None: | |||
| ): | |||
| self._application_generate_entity = application_generate_entity | |||
| self._user = user | |||
| self._truncator = VariableTruncator.default() | |||
| @@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate for completion app model config | |||
| @@ -192,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| application_generate_entity: CompletionAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| message_id: str, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| @@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| raise MessageNotExistsError() | |||
| current_app_model_config = app_model.app_model_config | |||
| if not current_app_model_config: | |||
| raise MoreLikeThisDisabledError() | |||
| more_like_this = current_app_model_config.more_like_this_dict | |||
| if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: | |||
| @@ -27,7 +27,7 @@ class CompletionAppRunner(AppRunner): | |||
| def run( | |||
| self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Run application | |||
| :param application_generate_entity: application generate entity | |||
| @@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = CompletionAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return response | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -14,14 +14,14 @@ from core.app.entities.queue_entities import ( | |||
| class MessageBasedAppQueueManager(AppQueueManager): | |||
| def __init__( | |||
| self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str | |||
| ) -> None: | |||
| ): | |||
| super().__init__(task_id, user_id, invoke_from) | |||
| self._conversation_id = str(conversation_id) | |||
| self._app_mode = app_mode | |||
| self._message_id = str(message_id) | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): | |||
| return app_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): | |||
| """ | |||
| Validate for workflow app model config | |||
| @@ -14,12 +14,12 @@ from core.app.entities.queue_entities import ( | |||
| class WorkflowAppQueueManager(AppQueueManager): | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str): | |||
| super().__init__(task_id, user_id, invoke_from) | |||
| self._app_mode = app_mode | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| @@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| variable_loader: VariableLoader, | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| queue_manager=queue_manager, | |||
| variable_loader=variable_loader, | |||
| @@ -44,7 +44,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| self._workflow = workflow | |||
| self._sys_user_id = system_user_id | |||
| def run(self) -> None: | |||
| def run(self): | |||
| """ | |||
| Run application | |||
| """ | |||
| @@ -127,6 +127,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| graph_runtime_state=graph_runtime_state, | |||
| command_channel=command_channel, | |||
| ) | |||
| @@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = WorkflowAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| @@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return dict(blocking_response.to_dict()) | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| @@ -88,7 +88,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| draft_var_saver_factory: DraftVariableSaverFactory, | |||
| ) -> None: | |||
| ): | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| @@ -259,7 +259,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| session.rollback() | |||
| raise | |||
| def _ensure_workflow_initialized(self) -> None: | |||
| def _ensure_workflow_initialized(self): | |||
| """Fluent validation for workflow state.""" | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| @@ -697,7 +697,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if tts_publisher: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution): | |||
| invoke_from = self._application_generate_entity.invoke_from | |||
| if invoke_from == InvokeFrom.SERVICE_API: | |||
| created_from = WorkflowAppLogCreatedFrom.SERVICE_API | |||
| @@ -67,7 +67,7 @@ class WorkflowBasedAppRunner: | |||
| queue_manager: AppQueueManager, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| app_id: str, | |||
| ) -> None: | |||
| ): | |||
| self._queue_manager = queue_manager | |||
| self._variable_loader = variable_loader | |||
| self._app_id = app_id | |||
| @@ -348,7 +348,7 @@ class WorkflowBasedAppRunner: | |||
| return graph, variable_pool | |||
| def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: | |||
| def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): | |||
| """ | |||
| Handle event | |||
| :param workflow_entry: workflow entry | |||
| @@ -580,5 +580,5 @@ class WorkflowBasedAppRunner: | |||
| ) | |||
| ) | |||
| def _publish_event(self, event: AppQueueEvent) -> None: | |||
| def _publish_event(self, event: AppQueueEvent): | |||
| self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline: | |||
| application_generate_entity: AppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| stream: bool, | |||
| ) -> None: | |||
| ): | |||
| self._application_generate_entity = application_generate_entity | |||
| self.queue_manager = queue_manager | |||
| self._start_at = time.perf_counter() | |||
| @@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| conversation: Conversation, | |||
| message: Message, | |||
| stream: bool, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| @@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: | |||
| def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None): | |||
| """ | |||
| Save message. | |||
| :return: | |||
| @@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| application_generate_entity=self._application_generate_entity, | |||
| ) | |||
| def _handle_stop(self, event: QueueStopEvent) -> None: | |||
| def _handle_stop(self, event: QueueStopEvent): | |||
| """ | |||
| Handle stop. | |||
| :return: | |||
| @@ -48,7 +48,7 @@ class MessageCycleManager: | |||
| AdvancedChatAppGenerateEntity, | |||
| ], | |||
| task_state: Union[EasyUITaskState, WorkflowTaskState], | |||
| ) -> None: | |||
| ): | |||
| self._application_generate_entity = application_generate_entity | |||
| self._task_state = task_state | |||
| @@ -132,7 +132,7 @@ class MessageCycleManager: | |||
| return None | |||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent): | |||
| """ | |||
| Handle retriever resources. | |||
| :param event: event | |||
| @@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str: | |||
| return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" | |||
| def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: | |||
| def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None): | |||
| """Print text with highlighting and no end characters.""" | |||
| text_to_print = get_colored_text(text, color) if color else text | |||
| print(text_to_print, end=end, file=file) | |||
| @@ -37,7 +37,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| color: Optional[str] = "" | |||
| current_loop: int = 1 | |||
| def __init__(self, color: Optional[str] = None) -> None: | |||
| def __init__(self, color: Optional[str] = None): | |||
| super().__init__() | |||
| """Initialize callback handler.""" | |||
| # use a specific color is not specified | |||
| @@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| self, | |||
| tool_name: str, | |||
| tool_inputs: Mapping[str, Any], | |||
| ) -> None: | |||
| ): | |||
| """Do nothing.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) | |||
| @@ -61,7 +61,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| message_id: Optional[str] = None, | |||
| timer: Optional[Any] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> None: | |||
| ): | |||
| """If not the final action, print out observation.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_tool_end]\n", color=self.color) | |||
| @@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| ) | |||
| ) | |||
| def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: | |||
| def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any): | |||
| """Do nothing.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") | |||
| def on_agent_start(self, thought: str) -> None: | |||
| def on_agent_start(self, thought: str): | |||
| """Run on agent start.""" | |||
| if dify_config.DEBUG: | |||
| if thought: | |||
| @@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| else: | |||
| print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) | |||
| def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: | |||
| def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any): | |||
| """Run on agent end.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) | |||
| @@ -21,14 +21,14 @@ class DatasetIndexToolCallbackHandler: | |||
| def __init__( | |||
| self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom | |||
| ) -> None: | |||
| ): | |||
| self._queue_manager = queue_manager | |||
| self._app_id = app_id | |||
| self._message_id = message_id | |||
| self._user_id = user_id | |||
| self._invoke_from = invoke_from | |||
| def on_query(self, query: str, dataset_id: str) -> None: | |||
| def on_query(self, query: str, dataset_id: str): | |||
| """ | |||
| Handle query. | |||
| """ | |||
| @@ -46,7 +46,7 @@ class DatasetIndexToolCallbackHandler: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| def on_tool_end(self, documents: list[Document]) -> None: | |||
| def on_tool_end(self, documents: list[Document]): | |||
| """Handle tool end.""" | |||
| for document in documents: | |||
| if document.metadata is not None: | |||
| @@ -30,8 +30,6 @@ class FakeDatasourceRuntime(DatasourceRuntime): | |||
| """ | |||
| def __init__(self): | |||
| super().__init__( | |||
| tenant_id="fake_tenant_id", | |||
| datasource_id="fake_datasource_id", | |||
| @@ -33,7 +33,7 @@ class SimpleModelProviderEntity(BaseModel): | |||
| icon_large: Optional[I18nObject] = None | |||
| supported_model_types: list[ModelType] | |||
| def __init__(self, provider_entity: ProviderEntity) -> None: | |||
| def __init__(self, provider_entity: ProviderEntity): | |||
| """ | |||
| Init simple provider. | |||
| @@ -57,7 +57,7 @@ class ProviderModelWithStatusEntity(ProviderModel): | |||
| load_balancing_enabled: bool = False | |||
| has_invalid_load_balancing_configs: bool = False | |||
| def raise_for_status(self) -> None: | |||
| def raise_for_status(self): | |||
| """ | |||
| Check model status and raise ValueError if not active. | |||
| @@ -280,9 +280,7 @@ class ProviderConfiguration(BaseModel): | |||
| else [], | |||
| ) | |||
| def validate_provider_credentials( | |||
| self, credentials: dict, credential_id: str = "", session: Session | None = None | |||
| ) -> dict: | |||
| def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): | |||
| """ | |||
| Validate custom credentials. | |||
| :param credentials: provider credentials | |||
| @@ -291,7 +289,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| def _validate(s: Session) -> dict: | |||
| def _validate(s: Session): | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self.extract_secret_variables( | |||
| self.provider.provider_credential_schema.credential_form_schemas | |||
| @@ -402,7 +400,7 @@ class ProviderConfiguration(BaseModel): | |||
| logger.warning("Error generating next credential name: %s", str(e)) | |||
| return "API KEY 1" | |||
| def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: | |||
| def create_provider_credential(self, credentials: dict, credential_name: str | None): | |||
| """ | |||
| Add custom provider credentials. | |||
| :param credentials: provider credentials | |||
| @@ -458,7 +456,7 @@ class ProviderConfiguration(BaseModel): | |||
| credentials: dict, | |||
| credential_id: str, | |||
| credential_name: str | None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| update a saved provider credential (by credential_id). | |||
| @@ -519,7 +517,7 @@ class ProviderConfiguration(BaseModel): | |||
| credential_record: ProviderCredential | ProviderModelCredential, | |||
| credential_source: str, | |||
| session: Session, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Update load balancing configurations that reference the given credential_id. | |||
| @@ -559,7 +557,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.commit() | |||
| def delete_provider_credential(self, credential_id: str) -> None: | |||
| def delete_provider_credential(self, credential_id: str): | |||
| """ | |||
| Delete a saved provider credential (by credential_id). | |||
| @@ -636,7 +634,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.rollback() | |||
| raise | |||
| def switch_active_provider_credential(self, credential_id: str) -> None: | |||
| def switch_active_provider_credential(self, credential_id: str): | |||
| """ | |||
| Switch active provider credential (copy the selected one into current active snapshot). | |||
| @@ -815,7 +813,7 @@ class ProviderConfiguration(BaseModel): | |||
| credentials: dict, | |||
| credential_id: str = "", | |||
| session: Session | None = None, | |||
| ) -> dict: | |||
| ): | |||
| """ | |||
| Validate custom model credentials. | |||
| @@ -826,7 +824,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| def _validate(s: Session) -> dict: | |||
| def _validate(s: Session): | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self.extract_secret_variables( | |||
| self.provider.model_credential_schema.credential_form_schemas | |||
| @@ -1010,7 +1008,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.rollback() | |||
| raise | |||
| def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||
| def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): | |||
| """ | |||
| Delete a saved provider credential (by credential_id). | |||
| @@ -1080,7 +1078,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.rollback() | |||
| raise | |||
| def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||
| def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str): | |||
| """ | |||
| if model list exist this custom model, switch the custom model credential. | |||
| if model list not exist this custom model, use the credential to add a new custom model record. | |||
| @@ -1123,7 +1121,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.add(provider_model_record) | |||
| session.commit() | |||
| def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: | |||
| def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): | |||
| """ | |||
| switch the custom model credential. | |||
| @@ -1153,7 +1151,7 @@ class ProviderConfiguration(BaseModel): | |||
| session.add(provider_model_record) | |||
| session.commit() | |||
| def delete_custom_model(self, model_type: ModelType, model: str) -> None: | |||
| def delete_custom_model(self, model_type: ModelType, model: str): | |||
| """ | |||
| Delete custom model. | |||
| :param model_type: model type | |||
| @@ -1350,7 +1348,7 @@ class ProviderConfiguration(BaseModel): | |||
| provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | |||
| ) | |||
| def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: | |||
| def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None): | |||
| """ | |||
| Switch preferred provider type. | |||
| :param provider_type: | |||
| @@ -1362,7 +1360,7 @@ class ProviderConfiguration(BaseModel): | |||
| if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: | |||
| return | |||
| def _switch(s: Session) -> None: | |||
| def _switch(s: Session): | |||
| # get preferred provider | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| @@ -1406,7 +1404,7 @@ class ProviderConfiguration(BaseModel): | |||
| return secret_input_form_variables | |||
| def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: | |||
| def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): | |||
| """ | |||
| Obfuscated credentials. | |||
| @@ -6,7 +6,7 @@ class LLMError(ValueError): | |||
| description: Optional[str] = None | |||
| def __init__(self, description: Optional[str] = None) -> None: | |||
| def __init__(self, description: Optional[str] = None): | |||
| self.description = description | |||
| @@ -10,11 +10,11 @@ class APIBasedExtensionRequestor: | |||
| timeout: tuple[int, int] = (5, 60) | |||
| """timeout for request connect and read""" | |||
| def __init__(self, api_endpoint: str, api_key: str) -> None: | |||
| def __init__(self, api_endpoint: str, api_key: str): | |||
| self.api_endpoint = api_endpoint | |||
| self.api_key = api_key | |||
| def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: | |||
| def request(self, point: APIBasedExtensionPoint, params: dict): | |||
| """ | |||
| Request the api. | |||
| @@ -34,7 +34,7 @@ class Extensible: | |||
| tenant_id: str | |||
| config: Optional[dict] = None | |||
| def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: | |||
| def __init__(self, tenant_id: str, config: Optional[dict] = None): | |||
| self.tenant_id = tenant_id | |||
| self.config = config | |||
| @@ -18,7 +18,7 @@ class ApiExternalDataTool(ExternalDataTool): | |||
| """the unique name of external data tool""" | |||
| @classmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -16,14 +16,14 @@ class ExternalDataTool(Extensible, ABC): | |||
| variable: str | |||
| """the tool variable name of app tool""" | |||
| def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: | |||
| def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None): | |||
| super().__init__(tenant_id, config) | |||
| self.app_id = app_id | |||
| self.variable = variable | |||
| @classmethod | |||
| @abstractmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension | |||
| class ExternalDataToolFactory: | |||
| def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: | |||
| def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict): | |||
| extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) | |||
| self.__extension_instance = extension_class( | |||
| tenant_id=tenant_id, app_id=app_id, variable=variable, config=config | |||
| ) | |||
| @classmethod | |||
| def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: | |||
| def validate_config(cls, name: str, tenant_id: str, config: dict): | |||
| """ | |||
| Validate the incoming form config data. | |||
| @@ -7,6 +7,6 @@ if TYPE_CHECKING: | |||
| _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None | |||
| def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: | |||
| def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]): | |||
| global _tool_file_manager_factory | |||
| _tool_file_manager_factory = factory | |||
| @@ -22,7 +22,7 @@ class CodeNodeProvider(BaseModel): | |||
| pass | |||
| @classmethod | |||
| def get_default_config(cls) -> dict: | |||
| def get_default_config(cls): | |||
| return { | |||
| "type": "code", | |||
| "config": { | |||
| @@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer | |||
| class Jinja2TemplateTransformer(TemplateTransformer): | |||
| @classmethod | |||
| def transform_response(cls, response: str) -> dict: | |||
| def transform_response(cls, response: str): | |||
| """ | |||
| Transform response to dict | |||
| :param response: response | |||
| @@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider): | |||
| def get_default_code(cls) -> str: | |||
| return dedent( | |||
| """ | |||
| def main(arg1: str, arg2: str) -> dict: | |||
| def main(arg1: str, arg2: str): | |||
| return { | |||
| "result": arg1 + arg2, | |||
| } | |||
| @@ -34,7 +34,7 @@ class ProviderCredentialsCache: | |||
| else: | |||
| return None | |||
| def set(self, credentials: dict) -> None: | |||
| def set(self, credentials: dict): | |||
| """ | |||
| Cache model provider credentials. | |||
| @@ -43,7 +43,7 @@ class ProviderCredentialsCache: | |||
| """ | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """ | |||
| Delete cached model provider credentials. | |||
| @@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC): | |||
| return None | |||
| return None | |||
| def set(self, config: dict[str, Any]) -> None: | |||
| def set(self, config: dict[str, Any]): | |||
| """Cache provider credentials""" | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(config)) | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """Delete cached provider credentials""" | |||
| redis_client.delete(self.cache_key) | |||
| @@ -75,10 +75,10 @@ class NoOpProviderCredentialCache: | |||
| """Get cached provider credentials""" | |||
| return None | |||
| def set(self, config: dict[str, Any]) -> None: | |||
| def set(self, config: dict[str, Any]): | |||
| """Cache provider credentials""" | |||
| pass | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """Delete cached provider credentials""" | |||
| pass | |||
| @@ -37,11 +37,11 @@ class ToolParameterCache: | |||
| else: | |||
| return None | |||
| def set(self, parameters: dict) -> None: | |||
| def set(self, parameters: dict): | |||
| """Cache model provider credentials.""" | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) | |||
| def delete(self) -> None: | |||
| def delete(self): | |||
| """ | |||
| Delete cached model provider credentials. | |||
| @@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]: | |||
| return None | |||
| def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: | |||
| def extract_external_trace_id_from_args(args: Mapping[str, Any]): | |||
| """ | |||
| Extract 'external_trace_id' from args. | |||
| @@ -44,11 +44,11 @@ class HostingConfiguration: | |||
| provider_map: dict[str, HostingProvider] | |||
| moderation_config: Optional[HostedModerationConfig] = None | |||
| def __init__(self) -> None: | |||
| def __init__(self): | |||
| self.provider_map = {} | |||
| self.moderation_config = None | |||
| def init_app(self, app: Flask) -> None: | |||
| def init_app(self, app: Flask): | |||
| if dify_config.EDITION != "CLOUD": | |||
| return | |||
| @@ -270,7 +270,9 @@ class IndexingRunner: | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| preview_texts = [] # type: ignore | |||
| # keep separate, avoid union-list ambiguity | |||
| preview_texts: list[PreviewDetail] = [] | |||
| qa_preview_texts: list[QAPreviewDetail] = [] | |||
| total_segments = 0 | |||
| index_type = doc_form | |||
| @@ -293,14 +295,14 @@ class IndexingRunner: | |||
| for document in documents: | |||
| if len(preview_texts) < 10: | |||
| if doc_form and doc_form == "qa_model": | |||
| preview_detail = QAPreviewDetail( | |||
| qa_detail = QAPreviewDetail( | |||
| question=document.page_content, answer=document.metadata.get("answer") or "" | |||
| ) | |||
| preview_texts.append(preview_detail) | |||
| qa_preview_texts.append(qa_detail) | |||
| else: | |||
| preview_detail = PreviewDetail(content=document.page_content) # type: ignore | |||
| preview_detail = PreviewDetail(content=document.page_content) | |||
| if document.children: | |||
| preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore | |||
| preview_detail.child_chunks = [child.page_content for child in document.children] | |||
| preview_texts.append(preview_detail) | |||
| # delete image files and related db records | |||
| @@ -321,8 +323,8 @@ class IndexingRunner: | |||
| db.session.delete(image_file) | |||
| if doc_form and doc_form == "qa_model": | |||
| return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) | |||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore | |||
| return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[]) | |||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) | |||
| def _extract( | |||
| self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | |||
| @@ -425,6 +427,7 @@ class IndexingRunner: | |||
| """ | |||
| Get the NodeParser object according to the processing rule. | |||
| """ | |||
| character_splitter: TextSplitter | |||
| if processing_rule_mode in ["custom", "hierarchical"]: | |||
| # The user-defined segmentation rule | |||
| max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH | |||
| @@ -451,7 +454,7 @@ class IndexingRunner: | |||
| embedding_model_instance=embedding_model_instance, | |||
| ) | |||
| return character_splitter # type: ignore | |||
| return character_splitter | |||
| def _split_to_documents_for_estimate( | |||
| self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule | |||
| @@ -510,7 +513,7 @@ class IndexingRunner: | |||
| dataset: Dataset, | |||
| dataset_document: DatasetDocument, | |||
| documents: list[Document], | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| insert index and update document/segment status to completed | |||
| """ | |||
| @@ -649,7 +652,7 @@ class IndexingRunner: | |||
| @staticmethod | |||
| def _update_document_index_status( | |||
| document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Update the document indexing status. | |||
| """ | |||
| @@ -668,7 +671,7 @@ class IndexingRunner: | |||
| db.session.commit() | |||
| @staticmethod | |||
| def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: | |||
| def _update_segments_by_document(dataset_document_id: str, update_params: dict): | |||
| """ | |||
| Update the document segment by document id. | |||
| """ | |||
| @@ -129,7 +129,7 @@ class LLMGenerator: | |||
| return questions | |||
| @classmethod | |||
| def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: | |||
| def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): | |||
| output_parser = RuleConfigGeneratorOutputParser() | |||
| error = "" | |||
| @@ -264,9 +264,7 @@ class LLMGenerator: | |||
| return rule_config | |||
| @classmethod | |||
| def generate_code( | |||
| cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript" | |||
| ) -> dict: | |||
| def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): | |||
| if code_language == "python": | |||
| prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) | |||
| else: | |||
| @@ -375,7 +373,7 @@ class LLMGenerator: | |||
| @staticmethod | |||
| def instruction_modify_legacy( | |||
| tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None | |||
| ) -> dict: | |||
| ): | |||
| last_run: Message | None = ( | |||
| db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() | |||
| ) | |||
| @@ -415,7 +413,7 @@ class LLMGenerator: | |||
| instruction: str, | |||
| model_config: dict, | |||
| ideal_output: str | None, | |||
| ) -> dict: | |||
| ): | |||
| from services.workflow_service import WorkflowService | |||
| session = db.session() | |||
| @@ -455,7 +453,7 @@ class LLMGenerator: | |||
| return [] | |||
| parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) | |||
| def dict_of_event(event: AgentLogEvent) -> dict: | |||
| def dict_of_event(event: AgentLogEvent): | |||
| return { | |||
| "status": event.status, | |||
| "error": event.error, | |||
| @@ -493,7 +491,7 @@ class LLMGenerator: | |||
| instruction: str, | |||
| node_type: str, | |||
| ideal_output: str | None, | |||
| ) -> dict: | |||
| ): | |||
| LAST_RUN = "{{#last_run#}}" | |||
| CURRENT = "{{#current#}}" | |||
| ERROR_MESSAGE = "{{#error_message#}}" | |||
| @@ -1,5 +1,3 @@ | |||
| from typing import Any | |||
| from core.llm_generator.output_parser.errors import OutputParserError | |||
| from core.llm_generator.prompts import ( | |||
| RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, | |||
| @@ -17,7 +15,7 @@ class RuleConfigGeneratorOutputParser: | |||
| RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, | |||
| ) | |||
| def parse(self, text: str) -> Any: | |||
| def parse(self, text: str): | |||
| try: | |||
| expected_keys = ["prompt", "variables", "opening_statement"] | |||
| parsed = parse_and_check_json_markdown(text, expected_keys) | |||
| @@ -210,7 +210,7 @@ def _handle_native_json_schema( | |||
| structured_output_schema: Mapping, | |||
| model_parameters: dict, | |||
| rules: list[ParameterRule], | |||
| ) -> dict: | |||
| ): | |||
| """ | |||
| Handle structured output for models with native JSON schema support. | |||
| @@ -232,7 +232,7 @@ def _handle_native_json_schema( | |||
| return model_parameters | |||
| def _set_response_format(model_parameters: dict, rules: list) -> None: | |||
| def _set_response_format(model_parameters: dict, rules: list): | |||
| """ | |||
| Set the appropriate response format parameter based on model rules. | |||
| @@ -306,7 +306,7 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]: | |||
| return structured_output | |||
| def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: | |||
| def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping): | |||
| """ | |||
| Prepare JSON schema based on model requirements. | |||
| @@ -334,7 +334,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema | |||
| return {"schema": processed_schema, "name": "llm_response"} | |||
| def remove_additional_properties(schema: dict) -> None: | |||
| def remove_additional_properties(schema: dict): | |||
| """ | |||
| Remove additionalProperties fields from JSON schema. | |||
| Used for models like Gemini that don't support this property. | |||
| @@ -357,7 +357,7 @@ def remove_additional_properties(schema: dict) -> None: | |||
| remove_additional_properties(item) | |||
| def convert_boolean_to_string(schema: dict) -> None: | |||
| def convert_boolean_to_string(schema: dict): | |||
| """ | |||
| Convert boolean type specifications to string in JSON schema. | |||
| @@ -1,6 +1,5 @@ | |||
| import json | |||
| import re | |||
| from typing import Any | |||
| from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | |||
| @@ -9,7 +8,7 @@ class SuggestedQuestionsAfterAnswerOutputParser: | |||
| def get_format_instructions(self) -> str: | |||
| return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT | |||
| def parse(self, text: str) -> Any: | |||
| def parse(self, text: str): | |||
| action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) | |||
| if action_match is not None: | |||
| json_obj = json.loads(action_match.group(0).strip()) | |||
| @@ -44,7 +44,7 @@ class OAuthClientProvider: | |||
| return None | |||
| return OAuthClientInformation.model_validate(client_information) | |||
| def save_client_information(self, client_information: OAuthClientInformationFull) -> None: | |||
| def save_client_information(self, client_information: OAuthClientInformationFull): | |||
| """Saves client information after dynamic registration.""" | |||
| MCPToolManageService.update_mcp_provider_credentials( | |||
| self.mcp_provider, | |||
| @@ -63,13 +63,13 @@ class OAuthClientProvider: | |||
| refresh_token=credentials.get("refresh_token", ""), | |||
| ) | |||
| def save_tokens(self, tokens: OAuthTokens) -> None: | |||
| def save_tokens(self, tokens: OAuthTokens): | |||
| """Stores new OAuth tokens for the current session.""" | |||
| # update mcp provider credentials | |||
| token_dict = tokens.model_dump() | |||
| MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) | |||
| def save_code_verifier(self, code_verifier: str) -> None: | |||
| def save_code_verifier(self, code_verifier: str): | |||
| """Saves a PKCE code verifier for the current session.""" | |||
| MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) | |||
| @@ -47,7 +47,7 @@ class SSETransport: | |||
| headers: dict[str, Any] | None = None, | |||
| timeout: float = 5.0, | |||
| sse_read_timeout: float = 5 * 60, | |||
| ) -> None: | |||
| ): | |||
| """Initialize the SSE transport. | |||
| Args: | |||
| @@ -76,7 +76,7 @@ class SSETransport: | |||
| return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme | |||
| def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: | |||
| def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue): | |||
| """Handle an 'endpoint' SSE event. | |||
| Args: | |||
| @@ -94,7 +94,7 @@ class SSETransport: | |||
| status_queue.put(_StatusReady(endpoint_url)) | |||
| def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: | |||
| def _handle_message_event(self, sse_data: str, read_queue: ReadQueue): | |||
| """Handle a 'message' SSE event. | |||
| Args: | |||
| @@ -110,7 +110,7 @@ class SSETransport: | |||
| logger.exception("Error parsing server message") | |||
| read_queue.put(exc) | |||
| def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | |||
| def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue): | |||
| """Handle a single SSE event. | |||
| Args: | |||
| @@ -126,7 +126,7 @@ class SSETransport: | |||
| case _: | |||
| logger.warning("Unknown SSE event: %s", sse.event) | |||
| def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | |||
| def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue): | |||
| """Read and process SSE events. | |||
| Args: | |||
| @@ -144,7 +144,7 @@ class SSETransport: | |||
| finally: | |||
| read_queue.put(None) | |||
| def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: | |||
| def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage): | |||
| """Send a single message to the server. | |||
| Args: | |||
| @@ -163,7 +163,7 @@ class SSETransport: | |||
| response.raise_for_status() | |||
| logger.debug("Client message sent successfully: %s", response.status_code) | |||
| def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: | |||
| def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue): | |||
| """Handle writing messages to the server. | |||
| Args: | |||
| @@ -303,7 +303,7 @@ def sse_client( | |||
| write_queue.put(None) | |||
| def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: | |||
| def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage): | |||
| """ | |||
| Send a message to the server using the provided HTTP client. | |||
| @@ -82,7 +82,7 @@ class StreamableHTTPTransport: | |||
| headers: dict[str, Any] | None = None, | |||
| timeout: float | timedelta = 30, | |||
| sse_read_timeout: float | timedelta = 60 * 5, | |||
| ) -> None: | |||
| ): | |||
| """Initialize the StreamableHTTP transport. | |||
| Args: | |||
| @@ -122,7 +122,7 @@ class StreamableHTTPTransport: | |||
| def _maybe_extract_session_id_from_response( | |||
| self, | |||
| response: httpx.Response, | |||
| ) -> None: | |||
| ): | |||
| """Extract and store session ID from response headers.""" | |||
| new_session_id = response.headers.get(MCP_SESSION_ID) | |||
| if new_session_id: | |||
| @@ -173,7 +173,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| client: httpx.Client, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| ) -> None: | |||
| ): | |||
| """Handle GET stream for server-initiated messages.""" | |||
| try: | |||
| if not self.session_id: | |||
| @@ -197,7 +197,7 @@ class StreamableHTTPTransport: | |||
| except Exception as exc: | |||
| logger.debug("GET stream error (non-fatal): %s", exc) | |||
| def _handle_resumption_request(self, ctx: RequestContext) -> None: | |||
| def _handle_resumption_request(self, ctx: RequestContext): | |||
| """Handle a resumption request using GET with SSE.""" | |||
| headers = self._update_headers_with_session(ctx.headers) | |||
| if ctx.metadata and ctx.metadata.resumption_token: | |||
| @@ -230,7 +230,7 @@ class StreamableHTTPTransport: | |||
| if is_complete: | |||
| break | |||
| def _handle_post_request(self, ctx: RequestContext) -> None: | |||
| def _handle_post_request(self, ctx: RequestContext): | |||
| """Handle a POST request with response processing.""" | |||
| headers = self._update_headers_with_session(ctx.headers) | |||
| message = ctx.session_message.message | |||
| @@ -278,7 +278,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| response: httpx.Response, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| ) -> None: | |||
| ): | |||
| """Handle JSON response from the server.""" | |||
| try: | |||
| content = response.read() | |||
| @@ -288,7 +288,7 @@ class StreamableHTTPTransport: | |||
| except Exception as exc: | |||
| server_to_client_queue.put(exc) | |||
| def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: | |||
| def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext): | |||
| """Handle SSE response from the server.""" | |||
| try: | |||
| event_source = EventSource(response) | |||
| @@ -307,7 +307,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| content_type: str, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| ) -> None: | |||
| ): | |||
| """Handle unexpected content type in response.""" | |||
| error_msg = f"Unexpected content type: {content_type}" | |||
| logger.error(error_msg) | |||
| @@ -317,7 +317,7 @@ class StreamableHTTPTransport: | |||
| self, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| request_id: RequestId, | |||
| ) -> None: | |||
| ): | |||
| """Send a session terminated error response.""" | |||
| jsonrpc_error = JSONRPCError( | |||
| jsonrpc="2.0", | |||
| @@ -333,7 +333,7 @@ class StreamableHTTPTransport: | |||
| client_to_server_queue: ClientToServerQueue, | |||
| server_to_client_queue: ServerToClientQueue, | |||
| start_get_stream: Callable[[], None], | |||
| ) -> None: | |||
| ): | |||
| """Handle writing requests to the server. | |||
| This method processes messages from the client_to_server_queue and sends them to the server. | |||
| @@ -379,7 +379,7 @@ class StreamableHTTPTransport: | |||
| except Exception as exc: | |||
| server_to_client_queue.put(exc) | |||
| def terminate_session(self, client: httpx.Client) -> None: | |||
| def terminate_session(self, client: httpx.Client): | |||
| """Terminate the session by sending a DELETE request.""" | |||
| if not self.session_id: | |||
| return | |||
| @@ -441,7 +441,7 @@ def streamablehttp_client( | |||
| timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), | |||
| ) as client: | |||
| # Define callbacks that need access to thread pool | |||
| def start_get_stream() -> None: | |||
| def start_get_stream(): | |||
| """Start a worker thread to handle server-initiated messages.""" | |||
| executor.submit(transport.handle_get_stream, client, server_to_client_queue) | |||
| @@ -76,7 +76,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| ReceiveNotificationT | |||
| ]""", | |||
| on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], | |||
| ) -> None: | |||
| ): | |||
| self.request_id = request_id | |||
| self.request_meta = request_meta | |||
| self.request = request | |||
| @@ -95,7 +95,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| exc_type: type[BaseException] | None, | |||
| exc_val: BaseException | None, | |||
| exc_tb: TracebackType | None, | |||
| ) -> None: | |||
| ): | |||
| """Exit the context manager, performing cleanup and notifying completion.""" | |||
| try: | |||
| if self._completed: | |||
| @@ -103,7 +103,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| finally: | |||
| self._entered = False | |||
| def respond(self, response: SendResultT | ErrorData) -> None: | |||
| def respond(self, response: SendResultT | ErrorData): | |||
| """Send a response for this request. | |||
| Must be called within a context manager block. | |||
| @@ -119,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): | |||
| self._session._send_response(request_id=self.request_id, response=response) | |||
| def cancel(self) -> None: | |||
| def cancel(self): | |||
| """Cancel this request and mark it as completed.""" | |||
| if not self._entered: | |||
| raise RuntimeError("RequestResponder must be used as a context manager") | |||
| @@ -163,7 +163,7 @@ class BaseSession( | |||
| receive_notification_type: type[ReceiveNotificationT], | |||
| # If none, reading will never time out | |||
| read_timeout_seconds: timedelta | None = None, | |||
| ) -> None: | |||
| ): | |||
| self._read_stream = read_stream | |||
| self._write_stream = write_stream | |||
| self._response_streams = {} | |||
| @@ -183,7 +183,7 @@ class BaseSession( | |||
| self._receiver_future = self._executor.submit(self._receive_loop) | |||
| return self | |||
| def check_receiver_status(self) -> None: | |||
| def check_receiver_status(self): | |||
| """`check_receiver_status` ensures that any exceptions raised during the | |||
| execution of `_receive_loop` are retrieved and propagated.""" | |||
| if self._receiver_future and self._receiver_future.done(): | |||
| @@ -191,7 +191,7 @@ class BaseSession( | |||
| def __exit__( | |||
| self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | |||
| ) -> None: | |||
| ): | |||
| self._read_stream.put(None) | |||
| self._write_stream.put(None) | |||
| @@ -277,7 +277,7 @@ class BaseSession( | |||
| self, | |||
| notification: SendNotificationT, | |||
| related_request_id: RequestId | None = None, | |||
| ) -> None: | |||
| ): | |||
| """ | |||
| Emits a notification, which is a one-way message that does not expect | |||
| a response. | |||
| @@ -296,7 +296,7 @@ class BaseSession( | |||
| ) | |||
| self._write_stream.put(session_message) | |||
| def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: | |||
| def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData): | |||
| if isinstance(response, ErrorData): | |||
| jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) | |||
| session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) | |||
| @@ -310,7 +310,7 @@ class BaseSession( | |||
| session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) | |||
| self._write_stream.put(session_message) | |||
| def _receive_loop(self) -> None: | |||
| def _receive_loop(self): | |||
| """ | |||
| Main message processing loop. | |||
| In a real synchronous implementation, this would likely run in a separate thread. | |||
| @@ -382,7 +382,7 @@ class BaseSession( | |||
| logger.exception("Error in message processing loop") | |||
| raise | |||
| def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: | |||
| def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]): | |||
| """ | |||
| Can be overridden by subclasses to handle a request without needing to | |||
| listen on the message stream. | |||
| @@ -391,15 +391,13 @@ class BaseSession( | |||
| forwarded on to the message stream. | |||
| """ | |||
| def _received_notification(self, notification: ReceiveNotificationT) -> None: | |||
| def _received_notification(self, notification: ReceiveNotificationT): | |||
| """ | |||
| Can be overridden by subclasses to handle a notification without needing | |||
| to listen on the message stream. | |||
| """ | |||
| def send_progress_notification( | |||
| self, progress_token: str | int, progress: float, total: float | None = None | |||
| ) -> None: | |||
| def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): | |||
| """ | |||
| Sends a progress notification for a request that is currently being | |||
| processed. | |||
| @@ -408,5 +406,5 @@ class BaseSession( | |||
| def _handle_incoming( | |||
| self, | |||
| req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, | |||
| ) -> None: | |||
| ): | |||
| """A generic handler for incoming messages. Overwritten by subclasses.""" | |||
| @@ -28,19 +28,19 @@ class LoggingFnT(Protocol): | |||
| def __call__( | |||
| self, | |||
| params: types.LoggingMessageNotificationParams, | |||
| ) -> None: ... | |||
| ): ... | |||
| class MessageHandlerFnT(Protocol): | |||
| def __call__( | |||
| self, | |||
| message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | |||
| ) -> None: ... | |||
| ): ... | |||
| def _default_message_handler( | |||
| message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | |||
| ) -> None: | |||
| ): | |||
| if isinstance(message, Exception): | |||
| raise ValueError(str(message)) | |||
| elif isinstance(message, (types.ServerNotification | RequestResponder)): | |||
| @@ -68,7 +68,7 @@ def _default_list_roots_callback( | |||
| def _default_logging_callback( | |||
| params: types.LoggingMessageNotificationParams, | |||
| ) -> None: | |||
| ): | |||
| pass | |||
| @@ -94,7 +94,7 @@ class ClientSession( | |||
| logging_callback: LoggingFnT | None = None, | |||
| message_handler: MessageHandlerFnT | None = None, | |||
| client_info: types.Implementation | None = None, | |||
| ) -> None: | |||
| ): | |||
| super().__init__( | |||
| read_stream, | |||
| write_stream, | |||
| @@ -155,9 +155,7 @@ class ClientSession( | |||
| types.EmptyResult, | |||
| ) | |||
| def send_progress_notification( | |||
| self, progress_token: str | int, progress: float, total: float | None = None | |||
| ) -> None: | |||
| def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): | |||
| """Send a progress notification.""" | |||
| self.send_notification( | |||
| types.ClientNotification( | |||
| @@ -314,7 +312,7 @@ class ClientSession( | |||
| types.ListToolsResult, | |||
| ) | |||
| def send_roots_list_changed(self) -> None: | |||
| def send_roots_list_changed(self): | |||
| """Send a roots/list_changed notification.""" | |||
| self.send_notification( | |||
| types.ClientNotification( | |||
| @@ -324,7 +322,7 @@ class ClientSession( | |||
| ) | |||
| ) | |||
| def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: | |||
| def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]): | |||
| ctx = RequestContext[ClientSession, Any]( | |||
| request_id=responder.request_id, | |||
| meta=responder.request_meta, | |||
| @@ -352,11 +350,11 @@ class ClientSession( | |||
| def _handle_incoming( | |||
| self, | |||
| req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, | |||
| ) -> None: | |||
| ): | |||
| """Handle incoming messages by forwarding to the message handler.""" | |||
| self._message_handler(req) | |||
| def _received_notification(self, notification: types.ServerNotification) -> None: | |||
| def _received_notification(self, notification: types.ServerNotification): | |||
| """Handle notifications from the server.""" | |||
| # Process specific notification types | |||
| match notification.root: | |||
| @@ -27,7 +27,7 @@ class TokenBufferMemory: | |||
| self, | |||
| conversation: Conversation, | |||
| model_instance: ModelInstance, | |||
| ) -> None: | |||
| ): | |||
| self.conversation = conversation | |||
| self.model_instance = model_instance | |||
| @@ -124,6 +124,7 @@ class TokenBufferMemory: | |||
| messages = list(reversed(thread_messages)) | |||
| curr_message_tokens = 0 | |||
| prompt_messages: list[PromptMessage] = [] | |||
| for message in messages: | |||
| # Process user message with files | |||