| """Base classes for LLM-powered router chains.""" | |||||
| from __future__ import annotations | |||||
| import json | |||||
| from typing import Any, Dict, List, Optional, Type, cast, NamedTuple | |||||
| from langchain.chains.base import Chain | |||||
| from pydantic import root_validator | |||||
| from langchain.chains import LLMChain | |||||
| from langchain.prompts import BasePromptTemplate | |||||
| from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel | |||||
| class Route(NamedTuple): | |||||
| destination: Optional[str] | |||||
| next_inputs: Dict[str, Any] | |||||
| class LLMRouterChain(Chain): | |||||
| """A router chain that uses an LLM chain to perform routing.""" | |||||
| llm_chain: LLMChain | |||||
| """LLM chain used to perform routing""" | |||||
| @root_validator() | |||||
| def validate_prompt(cls, values: dict) -> dict: | |||||
| prompt = values["llm_chain"].prompt | |||||
| if prompt.output_parser is None: | |||||
| raise ValueError( | |||||
| "LLMRouterChain requires base llm_chain prompt to have an output" | |||||
| " parser that converts LLM text output to a dictionary with keys" | |||||
| " 'destination' and 'next_inputs'. Received a prompt with no output" | |||||
| " parser." | |||||
| ) | |||||
| return values | |||||
| @property | |||||
| def input_keys(self) -> List[str]: | |||||
| """Will be whatever keys the LLM chain prompt expects. | |||||
| :meta private: | |||||
| """ | |||||
| return self.llm_chain.input_keys | |||||
| def _validate_outputs(self, outputs: Dict[str, Any]) -> None: | |||||
| super()._validate_outputs(outputs) | |||||
| if not isinstance(outputs["next_inputs"], dict): | |||||
| raise ValueError | |||||
| def _call( | |||||
| self, | |||||
| inputs: Dict[str, Any] | |||||
| ) -> Dict[str, Any]: | |||||
| output = cast( | |||||
| Dict[str, Any], | |||||
| self.llm_chain.predict_and_parse(**inputs), | |||||
| ) | |||||
| return output | |||||
| @classmethod | |||||
| def from_llm( | |||||
| cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any | |||||
| ) -> LLMRouterChain: | |||||
| """Convenience constructor.""" | |||||
| llm_chain = LLMChain(llm=llm, prompt=prompt) | |||||
| return cls(llm_chain=llm_chain, **kwargs) | |||||
| @property | |||||
| def output_keys(self) -> List[str]: | |||||
| return ["destination", "next_inputs"] | |||||
| def route(self, inputs: Dict[str, Any]) -> Route: | |||||
| result = self(inputs) | |||||
| return Route(result["destination"], result["next_inputs"]) | |||||
| class RouterOutputParser(BaseOutputParser[Dict[str, str]]): | |||||
| """Parser for output of router chain int he multi-prompt chain.""" | |||||
| default_destination: str = "DEFAULT" | |||||
| next_inputs_type: Type = str | |||||
| next_inputs_inner_key: str = "input" | |||||
| def parse_json_markdown(self, json_string: str) -> dict: | |||||
| # Remove the triple backticks if present | |||||
| json_string = json_string.replace("```json", "").replace("```", "") | |||||
| # Strip whitespace and newlines from the start and end | |||||
| json_string = json_string.strip() | |||||
| # Parse the JSON string into a Python dictionary | |||||
| parsed = json.loads(json_string) | |||||
| return parsed | |||||
| def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict: | |||||
| try: | |||||
| json_obj = self.parse_json_markdown(text) | |||||
| except json.JSONDecodeError as e: | |||||
| raise OutputParserException(f"Got invalid JSON object. Error: {e}") | |||||
| for key in expected_keys: | |||||
| if key not in json_obj: | |||||
| raise OutputParserException( | |||||
| f"Got invalid return object. Expected key `{key}` " | |||||
| f"to be present, but got {json_obj}" | |||||
| ) | |||||
| return json_obj | |||||
| def parse(self, text: str) -> Dict[str, Any]: | |||||
| try: | |||||
| expected_keys = ["destination", "next_inputs"] | |||||
| parsed = self.parse_and_check_json_markdown(text, expected_keys) | |||||
| if not isinstance(parsed["destination"], str): | |||||
| raise ValueError("Expected 'destination' to be a string.") | |||||
| if not isinstance(parsed["next_inputs"], self.next_inputs_type): | |||||
| raise ValueError( | |||||
| f"Expected 'next_inputs' to be {self.next_inputs_type}." | |||||
| ) | |||||
| parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]} | |||||
| if ( | |||||
| parsed["destination"].strip().lower() | |||||
| == self.default_destination.lower() | |||||
| ): | |||||
| parsed["destination"] = None | |||||
| else: | |||||
| parsed["destination"] = parsed["destination"].strip() | |||||
| return parsed | |||||
| except Exception as e: | |||||
| raise OutputParserException( | |||||
| f"Parsing text\n{text}\n raised following error:\n{e}" | |||||
| ) |
| from typing import Optional, List | from typing import Optional, List | ||||
| from langchain.callbacks import SharedCallbackManager | |||||
| from langchain.callbacks import SharedCallbackManager, CallbackManager | |||||
| from langchain.chains import SequentialChain | from langchain.chains import SequentialChain | ||||
| from langchain.chains.base import Chain | from langchain.chains.base import Chain | ||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from core.agent.agent_builder import AgentBuilder | |||||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | ||||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | |||||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | ||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||||
| from core.chain.chain_builder import ChainBuilder | from core.chain.chain_builder import ChainBuilder | ||||
| from core.constant import llm_constant | |||||
| from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain | |||||
| from core.conversation_message_task import ConversationMessageTask | from core.conversation_message_task import ConversationMessageTask | ||||
| from core.tool.dataset_tool_builder import DatasetToolBuilder | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset | |||||
| class MainChainBuilder: | class MainChainBuilder: | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| agent_mode=agent_mode, | agent_mode=agent_mode, | ||||
| memory=memory, | memory=memory, | ||||
| dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task), | |||||
| agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler | |||||
| conversation_message_task=conversation_message_task | |||||
| ) | ) | ||||
| chains += tool_chains | chains += tool_chains | ||||
| @classmethod | @classmethod | ||||
| def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], | def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], | ||||
| dataset_tool_callback_handler: DatasetToolCallbackHandler, | |||||
| agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): | |||||
| conversation_message_task: ConversationMessageTask): | |||||
| # agent mode | # agent mode | ||||
| chains = [] | chains = [] | ||||
| if agent_mode and agent_mode.get('enabled'): | if agent_mode and agent_mode.get('enabled'): | ||||
| tools = agent_mode.get('tools', []) | tools = agent_mode.get('tools', []) | ||||
| pre_fixed_chains = [] | pre_fixed_chains = [] | ||||
| agent_tools = [] | |||||
| # agent_tools = [] | |||||
| datasets = [] | |||||
| for tool in tools: | for tool in tools: | ||||
| tool_type = list(tool.keys())[0] | tool_type = list(tool.keys())[0] | ||||
| tool_config = list(tool.values())[0] | tool_config = list(tool.values())[0] | ||||
| if chain: | if chain: | ||||
| pre_fixed_chains.append(chain) | pre_fixed_chains.append(chain) | ||||
| elif tool_type == "dataset": | elif tool_type == "dataset": | ||||
| dataset_tool = DatasetToolBuilder.build_dataset_tool( | |||||
| tenant_id=tenant_id, | |||||
| dataset_id=tool_config.get("id"), | |||||
| response_mode='no_synthesizer', # "compact" | |||||
| callback_handler=dataset_tool_callback_handler | |||||
| ) | |||||
| # get dataset from dataset id | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.tenant_id == tenant_id, | |||||
| Dataset.id == tool_config.get("id") | |||||
| ).first() | |||||
| if dataset_tool: | |||||
| agent_tools.append(dataset_tool) | |||||
| if dataset: | |||||
| datasets.append(dataset) | |||||
| # add pre-fixed chains | # add pre-fixed chains | ||||
| chains += pre_fixed_chains | chains += pre_fixed_chains | ||||
| if len(agent_tools) == 1: | |||||
| if len(datasets) > 0: | |||||
| # tool to chain | # tool to chain | ||||
| tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output') | |||||
| chains.append(tool_chain) | |||||
| elif len(agent_tools) > 1: | |||||
| # build agent config | |||||
| agent_chain = AgentBuilder.to_agent_chain( | |||||
| multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets( | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| tools=agent_tools, | |||||
| memory=memory, | |||||
| dataset_tool_callback_handler=dataset_tool_callback_handler, | |||||
| agent_loop_gather_callback_handler=agent_loop_gather_callback_handler | |||||
| datasets=datasets, | |||||
| conversation_message_task=conversation_message_task, | |||||
| callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) | |||||
| ) | ) | ||||
| chains.append(agent_chain) | |||||
| chains.append(multi_dataset_router_chain) | |||||
| final_output_key = cls.get_chains_output_key(chains) | final_output_key = cls.get_chains_output_key(chains) | ||||
| from typing import Mapping, List, Dict, Any, Optional | |||||
| from langchain import LLMChain, PromptTemplate, ConversationChain | |||||
| from langchain.callbacks import CallbackManager | |||||
| from langchain.chains.base import Chain | |||||
| from langchain.schema import BaseLanguageModel | |||||
| from pydantic import Extra | |||||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | |||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||||
| from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser | |||||
| from core.conversation_message_task import ConversationMessageTask | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from core.tool.dataset_tool_builder import DatasetToolBuilder | |||||
| from core.tool.llama_index_tool import EnhanceLlamaIndexTool | |||||
| from models.dataset import Dataset | |||||
| MULTI_PROMPT_ROUTER_TEMPLATE = """ | |||||
| Given a raw text input to a language model select the model prompt best suited for \ | |||||
| the input. You will be given the names of the available prompts and a description of \ | |||||
| what the prompt is best suited for. You may also revise the original input if you \ | |||||
| think that revising it will ultimately lead to a better response from the language \ | |||||
| model. | |||||
| << FORMATTING >> | |||||
| Return a markdown code snippet with a JSON object formatted to look like: | |||||
| ```json | |||||
| {{{{ | |||||
| "destination": string \\ name of the prompt to use or "DEFAULT" | |||||
| "next_inputs": string \\ a potentially modified version of the original input | |||||
| }}}} | |||||
| ``` | |||||
| REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \ | |||||
| it can be "DEFAULT" if the input is not well suited for any of the candidate prompts. | |||||
| REMEMBER: "next_inputs" can just be the original input if you don't think any \ | |||||
| modifications are needed. | |||||
| << CANDIDATE PROMPTS >> | |||||
| {destinations} | |||||
| << INPUT >> | |||||
| {{input}} | |||||
| << OUTPUT >> | |||||
| """ | |||||
| class MultiDatasetRouterChain(Chain): | |||||
| """Use a single chain to route an input to one of multiple candidate chains.""" | |||||
| router_chain: LLMRouterChain | |||||
| """Chain for deciding a destination chain and the input to it.""" | |||||
| dataset_tools: Mapping[str, EnhanceLlamaIndexTool] | |||||
| """Map of name to candidate chains that inputs can be routed to.""" | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| arbitrary_types_allowed = True | |||||
| @property | |||||
| def input_keys(self) -> List[str]: | |||||
| """Will be whatever keys the router chain prompt expects. | |||||
| :meta private: | |||||
| """ | |||||
| return self.router_chain.input_keys | |||||
| @property | |||||
| def output_keys(self) -> List[str]: | |||||
| return ["text"] | |||||
| @classmethod | |||||
| def from_datasets( | |||||
| cls, | |||||
| tenant_id: str, | |||||
| datasets: List[Dataset], | |||||
| conversation_message_task: ConversationMessageTask, | |||||
| **kwargs: Any, | |||||
| ): | |||||
| """Convenience constructor for instantiating from destination prompts.""" | |||||
| llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) | |||||
| llm = LLMBuilder.to_llm( | |||||
| tenant_id=tenant_id, | |||||
| model_name='gpt-3.5-turbo', | |||||
| temperature=0, | |||||
| max_tokens=1024, | |||||
| callback_manager=llm_callback_manager | |||||
| ) | |||||
| destinations = [f"{d.id}: {d.description}" for d in datasets] | |||||
| destinations_str = "\n".join(destinations) | |||||
| router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( | |||||
| destinations=destinations_str | |||||
| ) | |||||
| router_prompt = PromptTemplate( | |||||
| template=router_template, | |||||
| input_variables=["input"], | |||||
| output_parser=RouterOutputParser(), | |||||
| ) | |||||
| router_chain = LLMRouterChain.from_llm(llm, router_prompt) | |||||
| dataset_tools = {} | |||||
| for dataset in datasets: | |||||
| dataset_tool = DatasetToolBuilder.build_dataset_tool( | |||||
| dataset=dataset, | |||||
| response_mode='no_synthesizer', # "compact" | |||||
| callback_handler=DatasetToolCallbackHandler(conversation_message_task) | |||||
| ) | |||||
| dataset_tools[dataset.id] = dataset_tool | |||||
| return cls( | |||||
| router_chain=router_chain, | |||||
| dataset_tools=dataset_tools, | |||||
| **kwargs, | |||||
| ) | |||||
| def _call( | |||||
| self, | |||||
| inputs: Dict[str, Any] | |||||
| ) -> Dict[str, Any]: | |||||
| if len(self.dataset_tools) == 0: | |||||
| return {"text": ''} | |||||
| elif len(self.dataset_tools) == 1: | |||||
| return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])} | |||||
| route = self.router_chain.route(inputs) | |||||
| if not route.destination: | |||||
| return {"text": ''} | |||||
| elif route.destination in self.dataset_tools: | |||||
| return {"text": self.dataset_tools[route.destination].run( | |||||
| route.next_inputs['input'] | |||||
| )} | |||||
| else: | |||||
| raise ValueError( | |||||
| f"Received invalid destination chain name '{route.destination}'" | |||||
| ) |
| from core.index.vector_index import VectorIndex | from core.index.vector_index import VectorIndex | ||||
| from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE | from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE | ||||
| from core.tool.llama_index_tool import EnhanceLlamaIndexTool | from core.tool.llama_index_tool import EnhanceLlamaIndexTool | ||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| class DatasetToolBuilder: | class DatasetToolBuilder: | ||||
| @classmethod | @classmethod | ||||
| def build_dataset_tool(cls, tenant_id: str, dataset_id: str, | |||||
| def build_dataset_tool(cls, dataset: Dataset, | |||||
| response_mode: str = "no_synthesizer", | response_mode: str = "no_synthesizer", | ||||
| callback_handler: Optional[DatasetToolCallbackHandler] = None): | callback_handler: Optional[DatasetToolCallbackHandler] = None): | ||||
| # get dataset from dataset id | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.tenant_id == tenant_id, | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| return None | |||||
| if dataset.indexing_technique == "economy": | if dataset.indexing_technique == "economy": | ||||
| # use keyword table query | # use keyword table query | ||||
| index = KeywordTableIndex(dataset=dataset).query_index | index = KeywordTableIndex(dataset=dataset).query_index | ||||
| index_tool_config = IndexToolConfig( | index_tool_config = IndexToolConfig( | ||||
| index=index, | index=index, | ||||
| name=f"dataset-{dataset_id}", | |||||
| name=f"dataset-{dataset.id}", | |||||
| description=description, | description=description, | ||||
| index_query_kwargs=query_kwargs, | index_query_kwargs=query_kwargs, | ||||
| tool_kwargs={ | tool_kwargs={ | ||||
| # return_direct: Whether to return LLM results directly or process the output data with an Output Parser | # return_direct: Whether to return LLM results directly or process the output data with an Output Parser | ||||
| ) | ) | ||||
| index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id) | |||||
| index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id) | |||||
| return EnhanceLlamaIndexTool.from_tool_config( | return EnhanceLlamaIndexTool.from_tool_config( | ||||
| tool_config=index_tool_config, | tool_config=index_tool_config, |