Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import builtins
  17. import json
  18. import logging
  19. import os
  20. from abc import ABC
  21. from functools import partial
  22. from typing import Any, Tuple, Union
  23. import pandas as pd
  24. from agent import settings
  25. _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
  26. _DEPRECATED_PARAMS = "_deprecated_params"
  27. _USER_FEEDED_PARAMS = "_user_feeded_params"
  28. _IS_RAW_CONF = "_is_raw_conf"
  29. class ComponentParamBase(ABC):
  30. def __init__(self):
  31. self.output_var_name = "output"
  32. self.infor_var_name = "infor"
  33. self.message_history_window_size = 22
  34. self.query = []
  35. self.inputs = []
  36. self.debug_inputs = []
  37. def set_name(self, name: str):
  38. self._name = name
  39. return self
  40. def check(self):
  41. raise NotImplementedError("Parameter Object should be checked.")
  42. @classmethod
  43. def _get_or_init_deprecated_params_set(cls):
  44. if not hasattr(cls, _DEPRECATED_PARAMS):
  45. setattr(cls, _DEPRECATED_PARAMS, set())
  46. return getattr(cls, _DEPRECATED_PARAMS)
  47. def _get_or_init_feeded_deprecated_params_set(self, conf=None):
  48. if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
  49. if conf is None:
  50. setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
  51. else:
  52. setattr(
  53. self,
  54. _FEEDED_DEPRECATED_PARAMS,
  55. set(conf[_FEEDED_DEPRECATED_PARAMS]),
  56. )
  57. return getattr(self, _FEEDED_DEPRECATED_PARAMS)
  58. def _get_or_init_user_feeded_params_set(self, conf=None):
  59. if not hasattr(self, _USER_FEEDED_PARAMS):
  60. if conf is None:
  61. setattr(self, _USER_FEEDED_PARAMS, set())
  62. else:
  63. setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
  64. return getattr(self, _USER_FEEDED_PARAMS)
  65. def get_user_feeded(self):
  66. return self._get_or_init_user_feeded_params_set()
  67. def get_feeded_deprecated_params(self):
  68. return self._get_or_init_feeded_deprecated_params_set()
  69. @property
  70. def _deprecated_params_set(self):
  71. return {name: True for name in self.get_feeded_deprecated_params()}
  72. def __str__(self):
  73. return json.dumps(self.as_dict(), ensure_ascii=False)
  74. def as_dict(self):
  75. def _recursive_convert_obj_to_dict(obj):
  76. ret_dict = {}
  77. for attr_name in list(obj.__dict__):
  78. if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
  79. continue
  80. # get attr
  81. attr = getattr(obj, attr_name)
  82. if isinstance(attr, pd.DataFrame):
  83. ret_dict[attr_name] = attr.to_dict()
  84. continue
  85. if attr and type(attr).__name__ not in dir(builtins):
  86. ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
  87. else:
  88. ret_dict[attr_name] = attr
  89. return ret_dict
  90. return _recursive_convert_obj_to_dict(self)
  91. def update(self, conf, allow_redundant=False):
  92. update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
  93. if update_from_raw_conf:
  94. deprecated_params_set = self._get_or_init_deprecated_params_set()
  95. feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set()
  96. user_feeded_params_set = self._get_or_init_user_feeded_params_set()
  97. setattr(self, _IS_RAW_CONF, False)
  98. else:
  99. feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set(conf)
  100. user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
  101. def _recursive_update_param(param, config, depth, prefix):
  102. if depth > settings.PARAM_MAXDEPTH:
  103. raise ValueError("Param define nesting too deep!!!, can not parse it")
  104. inst_variables = param.__dict__
  105. redundant_attrs = []
  106. for config_key, config_value in config.items():
  107. # redundant attr
  108. if config_key not in inst_variables:
  109. if not update_from_raw_conf and config_key.startswith("_"):
  110. setattr(param, config_key, config_value)
  111. else:
  112. setattr(param, config_key, config_value)
  113. # redundant_attrs.append(config_key)
  114. continue
  115. full_config_key = f"{prefix}{config_key}"
  116. if update_from_raw_conf:
  117. # add user feeded params
  118. user_feeded_params_set.add(full_config_key)
  119. # update user feeded deprecated param set
  120. if full_config_key in deprecated_params_set:
  121. feeded_deprecated_params_set.add(full_config_key)
  122. # supported attr
  123. attr = getattr(param, config_key)
  124. if type(attr).__name__ in dir(builtins) or attr is None:
  125. setattr(param, config_key, config_value)
  126. else:
  127. # recursive set obj attr
  128. sub_params = _recursive_update_param(attr, config_value, depth + 1, prefix=f"{prefix}{config_key}.")
  129. setattr(param, config_key, sub_params)
  130. if not allow_redundant and redundant_attrs:
  131. raise ValueError(f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`")
  132. return param
  133. return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
  134. def extract_not_builtin(self):
  135. def _get_not_builtin_types(obj):
  136. ret_dict = {}
  137. for variable in obj.__dict__:
  138. attr = getattr(obj, variable)
  139. if attr and type(attr).__name__ not in dir(builtins):
  140. ret_dict[variable] = _get_not_builtin_types(attr)
  141. return ret_dict
  142. return _get_not_builtin_types(self)
  143. def validate(self):
  144. self.builtin_types = dir(builtins)
  145. self.func = {
  146. "ge": self._greater_equal_than,
  147. "le": self._less_equal_than,
  148. "in": self._in,
  149. "not_in": self._not_in,
  150. "range": self._range,
  151. }
  152. home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
  153. param_validation_path_prefix = home_dir + "/param_validation/"
  154. param_name = type(self).__name__
  155. param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"])
  156. validation_json = None
  157. try:
  158. with open(param_validation_path, "r") as fin:
  159. validation_json = json.loads(fin.read())
  160. except BaseException:
  161. return
  162. self._validate_param(self, validation_json)
  163. def _validate_param(self, param_obj, validation_json):
  164. default_section = type(param_obj).__name__
  165. var_list = param_obj.__dict__
  166. for variable in var_list:
  167. attr = getattr(param_obj, variable)
  168. if type(attr).__name__ in self.builtin_types or attr is None:
  169. if variable not in validation_json:
  170. continue
  171. validation_dict = validation_json[default_section][variable]
  172. value = getattr(param_obj, variable)
  173. value_legal = False
  174. for op_type in validation_dict:
  175. if self.func[op_type](value, validation_dict[op_type]):
  176. value_legal = True
  177. break
  178. if not value_legal:
  179. raise ValueError("Plase check runtime conf, {} = {} does not match user-parameter restriction".format(variable, value))
  180. elif variable in validation_json:
  181. self._validate_param(attr, validation_json)
  182. @staticmethod
  183. def check_string(param, descr):
  184. if type(param).__name__ not in ["str"]:
  185. raise ValueError(descr + " {} not supported, should be string type".format(param))
  186. @staticmethod
  187. def check_empty(param, descr):
  188. if not param:
  189. raise ValueError(descr + " does not support empty value.")
  190. @staticmethod
  191. def check_positive_integer(param, descr):
  192. if type(param).__name__ not in ["int", "long"] or param <= 0:
  193. raise ValueError(descr + " {} not supported, should be positive integer".format(param))
  194. @staticmethod
  195. def check_positive_number(param, descr):
  196. if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
  197. raise ValueError(descr + " {} not supported, should be positive numeric".format(param))
  198. @staticmethod
  199. def check_nonnegative_number(param, descr):
  200. if type(param).__name__ not in ["float", "int", "long"] or param < 0:
  201. raise ValueError(descr + " {} not supported, should be non-negative numeric".format(param))
  202. @staticmethod
  203. def check_decimal_float(param, descr):
  204. if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
  205. raise ValueError(descr + " {} not supported, should be a float number in range [0, 1]".format(param))
  206. @staticmethod
  207. def check_boolean(param, descr):
  208. if type(param).__name__ != "bool":
  209. raise ValueError(descr + " {} not supported, should be bool type".format(param))
  210. @staticmethod
  211. def check_open_unit_interval(param, descr):
  212. if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
  213. raise ValueError(descr + " should be a numeric number between 0 and 1 exclusively")
  214. @staticmethod
  215. def check_valid_value(param, descr, valid_values):
  216. if param not in valid_values:
  217. raise ValueError(descr + " {} is not supported, it should be in {}".format(param, valid_values))
  218. @staticmethod
  219. def check_defined_type(param, descr, types):
  220. if type(param).__name__ not in types:
  221. raise ValueError(descr + " {} not supported, should be one of {}".format(param, types))
  222. @staticmethod
  223. def check_and_change_lower(param, valid_list, descr=""):
  224. if type(param).__name__ != "str":
  225. raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
  226. lower_param = param.lower()
  227. if lower_param in valid_list:
  228. return lower_param
  229. else:
  230. raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
  231. @staticmethod
  232. def _greater_equal_than(value, limit):
  233. return value >= limit - settings.FLOAT_ZERO
  234. @staticmethod
  235. def _less_equal_than(value, limit):
  236. return value <= limit + settings.FLOAT_ZERO
  237. @staticmethod
  238. def _range(value, ranges):
  239. in_range = False
  240. for left_limit, right_limit in ranges:
  241. if left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO:
  242. in_range = True
  243. break
  244. return in_range
  245. @staticmethod
  246. def _in(value, right_value_list):
  247. return value in right_value_list
  248. @staticmethod
  249. def _not_in(value, wrong_value_list):
  250. return value not in wrong_value_list
  251. def _warn_deprecated_param(self, param_name, descr):
  252. if self._deprecated_params_set.get(param_name):
  253. logging.warning(f"{descr} {param_name} is deprecated and ignored in this version.")
  254. def _warn_to_deprecate_param(self, param_name, descr, new_param):
  255. if self._deprecated_params_set.get(param_name):
  256. logging.warning(f"{descr} {param_name} will be deprecated in future release; please use {new_param} instead.")
  257. return True
  258. return False
  259. class ComponentBase(ABC):
  260. component_name: str
  261. def __str__(self):
  262. """
  263. {
  264. "component_name": "Begin",
  265. "params": {}
  266. }
  267. """
  268. out = getattr(self._param, self._param.output_var_name)
  269. if isinstance(out, pd.DataFrame) and "chunks" in out:
  270. del out["chunks"]
  271. setattr(self._param, self._param.output_var_name, out)
  272. return """{{
  273. "component_name": "{}",
  274. "params": {},
  275. "output": {},
  276. "inputs": {}
  277. }}""".format(
  278. self.component_name,
  279. self._param,
  280. json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False),
  281. json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False),
  282. )
  283. def __init__(self, canvas, id, param: ComponentParamBase):
  284. from agent.canvas import Canvas # Local import to avoid cyclic dependency
  285. assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
  286. self._canvas = canvas
  287. self._id = id
  288. self._param = param
  289. self._param.check()
  290. def get_dependent_components(self):
  291. cpnts = set(
  292. [
  293. para["component_id"].split("@")[0]
  294. for para in self._param.query
  295. if para.get("component_id") and para["component_id"].lower().find("answer") < 0 and para["component_id"].lower().find("begin") < 0
  296. ]
  297. )
  298. return list(cpnts)
  299. def run(self, history, **kwargs):
  300. logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False)))
  301. self._param.debug_inputs = []
  302. try:
  303. res = self._run(history, **kwargs)
  304. self.set_output(res)
  305. except Exception as e:
  306. self.set_output(pd.DataFrame([{"content": str(e)}]))
  307. raise e
  308. return res
  309. def _run(self, history, **kwargs):
  310. raise NotImplementedError()
  311. def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
  312. o = getattr(self._param, self._param.output_var_name)
  313. if not isinstance(o, partial):
  314. if not isinstance(o, pd.DataFrame):
  315. if isinstance(o, list):
  316. return self._param.output_var_name, pd.DataFrame(o).dropna()
  317. if o is None:
  318. return self._param.output_var_name, pd.DataFrame()
  319. return self._param.output_var_name, pd.DataFrame([{"content": str(o)}])
  320. return self._param.output_var_name, o
  321. if allow_partial or not isinstance(o, partial):
  322. if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
  323. return pd.DataFrame(o if isinstance(o, list) else [o]).dropna()
  324. return self._param.output_var_name, o
  325. outs = None
  326. for oo in o():
  327. if not isinstance(oo, pd.DataFrame):
  328. outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]).dropna()
  329. else:
  330. outs = oo.dropna()
  331. return self._param.output_var_name, outs
  332. def reset(self):
  333. setattr(self._param, self._param.output_var_name, None)
  334. self._param.inputs = []
  335. def set_output(self, v):
  336. setattr(self._param, self._param.output_var_name, v)
  337. def set_infor(self, v):
  338. setattr(self._param, self._param.infor_var_name, v)
  339. def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
  340. outs = []
  341. for q in sources:
  342. if q.get("component_id"):
  343. if "@" in q["component_id"] and q["component_id"].split("@")[0].lower().find("begin") >= 0:
  344. cpn_id, key = q["component_id"].split("@")
  345. for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
  346. if p["key"] == key:
  347. outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
  348. break
  349. else:
  350. assert False, f"Can't find parameter '{key}' for {cpn_id}"
  351. continue
  352. if q["component_id"].lower().find("answer") == 0:
  353. txt = []
  354. for r, c in self._canvas.history[::-1][: self._param.message_history_window_size][::-1]:
  355. txt.append(f"{r.upper()}:{c}")
  356. txt = "\n".join(txt)
  357. outs.append(pd.DataFrame([{"content": txt}]))
  358. continue
  359. outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
  360. elif q.get("value"):
  361. outs.append(pd.DataFrame([{"content": q["value"]}]))
  362. return outs
  363. def get_input(self):
  364. if self._param.debug_inputs:
  365. return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
  366. reversed_cpnts = []
  367. if len(self._canvas.path) > 1:
  368. reversed_cpnts.extend(self._canvas.path[-2])
  369. reversed_cpnts.extend(self._canvas.path[-1])
  370. up_cpns = self.get_upstream()
  371. reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
  372. if self._param.query:
  373. self._param.inputs = []
  374. outs = self._fetch_outputs_from(self._param.query)
  375. for out in outs:
  376. records = out.to_dict("records")
  377. content: str
  378. if len(records) > 1:
  379. content = "\n".join([str(d["content"]) for d in records])
  380. else:
  381. content = records[0]["content"]
  382. self._param.inputs.append({"component_id": records[0].get("component_id"), "content": content})
  383. if outs:
  384. df = pd.concat(outs, ignore_index=True)
  385. if "content" in df:
  386. df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
  387. return df
  388. upstream_outs = []
  389. for u in reversed_up_cpnts[::-1]:
  390. if self.get_component_name(u) in ["switch", "concentrator"]:
  391. continue
  392. if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
  393. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  394. if o is not None:
  395. o["component_id"] = u
  396. upstream_outs.append(o)
  397. continue
  398. # if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
  399. if self.component_name.lower().find("switch") < 0 and self.get_component_name(u) in ["relevant", "categorize"]:
  400. continue
  401. if u.lower().find("answer") >= 0:
  402. for r, c in self._canvas.history[::-1]:
  403. if r == "user":
  404. upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
  405. break
  406. break
  407. if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
  408. continue
  409. o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
  410. if o is not None:
  411. o["component_id"] = u
  412. upstream_outs.append(o)
  413. break
  414. assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
  415. df = pd.concat(upstream_outs, ignore_index=True)
  416. if "content" in df:
  417. df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
  418. self._param.inputs = []
  419. for _, r in df.iterrows():
  420. self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
  421. return df
  422. def get_input_elements(self):
  423. assert self._param.query, "Please verify the input parameters first."
  424. eles = []
  425. for q in self._param.query:
  426. if q.get("component_id"):
  427. cpn_id = q["component_id"]
  428. if cpn_id.split("@")[0].lower().find("begin") >= 0:
  429. cpn_id, key = cpn_id.split("@")
  430. eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
  431. continue
  432. eles.append({"name": self._canvas.get_component_name(cpn_id), "key": cpn_id})
  433. else:
  434. eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
  435. return eles
  436. def get_stream_input(self):
  437. reversed_cpnts = []
  438. if len(self._canvas.path) > 1:
  439. reversed_cpnts.extend(self._canvas.path[-2])
  440. reversed_cpnts.extend(self._canvas.path[-1])
  441. up_cpns = self.get_upstream()
  442. reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
  443. for u in reversed_up_cpnts[::-1]:
  444. if self.get_component_name(u) in ["switch", "answer"]:
  445. continue
  446. return self._canvas.get_component(u)["obj"].output()[1]
  447. @staticmethod
  448. def be_output(v):
  449. return pd.DataFrame([{"content": v}])
  450. def get_component_name(self, cpn_id):
  451. return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
  452. def debug(self, **kwargs):
  453. return self._run([], **kwargs)
  454. def get_parent(self):
  455. pid = self._canvas.get_component(self._id)["parent_id"]
  456. return self._canvas.get_component(pid)["obj"]
  457. def get_upstream(self):
  458. cpn_nms = self._canvas.get_component(self._id)["upstream"]
  459. return cpn_nms