| dataset_retrieval = DatasetRetrieval() | dataset_retrieval = DatasetRetrieval() | ||||
| context = dataset_retrieval.retrieve( | context = dataset_retrieval.retrieve( | ||||
| app_id=app_record.id, | |||||
| user_id=application_generate_entity.user_id, | |||||
| tenant_id=app_record.tenant_id, | tenant_id=app_record.tenant_id, | ||||
| model_config=application_generate_entity.model_config, | model_config=application_generate_entity.model_config, | ||||
| config=app_config.dataset, | config=app_config.dataset, |
| dataset_retrieval = DatasetRetrieval() | dataset_retrieval = DatasetRetrieval() | ||||
| context = dataset_retrieval.retrieve( | context = dataset_retrieval.retrieve( | ||||
| app_id=app_record.id, | |||||
| user_id=application_generate_entity.user_id, | |||||
| tenant_id=app_record.tenant_id, | tenant_id=app_record.tenant_id, | ||||
| model_config=application_generate_entity.model_config, | model_config=application_generate_entity.model_config, | ||||
| config=dataset_config, | config=dataset_config, |
| import time | |||||
| from collections.abc import Mapping | |||||
| from typing import Any, Optional | |||||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||||
| from langchain.chat_models.base import SimpleChatModel | |||||
| from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult | |||||
| class FakeLLM(SimpleChatModel): | |||||
| """Fake ChatModel for testing purposes.""" | |||||
| streaming: bool = False | |||||
| """Whether to stream the results or not.""" | |||||
| response: str | |||||
| @property | |||||
| def _llm_type(self) -> str: | |||||
| return "fake-chat-model" | |||||
| def _call( | |||||
| self, | |||||
| messages: list[BaseMessage], | |||||
| stop: Optional[list[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> str: | |||||
| """First try to lookup in queries, else return 'foo' or 'bar'.""" | |||||
| return self.response | |||||
| @property | |||||
| def _identifying_params(self) -> Mapping[str, Any]: | |||||
| return {"response": self.response} | |||||
| def get_num_tokens(self, text: str) -> int: | |||||
| return 0 | |||||
| def _generate( | |||||
| self, | |||||
| messages: list[BaseMessage], | |||||
| stop: Optional[list[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> ChatResult: | |||||
| output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) | |||||
| if self.streaming: | |||||
| for token in output_str: | |||||
| if run_manager: | |||||
| run_manager.on_llm_new_token(token) | |||||
| time.sleep(0.01) | |||||
| message = AIMessage(content=output_str) | |||||
| generation = ChatGeneration(message=message) | |||||
| llm_output = {"token_usage": { | |||||
| 'prompt_tokens': 0, | |||||
| 'completion_tokens': 0, | |||||
| 'total_tokens': 0, | |||||
| }} | |||||
| return ChatResult(generations=[generation], llm_output=llm_output) |
| from typing import Any, Optional | |||||
| from langchain import LLMChain as LCLLMChain | |||||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||||
| from langchain.schema import Generation, LLMResult | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||||
| from core.model_manager import ModelInstance | |||||
| from core.rag.retrieval.agent.fake_llm import FakeLLM | |||||
| class LLMChain(LCLLMChain): | |||||
| model_config: ModelConfigWithCredentialsEntity | |||||
| """The language model instance to use.""" | |||||
| llm: BaseLanguageModel = FakeLLM(response="") | |||||
| parameters: dict[str, Any] = {} | |||||
| def generate( | |||||
| self, | |||||
| input_list: list[dict[str, Any]], | |||||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||||
| ) -> LLMResult: | |||||
| """Generate LLM result from inputs.""" | |||||
| prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) | |||||
| messages = prompts[0].to_messages() | |||||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||||
| model_instance = ModelInstance( | |||||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||||
| model=self.model_config.model, | |||||
| ) | |||||
| result = model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| stream=False, | |||||
| stop=stop, | |||||
| model_parameters=self.parameters | |||||
| ) | |||||
| generations = [ | |||||
| [Generation(text=result.message.content)] | |||||
| ] | |||||
| return LLMResult(generations=generations) |
| from collections.abc import Sequence | |||||
| from typing import Any, Optional, Union | |||||
| from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent | |||||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||||
| from langchain.callbacks.base import BaseCallbackManager | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||||
| from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage | |||||
| from langchain.tools import BaseTool | |||||
| from pydantic import root_validator | |||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||||
| from core.model_manager import ModelInstance | |||||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||||
| from core.rag.retrieval.agent.fake_llm import FakeLLM | |||||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||||
| """ | |||||
| An Multi Dataset Retrieve Agent driven by Router. | |||||
| """ | |||||
| model_config: ModelConfigWithCredentialsEntity | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| @root_validator | |||||
| def validate_llm(cls, values: dict) -> dict: | |||||
| return values | |||||
| def should_use_agent(self, query: str): | |||||
| """ | |||||
| return should use agent | |||||
| :param query: | |||||
| :return: | |||||
| """ | |||||
| return True | |||||
| def plan( | |||||
| self, | |||||
| intermediate_steps: list[tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| if len(self.tools) == 0: | |||||
| return AgentFinish(return_values={"output": ''}, log='') | |||||
| elif len(self.tools) == 1: | |||||
| tool = next(iter(self.tools)) | |||||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||||
| # output = '' | |||||
| # rst_json = json.loads(rst) | |||||
| # for item in rst_json: | |||||
| # output += f'{item["content"]}\n' | |||||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||||
| if intermediate_steps: | |||||
| _, observation = intermediate_steps[-1] | |||||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||||
| try: | |||||
| agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) | |||||
| if isinstance(agent_decision, AgentAction): | |||||
| tool_inputs = agent_decision.tool_input | |||||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: | |||||
| tool_inputs['query'] = kwargs['input'] | |||||
| agent_decision.tool_input = tool_inputs | |||||
| else: | |||||
| agent_decision.return_values['output'] = '' | |||||
| return agent_decision | |||||
| except Exception as e: | |||||
| raise e | |||||
| def real_plan( | |||||
| self, | |||||
| intermediate_steps: list[tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||||
| selected_inputs = { | |||||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||||
| } | |||||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||||
| prompt = self.prompt.format_prompt(**full_inputs) | |||||
| messages = prompt.to_messages() | |||||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||||
| model_instance = ModelInstance( | |||||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||||
| model=self.model_config.model, | |||||
| ) | |||||
| tools = [] | |||||
| for function in self.functions: | |||||
| tool = PromptMessageTool( | |||||
| **function | |||||
| ) | |||||
| tools.append(tool) | |||||
| result = model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| tools=tools, | |||||
| stream=False, | |||||
| model_parameters={ | |||||
| 'temperature': 0.2, | |||||
| 'top_p': 0.3, | |||||
| 'max_tokens': 1500 | |||||
| } | |||||
| ) | |||||
| ai_message = AIMessage( | |||||
| content=result.message.content or "", | |||||
| additional_kwargs={ | |||||
| 'function_call': { | |||||
| 'id': result.message.tool_calls[0].id, | |||||
| **result.message.tool_calls[0].function.dict() | |||||
| } if result.message.tool_calls else None | |||||
| } | |||||
| ) | |||||
| agent_decision = _parse_ai_message(ai_message) | |||||
| return agent_decision | |||||
| async def aplan( | |||||
| self, | |||||
| intermediate_steps: list[tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| raise NotImplementedError() | |||||
| @classmethod | |||||
| def from_llm_and_tools( | |||||
| cls, | |||||
| model_config: ModelConfigWithCredentialsEntity, | |||||
| tools: Sequence[BaseTool], | |||||
| callback_manager: Optional[BaseCallbackManager] = None, | |||||
| extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, | |||||
| system_message: Optional[SystemMessage] = SystemMessage( | |||||
| content="You are a helpful AI assistant." | |||||
| ), | |||||
| **kwargs: Any, | |||||
| ) -> BaseSingleActionAgent: | |||||
| prompt = cls.create_prompt( | |||||
| extra_prompt_messages=extra_prompt_messages, | |||||
| system_message=system_message, | |||||
| ) | |||||
| return cls( | |||||
| model_config=model_config, | |||||
| llm=FakeLLM(response=''), | |||||
| prompt=prompt, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| **kwargs, | |||||
| ) |
| import re | |||||
| from collections.abc import Sequence | |||||
| from typing import Any, Optional, Union, cast | |||||
| from langchain import BasePromptTemplate, PromptTemplate | |||||
| from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent | |||||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||||
| from langchain.callbacks.base import BaseCallbackManager | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate | |||||
| from langchain.schema import AgentAction, AgentFinish, OutputParserException | |||||
| from langchain.tools import BaseTool | |||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||||
| from core.rag.retrieval.agent.llm_chain import LLMChain | |||||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||||
| Valid "action" values: "Final Answer" or {tool_names} | |||||
| Provide only ONE action per $JSON_BLOB, as shown: | |||||
| ``` | |||||
| {{{{ | |||||
| "action": $TOOL_NAME, | |||||
| "action_input": $INPUT | |||||
| }}}} | |||||
| ``` | |||||
| Follow this format: | |||||
| Question: input question to answer | |||||
| Thought: consider previous and subsequent steps | |||||
| Action: | |||||
| ``` | |||||
| $JSON_BLOB | |||||
| ``` | |||||
| Observation: action result | |||||
| ... (repeat Thought/Action/Observation N times) | |||||
| Thought: I know what to respond | |||||
| Action: | |||||
| ``` | |||||
| {{{{ | |||||
| "action": "Final Answer", | |||||
| "action_input": "Final response to human" | |||||
| }}}} | |||||
| ```""" | |||||
| class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||||
| dataset_tools: Sequence[BaseTool] | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| def should_use_agent(self, query: str): | |||||
| """ | |||||
| return should use agent | |||||
| Using the ReACT mode to determine whether an agent is needed is costly, | |||||
| so it's better to just use an Agent for reasoning, which is cheaper. | |||||
| :param query: | |||||
| :return: | |||||
| """ | |||||
| return True | |||||
| def plan( | |||||
| self, | |||||
| intermediate_steps: list[tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, | |||||
| along with observations | |||||
| callbacks: Callbacks to run. | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| if len(self.dataset_tools) == 0: | |||||
| return AgentFinish(return_values={"output": ''}, log='') | |||||
| elif len(self.dataset_tools) == 1: | |||||
| tool = next(iter(self.dataset_tools)) | |||||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||||
| if intermediate_steps: | |||||
| _, observation = intermediate_steps[-1] | |||||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||||
| full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) | |||||
| try: | |||||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||||
| except Exception as e: | |||||
| raise e | |||||
| try: | |||||
| agent_decision = self.output_parser.parse(full_output) | |||||
| if isinstance(agent_decision, AgentAction): | |||||
| tool_inputs = agent_decision.tool_input | |||||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | |||||
| tool_inputs['query'] = kwargs['input'] | |||||
| agent_decision.tool_input = tool_inputs | |||||
| elif isinstance(tool_inputs, str): | |||||
| agent_decision.tool_input = kwargs['input'] | |||||
| else: | |||||
| agent_decision.return_values['output'] = '' | |||||
| return agent_decision | |||||
| except OutputParserException: | |||||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||||
| "I don't know how to respond to that."}, "") | |||||
| @classmethod | |||||
| def create_prompt( | |||||
| cls, | |||||
| tools: Sequence[BaseTool], | |||||
| prefix: str = PREFIX, | |||||
| suffix: str = SUFFIX, | |||||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[list[str]] = None, | |||||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||||
| ) -> BasePromptTemplate: | |||||
| tool_strings = [] | |||||
| for tool in tools: | |||||
| args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) | |||||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") | |||||
| formatted_tools = "\n".join(tool_strings) | |||||
| unique_tool_names = set(tool.name for tool in tools) | |||||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | |||||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |||||
| if input_variables is None: | |||||
| input_variables = ["input", "agent_scratchpad"] | |||||
| _memory_prompts = memory_prompts or [] | |||||
| messages = [ | |||||
| SystemMessagePromptTemplate.from_template(template), | |||||
| *_memory_prompts, | |||||
| HumanMessagePromptTemplate.from_template(human_message_template), | |||||
| ] | |||||
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |||||
| @classmethod | |||||
| def create_completion_prompt( | |||||
| cls, | |||||
| tools: Sequence[BaseTool], | |||||
| prefix: str = PREFIX, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[list[str]] = None, | |||||
| ) -> PromptTemplate: | |||||
| """Create prompt in the style of the zero shot agent. | |||||
| Args: | |||||
| tools: List of tools the agent will have access to, used to format the | |||||
| prompt. | |||||
| prefix: String to put before the list of tools. | |||||
| input_variables: List of input variables the final prompt will expect. | |||||
| Returns: | |||||
| A PromptTemplate with the template assembled from the pieces here. | |||||
| """ | |||||
| suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. | |||||
| Question: {input} | |||||
| Thought: {agent_scratchpad} | |||||
| """ | |||||
| tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) | |||||
| tool_names = ", ".join([tool.name for tool in tools]) | |||||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||||
| template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) | |||||
| if input_variables is None: | |||||
| input_variables = ["input", "agent_scratchpad"] | |||||
| return PromptTemplate(template=template, input_variables=input_variables) | |||||
| def _construct_scratchpad( | |||||
| self, intermediate_steps: list[tuple[AgentAction, str]] | |||||
| ) -> str: | |||||
| agent_scratchpad = "" | |||||
| for action, observation in intermediate_steps: | |||||
| agent_scratchpad += action.log | |||||
| agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" | |||||
| if not isinstance(agent_scratchpad, str): | |||||
| raise ValueError("agent_scratchpad should be of type string.") | |||||
| if agent_scratchpad: | |||||
| llm_chain = cast(LLMChain, self.llm_chain) | |||||
| if llm_chain.model_config.mode == "chat": | |||||
| return ( | |||||
| f"This was your previous work " | |||||
| f"(but I haven't seen any of it! I only see what " | |||||
| f"you return as final answer):\n{agent_scratchpad}" | |||||
| ) | |||||
| else: | |||||
| return agent_scratchpad | |||||
| else: | |||||
| return agent_scratchpad | |||||
| @classmethod | |||||
| def from_llm_and_tools( | |||||
| cls, | |||||
| model_config: ModelConfigWithCredentialsEntity, | |||||
| tools: Sequence[BaseTool], | |||||
| callback_manager: Optional[BaseCallbackManager] = None, | |||||
| output_parser: Optional[AgentOutputParser] = None, | |||||
| prefix: str = PREFIX, | |||||
| suffix: str = SUFFIX, | |||||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[list[str]] = None, | |||||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Agent: | |||||
| """Construct an agent from an LLM and tools.""" | |||||
| cls._validate_tools(tools) | |||||
| if model_config.mode == "chat": | |||||
| prompt = cls.create_prompt( | |||||
| tools, | |||||
| prefix=prefix, | |||||
| suffix=suffix, | |||||
| human_message_template=human_message_template, | |||||
| format_instructions=format_instructions, | |||||
| input_variables=input_variables, | |||||
| memory_prompts=memory_prompts, | |||||
| ) | |||||
| else: | |||||
| prompt = cls.create_completion_prompt( | |||||
| tools, | |||||
| prefix=prefix, | |||||
| format_instructions=format_instructions, | |||||
| input_variables=input_variables | |||||
| ) | |||||
| llm_chain = LLMChain( | |||||
| model_config=model_config, | |||||
| prompt=prompt, | |||||
| callback_manager=callback_manager, | |||||
| parameters={ | |||||
| 'temperature': 0.2, | |||||
| 'top_p': 0.3, | |||||
| 'max_tokens': 1500 | |||||
| } | |||||
| ) | |||||
| tool_names = [tool.name for tool in tools] | |||||
| _output_parser = output_parser | |||||
| return cls( | |||||
| llm_chain=llm_chain, | |||||
| allowed_tools=tool_names, | |||||
| output_parser=_output_parser, | |||||
| dataset_tools=tools, | |||||
| **kwargs, | |||||
| ) |
| import logging | |||||
| from typing import Optional, Union | |||||
| from langchain.agents import AgentExecutor as LCAgentExecutor | |||||
| from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.tools import BaseTool | |||||
| from pydantic import BaseModel, Extra | |||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||||
| from core.entities.agent_entities import PlanningStrategy | |||||
| from core.entities.message_entities import prompt_messages_to_lc_messages | |||||
| from core.helper import moderation | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||||
| from core.model_runtime.errors.invoke import InvokeError | |||||
| from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||||
| from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser | |||||
| from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||||
| from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | |||||
| class AgentConfiguration(BaseModel): | |||||
| strategy: PlanningStrategy | |||||
| model_config: ModelConfigWithCredentialsEntity | |||||
| tools: list[BaseTool] | |||||
| summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None | |||||
| memory: Optional[TokenBufferMemory] = None | |||||
| callbacks: Callbacks = None | |||||
| max_iterations: int = 6 | |||||
| max_execution_time: Optional[float] = None | |||||
| early_stopping_method: str = "generate" | |||||
| # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| arbitrary_types_allowed = True | |||||
| class AgentExecuteResult(BaseModel): | |||||
| strategy: PlanningStrategy | |||||
| output: Optional[str] | |||||
| configuration: AgentConfiguration | |||||
| class AgentExecutor: | |||||
| def __init__(self, configuration: AgentConfiguration): | |||||
| self.configuration = configuration | |||||
| self.agent = self._init_agent() | |||||
| def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: | |||||
| if self.configuration.strategy == PlanningStrategy.ROUTER: | |||||
| self.configuration.tools = [t for t in self.configuration.tools | |||||
| if isinstance(t, DatasetRetrieverTool) | |||||
| or isinstance(t, DatasetMultiRetrieverTool)] | |||||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | |||||
| model_config=self.configuration.model_config, | |||||
| tools=self.configuration.tools, | |||||
| extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) | |||||
| if self.configuration.memory else None, | |||||
| verbose=True | |||||
| ) | |||||
| elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | |||||
| self.configuration.tools = [t for t in self.configuration.tools | |||||
| if isinstance(t, DatasetRetrieverTool) | |||||
| or isinstance(t, DatasetMultiRetrieverTool)] | |||||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | |||||
| model_config=self.configuration.model_config, | |||||
| tools=self.configuration.tools, | |||||
| output_parser=StructuredChatOutputParser(), | |||||
| verbose=True | |||||
| ) | |||||
| else: | |||||
| raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") | |||||
| return agent | |||||
| def should_use_agent(self, query: str) -> bool: | |||||
| return self.agent.should_use_agent(query) | |||||
| def run(self, query: str) -> AgentExecuteResult: | |||||
| moderation_result = moderation.check_moderation( | |||||
| self.configuration.model_config, | |||||
| query | |||||
| ) | |||||
| if moderation_result: | |||||
| return AgentExecuteResult( | |||||
| output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", | |||||
| strategy=self.configuration.strategy, | |||||
| configuration=self.configuration | |||||
| ) | |||||
| agent_executor = LCAgentExecutor.from_agent_and_tools( | |||||
| agent=self.agent, | |||||
| tools=self.configuration.tools, | |||||
| max_iterations=self.configuration.max_iterations, | |||||
| max_execution_time=self.configuration.max_execution_time, | |||||
| early_stopping_method=self.configuration.early_stopping_method, | |||||
| callbacks=self.configuration.callbacks | |||||
| ) | |||||
| try: | |||||
| output = agent_executor.run(input=query) | |||||
| except InvokeError as ex: | |||||
| raise ex | |||||
| except Exception as ex: | |||||
| logging.exception("agent_executor run failed") | |||||
| output = None | |||||
| return AgentExecuteResult( | |||||
| output=output, | |||||
| strategy=self.configuration.strategy, | |||||
| configuration=self.configuration | |||||
| ) |
| import threading | |||||
| from typing import Optional, cast | from typing import Optional, cast | ||||
| from langchain.tools import BaseTool | |||||
| from flask import Flask, current_app | |||||
| from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity | from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | ||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.entities.agent_entities import PlanningStrategy | from core.entities.agent_entities import PlanningStrategy | ||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_runtime.entities.model_entities import ModelFeature | |||||
| from core.model_manager import ModelInstance, ModelManager | |||||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor | |||||
| from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.models.document import Document | |||||
| from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter | |||||
| from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter | |||||
| from core.rerank.rerank import RerankRunner | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset | |||||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enabled': False | |||||
| } | |||||
| class DatasetRetrieval: | class DatasetRetrieval: | ||||
| def retrieve(self, tenant_id: str, | |||||
| def retrieve(self, app_id: str, user_id: str, tenant_id: str, | |||||
| model_config: ModelConfigWithCredentialsEntity, | model_config: ModelConfigWithCredentialsEntity, | ||||
| config: DatasetEntity, | config: DatasetEntity, | ||||
| query: str, | query: str, | ||||
| memory: Optional[TokenBufferMemory] = None) -> Optional[str]: | memory: Optional[TokenBufferMemory] = None) -> Optional[str]: | ||||
| """ | """ | ||||
| Retrieve dataset. | Retrieve dataset. | ||||
| :param app_id: app_id | |||||
| :param user_id: user_id | |||||
| :param tenant_id: tenant id | :param tenant_id: tenant id | ||||
| :param model_config: model config | :param model_config: model config | ||||
| :param config: dataset config | :param config: dataset config | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| dataset_ids = config.dataset_ids | dataset_ids = config.dataset_ids | ||||
| if len(dataset_ids) == 0: | |||||
| return None | |||||
| retrieve_config = config.retrieve_config | retrieve_config = config.retrieve_config | ||||
| # check model is support tool calling | # check model is support tool calling | ||||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | model_type_instance = model_config.provider_model_bundle.model_type_instance | ||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | model_type_instance = cast(LargeLanguageModel, model_type_instance) | ||||
| model_manager = ModelManager() | |||||
| model_instance = model_manager.get_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| model_type=ModelType.LLM, | |||||
| provider=model_config.provider, | |||||
| model=model_config.model | |||||
| ) | |||||
| # get model schema | # get model schema | ||||
| model_schema = model_type_instance.get_model_schema( | model_schema = model_type_instance.get_model_schema( | ||||
| model=model_config.model, | model=model_config.model, | ||||
| if ModelFeature.TOOL_CALL in features \ | if ModelFeature.TOOL_CALL in features \ | ||||
| or ModelFeature.MULTI_TOOL_CALL in features: | or ModelFeature.MULTI_TOOL_CALL in features: | ||||
| planning_strategy = PlanningStrategy.ROUTER | planning_strategy = PlanningStrategy.ROUTER | ||||
| dataset_retriever_tools = self.to_dataset_retriever_tool( | |||||
| tenant_id=tenant_id, | |||||
| dataset_ids=dataset_ids, | |||||
| retrieve_config=retrieve_config, | |||||
| return_resource=show_retrieve_source, | |||||
| invoke_from=invoke_from, | |||||
| hit_callback=hit_callback | |||||
| ) | |||||
| if len(dataset_retriever_tools) == 0: | |||||
| return None | |||||
| agent_configuration = AgentConfiguration( | |||||
| strategy=planning_strategy, | |||||
| model_config=model_config, | |||||
| tools=dataset_retriever_tools, | |||||
| memory=memory, | |||||
| max_iterations=10, | |||||
| max_execution_time=400.0, | |||||
| early_stopping_method="generate" | |||||
| ) | |||||
| agent_executor = AgentExecutor(agent_configuration) | |||||
| should_use_agent = agent_executor.should_use_agent(query) | |||||
| if not should_use_agent: | |||||
| return None | |||||
| result = agent_executor.run(query) | |||||
| return result.output | |||||
| def to_dataset_retriever_tool(self, tenant_id: str, | |||||
| dataset_ids: list[str], | |||||
| retrieve_config: DatasetRetrieveConfigEntity, | |||||
| return_resource: bool, | |||||
| invoke_from: InvokeFrom, | |||||
| hit_callback: DatasetIndexToolCallbackHandler) \ | |||||
| -> Optional[list[BaseTool]]: | |||||
| """ | |||||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||||
| :param tenant_id: tenant id | |||||
| :param dataset_ids: dataset ids | |||||
| :param retrieve_config: retrieve config | |||||
| :param return_resource: return resource | |||||
| :param invoke_from: invoke from | |||||
| :param hit_callback: hit callback | |||||
| """ | |||||
| tools = [] | |||||
| available_datasets = [] | available_datasets = [] | ||||
| for dataset_id in dataset_ids: | for dataset_id in dataset_ids: | ||||
| # get dataset from dataset id | # get dataset from dataset id | ||||
| continue | continue | ||||
| available_datasets.append(dataset) | available_datasets.append(dataset) | ||||
| all_documents = [] | |||||
| user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' | |||||
| if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | ||||
| all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query, | |||||
| model_instance, | |||||
| model_config, planning_strategy) | |||||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | |||||
| all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from, | |||||
| available_datasets, query, retrieve_config.top_k, | |||||
| retrieve_config.score_threshold, | |||||
| retrieve_config.reranking_model.get('reranking_provider_name'), | |||||
| retrieve_config.reranking_model.get('reranking_model_name')) | |||||
| document_score_list = {} | |||||
| for item in all_documents: | |||||
| if 'score' in item.metadata and item.metadata['score']: | |||||
| document_score_list[item.metadata['doc_id']] = item.metadata['score'] | |||||
| document_context_list = [] | |||||
| index_node_ids = [document.metadata['doc_id'] for document in all_documents] | |||||
| segments = DocumentSegment.query.filter( | |||||
| DocumentSegment.dataset_id.in_(dataset_ids), | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids) | |||||
| ).all() | |||||
| if segments: | |||||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||||
| sorted_segments = sorted(segments, | |||||
| key=lambda segment: index_node_id_to_position.get(segment.index_node_id, | |||||
| float('inf'))) | |||||
| for segment in sorted_segments: | |||||
| if segment.answer: | |||||
| document_context_list.append(f'question:{segment.content} answer:{segment.answer}') | |||||
| else: | |||||
| document_context_list.append(segment.content) | |||||
| if show_retrieve_source: | |||||
| context_list = [] | |||||
| resource_number = 1 | |||||
| for segment in sorted_segments: | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=segment.dataset_id | |||||
| ).first() | |||||
| document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).first() | |||||
| if dataset and document: | |||||
| source = { | |||||
| 'position': resource_number, | |||||
| 'dataset_id': dataset.id, | |||||
| 'dataset_name': dataset.name, | |||||
| 'document_id': document.id, | |||||
| 'document_name': document.name, | |||||
| 'data_source_type': document.data_source_type, | |||||
| 'segment_id': segment.id, | |||||
| 'retriever_from': invoke_from.to_source(), | |||||
| 'score': document_score_list.get(segment.index_node_id, None) | |||||
| } | |||||
| if invoke_from.to_source() == 'dev': | |||||
| source['hit_count'] = segment.hit_count | |||||
| source['word_count'] = segment.word_count | |||||
| source['segment_position'] = segment.position | |||||
| source['index_node_hash'] = segment.index_node_hash | |||||
| if segment.answer: | |||||
| source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' | |||||
| else: | |||||
| source['content'] = segment.content | |||||
| context_list.append(source) | |||||
| resource_number += 1 | |||||
| if hit_callback: | |||||
| hit_callback.return_retriever_resource_info(context_list) | |||||
| return str("\n".join(document_context_list)) | |||||
| return '' | |||||
| def single_retrieve(self, app_id: str, | |||||
| tenant_id: str, | |||||
| user_id: str, | |||||
| user_from: str, | |||||
| available_datasets: list, | |||||
| query: str, | |||||
| model_instance: ModelInstance, | |||||
| model_config: ModelConfigWithCredentialsEntity, | |||||
| planning_strategy: PlanningStrategy, | |||||
| ): | |||||
| tools = [] | |||||
| for dataset in available_datasets: | |||||
| description = dataset.description | |||||
| if not description: | |||||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||||
| description = description.replace('\n', '').replace('\r', '') | |||||
| message_tool = PromptMessageTool( | |||||
| name=dataset.id, | |||||
| description=description, | |||||
| parameters={ | |||||
| "type": "object", | |||||
| "properties": {}, | |||||
| "required": [], | |||||
| } | |||||
| ) | |||||
| tools.append(message_tool) | |||||
| dataset_id = None | |||||
| if planning_strategy == PlanningStrategy.REACT_ROUTER: | |||||
| react_multi_dataset_router = ReactMultiDatasetRouter() | |||||
| dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, | |||||
| user_id, tenant_id) | |||||
| elif planning_strategy == PlanningStrategy.ROUTER: | |||||
| function_call_router = FunctionCallMultiDatasetRouter() | |||||
| dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) | |||||
| if dataset_id: | |||||
| # get retrieval model config | # get retrieval model config | ||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enabled': False | |||||
| } | |||||
| for dataset in available_datasets: | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| if dataset: | |||||
| retrieval_model_config = dataset.retrieval_model \ | retrieval_model_config = dataset.retrieval_model \ | ||||
| if dataset.retrieval_model else default_retrieval_model | if dataset.retrieval_model else default_retrieval_model | ||||
| # get top k | # get top k | ||||
| top_k = retrieval_model_config['top_k'] | top_k = retrieval_model_config['top_k'] | ||||
| # get retrieval method | |||||
| if dataset.indexing_technique == "economy": | |||||
| retrival_method = 'keyword_search' | |||||
| else: | |||||
| retrival_method = retrieval_model_config['search_method'] | |||||
| # get reranking model | |||||
| reranking_model = retrieval_model_config['reranking_model'] \ | |||||
| if retrieval_model_config['reranking_enable'] else None | |||||
| # get score threshold | # get score threshold | ||||
| score_threshold = None | |||||
| score_threshold = .0 | |||||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | ||||
| if score_threshold_enabled: | if score_threshold_enabled: | ||||
| score_threshold = retrieval_model_config.get("score_threshold") | score_threshold = retrieval_model_config.get("score_threshold") | ||||
| tool = DatasetRetrieverTool.from_dataset( | |||||
| dataset=dataset, | |||||
| top_k=top_k, | |||||
| score_threshold=score_threshold, | |||||
| hit_callbacks=[hit_callback], | |||||
| return_resource=return_resource, | |||||
| retriever_from=invoke_from.to_source() | |||||
| ) | |||||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=top_k, score_threshold=score_threshold, | |||||
| reranking_model=reranking_model) | |||||
| self._on_query(query, [dataset_id], app_id, user_from, user_id) | |||||
| if results: | |||||
| self._on_retrival_end(results) | |||||
| return results | |||||
| return [] | |||||
| tools.append(tool) | |||||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | |||||
| tool = DatasetMultiRetrieverTool.from_dataset( | |||||
| dataset_ids=[dataset.id for dataset in available_datasets], | |||||
| tenant_id=tenant_id, | |||||
| top_k=retrieve_config.top_k or 2, | |||||
| score_threshold=retrieve_config.score_threshold, | |||||
| hit_callbacks=[hit_callback], | |||||
| return_resource=return_resource, | |||||
| retriever_from=invoke_from.to_source(), | |||||
| reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), | |||||
| reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') | |||||
| def multiple_retrieve(self, | |||||
| app_id: str, | |||||
| tenant_id: str, | |||||
| user_id: str, | |||||
| user_from: str, | |||||
| available_datasets: list, | |||||
| query: str, | |||||
| top_k: int, | |||||
| score_threshold: float, | |||||
| reranking_provider_name: str, | |||||
| reranking_model_name: str): | |||||
| threads = [] | |||||
| all_documents = [] | |||||
| dataset_ids = [dataset.id for dataset in available_datasets] | |||||
| for dataset in available_datasets: | |||||
| retrieval_thread = threading.Thread(target=self._retriever, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': dataset.id, | |||||
| 'query': query, | |||||
| 'top_k': top_k, | |||||
| 'all_documents': all_documents, | |||||
| }) | |||||
| threads.append(retrieval_thread) | |||||
| retrieval_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # do rerank for searched documents | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| provider=reranking_provider_name, | |||||
| model_type=ModelType.RERANK, | |||||
| model=reranking_model_name | |||||
| ) | |||||
| rerank_runner = RerankRunner(rerank_model_instance) | |||||
| all_documents = rerank_runner.run(query, all_documents, | |||||
| score_threshold, | |||||
| top_k) | |||||
| self._on_query(query, dataset_ids, app_id, user_from, user_id) | |||||
| if all_documents: | |||||
| self._on_retrival_end(all_documents) | |||||
| return all_documents | |||||
| def _on_retrival_end(self, documents: list[Document]) -> None: | |||||
| """Handle retrival end.""" | |||||
| for document in documents: | |||||
| query = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.index_node_id == document.metadata['doc_id'] | |||||
| ) | ) | ||||
| tools.append(tool) | |||||
| # if 'dataset_id' in document.metadata: | |||||
| if 'dataset_id' in document.metadata: | |||||
| query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) | |||||
| # add hit count to document segment | |||||
| query.update( | |||||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||||
| synchronize_session=False | |||||
| ) | |||||
| db.session.commit() | |||||
| def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None: | |||||
| """ | |||||
| Handle query. | |||||
| """ | |||||
| if not query: | |||||
| return | |||||
| for dataset_id in dataset_ids: | |||||
| dataset_query = DatasetQuery( | |||||
| dataset_id=dataset_id, | |||||
| content=query, | |||||
| source='app', | |||||
| source_app_id=app_id, | |||||
| created_by_role=user_from, | |||||
| created_by=user_id | |||||
| ) | |||||
| db.session.add(dataset_query) | |||||
| db.session.commit() | |||||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| return [] | |||||
| # get retrieval model , if the model is not setting , using default | |||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||||
| if dataset.indexing_technique == "economy": | |||||
| # use keyword table query | |||||
| documents = RetrievalService.retrieve(retrival_method='keyword_search', | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=top_k | |||||
| ) | |||||
| if documents: | |||||
| all_documents.extend(documents) | |||||
| else: | |||||
| if top_k > 0: | |||||
| # retrieval source | |||||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=top_k, | |||||
| score_threshold=retrieval_model['score_threshold'] | |||||
| if retrieval_model['score_threshold_enabled'] else None, | |||||
| reranking_model=retrieval_model['reranking_model'] | |||||
| if retrieval_model['reranking_enable'] else None | |||||
| ) | |||||
| return tools | |||||
| all_documents.extend(documents) |
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool | from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool | ||||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | ||||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage | from core.prompt.entities.advanced_prompt_entities import ChatModelMessage | ||||
| from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser | |||||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | |||||
| from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser | |||||
| from core.workflow.nodes.llm.llm_node import LLMNode | from core.workflow.nodes.llm.llm_node import LLMNode | ||||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | ||||
| self, | self, | ||||
| query: str, | query: str, | ||||
| dataset_tools: list[PromptMessageTool], | dataset_tools: list[PromptMessageTool], | ||||
| node_data: KnowledgeRetrievalNodeData, | |||||
| model_config: ModelConfigWithCredentialsEntity, | model_config: ModelConfigWithCredentialsEntity, | ||||
| model_instance: ModelInstance, | model_instance: ModelInstance, | ||||
| user_id: str, | user_id: str, | ||||
| tenant_id: str, | |||||
| tenant_id: str | |||||
| ) -> Union[str, None]: | ) -> Union[str, None]: | ||||
| """Given input, decided what to do. | """Given input, decided what to do. | ||||
| return dataset_tools[0].name | return dataset_tools[0].name | ||||
| try: | try: | ||||
| return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance, | |||||
| return self._react_invoke(query=query, model_config=model_config, | |||||
| model_instance=model_instance, | |||||
| tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) | tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) | ||||
| except Exception as e: | except Exception as e: | ||||
| return None | return None | ||||
| def _react_invoke( | def _react_invoke( | ||||
| self, | self, | ||||
| query: str, | query: str, | ||||
| node_data: KnowledgeRetrievalNodeData, | |||||
| model_config: ModelConfigWithCredentialsEntity, | model_config: ModelConfigWithCredentialsEntity, | ||||
| model_instance: ModelInstance, | model_instance: ModelInstance, | ||||
| tools: Sequence[PromptMessageTool], | tools: Sequence[PromptMessageTool], | ||||
| model_config=model_config | model_config=model_config | ||||
| ) | ) | ||||
| result_text, usage = self._invoke_llm( | result_text, usage = self._invoke_llm( | ||||
| node_data=node_data, | |||||
| completion_param=model_config.parameters, | |||||
| model_instance=model_instance, | model_instance=model_instance, | ||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| stop=stop, | stop=stop, | ||||
| return agent_decision.tool | return agent_decision.tool | ||||
| return None | return None | ||||
| def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData, | |||||
| def _invoke_llm(self, completion_param: dict, | |||||
| model_instance: ModelInstance, | model_instance: ModelInstance, | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]: | |||||
| stop: list[str], user_id: str, tenant_id: str | |||||
| ) -> tuple[str, LLMUsage]: | |||||
| """ | """ | ||||
| Invoke large language model | Invoke large language model | ||||
| :param node_data: node data | :param node_data: node data | ||||
| """ | """ | ||||
| invoke_result = model_instance.invoke_llm( | invoke_result = model_instance.invoke_llm( | ||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| model_parameters=node_data.single_retrieval_config.model.completion_params, | |||||
| model_parameters=completion_param, | |||||
| stop=stop, | stop=stop, | ||||
| stream=True, | stream=True, | ||||
| user=user_id, | user=user_id, | ||||
| ) -> list[ChatModelMessage]: | ) -> list[ChatModelMessage]: | ||||
| tool_strings = [] | tool_strings = [] | ||||
| for tool in tools: | for tool in tools: | ||||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") | |||||
| tool_strings.append( | |||||
| f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") | |||||
| formatted_tools = "\n".join(tool_strings) | formatted_tools = "\n".join(tool_strings) | ||||
| unique_tool_names = set(tool.name for tool in tools) | unique_tool_names = set(tool.name for tool in tools) | ||||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) |
| import threading | |||||
| from typing import Any, cast | from typing import Any, cast | ||||
| from flask import Flask, current_app | |||||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | from core.app.app_config.entities import DatasetRetrieveConfigEntity | ||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | ||||
| from core.entities.agent_entities import PlanningStrategy | from core.entities.agent_entities import PlanningStrategy | ||||
| from core.entities.model_entities import ModelStatus | from core.entities.model_entities import ModelStatus | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | ||||
| from core.model_manager import ModelInstance, ModelManager | from core.model_manager import ModelInstance, ModelManager | ||||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | from core.model_runtime.entities.model_entities import ModelFeature, ModelType | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rerank.rerank import RerankRunner | |||||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | from core.workflow.entities.base_node_data_entities import BaseNodeData | ||||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | from core.workflow.entities.node_entities import NodeRunResult, NodeType | ||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| from core.workflow.nodes.base_node import BaseNode | from core.workflow.nodes.base_node import BaseNode | ||||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | ||||
| from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter | |||||
| from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment | |||||
| from models.dataset import Dataset, Document, DocumentSegment | |||||
| from models.workflow import WorkflowNodeExecutionStatus | from models.workflow import WorkflowNodeExecutionStatus | ||||
| default_retrieval_model = { | default_retrieval_model = { | ||||
| available_datasets.append(dataset) | available_datasets.append(dataset) | ||||
| all_documents = [] | all_documents = [] | ||||
| dataset_retrieval = DatasetRetrieval() | |||||
| if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: | if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: | ||||
| all_documents = self._single_retrieve(available_datasets, node_data, query) | |||||
| # fetch model config | |||||
| model_instance, model_config = self._fetch_model_config(node_data) | |||||
| # check model is support tool calling | |||||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| # get model schema | |||||
| model_schema = model_type_instance.get_model_schema( | |||||
| model=model_config.model, | |||||
| credentials=model_config.credentials | |||||
| ) | |||||
| if model_schema: | |||||
| planning_strategy = PlanningStrategy.REACT_ROUTER | |||||
| features = model_schema.features | |||||
| if features: | |||||
| if ModelFeature.TOOL_CALL in features \ | |||||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||||
| planning_strategy = PlanningStrategy.ROUTER | |||||
| all_documents = dataset_retrieval.single_retrieve( | |||||
| available_datasets=available_datasets, | |||||
| tenant_id=self.tenant_id, | |||||
| user_id=self.user_id, | |||||
| app_id=self.app_id, | |||||
| user_from=self.user_from.value, | |||||
| query=query, | |||||
| model_config=model_config, | |||||
| model_instance=model_instance, | |||||
| planning_strategy=planning_strategy | |||||
| ) | |||||
| elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: | elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: | ||||
| all_documents = self._multiple_retrieve(available_datasets, node_data, query) | |||||
| all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, | |||||
| self.user_from.value, | |||||
| available_datasets, query, | |||||
| node_data.multiple_retrieval_config.top_k, | |||||
| node_data.multiple_retrieval_config.score_threshold, | |||||
| node_data.multiple_retrieval_config.reranking_model.provider, | |||||
| node_data.multiple_retrieval_config.reranking_model.model) | |||||
| context_list = [] | context_list = [] | ||||
| if all_documents: | if all_documents: | ||||
| variable_mapping['query'] = node_data.query_variable_selector | variable_mapping['query'] = node_data.query_variable_selector | ||||
| return variable_mapping | return variable_mapping | ||||
| def _single_retrieve(self, available_datasets, node_data, query): | |||||
| tools = [] | |||||
| for dataset in available_datasets: | |||||
| description = dataset.description | |||||
| if not description: | |||||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||||
| description = description.replace('\n', '').replace('\r', '') | |||||
| message_tool = PromptMessageTool( | |||||
| name=dataset.id, | |||||
| description=description, | |||||
| parameters={ | |||||
| "type": "object", | |||||
| "properties": {}, | |||||
| "required": [], | |||||
| } | |||||
| ) | |||||
| tools.append(message_tool) | |||||
| # fetch model config | |||||
| model_instance, model_config = self._fetch_model_config(node_data) | |||||
| # check model is support tool calling | |||||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| # get model schema | |||||
| model_schema = model_type_instance.get_model_schema( | |||||
| model=model_config.model, | |||||
| credentials=model_config.credentials | |||||
| ) | |||||
| if not model_schema: | |||||
| return None | |||||
| planning_strategy = PlanningStrategy.REACT_ROUTER | |||||
| features = model_schema.features | |||||
| if features: | |||||
| if ModelFeature.TOOL_CALL in features \ | |||||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||||
| planning_strategy = PlanningStrategy.ROUTER | |||||
| dataset_id = None | |||||
| if planning_strategy == PlanningStrategy.REACT_ROUTER: | |||||
| react_multi_dataset_router = ReactMultiDatasetRouter() | |||||
| dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance, | |||||
| self.user_id, self.tenant_id) | |||||
| elif planning_strategy == PlanningStrategy.ROUTER: | |||||
| function_call_router = FunctionCallMultiDatasetRouter() | |||||
| dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) | |||||
| if dataset_id: | |||||
| # get retrieval model config | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| if dataset: | |||||
| retrieval_model_config = dataset.retrieval_model \ | |||||
| if dataset.retrieval_model else default_retrieval_model | |||||
| # get top k | |||||
| top_k = retrieval_model_config['top_k'] | |||||
| # get retrieval method | |||||
| if dataset.indexing_technique == "economy": | |||||
| retrival_method = 'keyword_search' | |||||
| else: | |||||
| retrival_method = retrieval_model_config['search_method'] | |||||
| # get reranking model | |||||
| reranking_model=retrieval_model_config['reranking_model'] \ | |||||
| if retrieval_model_config['reranking_enable'] else None | |||||
| # get score threshold | |||||
| score_threshold = .0 | |||||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | |||||
| if score_threshold_enabled: | |||||
| score_threshold = retrieval_model_config.get("score_threshold") | |||||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=top_k, score_threshold=score_threshold, | |||||
| reranking_model=reranking_model) | |||||
| self._on_query(query, [dataset_id]) | |||||
| if results: | |||||
| self._on_retrival_end(results) | |||||
| return results | |||||
| return [] | |||||
| def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ | def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ | ||||
| ModelInstance, ModelConfigWithCredentialsEntity]: | ModelInstance, ModelConfigWithCredentialsEntity]: | ||||
| """ | """ | ||||
| parameters=completion_params, | parameters=completion_params, | ||||
| stop=stop, | stop=stop, | ||||
| ) | ) | ||||
| def _multiple_retrieve(self, available_datasets, node_data, query): | |||||
| threads = [] | |||||
| all_documents = [] | |||||
| dataset_ids = [dataset.id for dataset in available_datasets] | |||||
| for dataset in available_datasets: | |||||
| retrieval_thread = threading.Thread(target=self._retriever, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': dataset.id, | |||||
| 'query': query, | |||||
| 'top_k': node_data.multiple_retrieval_config.top_k, | |||||
| 'all_documents': all_documents, | |||||
| }) | |||||
| threads.append(retrieval_thread) | |||||
| retrieval_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # do rerank for searched documents | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=self.tenant_id, | |||||
| provider=node_data.multiple_retrieval_config.reranking_model.provider, | |||||
| model_type=ModelType.RERANK, | |||||
| model=node_data.multiple_retrieval_config.reranking_model.model | |||||
| ) | |||||
| rerank_runner = RerankRunner(rerank_model_instance) | |||||
| all_documents = rerank_runner.run(query, all_documents, | |||||
| node_data.multiple_retrieval_config.score_threshold, | |||||
| node_data.multiple_retrieval_config.top_k) | |||||
| self._on_query(query, dataset_ids) | |||||
| if all_documents: | |||||
| self._on_retrival_end(all_documents) | |||||
| return all_documents | |||||
| def _on_retrival_end(self, documents: list[Document]) -> None: | |||||
| """Handle retrival end.""" | |||||
| for document in documents: | |||||
| query = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.index_node_id == document.metadata['doc_id'] | |||||
| ) | |||||
| # if 'dataset_id' in document.metadata: | |||||
| if 'dataset_id' in document.metadata: | |||||
| query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) | |||||
| # add hit count to document segment | |||||
| query.update( | |||||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||||
| synchronize_session=False | |||||
| ) | |||||
| db.session.commit() | |||||
| def _on_query(self, query: str, dataset_ids: list[str]) -> None: | |||||
| """ | |||||
| Handle query. | |||||
| """ | |||||
| if not query: | |||||
| return | |||||
| for dataset_id in dataset_ids: | |||||
| dataset_query = DatasetQuery( | |||||
| dataset_id=dataset_id, | |||||
| content=query, | |||||
| source='app', | |||||
| source_app_id=self.app_id, | |||||
| created_by_role=self.user_from.value, | |||||
| created_by=self.user_id | |||||
| ) | |||||
| db.session.add(dataset_query) | |||||
| db.session.commit() | |||||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.tenant_id == self.tenant_id, | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| return [] | |||||
| # get retrieval model , if the model is not setting , using default | |||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||||
| if dataset.indexing_technique == "economy": | |||||
| # use keyword table query | |||||
| documents = RetrievalService.retrieve(retrival_method='keyword_search', | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=top_k | |||||
| ) | |||||
| if documents: | |||||
| all_documents.extend(documents) | |||||
| else: | |||||
| if top_k > 0: | |||||
| # retrieval source | |||||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=top_k, | |||||
| score_threshold=retrieval_model['score_threshold'] | |||||
| if retrieval_model['score_threshold_enabled'] else None, | |||||
| reranking_model=retrieval_model['reranking_model'] | |||||
| if retrieval_model['reranking_enable'] else None | |||||
| ) | |||||
| all_documents.extend(documents) | |||||