Browse Source

Agent plans tasks by referring to its own prompt. (#9315)

### What problem does this PR solve?

Fixes the issue in the analyze_task execution flow where the Lead Agent
was not utilizing its own sys_prompt during task analysis, resulting in
incorrect or incomplete task planning.
https://github.com/infiniflow/ragflow/issues/9294
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.20.2
TeslaZY 2 months ago
parent
commit
476c56868d
No account linked to committer's email address
3 changed files with 11 additions and 8 deletions
  1. 5
    5
      agent/component/agent_with_tools.py
  2. 3
    0
      rag/prompts/analyze_task_user.md
  3. 3
    3
      rag/prompts/prompts.py

+ 5
- 5
agent/component/agent_with_tools.py View File

@@ -165,7 +165,7 @@ class Agent(LLM, ToolBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
for delta_ans, tk in self._react_with_tools_streamly(msg, use_tools):
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools):
ans += delta_ans

if ans.find("**ERROR**") >= 0:
@@ -185,7 +185,7 @@ class Agent(LLM, ToolBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(msg, use_tools):
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools):
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
@@ -208,7 +208,7 @@ class Agent(LLM, ToolBase):
]):
yield delta_ans

def _react_with_tools_streamly(self, history: list[dict], use_tools):
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
@@ -221,7 +221,7 @@ class Agent(LLM, ToolBase):

def use_tool(name, args):
nonlocal hist, use_tools, token_count,last_calling,user_request
print(f"{last_calling=} == {name=}", )
logging.info(f"{last_calling=} == {name=}")
# Summarize of function calling
#if all([
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
@@ -275,7 +275,7 @@ class Agent(LLM, ToolBase):
else:
hist.append({"role": "user", "content": content})

task_desc = analyze_task(self.chat_mdl, user_request, tool_metas)
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
self.callback("analyze_task", {}, task_desc)
for _ in range(self._param.max_rounds + 1):
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)

+ 3
- 0
rag/prompts/analyze_task_user.md View File

@@ -4,6 +4,9 @@ Task: {{ task }}

Context: {{ context }}

**Agent Prompt**
{{ agent_prompt }}

**Analysis Requirements:**
1. Is it just a small talk? (If yes, no further plan or analysis is needed)
2. What is the core objective of the task?

+ 3
- 3
rag/prompts/prompts.py View File

@@ -335,13 +335,13 @@ def form_history(history, limit=-6):
return context


def analyze_task(chat_mdl, task_name, tools_description: list[dict]):
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict]):
tools_desc = tool_schema(tools_description)
context = ""

template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER)
kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": template.render(task=task_name, context=context, tools_desc=tools_desc)}], {})
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": context}], {})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)

Loading…
Cancel
Save