You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from abc import ABC
  17. import builtins
  18. import json
  19. import os
  20. import logging
  21. from functools import partial
  22. from typing import Tuple, Union
  23. import pandas as pd
  24. from agent import settings
  25. _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
  26. _DEPRECATED_PARAMS = "_deprecated_params"
  27. _USER_FEEDED_PARAMS = "_user_feeded_params"
  28. _IS_RAW_CONF = "_is_raw_conf"
  29. class ComponentParamBase(ABC):
  30. def __init__(self):
  31. self.output_var_name = "output"
  32. self.message_history_window_size = 22
  33. self.query = []
  34. self.inputs = []
  35. def set_name(self, name: str):
  36. self._name = name
  37. return self
  38. def check(self):
  39. raise NotImplementedError("Parameter Object should be checked.")
  40. @classmethod
  41. def _get_or_init_deprecated_params_set(cls):
  42. if not hasattr(cls, _DEPRECATED_PARAMS):
  43. setattr(cls, _DEPRECATED_PARAMS, set())
  44. return getattr(cls, _DEPRECATED_PARAMS)
  45. def _get_or_init_feeded_deprecated_params_set(self, conf=None):
  46. if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
  47. if conf is None:
  48. setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
  49. else:
  50. setattr(
  51. self,
  52. _FEEDED_DEPRECATED_PARAMS,
  53. set(conf[_FEEDED_DEPRECATED_PARAMS]),
  54. )
  55. return getattr(self, _FEEDED_DEPRECATED_PARAMS)
  56. def _get_or_init_user_feeded_params_set(self, conf=None):
  57. if not hasattr(self, _USER_FEEDED_PARAMS):
  58. if conf is None:
  59. setattr(self, _USER_FEEDED_PARAMS, set())
  60. else:
  61. setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
  62. return getattr(self, _USER_FEEDED_PARAMS)
  63. def get_user_feeded(self):
  64. return self._get_or_init_user_feeded_params_set()
  65. def get_feeded_deprecated_params(self):
  66. return self._get_or_init_feeded_deprecated_params_set()
  67. @property
  68. def _deprecated_params_set(self):
  69. return {name: True for name in self.get_feeded_deprecated_params()}
  70. def __str__(self):
  71. return json.dumps(self.as_dict(), ensure_ascii=False)
  72. def as_dict(self):
  73. def _recursive_convert_obj_to_dict(obj):
  74. ret_dict = {}
  75. for attr_name in list(obj.__dict__):
  76. if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
  77. continue
  78. # get attr
  79. attr = getattr(obj, attr_name)
  80. if isinstance(attr, pd.DataFrame):
  81. ret_dict[attr_name] = attr.to_dict()
  82. continue
  83. if attr and type(attr).__name__ not in dir(builtins):
  84. ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
  85. else:
  86. ret_dict[attr_name] = attr
  87. return ret_dict
  88. return _recursive_convert_obj_to_dict(self)
  89. def update(self, conf, allow_redundant=False):
  90. update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
  91. if update_from_raw_conf:
  92. deprecated_params_set = self._get_or_init_deprecated_params_set()
  93. feeded_deprecated_params_set = (
  94. self._get_or_init_feeded_deprecated_params_set()
  95. )
  96. user_feeded_params_set = self._get_or_init_user_feeded_params_set()
  97. setattr(self, _IS_RAW_CONF, False)
  98. else:
  99. feeded_deprecated_params_set = (
  100. self._get_or_init_feeded_deprecated_params_set(conf)
  101. )
  102. user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
  103. def _recursive_update_param(param, config, depth, prefix):
  104. if depth > settings.PARAM_MAXDEPTH:
  105. raise ValueError("Param define nesting too deep!!!, can not parse it")
  106. inst_variables = param.__dict__
  107. redundant_attrs = []
  108. for config_key, config_value in config.items():
  109. # redundant attr
  110. if config_key not in inst_variables:
  111. if not update_from_raw_conf and config_key.startswith("_"):
  112. setattr(param, config_key, config_value)
  113. else:
  114. setattr(param, config_key, config_value)
  115. # redundant_attrs.append(config_key)
  116. continue
  117. full_config_key = f"{prefix}{config_key}"
  118. if update_from_raw_conf:
  119. # add user feeded params
  120. user_feeded_params_set.add(full_config_key)
  121. # update user feeded deprecated param set
  122. if full_config_key in deprecated_params_set:
  123. feeded_deprecated_params_set.add(full_config_key)
  124. # supported attr
  125. attr = getattr(param, config_key)
  126. if type(attr).__name__ in dir(builtins) or attr is None:
  127. setattr(param, config_key, config_value)
  128. else:
  129. # recursive set obj attr
  130. sub_params = _recursive_update_param(
  131. attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
  132. )
  133. setattr(param, config_key, sub_params)
  134. if not allow_redundant and redundant_attrs:
  135. raise ValueError(
  136. f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
  137. )
  138. return param
  139. return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
  140. def extract_not_builtin(self):
  141. def _get_not_builtin_types(obj):
  142. ret_dict = {}
  143. for variable in obj.__dict__:
  144. attr = getattr(obj, variable)
  145. if attr and type(attr).__name__ not in dir(builtins):
  146. ret_dict[variable] = _get_not_builtin_types(attr)
  147. return ret_dict
  148. return _get_not_builtin_types(self)
  149. def validate(self):
  150. self.builtin_types = dir(builtins)
  151. self.func = {
  152. "ge": self._greater_equal_than,
  153. "le": self._less_equal_than,
  154. "in": self._in,
  155. "not_in": self._not_in,
  156. "range": self._range,
  157. }
  158. home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
  159. param_validation_path_prefix = home_dir + "/param_validation/"
  160. param_name = type(self).__name__
  161. param_validation_path = "/".join(
  162. [param_validation_path_prefix, param_name + ".json"]
  163. )
  164. validation_json = None
  165. try:
  166. with open(param_validation_path, "r") as fin:
  167. validation_json = json.loads(fin.read())
  168. except BaseException:
  169. return
  170. self._validate_param(self, validation_json)
  171. def _validate_param(self, param_obj, validation_json):
  172. default_section = type(param_obj).__name__
  173. var_list = param_obj.__dict__
  174. for variable in var_list:
  175. attr = getattr(param_obj, variable)
  176. if type(attr).__name__ in self.builtin_types or attr is None:
  177. if variable not in validation_json:
  178. continue
  179. validation_dict = validation_json[default_section][variable]
  180. value = getattr(param_obj, variable)
  181. value_legal = False
  182. for op_type in validation_dict:
  183. if self.func[op_type](value, validation_dict[op_type]):
  184. value_legal = True
  185. break
  186. if not value_legal:
  187. raise ValueError(
  188. "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
  189. variable, value
  190. )
  191. )
  192. elif variable in validation_json:
  193. self._validate_param(attr, validation_json)
  194. @staticmethod
  195. def check_string(param, descr):
  196. if type(param).__name__ not in ["str"]:
  197. raise ValueError(
  198. descr + " {} not supported, should be string type".format(param)
  199. )
  200. @staticmethod
  201. def check_empty(param, descr):
  202. if not param:
  203. raise ValueError(
  204. descr + " does not support empty value."
  205. )
  206. @staticmethod
  207. def check_positive_integer(param, descr):
  208. if type(param).__name__ not in ["int", "long"] or param <= 0:
  209. raise ValueError(
  210. descr + " {} not supported, should be positive integer".format(param)
  211. )
  212. @staticmethod
  213. def check_positive_number(param, descr):
  214. if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
  215. raise ValueError(
  216. descr + " {} not supported, should be positive numeric".format(param)
  217. )
  218. @staticmethod
  219. def check_nonnegative_number(param, descr):
  220. if type(param).__name__ not in ["float", "int", "long"] or param < 0:
  221. raise ValueError(
  222. descr
  223. + " {} not supported, should be non-negative numeric".format(param)
  224. )
  225. @staticmethod
  226. def check_decimal_float(param, descr):
  227. if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
  228. raise ValueError(
  229. descr
  230. + " {} not supported, should be a float number in range [0, 1]".format(
  231. param
  232. )
  233. )
  234. @staticmethod
  235. def check_boolean(param, descr):
  236. if type(param).__name__ != "bool":
  237. raise ValueError(
  238. descr + " {} not supported, should be bool type".format(param)
  239. )
  240. @staticmethod
  241. def check_open_unit_interval(param, descr):
  242. if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
  243. raise ValueError(
  244. descr + " should be a numeric number between 0 and 1 exclusively"
  245. )
  246. @staticmethod
  247. def check_valid_value(param, descr, valid_values):
  248. if param not in valid_values:
  249. raise ValueError(
  250. descr
  251. + " {} is not supported, it should be in {}".format(param, valid_values)
  252. )
  253. @staticmethod
  254. def check_defined_type(param, descr, types):
  255. if type(param).__name__ not in types:
  256. raise ValueError(
  257. descr + " {} not supported, should be one of {}".format(param, types)
  258. )
  259. @staticmethod
  260. def check_and_change_lower(param, valid_list, descr=""):
  261. if type(param).__name__ != "str":
  262. raise ValueError(
  263. descr
  264. + " {} not supported, should be one of {}".format(param, valid_list)
  265. )
  266. lower_param = param.lower()
  267. if lower_param in valid_list:
  268. return lower_param
  269. else:
  270. raise ValueError(
  271. descr
  272. + " {} not supported, should be one of {}".format(param, valid_list)
  273. )
  274. @staticmethod
  275. def _greater_equal_than(value, limit):
  276. return value >= limit - settings.FLOAT_ZERO
  277. @staticmethod
  278. def _less_equal_than(value, limit):
  279. return value <= limit + settings.FLOAT_ZERO
  280. @staticmethod
  281. def _range(value, ranges):
  282. in_range = False
  283. for left_limit, right_limit in ranges:
  284. if (
  285. left_limit - settings.FLOAT_ZERO
  286. <= value
  287. <= right_limit + settings.FLOAT_ZERO
  288. ):
  289. in_range = True
  290. break
  291. return in_range
  292. @staticmethod
  293. def _in(value, right_value_list):
  294. return value in right_value_list
  295. @staticmethod
  296. def _not_in(value, wrong_value_list):
  297. return value not in wrong_value_list
  298. def _warn_deprecated_param(self, param_name, descr):
  299. if self._deprecated_params_set.get(param_name):
  300. logging.warning(
  301. f"{descr} {param_name} is deprecated and ignored in this version."
  302. )
  303. def _warn_to_deprecate_param(self, param_name, descr, new_param):
  304. if self._deprecated_params_set.get(param_name):
  305. logging.warning(
  306. f"{descr} {param_name} will be deprecated in future release; "
  307. f"please use {new_param} instead."
  308. )
  309. return True
  310. return False
  311. class ComponentBase(ABC):
  312. component_name: str
  313. def __str__(self):
  314. """
  315. {
  316. "component_name": "Begin",
  317. "params": {}
  318. }
  319. """
  320. out = json.loads(str(self._param)).get("output", {})
  321. if isinstance(out, dict) and "vector" in out:
  322. del out["vector"]
  323. return """{{
  324. "component_name": "{}",
  325. "params": {},
  326. "output": {},
  327. "inputs": {}
  328. }}""".format(self.component_name,
  329. self._param,
  330. json.dumps(out, ensure_ascii=False),
  331. json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False)
  332. )
  333. def __init__(self, canvas, id, param: ComponentParamBase):
  334. self._canvas = canvas
  335. self._id = id
  336. self._param = param
  337. self._param.check()
  338. def get_dependent_components(self):
  339. cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
  340. if para.get("component_id") \
  341. and para["component_id"].lower().find("answer") < 0 \
  342. and para["component_id"].lower().find("begin") < 0])
  343. return list(cpnts)
  344. def run(self, history, **kwargs):
  345. logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
  346. json.dumps(kwargs, ensure_ascii=False)))
  347. try:
  348. res = self._run(history, **kwargs)
  349. self.set_output(res)
  350. except Exception as e:
  351. self.set_output(pd.DataFrame([{"content": str(e)}]))
  352. raise e
  353. return res
  354. def _run(self, history, **kwargs):
  355. raise NotImplementedError()
  356. def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
  357. o = getattr(self._param, self._param.output_var_name)
  358. if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
  359. if not isinstance(o, list):
  360. o = [o]
  361. o = pd.DataFrame(o)
  362. if allow_partial or not isinstance(o, partial):
  363. if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
  364. return pd.DataFrame(o if isinstance(o, list) else [o])
  365. return self._param.output_var_name, o
  366. outs = None
  367. for oo in o():
  368. if not isinstance(oo, pd.DataFrame):
  369. outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
  370. else:
  371. outs = oo
  372. return self._param.output_var_name, outs
  373. def reset(self):
  374. setattr(self._param, self._param.output_var_name, None)
  375. self._param.inputs = []
  376. def set_output(self, v: partial | pd.DataFrame):
  377. setattr(self._param, self._param.output_var_name, v)
  378. def get_input(self):
  379. reversed_cpnts = []
  380. if len(self._canvas.path) > 1:
  381. reversed_cpnts.extend(self._canvas.path[-2])
  382. reversed_cpnts.extend(self._canvas.path[-1])
  383. if self._param.query:
  384. self._param.inputs = []
  385. outs = []
  386. for q in self._param.query:
  387. if q["component_id"]:
  388. if q["component_id"].split("@")[0].lower().find("begin") >= 0:
  389. cpn_id, key = q["component_id"].split("@")
  390. for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
  391. if p["key"] == key:
  392. outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
  393. self._param.inputs.append({"component_id": q["component_id"],
  394. "content": p.get("value", "")})
  395. break
  396. else:
  397. assert False, f"Can't find parameter '{key}' for {cpn_id}"
  398. continue
  399. outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
  400. self._param.inputs.append({"component_id": q["component_id"],
  401. "content": "\n".join(
  402. [str(d["content"]) for d in outs[-1].to_dict('records')])})
  403. elif q["value"]:
  404. self._param.inputs.append({"component_id": None, "content": q["value"]})
  405. outs.append(pd.DataFrame([{"content": q["value"]}]))
  406. if outs:
  407. df = pd.concat(outs, ignore_index=True)
  408. if "content" in df:
  409. df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
  410. return df
  411. upstream_outs = []
  412. for u in reversed_cpnts[::-1]:
  413. if self.get_component_name(u) in ["switch", "concentrator"]:
  414. continue
  415. if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
  416. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  417. if o is not None:
  418. o["component_id"] = u
  419. upstream_outs.append(o)
  420. continue
  421. #if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
  422. if self.component_name.lower().find("switch") < 0 \
  423. and self.get_component_name(u) in ["relevant", "categorize"]:
  424. continue
  425. if u.lower().find("answer") >= 0:
  426. for r, c in self._canvas.history[::-1]:
  427. if r == "user":
  428. upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
  429. break
  430. break
  431. if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
  432. continue
  433. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  434. if o is not None:
  435. o["component_id"] = u
  436. upstream_outs.append(o)
  437. break
  438. assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
  439. df = pd.concat(upstream_outs, ignore_index=True)
  440. if "content" in df:
  441. df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
  442. self._param.inputs = []
  443. for _, r in df.iterrows():
  444. self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
  445. return df
  446. def get_stream_input(self):
  447. reversed_cpnts = []
  448. if len(self._canvas.path) > 1:
  449. reversed_cpnts.extend(self._canvas.path[-2])
  450. reversed_cpnts.extend(self._canvas.path[-1])
  451. for u in reversed_cpnts[::-1]:
  452. if self.get_component_name(u) in ["switch", "answer"]:
  453. continue
  454. return self._canvas.get_component(u)["obj"].output()[1]
  455. @staticmethod
  456. def be_output(v):
  457. return pd.DataFrame([{"content": v}])
  458. def get_component_name(self, cpn_id):
  459. return self._canvas.get_component(cpn_id)["obj"].component_name.lower()