|
|
|
@@ -1,7 +1,7 @@ |
|
|
|
import re |
|
|
|
from typing import List, Tuple, Any, Union, Sequence, Optional, cast |
|
|
|
|
|
|
|
from langchain import BasePromptTemplate |
|
|
|
from langchain import BasePromptTemplate, PromptTemplate |
|
|
|
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent |
|
|
|
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE |
|
|
|
from langchain.callbacks.base import BaseCallbackManager |
|
|
|
@@ -12,6 +12,7 @@ from langchain.tools import BaseTool |
|
|
|
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX |
|
|
|
|
|
|
|
from core.chain.llm_chain import LLMChain |
|
|
|
from core.model_providers.models.entity.model_params import ModelMode |
|
|
|
from core.model_providers.models.llm.base import BaseLLM |
|
|
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool |
|
|
|
|
|
|
|
@@ -92,6 +93,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): |
|
|
|
rst = tool.run(tool_input={'query': kwargs['input']}) |
|
|
|
return AgentFinish(return_values={"output": rst}, log=rst) |
|
|
|
|
|
|
|
if intermediate_steps: |
|
|
|
_, observation = intermediate_steps[-1] |
|
|
|
return AgentFinish(return_values={"output": observation}, log=observation) |
|
|
|
|
|
|
|
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) |
|
|
|
|
|
|
|
try: |
|
|
|
@@ -107,6 +112,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): |
|
|
|
if isinstance(tool_inputs, dict) and 'query' in tool_inputs: |
|
|
|
tool_inputs['query'] = kwargs['input'] |
|
|
|
agent_decision.tool_input = tool_inputs |
|
|
|
elif isinstance(tool_inputs, str): |
|
|
|
agent_decision.tool_input = kwargs['input'] |
|
|
|
else: |
|
|
|
agent_decision.return_values['output'] = '' |
|
|
|
return agent_decision |
|
|
|
@@ -143,6 +150,61 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): |
|
|
|
] |
|
|
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def create_completion_prompt( |
|
|
|
cls, |
|
|
|
tools: Sequence[BaseTool], |
|
|
|
prefix: str = PREFIX, |
|
|
|
format_instructions: str = FORMAT_INSTRUCTIONS, |
|
|
|
input_variables: Optional[List[str]] = None, |
|
|
|
) -> PromptTemplate: |
|
|
|
"""Create prompt in the style of the zero shot agent. |
|
|
|
|
|
|
|
Args: |
|
|
|
tools: List of tools the agent will have access to, used to format the |
|
|
|
prompt. |
|
|
|
prefix: String to put before the list of tools. |
|
|
|
input_variables: List of input variables the final prompt will expect. |
|
|
|
|
|
|
|
Returns: |
|
|
|
A PromptTemplate with the template assembled from the pieces here. |
|
|
|
""" |
|
|
|
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. |
|
|
|
Question: {input} |
|
|
|
Thought: {agent_scratchpad} |
|
|
|
""" |
|
|
|
|
|
|
|
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) |
|
|
|
tool_names = ", ".join([tool.name for tool in tools]) |
|
|
|
format_instructions = format_instructions.format(tool_names=tool_names) |
|
|
|
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) |
|
|
|
if input_variables is None: |
|
|
|
input_variables = ["input", "agent_scratchpad"] |
|
|
|
return PromptTemplate(template=template, input_variables=input_variables) |
|
|
|
|
|
|
|
def _construct_scratchpad( |
|
|
|
self, intermediate_steps: List[Tuple[AgentAction, str]] |
|
|
|
) -> str: |
|
|
|
agent_scratchpad = "" |
|
|
|
for action, observation in intermediate_steps: |
|
|
|
agent_scratchpad += action.log |
|
|
|
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" |
|
|
|
|
|
|
|
if not isinstance(agent_scratchpad, str): |
|
|
|
raise ValueError("agent_scratchpad should be of type string.") |
|
|
|
if agent_scratchpad: |
|
|
|
llm_chain = cast(LLMChain, self.llm_chain) |
|
|
|
if llm_chain.model_instance.model_mode == ModelMode.CHAT: |
|
|
|
return ( |
|
|
|
f"This was your previous work " |
|
|
|
f"(but I haven't seen any of it! I only see what " |
|
|
|
f"you return as final answer):\n{agent_scratchpad}" |
|
|
|
) |
|
|
|
else: |
|
|
|
return agent_scratchpad |
|
|
|
else: |
|
|
|
return agent_scratchpad |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_llm_and_tools( |
|
|
|
cls, |
|
|
|
@@ -160,15 +222,23 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): |
|
|
|
) -> Agent: |
|
|
|
"""Construct an agent from an LLM and tools.""" |
|
|
|
cls._validate_tools(tools) |
|
|
|
prompt = cls.create_prompt( |
|
|
|
tools, |
|
|
|
prefix=prefix, |
|
|
|
suffix=suffix, |
|
|
|
human_message_template=human_message_template, |
|
|
|
format_instructions=format_instructions, |
|
|
|
input_variables=input_variables, |
|
|
|
memory_prompts=memory_prompts, |
|
|
|
) |
|
|
|
if model_instance.model_mode == ModelMode.CHAT: |
|
|
|
prompt = cls.create_prompt( |
|
|
|
tools, |
|
|
|
prefix=prefix, |
|
|
|
suffix=suffix, |
|
|
|
human_message_template=human_message_template, |
|
|
|
format_instructions=format_instructions, |
|
|
|
input_variables=input_variables, |
|
|
|
memory_prompts=memory_prompts, |
|
|
|
) |
|
|
|
else: |
|
|
|
prompt = cls.create_completion_prompt( |
|
|
|
tools, |
|
|
|
prefix=prefix, |
|
|
|
format_instructions=format_instructions, |
|
|
|
input_variables=input_variables |
|
|
|
) |
|
|
|
llm_chain = LLMChain( |
|
|
|
model_instance=model_instance, |
|
|
|
prompt=prompt, |