Kaynağa Gözat

fix: resolve typing errors in configs module (#25268)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/2.0.0-beta.2^2
-LAN- 1 ay önce
ebeveyn
işleme
b05245eab0
No account linked to committer's email address

+ 1
- 2
api/configs/middleware/__init__.py Dosyayı Görüntüle



class MiddlewareConfig( class MiddlewareConfig(
# place the configs in alphabet order # place the configs in alphabet order
CeleryConfig,
DatabaseConfig,
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
KeywordStoreConfig, KeywordStoreConfig,
RedisConfig, RedisConfig,
# configs of storage and storage providers # configs of storage and storage providers

+ 3
- 2
api/configs/middleware/vdb/clickzetta_config.py Dosyayı Görüntüle

from typing import Optional from typing import Optional


from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings




class ClickzettaConfig(BaseModel):
class ClickzettaConfig(BaseSettings):
""" """
Clickzetta Lakehouse vector database configuration Clickzetta Lakehouse vector database configuration
""" """

+ 3
- 2
api/configs/middleware/vdb/matrixone_config.py Dosyayı Görüntüle

from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings




class MatrixoneConfig(BaseModel):
class MatrixoneConfig(BaseSettings):
"""Matrixone vector database configuration.""" """Matrixone vector database configuration."""


MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server") MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")

+ 1
- 1
api/configs/packaging/__init__.py Dosyayı Görüntüle

from pydantic import Field from pydantic import Field


from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
from configs.packaging.pyproject import PyProjectTomlConfig




class PackagingInfo(PyProjectTomlConfig): class PackagingInfo(PyProjectTomlConfig):

+ 32
- 30
api/configs/remote_settings_sources/apollo/client.py Dosyayı Görüntüle

import os import os
import threading import threading
import time import time
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from pathlib import Path from pathlib import Path
from typing import Any


from .python_3x import http_request, makedirs_wrapper from .python_3x import http_request, makedirs_wrapper
from .utils import ( from .utils import (
class ApolloClient: class ApolloClient:
def __init__( def __init__(
self, self,
config_url,
app_id,
cluster="default",
secret="",
start_hot_update=True,
change_listener=None,
_notification_map=None,
config_url: str,
app_id: str,
cluster: str = "default",
secret: str = "",
start_hot_update: bool = True,
change_listener: Callable[[str, str, str, Any], None] | None = None,
_notification_map: dict[str, int] | None = None,
): ):
# Core routing parameters # Core routing parameters
self.config_url = config_url self.config_url = config_url
# Private control variables # Private control variables
self._cycle_time = 5 self._cycle_time = 5
self._stopping = False self._stopping = False
self._cache = {}
self._no_key = {}
self._hash = {}
self._cache: dict[str, dict[str, Any]] = {}
self._no_key: dict[str, str] = {}
self._hash: dict[str, str] = {}
self._pull_timeout = 75 self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/" self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None
self._long_poll_thread: threading.Thread | None = None
self._change_listener = change_listener # "add" "delete" "update" self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None: if _notification_map is None:
_notification_map = {"application": -1} _notification_map = {"application": -1}
self._notification_map = _notification_map self._notification_map = _notification_map
self.last_release_key = None
self.last_release_key: str | None = None
# Private startup method # Private startup method
self._path_checker() self._path_checker()
if start_hot_update: if start_hot_update:
heartbeat.daemon = True heartbeat.daemon = True
heartbeat.start() heartbeat.start()


def get_json_from_net(self, namespace="application"):
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format( url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip self.config_url, self.app_id, self.cluster, namespace, "", self.ip
) )
logger.exception("an error occurred in get_json_from_net") logger.exception("an error occurred in get_json_from_net")
return None return None


def get_value(self, key, default_val=None, namespace="application"):
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
try: try:
# read memory configuration # read memory configuration
namespace_cache = self._cache.get(namespace) namespace_cache = self._cache.get(namespace)
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key) val = get_value_from_dict(namespace_data, key)
if val is not None: if val is not None:
self._update_cache_and_file(namespace_data, namespace)
if namespace_data is not None:
self._update_cache_and_file(namespace_data, namespace)
return val return val


# read the file configuration # read the file configuration
# to ensure the real-time correctness of the function call. # to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice # If the user does not have the same default val twice
# and the default val is used here, there may be a problem. # and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key):
def _set_local_cache_none(self, namespace: str, key: str) -> None:
no_key = no_key_cache_key(namespace, key) no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key self._no_key[no_key] = key


def _start_hot_update(self):
def _start_hot_update(self) -> None:
self._long_poll_thread = threading.Thread(target=self._listener) self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit # When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched. # when the main thread is launched.
self._long_poll_thread.daemon = True self._long_poll_thread.daemon = True
self._long_poll_thread.start() self._long_poll_thread.start()


def stop(self):
def stop(self) -> None:
self._stopping = True self._stopping = True
logger.info("Stopping listener...") logger.info("Stopping listener...")


# Call the set callback function, and if it is abnormal, try it out # Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv):
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
if self._change_listener is None: if self._change_listener is None:
return return
if old_kv is None: if old_kv is None:
except BaseException as e: except BaseException as e:
logger.warning(str(e)) logger.warning(str(e))


def _path_checker(self):
def _path_checker(self) -> None:
if not os.path.isdir(self._cache_file_path): if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path) makedirs_wrapper(self._cache_file_path)


# update the local cache and file cache # update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"):
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
# update the local cache # update the local cache
self._cache[namespace] = namespace_data self._cache[namespace] = namespace_data
# update the file cache # update the file cache
self._hash[namespace] = new_hash self._hash[namespace] = new_hash


# get the configuration from the local file # get the configuration from the local file
def _get_local_cache(self, namespace="application"):
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt") cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path): if os.path.isfile(cache_file_path):
with open(cache_file_path) as f: with open(cache_file_path) as f:
return result return result
return {} return {}


def _long_poll(self):
notifications = []
def _long_poll(self) -> None:
notifications: list[dict[str, Any]] = []
for key in self._cache: for key in self._cache:
namespace_data = self._cache[key] namespace_data = self._cache[key]
notification_id = -1 notification_id = -1
except Exception as e: except Exception as e:
logger.warning(str(e)) logger.warning(str(e))


def _get_net_and_set_local(self, namespace, n_id, call_change=False):
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
if not namespace_data: if not namespace_data:
return return
new_kv = namespace_data.get(CONFIGURATIONS) new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv) self._call_listener(namespace, old_kv, new_kv)


def _listener(self):
def _listener(self) -> None:
logger.info("start long_poll") logger.info("start long_poll")
while not self._stopping: while not self._stopping:
self._long_poll() self._long_poll()
headers["Timestamp"] = time_unix_now headers["Timestamp"] = time_unix_now
return headers return headers


def _heart_beat(self):
def _heart_beat(self) -> None:
while not self._stopping: while not self._stopping:
for namespace in self._notification_map: for namespace in self._notification_map:
self._do_heart_beat(namespace) self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10 minutes time.sleep(60 * 10) # 10 minutes


def _do_heart_beat(self, namespace):
def _do_heart_beat(self, namespace: str) -> None:
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}" url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
try: try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
logger.exception("an error occurred in _do_heart_beat") logger.exception("an error occurred in _do_heart_beat")
return None return None


def get_all_dicts(self, namespace):
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
namespace_data = self._cache.get(namespace) namespace_data = self._cache.get(namespace)
if namespace_data is None: if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace) net_namespace_data = self.get_json_from_net(namespace)

+ 6
- 4
api/configs/remote_settings_sources/apollo/python_3x.py Dosyayı Görüntüle

import os import os
import ssl import ssl
import urllib.request import urllib.request
from collections.abc import Mapping
from typing import Any
from urllib import parse from urllib import parse
from urllib.error import HTTPError from urllib.error import HTTPError


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)




def http_request(url, timeout, headers={}):
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
try: try:
request = urllib.request.Request(url, headers=headers)
request = urllib.request.Request(url, headers=dict(headers))
res = urllib.request.urlopen(request, timeout=timeout) res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8") body = res.read().decode("utf-8")
return res.code, body return res.code, body
raise e raise e




def url_encode(params):
def url_encode(params: dict[str, Any]) -> str:
return parse.urlencode(params) return parse.urlencode(params)




def makedirs_wrapper(path):
def makedirs_wrapper(path: str) -> None:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)

+ 6
- 5
api/configs/remote_settings_sources/apollo/utils.py Dosyayı Görüntüle

import hashlib import hashlib
import socket import socket
from typing import Any


from .python_3x import url_encode from .python_3x import url_encode






# add timestamps uris and keys # add timestamps uris and keys
def signature(timestamp, uri, secret):
def signature(timestamp: str, uri: str, secret: str) -> str:
import base64 import base64
import hmac import hmac


return base64.b64encode(hmac_code).decode() return base64.b64encode(hmac_code).decode()




def url_encode_wrapper(params):
def url_encode_wrapper(params: dict[str, Any]) -> str:
return url_encode(params) return url_encode(params)




def no_key_cache_key(namespace, key):
def no_key_cache_key(namespace: str, key: str) -> str:
return f"{namespace}{len(namespace)}{key}" return f"{namespace}{len(namespace)}{key}"




# Returns whether the obtained value is obtained, and None if it does not # Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key):
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
if namespace_cache: if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS) kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None: if kv_data is None:
return None return None




def init_ip():
def init_ip() -> str:
ip = "" ip = ""
s = None s = None
try: try:

+ 5
- 8
api/configs/remote_settings_sources/nacos/__init__.py Dosyayı Görüntüle



from configs.remote_settings_sources.base import RemoteSettingsSource from configs.remote_settings_sources.base import RemoteSettingsSource


from .utils import _parse_config
from .utils import parse_config




class NacosSettingsSource(RemoteSettingsSource): class NacosSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]): def __init__(self, configs: Mapping[str, Any]):
self.configs = configs self.configs = configs
self.remote_configs: dict[str, Any] = {}
self.remote_configs: dict[str, str] = {}
self.async_init() self.async_init()


def async_init(self):
def async_init(self) -> None:
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties") data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify") group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "") tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
logger.exception("[get-access-token] exception occurred") logger.exception("[get-access-token] exception occurred")
raise raise


def _parse_config(self, content: str):
def _parse_config(self, content: str) -> dict[str, str]:
if not content: if not content:
return {} return {}
try: try:
return _parse_config(self, content)
return parse_config(content)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}") raise RuntimeError(f"Failed to parse config: {e}")


def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")

field_value = self.remote_configs.get(field_name) field_value = self.remote_configs.get(field_name)
if field_value is None: if field_value is None:
return None, field_name, False return None, field_name, False

+ 15
- 7
api/configs/remote_settings_sources/nacos/http_request.py Dosyayı Görüntüle

self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY") self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY") self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848") self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
self.token = None
self.token: str | None = None
self.token_ttl = 18000 self.token_ttl = 18000
self.token_expire_time: float = 0 self.token_expire_time: float = 0


def http_request(self, url, method="GET", headers=None, params=None):
def http_request(
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
) -> str:
if headers is None:
headers = {}
if params is None:
params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
except requests.RequestException as e: except requests.RequestException as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"


def _inject_auth_info(self, headers, params, module="config"):
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"}) headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})


if module == "login": if module == "login":
headers["timeStamp"] = ts headers["timeStamp"] = ts
if self.username and self.password: if self.username and self.password:
self.get_access_token(force_refresh=False) self.get_access_token(force_refresh=False)
params["accessToken"] = self.token
if self.token is not None:
params["accessToken"] = self.token


def __do_sign(self, sign_str, sk):
def __do_sign(self, sign_str: str, sk: str) -> str:
return ( return (
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest()) base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
.decode() .decode()
.strip() .strip()
) )


def get_sign_str(self, group, tenant, ts):
def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
sign_str = "" sign_str = ""
if tenant: if tenant:
sign_str = tenant + "+" sign_str = tenant + "+"
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it. sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
return sign_str return sign_str


def get_access_token(self, force_refresh=False):
def get_access_token(self, force_refresh: bool = False) -> str | None:
current_time = time.time() current_time = time.time()
if self.token and not force_refresh and self.token_expire_time > current_time: if self.token and not force_refresh and self.token_expire_time > current_time:
return self.token return self.token
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000) self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10 self.token_expire_time = current_time + self.token_ttl - 10
return self.token
except Exception: except Exception:
logger.exception("[get-access-token] exception occur") logger.exception("[get-access-token] exception occur")
raise raise

+ 1
- 1
api/configs/remote_settings_sources/nacos/utils.py Dosyayı Görüntüle

def _parse_config(self, content: str) -> dict[str, str]:
def parse_config(content: str) -> dict[str, str]:
config: dict[str, str] = {} config: dict[str, str] = {}
if not content: if not content:
return config return config

+ 4
- 3
api/pyrightconfig.json Dosyayı Görüntüle

{ {
"include": ["."],
"include": [
"."
],
"exclude": [ "exclude": [
"tests/", "tests/",
"migrations/", "migrations/",
"events/", "events/",
"contexts/", "contexts/",
"constants/", "constants/",
"configs/",
"commands.py" "commands.py"
], ],
"typeCheckingMode": "strict", "typeCheckingMode": "strict",
"pythonVersion": "3.11", "pythonVersion": "3.11",
"pythonPlatform": "All" "pythonPlatform": "All"
}
}

Loading…
İptal
Kaydet