|
|
|
@@ -87,10 +87,17 @@ class Categorize(Generate, ABC): |
|
|
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) |
|
|
|
ans = chat_mdl.chat(self._param.get_prompt(input), [{"role": "user", "content": "\nCategory: "}], |
|
|
|
self._param.gen_conf()) |
|
|
|
logging.debug(f"input: {input}, answer: {str(ans)}") |
|
|
|
logging.debug(f"input: {input}, answer: {str(ans)}") |
|
|
|
# Count the number of times each category appears in the answer. |
|
|
|
category_counts = {} |
|
|
|
for c in self._param.category_description.keys(): |
|
|
|
if ans.lower().find(c.lower()) >= 0: |
|
|
|
return Categorize.be_output(self._param.category_description[c]["to"]) |
|
|
|
count = ans.lower().count(c.lower()) |
|
|
|
category_counts[c] = count |
|
|
|
|
|
|
|
# If a category is found, return the category with the highest count. |
|
|
|
if any(category_counts.values()): |
|
|
|
max_category = max(category_counts.items(), key=lambda x: x[1]) |
|
|
|
return Categorize.be_output(self._param.category_description[max_category[0]]["to"]) |
|
|
|
|
|
|
|
return Categorize.be_output(list(self._param.category_description.items())[-1][1]["to"]) |
|
|
|
|