You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 23KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from abc import ABC
  17. import builtins
  18. import json
  19. import os
  20. import logging
  21. from functools import partial
  22. from typing import Any, Tuple, Union
  23. import pandas as pd
  24. from agent import settings
  25. _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
  26. _DEPRECATED_PARAMS = "_deprecated_params"
  27. _USER_FEEDED_PARAMS = "_user_feeded_params"
  28. _IS_RAW_CONF = "_is_raw_conf"
  29. class ComponentParamBase(ABC):
  30. def __init__(self):
  31. self.output_var_name = "output"
  32. self.infor_var_name = "infor"
  33. self.message_history_window_size = 22
  34. self.query = []
  35. self.inputs = []
  36. self.debug_inputs = []
  37. def set_name(self, name: str):
  38. self._name = name
  39. return self
  40. def check(self):
  41. raise NotImplementedError("Parameter Object should be checked.")
  42. @classmethod
  43. def _get_or_init_deprecated_params_set(cls):
  44. if not hasattr(cls, _DEPRECATED_PARAMS):
  45. setattr(cls, _DEPRECATED_PARAMS, set())
  46. return getattr(cls, _DEPRECATED_PARAMS)
  47. def _get_or_init_feeded_deprecated_params_set(self, conf=None):
  48. if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
  49. if conf is None:
  50. setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
  51. else:
  52. setattr(
  53. self,
  54. _FEEDED_DEPRECATED_PARAMS,
  55. set(conf[_FEEDED_DEPRECATED_PARAMS]),
  56. )
  57. return getattr(self, _FEEDED_DEPRECATED_PARAMS)
  58. def _get_or_init_user_feeded_params_set(self, conf=None):
  59. if not hasattr(self, _USER_FEEDED_PARAMS):
  60. if conf is None:
  61. setattr(self, _USER_FEEDED_PARAMS, set())
  62. else:
  63. setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
  64. return getattr(self, _USER_FEEDED_PARAMS)
  65. def get_user_feeded(self):
  66. return self._get_or_init_user_feeded_params_set()
  67. def get_feeded_deprecated_params(self):
  68. return self._get_or_init_feeded_deprecated_params_set()
  69. @property
  70. def _deprecated_params_set(self):
  71. return {name: True for name in self.get_feeded_deprecated_params()}
  72. def __str__(self):
  73. return json.dumps(self.as_dict(), ensure_ascii=False)
  74. def as_dict(self):
  75. def _recursive_convert_obj_to_dict(obj):
  76. ret_dict = {}
  77. for attr_name in list(obj.__dict__):
  78. if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
  79. continue
  80. # get attr
  81. attr = getattr(obj, attr_name)
  82. if isinstance(attr, pd.DataFrame):
  83. ret_dict[attr_name] = attr.to_dict()
  84. continue
  85. if attr and type(attr).__name__ not in dir(builtins):
  86. ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
  87. else:
  88. ret_dict[attr_name] = attr
  89. return ret_dict
  90. return _recursive_convert_obj_to_dict(self)
  91. def update(self, conf, allow_redundant=False):
  92. update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
  93. if update_from_raw_conf:
  94. deprecated_params_set = self._get_or_init_deprecated_params_set()
  95. feeded_deprecated_params_set = (
  96. self._get_or_init_feeded_deprecated_params_set()
  97. )
  98. user_feeded_params_set = self._get_or_init_user_feeded_params_set()
  99. setattr(self, _IS_RAW_CONF, False)
  100. else:
  101. feeded_deprecated_params_set = (
  102. self._get_or_init_feeded_deprecated_params_set(conf)
  103. )
  104. user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
  105. def _recursive_update_param(param, config, depth, prefix):
  106. if depth > settings.PARAM_MAXDEPTH:
  107. raise ValueError("Param define nesting too deep!!!, can not parse it")
  108. inst_variables = param.__dict__
  109. redundant_attrs = []
  110. for config_key, config_value in config.items():
  111. # redundant attr
  112. if config_key not in inst_variables:
  113. if not update_from_raw_conf and config_key.startswith("_"):
  114. setattr(param, config_key, config_value)
  115. else:
  116. setattr(param, config_key, config_value)
  117. # redundant_attrs.append(config_key)
  118. continue
  119. full_config_key = f"{prefix}{config_key}"
  120. if update_from_raw_conf:
  121. # add user feeded params
  122. user_feeded_params_set.add(full_config_key)
  123. # update user feeded deprecated param set
  124. if full_config_key in deprecated_params_set:
  125. feeded_deprecated_params_set.add(full_config_key)
  126. # supported attr
  127. attr = getattr(param, config_key)
  128. if type(attr).__name__ in dir(builtins) or attr is None:
  129. setattr(param, config_key, config_value)
  130. else:
  131. # recursive set obj attr
  132. sub_params = _recursive_update_param(
  133. attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
  134. )
  135. setattr(param, config_key, sub_params)
  136. if not allow_redundant and redundant_attrs:
  137. raise ValueError(
  138. f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
  139. )
  140. return param
  141. return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
  142. def extract_not_builtin(self):
  143. def _get_not_builtin_types(obj):
  144. ret_dict = {}
  145. for variable in obj.__dict__:
  146. attr = getattr(obj, variable)
  147. if attr and type(attr).__name__ not in dir(builtins):
  148. ret_dict[variable] = _get_not_builtin_types(attr)
  149. return ret_dict
  150. return _get_not_builtin_types(self)
  151. def validate(self):
  152. self.builtin_types = dir(builtins)
  153. self.func = {
  154. "ge": self._greater_equal_than,
  155. "le": self._less_equal_than,
  156. "in": self._in,
  157. "not_in": self._not_in,
  158. "range": self._range,
  159. }
  160. home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
  161. param_validation_path_prefix = home_dir + "/param_validation/"
  162. param_name = type(self).__name__
  163. param_validation_path = "/".join(
  164. [param_validation_path_prefix, param_name + ".json"]
  165. )
  166. validation_json = None
  167. try:
  168. with open(param_validation_path, "r") as fin:
  169. validation_json = json.loads(fin.read())
  170. except BaseException:
  171. return
  172. self._validate_param(self, validation_json)
  173. def _validate_param(self, param_obj, validation_json):
  174. default_section = type(param_obj).__name__
  175. var_list = param_obj.__dict__
  176. for variable in var_list:
  177. attr = getattr(param_obj, variable)
  178. if type(attr).__name__ in self.builtin_types or attr is None:
  179. if variable not in validation_json:
  180. continue
  181. validation_dict = validation_json[default_section][variable]
  182. value = getattr(param_obj, variable)
  183. value_legal = False
  184. for op_type in validation_dict:
  185. if self.func[op_type](value, validation_dict[op_type]):
  186. value_legal = True
  187. break
  188. if not value_legal:
  189. raise ValueError(
  190. "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
  191. variable, value
  192. )
  193. )
  194. elif variable in validation_json:
  195. self._validate_param(attr, validation_json)
  196. @staticmethod
  197. def check_string(param, descr):
  198. if type(param).__name__ not in ["str"]:
  199. raise ValueError(
  200. descr + " {} not supported, should be string type".format(param)
  201. )
  202. @staticmethod
  203. def check_empty(param, descr):
  204. if not param:
  205. raise ValueError(
  206. descr + " does not support empty value."
  207. )
  208. @staticmethod
  209. def check_positive_integer(param, descr):
  210. if type(param).__name__ not in ["int", "long"] or param <= 0:
  211. raise ValueError(
  212. descr + " {} not supported, should be positive integer".format(param)
  213. )
  214. @staticmethod
  215. def check_positive_number(param, descr):
  216. if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
  217. raise ValueError(
  218. descr + " {} not supported, should be positive numeric".format(param)
  219. )
  220. @staticmethod
  221. def check_nonnegative_number(param, descr):
  222. if type(param).__name__ not in ["float", "int", "long"] or param < 0:
  223. raise ValueError(
  224. descr
  225. + " {} not supported, should be non-negative numeric".format(param)
  226. )
  227. @staticmethod
  228. def check_decimal_float(param, descr):
  229. if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
  230. raise ValueError(
  231. descr
  232. + " {} not supported, should be a float number in range [0, 1]".format(
  233. param
  234. )
  235. )
  236. @staticmethod
  237. def check_boolean(param, descr):
  238. if type(param).__name__ != "bool":
  239. raise ValueError(
  240. descr + " {} not supported, should be bool type".format(param)
  241. )
  242. @staticmethod
  243. def check_open_unit_interval(param, descr):
  244. if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
  245. raise ValueError(
  246. descr + " should be a numeric number between 0 and 1 exclusively"
  247. )
  248. @staticmethod
  249. def check_valid_value(param, descr, valid_values):
  250. if param not in valid_values:
  251. raise ValueError(
  252. descr
  253. + " {} is not supported, it should be in {}".format(param, valid_values)
  254. )
  255. @staticmethod
  256. def check_defined_type(param, descr, types):
  257. if type(param).__name__ not in types:
  258. raise ValueError(
  259. descr + " {} not supported, should be one of {}".format(param, types)
  260. )
  261. @staticmethod
  262. def check_and_change_lower(param, valid_list, descr=""):
  263. if type(param).__name__ != "str":
  264. raise ValueError(
  265. descr
  266. + " {} not supported, should be one of {}".format(param, valid_list)
  267. )
  268. lower_param = param.lower()
  269. if lower_param in valid_list:
  270. return lower_param
  271. else:
  272. raise ValueError(
  273. descr
  274. + " {} not supported, should be one of {}".format(param, valid_list)
  275. )
  276. @staticmethod
  277. def _greater_equal_than(value, limit):
  278. return value >= limit - settings.FLOAT_ZERO
  279. @staticmethod
  280. def _less_equal_than(value, limit):
  281. return value <= limit + settings.FLOAT_ZERO
  282. @staticmethod
  283. def _range(value, ranges):
  284. in_range = False
  285. for left_limit, right_limit in ranges:
  286. if (
  287. left_limit - settings.FLOAT_ZERO
  288. <= value
  289. <= right_limit + settings.FLOAT_ZERO
  290. ):
  291. in_range = True
  292. break
  293. return in_range
  294. @staticmethod
  295. def _in(value, right_value_list):
  296. return value in right_value_list
  297. @staticmethod
  298. def _not_in(value, wrong_value_list):
  299. return value not in wrong_value_list
  300. def _warn_deprecated_param(self, param_name, descr):
  301. if self._deprecated_params_set.get(param_name):
  302. logging.warning(
  303. f"{descr} {param_name} is deprecated and ignored in this version."
  304. )
  305. def _warn_to_deprecate_param(self, param_name, descr, new_param):
  306. if self._deprecated_params_set.get(param_name):
  307. logging.warning(
  308. f"{descr} {param_name} will be deprecated in future release; "
  309. f"please use {new_param} instead."
  310. )
  311. return True
  312. return False
  313. class ComponentBase(ABC):
  314. component_name: str
  315. def __str__(self):
  316. """
  317. {
  318. "component_name": "Begin",
  319. "params": {}
  320. }
  321. """
  322. out = getattr(self._param, self._param.output_var_name)
  323. if isinstance(out, pd.DataFrame) and "chunks" in out:
  324. del out["chunks"]
  325. setattr(self._param, self._param.output_var_name, out)
  326. return """{{
  327. "component_name": "{}",
  328. "params": {},
  329. "output": {},
  330. "inputs": {}
  331. }}""".format(self.component_name,
  332. self._param,
  333. json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False),
  334. json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False)
  335. )
  336. def __init__(self, canvas, id, param: ComponentParamBase):
  337. from agent.canvas import Canvas # Local import to avoid cyclic dependency
  338. assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
  339. self._canvas = canvas
  340. self._id = id
  341. self._param = param
  342. self._param.check()
  343. def get_dependent_components(self):
  344. cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
  345. if para.get("component_id") \
  346. and para["component_id"].lower().find("answer") < 0 \
  347. and para["component_id"].lower().find("begin") < 0])
  348. return list(cpnts)
  349. def run(self, history, **kwargs):
  350. logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
  351. json.dumps(kwargs, ensure_ascii=False)))
  352. self._param.debug_inputs = []
  353. try:
  354. res = self._run(history, **kwargs)
  355. self.set_output(res)
  356. except Exception as e:
  357. self.set_output(pd.DataFrame([{"content": str(e)}]))
  358. raise e
  359. return res
  360. def _run(self, history, **kwargs):
  361. raise NotImplementedError()
  362. def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
  363. o = getattr(self._param, self._param.output_var_name)
  364. if not isinstance(o, partial):
  365. if not isinstance(o, pd.DataFrame):
  366. if isinstance(o, list):
  367. return self._param.output_var_name, pd.DataFrame(o).dropna()
  368. if o is None:
  369. return self._param.output_var_name, pd.DataFrame()
  370. return self._param.output_var_name, pd.DataFrame([{"content": str(o)}])
  371. return self._param.output_var_name, o
  372. if allow_partial or not isinstance(o, partial):
  373. if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
  374. return pd.DataFrame(o if isinstance(o, list) else [o]).dropna()
  375. return self._param.output_var_name, o
  376. outs = None
  377. for oo in o():
  378. if not isinstance(oo, pd.DataFrame):
  379. outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]).dropna()
  380. else:
  381. outs = oo.dropna()
  382. return self._param.output_var_name, outs
  383. def reset(self):
  384. setattr(self._param, self._param.output_var_name, None)
  385. self._param.inputs = []
  386. def set_output(self, v):
  387. setattr(self._param, self._param.output_var_name, v)
  388. def set_infor(self, v):
  389. setattr(self._param, self._param.infor_var_name, v)
  390. def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
  391. outs = []
  392. for q in sources:
  393. if q.get("component_id"):
  394. if "@" in q["component_id"] and q["component_id"].split("@")[0].lower().find("begin") >= 0:
  395. cpn_id, key = q["component_id"].split("@")
  396. for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
  397. if p["key"] == key:
  398. outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
  399. break
  400. else:
  401. assert False, f"Can't find parameter '{key}' for {cpn_id}"
  402. continue
  403. if q["component_id"].lower().find("answer") == 0:
  404. txt = []
  405. for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
  406. txt.append(f"{r.upper()}:{c}")
  407. txt = "\n".join(txt)
  408. outs.append(pd.DataFrame([{"content": txt}]))
  409. continue
  410. outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
  411. elif q.get("value"):
  412. outs.append(pd.DataFrame([{"content": q["value"]}]))
  413. return outs
  414. def get_input(self):
  415. if self._param.debug_inputs:
  416. return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
  417. reversed_cpnts = []
  418. if len(self._canvas.path) > 1:
  419. reversed_cpnts.extend(self._canvas.path[-2])
  420. reversed_cpnts.extend(self._canvas.path[-1])
  421. up_cpns = self.get_upstream()
  422. reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
  423. if self._param.query:
  424. self._param.inputs = []
  425. outs = self._fetch_outputs_from(self._param.query)
  426. for out in outs:
  427. records = out.to_dict("records")
  428. content: str
  429. if len(records) > 1:
  430. content = "\n".join(
  431. [str(d["content"]) for d in records]
  432. )
  433. else:
  434. content = records[0]["content"]
  435. self._param.inputs.append({
  436. "component_id": records[0].get("component_id"),
  437. "content": content
  438. })
  439. if outs:
  440. df = pd.concat(outs, ignore_index=True)
  441. if "content" in df:
  442. df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
  443. return df
  444. upstream_outs = []
  445. for u in reversed_up_cpnts[::-1]:
  446. if self.get_component_name(u) in ["switch", "concentrator"]:
  447. continue
  448. if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
  449. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  450. if o is not None:
  451. o["component_id"] = u
  452. upstream_outs.append(o)
  453. continue
  454. #if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
  455. if self.component_name.lower().find("switch") < 0 \
  456. and self.get_component_name(u) in ["relevant", "categorize"]:
  457. continue
  458. if u.lower().find("answer") >= 0:
  459. for r, c in self._canvas.history[::-1]:
  460. if r == "user":
  461. upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
  462. break
  463. break
  464. if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
  465. continue
  466. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  467. if o is not None:
  468. o["component_id"] = u
  469. upstream_outs.append(o)
  470. break
  471. assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
  472. df = pd.concat(upstream_outs, ignore_index=True)
  473. if "content" in df:
  474. df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
  475. self._param.inputs = []
  476. for _, r in df.iterrows():
  477. self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
  478. return df
  479. def get_input_elements(self):
  480. assert self._param.query, "Please verify the input parameters first."
  481. eles = []
  482. for q in self._param.query:
  483. if q.get("component_id"):
  484. cpn_id = q["component_id"]
  485. if cpn_id.split("@")[0].lower().find("begin") >= 0:
  486. cpn_id, key = cpn_id.split("@")
  487. eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
  488. continue
  489. eles.append({"name": self._canvas.get_component_name(cpn_id), "key": cpn_id})
  490. else:
  491. eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
  492. return eles
  493. def get_stream_input(self):
  494. reversed_cpnts = []
  495. if len(self._canvas.path) > 1:
  496. reversed_cpnts.extend(self._canvas.path[-2])
  497. reversed_cpnts.extend(self._canvas.path[-1])
  498. up_cpns = self.get_upstream()
  499. reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
  500. for u in reversed_up_cpnts[::-1]:
  501. if self.get_component_name(u) in ["switch", "answer"]:
  502. continue
  503. return self._canvas.get_component(u)["obj"].output()[1]
  504. @staticmethod
  505. def be_output(v):
  506. return pd.DataFrame([{"content": v}])
  507. def get_component_name(self, cpn_id):
  508. return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
  509. def debug(self, **kwargs):
  510. return self._run([], **kwargs)
  511. def get_parent(self):
  512. pid = self._canvas.get_component(self._id)["parent_id"]
  513. return self._canvas.get_component(pid)["obj"]
  514. def get_upstream(self):
  515. cpn_nms = self._canvas.get_component(self._id)['upstream']
  516. return cpn_nms