You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import threading
  6. import time
  7. from collections.abc import Callable, Mapping
  8. from pathlib import Path
  9. from typing import Any
  10. from .python_3x import http_request, makedirs_wrapper
  11. from .utils import (
  12. CONFIGURATIONS,
  13. NAMESPACE_NAME,
  14. NOTIFICATION_ID,
  15. get_value_from_dict,
  16. init_ip,
  17. no_key_cache_key,
  18. signature,
  19. url_encode_wrapper,
  20. )
  21. logger = logging.getLogger(__name__)
  22. class ApolloClient:
  23. def __init__(
  24. self,
  25. config_url: str,
  26. app_id: str,
  27. cluster: str = "default",
  28. secret: str = "",
  29. start_hot_update: bool = True,
  30. change_listener: Callable[[str, str, str, Any], None] | None = None,
  31. _notification_map: dict[str, int] | None = None,
  32. ):
  33. # Core routing parameters
  34. self.config_url = config_url
  35. self.cluster = cluster
  36. self.app_id = app_id
  37. # Non-core parameters
  38. self.ip = init_ip()
  39. self.secret = secret
  40. # Check the parameter variables
  41. # Private control variables
  42. self._cycle_time = 5
  43. self._stopping = False
  44. self._cache: dict[str, dict[str, Any]] = {}
  45. self._no_key: dict[str, str] = {}
  46. self._hash: dict[str, str] = {}
  47. self._pull_timeout = 75
  48. self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
  49. self._long_poll_thread: threading.Thread | None = None
  50. self._change_listener = change_listener # "add" "delete" "update"
  51. if _notification_map is None:
  52. _notification_map = {"application": -1}
  53. self._notification_map = _notification_map
  54. self.last_release_key: str | None = None
  55. # Private startup method
  56. self._path_checker()
  57. if start_hot_update:
  58. self._start_hot_update()
  59. # start the heartbeat thread
  60. heartbeat = threading.Thread(target=self._heart_beat)
  61. heartbeat.daemon = True
  62. heartbeat.start()
  63. def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
  64. url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
  65. self.config_url, self.app_id, self.cluster, namespace, "", self.ip
  66. )
  67. try:
  68. code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
  69. if code == 200:
  70. if not body:
  71. logger.error("get_json_from_net load configs failed, body is %s", body)
  72. return None
  73. data = json.loads(body)
  74. data = data["configurations"]
  75. return_data = {CONFIGURATIONS: data}
  76. return return_data
  77. else:
  78. return None
  79. except Exception:
  80. logger.exception("an error occurred in get_json_from_net")
  81. return None
  82. def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
  83. try:
  84. # read memory configuration
  85. namespace_cache = self._cache.get(namespace)
  86. val = get_value_from_dict(namespace_cache, key)
  87. if val is not None:
  88. return val
  89. no_key = no_key_cache_key(namespace, key)
  90. if no_key in self._no_key:
  91. return default_val
  92. # read the network configuration
  93. namespace_data = self.get_json_from_net(namespace)
  94. val = get_value_from_dict(namespace_data, key)
  95. if val is not None:
  96. if namespace_data is not None:
  97. self._update_cache_and_file(namespace_data, namespace)
  98. return val
  99. # read the file configuration
  100. namespace_cache = self._get_local_cache(namespace)
  101. val = get_value_from_dict(namespace_cache, key)
  102. if val is not None:
  103. self._update_cache_and_file(namespace_cache, namespace)
  104. return val
  105. # If all of them are not obtained, the default value is returned
  106. # and the local cache is set to None
  107. self._set_local_cache_none(namespace, key)
  108. return default_val
  109. except Exception:
  110. logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
  111. return default_val
  112. # Set the key of a namespace to none, and do not set default val
  113. # to ensure the real-time correctness of the function call.
  114. # If the user does not have the same default val twice
  115. # and the default val is used here, there may be a problem.
  116. def _set_local_cache_none(self, namespace: str, key: str) -> None:
  117. no_key = no_key_cache_key(namespace, key)
  118. self._no_key[no_key] = key
  119. def _start_hot_update(self) -> None:
  120. self._long_poll_thread = threading.Thread(target=self._listener)
  121. # When the asynchronous thread is started, the daemon thread will automatically exit
  122. # when the main thread is launched.
  123. self._long_poll_thread.daemon = True
  124. self._long_poll_thread.start()
  125. def stop(self) -> None:
  126. self._stopping = True
  127. logger.info("Stopping listener...")
  128. # Call the set callback function, and if it is abnormal, try it out
  129. def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
  130. if self._change_listener is None:
  131. return
  132. if old_kv is None:
  133. old_kv = {}
  134. if new_kv is None:
  135. new_kv = {}
  136. try:
  137. for key in old_kv:
  138. new_value = new_kv.get(key)
  139. old_value = old_kv.get(key)
  140. if new_value is None:
  141. # If newValue is empty, it means key, and the value is deleted.
  142. self._change_listener("delete", namespace, key, old_value)
  143. continue
  144. if new_value != old_value:
  145. self._change_listener("update", namespace, key, new_value)
  146. continue
  147. for key in new_kv:
  148. new_value = new_kv.get(key)
  149. old_value = old_kv.get(key)
  150. if old_value is None:
  151. self._change_listener("add", namespace, key, new_value)
  152. except BaseException as e:
  153. logger.warning(str(e))
  154. def _path_checker(self) -> None:
  155. if not os.path.isdir(self._cache_file_path):
  156. makedirs_wrapper(self._cache_file_path)
  157. # update the local cache and file cache
  158. def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
  159. # update the local cache
  160. self._cache[namespace] = namespace_data
  161. # update the file cache
  162. new_string = json.dumps(namespace_data)
  163. new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
  164. if self._hash.get(namespace) == new_hash:
  165. pass
  166. else:
  167. file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
  168. file_path.write_text(new_string)
  169. self._hash[namespace] = new_hash
  170. # get the configuration from the local file
  171. def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
  172. cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
  173. if os.path.isfile(cache_file_path):
  174. with open(cache_file_path) as f:
  175. result = json.loads(f.readline())
  176. return result
  177. return {}
  178. def _long_poll(self) -> None:
  179. notifications: list[dict[str, Any]] = []
  180. for key in self._cache:
  181. namespace_data = self._cache[key]
  182. notification_id = -1
  183. if NOTIFICATION_ID in namespace_data:
  184. notification_id = self._cache[key][NOTIFICATION_ID]
  185. notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
  186. try:
  187. # if the length is 0 it is returned directly
  188. if len(notifications) == 0:
  189. return
  190. url = f"{self.config_url}/notifications/v2"
  191. params = {
  192. "appId": self.app_id,
  193. "cluster": self.cluster,
  194. "notifications": json.dumps(notifications, ensure_ascii=False),
  195. }
  196. param_str = url_encode_wrapper(params)
  197. url = url + "?" + param_str
  198. code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
  199. http_code = code
  200. if http_code == 304:
  201. logger.debug("No change, loop...")
  202. return
  203. if http_code == 200:
  204. if not body:
  205. logger.error("_long_poll load configs failed,body is %s", body)
  206. return
  207. data = json.loads(body)
  208. for entry in data:
  209. namespace = entry[NAMESPACE_NAME]
  210. n_id = entry[NOTIFICATION_ID]
  211. logger.info("%s has changes: notificationId=%d", namespace, n_id)
  212. self._get_net_and_set_local(namespace, n_id, call_change=True)
  213. return
  214. else:
  215. logger.warning("Sleep...")
  216. except Exception as e:
  217. logger.warning(str(e))
  218. def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
  219. namespace_data = self.get_json_from_net(namespace)
  220. if not namespace_data:
  221. return
  222. namespace_data[NOTIFICATION_ID] = n_id
  223. old_namespace = self._cache.get(namespace)
  224. self._update_cache_and_file(namespace_data, namespace)
  225. if self._change_listener is not None and call_change and old_namespace:
  226. old_kv = old_namespace.get(CONFIGURATIONS)
  227. new_kv = namespace_data.get(CONFIGURATIONS)
  228. self._call_listener(namespace, old_kv, new_kv)
  229. def _listener(self) -> None:
  230. logger.info("start long_poll")
  231. while not self._stopping:
  232. self._long_poll()
  233. time.sleep(self._cycle_time)
  234. logger.info("stopped, long_poll")
  235. # add the need for endorsement to the header
  236. def _sign_headers(self, url: str) -> Mapping[str, str]:
  237. headers: dict[str, str] = {}
  238. if self.secret == "":
  239. return headers
  240. uri = url[len(self.config_url) : len(url)]
  241. time_unix_now = str(int(round(time.time() * 1000)))
  242. headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
  243. headers["Timestamp"] = time_unix_now
  244. return headers
  245. def _heart_beat(self) -> None:
  246. while not self._stopping:
  247. for namespace in self._notification_map:
  248. self._do_heart_beat(namespace)
  249. time.sleep(60 * 10) # 10 minutes
  250. def _do_heart_beat(self, namespace: str) -> None:
  251. url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
  252. try:
  253. code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
  254. if code == 200:
  255. if not body:
  256. logger.error("_do_heart_beat load configs failed,body is %s", body)
  257. return None
  258. data = json.loads(body)
  259. if self.last_release_key == data["releaseKey"]:
  260. return None
  261. self.last_release_key = data["releaseKey"]
  262. data = data["configurations"]
  263. self._update_cache_and_file(data, namespace)
  264. else:
  265. return None
  266. except Exception:
  267. logger.exception("an error occurred in _do_heart_beat")
  268. return None
  269. def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
  270. namespace_data = self._cache.get(namespace)
  271. if namespace_data is None:
  272. net_namespace_data = self.get_json_from_net(namespace)
  273. if not net_namespace_data:
  274. return namespace_data
  275. namespace_data = net_namespace_data.get(CONFIGURATIONS)
  276. if namespace_data:
  277. self._update_cache_and_file(namespace_data, namespace)
  278. return namespace_data