Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

categorize.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from abc import ABC
  20. from api.db import LLMType
  21. from api.db.services.llm_service import LLMBundle
  22. from agent.component.llm import LLMParam, LLM
  23. from api.utils.api_utils import timeout
  24. from rag.llm.chat_model import ERROR_PREFIX
  25. class CategorizeParam(LLMParam):
  26. """
  27. Define the Categorize component parameters.
  28. """
  29. def __init__(self):
  30. super().__init__()
  31. self.category_description = {}
  32. self.query = "sys.query"
  33. self.message_history_window_size = 1
  34. self.update_prompt()
  35. def check(self):
  36. self.check_positive_integer(self.message_history_window_size, "[Categorize] Message window size > 0")
  37. self.check_empty(self.category_description, "[Categorize] Category examples")
  38. for k, v in self.category_description.items():
  39. if not k:
  40. raise ValueError("[Categorize] Category name can not be empty!")
  41. if not v.get("to"):
  42. raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
  43. def get_input_form(self) -> dict[str, dict]:
  44. return {
  45. "query": {
  46. "type": "line",
  47. "name": "Query"
  48. }
  49. }
  50. def update_prompt(self):
  51. cate_lines = []
  52. for c, desc in self.category_description.items():
  53. for line in desc.get("examples", []):
  54. if not line:
  55. continue
  56. cate_lines.append("USER: \"" + re.sub(r"\n", " ", line, flags=re.DOTALL) + "\" → "+c)
  57. descriptions = []
  58. for c, desc in self.category_description.items():
  59. if desc.get("description"):
  60. descriptions.append(
  61. "\n------\nCategory: {}\nDescription: {}".format(c, desc["description"]))
  62. self.sys_prompt = """
  63. You are an advanced classification system that categorizes user questions into specific types. Analyze the input question and classify it into ONE of the following categories:
  64. {}
  65. Here's description of each category:
  66. - {}
  67. ---- Instructions ----
  68. - Consider both explicit mentions and implied context
  69. - Prioritize the most specific applicable category
  70. - Return only the category name without explanations
  71. - Use "Other" only when no other category fits
  72. """.format(
  73. "\n - ".join(list(self.category_description.keys())),
  74. "\n".join(descriptions)
  75. )
  76. if cate_lines:
  77. self.sys_prompt += """
  78. ---- Examples ----
  79. {}
  80. """.format("\n".join(cate_lines))
  81. class Categorize(LLM, ABC):
  82. component_name = "Categorize"
  83. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
  84. def _invoke(self, **kwargs):
  85. msg = self._canvas.get_history(self._param.message_history_window_size)
  86. if not msg:
  87. msg = [{"role": "user", "content": ""}]
  88. if kwargs.get("sys.query"):
  89. msg[-1]["content"] = kwargs["sys.query"]
  90. self.set_input_value("sys.query", kwargs["sys.query"])
  91. else:
  92. msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
  93. self.set_input_value(self._param.query, msg[-1]["content"])
  94. self._param.update_prompt()
  95. chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
  96. user_prompt = """
  97. ---- Real Data ----
  98. {} →
  99. """.format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
  100. ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
  101. logging.info(f"input: {user_prompt}, answer: {str(ans)}")
  102. if ERROR_PREFIX in ans:
  103. raise Exception(ans)
  104. # Count the number of times each category appears in the answer.
  105. category_counts = {}
  106. for c in self._param.category_description.keys():
  107. count = ans.lower().count(c.lower())
  108. category_counts[c] = count
  109. cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
  110. max_category = list(self._param.category_description.keys())[0]
  111. if any(category_counts.values()):
  112. max_category = max(category_counts.items(), key=lambda x: x[1])[0]
  113. cpn_ids = self._param.category_description[max_category]["to"]
  114. self.set_output("category_name", max_category)
  115. self.set_output("_next", cpn_ids)
  116. def thoughts(self) -> str:
  117. return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))