You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from abc import ABC
  17. import builtins
  18. import json
  19. import os
  20. from copy import deepcopy
  21. from functools import partial
  22. from typing import List, Dict, Tuple, Union
  23. import pandas as pd
  24. from agent import settings
  25. from agent.settings import flow_logger, DEBUG
  26. _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
  27. _DEPRECATED_PARAMS = "_deprecated_params"
  28. _USER_FEEDED_PARAMS = "_user_feeded_params"
  29. _IS_RAW_CONF = "_is_raw_conf"
  30. class ComponentParamBase(ABC):
  31. def __init__(self):
  32. self.output_var_name = "output"
  33. self.message_history_window_size = 22
  34. def set_name(self, name: str):
  35. self._name = name
  36. return self
  37. def check(self):
  38. raise NotImplementedError("Parameter Object should be checked.")
  39. @classmethod
  40. def _get_or_init_deprecated_params_set(cls):
  41. if not hasattr(cls, _DEPRECATED_PARAMS):
  42. setattr(cls, _DEPRECATED_PARAMS, set())
  43. return getattr(cls, _DEPRECATED_PARAMS)
  44. def _get_or_init_feeded_deprecated_params_set(self, conf=None):
  45. if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
  46. if conf is None:
  47. setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
  48. else:
  49. setattr(
  50. self,
  51. _FEEDED_DEPRECATED_PARAMS,
  52. set(conf[_FEEDED_DEPRECATED_PARAMS]),
  53. )
  54. return getattr(self, _FEEDED_DEPRECATED_PARAMS)
  55. def _get_or_init_user_feeded_params_set(self, conf=None):
  56. if not hasattr(self, _USER_FEEDED_PARAMS):
  57. if conf is None:
  58. setattr(self, _USER_FEEDED_PARAMS, set())
  59. else:
  60. setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
  61. return getattr(self, _USER_FEEDED_PARAMS)
  62. def get_user_feeded(self):
  63. return self._get_or_init_user_feeded_params_set()
  64. def get_feeded_deprecated_params(self):
  65. return self._get_or_init_feeded_deprecated_params_set()
  66. @property
  67. def _deprecated_params_set(self):
  68. return {name: True for name in self.get_feeded_deprecated_params()}
  69. def __str__(self):
  70. return json.dumps(self.as_dict(), ensure_ascii=False)
  71. def as_dict(self):
  72. def _recursive_convert_obj_to_dict(obj):
  73. ret_dict = {}
  74. for attr_name in list(obj.__dict__):
  75. if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
  76. continue
  77. # get attr
  78. attr = getattr(obj, attr_name)
  79. if isinstance(attr, pd.DataFrame):
  80. ret_dict[attr_name] = attr.to_dict()
  81. continue
  82. if attr and type(attr).__name__ not in dir(builtins):
  83. ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
  84. else:
  85. ret_dict[attr_name] = attr
  86. return ret_dict
  87. return _recursive_convert_obj_to_dict(self)
  88. def update(self, conf, allow_redundant=False):
  89. update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
  90. if update_from_raw_conf:
  91. deprecated_params_set = self._get_or_init_deprecated_params_set()
  92. feeded_deprecated_params_set = (
  93. self._get_or_init_feeded_deprecated_params_set()
  94. )
  95. user_feeded_params_set = self._get_or_init_user_feeded_params_set()
  96. setattr(self, _IS_RAW_CONF, False)
  97. else:
  98. feeded_deprecated_params_set = (
  99. self._get_or_init_feeded_deprecated_params_set(conf)
  100. )
  101. user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
  102. def _recursive_update_param(param, config, depth, prefix):
  103. if depth > settings.PARAM_MAXDEPTH:
  104. raise ValueError("Param define nesting too deep!!!, can not parse it")
  105. inst_variables = param.__dict__
  106. redundant_attrs = []
  107. for config_key, config_value in config.items():
  108. # redundant attr
  109. if config_key not in inst_variables:
  110. if not update_from_raw_conf and config_key.startswith("_"):
  111. setattr(param, config_key, config_value)
  112. else:
  113. setattr(param, config_key, config_value)
  114. # redundant_attrs.append(config_key)
  115. continue
  116. full_config_key = f"{prefix}{config_key}"
  117. if update_from_raw_conf:
  118. # add user feeded params
  119. user_feeded_params_set.add(full_config_key)
  120. # update user feeded deprecated param set
  121. if full_config_key in deprecated_params_set:
  122. feeded_deprecated_params_set.add(full_config_key)
  123. # supported attr
  124. attr = getattr(param, config_key)
  125. if type(attr).__name__ in dir(builtins) or attr is None:
  126. setattr(param, config_key, config_value)
  127. else:
  128. # recursive set obj attr
  129. sub_params = _recursive_update_param(
  130. attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
  131. )
  132. setattr(param, config_key, sub_params)
  133. if not allow_redundant and redundant_attrs:
  134. raise ValueError(
  135. f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
  136. )
  137. return param
  138. return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
  139. def extract_not_builtin(self):
  140. def _get_not_builtin_types(obj):
  141. ret_dict = {}
  142. for variable in obj.__dict__:
  143. attr = getattr(obj, variable)
  144. if attr and type(attr).__name__ not in dir(builtins):
  145. ret_dict[variable] = _get_not_builtin_types(attr)
  146. return ret_dict
  147. return _get_not_builtin_types(self)
  148. def validate(self):
  149. self.builtin_types = dir(builtins)
  150. self.func = {
  151. "ge": self._greater_equal_than,
  152. "le": self._less_equal_than,
  153. "in": self._in,
  154. "not_in": self._not_in,
  155. "range": self._range,
  156. }
  157. home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
  158. param_validation_path_prefix = home_dir + "/param_validation/"
  159. param_name = type(self).__name__
  160. param_validation_path = "/".join(
  161. [param_validation_path_prefix, param_name + ".json"]
  162. )
  163. validation_json = None
  164. try:
  165. with open(param_validation_path, "r") as fin:
  166. validation_json = json.loads(fin.read())
  167. except BaseException:
  168. return
  169. self._validate_param(self, validation_json)
  170. def _validate_param(self, param_obj, validation_json):
  171. default_section = type(param_obj).__name__
  172. var_list = param_obj.__dict__
  173. for variable in var_list:
  174. attr = getattr(param_obj, variable)
  175. if type(attr).__name__ in self.builtin_types or attr is None:
  176. if variable not in validation_json:
  177. continue
  178. validation_dict = validation_json[default_section][variable]
  179. value = getattr(param_obj, variable)
  180. value_legal = False
  181. for op_type in validation_dict:
  182. if self.func[op_type](value, validation_dict[op_type]):
  183. value_legal = True
  184. break
  185. if not value_legal:
  186. raise ValueError(
  187. "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
  188. variable, value
  189. )
  190. )
  191. elif variable in validation_json:
  192. self._validate_param(attr, validation_json)
  193. @staticmethod
  194. def check_string(param, descr):
  195. if type(param).__name__ not in ["str"]:
  196. raise ValueError(
  197. descr + " {} not supported, should be string type".format(param)
  198. )
  199. @staticmethod
  200. def check_empty(param, descr):
  201. if not param:
  202. raise ValueError(
  203. descr + " does not support empty value."
  204. )
  205. @staticmethod
  206. def check_positive_integer(param, descr):
  207. if type(param).__name__ not in ["int", "long"] or param <= 0:
  208. raise ValueError(
  209. descr + " {} not supported, should be positive integer".format(param)
  210. )
  211. @staticmethod
  212. def check_positive_number(param, descr):
  213. if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
  214. raise ValueError(
  215. descr + " {} not supported, should be positive numeric".format(param)
  216. )
  217. @staticmethod
  218. def check_nonnegative_number(param, descr):
  219. if type(param).__name__ not in ["float", "int", "long"] or param < 0:
  220. raise ValueError(
  221. descr
  222. + " {} not supported, should be non-negative numeric".format(param)
  223. )
  224. @staticmethod
  225. def check_decimal_float(param, descr):
  226. if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
  227. raise ValueError(
  228. descr
  229. + " {} not supported, should be a float number in range [0, 1]".format(
  230. param
  231. )
  232. )
  233. @staticmethod
  234. def check_boolean(param, descr):
  235. if type(param).__name__ != "bool":
  236. raise ValueError(
  237. descr + " {} not supported, should be bool type".format(param)
  238. )
  239. @staticmethod
  240. def check_open_unit_interval(param, descr):
  241. if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
  242. raise ValueError(
  243. descr + " should be a numeric number between 0 and 1 exclusively"
  244. )
  245. @staticmethod
  246. def check_valid_value(param, descr, valid_values):
  247. if param not in valid_values:
  248. raise ValueError(
  249. descr
  250. + " {} is not supported, it should be in {}".format(param, valid_values)
  251. )
  252. @staticmethod
  253. def check_defined_type(param, descr, types):
  254. if type(param).__name__ not in types:
  255. raise ValueError(
  256. descr + " {} not supported, should be one of {}".format(param, types)
  257. )
  258. @staticmethod
  259. def check_and_change_lower(param, valid_list, descr=""):
  260. if type(param).__name__ != "str":
  261. raise ValueError(
  262. descr
  263. + " {} not supported, should be one of {}".format(param, valid_list)
  264. )
  265. lower_param = param.lower()
  266. if lower_param in valid_list:
  267. return lower_param
  268. else:
  269. raise ValueError(
  270. descr
  271. + " {} not supported, should be one of {}".format(param, valid_list)
  272. )
  273. @staticmethod
  274. def _greater_equal_than(value, limit):
  275. return value >= limit - settings.FLOAT_ZERO
  276. @staticmethod
  277. def _less_equal_than(value, limit):
  278. return value <= limit + settings.FLOAT_ZERO
  279. @staticmethod
  280. def _range(value, ranges):
  281. in_range = False
  282. for left_limit, right_limit in ranges:
  283. if (
  284. left_limit - settings.FLOAT_ZERO
  285. <= value
  286. <= right_limit + settings.FLOAT_ZERO
  287. ):
  288. in_range = True
  289. break
  290. return in_range
  291. @staticmethod
  292. def _in(value, right_value_list):
  293. return value in right_value_list
  294. @staticmethod
  295. def _not_in(value, wrong_value_list):
  296. return value not in wrong_value_list
  297. def _warn_deprecated_param(self, param_name, descr):
  298. if self._deprecated_params_set.get(param_name):
  299. flow_logger.warning(
  300. f"{descr} {param_name} is deprecated and ignored in this version."
  301. )
  302. def _warn_to_deprecate_param(self, param_name, descr, new_param):
  303. if self._deprecated_params_set.get(param_name):
  304. flow_logger.warning(
  305. f"{descr} {param_name} will be deprecated in future release; "
  306. f"please use {new_param} instead."
  307. )
  308. return True
  309. return False
  310. class ComponentBase(ABC):
  311. component_name: str
  312. def __str__(self):
  313. """
  314. {
  315. "component_name": "Begin",
  316. "params": {}
  317. }
  318. """
  319. return """{{
  320. "component_name": "{}",
  321. "params": {}
  322. }}""".format(self.component_name,
  323. self._param
  324. )
  325. def __init__(self, canvas, id, param: ComponentParamBase):
  326. self._canvas = canvas
  327. self._id = id
  328. self._param = param
  329. self._param.check()
  330. def run(self, history, **kwargs):
  331. flow_logger.info("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
  332. json.dumps(kwargs, ensure_ascii=False)))
  333. try:
  334. res = self._run(history, **kwargs)
  335. self.set_output(res)
  336. except Exception as e:
  337. self.set_output(pd.DataFrame([{"content": str(e)}]))
  338. raise e
  339. return res
  340. def _run(self, history, **kwargs):
  341. raise NotImplementedError()
  342. def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
  343. o = getattr(self._param, self._param.output_var_name)
  344. if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
  345. if not isinstance(o, list): o = [o]
  346. o = pd.DataFrame(o)
  347. if allow_partial or not isinstance(o, partial):
  348. if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
  349. return pd.DataFrame(o if isinstance(o, list) else [o])
  350. return self._param.output_var_name, o
  351. outs = None
  352. for oo in o():
  353. if not isinstance(oo, pd.DataFrame):
  354. outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
  355. else: outs = oo
  356. return self._param.output_var_name, outs
  357. def reset(self):
  358. setattr(self._param, self._param.output_var_name, None)
  359. def set_output(self, v: pd.DataFrame):
  360. setattr(self._param, self._param.output_var_name, v)
  361. def get_input(self):
  362. upstream_outs = []
  363. reversed_cpnts = []
  364. if len(self._canvas.path) > 1:
  365. reversed_cpnts.extend(self._canvas.path[-2])
  366. reversed_cpnts.extend(self._canvas.path[-1])
  367. if DEBUG: print(self.component_name, reversed_cpnts[::-1])
  368. for u in reversed_cpnts[::-1]:
  369. if self.get_component_name(u) in ["switch"]: continue
  370. if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
  371. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  372. if o is not None:
  373. upstream_outs.append(o)
  374. continue
  375. if u not in self._canvas.get_component(self._id)["upstream"]: continue
  376. if self.component_name.lower().find("switch") < 0 \
  377. and self.get_component_name(u) in ["relevant", "categorize"]:
  378. continue
  379. if u.lower().find("answer") >= 0:
  380. for r, c in self._canvas.history[::-1]:
  381. if r == "user":
  382. upstream_outs.append(pd.DataFrame([{"content": c}]))
  383. break
  384. break
  385. if self.component_name.lower().find("answer") >= 0:
  386. if self.get_component_name(u) in ["relevant"]:
  387. continue
  388. else:
  389. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  390. if o is not None:
  391. upstream_outs.append(o)
  392. break
  393. if upstream_outs:
  394. df = pd.concat(upstream_outs, ignore_index=True)
  395. if "content" in df:
  396. df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
  397. return df
  398. return pd.DataFrame()
  399. def get_stream_input(self):
  400. reversed_cpnts = []
  401. if len(self._canvas.path) > 1:
  402. reversed_cpnts.extend(self._canvas.path[-2])
  403. reversed_cpnts.extend(self._canvas.path[-1])
  404. for u in reversed_cpnts[::-1]:
  405. if self.get_component_name(u) in ["switch", "answer"]: continue
  406. return self._canvas.get_component(u)["obj"].output()[1]
  407. @staticmethod
  408. def be_output(v):
  409. return pd.DataFrame([{"content": v}])
  410. def get_component_name(self, cpn_id):
  411. return self._canvas.get_component(cpn_id)["obj"].component_name.lower()