Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

base.py 19KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import time
  18. from abc import ABC, abstractmethod
  19. import builtins
  20. import json
  21. import os
  22. import logging
  23. from typing import Any, List, Union
  24. import pandas as pd
  25. import trio
  26. from agent import settings
  27. from api.utils.api_utils import timeout
  28. _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
  29. _DEPRECATED_PARAMS = "_deprecated_params"
  30. _USER_FEEDED_PARAMS = "_user_feeded_params"
  31. _IS_RAW_CONF = "_is_raw_conf"
  32. class ComponentParamBase(ABC):
  33. def __init__(self):
  34. self.message_history_window_size = 22
  35. self.inputs = {}
  36. self.outputs = {}
  37. self.description = ""
  38. self.max_retries = 0
  39. self.delay_after_error = 2.0
  40. self.exception_method = None
  41. self.exception_default_value = None
  42. self.exception_goto = None
  43. self.debug_inputs = {}
  44. def set_name(self, name: str):
  45. self._name = name
  46. return self
  47. def check(self):
  48. raise NotImplementedError("Parameter Object should be checked.")
  49. @classmethod
  50. def _get_or_init_deprecated_params_set(cls):
  51. if not hasattr(cls, _DEPRECATED_PARAMS):
  52. setattr(cls, _DEPRECATED_PARAMS, set())
  53. return getattr(cls, _DEPRECATED_PARAMS)
  54. def _get_or_init_feeded_deprecated_params_set(self, conf=None):
  55. if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
  56. if conf is None:
  57. setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
  58. else:
  59. setattr(
  60. self,
  61. _FEEDED_DEPRECATED_PARAMS,
  62. set(conf[_FEEDED_DEPRECATED_PARAMS]),
  63. )
  64. return getattr(self, _FEEDED_DEPRECATED_PARAMS)
  65. def _get_or_init_user_feeded_params_set(self, conf=None):
  66. if not hasattr(self, _USER_FEEDED_PARAMS):
  67. if conf is None:
  68. setattr(self, _USER_FEEDED_PARAMS, set())
  69. else:
  70. setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
  71. return getattr(self, _USER_FEEDED_PARAMS)
  72. def get_user_feeded(self):
  73. return self._get_or_init_user_feeded_params_set()
  74. def get_feeded_deprecated_params(self):
  75. return self._get_or_init_feeded_deprecated_params_set()
  76. @property
  77. def _deprecated_params_set(self):
  78. return {name: True for name in self.get_feeded_deprecated_params()}
  79. def __str__(self):
  80. return json.dumps(self.as_dict(), ensure_ascii=False)
  81. def as_dict(self):
  82. def _recursive_convert_obj_to_dict(obj):
  83. ret_dict = {}
  84. if isinstance(obj, dict):
  85. for k,v in obj.items():
  86. if isinstance(v, dict) or (v and type(v).__name__ not in dir(builtins)):
  87. ret_dict[k] = _recursive_convert_obj_to_dict(v)
  88. else:
  89. ret_dict[k] = v
  90. return ret_dict
  91. for attr_name in list(obj.__dict__):
  92. if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
  93. continue
  94. # get attr
  95. attr = getattr(obj, attr_name)
  96. if isinstance(attr, pd.DataFrame):
  97. ret_dict[attr_name] = attr.to_dict()
  98. continue
  99. if isinstance(attr, dict) or (attr and type(attr).__name__ not in dir(builtins)):
  100. ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
  101. else:
  102. ret_dict[attr_name] = attr
  103. return ret_dict
  104. return _recursive_convert_obj_to_dict(self)
  105. def update(self, conf, allow_redundant=False):
  106. update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
  107. if update_from_raw_conf:
  108. deprecated_params_set = self._get_or_init_deprecated_params_set()
  109. feeded_deprecated_params_set = (
  110. self._get_or_init_feeded_deprecated_params_set()
  111. )
  112. user_feeded_params_set = self._get_or_init_user_feeded_params_set()
  113. setattr(self, _IS_RAW_CONF, False)
  114. else:
  115. feeded_deprecated_params_set = (
  116. self._get_or_init_feeded_deprecated_params_set(conf)
  117. )
  118. user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
  119. def _recursive_update_param(param, config, depth, prefix):
  120. if depth > settings.PARAM_MAXDEPTH:
  121. raise ValueError("Param define nesting too deep!!!, can not parse it")
  122. inst_variables = param.__dict__
  123. redundant_attrs = []
  124. for config_key, config_value in config.items():
  125. # redundant attr
  126. if config_key not in inst_variables:
  127. if not update_from_raw_conf and config_key.startswith("_"):
  128. setattr(param, config_key, config_value)
  129. else:
  130. setattr(param, config_key, config_value)
  131. # redundant_attrs.append(config_key)
  132. continue
  133. full_config_key = f"{prefix}{config_key}"
  134. if update_from_raw_conf:
  135. # add user feeded params
  136. user_feeded_params_set.add(full_config_key)
  137. # update user feeded deprecated param set
  138. if full_config_key in deprecated_params_set:
  139. feeded_deprecated_params_set.add(full_config_key)
  140. # supported attr
  141. attr = getattr(param, config_key)
  142. if type(attr).__name__ in dir(builtins) or attr is None:
  143. setattr(param, config_key, config_value)
  144. else:
  145. # recursive set obj attr
  146. sub_params = _recursive_update_param(
  147. attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
  148. )
  149. setattr(param, config_key, sub_params)
  150. if not allow_redundant and redundant_attrs:
  151. raise ValueError(
  152. f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
  153. )
  154. return param
  155. return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
  156. def extract_not_builtin(self):
  157. def _get_not_builtin_types(obj):
  158. ret_dict = {}
  159. for variable in obj.__dict__:
  160. attr = getattr(obj, variable)
  161. if attr and type(attr).__name__ not in dir(builtins):
  162. ret_dict[variable] = _get_not_builtin_types(attr)
  163. return ret_dict
  164. return _get_not_builtin_types(self)
  165. def validate(self):
  166. self.builtin_types = dir(builtins)
  167. self.func = {
  168. "ge": self._greater_equal_than,
  169. "le": self._less_equal_than,
  170. "in": self._in,
  171. "not_in": self._not_in,
  172. "range": self._range,
  173. }
  174. home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
  175. param_validation_path_prefix = home_dir + "/param_validation/"
  176. param_name = type(self).__name__
  177. param_validation_path = "/".join(
  178. [param_validation_path_prefix, param_name + ".json"]
  179. )
  180. validation_json = None
  181. try:
  182. with open(param_validation_path, "r") as fin:
  183. validation_json = json.loads(fin.read())
  184. except BaseException:
  185. return
  186. self._validate_param(self, validation_json)
  187. def _validate_param(self, param_obj, validation_json):
  188. default_section = type(param_obj).__name__
  189. var_list = param_obj.__dict__
  190. for variable in var_list:
  191. attr = getattr(param_obj, variable)
  192. if type(attr).__name__ in self.builtin_types or attr is None:
  193. if variable not in validation_json:
  194. continue
  195. validation_dict = validation_json[default_section][variable]
  196. value = getattr(param_obj, variable)
  197. value_legal = False
  198. for op_type in validation_dict:
  199. if self.func[op_type](value, validation_dict[op_type]):
  200. value_legal = True
  201. break
  202. if not value_legal:
  203. raise ValueError(
  204. "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
  205. variable, value
  206. )
  207. )
  208. elif variable in validation_json:
  209. self._validate_param(attr, validation_json)
  210. @staticmethod
  211. def check_string(param, descr):
  212. if type(param).__name__ not in ["str"]:
  213. raise ValueError(
  214. descr + " {} not supported, should be string type".format(param)
  215. )
  216. @staticmethod
  217. def check_empty(param, descr):
  218. if not param:
  219. raise ValueError(
  220. descr + " does not support empty value."
  221. )
  222. @staticmethod
  223. def check_positive_integer(param, descr):
  224. if type(param).__name__ not in ["int", "long"] or param <= 0:
  225. raise ValueError(
  226. descr + " {} not supported, should be positive integer".format(param)
  227. )
  228. @staticmethod
  229. def check_positive_number(param, descr):
  230. if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
  231. raise ValueError(
  232. descr + " {} not supported, should be positive numeric".format(param)
  233. )
  234. @staticmethod
  235. def check_nonnegative_number(param, descr):
  236. if type(param).__name__ not in ["float", "int", "long"] or param < 0:
  237. raise ValueError(
  238. descr
  239. + " {} not supported, should be non-negative numeric".format(param)
  240. )
  241. @staticmethod
  242. def check_decimal_float(param, descr):
  243. if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
  244. raise ValueError(
  245. descr
  246. + " {} not supported, should be a float number in range [0, 1]".format(
  247. param
  248. )
  249. )
  250. @staticmethod
  251. def check_boolean(param, descr):
  252. if type(param).__name__ != "bool":
  253. raise ValueError(
  254. descr + " {} not supported, should be bool type".format(param)
  255. )
  256. @staticmethod
  257. def check_open_unit_interval(param, descr):
  258. if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
  259. raise ValueError(
  260. descr + " should be a numeric number between 0 and 1 exclusively"
  261. )
  262. @staticmethod
  263. def check_valid_value(param, descr, valid_values):
  264. if param not in valid_values:
  265. raise ValueError(
  266. descr
  267. + " {} is not supported, it should be in {}".format(param, valid_values)
  268. )
  269. @staticmethod
  270. def check_defined_type(param, descr, types):
  271. if type(param).__name__ not in types:
  272. raise ValueError(
  273. descr + " {} not supported, should be one of {}".format(param, types)
  274. )
  275. @staticmethod
  276. def check_and_change_lower(param, valid_list, descr=""):
  277. if type(param).__name__ != "str":
  278. raise ValueError(
  279. descr
  280. + " {} not supported, should be one of {}".format(param, valid_list)
  281. )
  282. lower_param = param.lower()
  283. if lower_param in valid_list:
  284. return lower_param
  285. else:
  286. raise ValueError(
  287. descr
  288. + " {} not supported, should be one of {}".format(param, valid_list)
  289. )
  290. @staticmethod
  291. def _greater_equal_than(value, limit):
  292. return value >= limit - settings.FLOAT_ZERO
  293. @staticmethod
  294. def _less_equal_than(value, limit):
  295. return value <= limit + settings.FLOAT_ZERO
  296. @staticmethod
  297. def _range(value, ranges):
  298. in_range = False
  299. for left_limit, right_limit in ranges:
  300. if (
  301. left_limit - settings.FLOAT_ZERO
  302. <= value
  303. <= right_limit + settings.FLOAT_ZERO
  304. ):
  305. in_range = True
  306. break
  307. return in_range
  308. @staticmethod
  309. def _in(value, right_value_list):
  310. return value in right_value_list
  311. @staticmethod
  312. def _not_in(value, wrong_value_list):
  313. return value not in wrong_value_list
  314. def _warn_deprecated_param(self, param_name, descr):
  315. if self._deprecated_params_set.get(param_name):
  316. logging.warning(
  317. f"{descr} {param_name} is deprecated and ignored in this version."
  318. )
  319. def _warn_to_deprecate_param(self, param_name, descr, new_param):
  320. if self._deprecated_params_set.get(param_name):
  321. logging.warning(
  322. f"{descr} {param_name} will be deprecated in future release; "
  323. f"please use {new_param} instead."
  324. )
  325. return True
  326. return False
  327. class ComponentBase(ABC):
  328. component_name: str
  329. thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
  330. variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys\.[a-z_]+)\} *\}*"
  331. def __str__(self):
  332. """
  333. {
  334. "component_name": "Begin",
  335. "params": {}
  336. }
  337. """
  338. return """{{
  339. "component_name": "{}",
  340. "params": {}
  341. }}""".format(self.component_name,
  342. self._param
  343. )
  344. def __init__(self, canvas, id, param: ComponentParamBase):
  345. from agent.canvas import Canvas # Local import to avoid cyclic dependency
  346. assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
  347. self._canvas = canvas
  348. self._id = id
  349. self._param = param
  350. self._param.check()
  351. def invoke(self, **kwargs) -> dict[str, Any]:
  352. self.set_output("_created_time", time.perf_counter())
  353. try:
  354. self._invoke(**kwargs)
  355. except Exception as e:
  356. if self.get_exception_default_value():
  357. self.set_exception_default_value()
  358. else:
  359. self.set_output("_ERROR", str(e))
  360. logging.exception(e)
  361. self._param.debug_inputs = {}
  362. self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
  363. return self.output()
  364. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
  365. def _invoke(self, **kwargs):
  366. raise NotImplementedError()
  367. def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
  368. if var_nm:
  369. return self._param.outputs.get(var_nm, {}).get("value", "")
  370. return {k: o.get("value") for k,o in self._param.outputs.items()}
  371. def set_output(self, key: str, value: Any):
  372. if key not in self._param.outputs:
  373. self._param.outputs[key] = {"value": None, "type": str(type(value))}
  374. self._param.outputs[key]["value"] = value
  375. def error(self):
  376. return self._param.outputs.get("_ERROR", {}).get("value")
  377. def reset(self):
  378. for k in self._param.outputs.keys():
  379. self._param.outputs[k]["value"] = None
  380. for k in self._param.inputs.keys():
  381. self._param.inputs[k]["value"] = None
  382. self._param.debug_inputs = {}
  383. def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
  384. if key:
  385. return self._param.inputs.get(key, {}).get("value")
  386. res = {}
  387. for var, o in self.get_input_elements().items():
  388. v = self.get_param(var)
  389. if v is None:
  390. continue
  391. if isinstance(v, str) and self._canvas.is_reff(v):
  392. self.set_input_value(var, self._canvas.get_variable_value(v))
  393. else:
  394. self.set_input_value(var, v)
  395. res[var] = self.get_input_value(var)
  396. return res
  397. def get_input_values(self) -> Union[Any, dict[str, Any]]:
  398. if self._param.debug_inputs:
  399. return self._param.debug_inputs
  400. return {var: self.get_input_value(var) for var, o in self.get_input_elements().items()}
  401. def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]:
  402. res = {}
  403. for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE):
  404. exp = r.group(1)
  405. cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp)
  406. res[exp] = {
  407. "name": (self._canvas.get_component_name(cpn_id) +f"@{var_nm}") if cpn_id else exp,
  408. "value": self._canvas.get_variable_value(exp),
  409. "_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
  410. "_cpn_id": cpn_id
  411. }
  412. return res
  413. def get_input_elements(self) -> dict[str, Any]:
  414. return self._param.inputs
  415. def get_input_form(self) -> dict[str, dict]:
  416. return self._param.get_input_form()
  417. def set_input_value(self, key: str, value: Any) -> None:
  418. if key not in self._param.inputs:
  419. self._param.inputs[key] = {"value": None}
  420. self._param.inputs[key]["value"] = value
  421. def get_input_value(self, key: str) -> Any:
  422. if key not in self._param.inputs:
  423. return None
  424. return self._param.inputs[key].get("value")
  425. def get_component_name(self, cpn_id) -> str:
  426. return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
  427. def get_param(self, name):
  428. if hasattr(self._param, name):
  429. return getattr(self._param, name)
  430. def debug(self, **kwargs):
  431. return self._invoke(**kwargs)
  432. def get_parent(self) -> Union[object, None]:
  433. pid = self._canvas.get_component(self._id).get("parent_id")
  434. if not pid:
  435. return
  436. return self._canvas.get_component(pid)["obj"]
  437. def get_upstream(self) -> List[str]:
  438. cpn_nms = self._canvas.get_component(self._id)['upstream']
  439. return cpn_nms
  440. @staticmethod
  441. def string_format(content: str, kv: dict[str, str]) -> str:
  442. for n, v in kv.items():
  443. def repl(_match, val=v):
  444. return str(val) if val is not None else ""
  445. content = re.sub(
  446. r"\{%s\}" % re.escape(n),
  447. repl,
  448. content
  449. )
  450. return content
  451. def exception_handler(self):
  452. if not self._param.exception_method:
  453. return
  454. return {
  455. "goto": self._param.exception_goto,
  456. "default_value": self._param.exception_default_value
  457. }
  458. def get_exception_default_value(self):
  459. if self._param.exception_method != "comment":
  460. return ""
  461. return self._param.exception_default_value
  462. def set_exception_default_value(self):
  463. self.set_output("result", self.get_exception_default_value())
  464. @abstractmethod
  465. def thoughts(self) -> str:
  466. ...