| @@ -156,6 +156,8 @@ class ChatAppRunner(AppRunner): | |||
| dataset_retrieval = DatasetRetrieval() | |||
| context = dataset_retrieval.retrieve( | |||
| app_id=app_record.id, | |||
| user_id=application_generate_entity.user_id, | |||
| tenant_id=app_record.tenant_id, | |||
| model_config=application_generate_entity.model_config, | |||
| config=app_config.dataset, | |||
| @@ -116,6 +116,8 @@ class CompletionAppRunner(AppRunner): | |||
| dataset_retrieval = DatasetRetrieval() | |||
| context = dataset_retrieval.retrieve( | |||
| app_id=app_record.id, | |||
| user_id=application_generate_entity.user_id, | |||
| tenant_id=app_record.tenant_id, | |||
| model_config=application_generate_entity.model_config, | |||
| config=dataset_config, | |||
| @@ -1,59 +0,0 @@ | |||
| import time | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||
| from langchain.chat_models.base import SimpleChatModel | |||
| from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult | |||
| class FakeLLM(SimpleChatModel): | |||
| """Fake ChatModel for testing purposes.""" | |||
| streaming: bool = False | |||
| """Whether to stream the results or not.""" | |||
| response: str | |||
| @property | |||
| def _llm_type(self) -> str: | |||
| return "fake-chat-model" | |||
| def _call( | |||
| self, | |||
| messages: list[BaseMessage], | |||
| stop: Optional[list[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> str: | |||
| """First try to lookup in queries, else return 'foo' or 'bar'.""" | |||
| return self.response | |||
| @property | |||
| def _identifying_params(self) -> Mapping[str, Any]: | |||
| return {"response": self.response} | |||
| def get_num_tokens(self, text: str) -> int: | |||
| return 0 | |||
| def _generate( | |||
| self, | |||
| messages: list[BaseMessage], | |||
| stop: Optional[list[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> ChatResult: | |||
| output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) | |||
| if self.streaming: | |||
| for token in output_str: | |||
| if run_manager: | |||
| run_manager.on_llm_new_token(token) | |||
| time.sleep(0.01) | |||
| message = AIMessage(content=output_str) | |||
| generation = ChatGeneration(message=message) | |||
| llm_output = {"token_usage": { | |||
| 'prompt_tokens': 0, | |||
| 'completion_tokens': 0, | |||
| 'total_tokens': 0, | |||
| }} | |||
| return ChatResult(generations=[generation], llm_output=llm_output) | |||
| @@ -1,46 +0,0 @@ | |||
| from typing import Any, Optional | |||
| from langchain import LLMChain as LCLLMChain | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.schema import Generation, LLMResult | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.model_manager import ModelInstance | |||
| from core.rag.retrieval.agent.fake_llm import FakeLLM | |||
| class LLMChain(LCLLMChain): | |||
| model_config: ModelConfigWithCredentialsEntity | |||
| """The language model instance to use.""" | |||
| llm: BaseLanguageModel = FakeLLM(response="") | |||
| parameters: dict[str, Any] = {} | |||
| def generate( | |||
| self, | |||
| input_list: list[dict[str, Any]], | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> LLMResult: | |||
| """Generate LLM result from inputs.""" | |||
| prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) | |||
| messages = prompts[0].to_messages() | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| stream=False, | |||
| stop=stop, | |||
| model_parameters=self.parameters | |||
| ) | |||
| generations = [ | |||
| [Generation(text=result.message.content)] | |||
| ] | |||
| return LLMResult(generations=generations) | |||
| @@ -1,179 +0,0 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union | |||
| from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent | |||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.rag.retrieval.agent.fake_llm import FakeLLM | |||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| """ | |||
| An Multi Dataset Retrieve Agent driven by Router. | |||
| """ | |||
| model_config: ModelConfigWithCredentialsEntity | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| @root_validator | |||
| def validate_llm(cls, values: dict) -> dict: | |||
| return values | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| :param query: | |||
| :return: | |||
| """ | |||
| return True | |||
| def plan( | |||
| self, | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| if len(self.tools) == 0: | |||
| return AgentFinish(return_values={"output": ''}, log='') | |||
| elif len(self.tools) == 1: | |||
| tool = next(iter(self.tools)) | |||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||
| # output = '' | |||
| # rst_json = json.loads(rst) | |||
| # for item in rst_json: | |||
| # output += f'{item["content"]}\n' | |||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||
| if intermediate_steps: | |||
| _, observation = intermediate_steps[-1] | |||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||
| try: | |||
| agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) | |||
| if isinstance(agent_decision, AgentAction): | |||
| tool_inputs = agent_decision.tool_input | |||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: | |||
| tool_inputs['query'] = kwargs['input'] | |||
| agent_decision.tool_input = tool_inputs | |||
| else: | |||
| agent_decision.return_values['output'] = '' | |||
| return agent_decision | |||
| except Exception as e: | |||
| raise e | |||
| def real_plan( | |||
| self, | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||
| selected_inputs = { | |||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||
| } | |||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| tools = [] | |||
| for function in self.functions: | |||
| tool = PromptMessageTool( | |||
| **function | |||
| ) | |||
| tools.append(tool) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=tools, | |||
| stream=False, | |||
| model_parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.message.content or "", | |||
| additional_kwargs={ | |||
| 'function_call': { | |||
| 'id': result.message.tool_calls[0].id, | |||
| **result.message.tool_calls[0].function.dict() | |||
| } if result.message.tool_calls else None | |||
| } | |||
| ) | |||
| agent_decision = _parse_ai_message(ai_message) | |||
| return agent_decision | |||
| async def aplan( | |||
| self, | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| raise NotImplementedError() | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| **kwargs: Any, | |||
| ) -> BaseSingleActionAgent: | |||
| prompt = cls.create_prompt( | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_config=model_config, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| **kwargs, | |||
| ) | |||
| @@ -1,259 +0,0 @@ | |||
| import re | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from langchain import BasePromptTemplate, PromptTemplate | |||
| from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent | |||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, OutputParserException | |||
| from langchain.tools import BaseTool | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.rag.retrieval.agent.llm_chain import LLMChain | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||
| Valid "action" values: "Final Answer" or {tool_names} | |||
| Provide only ONE action per $JSON_BLOB, as shown: | |||
| ``` | |||
| {{{{ | |||
| "action": $TOOL_NAME, | |||
| "action_input": $INPUT | |||
| }}}} | |||
| ``` | |||
| Follow this format: | |||
| Question: input question to answer | |||
| Thought: consider previous and subsequent steps | |||
| Action: | |||
| ``` | |||
| $JSON_BLOB | |||
| ``` | |||
| Observation: action result | |||
| ... (repeat Thought/Action/Observation N times) | |||
| Thought: I know what to respond | |||
| Action: | |||
| ``` | |||
| {{{{ | |||
| "action": "Final Answer", | |||
| "action_input": "Final response to human" | |||
| }}}} | |||
| ```""" | |||
| class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| dataset_tools: Sequence[BaseTool] | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| Using the ReACT mode to determine whether an agent is needed is costly, | |||
| so it's better to just use an Agent for reasoning, which is cheaper. | |||
| :param query: | |||
| :return: | |||
| """ | |||
| return True | |||
| def plan( | |||
| self, | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, | |||
| along with observations | |||
| callbacks: Callbacks to run. | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| if len(self.dataset_tools) == 0: | |||
| return AgentFinish(return_values={"output": ''}, log='') | |||
| elif len(self.dataset_tools) == 1: | |||
| tool = next(iter(self.dataset_tools)) | |||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||
| if intermediate_steps: | |||
| _, observation = intermediate_steps[-1] | |||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||
| full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) | |||
| try: | |||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||
| except Exception as e: | |||
| raise e | |||
| try: | |||
| agent_decision = self.output_parser.parse(full_output) | |||
| if isinstance(agent_decision, AgentAction): | |||
| tool_inputs = agent_decision.tool_input | |||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | |||
| tool_inputs['query'] = kwargs['input'] | |||
| agent_decision.tool_input = tool_inputs | |||
| elif isinstance(tool_inputs, str): | |||
| agent_decision.tool_input = kwargs['input'] | |||
| else: | |||
| agent_decision.return_values['output'] = '' | |||
| return agent_decision | |||
| except OutputParserException: | |||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||
| "I don't know how to respond to that."}, "") | |||
| @classmethod | |||
| def create_prompt( | |||
| cls, | |||
| tools: Sequence[BaseTool], | |||
| prefix: str = PREFIX, | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| ) -> BasePromptTemplate: | |||
| tool_strings = [] | |||
| for tool in tools: | |||
| args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) | |||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") | |||
| formatted_tools = "\n".join(tool_strings) | |||
| unique_tool_names = set(tool.name for tool in tools) | |||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | |||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |||
| if input_variables is None: | |||
| input_variables = ["input", "agent_scratchpad"] | |||
| _memory_prompts = memory_prompts or [] | |||
| messages = [ | |||
| SystemMessagePromptTemplate.from_template(template), | |||
| *_memory_prompts, | |||
| HumanMessagePromptTemplate.from_template(human_message_template), | |||
| ] | |||
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |||
| @classmethod | |||
| def create_completion_prompt( | |||
| cls, | |||
| tools: Sequence[BaseTool], | |||
| prefix: str = PREFIX, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[list[str]] = None, | |||
| ) -> PromptTemplate: | |||
| """Create prompt in the style of the zero shot agent. | |||
| Args: | |||
| tools: List of tools the agent will have access to, used to format the | |||
| prompt. | |||
| prefix: String to put before the list of tools. | |||
| input_variables: List of input variables the final prompt will expect. | |||
| Returns: | |||
| A PromptTemplate with the template assembled from the pieces here. | |||
| """ | |||
| suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. | |||
| Question: {input} | |||
| Thought: {agent_scratchpad} | |||
| """ | |||
| tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) | |||
| tool_names = ", ".join([tool.name for tool in tools]) | |||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||
| template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) | |||
| if input_variables is None: | |||
| input_variables = ["input", "agent_scratchpad"] | |||
| return PromptTemplate(template=template, input_variables=input_variables) | |||
| def _construct_scratchpad( | |||
| self, intermediate_steps: list[tuple[AgentAction, str]] | |||
| ) -> str: | |||
| agent_scratchpad = "" | |||
| for action, observation in intermediate_steps: | |||
| agent_scratchpad += action.log | |||
| agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" | |||
| if not isinstance(agent_scratchpad, str): | |||
| raise ValueError("agent_scratchpad should be of type string.") | |||
| if agent_scratchpad: | |||
| llm_chain = cast(LLMChain, self.llm_chain) | |||
| if llm_chain.model_config.mode == "chat": | |||
| return ( | |||
| f"This was your previous work " | |||
| f"(but I haven't seen any of it! I only see what " | |||
| f"you return as final answer):\n{agent_scratchpad}" | |||
| ) | |||
| else: | |||
| return agent_scratchpad | |||
| else: | |||
| return agent_scratchpad | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| output_parser: Optional[AgentOutputParser] = None, | |||
| prefix: str = PREFIX, | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| """Construct an agent from an LLM and tools.""" | |||
| cls._validate_tools(tools) | |||
| if model_config.mode == "chat": | |||
| prompt = cls.create_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| suffix=suffix, | |||
| human_message_template=human_message_template, | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables, | |||
| memory_prompts=memory_prompts, | |||
| ) | |||
| else: | |||
| prompt = cls.create_completion_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables | |||
| ) | |||
| llm_chain = LLMChain( | |||
| model_config=model_config, | |||
| prompt=prompt, | |||
| callback_manager=callback_manager, | |||
| parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| tool_names = [tool.name for tool in tools] | |||
| _output_parser = output_parser | |||
| return cls( | |||
| llm_chain=llm_chain, | |||
| allowed_tools=tool_names, | |||
| output_parser=_output_parser, | |||
| dataset_tools=tools, | |||
| **kwargs, | |||
| ) | |||
| @@ -1,117 +0,0 @@ | |||
| import logging | |||
| from typing import Optional, Union | |||
| from langchain.agents import AgentExecutor as LCAgentExecutor | |||
| from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.tools import BaseTool | |||
| from pydantic import BaseModel, Extra | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.entities.message_entities import prompt_messages_to_lc_messages | |||
| from core.helper import moderation | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||
| from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||
| from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | |||
| class AgentConfiguration(BaseModel): | |||
| strategy: PlanningStrategy | |||
| model_config: ModelConfigWithCredentialsEntity | |||
| tools: list[BaseTool] | |||
| summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None | |||
| memory: Optional[TokenBufferMemory] = None | |||
| callbacks: Callbacks = None | |||
| max_iterations: int = 6 | |||
| max_execution_time: Optional[float] = None | |||
| early_stopping_method: str = "generate" | |||
| # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| arbitrary_types_allowed = True | |||
| class AgentExecuteResult(BaseModel): | |||
| strategy: PlanningStrategy | |||
| output: Optional[str] | |||
| configuration: AgentConfiguration | |||
| class AgentExecutor: | |||
| def __init__(self, configuration: AgentConfiguration): | |||
| self.configuration = configuration | |||
| self.agent = self._init_agent() | |||
| def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: | |||
| if self.configuration.strategy == PlanningStrategy.ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools | |||
| if isinstance(t, DatasetRetrieverTool) | |||
| or isinstance(t, DatasetMultiRetrieverTool)] | |||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) | |||
| if self.configuration.memory else None, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools | |||
| if isinstance(t, DatasetRetrieverTool) | |||
| or isinstance(t, DatasetMultiRetrieverTool)] | |||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| verbose=True | |||
| ) | |||
| else: | |||
| raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") | |||
| return agent | |||
| def should_use_agent(self, query: str) -> bool: | |||
| return self.agent.should_use_agent(query) | |||
| def run(self, query: str) -> AgentExecuteResult: | |||
| moderation_result = moderation.check_moderation( | |||
| self.configuration.model_config, | |||
| query | |||
| ) | |||
| if moderation_result: | |||
| return AgentExecuteResult( | |||
| output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", | |||
| strategy=self.configuration.strategy, | |||
| configuration=self.configuration | |||
| ) | |||
| agent_executor = LCAgentExecutor.from_agent_and_tools( | |||
| agent=self.agent, | |||
| tools=self.configuration.tools, | |||
| max_iterations=self.configuration.max_iterations, | |||
| max_execution_time=self.configuration.max_execution_time, | |||
| early_stopping_method=self.configuration.early_stopping_method, | |||
| callbacks=self.configuration.callbacks | |||
| ) | |||
| try: | |||
| output = agent_executor.run(input=query) | |||
| except InvokeError as ex: | |||
| raise ex | |||
| except Exception as ex: | |||
| logging.exception("agent_executor run failed") | |||
| output = None | |||
| return AgentExecuteResult( | |||
| output=output, | |||
| strategy=self.configuration.strategy, | |||
| configuration=self.configuration | |||
| ) | |||
| @@ -1,23 +1,40 @@ | |||
| import threading | |||
| from typing import Optional, cast | |||
| from langchain.tools import BaseTool | |||
| from flask import Flask, current_app | |||
| from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor | |||
| from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.models.document import Document | |||
| from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter | |||
| from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter | |||
| from core.rerank.rerank import RerankRunner | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| } | |||
| class DatasetRetrieval: | |||
| def retrieve(self, tenant_id: str, | |||
| def retrieve(self, app_id: str, user_id: str, tenant_id: str, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| config: DatasetEntity, | |||
| query: str, | |||
| @@ -27,6 +44,8 @@ class DatasetRetrieval: | |||
| memory: Optional[TokenBufferMemory] = None) -> Optional[str]: | |||
| """ | |||
| Retrieve dataset. | |||
| :param app_id: app_id | |||
| :param user_id: user_id | |||
| :param tenant_id: tenant id | |||
| :param model_config: model config | |||
| :param config: dataset config | |||
| @@ -38,12 +57,22 @@ class DatasetRetrieval: | |||
| :return: | |||
| """ | |||
| dataset_ids = config.dataset_ids | |||
| if len(dataset_ids) == 0: | |||
| return None | |||
| retrieve_config = config.retrieve_config | |||
| # check model is support tool calling | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM, | |||
| provider=model_config.provider, | |||
| model=model_config.model | |||
| ) | |||
| # get model schema | |||
| model_schema = model_type_instance.get_model_schema( | |||
| model=model_config.model, | |||
| @@ -59,56 +88,6 @@ class DatasetRetrieval: | |||
| if ModelFeature.TOOL_CALL in features \ | |||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||
| planning_strategy = PlanningStrategy.ROUTER | |||
| dataset_retriever_tools = self.to_dataset_retriever_tool( | |||
| tenant_id=tenant_id, | |||
| dataset_ids=dataset_ids, | |||
| retrieve_config=retrieve_config, | |||
| return_resource=show_retrieve_source, | |||
| invoke_from=invoke_from, | |||
| hit_callback=hit_callback | |||
| ) | |||
| if len(dataset_retriever_tools) == 0: | |||
| return None | |||
| agent_configuration = AgentConfiguration( | |||
| strategy=planning_strategy, | |||
| model_config=model_config, | |||
| tools=dataset_retriever_tools, | |||
| memory=memory, | |||
| max_iterations=10, | |||
| max_execution_time=400.0, | |||
| early_stopping_method="generate" | |||
| ) | |||
| agent_executor = AgentExecutor(agent_configuration) | |||
| should_use_agent = agent_executor.should_use_agent(query) | |||
| if not should_use_agent: | |||
| return None | |||
| result = agent_executor.run(query) | |||
| return result.output | |||
| def to_dataset_retriever_tool(self, tenant_id: str, | |||
| dataset_ids: list[str], | |||
| retrieve_config: DatasetRetrieveConfigEntity, | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler) \ | |||
| -> Optional[list[BaseTool]]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param tenant_id: tenant id | |||
| :param dataset_ids: dataset ids | |||
| :param retrieve_config: retrieve config | |||
| :param return_resource: return resource | |||
| :param invoke_from: invoke from | |||
| :param hit_callback: hit callback | |||
| """ | |||
| tools = [] | |||
| available_datasets = [] | |||
| for dataset_id in dataset_ids: | |||
| # get dataset from dataset id | |||
| @@ -127,56 +106,270 @@ class DatasetRetrieval: | |||
| continue | |||
| available_datasets.append(dataset) | |||
| all_documents = [] | |||
| user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' | |||
| if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | |||
| all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query, | |||
| model_instance, | |||
| model_config, planning_strategy) | |||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | |||
| all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from, | |||
| available_datasets, query, retrieve_config.top_k, | |||
| retrieve_config.score_threshold, | |||
| retrieve_config.reranking_model.get('reranking_provider_name'), | |||
| retrieve_config.reranking_model.get('reranking_model_name')) | |||
| document_score_list = {} | |||
| for item in all_documents: | |||
| if 'score' in item.metadata and item.metadata['score']: | |||
| document_score_list[item.metadata['doc_id']] = item.metadata['score'] | |||
| document_context_list = [] | |||
| index_node_ids = [document.metadata['doc_id'] for document in all_documents] | |||
| segments = DocumentSegment.query.filter( | |||
| DocumentSegment.dataset_id.in_(dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids) | |||
| ).all() | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| sorted_segments = sorted(segments, | |||
| key=lambda segment: index_node_id_to_position.get(segment.index_node_id, | |||
| float('inf'))) | |||
| for segment in sorted_segments: | |||
| if segment.answer: | |||
| document_context_list.append(f'question:{segment.content} answer:{segment.answer}') | |||
| else: | |||
| document_context_list.append(segment.content) | |||
| if show_retrieve_source: | |||
| context_list = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| dataset = Dataset.query.filter_by( | |||
| id=segment.dataset_id | |||
| ).first() | |||
| document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).first() | |||
| if dataset and document: | |||
| source = { | |||
| 'position': resource_number, | |||
| 'dataset_id': dataset.id, | |||
| 'dataset_name': dataset.name, | |||
| 'document_id': document.id, | |||
| 'document_name': document.name, | |||
| 'data_source_type': document.data_source_type, | |||
| 'segment_id': segment.id, | |||
| 'retriever_from': invoke_from.to_source(), | |||
| 'score': document_score_list.get(segment.index_node_id, None) | |||
| } | |||
| if invoke_from.to_source() == 'dev': | |||
| source['hit_count'] = segment.hit_count | |||
| source['word_count'] = segment.word_count | |||
| source['segment_position'] = segment.position | |||
| source['index_node_hash'] = segment.index_node_hash | |||
| if segment.answer: | |||
| source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' | |||
| else: | |||
| source['content'] = segment.content | |||
| context_list.append(source) | |||
| resource_number += 1 | |||
| if hit_callback: | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| return str("\n".join(document_context_list)) | |||
| return '' | |||
| def single_retrieve(self, app_id: str, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| user_from: str, | |||
| available_datasets: list, | |||
| query: str, | |||
| model_instance: ModelInstance, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| planning_strategy: PlanningStrategy, | |||
| ): | |||
| tools = [] | |||
| for dataset in available_datasets: | |||
| description = dataset.description | |||
| if not description: | |||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||
| description = description.replace('\n', '').replace('\r', '') | |||
| message_tool = PromptMessageTool( | |||
| name=dataset.id, | |||
| description=description, | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": {}, | |||
| "required": [], | |||
| } | |||
| ) | |||
| tools.append(message_tool) | |||
| dataset_id = None | |||
| if planning_strategy == PlanningStrategy.REACT_ROUTER: | |||
| react_multi_dataset_router = ReactMultiDatasetRouter() | |||
| dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, | |||
| user_id, tenant_id) | |||
| elif planning_strategy == PlanningStrategy.ROUTER: | |||
| function_call_router = FunctionCallMultiDatasetRouter() | |||
| dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) | |||
| if dataset_id: | |||
| # get retrieval model config | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| } | |||
| for dataset in available_datasets: | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if dataset: | |||
| retrieval_model_config = dataset.retrieval_model \ | |||
| if dataset.retrieval_model else default_retrieval_model | |||
| # get top k | |||
| top_k = retrieval_model_config['top_k'] | |||
| # get retrieval method | |||
| if dataset.indexing_technique == "economy": | |||
| retrival_method = 'keyword_search' | |||
| else: | |||
| retrival_method = retrieval_model_config['search_method'] | |||
| # get reranking model | |||
| reranking_model = retrieval_model_config['reranking_model'] \ | |||
| if retrieval_model_config['reranking_enable'] else None | |||
| # get score threshold | |||
| score_threshold = None | |||
| score_threshold = .0 | |||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | |||
| if score_threshold_enabled: | |||
| score_threshold = retrieval_model_config.get("score_threshold") | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| hit_callbacks=[hit_callback], | |||
| return_resource=return_resource, | |||
| retriever_from=invoke_from.to_source() | |||
| ) | |||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=top_k, score_threshold=score_threshold, | |||
| reranking_model=reranking_model) | |||
| self._on_query(query, [dataset_id], app_id, user_from, user_id) | |||
| if results: | |||
| self._on_retrival_end(results) | |||
| return results | |||
| return [] | |||
| tools.append(tool) | |||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | |||
| tool = DatasetMultiRetrieverTool.from_dataset( | |||
| dataset_ids=[dataset.id for dataset in available_datasets], | |||
| tenant_id=tenant_id, | |||
| top_k=retrieve_config.top_k or 2, | |||
| score_threshold=retrieve_config.score_threshold, | |||
| hit_callbacks=[hit_callback], | |||
| return_resource=return_resource, | |||
| retriever_from=invoke_from.to_source(), | |||
| reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), | |||
| reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') | |||
| def multiple_retrieve(self, | |||
| app_id: str, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| user_from: str, | |||
| available_datasets: list, | |||
| query: str, | |||
| top_k: int, | |||
| score_threshold: float, | |||
| reranking_provider_name: str, | |||
| reranking_model_name: str): | |||
| threads = [] | |||
| all_documents = [] | |||
| dataset_ids = [dataset.id for dataset in available_datasets] | |||
| for dataset in available_datasets: | |||
| retrieval_thread = threading.Thread(target=self._retriever, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': dataset.id, | |||
| 'query': query, | |||
| 'top_k': top_k, | |||
| 'all_documents': all_documents, | |||
| }) | |||
| threads.append(retrieval_thread) | |||
| retrieval_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| # do rerank for searched documents | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| provider=reranking_provider_name, | |||
| model_type=ModelType.RERANK, | |||
| model=reranking_model_name | |||
| ) | |||
| rerank_runner = RerankRunner(rerank_model_instance) | |||
| all_documents = rerank_runner.run(query, all_documents, | |||
| score_threshold, | |||
| top_k) | |||
| self._on_query(query, dataset_ids, app_id, user_from, user_id) | |||
| if all_documents: | |||
| self._on_retrival_end(all_documents) | |||
| return all_documents | |||
| def _on_retrival_end(self, documents: list[Document]) -> None: | |||
| """Handle retrival end.""" | |||
| for document in documents: | |||
| query = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.index_node_id == document.metadata['doc_id'] | |||
| ) | |||
| tools.append(tool) | |||
| # if 'dataset_id' in document.metadata: | |||
| if 'dataset_id' in document.metadata: | |||
| query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) | |||
| # add hit count to document segment | |||
| query.update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||
| synchronize_session=False | |||
| ) | |||
| db.session.commit() | |||
| def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None: | |||
| """ | |||
| Handle query. | |||
| """ | |||
| if not query: | |||
| return | |||
| for dataset_id in dataset_ids: | |||
| dataset_query = DatasetQuery( | |||
| dataset_id=dataset_id, | |||
| content=query, | |||
| source='app', | |||
| source_app_id=app_id, | |||
| created_by_role=user_from, | |||
| created_by=user_id | |||
| ) | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| return [] | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve(retrival_method='keyword_search', | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=top_k | |||
| ) | |||
| if documents: | |||
| all_documents.extend(documents) | |||
| else: | |||
| if top_k > 0: | |||
| # retrieval source | |||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=top_k, | |||
| score_threshold=retrieval_model['score_threshold'] | |||
| if retrieval_model['score_threshold_enabled'] else None, | |||
| reranking_model=retrieval_model['reranking_model'] | |||
| if retrieval_model['reranking_enable'] else None | |||
| ) | |||
| return tools | |||
| all_documents.extend(documents) | |||
| @@ -12,8 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool | |||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage | |||
| from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | |||
| from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.workflow.nodes.llm.llm_node import LLMNode | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| @@ -55,11 +54,10 @@ class ReactMultiDatasetRouter: | |||
| self, | |||
| query: str, | |||
| dataset_tools: list[PromptMessageTool], | |||
| node_data: KnowledgeRetrievalNodeData, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| model_instance: ModelInstance, | |||
| user_id: str, | |||
| tenant_id: str, | |||
| tenant_id: str | |||
| ) -> Union[str, None]: | |||
| """Given input, decided what to do. | |||
| @@ -72,7 +70,8 @@ class ReactMultiDatasetRouter: | |||
| return dataset_tools[0].name | |||
| try: | |||
| return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance, | |||
| return self._react_invoke(query=query, model_config=model_config, | |||
| model_instance=model_instance, | |||
| tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) | |||
| except Exception as e: | |||
| return None | |||
| @@ -80,7 +79,6 @@ class ReactMultiDatasetRouter: | |||
| def _react_invoke( | |||
| self, | |||
| query: str, | |||
| node_data: KnowledgeRetrievalNodeData, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| model_instance: ModelInstance, | |||
| tools: Sequence[PromptMessageTool], | |||
| @@ -121,7 +119,7 @@ class ReactMultiDatasetRouter: | |||
| model_config=model_config | |||
| ) | |||
| result_text, usage = self._invoke_llm( | |||
| node_data=node_data, | |||
| completion_param=model_config.parameters, | |||
| model_instance=model_instance, | |||
| prompt_messages=prompt_messages, | |||
| stop=stop, | |||
| @@ -134,10 +132,11 @@ class ReactMultiDatasetRouter: | |||
| return agent_decision.tool | |||
| return None | |||
| def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData, | |||
| def _invoke_llm(self, completion_param: dict, | |||
| model_instance: ModelInstance, | |||
| prompt_messages: list[PromptMessage], | |||
| stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]: | |||
| stop: list[str], user_id: str, tenant_id: str | |||
| ) -> tuple[str, LLMUsage]: | |||
| """ | |||
| Invoke large language model | |||
| :param node_data: node data | |||
| @@ -148,7 +147,7 @@ class ReactMultiDatasetRouter: | |||
| """ | |||
| invoke_result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=node_data.single_retrieval_config.model.completion_params, | |||
| model_parameters=completion_param, | |||
| stop=stop, | |||
| stream=True, | |||
| user=user_id, | |||
| @@ -203,7 +202,8 @@ class ReactMultiDatasetRouter: | |||
| ) -> list[ChatModelMessage]: | |||
| tool_strings = [] | |||
| for tool in tools: | |||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") | |||
| tool_strings.append( | |||
| f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") | |||
| formatted_tools = "\n".join(tool_strings) | |||
| unique_tool_names = set(tool.name for tool in tools) | |||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | |||
| @@ -1,28 +1,21 @@ | |||
| import threading | |||
| from typing import Any, cast | |||
| from flask import Flask, current_app | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rerank.rerank import RerankRunner | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | |||
| from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter | |||
| from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| default_retrieval_model = { | |||
| @@ -106,10 +99,45 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| available_datasets.append(dataset) | |||
| all_documents = [] | |||
| dataset_retrieval = DatasetRetrieval() | |||
| if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: | |||
| all_documents = self._single_retrieve(available_datasets, node_data, query) | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data) | |||
| # check model is support tool calling | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # get model schema | |||
| model_schema = model_type_instance.get_model_schema( | |||
| model=model_config.model, | |||
| credentials=model_config.credentials | |||
| ) | |||
| if model_schema: | |||
| planning_strategy = PlanningStrategy.REACT_ROUTER | |||
| features = model_schema.features | |||
| if features: | |||
| if ModelFeature.TOOL_CALL in features \ | |||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||
| planning_strategy = PlanningStrategy.ROUTER | |||
| all_documents = dataset_retrieval.single_retrieve( | |||
| available_datasets=available_datasets, | |||
| tenant_id=self.tenant_id, | |||
| user_id=self.user_id, | |||
| app_id=self.app_id, | |||
| user_from=self.user_from.value, | |||
| query=query, | |||
| model_config=model_config, | |||
| model_instance=model_instance, | |||
| planning_strategy=planning_strategy | |||
| ) | |||
| elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: | |||
| all_documents = self._multiple_retrieve(available_datasets, node_data, query) | |||
| all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, | |||
| self.user_from.value, | |||
| available_datasets, query, | |||
| node_data.multiple_retrieval_config.top_k, | |||
| node_data.multiple_retrieval_config.score_threshold, | |||
| node_data.multiple_retrieval_config.reranking_model.provider, | |||
| node_data.multiple_retrieval_config.reranking_model.model) | |||
| context_list = [] | |||
| if all_documents: | |||
| @@ -184,87 +212,6 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| variable_mapping['query'] = node_data.query_variable_selector | |||
| return variable_mapping | |||
| def _single_retrieve(self, available_datasets, node_data, query): | |||
| tools = [] | |||
| for dataset in available_datasets: | |||
| description = dataset.description | |||
| if not description: | |||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||
| description = description.replace('\n', '').replace('\r', '') | |||
| message_tool = PromptMessageTool( | |||
| name=dataset.id, | |||
| description=description, | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": {}, | |||
| "required": [], | |||
| } | |||
| ) | |||
| tools.append(message_tool) | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data) | |||
| # check model is support tool calling | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # get model schema | |||
| model_schema = model_type_instance.get_model_schema( | |||
| model=model_config.model, | |||
| credentials=model_config.credentials | |||
| ) | |||
| if not model_schema: | |||
| return None | |||
| planning_strategy = PlanningStrategy.REACT_ROUTER | |||
| features = model_schema.features | |||
| if features: | |||
| if ModelFeature.TOOL_CALL in features \ | |||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||
| planning_strategy = PlanningStrategy.ROUTER | |||
| dataset_id = None | |||
| if planning_strategy == PlanningStrategy.REACT_ROUTER: | |||
| react_multi_dataset_router = ReactMultiDatasetRouter() | |||
| dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance, | |||
| self.user_id, self.tenant_id) | |||
| elif planning_strategy == PlanningStrategy.ROUTER: | |||
| function_call_router = FunctionCallMultiDatasetRouter() | |||
| dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) | |||
| if dataset_id: | |||
| # get retrieval model config | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if dataset: | |||
| retrieval_model_config = dataset.retrieval_model \ | |||
| if dataset.retrieval_model else default_retrieval_model | |||
| # get top k | |||
| top_k = retrieval_model_config['top_k'] | |||
| # get retrieval method | |||
| if dataset.indexing_technique == "economy": | |||
| retrival_method = 'keyword_search' | |||
| else: | |||
| retrival_method = retrieval_model_config['search_method'] | |||
| # get reranking model | |||
| reranking_model=retrieval_model_config['reranking_model'] \ | |||
| if retrieval_model_config['reranking_enable'] else None | |||
| # get score threshold | |||
| score_threshold = .0 | |||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | |||
| if score_threshold_enabled: | |||
| score_threshold = retrieval_model_config.get("score_threshold") | |||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=top_k, score_threshold=score_threshold, | |||
| reranking_model=reranking_model) | |||
| self._on_query(query, [dataset_id]) | |||
| if results: | |||
| self._on_retrival_end(results) | |||
| return results | |||
| return [] | |||
| def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ | |||
| ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| """ | |||
| @@ -335,112 +282,3 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| parameters=completion_params, | |||
| stop=stop, | |||
| ) | |||
| def _multiple_retrieve(self, available_datasets, node_data, query): | |||
| threads = [] | |||
| all_documents = [] | |||
| dataset_ids = [dataset.id for dataset in available_datasets] | |||
| for dataset in available_datasets: | |||
| retrieval_thread = threading.Thread(target=self._retriever, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': dataset.id, | |||
| 'query': query, | |||
| 'top_k': node_data.multiple_retrieval_config.top_k, | |||
| 'all_documents': all_documents, | |||
| }) | |||
| threads.append(retrieval_thread) | |||
| retrieval_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| # do rerank for searched documents | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=self.tenant_id, | |||
| provider=node_data.multiple_retrieval_config.reranking_model.provider, | |||
| model_type=ModelType.RERANK, | |||
| model=node_data.multiple_retrieval_config.reranking_model.model | |||
| ) | |||
| rerank_runner = RerankRunner(rerank_model_instance) | |||
| all_documents = rerank_runner.run(query, all_documents, | |||
| node_data.multiple_retrieval_config.score_threshold, | |||
| node_data.multiple_retrieval_config.top_k) | |||
| self._on_query(query, dataset_ids) | |||
| if all_documents: | |||
| self._on_retrival_end(all_documents) | |||
| return all_documents | |||
| def _on_retrival_end(self, documents: list[Document]) -> None: | |||
| """Handle retrival end.""" | |||
| for document in documents: | |||
| query = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.index_node_id == document.metadata['doc_id'] | |||
| ) | |||
| # if 'dataset_id' in document.metadata: | |||
| if 'dataset_id' in document.metadata: | |||
| query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) | |||
| # add hit count to document segment | |||
| query.update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||
| synchronize_session=False | |||
| ) | |||
| db.session.commit() | |||
| def _on_query(self, query: str, dataset_ids: list[str]) -> None: | |||
| """ | |||
| Handle query. | |||
| """ | |||
| if not query: | |||
| return | |||
| for dataset_id in dataset_ids: | |||
| dataset_query = DatasetQuery( | |||
| dataset_id=dataset_id, | |||
| content=query, | |||
| source='app', | |||
| source_app_id=self.app_id, | |||
| created_by_role=self.user_from.value, | |||
| created_by=self.user_id | |||
| ) | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == self.tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| return [] | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve(retrival_method='keyword_search', | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=top_k | |||
| ) | |||
| if documents: | |||
| all_documents.extend(documents) | |||
| else: | |||
| if top_k > 0: | |||
| # retrieval source | |||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=top_k, | |||
| score_threshold=retrieval_model['score_threshold'] | |||
| if retrieval_model['score_threshold_enabled'] else None, | |||
| reranking_model=retrieval_model['reranking_model'] | |||
| if retrieval_model['reranking_enable'] else None | |||
| ) | |||
| all_documents.extend(documents) | |||