| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 | 
							- from typing import Optional, List
 - 
 - from langchain.callbacks import SharedCallbackManager, CallbackManager
 - from langchain.chains import SequentialChain
 - from langchain.chains.base import Chain
 - from langchain.memory.chat_memory import BaseChatMemory
 - 
 - from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 - from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 - from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 - from core.chain.chain_builder import ChainBuilder
 - from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
 - from core.conversation_message_task import ConversationMessageTask
 - from extensions.ext_database import db
 - from models.dataset import Dataset
 - 
 - 
 - class MainChainBuilder:
 -     @classmethod
 -     def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
 -                                 conversation_message_task: ConversationMessageTask):
 -         first_input_key = "input"
 -         final_output_key = "output"
 - 
 -         chains = []
 - 
 -         chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
 - 
 -         # agent mode
 -         tool_chains, chains_output_key = cls.get_agent_chains(
 -             tenant_id=tenant_id,
 -             agent_mode=agent_mode,
 -             memory=memory,
 -             conversation_message_task=conversation_message_task
 -         )
 -         chains += tool_chains
 - 
 -         if chains_output_key:
 -             final_output_key = chains_output_key
 - 
 -         if len(chains) == 0:
 -             return None
 - 
 -         for chain in chains:
 -             # do not add handler into singleton callback manager
 -             if not isinstance(chain.callback_manager, SharedCallbackManager):
 -                 chain.callback_manager.add_handler(chain_callback_handler)
 - 
 -         # build main chain
 -         overall_chain = SequentialChain(
 -             chains=chains,
 -             input_variables=[first_input_key],
 -             output_variables=[final_output_key],
 -             memory=memory,  # only for use the memory prompt input key
 -         )
 - 
 -         return overall_chain
 - 
 -     @classmethod
 -     def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
 -                          conversation_message_task: ConversationMessageTask):
 -         # agent mode
 -         chains = []
 -         if agent_mode and agent_mode.get('enabled'):
 -             tools = agent_mode.get('tools', [])
 - 
 -             pre_fixed_chains = []
 -             # agent_tools = []
 -             datasets = []
 -             for tool in tools:
 -                 tool_type = list(tool.keys())[0]
 -                 tool_config = list(tool.values())[0]
 -                 if tool_type == 'sensitive-word-avoidance':
 -                     chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
 -                     if chain:
 -                         pre_fixed_chains.append(chain)
 -                 elif tool_type == "dataset":
 -                     # get dataset from dataset id
 -                     dataset = db.session.query(Dataset).filter(
 -                         Dataset.tenant_id == tenant_id,
 -                         Dataset.id == tool_config.get("id")
 -                     ).first()
 - 
 -                     if dataset:
 -                         datasets.append(dataset)
 - 
 -             # add pre-fixed chains
 -             chains += pre_fixed_chains
 - 
 -             if len(datasets) > 0:
 -                 # tool to chain
 -                 multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
 -                     tenant_id=tenant_id,
 -                     datasets=datasets,
 -                     conversation_message_task=conversation_message_task,
 -                     callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
 -                 )
 -                 chains.append(multi_dataset_router_chain)
 - 
 -         final_output_key = cls.get_chains_output_key(chains)
 - 
 -         return chains, final_output_key
 - 
 -     @classmethod
 -     def get_chains_output_key(cls, chains: List[Chain]):
 -         if len(chains) > 0:
 -             return chains[-1].output_keys[0]
 -         return None
 
 
  |