Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

switch.py 2.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from abc import ABC
  17. import pandas as pd
  18. from api.db import LLMType
  19. from api.db.services.knowledgebase_service import KnowledgebaseService
  20. from api.db.services.llm_service import LLMBundle
  21. from api.settings import retrievaler
  22. from graph.component.base import ComponentBase, ComponentParamBase
  23. class SwitchParam(ComponentParamBase):
  24. """
  25. Define the Switch component parameters.
  26. """
  27. def __init__(self):
  28. super().__init__()
  29. """
  30. {
  31. "cpn_id": "categorize:0",
  32. "not": False,
  33. "operator": "gt/gte/lt/lte/eq/in",
  34. "value": "",
  35. "to": ""
  36. }
  37. """
  38. self.conditions = []
  39. self.default = ""
  40. def check(self):
  41. self.check_empty(self.conditions, "[Switch] conditions")
  42. self.check_empty(self.default, "[Switch] Default path")
  43. for cond in self.conditions:
  44. if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
  45. def operators(self, field, op, value):
  46. if op == "gt":
  47. return float(field) > float(value)
  48. if op == "gte":
  49. return float(field) >= float(value)
  50. if op == "lt":
  51. return float(field) < float(value)
  52. if op == "lte":
  53. return float(field) <= float(value)
  54. if op == "eq":
  55. return str(field) == str(value)
  56. if op == "in":
  57. return str(field).find(str(value)) >= 0
  58. return False
  59. class Switch(ComponentBase, ABC):
  60. component_name = "Switch"
  61. def _run(self, history, **kwargs):
  62. for cond in self._param.conditions:
  63. input = self._canvas.get_component(cond["cpn_id"])["obj"].output()[1]
  64. if self._param.operators(input.iloc[0, 0], cond["operator"], cond["value"]):
  65. if not cond["not"]:
  66. return pd.DataFrame([{"content": cond["to"]}])
  67. return pd.DataFrame([{"content": self._param.default}])