| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from abc import ABC
- import builtins
- import json
- import os
- import logging
- from functools import partial
- from typing import Tuple, Union
-
- import pandas as pd
-
- from agent import settings
-
- _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
- _DEPRECATED_PARAMS = "_deprecated_params"
- _USER_FEEDED_PARAMS = "_user_feeded_params"
- _IS_RAW_CONF = "_is_raw_conf"
-
-
- class ComponentParamBase(ABC):
- def __init__(self):
- self.output_var_name = "output"
- self.message_history_window_size = 22
- self.query = []
- self.inputs = []
-
- def set_name(self, name: str):
- self._name = name
- return self
-
- def check(self):
- raise NotImplementedError("Parameter Object should be checked.")
-
- @classmethod
- def _get_or_init_deprecated_params_set(cls):
- if not hasattr(cls, _DEPRECATED_PARAMS):
- setattr(cls, _DEPRECATED_PARAMS, set())
- return getattr(cls, _DEPRECATED_PARAMS)
-
- def _get_or_init_feeded_deprecated_params_set(self, conf=None):
- if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
- if conf is None:
- setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
- else:
- setattr(
- self,
- _FEEDED_DEPRECATED_PARAMS,
- set(conf[_FEEDED_DEPRECATED_PARAMS]),
- )
- return getattr(self, _FEEDED_DEPRECATED_PARAMS)
-
- def _get_or_init_user_feeded_params_set(self, conf=None):
- if not hasattr(self, _USER_FEEDED_PARAMS):
- if conf is None:
- setattr(self, _USER_FEEDED_PARAMS, set())
- else:
- setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
- return getattr(self, _USER_FEEDED_PARAMS)
-
- def get_user_feeded(self):
- return self._get_or_init_user_feeded_params_set()
-
- def get_feeded_deprecated_params(self):
- return self._get_or_init_feeded_deprecated_params_set()
-
- @property
- def _deprecated_params_set(self):
- return {name: True for name in self.get_feeded_deprecated_params()}
-
- def __str__(self):
- return json.dumps(self.as_dict(), ensure_ascii=False)
-
- def as_dict(self):
- def _recursive_convert_obj_to_dict(obj):
- ret_dict = {}
- for attr_name in list(obj.__dict__):
- if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
- continue
- # get attr
- attr = getattr(obj, attr_name)
- if isinstance(attr, pd.DataFrame):
- ret_dict[attr_name] = attr.to_dict()
- continue
- if attr and type(attr).__name__ not in dir(builtins):
- ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
- else:
- ret_dict[attr_name] = attr
-
- return ret_dict
-
- return _recursive_convert_obj_to_dict(self)
-
- def update(self, conf, allow_redundant=False):
- update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
- if update_from_raw_conf:
- deprecated_params_set = self._get_or_init_deprecated_params_set()
- feeded_deprecated_params_set = (
- self._get_or_init_feeded_deprecated_params_set()
- )
- user_feeded_params_set = self._get_or_init_user_feeded_params_set()
- setattr(self, _IS_RAW_CONF, False)
- else:
- feeded_deprecated_params_set = (
- self._get_or_init_feeded_deprecated_params_set(conf)
- )
- user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
-
- def _recursive_update_param(param, config, depth, prefix):
- if depth > settings.PARAM_MAXDEPTH:
- raise ValueError("Param define nesting too deep!!!, can not parse it")
-
- inst_variables = param.__dict__
- redundant_attrs = []
- for config_key, config_value in config.items():
- # redundant attr
- if config_key not in inst_variables:
- if not update_from_raw_conf and config_key.startswith("_"):
- setattr(param, config_key, config_value)
- else:
- setattr(param, config_key, config_value)
- # redundant_attrs.append(config_key)
- continue
-
- full_config_key = f"{prefix}{config_key}"
-
- if update_from_raw_conf:
- # add user feeded params
- user_feeded_params_set.add(full_config_key)
-
- # update user feeded deprecated param set
- if full_config_key in deprecated_params_set:
- feeded_deprecated_params_set.add(full_config_key)
-
- # supported attr
- attr = getattr(param, config_key)
- if type(attr).__name__ in dir(builtins) or attr is None:
- setattr(param, config_key, config_value)
-
- else:
- # recursive set obj attr
- sub_params = _recursive_update_param(
- attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
- )
- setattr(param, config_key, sub_params)
-
- if not allow_redundant and redundant_attrs:
- raise ValueError(
- f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
- )
-
- return param
-
- return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
-
- def extract_not_builtin(self):
- def _get_not_builtin_types(obj):
- ret_dict = {}
- for variable in obj.__dict__:
- attr = getattr(obj, variable)
- if attr and type(attr).__name__ not in dir(builtins):
- ret_dict[variable] = _get_not_builtin_types(attr)
-
- return ret_dict
-
- return _get_not_builtin_types(self)
-
- def validate(self):
- self.builtin_types = dir(builtins)
- self.func = {
- "ge": self._greater_equal_than,
- "le": self._less_equal_than,
- "in": self._in,
- "not_in": self._not_in,
- "range": self._range,
- }
- home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
- param_validation_path_prefix = home_dir + "/param_validation/"
-
- param_name = type(self).__name__
- param_validation_path = "/".join(
- [param_validation_path_prefix, param_name + ".json"]
- )
-
- validation_json = None
-
- try:
- with open(param_validation_path, "r") as fin:
- validation_json = json.loads(fin.read())
- except BaseException:
- return
-
- self._validate_param(self, validation_json)
-
- def _validate_param(self, param_obj, validation_json):
- default_section = type(param_obj).__name__
- var_list = param_obj.__dict__
-
- for variable in var_list:
- attr = getattr(param_obj, variable)
-
- if type(attr).__name__ in self.builtin_types or attr is None:
- if variable not in validation_json:
- continue
-
- validation_dict = validation_json[default_section][variable]
- value = getattr(param_obj, variable)
- value_legal = False
-
- for op_type in validation_dict:
- if self.func[op_type](value, validation_dict[op_type]):
- value_legal = True
- break
-
- if not value_legal:
- raise ValueError(
- "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
- variable, value
- )
- )
-
- elif variable in validation_json:
- self._validate_param(attr, validation_json)
-
- @staticmethod
- def check_string(param, descr):
- if type(param).__name__ not in ["str"]:
- raise ValueError(
- descr + " {} not supported, should be string type".format(param)
- )
-
- @staticmethod
- def check_empty(param, descr):
- if not param:
- raise ValueError(
- descr + " does not support empty value."
- )
-
- @staticmethod
- def check_positive_integer(param, descr):
- if type(param).__name__ not in ["int", "long"] or param <= 0:
- raise ValueError(
- descr + " {} not supported, should be positive integer".format(param)
- )
-
- @staticmethod
- def check_positive_number(param, descr):
- if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
- raise ValueError(
- descr + " {} not supported, should be positive numeric".format(param)
- )
-
- @staticmethod
- def check_nonnegative_number(param, descr):
- if type(param).__name__ not in ["float", "int", "long"] or param < 0:
- raise ValueError(
- descr
- + " {} not supported, should be non-negative numeric".format(param)
- )
-
- @staticmethod
- def check_decimal_float(param, descr):
- if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
- raise ValueError(
- descr
- + " {} not supported, should be a float number in range [0, 1]".format(
- param
- )
- )
-
- @staticmethod
- def check_boolean(param, descr):
- if type(param).__name__ != "bool":
- raise ValueError(
- descr + " {} not supported, should be bool type".format(param)
- )
-
- @staticmethod
- def check_open_unit_interval(param, descr):
- if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
- raise ValueError(
- descr + " should be a numeric number between 0 and 1 exclusively"
- )
-
- @staticmethod
- def check_valid_value(param, descr, valid_values):
- if param not in valid_values:
- raise ValueError(
- descr
- + " {} is not supported, it should be in {}".format(param, valid_values)
- )
-
- @staticmethod
- def check_defined_type(param, descr, types):
- if type(param).__name__ not in types:
- raise ValueError(
- descr + " {} not supported, should be one of {}".format(param, types)
- )
-
- @staticmethod
- def check_and_change_lower(param, valid_list, descr=""):
- if type(param).__name__ != "str":
- raise ValueError(
- descr
- + " {} not supported, should be one of {}".format(param, valid_list)
- )
-
- lower_param = param.lower()
- if lower_param in valid_list:
- return lower_param
- else:
- raise ValueError(
- descr
- + " {} not supported, should be one of {}".format(param, valid_list)
- )
-
- @staticmethod
- def _greater_equal_than(value, limit):
- return value >= limit - settings.FLOAT_ZERO
-
- @staticmethod
- def _less_equal_than(value, limit):
- return value <= limit + settings.FLOAT_ZERO
-
- @staticmethod
- def _range(value, ranges):
- in_range = False
- for left_limit, right_limit in ranges:
- if (
- left_limit - settings.FLOAT_ZERO
- <= value
- <= right_limit + settings.FLOAT_ZERO
- ):
- in_range = True
- break
-
- return in_range
-
- @staticmethod
- def _in(value, right_value_list):
- return value in right_value_list
-
- @staticmethod
- def _not_in(value, wrong_value_list):
- return value not in wrong_value_list
-
- def _warn_deprecated_param(self, param_name, descr):
- if self._deprecated_params_set.get(param_name):
- logging.warning(
- f"{descr} {param_name} is deprecated and ignored in this version."
- )
-
- def _warn_to_deprecate_param(self, param_name, descr, new_param):
- if self._deprecated_params_set.get(param_name):
- logging.warning(
- f"{descr} {param_name} will be deprecated in future release; "
- f"please use {new_param} instead."
- )
- return True
- return False
-
-
- class ComponentBase(ABC):
- component_name: str
-
- def __str__(self):
- """
- {
- "component_name": "Begin",
- "params": {}
- }
- """
- out = json.loads(str(self._param)).get("output", {})
- if isinstance(out, dict) and "vector" in out:
- del out["vector"]
- return """{{
- "component_name": "{}",
- "params": {},
- "output": {},
- "inputs": {}
- }}""".format(self.component_name,
- self._param,
- json.dumps(out, ensure_ascii=False),
- json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False)
- )
-
- def __init__(self, canvas, id, param: ComponentParamBase):
- self._canvas = canvas
- self._id = id
- self._param = param
- self._param.check()
-
- def get_dependent_components(self):
- cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
- if para.get("component_id") \
- and para["component_id"].lower().find("answer") < 0 \
- and para["component_id"].lower().find("begin") < 0])
- return list(cpnts)
-
- def run(self, history, **kwargs):
- logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
- json.dumps(kwargs, ensure_ascii=False)))
- try:
- res = self._run(history, **kwargs)
- self.set_output(res)
- except Exception as e:
- self.set_output(pd.DataFrame([{"content": str(e)}]))
- raise e
-
- return res
-
- def _run(self, history, **kwargs):
- raise NotImplementedError()
-
- def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
- o = getattr(self._param, self._param.output_var_name)
- if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
- if not isinstance(o, list):
- o = [o]
- o = pd.DataFrame(o)
-
- if allow_partial or not isinstance(o, partial):
- if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
- return pd.DataFrame(o if isinstance(o, list) else [o])
- return self._param.output_var_name, o
-
- outs = None
- for oo in o():
- if not isinstance(oo, pd.DataFrame):
- outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
- else:
- outs = oo
- return self._param.output_var_name, outs
-
- def reset(self):
- setattr(self._param, self._param.output_var_name, None)
- self._param.inputs = []
-
- def set_output(self, v: partial | pd.DataFrame):
- setattr(self._param, self._param.output_var_name, v)
-
- def get_input(self):
- reversed_cpnts = []
- if len(self._canvas.path) > 1:
- reversed_cpnts.extend(self._canvas.path[-2])
- reversed_cpnts.extend(self._canvas.path[-1])
-
- if self._param.query:
- self._param.inputs = []
- outs = []
- for q in self._param.query:
- if q["component_id"]:
- if q["component_id"].split("@")[0].lower().find("begin") >= 0:
- cpn_id, key = q["component_id"].split("@")
- for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
- if p["key"] == key:
- outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
- self._param.inputs.append({"component_id": q["component_id"],
- "content": p.get("value", "")})
- break
- else:
- assert False, f"Can't find parameter '{key}' for {cpn_id}"
- continue
-
- outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
- self._param.inputs.append({"component_id": q["component_id"],
- "content": "\n".join(
- [str(d["content"]) for d in outs[-1].to_dict('records')])})
- elif q["value"]:
- self._param.inputs.append({"component_id": None, "content": q["value"]})
- outs.append(pd.DataFrame([{"content": q["value"]}]))
- if outs:
- df = pd.concat(outs, ignore_index=True)
- if "content" in df:
- df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
- return df
-
- upstream_outs = []
-
- for u in reversed_cpnts[::-1]:
- if self.get_component_name(u) in ["switch", "concentrator"]:
- continue
- if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
- o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
- if o is not None:
- o["component_id"] = u
- upstream_outs.append(o)
- continue
- #if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
- if self.component_name.lower().find("switch") < 0 \
- and self.get_component_name(u) in ["relevant", "categorize"]:
- continue
- if u.lower().find("answer") >= 0:
- for r, c in self._canvas.history[::-1]:
- if r == "user":
- upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
- break
- break
- if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
- continue
- o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
- if o is not None:
- o["component_id"] = u
- upstream_outs.append(o)
- break
-
- assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
-
- df = pd.concat(upstream_outs, ignore_index=True)
- if "content" in df:
- df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
-
- self._param.inputs = []
- for _, r in df.iterrows():
- self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
-
- return df
-
- def get_stream_input(self):
- reversed_cpnts = []
- if len(self._canvas.path) > 1:
- reversed_cpnts.extend(self._canvas.path[-2])
- reversed_cpnts.extend(self._canvas.path[-1])
-
- for u in reversed_cpnts[::-1]:
- if self.get_component_name(u) in ["switch", "answer"]:
- continue
- return self._canvas.get_component(u)["obj"].output()[1]
-
- @staticmethod
- def be_output(v):
- return pd.DataFrame([{"content": v}])
-
- def get_component_name(self, cpn_id):
- return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
|