| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import logging
- import os
- import re
- from abc import ABC
-
- from api.db import LLMType
- from api.db.services.llm_service import LLMBundle
- from agent.component.llm import LLMParam, LLM
- from api.utils.api_utils import timeout
- from rag.llm.chat_model import ERROR_PREFIX
-
-
- class CategorizeParam(LLMParam):
-
- """
- Define the Categorize component parameters.
- """
- def __init__(self):
- super().__init__()
- self.category_description = {}
- self.query = "sys.query"
- self.message_history_window_size = 1
- self.update_prompt()
-
- def check(self):
- self.check_positive_integer(self.message_history_window_size, "[Categorize] Message window size > 0")
- self.check_empty(self.category_description, "[Categorize] Category examples")
- for k, v in self.category_description.items():
- if not k:
- raise ValueError("[Categorize] Category name can not be empty!")
- if not v.get("to"):
- raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
-
- def get_input_form(self) -> dict[str, dict]:
- return {
- "query": {
- "type": "line",
- "name": "Query"
- }
- }
-
- def update_prompt(self):
- cate_lines = []
- for c, desc in self.category_description.items():
- for line in desc.get("examples", []):
- if not line:
- continue
- cate_lines.append("USER: \"" + re.sub(r"\n", " ", line, flags=re.DOTALL) + "\" → "+c)
-
- descriptions = []
- for c, desc in self.category_description.items():
- if desc.get("description"):
- descriptions.append(
- "\n------\nCategory: {}\nDescription: {}".format(c, desc["description"]))
-
- self.sys_prompt = """
- You are an advanced classification system that categorizes user questions into specific types. Analyze the input question and classify it into ONE of the following categories:
- {}
-
- Here's description of each category:
- - {}
-
- ---- Instructions ----
- - Consider both explicit mentions and implied context
- - Prioritize the most specific applicable category
- - Return only the category name without explanations
- - Use "Other" only when no other category fits
-
- """.format(
- "\n - ".join(list(self.category_description.keys())),
- "\n".join(descriptions)
- )
-
- if cate_lines:
- self.sys_prompt += """
- ---- Examples ----
- {}
- """.format("\n".join(cate_lines))
-
-
- class Categorize(LLM, ABC):
- component_name = "Categorize"
-
- @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
- def _invoke(self, **kwargs):
- msg = self._canvas.get_history(self._param.message_history_window_size)
- if not msg:
- msg = [{"role": "user", "content": ""}]
- if kwargs.get("sys.query"):
- msg[-1]["content"] = kwargs["sys.query"]
- self.set_input_value("sys.query", kwargs["sys.query"])
- else:
- msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
- self.set_input_value(self._param.query, msg[-1]["content"])
- self._param.update_prompt()
- chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
-
- user_prompt = """
- ---- Real Data ----
- {} →
- """.format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
- ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
- logging.info(f"input: {user_prompt}, answer: {str(ans)}")
- if ERROR_PREFIX in ans:
- raise Exception(ans)
- # Count the number of times each category appears in the answer.
- category_counts = {}
- for c in self._param.category_description.keys():
- count = ans.lower().count(c.lower())
- category_counts[c] = count
-
- cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
- max_category = list(self._param.category_description.keys())[0]
- if any(category_counts.values()):
- max_category = max(category_counts.items(), key=lambda x: x[1])[0]
- cpn_ids = self._param.category_description[max_category]["to"]
-
- self.set_output("category_name", max_category)
- self.set_output("_next", cpn_ids)
-
- def thoughts(self) -> str:
- return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|