Browse Source

feat: remove llm client use (#1316)

tags/0.3.27
takatost 2 years ago
parent
commit
cbf095465c
No account linked to committer's email address

+ 57
- 7
api/core/agent/agent/multi_dataset_router_agent.py View File

from typing import Tuple, List, Any, Union, Sequence, Optional, cast from typing import Tuple, List, Any, Union, Sequence, Optional, cast


from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator


from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool






arbitrary_types_allowed = True arbitrary_types_allowed = True


@root_validator
def validate_llm(cls, values: dict) -> dict:
return values

def should_use_agent(self, query: str): def should_use_agent(self, query: str):
""" """
return should use agent return should use agent
return AgentFinish(return_values={"output": observation}, log=observation) return AgentFinish(return_values={"output": observation}, log=observation)


try: try:
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction): if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs: if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
new_exception = self.model_instance.handle_exceptions(e) new_exception = self.model_instance.handle_exceptions(e)
raise new_exception raise new_exception


def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.

Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.

Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)

ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)

agent_decision = _parse_ai_message(ai_message)
return agent_decision

async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
), ),
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages, extra_prompt_messages=extra_prompt_messages,
system_message=system_message, system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs, **kwargs,
) )

+ 193
- 19
api/core/agent/agent/openai_function_call.py View File

_format_intermediate_steps _format_intermediate_steps
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator


from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM




class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
model_instance: BaseLLM

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

@root_validator
def validate_llm(cls, values: dict) -> dict:
return values


@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
), ),
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools, tools=tools,
callback_manager=callback_manager, callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs, **kwargs,
) )


:param query: :param query:
:return: :return:
""" """
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 40
original_max_tokens = self.model_instance.model_kwargs.max_tokens
self.model_instance.model_kwargs.max_tokens = 40


prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages() messages = prompt.to_messages()


try: try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
callbacks=None
) )
except Exception as e: except Exception as e:
new_exception = self.model_instance.handle_exceptions(e) new_exception = self.model_instance.handle_exceptions(e)
raise new_exception raise new_exception


function_call = predicted_message.additional_kwargs.get("function_call", {})
function_call = result.function_call


self.llm.max_tokens = original_max_tokens
self.model_instance.model_kwargs.max_tokens = original_max_tokens


return True if function_call else False return True if function_call else False


except ExceededLLMTokensLimitError as e: except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e)) return AgentFinish(return_values={"output": str(e)}, log=str(e))


predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)

ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
) )
agent_decision = _parse_ai_message(predicted_message)
agent_decision = _parse_ai_message(ai_message)


if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input tool_inputs = agent_decision.tool_input
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError: except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")

def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages

system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)

if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)

new_messages = [system_message, human_message]

if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)

self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)

new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)

return new_messages

def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)

chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)

def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.

Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_instance.model_provider.provider_name == 'azure_openai':
model = model_instance.base_model_name
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_instance.base_model_name

tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)

if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))

if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3

if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))

return num_tokens

+ 0
- 140
api/core/agent/agent/openai_function_call_summarize_mixin.py View File

from typing import cast, List

from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel

from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM


class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages

system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)

if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)

new_messages = [system_message, human_message]

if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)

summary_handler = SummarizerMixin(llm=self.summary_llm)
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)

new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)

return new_messages

def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.

Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, model_instance.client)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))

if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3

if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))

return num_tokens

+ 0
- 107
api/core/agent/agent/openai_multi_function_call.py View File

from typing import List, Tuple, Any, Union, Sequence, Optional

from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool

from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin


class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):

@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)

def should_use_agent(self, query: str):
"""
return should use agent

:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15

prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()

try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception

function_call = predicted_message.additional_kwargs.get("function_call", {})

self.llm.max_tokens = original_max_tokens

return True if function_call else False

def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.

Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.

Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()

# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))

predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision

@classmethod
def get_system_message(cls):
# get current time
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")

+ 19
- 9
api/core/agent/agent/structed_multi_dataset_router_agent.py View File

from langchain import BasePromptTemplate from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX


from core.chain.llm_chain import LLMChain
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool






class StructuredMultiDatasetRouterAgent(StructuredChatAgent): class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool] dataset_tools: Sequence[BaseTool]


class Config: class Config:
try: try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e: except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception raise new_exception


try: try:
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prefix=prefix, prefix=prefix,
suffix=suffix, suffix=suffix,
human_message_template=human_message_template, human_message_template=human_message_template,
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=input_variables, input_variables=input_variables,
memory_prompts=memory_prompts, memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
dataset_tools=tools, dataset_tools=tools,
**kwargs, **kwargs,
) )

+ 38
- 16
api/core/agent/agent/structured_chat.py View File

from langchain import BasePromptTemplate from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX


from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM


FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
summary_model_instance: BaseLLM = None


class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
if prompts: if prompts:
messages = prompts[0].to_messages() messages = prompts[0].to_messages()


rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
if rest_tokens < 0: if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs) full_inputs = self.summarize_messages(intermediate_steps, **kwargs)


try: try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e: except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception raise new_exception


try: try:
"I don't know how to respond to that."}, "") "I don't know how to respond to that."}, "")


def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_llm:
if len(intermediate_steps) >= 2 and self.summary_model_instance:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation) should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps] for _, observation in should_summary_intermediate_steps]
error_msg = "Exceeded LLM tokens limit, stopped." error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg) raise ExceededLLMTokensLimitError(error_msg)


summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs: if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop() kwargs["chat_history"].pop()


self.moving_summary_buffer = summary_handler.predict_new_summary(
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages, messages=should_summary_messages,
existing_summary=self.moving_summary_buffer existing_summary=self.moving_summary_buffer
) )


return self.get_full_inputs([intermediate_steps[-1]], **kwargs) return self.get_full_inputs([intermediate_steps[-1]], **kwargs)


def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)

chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)

@classmethod @classmethod
def create_prompt( def create_prompt(
cls, cls,
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prefix=prefix, prefix=prefix,
suffix=suffix, suffix=suffix,
human_message_template=human_message_template, human_message_template=human_message_template,
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=input_variables, input_variables=input_variables,
memory_prompts=memory_prompts, memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs, **kwargs,
) )

+ 2
- 18
api/core/agent/agent_executor.py View File



from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
REACT_ROUTER = 'react_router' REACT_ROUTER = 'react_router'
REACT = 'react' REACT = 'react'
FUNCTION_CALL = 'function_call' FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'




class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
if self.configuration.strategy == PlanningStrategy.REACT: if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None, if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None, if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools( agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True verbose=True
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
verbose=True verbose=True

+ 36
- 0
api/core/chain/llm_chain.py View File

from typing import List, Dict, Any, Optional

from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel

from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM


class LLMChain(LCLLMChain):
model_instance: BaseLLM
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")

def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
stop=stop
)

generations = [
[Generation(text=result.content)]
]

return LLMResult(generations=generations)

+ 18
- 3
api/core/model_providers/models/entity/message.py View File

import enum import enum


from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel from pydantic import BaseModel




prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
source: list = None source: list = None
function_call: dict = None




class MessageType(enum.Enum): class MessageType(enum.Enum):
class PromptMessage(BaseModel): class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN type: MessageType = MessageType.HUMAN
content: str = '' content: str = ''
function_call: dict = None




def to_lc_messages(messages: list[PromptMessage]): def to_lc_messages(messages: list[PromptMessage]):
if message.type == MessageType.HUMAN: if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content)) lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT: elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
additional_kwargs = {}
if message.function_call:
additional_kwargs['function_call'] = message.function_call
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
elif message.type == MessageType.SYSTEM: elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content)) lc_messages.append(SystemMessage(content=message.content))


if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
message_kwargs = {
'content': message.content,
'type': MessageType.ASSISTANT
}

if 'function_call' in message.additional_kwargs:
message_kwargs['function_call'] = message.additional_kwargs['function_call']

prompt_messages.append(PromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
return prompt_messages return prompt_messages





+ 14
- 1
api/core/model_providers/models/llm/azure_openai_model.py View File

:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}

if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]

if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']

return self._client.generate(**generate_kwargs)
@property @property
def base_model_name(self) -> str: def base_model_name(self) -> str:

+ 8
- 12
api/core/model_providers/models/llm/base.py View File

from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
except Exception as ex: except Exception as ex:
raise self.handle_exceptions(ex) raise self.handle_exceptions(ex)


function_call = None
if isinstance(result.generations[0][0], ChatGeneration): if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content completion_content = result.generations[0][0].message.content
if 'function_call' in result.generations[0][0].message.additional_kwargs:
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
else: else:
completion_content = result.generations[0][0].text completion_content = result.generations[0][0].text


return LLMRunResult( return LLMRunResult(
content=completion_content, content=completion_content,
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
completion_tokens=completion_tokens,
function_call=function_call
) )


@abstractmethod @abstractmethod
if len(messages) == 0: if len(messages) == 0:
return [] return []


chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))

return chat_messages
return to_lc_messages(messages)


def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
""" """

+ 15
- 1
api/core/model_providers/models/llm/openai_model.py View File

raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")


prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)

generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}

if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]

if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']

return self._client.generate(**generate_kwargs)


def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """

+ 10
- 14
api/core/orchestrator_rule_parser.py View File

import math import math
from typing import Optional from typing import Optional


from flask import current_app
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig from models.model import AppModelConfig
from models.provider import ProviderType




class OrchestratorRuleParser: class OrchestratorRuleParser:
# only OpenAI chat model (include Azure) support function call, use ReACT instead # only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \ if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
if planning_strategy == PlanningStrategy.FUNCTION_CALL:
planning_strategy = PlanningStrategy.REACT planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER: elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER planning_strategy = PlanningStrategy.REACT_ROUTER
tool = self.to_current_datetime_tool() tool = self.to_current_datetime_tool()


if tool: if tool:
tool.callbacks.extend(callbacks)
if tool.callbacks is not None:
tool.callbacks.extend(callbacks)
else:
tool.callbacks = callbacks
tools.append(tool) tools.append(tool)


return tools return tools
summary_model_instance = None summary_model_instance = None


tool = WebReaderTool( tool = WebReaderTool(
llm=summary_model_instance.client if summary_model_instance else None,
model_instance=summary_model_instance if summary_model_instance else None,
max_chunk_length=4000, max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
continue_reading=True
) )


return tool return tool
"is not up to date. " "is not up to date. "
"Input should be a search query.", "Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run, func=OptimizedSerpAPIWrapper(**func_kwargs).run,
args_schema=OptimizedSerpAPIInput,
callbacks=[DifyStdOutCallbackHandler()]
args_schema=OptimizedSerpAPIInput
) )


return tool return tool


def to_current_datetime_tool(self) -> Optional[BaseTool]: def to_current_datetime_tool(self) -> Optional[BaseTool]:
tool = DatetimeTool(
callbacks=[DifyStdOutCallbackHandler()]
)
tool = DatetimeTool()


return tool return tool


return WikipediaQueryRun( return WikipediaQueryRun(
name="wikipedia", name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput,
callbacks=[DifyStdOutCallbackHandler()]
args_schema=WikipediaInput
) )


@classmethod @classmethod

+ 24
- 6
api/core/tool/web_reader_tool.py View File



import requests import requests
from bs4 import BeautifulSoup, NavigableString, Comment, CData from bs4 import BeautifulSoup, NavigableString, Comment, CData
from langchain.base_language import BaseLanguageModel
from langchain.chains.summarize import load_summarize_chain
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from regex import regex from regex import regex


from core.chain.llm_chain import LLMChain
from core.data_loader import file_extractor from core.data_loader import file_extractor
from core.data_loader.file_extractor import FileExtractor from core.data_loader.file_extractor import FileExtractor
from core.model_providers.models.llm.base import BaseLLM


FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
summary_chunk_overlap: int = 0 summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
continue_reading: bool = True continue_reading: bool = True
llm: BaseLanguageModel = None
model_instance: BaseLLM = None


def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try: try:
except Exception as e: except Exception as e:
return f'Read this website failed, caused by: {str(e)}.' return f'Read this website failed, caused by: {str(e)}.'


if summary and self.llm:
if summary and self.model_instance:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens, chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap, chunk_overlap=self.summary_chunk_overlap,
if len(docs) > 5: if len(docs) > 5:
docs = docs[:5] docs = docs[:5]


chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
chain = self.get_summary_chain()
try: try:
page_contents = chain.run(docs) page_contents = chain.run(docs)
# todo use cache
except Exception as e: except Exception as e:
return f'Read this website failed, caused by: {str(e)}.' return f'Read this website failed, caused by: {str(e)}.'
else: else:
async def _arun(self, url: str) -> str: async def _arun(self, url: str) -> str:
raise NotImplementedError raise NotImplementedError


def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_instance=self.model_instance,
prompt=refine_prompts.PROMPT
)
refine_chain = LLMChain(
model_instance=self.model_instance,
prompt=refine_prompts.REFINE_PROMPT
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)



def page_result(text: str, cursor: int, max_length: int) -> str: def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""

Loading…
Cancel
Save