| from typing import cast | |||||
| from core.entities.application_entities import ModelConfigEntity | |||||
| from core.model_runtime.entities.message_entities import PromptMessage | |||||
| from core.model_runtime.entities.model_entities import ModelPropertyKey | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| class CalcTokenMixin: | |||||
| def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: | |||||
| """ | |||||
| Got the rest tokens available for the model after excluding messages tokens and completion max tokens | |||||
| :param model_config: | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||||
| max_tokens = 0 | |||||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||||
| if (parameter_rule.name == 'max_tokens' | |||||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||||
| max_tokens = (model_config.parameters.get(parameter_rule.name) | |||||
| or model_config.parameters.get(parameter_rule.use_template)) or 0 | |||||
| if model_context_tokens is None: | |||||
| return 0 | |||||
| if max_tokens is None: | |||||
| max_tokens = 0 | |||||
| prompt_tokens = model_type_instance.get_num_tokens( | |||||
| model_config.model, | |||||
| model_config.credentials, | |||||
| messages | |||||
| ) | |||||
| rest_tokens = model_context_tokens - max_tokens - prompt_tokens | |||||
| return rest_tokens | |||||
| class ExceededLLMTokensLimitError(Exception): | |||||
| pass |
| from collections.abc import Sequence | |||||
| from typing import Any, Optional, Union | |||||
| from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent | |||||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||||
| from langchain.callbacks.base import BaseCallbackManager | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken | |||||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||||
| from langchain.schema import ( | |||||
| AgentAction, | |||||
| AgentFinish, | |||||
| AIMessage, | |||||
| BaseMessage, | |||||
| HumanMessage, | |||||
| SystemMessage, | |||||
| get_buffer_string, | |||||
| ) | |||||
| from langchain.tools import BaseTool | |||||
| from pydantic import root_validator | |||||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | |||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.entities.application_entities import ModelConfigEntity | |||||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||||
| from core.model_manager import ModelInstance | |||||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool | |||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): | |||||
| moving_summary_buffer: str = "" | |||||
| moving_summary_index: int = 0 | |||||
| summary_model_config: ModelConfigEntity = None | |||||
| model_config: ModelConfigEntity | |||||
| agent_llm_callback: Optional[AgentLLMCallback] = None | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| @root_validator | |||||
| def validate_llm(cls, values: dict) -> dict: | |||||
| return values | |||||
| @classmethod | |||||
| def from_llm_and_tools( | |||||
| cls, | |||||
| model_config: ModelConfigEntity, | |||||
| tools: Sequence[BaseTool], | |||||
| callback_manager: Optional[BaseCallbackManager] = None, | |||||
| extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, | |||||
| system_message: Optional[SystemMessage] = SystemMessage( | |||||
| content="You are a helpful AI assistant." | |||||
| ), | |||||
| agent_llm_callback: Optional[AgentLLMCallback] = None, | |||||
| **kwargs: Any, | |||||
| ) -> BaseSingleActionAgent: | |||||
| prompt = cls.create_prompt( | |||||
| extra_prompt_messages=extra_prompt_messages, | |||||
| system_message=system_message, | |||||
| ) | |||||
| return cls( | |||||
| model_config=model_config, | |||||
| llm=FakeLLM(response=''), | |||||
| prompt=prompt, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| agent_llm_callback=agent_llm_callback, | |||||
| **kwargs, | |||||
| ) | |||||
| def should_use_agent(self, query: str): | |||||
| """ | |||||
| return should use agent | |||||
| :param query: | |||||
| :return: | |||||
| """ | |||||
| original_max_tokens = 0 | |||||
| for parameter_rule in self.model_config.model_schema.parameter_rules: | |||||
| if (parameter_rule.name == 'max_tokens' | |||||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||||
| original_max_tokens = (self.model_config.parameters.get(parameter_rule.name) | |||||
| or self.model_config.parameters.get(parameter_rule.use_template)) or 0 | |||||
| self.model_config.parameters['max_tokens'] = 40 | |||||
| prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | |||||
| messages = prompt.to_messages() | |||||
| try: | |||||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||||
| model_instance = ModelInstance( | |||||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||||
| model=self.model_config.model, | |||||
| ) | |||||
| tools = [] | |||||
| for function in self.functions: | |||||
| tool = PromptMessageTool( | |||||
| **function | |||||
| ) | |||||
| tools.append(tool) | |||||
| result = model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| tools=tools, | |||||
| stream=False, | |||||
| model_parameters={ | |||||
| 'temperature': 0.2, | |||||
| 'top_p': 0.3, | |||||
| 'max_tokens': 1500 | |||||
| } | |||||
| ) | |||||
| except Exception as e: | |||||
| raise e | |||||
| self.model_config.parameters['max_tokens'] = original_max_tokens | |||||
| return True if result.message.tool_calls else False | |||||
| def plan( | |||||
| self, | |||||
| intermediate_steps: list[tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||||
| selected_inputs = { | |||||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||||
| } | |||||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||||
| prompt = self.prompt.format_prompt(**full_inputs) | |||||
| messages = prompt.to_messages() | |||||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||||
| # summarize messages if rest_tokens < 0 | |||||
| try: | |||||
| prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions) | |||||
| except ExceededLLMTokensLimitError as e: | |||||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||||
| model_instance = ModelInstance( | |||||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||||
| model=self.model_config.model, | |||||
| ) | |||||
| tools = [] | |||||
| for function in self.functions: | |||||
| tool = PromptMessageTool( | |||||
| **function | |||||
| ) | |||||
| tools.append(tool) | |||||
| result = model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| tools=tools, | |||||
| stream=False, | |||||
| callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [], | |||||
| model_parameters={ | |||||
| 'temperature': 0.2, | |||||
| 'top_p': 0.3, | |||||
| 'max_tokens': 1500 | |||||
| } | |||||
| ) | |||||
| ai_message = AIMessage( | |||||
| content=result.message.content or "", | |||||
| additional_kwargs={ | |||||
| 'function_call': { | |||||
| 'id': result.message.tool_calls[0].id, | |||||
| **result.message.tool_calls[0].function.dict() | |||||
| } if result.message.tool_calls else None | |||||
| } | |||||
| ) | |||||
| agent_decision = _parse_ai_message(ai_message) | |||||
| if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': | |||||
| tool_inputs = agent_decision.tool_input | |||||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | |||||
| tool_inputs['query'] = kwargs['input'] | |||||
| agent_decision.tool_input = tool_inputs | |||||
| return agent_decision | |||||
| @classmethod | |||||
| def get_system_message(cls): | |||||
| return SystemMessage(content="You are a helpful AI assistant.\n" | |||||
| "The current date or current time you know is wrong.\n" | |||||
| "Respond directly if appropriate.") | |||||
| def return_stopped_response( | |||||
| self, | |||||
| early_stopping_method: str, | |||||
| intermediate_steps: list[tuple[AgentAction, str]], | |||||
| **kwargs: Any, | |||||
| ) -> AgentFinish: | |||||
| try: | |||||
| return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) | |||||
| except ValueError: | |||||
| return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") | |||||
| def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: | |||||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||||
| rest_tokens = self.get_message_rest_tokens( | |||||
| self.model_config, | |||||
| messages, | |||||
| **kwargs | |||||
| ) | |||||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||||
| if rest_tokens >= 0: | |||||
| return messages | |||||
| system_message = None | |||||
| human_message = None | |||||
| should_summary_messages = [] | |||||
| for message in messages: | |||||
| if isinstance(message, SystemMessage): | |||||
| system_message = message | |||||
| elif isinstance(message, HumanMessage): | |||||
| human_message = message | |||||
| else: | |||||
| should_summary_messages.append(message) | |||||
| if len(should_summary_messages) > 2: | |||||
| ai_message = should_summary_messages[-2] | |||||
| function_message = should_summary_messages[-1] | |||||
| should_summary_messages = should_summary_messages[self.moving_summary_index:-2] | |||||
| self.moving_summary_index = len(should_summary_messages) | |||||
| else: | |||||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||||
| raise ExceededLLMTokensLimitError(error_msg) | |||||
| new_messages = [system_message, human_message] | |||||
| if self.moving_summary_index == 0: | |||||
| should_summary_messages.insert(0, human_message) | |||||
| self.moving_summary_buffer = self.predict_new_summary( | |||||
| messages=should_summary_messages, | |||||
| existing_summary=self.moving_summary_buffer | |||||
| ) | |||||
| new_messages.append(AIMessage(content=self.moving_summary_buffer)) | |||||
| new_messages.append(ai_message) | |||||
| new_messages.append(function_message) | |||||
| return new_messages | |||||
| def predict_new_summary( | |||||
| self, messages: list[BaseMessage], existing_summary: str | |||||
| ) -> str: | |||||
| new_lines = get_buffer_string( | |||||
| messages, | |||||
| human_prefix="Human", | |||||
| ai_prefix="AI", | |||||
| ) | |||||
| chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) | |||||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||||
| def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: | |||||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||||
| if model_config.provider == 'azure_openai': | |||||
| model = model_config.model | |||||
| model = model.replace("gpt-35", "gpt-3.5") | |||||
| else: | |||||
| model = model_config.credentials.get("base_model_name") | |||||
| tiktoken_ = _import_tiktoken() | |||||
| try: | |||||
| encoding = tiktoken_.encoding_for_model(model) | |||||
| except KeyError: | |||||
| model = "cl100k_base" | |||||
| encoding = tiktoken_.get_encoding(model) | |||||
| if model.startswith("gpt-3.5-turbo"): | |||||
| # every message follows <im_start>{role/name}\n{content}<im_end>\n | |||||
| tokens_per_message = 4 | |||||
| # if there's a name, the role is omitted | |||||
| tokens_per_name = -1 | |||||
| elif model.startswith("gpt-4"): | |||||
| tokens_per_message = 3 | |||||
| tokens_per_name = 1 | |||||
| else: | |||||
| raise NotImplementedError( | |||||
| f"get_num_tokens_from_messages() is not presently implemented " | |||||
| f"for model {model}." | |||||
| "See https://github.com/openai/openai-python/blob/main/chatml.md for " | |||||
| "information on how messages are converted to tokens." | |||||
| ) | |||||
| num_tokens = 0 | |||||
| for m in messages: | |||||
| message = _convert_message_to_dict(m) | |||||
| num_tokens += tokens_per_message | |||||
| for key, value in message.items(): | |||||
| if key == "function_call": | |||||
| for f_key, f_value in value.items(): | |||||
| num_tokens += len(encoding.encode(f_key)) | |||||
| num_tokens += len(encoding.encode(f_value)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(value)) | |||||
| if key == "name": | |||||
| num_tokens += tokens_per_name | |||||
| # every reply is primed with <im_start>assistant | |||||
| num_tokens += 3 | |||||
| if kwargs.get('functions'): | |||||
| for function in kwargs.get('functions'): | |||||
| num_tokens += len(encoding.encode('name')) | |||||
| num_tokens += len(encoding.encode(function.get("name"))) | |||||
| num_tokens += len(encoding.encode('description')) | |||||
| num_tokens += len(encoding.encode(function.get("description"))) | |||||
| parameters = function.get("parameters") | |||||
| num_tokens += len(encoding.encode('parameters')) | |||||
| if 'title' in parameters: | |||||
| num_tokens += len(encoding.encode('title')) | |||||
| num_tokens += len(encoding.encode(parameters.get("title"))) | |||||
| num_tokens += len(encoding.encode('type')) | |||||
| num_tokens += len(encoding.encode(parameters.get("type"))) | |||||
| if 'properties' in parameters: | |||||
| num_tokens += len(encoding.encode('properties')) | |||||
| for key, value in parameters.get('properties').items(): | |||||
| num_tokens += len(encoding.encode(key)) | |||||
| for field_key, field_value in value.items(): | |||||
| num_tokens += len(encoding.encode(field_key)) | |||||
| if field_key == 'enum': | |||||
| for enum_field in field_value: | |||||
| num_tokens += 3 | |||||
| num_tokens += len(encoding.encode(enum_field)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(field_key)) | |||||
| num_tokens += len(encoding.encode(str(field_value))) | |||||
| if 'required' in parameters: | |||||
| num_tokens += len(encoding.encode('required')) | |||||
| for required_field in parameters['required']: | |||||
| num_tokens += 3 | |||||
| num_tokens += len(encoding.encode(required_field)) | |||||
| return num_tokens |
| import re | |||||
| from collections.abc import Sequence | |||||
| from typing import Any, Optional, Union, cast | |||||
| from langchain import BasePromptTemplate, PromptTemplate | |||||
| from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent | |||||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||||
| from langchain.callbacks.base import BaseCallbackManager | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||||
| from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate | |||||
| from langchain.schema import ( | |||||
| AgentAction, | |||||
| AgentFinish, | |||||
| AIMessage, | |||||
| BaseMessage, | |||||
| HumanMessage, | |||||
| OutputParserException, | |||||
| get_buffer_string, | |||||
| ) | |||||
| from langchain.tools import BaseTool | |||||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | |||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.entities.application_entities import ModelConfigEntity | |||||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||||
| Valid "action" values: "Final Answer" or {tool_names} | |||||
| Provide only ONE action per $JSON_BLOB, as shown: | |||||
| ``` | |||||
| {{{{ | |||||
| "action": $TOOL_NAME, | |||||
| "action_input": $INPUT | |||||
| }}}} | |||||
| ``` | |||||
| Follow this format: | |||||
| Question: input question to answer | |||||
| Thought: consider previous and subsequent steps | |||||
| Action: | |||||
| ``` | |||||
| $JSON_BLOB | |||||
| ``` | |||||
| Observation: action result | |||||
| ... (repeat Thought/Action/Observation N times) | |||||
| Thought: I know what to respond | |||||
| Action: | |||||
| ``` | |||||
| {{{{ | |||||
| "action": "Final Answer", | |||||
| "action_input": "Final response to human" | |||||
| }}}} | |||||
| ```""" | |||||
| class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||||
| moving_summary_buffer: str = "" | |||||
| moving_summary_index: int = 0 | |||||
| summary_model_config: ModelConfigEntity = None | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| def should_use_agent(self, query: str): | |||||
| """ | |||||
| return should use agent | |||||
| Using the ReACT mode to determine whether an agent is needed is costly, | |||||
| so it's better to just use an Agent for reasoning, which is cheaper. | |||||
| :param query: | |||||
| :return: | |||||
| """ | |||||
| return True | |||||
| def plan( | |||||
| self, | |||||
| intermediate_steps: list[tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, | |||||
| along with observatons | |||||
| callbacks: Callbacks to run. | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) | |||||
| prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)]) | |||||
| messages = [] | |||||
| if prompts: | |||||
| messages = prompts[0].to_messages() | |||||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages) | |||||
| if rest_tokens < 0: | |||||
| full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | |||||
| try: | |||||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||||
| except Exception as e: | |||||
| raise e | |||||
| try: | |||||
| agent_decision = self.output_parser.parse(full_output) | |||||
| if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': | |||||
| tool_inputs = agent_decision.tool_input | |||||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | |||||
| tool_inputs['query'] = kwargs['input'] | |||||
| agent_decision.tool_input = tool_inputs | |||||
| return agent_decision | |||||
| except OutputParserException: | |||||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||||
| "I don't know how to respond to that."}, "") | |||||
| def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): | |||||
| if len(intermediate_steps) >= 2 and self.summary_model_config: | |||||
| should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] | |||||
| should_summary_messages = [AIMessage(content=observation) | |||||
| for _, observation in should_summary_intermediate_steps] | |||||
| if self.moving_summary_index == 0: | |||||
| should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input"))) | |||||
| self.moving_summary_index = len(intermediate_steps) | |||||
| else: | |||||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||||
| raise ExceededLLMTokensLimitError(error_msg) | |||||
| if self.moving_summary_buffer and 'chat_history' in kwargs: | |||||
| kwargs["chat_history"].pop() | |||||
| self.moving_summary_buffer = self.predict_new_summary( | |||||
| messages=should_summary_messages, | |||||
| existing_summary=self.moving_summary_buffer | |||||
| ) | |||||
| if 'chat_history' in kwargs: | |||||
| kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer)) | |||||
| return self.get_full_inputs([intermediate_steps[-1]], **kwargs) | |||||
| def predict_new_summary( | |||||
| self, messages: list[BaseMessage], existing_summary: str | |||||
| ) -> str: | |||||
| new_lines = get_buffer_string( | |||||
| messages, | |||||
| human_prefix="Human", | |||||
| ai_prefix="AI", | |||||
| ) | |||||
| chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) | |||||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||||
| @classmethod | |||||
| def create_prompt( | |||||
| cls, | |||||
| tools: Sequence[BaseTool], | |||||
| prefix: str = PREFIX, | |||||
| suffix: str = SUFFIX, | |||||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[list[str]] = None, | |||||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||||
| ) -> BasePromptTemplate: | |||||
| tool_strings = [] | |||||
| for tool in tools: | |||||
| args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) | |||||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") | |||||
| formatted_tools = "\n".join(tool_strings) | |||||
| tool_names = ", ".join([('"' + tool.name + '"') for tool in tools]) | |||||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |||||
| if input_variables is None: | |||||
| input_variables = ["input", "agent_scratchpad"] | |||||
| _memory_prompts = memory_prompts or [] | |||||
| messages = [ | |||||
| SystemMessagePromptTemplate.from_template(template), | |||||
| *_memory_prompts, | |||||
| HumanMessagePromptTemplate.from_template(human_message_template), | |||||
| ] | |||||
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |||||
| @classmethod | |||||
| def create_completion_prompt( | |||||
| cls, | |||||
| tools: Sequence[BaseTool], | |||||
| prefix: str = PREFIX, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[list[str]] = None, | |||||
| ) -> PromptTemplate: | |||||
| """Create prompt in the style of the zero shot agent. | |||||
| Args: | |||||
| tools: List of tools the agent will have access to, used to format the | |||||
| prompt. | |||||
| prefix: String to put before the list of tools. | |||||
| input_variables: List of input variables the final prompt will expect. | |||||
| Returns: | |||||
| A PromptTemplate with the template assembled from the pieces here. | |||||
| """ | |||||
| suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. | |||||
| Question: {input} | |||||
| Thought: {agent_scratchpad} | |||||
| """ | |||||
| tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) | |||||
| tool_names = ", ".join([tool.name for tool in tools]) | |||||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||||
| template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) | |||||
| if input_variables is None: | |||||
| input_variables = ["input", "agent_scratchpad"] | |||||
| return PromptTemplate(template=template, input_variables=input_variables) | |||||
| def _construct_scratchpad( | |||||
| self, intermediate_steps: list[tuple[AgentAction, str]] | |||||
| ) -> str: | |||||
| agent_scratchpad = "" | |||||
| for action, observation in intermediate_steps: | |||||
| agent_scratchpad += action.log | |||||
| agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" | |||||
| if not isinstance(agent_scratchpad, str): | |||||
| raise ValueError("agent_scratchpad should be of type string.") | |||||
| if agent_scratchpad: | |||||
| llm_chain = cast(LLMChain, self.llm_chain) | |||||
| if llm_chain.model_config.mode == "chat": | |||||
| return ( | |||||
| f"This was your previous work " | |||||
| f"(but I haven't seen any of it! I only see what " | |||||
| f"you return as final answer):\n{agent_scratchpad}" | |||||
| ) | |||||
| else: | |||||
| return agent_scratchpad | |||||
| else: | |||||
| return agent_scratchpad | |||||
| @classmethod | |||||
| def from_llm_and_tools( | |||||
| cls, | |||||
| model_config: ModelConfigEntity, | |||||
| tools: Sequence[BaseTool], | |||||
| callback_manager: Optional[BaseCallbackManager] = None, | |||||
| output_parser: Optional[AgentOutputParser] = None, | |||||
| prefix: str = PREFIX, | |||||
| suffix: str = SUFFIX, | |||||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[list[str]] = None, | |||||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||||
| agent_llm_callback: Optional[AgentLLMCallback] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Agent: | |||||
| """Construct an agent from an LLM and tools.""" | |||||
| cls._validate_tools(tools) | |||||
| if model_config.mode == "chat": | |||||
| prompt = cls.create_prompt( | |||||
| tools, | |||||
| prefix=prefix, | |||||
| suffix=suffix, | |||||
| human_message_template=human_message_template, | |||||
| format_instructions=format_instructions, | |||||
| input_variables=input_variables, | |||||
| memory_prompts=memory_prompts, | |||||
| ) | |||||
| else: | |||||
| prompt = cls.create_completion_prompt( | |||||
| tools, | |||||
| prefix=prefix, | |||||
| format_instructions=format_instructions, | |||||
| input_variables=input_variables, | |||||
| ) | |||||
| llm_chain = LLMChain( | |||||
| model_config=model_config, | |||||
| prompt=prompt, | |||||
| callback_manager=callback_manager, | |||||
| agent_llm_callback=agent_llm_callback, | |||||
| parameters={ | |||||
| 'temperature': 0.2, | |||||
| 'top_p': 0.3, | |||||
| 'max_tokens': 1500 | |||||
| } | |||||
| ) | |||||
| tool_names = [tool.name for tool in tools] | |||||
| _output_parser = output_parser | |||||
| return cls( | |||||
| llm_chain=llm_chain, | |||||
| allowed_tools=tool_names, | |||||
| output_parser=_output_parser, | |||||
| **kwargs, | |||||
| ) |
| import json | |||||
| import logging | import logging | ||||
| from typing import cast | from typing import cast | ||||
| from core.moderation.base import ModerationException | from core.moderation.base import ModerationException | ||||
| from core.tools.entities.tool_entities import ToolRuntimeVariablePool | from core.tools.entities.tool_entities import ToolRuntimeVariablePool | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import App, Conversation, Message, MessageAgentThought, MessageChain | |||||
| from models.model import App, Conversation, Message, MessageAgentThought | |||||
| from models.tools import ToolConversationVariables | from models.tools import ToolConversationVariables | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| # convert db variables to tool variables | # convert db variables to tool variables | ||||
| tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) | tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) | ||||
| message_chain = self._init_message_chain( | |||||
| message=message, | |||||
| query=query | |||||
| ) | |||||
| # init model instance | # init model instance | ||||
| model_instance = ModelInstance( | model_instance = ModelInstance( | ||||
| 'pool': db_variables.variables | 'pool': db_variables.variables | ||||
| }) | }) | ||||
| def _init_message_chain(self, message: Message, query: str) -> MessageChain: | |||||
| """ | |||||
| Init MessageChain | |||||
| :param message: message | |||||
| :param query: query | |||||
| :return: | |||||
| """ | |||||
| message_chain = MessageChain( | |||||
| message_id=message.id, | |||||
| type="AgentExecutor", | |||||
| input=json.dumps({ | |||||
| "input": query | |||||
| }) | |||||
| ) | |||||
| db.session.add(message_chain) | |||||
| db.session.commit() | |||||
| return message_chain | |||||
| def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None: | |||||
| """ | |||||
| Save MessageChain | |||||
| :param message_chain: message chain | |||||
| :param output_text: output text | |||||
| :return: | |||||
| """ | |||||
| message_chain.output = json.dumps({ | |||||
| "output": output_text | |||||
| }) | |||||
| db.session.commit() | |||||
| def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, | def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, | ||||
| message: Message) -> LLMUsage: | message: Message) -> LLMUsage: | ||||
| """ | """ |
| from core.application_queue_manager import ApplicationQueueManager, PublishFrom | from core.application_queue_manager import ApplicationQueueManager, PublishFrom | ||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity | from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity | ||||
| from core.features.dataset_retrieval import DatasetRetrievalFeature | |||||
| from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.moderation.base import ModerationException | from core.moderation.base import ModerationException |
| from enum import Enum | |||||
| class PlanningStrategy(Enum): | |||||
| ROUTER = 'router' | |||||
| REACT_ROUTER = 'react_router' | |||||
| REACT = 'react' | |||||
| FUNCTION_CALL = 'function_call' |
| import logging | |||||
| from typing import Optional, cast | |||||
| from langchain.tools import BaseTool | |||||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||||
| from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy | |||||
| from core.application_queue_manager import ApplicationQueueManager | |||||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||||
| from core.entities.application_entities import ( | |||||
| AgentEntity, | |||||
| AppOrchestrationConfigEntity, | |||||
| InvokeFrom, | |||||
| ModelConfigEntity, | |||||
| ) | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||||
| from core.model_runtime.model_providers import model_provider_factory | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset | |||||
| from models.model import Message | |||||
| logger = logging.getLogger(__name__) | |||||
| class AgentRunnerFeature: | |||||
| def __init__(self, tenant_id: str, | |||||
| app_orchestration_config: AppOrchestrationConfigEntity, | |||||
| model_config: ModelConfigEntity, | |||||
| config: AgentEntity, | |||||
| queue_manager: ApplicationQueueManager, | |||||
| message: Message, | |||||
| user_id: str, | |||||
| agent_llm_callback: AgentLLMCallback, | |||||
| callback: AgentLoopGatherCallbackHandler, | |||||
| memory: Optional[TokenBufferMemory] = None,) -> None: | |||||
| """ | |||||
| Agent runner | |||||
| :param tenant_id: tenant id | |||||
| :param app_orchestration_config: app orchestration config | |||||
| :param model_config: model config | |||||
| :param config: dataset config | |||||
| :param queue_manager: queue manager | |||||
| :param message: message | |||||
| :param user_id: user id | |||||
| :param agent_llm_callback: agent llm callback | |||||
| :param callback: callback | |||||
| :param memory: memory | |||||
| """ | |||||
| self.tenant_id = tenant_id | |||||
| self.app_orchestration_config = app_orchestration_config | |||||
| self.model_config = model_config | |||||
| self.config = config | |||||
| self.queue_manager = queue_manager | |||||
| self.message = message | |||||
| self.user_id = user_id | |||||
| self.agent_llm_callback = agent_llm_callback | |||||
| self.callback = callback | |||||
| self.memory = memory | |||||
| def run(self, query: str, | |||||
| invoke_from: InvokeFrom) -> Optional[str]: | |||||
| """ | |||||
| Retrieve agent loop result. | |||||
| :param query: query | |||||
| :param invoke_from: invoke from | |||||
| :return: | |||||
| """ | |||||
| provider = self.config.provider | |||||
| model = self.config.model | |||||
| tool_configs = self.config.tools | |||||
| # check model is support tool calling | |||||
| provider_instance = model_provider_factory.get_provider_instance(provider=provider) | |||||
| model_type_instance = provider_instance.get_model_instance(ModelType.LLM) | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| # get model schema | |||||
| model_schema = model_type_instance.get_model_schema( | |||||
| model=model, | |||||
| credentials=self.model_config.credentials | |||||
| ) | |||||
| if not model_schema: | |||||
| return None | |||||
| planning_strategy = PlanningStrategy.REACT | |||||
| features = model_schema.features | |||||
| if features: | |||||
| if ModelFeature.TOOL_CALL in features \ | |||||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||||
| planning_strategy = PlanningStrategy.FUNCTION_CALL | |||||
| tools = self.to_tools( | |||||
| tool_configs=tool_configs, | |||||
| invoke_from=invoke_from, | |||||
| callbacks=[self.callback, DifyStdOutCallbackHandler()], | |||||
| ) | |||||
| if len(tools) == 0: | |||||
| return None | |||||
| agent_configuration = AgentConfiguration( | |||||
| strategy=planning_strategy, | |||||
| model_config=self.model_config, | |||||
| tools=tools, | |||||
| memory=self.memory, | |||||
| max_iterations=10, | |||||
| max_execution_time=400.0, | |||||
| early_stopping_method="generate", | |||||
| agent_llm_callback=self.agent_llm_callback, | |||||
| callbacks=[self.callback, DifyStdOutCallbackHandler()] | |||||
| ) | |||||
| agent_executor = AgentExecutor(agent_configuration) | |||||
| try: | |||||
| # check if should use agent | |||||
| should_use_agent = agent_executor.should_use_agent(query) | |||||
| if not should_use_agent: | |||||
| return None | |||||
| result = agent_executor.run(query) | |||||
| return result.output | |||||
| except Exception as ex: | |||||
| logger.exception("agent_executor run failed") | |||||
| return None | |||||
| def to_dataset_retriever_tool(self, tool_config: dict, | |||||
| invoke_from: InvokeFrom) \ | |||||
| -> Optional[BaseTool]: | |||||
| """ | |||||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||||
| :param tool_config: tool config | |||||
| :param invoke_from: invoke from | |||||
| """ | |||||
| show_retrieve_source = self.app_orchestration_config.show_retrieve_source | |||||
| hit_callback = DatasetIndexToolCallbackHandler( | |||||
| queue_manager=self.queue_manager, | |||||
| app_id=self.message.app_id, | |||||
| message_id=self.message.id, | |||||
| user_id=self.user_id, | |||||
| invoke_from=invoke_from | |||||
| ) | |||||
| # get dataset from dataset id | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.tenant_id == self.tenant_id, | |||||
| Dataset.id == tool_config.get("id") | |||||
| ).first() | |||||
| # pass if dataset is not available | |||||
| if not dataset: | |||||
| return None | |||||
| # pass if dataset is not available | |||||
| if (dataset and dataset.available_document_count == 0 | |||||
| and dataset.available_document_count == 0): | |||||
| return None | |||||
| # get retrieval model config | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enabled': False | |||||
| } | |||||
| retrieval_model_config = dataset.retrieval_model \ | |||||
| if dataset.retrieval_model else default_retrieval_model | |||||
| # get top k | |||||
| top_k = retrieval_model_config['top_k'] | |||||
| # get score threshold | |||||
| score_threshold = None | |||||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | |||||
| if score_threshold_enabled: | |||||
| score_threshold = retrieval_model_config.get("score_threshold") | |||||
| tool = DatasetRetrieverTool.from_dataset( | |||||
| dataset=dataset, | |||||
| top_k=top_k, | |||||
| score_threshold=score_threshold, | |||||
| hit_callbacks=[hit_callback], | |||||
| return_resource=show_retrieve_source, | |||||
| retriever_from=invoke_from.to_source() | |||||
| ) | |||||
| return tool |
| from langchain.schema import Generation, LLMResult | from langchain.schema import Generation, LLMResult | ||||
| from langchain.schema.language_model import BaseLanguageModel | from langchain.schema.language_model import BaseLanguageModel | ||||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||||
| from core.entities.application_entities import ModelConfigEntity | from core.entities.application_entities import ModelConfigEntity | ||||
| from core.entities.message_entities import lc_messages_to_prompt_messages | from core.entities.message_entities import lc_messages_to_prompt_messages | ||||
| from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback | |||||
| from core.features.dataset_retrieval.agent.fake_llm import FakeLLM | |||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| class LLMChain(LCLLMChain): | class LLMChain(LCLLMChain): |
| from core.entities.application_entities import ModelConfigEntity | from core.entities.application_entities import ModelConfigEntity | ||||
| from core.entities.message_entities import lc_messages_to_prompt_messages | from core.entities.message_entities import lc_messages_to_prompt_messages | ||||
| from core.features.dataset_retrieval.agent.fake_llm import FakeLLM | |||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.model_runtime.entities.message_entities import PromptMessageTool | from core.model_runtime.entities.message_entities import PromptMessageTool | ||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | class MultiDatasetRouterAgent(OpenAIFunctionsAgent): |
| from langchain.schema import AgentAction, AgentFinish, OutputParserException | from langchain.schema import AgentAction, AgentFinish, OutputParserException | ||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.entities.application_entities import ModelConfigEntity | from core.entities.application_entities import ModelConfigEntity | ||||
| from core.features.dataset_retrieval.agent.llm_chain import LLMChain | |||||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | ||||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. |
| import enum | |||||
| import logging | import logging | ||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from pydantic import BaseModel, Extra | from pydantic import BaseModel, Extra | ||||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||||
| from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||||
| from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | |||||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | |||||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | |||||
| from core.entities.agent_entities import PlanningStrategy | |||||
| from core.entities.application_entities import ModelConfigEntity | from core.entities.application_entities import ModelConfigEntity | ||||
| from core.entities.message_entities import prompt_messages_to_lc_messages | from core.entities.message_entities import prompt_messages_to_lc_messages | ||||
| from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback | |||||
| from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||||
| from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser | |||||
| from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||||
| from core.helper import moderation | from core.helper import moderation | ||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_runtime.errors.invoke import InvokeError | from core.model_runtime.errors.invoke import InvokeError | ||||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | ||||
| class PlanningStrategy(str, enum.Enum): | |||||
| ROUTER = 'router' | |||||
| REACT_ROUTER = 'react_router' | |||||
| REACT = 'react' | |||||
| FUNCTION_CALL = 'function_call' | |||||
| class AgentConfiguration(BaseModel): | class AgentConfiguration(BaseModel): | ||||
| strategy: PlanningStrategy | strategy: PlanningStrategy | ||||
| model_config: ModelConfigEntity | model_config: ModelConfigEntity | ||||
| self.agent = self._init_agent() | self.agent = self._init_agent() | ||||
| def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: | def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: | ||||
| if self.configuration.strategy == PlanningStrategy.REACT: | |||||
| agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | |||||
| model_config=self.configuration.model_config, | |||||
| tools=self.configuration.tools, | |||||
| output_parser=StructuredChatOutputParser(), | |||||
| summary_model_config=self.configuration.summary_model_config | |||||
| if self.configuration.summary_model_config else None, | |||||
| agent_llm_callback=self.configuration.agent_llm_callback, | |||||
| verbose=True | |||||
| ) | |||||
| elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | |||||
| agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | |||||
| model_config=self.configuration.model_config, | |||||
| tools=self.configuration.tools, | |||||
| extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) | |||||
| if self.configuration.memory else None, # used for read chat histories memory | |||||
| summary_model_config=self.configuration.summary_model_config | |||||
| if self.configuration.summary_model_config else None, | |||||
| agent_llm_callback=self.configuration.agent_llm_callback, | |||||
| verbose=True | |||||
| ) | |||||
| elif self.configuration.strategy == PlanningStrategy.ROUTER: | |||||
| if self.configuration.strategy == PlanningStrategy.ROUTER: | |||||
| self.configuration.tools = [t for t in self.configuration.tools | self.configuration.tools = [t for t in self.configuration.tools | ||||
| if isinstance(t, DatasetRetrieverTool) | if isinstance(t, DatasetRetrieverTool) | ||||
| or isinstance(t, DatasetMultiRetrieverTool)] | or isinstance(t, DatasetMultiRetrieverTool)] |
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.entities.agent_entities import PlanningStrategy | |||||
| from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity | from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity | ||||
| from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_runtime.entities.model_entities import ModelFeature | from core.model_runtime.entities.model_entities import ModelFeature | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel |
| import base64 | |||||
| import hashlib | |||||
| import hmac | |||||
| import json | |||||
| import queue | |||||
| import ssl | |||||
| from datetime import datetime | |||||
| from time import mktime | |||||
| from typing import Optional | |||||
| from urllib.parse import urlencode, urlparse | |||||
| from wsgiref.handlers import format_date_time | |||||
| import websocket | |||||
| class SparkLLMClient: | |||||
| def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): | |||||
| domain = 'spark-api.xf-yun.com' | |||||
| endpoint = 'chat' | |||||
| if api_domain: | |||||
| domain = api_domain | |||||
| if model_name == 'spark-v3': | |||||
| endpoint = 'multimodal' | |||||
| model_api_configs = { | |||||
| 'spark': { | |||||
| 'version': 'v1.1', | |||||
| 'chat_domain': 'general' | |||||
| }, | |||||
| 'spark-v2': { | |||||
| 'version': 'v2.1', | |||||
| 'chat_domain': 'generalv2' | |||||
| }, | |||||
| 'spark-v3': { | |||||
| 'version': 'v3.1', | |||||
| 'chat_domain': 'generalv3' | |||||
| }, | |||||
| 'spark-v3.5': { | |||||
| 'version': 'v3.5', | |||||
| 'chat_domain': 'generalv3.5' | |||||
| } | |||||
| } | |||||
| api_version = model_api_configs[model_name]['version'] | |||||
| self.chat_domain = model_api_configs[model_name]['chat_domain'] | |||||
| self.api_base = f"wss://{domain}/{api_version}/{endpoint}" | |||||
| self.app_id = app_id | |||||
| self.ws_url = self.create_url( | |||||
| urlparse(self.api_base).netloc, | |||||
| urlparse(self.api_base).path, | |||||
| self.api_base, | |||||
| api_key, | |||||
| api_secret | |||||
| ) | |||||
| self.queue = queue.Queue() | |||||
| self.blocking_message = '' | |||||
| def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: | |||||
| # generate timestamp by RFC1123 | |||||
| now = datetime.now() | |||||
| date = format_date_time(mktime(now.timetuple())) | |||||
| signature_origin = "host: " + host + "\n" | |||||
| signature_origin += "date: " + date + "\n" | |||||
| signature_origin += "GET " + path + " HTTP/1.1" | |||||
| # encrypt using hmac-sha256 | |||||
| signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), | |||||
| digestmod=hashlib.sha256).digest() | |||||
| signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') | |||||
| authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' | |||||
| authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') | |||||
| v = { | |||||
| "authorization": authorization, | |||||
| "date": date, | |||||
| "host": host | |||||
| } | |||||
| # generate url | |||||
| url = api_base + '?' + urlencode(v) | |||||
| return url | |||||
| def run(self, messages: list, user_id: str, | |||||
| model_kwargs: Optional[dict] = None, streaming: bool = False): | |||||
| websocket.enableTrace(False) | |||||
| ws = websocket.WebSocketApp( | |||||
| self.ws_url, | |||||
| on_message=self.on_message, | |||||
| on_error=self.on_error, | |||||
| on_close=self.on_close, | |||||
| on_open=self.on_open | |||||
| ) | |||||
| ws.messages = messages | |||||
| ws.user_id = user_id | |||||
| ws.model_kwargs = model_kwargs | |||||
| ws.streaming = streaming | |||||
| ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | |||||
| def on_error(self, ws, error): | |||||
| self.queue.put({ | |||||
| 'status_code': error.status_code, | |||||
| 'error': error.resp_body.decode('utf-8') | |||||
| }) | |||||
| ws.close() | |||||
| def on_close(self, ws, close_status_code, close_reason): | |||||
| self.queue.put({'done': True}) | |||||
| def on_open(self, ws): | |||||
| self.blocking_message = '' | |||||
| data = json.dumps(self.gen_params( | |||||
| messages=ws.messages, | |||||
| user_id=ws.user_id, | |||||
| model_kwargs=ws.model_kwargs | |||||
| )) | |||||
| ws.send(data) | |||||
| def on_message(self, ws, message): | |||||
| data = json.loads(message) | |||||
| code = data['header']['code'] | |||||
| if code != 0: | |||||
| self.queue.put({ | |||||
| 'status_code': 400, | |||||
| 'error': f"Code: {code}, Error: {data['header']['message']}" | |||||
| }) | |||||
| ws.close() | |||||
| else: | |||||
| choices = data["payload"]["choices"] | |||||
| status = choices["status"] | |||||
| content = choices["text"][0]["content"] | |||||
| if ws.streaming: | |||||
| self.queue.put({'data': content}) | |||||
| else: | |||||
| self.blocking_message += content | |||||
| if status == 2: | |||||
| if not ws.streaming: | |||||
| self.queue.put({'data': self.blocking_message}) | |||||
| ws.close() | |||||
| def gen_params(self, messages: list, user_id: str, | |||||
| model_kwargs: Optional[dict] = None) -> dict: | |||||
| data = { | |||||
| "header": { | |||||
| "app_id": self.app_id, | |||||
| "uid": user_id | |||||
| }, | |||||
| "parameter": { | |||||
| "chat": { | |||||
| "domain": self.chat_domain | |||||
| } | |||||
| }, | |||||
| "payload": { | |||||
| "message": { | |||||
| "text": messages | |||||
| } | |||||
| } | |||||
| } | |||||
| if model_kwargs: | |||||
| data['parameter']['chat'].update(model_kwargs) | |||||
| return data | |||||
| def subscribe(self): | |||||
| while True: | |||||
| content = self.queue.get() | |||||
| if 'error' in content: | |||||
| if content['status_code'] == 401: | |||||
| raise SparkError('[Spark] The credentials you provided are incorrect. ' | |||||
| 'Please double-check and fill them in again.') | |||||
| elif content['status_code'] == 403: | |||||
| raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " | |||||
| "Please try again after obtaining the necessary permissions.") | |||||
| else: | |||||
| raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") | |||||
| if 'data' not in content: | |||||
| break | |||||
| yield content | |||||
| class SparkError(Exception): | |||||
| pass |
| from datetime import datetime | |||||
| from langchain.tools import BaseTool | |||||
| from pydantic import BaseModel, Field | |||||
| class DatetimeToolInput(BaseModel): | |||||
| type: str = Field(..., description="Type for current time, must be: datetime.") | |||||
| class DatetimeTool(BaseTool): | |||||
| """Tool for querying current datetime.""" | |||||
| name: str = "current_datetime" | |||||
| args_schema: type[BaseModel] = DatetimeToolInput | |||||
| description: str = "A tool when you want to get the current date, time, week, month or year, " \ | |||||
| "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"." | |||||
| def _run(self, type: str) -> str: | |||||
| # get current time | |||||
| current_time = datetime.utcnow() | |||||
| return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A") | |||||
| async def _arun(self, tool_input: str) -> str: | |||||
| raise NotImplementedError() |
| import base64 | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Optional | |||||
| from extensions.ext_database import db | |||||
| from libs import rsa | |||||
| from models.account import Tenant | |||||
| from models.tool import ToolProvider, ToolProviderName | |||||
| class BaseToolProvider(ABC): | |||||
| def __init__(self, tenant_id: str): | |||||
| self.tenant_id = tenant_id | |||||
| @abstractmethod | |||||
| def get_provider_name(self) -> ToolProviderName: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def encrypt_credentials(self, credentials: dict) -> Optional[dict]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def credentials_to_func_kwargs(self) -> Optional[dict]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def credentials_validate(self, credentials: dict): | |||||
| raise NotImplementedError | |||||
| def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]: | |||||
| """ | |||||
| Returns the Provider instance for the given tenant_id and tool_name. | |||||
| """ | |||||
| query = db.session.query(ToolProvider).filter( | |||||
| ToolProvider.tenant_id == self.tenant_id, | |||||
| ToolProvider.tool_name == self.get_provider_name().value | |||||
| ) | |||||
| if must_enabled: | |||||
| query = query.filter(ToolProvider.is_enabled == True) | |||||
| return query.first() | |||||
| def encrypt_token(self, token) -> str: | |||||
| tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() | |||||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||||
| return base64.b64encode(encrypted_token).decode() | |||||
| def decrypt_token(self, token: str, obfuscated: bool = False) -> str: | |||||
| token = rsa.decrypt(base64.b64decode(token), self.tenant_id) | |||||
| if obfuscated: | |||||
| return self._obfuscated_token(token) | |||||
| return token | |||||
| def _obfuscated_token(self, token: str) -> str: | |||||
| return token[:6] + '*' * (len(token) - 8) + token[-2:] |
| class ToolValidateFailedError(Exception): | |||||
| description = "Tool Provider Validate failed" |
| from typing import Optional | |||||
| from core.tool.provider.base import BaseToolProvider | |||||
| from core.tool.provider.errors import ToolValidateFailedError | |||||
| from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper | |||||
| from models.tool import ToolProviderName | |||||
| class SerpAPIToolProvider(BaseToolProvider): | |||||
| def get_provider_name(self) -> ToolProviderName: | |||||
| """ | |||||
| Returns the name of the provider. | |||||
| :return: | |||||
| """ | |||||
| return ToolProviderName.SERPAPI | |||||
| def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |||||
| """ | |||||
| Returns the credentials for SerpAPI as a dictionary. | |||||
| :param obfuscated: obfuscate credentials if True | |||||
| :return: | |||||
| """ | |||||
| tool_provider = self.get_provider(must_enabled=True) | |||||
| if not tool_provider: | |||||
| return None | |||||
| credentials = tool_provider.credentials | |||||
| if not credentials: | |||||
| return None | |||||
| if credentials.get('api_key'): | |||||
| credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated) | |||||
| return credentials | |||||
| def credentials_to_func_kwargs(self) -> Optional[dict]: | |||||
| """ | |||||
| Returns the credentials function kwargs as a dictionary. | |||||
| :return: | |||||
| """ | |||||
| credentials = self.get_credentials() | |||||
| if not credentials: | |||||
| return None | |||||
| return { | |||||
| 'serpapi_api_key': credentials.get('api_key') | |||||
| } | |||||
| def credentials_validate(self, credentials: dict): | |||||
| """ | |||||
| Validates the given credentials. | |||||
| :param credentials: | |||||
| :return: | |||||
| """ | |||||
| if 'api_key' not in credentials or not credentials.get('api_key'): | |||||
| raise ToolValidateFailedError("SerpAPI api_key is required.") | |||||
| api_key = credentials.get('api_key') | |||||
| try: | |||||
| OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test') | |||||
| except Exception as e: | |||||
| raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e)) | |||||
| def encrypt_credentials(self, credentials: dict) -> Optional[dict]: | |||||
| """ | |||||
| Encrypts the given credentials. | |||||
| :param credentials: | |||||
| :return: | |||||
| """ | |||||
| credentials['api_key'] = self.encrypt_token(credentials.get('api_key')) | |||||
| return credentials |
| from typing import Optional | |||||
| from core.tool.provider.base import BaseToolProvider | |||||
| from core.tool.provider.serpapi_provider import SerpAPIToolProvider | |||||
| class ToolProviderService: | |||||
| def __init__(self, tenant_id: str, provider_name: str): | |||||
| self.provider = self._init_provider(tenant_id, provider_name) | |||||
| def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider: | |||||
| if provider_name == 'serpapi': | |||||
| return SerpAPIToolProvider(tenant_id) | |||||
| else: | |||||
| raise Exception('tool provider {} not found'.format(provider_name)) | |||||
| def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |||||
| """ | |||||
| Returns the credentials for Tool as a dictionary. | |||||
| :param obfuscated: | |||||
| :return: | |||||
| """ | |||||
| return self.provider.get_credentials(obfuscated) | |||||
| def credentials_validate(self, credentials: dict): | |||||
| """ | |||||
| Validates the given credentials. | |||||
| :param credentials: | |||||
| :raises: ValidateFailedError | |||||
| """ | |||||
| return self.provider.credentials_validate(credentials) | |||||
| def encrypt_credentials(self, credentials: dict): | |||||
| """ | |||||
| Encrypts the given credentials. | |||||
| :param credentials: | |||||
| :return: | |||||
| """ | |||||
| return self.provider.encrypt_credentials(credentials) |
| from langchain import SerpAPIWrapper | |||||
| from pydantic import BaseModel, Field | |||||
| class OptimizedSerpAPIInput(BaseModel): | |||||
| query: str = Field(..., description="search query.") | |||||
| class OptimizedSerpAPIWrapper(SerpAPIWrapper): | |||||
| @staticmethod | |||||
| def _process_response(res: dict, num_results: int = 5) -> str: | |||||
| """Process response from SerpAPI.""" | |||||
| if "error" in res.keys(): | |||||
| raise ValueError(f"Got error from SerpAPI: {res['error']}") | |||||
| if "answer_box" in res.keys() and type(res["answer_box"]) == list: | |||||
| res["answer_box"] = res["answer_box"][0] | |||||
| if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): | |||||
| toret = res["answer_box"]["answer"] | |||||
| elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): | |||||
| toret = res["answer_box"]["snippet"] | |||||
| elif ( | |||||
| "answer_box" in res.keys() | |||||
| and "snippet_highlighted_words" in res["answer_box"].keys() | |||||
| ): | |||||
| toret = res["answer_box"]["snippet_highlighted_words"][0] | |||||
| elif ( | |||||
| "sports_results" in res.keys() | |||||
| and "game_spotlight" in res["sports_results"].keys() | |||||
| ): | |||||
| toret = res["sports_results"]["game_spotlight"] | |||||
| elif ( | |||||
| "shopping_results" in res.keys() | |||||
| and "title" in res["shopping_results"][0].keys() | |||||
| ): | |||||
| toret = res["shopping_results"][:3] | |||||
| elif ( | |||||
| "knowledge_graph" in res.keys() | |||||
| and "description" in res["knowledge_graph"].keys() | |||||
| ): | |||||
| toret = res["knowledge_graph"]["description"] | |||||
| elif 'organic_results' in res.keys() and len(res['organic_results']) > 0: | |||||
| toret = "" | |||||
| for result in res["organic_results"][:num_results]: | |||||
| if "link" in result: | |||||
| toret += "----------------\nlink: " + result["link"] + "\n" | |||||
| if "snippet" in result: | |||||
| toret += "snippet: " + result["snippet"] + "\n" | |||||
| else: | |||||
| toret = "No good search result found" | |||||
| return "search result:\n" + toret |
| import hashlib | |||||
| import json | |||||
| import os | |||||
| import re | |||||
| import site | |||||
| import subprocess | |||||
| import tempfile | |||||
| import unicodedata | |||||
| from contextlib import contextmanager | |||||
| from typing import Any | |||||
| import requests | |||||
| from bs4 import BeautifulSoup, CData, Comment, NavigableString | |||||
| from langchain.chains import RefineDocumentsChain | |||||
| from langchain.chains.summarize import refine_prompts | |||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||||
| from langchain.tools.base import BaseTool | |||||
| from newspaper import Article | |||||
| from pydantic import BaseModel, Field | |||||
| from regex import regex | |||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.entities.application_entities import ModelConfigEntity | |||||
| from core.rag.extractor import extract_processor | |||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| from core.rag.models.document import Document | |||||
| FULL_TEMPLATE = """ | |||||
| TITLE: {title} | |||||
| AUTHORS: {authors} | |||||
| PUBLISH DATE: {publish_date} | |||||
| TOP_IMAGE_URL: {top_image} | |||||
| TEXT: | |||||
| {text} | |||||
| """ | |||||
| class WebReaderToolInput(BaseModel): | |||||
| url: str = Field(..., description="URL of the website to read") | |||||
| summary: bool = Field( | |||||
| default=False, | |||||
| description="When the user's question requires extracting the summarizing content of the webpage, " | |||||
| "set it to true." | |||||
| ) | |||||
| cursor: int = Field( | |||||
| default=0, | |||||
| description="Start reading from this character." | |||||
| "Use when the first response was truncated" | |||||
| "and you want to continue reading the page." | |||||
| "The value cannot exceed 24000.", | |||||
| ) | |||||
| class WebReaderTool(BaseTool): | |||||
| """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" | |||||
| name: str = "web_reader" | |||||
| args_schema: type[BaseModel] = WebReaderToolInput | |||||
| description: str = "use this to read a website. " \ | |||||
| "If you can answer the question based on the information provided, " \ | |||||
| "there is no need to use." | |||||
| page_contents: str = None | |||||
| url: str = None | |||||
| max_chunk_length: int = 4000 | |||||
| summary_chunk_tokens: int = 4000 | |||||
| summary_chunk_overlap: int = 0 | |||||
| summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] | |||||
| continue_reading: bool = True | |||||
| model_config: ModelConfigEntity | |||||
| model_parameters: dict[str, Any] | |||||
| def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: | |||||
| try: | |||||
| if not self.page_contents or self.url != url: | |||||
| page_contents = get_url(url) | |||||
| self.page_contents = page_contents | |||||
| self.url = url | |||||
| else: | |||||
| page_contents = self.page_contents | |||||
| except Exception as e: | |||||
| return f'Read this website failed, caused by: {str(e)}.' | |||||
| if summary: | |||||
| character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |||||
| chunk_size=self.summary_chunk_tokens, | |||||
| chunk_overlap=self.summary_chunk_overlap, | |||||
| separators=self.summary_separators | |||||
| ) | |||||
| texts = character_splitter.split_text(page_contents) | |||||
| docs = [Document(page_content=t) for t in texts] | |||||
| if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'): | |||||
| return "No content found." | |||||
| # only use first 5 docs | |||||
| if len(docs) > 5: | |||||
| docs = docs[:5] | |||||
| chain = self.get_summary_chain() | |||||
| try: | |||||
| page_contents = chain.run(docs) | |||||
| except Exception as e: | |||||
| return f'Read this website failed, caused by: {str(e)}.' | |||||
| else: | |||||
| page_contents = page_result(page_contents, cursor, self.max_chunk_length) | |||||
| if self.continue_reading and len(page_contents) >= self.max_chunk_length: | |||||
| page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ | |||||
| f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ | |||||
| f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." | |||||
| return page_contents | |||||
| async def _arun(self, url: str) -> str: | |||||
| raise NotImplementedError | |||||
| def get_summary_chain(self) -> RefineDocumentsChain: | |||||
| initial_chain = LLMChain( | |||||
| model_config=self.model_config, | |||||
| prompt=refine_prompts.PROMPT, | |||||
| parameters=self.model_parameters | |||||
| ) | |||||
| refine_chain = LLMChain( | |||||
| model_config=self.model_config, | |||||
| prompt=refine_prompts.REFINE_PROMPT, | |||||
| parameters=self.model_parameters | |||||
| ) | |||||
| return RefineDocumentsChain( | |||||
| initial_llm_chain=initial_chain, | |||||
| refine_llm_chain=refine_chain, | |||||
| document_variable_name="text", | |||||
| initial_response_name="existing_answer", | |||||
| callbacks=self.callbacks | |||||
| ) | |||||
| def page_result(text: str, cursor: int, max_length: int) -> str: | |||||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | |||||
| return text[cursor: cursor + max_length] | |||||
| def get_url(url: str) -> str: | |||||
| """Fetch URL and return the contents as a string.""" | |||||
| headers = { | |||||
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | |||||
| } | |||||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||||
| head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | |||||
| if head_response.status_code != 200: | |||||
| return "URL returned status code {}.".format(head_response.status_code) | |||||
| # check content-type | |||||
| main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip() | |||||
| if main_content_type not in supported_content_types: | |||||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | |||||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||||
| return ExtractProcessor.load_from_url(url, return_text=True) | |||||
| response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | |||||
| a = extract_using_readabilipy(response.text) | |||||
| if not a['plain_text'] or not a['plain_text'].strip(): | |||||
| return get_url_from_newspaper3k(url) | |||||
| res = FULL_TEMPLATE.format( | |||||
| title=a['title'], | |||||
| authors=a['byline'], | |||||
| publish_date=a['date'], | |||||
| top_image="", | |||||
| text=a['plain_text'] if a['plain_text'] else "", | |||||
| ) | |||||
| return res | |||||
| def get_url_from_newspaper3k(url: str) -> str: | |||||
| a = Article(url) | |||||
| a.download() | |||||
| a.parse() | |||||
| res = FULL_TEMPLATE.format( | |||||
| title=a.title, | |||||
| authors=a.authors, | |||||
| publish_date=a.publish_date, | |||||
| top_image=a.top_image, | |||||
| text=a.text, | |||||
| ) | |||||
| return res | |||||
| def extract_using_readabilipy(html): | |||||
| with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: | |||||
| f_html.write(html) | |||||
| f_html.close() | |||||
| html_path = f_html.name | |||||
| # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file | |||||
| article_json_path = html_path + ".json" | |||||
| jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') | |||||
| with chdir(jsdir): | |||||
| subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) | |||||
| # Read output of call to Readability.parse() from JSON file and return as Python dictionary | |||||
| with open(article_json_path, encoding="utf-8") as json_file: | |||||
| input_json = json.loads(json_file.read()) | |||||
| # Deleting files after processing | |||||
| os.unlink(article_json_path) | |||||
| os.unlink(html_path) | |||||
| article_json = { | |||||
| "title": None, | |||||
| "byline": None, | |||||
| "date": None, | |||||
| "content": None, | |||||
| "plain_content": None, | |||||
| "plain_text": None | |||||
| } | |||||
| # Populate article fields from readability fields where present | |||||
| if input_json: | |||||
| if "title" in input_json and input_json["title"]: | |||||
| article_json["title"] = input_json["title"] | |||||
| if "byline" in input_json and input_json["byline"]: | |||||
| article_json["byline"] = input_json["byline"] | |||||
| if "date" in input_json and input_json["date"]: | |||||
| article_json["date"] = input_json["date"] | |||||
| if "content" in input_json and input_json["content"]: | |||||
| article_json["content"] = input_json["content"] | |||||
| article_json["plain_content"] = plain_content(article_json["content"], False, False) | |||||
| article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) | |||||
| if "textContent" in input_json and input_json["textContent"]: | |||||
| article_json["plain_text"] = input_json["textContent"] | |||||
| article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) | |||||
| return article_json | |||||
| def find_module_path(module_name): | |||||
| for package_path in site.getsitepackages(): | |||||
| potential_path = os.path.join(package_path, module_name) | |||||
| if os.path.exists(potential_path): | |||||
| return potential_path | |||||
| return None | |||||
| @contextmanager | |||||
| def chdir(path): | |||||
| """Change directory in context and return to original on exit""" | |||||
| # From https://stackoverflow.com/a/37996581, couldn't find a built-in | |||||
| original_path = os.getcwd() | |||||
| os.chdir(path) | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| os.chdir(original_path) | |||||
| def extract_text_blocks_as_plain_text(paragraph_html): | |||||
| # Load article as DOM | |||||
| soup = BeautifulSoup(paragraph_html, 'html.parser') | |||||
| # Select all lists | |||||
| list_elements = soup.find_all(['ul', 'ol']) | |||||
| # Prefix text in all list items with "* " and make lists paragraphs | |||||
| for list_element in list_elements: | |||||
| plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) | |||||
| list_element.string = plain_items | |||||
| list_element.name = "p" | |||||
| # Select all text blocks | |||||
| text_blocks = [s.parent for s in soup.find_all(string=True)] | |||||
| text_blocks = [plain_text_leaf_node(block) for block in text_blocks] | |||||
| # Drop empty paragraphs | |||||
| text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) | |||||
| return text_blocks | |||||
| def plain_text_leaf_node(element): | |||||
| # Extract all text, stripped of any child HTML elements and normalise it | |||||
| plain_text = normalise_text(element.get_text()) | |||||
| if plain_text != "" and element.name == "li": | |||||
| plain_text = "* {}, ".format(plain_text) | |||||
| if plain_text == "": | |||||
| plain_text = None | |||||
| if "data-node-index" in element.attrs: | |||||
| plain = {"node_index": element["data-node-index"], "text": plain_text} | |||||
| else: | |||||
| plain = {"text": plain_text} | |||||
| return plain | |||||
| def plain_content(readability_content, content_digests, node_indexes): | |||||
| # Load article as DOM | |||||
| soup = BeautifulSoup(readability_content, 'html.parser') | |||||
| # Make all elements plain | |||||
| elements = plain_elements(soup.contents, content_digests, node_indexes) | |||||
| if node_indexes: | |||||
| # Add node index attributes to nodes | |||||
| elements = [add_node_indexes(element) for element in elements] | |||||
| # Replace article contents with plain elements | |||||
| soup.contents = elements | |||||
| return str(soup) | |||||
| def plain_elements(elements, content_digests, node_indexes): | |||||
| # Get plain content versions of all elements | |||||
| elements = [plain_element(element, content_digests, node_indexes) | |||||
| for element in elements] | |||||
| if content_digests: | |||||
| # Add content digest attribute to nodes | |||||
| elements = [add_content_digest(element) for element in elements] | |||||
| return elements | |||||
| def plain_element(element, content_digests, node_indexes): | |||||
| # For lists, we make each item plain text | |||||
| if is_leaf(element): | |||||
| # For leaf node elements, extract the text content, discarding any HTML tags | |||||
| # 1. Get element contents as text | |||||
| plain_text = element.get_text() | |||||
| # 2. Normalise the extracted text string to a canonical representation | |||||
| plain_text = normalise_text(plain_text) | |||||
| # 3. Update element content to be plain text | |||||
| element.string = plain_text | |||||
| elif is_text(element): | |||||
| if is_non_printing(element): | |||||
| # The simplified HTML may have come from Readability.js so might | |||||
| # have non-printing text (e.g. Comment or CData). In this case, we | |||||
| # keep the structure, but ensure that the string is empty. | |||||
| element = type(element)("") | |||||
| else: | |||||
| plain_text = element.string | |||||
| plain_text = normalise_text(plain_text) | |||||
| element = type(element)(plain_text) | |||||
| else: | |||||
| # If not a leaf node or leaf type call recursively on child nodes, replacing | |||||
| element.contents = plain_elements(element.contents, content_digests, node_indexes) | |||||
| return element | |||||
| def add_node_indexes(element, node_index="0"): | |||||
| # Can't add attributes to string types | |||||
| if is_text(element): | |||||
| return element | |||||
| # Add index to current element | |||||
| element["data-node-index"] = node_index | |||||
| # Add index to child elements | |||||
| for local_idx, child in enumerate( | |||||
| [c for c in element.contents if not is_text(c)], start=1): | |||||
| # Can't add attributes to leaf string types | |||||
| child_index = "{stem}.{local}".format( | |||||
| stem=node_index, local=local_idx) | |||||
| add_node_indexes(child, node_index=child_index) | |||||
| return element | |||||
| def normalise_text(text): | |||||
| """Normalise unicode and whitespace.""" | |||||
| # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them | |||||
| text = strip_control_characters(text) | |||||
| text = normalise_unicode(text) | |||||
| text = normalise_whitespace(text) | |||||
| return text | |||||
| def strip_control_characters(text): | |||||
| """Strip out unicode control characters which might break the parsing.""" | |||||
| # Unicode control characters | |||||
| # [Cc]: Other, Control [includes new lines] | |||||
| # [Cf]: Other, Format | |||||
| # [Cn]: Other, Not Assigned | |||||
| # [Co]: Other, Private Use | |||||
| # [Cs]: Other, Surrogate | |||||
| control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs']) | |||||
| retained_chars = ['\t', '\n', '\r', '\f'] | |||||
| # Remove non-printing control characters | |||||
| return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) | |||||
| def normalise_unicode(text): | |||||
| """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible.""" | |||||
| normal_form = "NFKC" | |||||
| text = unicodedata.normalize(normal_form, text) | |||||
| return text | |||||
| def normalise_whitespace(text): | |||||
| """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" | |||||
| text = regex.sub(r"\s+", " ", text) | |||||
| # Remove leading and trailing whitespace | |||||
| text = text.strip() | |||||
| return text | |||||
| def is_leaf(element): | |||||
| return (element.name in ['p', 'li']) | |||||
| def is_text(element): | |||||
| return isinstance(element, NavigableString) | |||||
| def is_non_printing(element): | |||||
| return any(isinstance(element, _e) for _e in [Comment, CData]) | |||||
| def add_content_digest(element): | |||||
| if not is_text(element): | |||||
| element["data-content-digest"] = content_digest(element) | |||||
| return element | |||||
| def content_digest(element): | |||||
| if is_text(element): | |||||
| # Hash | |||||
| trimmed_string = element.string.strip() | |||||
| if trimmed_string == "": | |||||
| digest = "" | |||||
| else: | |||||
| digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() | |||||
| else: | |||||
| contents = element.contents | |||||
| num_contents = len(contents) | |||||
| if num_contents == 0: | |||||
| # No hash when no child elements exist | |||||
| digest = "" | |||||
| elif num_contents == 1: | |||||
| # If single child, use digest of child | |||||
| digest = content_digest(contents[0]) | |||||
| else: | |||||
| # Build content digest from the "non-empty" digests of child nodes | |||||
| digest = hashlib.sha256() | |||||
| child_digests = list( | |||||
| filter(lambda x: x != "", [content_digest(content) for content in contents])) | |||||
| for child in child_digests: | |||||
| digest.update(child.encode('utf-8')) | |||||
| digest = digest.hexdigest() | |||||
| return digest |
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom | from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom | ||||
| from core.features.dataset_retrieval import DatasetRetrievalFeature | |||||
| from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature | |||||
| from core.tools.entities.common_entities import I18nObject | from core.tools.entities.common_entities import I18nObject | ||||
| from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter | from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter | ||||
| from core.tools.tool.tool import Tool | from core.tools.tool.tool import Tool | ||||
| @staticmethod | @staticmethod | ||||
| def get_dataset_tools(tenant_id: str, | def get_dataset_tools(tenant_id: str, | ||||
| dataset_ids: list[str], | |||||
| retrieve_config: DatasetRetrieveConfigEntity, | |||||
| return_resource: bool, | |||||
| invoke_from: InvokeFrom, | |||||
| hit_callback: DatasetIndexToolCallbackHandler | |||||
| ) -> list['DatasetRetrieverTool']: | |||||
| dataset_ids: list[str], | |||||
| retrieve_config: DatasetRetrieveConfigEntity, | |||||
| return_resource: bool, | |||||
| invoke_from: InvokeFrom, | |||||
| hit_callback: DatasetIndexToolCallbackHandler | |||||
| ) -> list['DatasetRetrieverTool']: | |||||
| """ | """ | ||||
| get dataset tool | get dataset tool | ||||
| """ | """ | ||||
| ) | ) | ||||
| # restore retrieve strategy | # restore retrieve strategy | ||||
| retrieve_config.retrieve_strategy = original_retriever_mode | retrieve_config.retrieve_strategy = original_retriever_mode | ||||
| # convert langchain tools to Tools | # convert langchain tools to Tools | ||||
| tools = [] | tools = [] | ||||
| for langchain_tool in langchain_tools: | for langchain_tool in langchain_tools: | ||||
| llm=langchain_tool.description), | llm=langchain_tool.description), | ||||
| runtime=DatasetRetrieverTool.Runtime() | runtime=DatasetRetrieverTool.Runtime() | ||||
| ) | ) | ||||
| tools.append(tool) | tools.append(tool) | ||||
| return tools | return tools | ||||
| def get_runtime_parameters(self) -> list[ToolParameter]: | def get_runtime_parameters(self) -> list[ToolParameter]: | ||||
| return [ | return [ | ||||
| ToolParameter(name='query', | ToolParameter(name='query', | ||||
| label=I18nObject(en_US='', zh_Hans=''), | |||||
| human_description=I18nObject(en_US='', zh_Hans=''), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| llm_description='Query for the dataset to be used to retrieve the dataset.', | |||||
| required=True, | |||||
| default=''), | |||||
| label=I18nObject(en_US='', zh_Hans=''), | |||||
| human_description=I18nObject(en_US='', zh_Hans=''), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| llm_description='Query for the dataset to be used to retrieve the dataset.', | |||||
| required=True, | |||||
| default=''), | |||||
| ] | ] | ||||
| def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: | def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: | ||||
| query = tool_parameters.get('query', None) | query = tool_parameters.get('query', None) | ||||
| if not query: | if not query: | ||||
| return self.create_text_message(text='please input query') | return self.create_text_message(text='please input query') | ||||
| # invoke dataset retriever tool | # invoke dataset retriever tool | ||||
| result = self.langchain_tool._run(query=query) | result = self.langchain_tool._run(query=query) | ||||
| """ | """ | ||||
| validate the credentials for dataset retriever tool | validate the credentials for dataset retriever tool | ||||
| """ | """ | ||||
| pass | |||||
| pass |
| import tempfile | import tempfile | ||||
| import unicodedata | import unicodedata | ||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||
| from typing import Any | |||||
| import requests | import requests | ||||
| from bs4 import BeautifulSoup, CData, Comment, NavigableString | from bs4 import BeautifulSoup, CData, Comment, NavigableString | ||||
| from langchain.chains import RefineDocumentsChain | |||||
| from langchain.chains.summarize import refine_prompts | |||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||||
| from langchain.tools.base import BaseTool | |||||
| from newspaper import Article | from newspaper import Article | ||||
| from pydantic import BaseModel, Field | |||||
| from regex import regex | from regex import regex | ||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.entities.application_entities import ModelConfigEntity | |||||
| from core.rag.extractor import extract_processor | from core.rag.extractor import extract_processor | ||||
| from core.rag.extractor.extract_processor import ExtractProcessor | from core.rag.extractor.extract_processor import ExtractProcessor | ||||
| from core.rag.models.document import Document | |||||
| FULL_TEMPLATE = """ | FULL_TEMPLATE = """ | ||||
| TITLE: {title} | TITLE: {title} | ||||
| """ | """ | ||||
| class WebReaderToolInput(BaseModel): | |||||
| url: str = Field(..., description="URL of the website to read") | |||||
| summary: bool = Field( | |||||
| default=False, | |||||
| description="When the user's question requires extracting the summarizing content of the webpage, " | |||||
| "set it to true." | |||||
| ) | |||||
| cursor: int = Field( | |||||
| default=0, | |||||
| description="Start reading from this character." | |||||
| "Use when the first response was truncated" | |||||
| "and you want to continue reading the page." | |||||
| "The value cannot exceed 24000.", | |||||
| ) | |||||
| class WebReaderTool(BaseTool): | |||||
| """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" | |||||
| name: str = "web_reader" | |||||
| args_schema: type[BaseModel] = WebReaderToolInput | |||||
| description: str = "use this to read a website. " \ | |||||
| "If you can answer the question based on the information provided, " \ | |||||
| "there is no need to use." | |||||
| page_contents: str = None | |||||
| url: str = None | |||||
| max_chunk_length: int = 4000 | |||||
| summary_chunk_tokens: int = 4000 | |||||
| summary_chunk_overlap: int = 0 | |||||
| summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] | |||||
| continue_reading: bool = True | |||||
| model_config: ModelConfigEntity | |||||
| model_parameters: dict[str, Any] | |||||
| def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: | |||||
| try: | |||||
| if not self.page_contents or self.url != url: | |||||
| page_contents = get_url(url) | |||||
| self.page_contents = page_contents | |||||
| self.url = url | |||||
| else: | |||||
| page_contents = self.page_contents | |||||
| except Exception as e: | |||||
| return f'Read this website failed, caused by: {str(e)}.' | |||||
| if summary: | |||||
| character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |||||
| chunk_size=self.summary_chunk_tokens, | |||||
| chunk_overlap=self.summary_chunk_overlap, | |||||
| separators=self.summary_separators | |||||
| ) | |||||
| texts = character_splitter.split_text(page_contents) | |||||
| docs = [Document(page_content=t) for t in texts] | |||||
| if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'): | |||||
| return "No content found." | |||||
| # only use first 5 docs | |||||
| if len(docs) > 5: | |||||
| docs = docs[:5] | |||||
| chain = self.get_summary_chain() | |||||
| try: | |||||
| page_contents = chain.run(docs) | |||||
| except Exception as e: | |||||
| return f'Read this website failed, caused by: {str(e)}.' | |||||
| else: | |||||
| page_contents = page_result(page_contents, cursor, self.max_chunk_length) | |||||
| if self.continue_reading and len(page_contents) >= self.max_chunk_length: | |||||
| page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ | |||||
| f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ | |||||
| f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." | |||||
| return page_contents | |||||
| async def _arun(self, url: str) -> str: | |||||
| raise NotImplementedError | |||||
| def get_summary_chain(self) -> RefineDocumentsChain: | |||||
| initial_chain = LLMChain( | |||||
| model_config=self.model_config, | |||||
| prompt=refine_prompts.PROMPT, | |||||
| parameters=self.model_parameters | |||||
| ) | |||||
| refine_chain = LLMChain( | |||||
| model_config=self.model_config, | |||||
| prompt=refine_prompts.REFINE_PROMPT, | |||||
| parameters=self.model_parameters | |||||
| ) | |||||
| return RefineDocumentsChain( | |||||
| initial_llm_chain=initial_chain, | |||||
| refine_llm_chain=refine_chain, | |||||
| document_variable_name="text", | |||||
| initial_response_name="existing_answer", | |||||
| callbacks=self.callbacks | |||||
| ) | |||||
| def page_result(text: str, cursor: int, max_length: int) -> str: | def page_result(text: str, cursor: int, max_length: int) -> str: | ||||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | ||||
| return text[cursor: cursor + max_length] | return text[cursor: cursor + max_length] |
| import re | import re | ||||
| import uuid | import uuid | ||||
| from core.agent.agent_executor import PlanningStrategy | |||||
| from core.entities.agent_entities import PlanningStrategy | |||||
| from core.external_data_tool.factory import ExternalDataToolFactory | from core.external_data_tool.factory import ExternalDataToolFactory | ||||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | ||||
| from core.model_runtime.model_providers import model_provider_factory | from core.model_runtime.model_providers import model_provider_factory |