| @@ -1,49 +0,0 @@ | |||
| from typing import cast | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| class CalcTokenMixin: | |||
| def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: | |||
| """ | |||
| Got the rest tokens available for the model after excluding messages tokens and completion max tokens | |||
| :param model_config: | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||
| max_tokens = 0 | |||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||
| if (parameter_rule.name == 'max_tokens' | |||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||
| max_tokens = (model_config.parameters.get(parameter_rule.name) | |||
| or model_config.parameters.get(parameter_rule.use_template)) or 0 | |||
| if model_context_tokens is None: | |||
| return 0 | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| prompt_tokens = model_type_instance.get_num_tokens( | |||
| model_config.model, | |||
| model_config.credentials, | |||
| messages | |||
| ) | |||
| rest_tokens = model_context_tokens - max_tokens - prompt_tokens | |||
| return rest_tokens | |||
| class ExceededLLMTokensLimitError(Exception): | |||
| pass | |||
| @@ -1,361 +0,0 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union | |||
| from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent | |||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken | |||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import ( | |||
| AgentAction, | |||
| AgentFinish, | |||
| AIMessage, | |||
| BaseMessage, | |||
| HumanMessage, | |||
| SystemMessage, | |||
| get_buffer_string, | |||
| ) | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_model_config: ModelConfigEntity = None | |||
| model_config: ModelConfigEntity | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| @root_validator | |||
| def validate_llm(cls, values: dict) -> dict: | |||
| return values | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None, | |||
| **kwargs: Any, | |||
| ) -> BaseSingleActionAgent: | |||
| prompt = cls.create_prompt( | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_config=model_config, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| agent_llm_callback=agent_llm_callback, | |||
| **kwargs, | |||
| ) | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| :param query: | |||
| :return: | |||
| """ | |||
| original_max_tokens = 0 | |||
| for parameter_rule in self.model_config.model_schema.parameter_rules: | |||
| if (parameter_rule.name == 'max_tokens' | |||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||
| original_max_tokens = (self.model_config.parameters.get(parameter_rule.name) | |||
| or self.model_config.parameters.get(parameter_rule.use_template)) or 0 | |||
| self.model_config.parameters['max_tokens'] = 40 | |||
| prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | |||
| messages = prompt.to_messages() | |||
| try: | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| tools = [] | |||
| for function in self.functions: | |||
| tool = PromptMessageTool( | |||
| **function | |||
| ) | |||
| tools.append(tool) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=tools, | |||
| stream=False, | |||
| model_parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| except Exception as e: | |||
| raise e | |||
| self.model_config.parameters['max_tokens'] = original_max_tokens | |||
| return True if result.message.tool_calls else False | |||
| def plan( | |||
| self, | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||
| selected_inputs = { | |||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||
| } | |||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| # summarize messages if rest_tokens < 0 | |||
| try: | |||
| prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions) | |||
| except ExceededLLMTokensLimitError as e: | |||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| tools = [] | |||
| for function in self.functions: | |||
| tool = PromptMessageTool( | |||
| **function | |||
| ) | |||
| tools.append(tool) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=tools, | |||
| stream=False, | |||
| callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [], | |||
| model_parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.message.content or "", | |||
| additional_kwargs={ | |||
| 'function_call': { | |||
| 'id': result.message.tool_calls[0].id, | |||
| **result.message.tool_calls[0].function.dict() | |||
| } if result.message.tool_calls else None | |||
| } | |||
| ) | |||
| agent_decision = _parse_ai_message(ai_message) | |||
| if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': | |||
| tool_inputs = agent_decision.tool_input | |||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | |||
| tool_inputs['query'] = kwargs['input'] | |||
| agent_decision.tool_input = tool_inputs | |||
| return agent_decision | |||
| @classmethod | |||
| def get_system_message(cls): | |||
| return SystemMessage(content="You are a helpful AI assistant.\n" | |||
| "The current date or current time you know is wrong.\n" | |||
| "Respond directly if appropriate.") | |||
| def return_stopped_response( | |||
| self, | |||
| early_stopping_method: str, | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| **kwargs: Any, | |||
| ) -> AgentFinish: | |||
| try: | |||
| return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) | |||
| except ValueError: | |||
| return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") | |||
| def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: | |||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||
| rest_tokens = self.get_message_rest_tokens( | |||
| self.model_config, | |||
| messages, | |||
| **kwargs | |||
| ) | |||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||
| if rest_tokens >= 0: | |||
| return messages | |||
| system_message = None | |||
| human_message = None | |||
| should_summary_messages = [] | |||
| for message in messages: | |||
| if isinstance(message, SystemMessage): | |||
| system_message = message | |||
| elif isinstance(message, HumanMessage): | |||
| human_message = message | |||
| else: | |||
| should_summary_messages.append(message) | |||
| if len(should_summary_messages) > 2: | |||
| ai_message = should_summary_messages[-2] | |||
| function_message = should_summary_messages[-1] | |||
| should_summary_messages = should_summary_messages[self.moving_summary_index:-2] | |||
| self.moving_summary_index = len(should_summary_messages) | |||
| else: | |||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||
| raise ExceededLLMTokensLimitError(error_msg) | |||
| new_messages = [system_message, human_message] | |||
| if self.moving_summary_index == 0: | |||
| should_summary_messages.insert(0, human_message) | |||
| self.moving_summary_buffer = self.predict_new_summary( | |||
| messages=should_summary_messages, | |||
| existing_summary=self.moving_summary_buffer | |||
| ) | |||
| new_messages.append(AIMessage(content=self.moving_summary_buffer)) | |||
| new_messages.append(ai_message) | |||
| new_messages.append(function_message) | |||
| return new_messages | |||
| def predict_new_summary( | |||
| self, messages: list[BaseMessage], existing_summary: str | |||
| ) -> str: | |||
| new_lines = get_buffer_string( | |||
| messages, | |||
| human_prefix="Human", | |||
| ai_prefix="AI", | |||
| ) | |||
| chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) | |||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||
| def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: | |||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||
| if model_config.provider == 'azure_openai': | |||
| model = model_config.model | |||
| model = model.replace("gpt-35", "gpt-3.5") | |||
| else: | |||
| model = model_config.credentials.get("base_model_name") | |||
| tiktoken_ = _import_tiktoken() | |||
| try: | |||
| encoding = tiktoken_.encoding_for_model(model) | |||
| except KeyError: | |||
| model = "cl100k_base" | |||
| encoding = tiktoken_.get_encoding(model) | |||
| if model.startswith("gpt-3.5-turbo"): | |||
| # every message follows <im_start>{role/name}\n{content}<im_end>\n | |||
| tokens_per_message = 4 | |||
| # if there's a name, the role is omitted | |||
| tokens_per_name = -1 | |||
| elif model.startswith("gpt-4"): | |||
| tokens_per_message = 3 | |||
| tokens_per_name = 1 | |||
| else: | |||
| raise NotImplementedError( | |||
| f"get_num_tokens_from_messages() is not presently implemented " | |||
| f"for model {model}." | |||
| "See https://github.com/openai/openai-python/blob/main/chatml.md for " | |||
| "information on how messages are converted to tokens." | |||
| ) | |||
| num_tokens = 0 | |||
| for m in messages: | |||
| message = _convert_message_to_dict(m) | |||
| num_tokens += tokens_per_message | |||
| for key, value in message.items(): | |||
| if key == "function_call": | |||
| for f_key, f_value in value.items(): | |||
| num_tokens += len(encoding.encode(f_key)) | |||
| num_tokens += len(encoding.encode(f_value)) | |||
| else: | |||
| num_tokens += len(encoding.encode(value)) | |||
| if key == "name": | |||
| num_tokens += tokens_per_name | |||
| # every reply is primed with <im_start>assistant | |||
| num_tokens += 3 | |||
| if kwargs.get('functions'): | |||
| for function in kwargs.get('functions'): | |||
| num_tokens += len(encoding.encode('name')) | |||
| num_tokens += len(encoding.encode(function.get("name"))) | |||
| num_tokens += len(encoding.encode('description')) | |||
| num_tokens += len(encoding.encode(function.get("description"))) | |||
| parameters = function.get("parameters") | |||
| num_tokens += len(encoding.encode('parameters')) | |||
| if 'title' in parameters: | |||
| num_tokens += len(encoding.encode('title')) | |||
| num_tokens += len(encoding.encode(parameters.get("title"))) | |||
| num_tokens += len(encoding.encode('type')) | |||
| num_tokens += len(encoding.encode(parameters.get("type"))) | |||
| if 'properties' in parameters: | |||
| num_tokens += len(encoding.encode('properties')) | |||
| for key, value in parameters.get('properties').items(): | |||
| num_tokens += len(encoding.encode(key)) | |||
| for field_key, field_value in value.items(): | |||
| num_tokens += len(encoding.encode(field_key)) | |||
| if field_key == 'enum': | |||
| for enum_field in field_value: | |||
| num_tokens += 3 | |||
| num_tokens += len(encoding.encode(enum_field)) | |||
| else: | |||
| num_tokens += len(encoding.encode(field_key)) | |||
| num_tokens += len(encoding.encode(str(field_value))) | |||
| if 'required' in parameters: | |||
| num_tokens += len(encoding.encode('required')) | |||
| for required_field in parameters['required']: | |||
| num_tokens += 3 | |||
| num_tokens += len(encoding.encode(required_field)) | |||
| return num_tokens | |||
| @@ -1,306 +0,0 @@ | |||
| import re | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from langchain import BasePromptTemplate, PromptTemplate | |||
| from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent | |||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||
| from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate | |||
| from langchain.schema import ( | |||
| AgentAction, | |||
| AgentFinish, | |||
| AIMessage, | |||
| BaseMessage, | |||
| HumanMessage, | |||
| OutputParserException, | |||
| get_buffer_string, | |||
| ) | |||
| from langchain.tools import BaseTool | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||
| Valid "action" values: "Final Answer" or {tool_names} | |||
| Provide only ONE action per $JSON_BLOB, as shown: | |||
| ``` | |||
| {{{{ | |||
| "action": $TOOL_NAME, | |||
| "action_input": $INPUT | |||
| }}}} | |||
| ``` | |||
| Follow this format: | |||
| Question: input question to answer | |||
| Thought: consider previous and subsequent steps | |||
| Action: | |||
| ``` | |||
| $JSON_BLOB | |||
| ``` | |||
| Observation: action result | |||
| ... (repeat Thought/Action/Observation N times) | |||
| Thought: I know what to respond | |||
| Action: | |||
| ``` | |||
| {{{{ | |||
| "action": "Final Answer", | |||
| "action_input": "Final response to human" | |||
| }}}} | |||
| ```""" | |||
| class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_model_config: ModelConfigEntity = None | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| Using the ReACT mode to determine whether an agent is needed is costly, | |||
| so it's better to just use an Agent for reasoning, which is cheaper. | |||
| :param query: | |||
| :return: | |||
| """ | |||
| return True | |||
| def plan( | |||
| self, | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, | |||
| along with observatons | |||
| callbacks: Callbacks to run. | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) | |||
| prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)]) | |||
| messages = [] | |||
| if prompts: | |||
| messages = prompts[0].to_messages() | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages) | |||
| if rest_tokens < 0: | |||
| full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | |||
| try: | |||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||
| except Exception as e: | |||
| raise e | |||
| try: | |||
| agent_decision = self.output_parser.parse(full_output) | |||
| if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': | |||
| tool_inputs = agent_decision.tool_input | |||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | |||
| tool_inputs['query'] = kwargs['input'] | |||
| agent_decision.tool_input = tool_inputs | |||
| return agent_decision | |||
| except OutputParserException: | |||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||
| "I don't know how to respond to that."}, "") | |||
| def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): | |||
| if len(intermediate_steps) >= 2 and self.summary_model_config: | |||
| should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] | |||
| should_summary_messages = [AIMessage(content=observation) | |||
| for _, observation in should_summary_intermediate_steps] | |||
| if self.moving_summary_index == 0: | |||
| should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input"))) | |||
| self.moving_summary_index = len(intermediate_steps) | |||
| else: | |||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||
| raise ExceededLLMTokensLimitError(error_msg) | |||
| if self.moving_summary_buffer and 'chat_history' in kwargs: | |||
| kwargs["chat_history"].pop() | |||
| self.moving_summary_buffer = self.predict_new_summary( | |||
| messages=should_summary_messages, | |||
| existing_summary=self.moving_summary_buffer | |||
| ) | |||
| if 'chat_history' in kwargs: | |||
| kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer)) | |||
| return self.get_full_inputs([intermediate_steps[-1]], **kwargs) | |||
| def predict_new_summary( | |||
| self, messages: list[BaseMessage], existing_summary: str | |||
| ) -> str: | |||
| new_lines = get_buffer_string( | |||
| messages, | |||
| human_prefix="Human", | |||
| ai_prefix="AI", | |||
| ) | |||
| chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) | |||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||
| @classmethod | |||
| def create_prompt( | |||
| cls, | |||
| tools: Sequence[BaseTool], | |||
| prefix: str = PREFIX, | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| ) -> BasePromptTemplate: | |||
| tool_strings = [] | |||
| for tool in tools: | |||
| args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) | |||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") | |||
| formatted_tools = "\n".join(tool_strings) | |||
| tool_names = ", ".join([('"' + tool.name + '"') for tool in tools]) | |||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |||
| if input_variables is None: | |||
| input_variables = ["input", "agent_scratchpad"] | |||
| _memory_prompts = memory_prompts or [] | |||
| messages = [ | |||
| SystemMessagePromptTemplate.from_template(template), | |||
| *_memory_prompts, | |||
| HumanMessagePromptTemplate.from_template(human_message_template), | |||
| ] | |||
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |||
| @classmethod | |||
| def create_completion_prompt( | |||
| cls, | |||
| tools: Sequence[BaseTool], | |||
| prefix: str = PREFIX, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[list[str]] = None, | |||
| ) -> PromptTemplate: | |||
| """Create prompt in the style of the zero shot agent. | |||
| Args: | |||
| tools: List of tools the agent will have access to, used to format the | |||
| prompt. | |||
| prefix: String to put before the list of tools. | |||
| input_variables: List of input variables the final prompt will expect. | |||
| Returns: | |||
| A PromptTemplate with the template assembled from the pieces here. | |||
| """ | |||
| suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. | |||
| Question: {input} | |||
| Thought: {agent_scratchpad} | |||
| """ | |||
| tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) | |||
| tool_names = ", ".join([tool.name for tool in tools]) | |||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||
| template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) | |||
| if input_variables is None: | |||
| input_variables = ["input", "agent_scratchpad"] | |||
| return PromptTemplate(template=template, input_variables=input_variables) | |||
| def _construct_scratchpad( | |||
| self, intermediate_steps: list[tuple[AgentAction, str]] | |||
| ) -> str: | |||
| agent_scratchpad = "" | |||
| for action, observation in intermediate_steps: | |||
| agent_scratchpad += action.log | |||
| agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" | |||
| if not isinstance(agent_scratchpad, str): | |||
| raise ValueError("agent_scratchpad should be of type string.") | |||
| if agent_scratchpad: | |||
| llm_chain = cast(LLMChain, self.llm_chain) | |||
| if llm_chain.model_config.mode == "chat": | |||
| return ( | |||
| f"This was your previous work " | |||
| f"(but I haven't seen any of it! I only see what " | |||
| f"you return as final answer):\n{agent_scratchpad}" | |||
| ) | |||
| else: | |||
| return agent_scratchpad | |||
| else: | |||
| return agent_scratchpad | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| output_parser: Optional[AgentOutputParser] = None, | |||
| prefix: str = PREFIX, | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| """Construct an agent from an LLM and tools.""" | |||
| cls._validate_tools(tools) | |||
| if model_config.mode == "chat": | |||
| prompt = cls.create_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| suffix=suffix, | |||
| human_message_template=human_message_template, | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables, | |||
| memory_prompts=memory_prompts, | |||
| ) | |||
| else: | |||
| prompt = cls.create_completion_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables, | |||
| ) | |||
| llm_chain = LLMChain( | |||
| model_config=model_config, | |||
| prompt=prompt, | |||
| callback_manager=callback_manager, | |||
| agent_llm_callback=agent_llm_callback, | |||
| parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| tool_names = [tool.name for tool in tools] | |||
| _output_parser = output_parser | |||
| return cls( | |||
| llm_chain=llm_chain, | |||
| allowed_tools=tool_names, | |||
| output_parser=_output_parser, | |||
| **kwargs, | |||
| ) | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| from typing import cast | |||
| @@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large | |||
| from core.moderation.base import ModerationException | |||
| from core.tools.entities.tool_entities import ToolRuntimeVariablePool | |||
| from extensions.ext_database import db | |||
| from models.model import App, Conversation, Message, MessageAgentThought, MessageChain | |||
| from models.model import App, Conversation, Message, MessageAgentThought | |||
| from models.tools import ToolConversationVariables | |||
| logger = logging.getLogger(__name__) | |||
| @@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner): | |||
| # convert db variables to tool variables | |||
| tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) | |||
| message_chain = self._init_message_chain( | |||
| message=message, | |||
| query=query | |||
| ) | |||
| # init model instance | |||
| model_instance = ModelInstance( | |||
| @@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner): | |||
| 'pool': db_variables.variables | |||
| }) | |||
| def _init_message_chain(self, message: Message, query: str) -> MessageChain: | |||
| """ | |||
| Init MessageChain | |||
| :param message: message | |||
| :param query: query | |||
| :return: | |||
| """ | |||
| message_chain = MessageChain( | |||
| message_id=message.id, | |||
| type="AgentExecutor", | |||
| input=json.dumps({ | |||
| "input": query | |||
| }) | |||
| ) | |||
| db.session.add(message_chain) | |||
| db.session.commit() | |||
| return message_chain | |||
| def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None: | |||
| """ | |||
| Save MessageChain | |||
| :param message_chain: message chain | |||
| :param output_text: output text | |||
| :return: | |||
| """ | |||
| message_chain.output = json.dumps({ | |||
| "output": output_text | |||
| }) | |||
| db.session.commit() | |||
| def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, | |||
| message: Message) -> LLMUsage: | |||
| """ | |||
| @@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner | |||
| from core.application_queue_manager import ApplicationQueueManager, PublishFrom | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity | |||
| from core.features.dataset_retrieval import DatasetRetrievalFeature | |||
| from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.moderation.base import ModerationException | |||
| @@ -0,0 +1,8 @@ | |||
| from enum import Enum | |||
| class PlanningStrategy(Enum): | |||
| ROUTER = 'router' | |||
| REACT_ROUTER = 'react_router' | |||
| REACT = 'react' | |||
| FUNCTION_CALL = 'function_call' | |||
| @@ -1,199 +0,0 @@ | |||
| import logging | |||
| from typing import Optional, cast | |||
| from langchain.tools import BaseTool | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.entities.application_entities import ( | |||
| AgentEntity, | |||
| AppOrchestrationConfigEntity, | |||
| InvokeFrom, | |||
| ModelConfigEntity, | |||
| ) | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| from models.model import Message | |||
| logger = logging.getLogger(__name__) | |||
| class AgentRunnerFeature: | |||
| def __init__(self, tenant_id: str, | |||
| app_orchestration_config: AppOrchestrationConfigEntity, | |||
| model_config: ModelConfigEntity, | |||
| config: AgentEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| message: Message, | |||
| user_id: str, | |||
| agent_llm_callback: AgentLLMCallback, | |||
| callback: AgentLoopGatherCallbackHandler, | |||
| memory: Optional[TokenBufferMemory] = None,) -> None: | |||
| """ | |||
| Agent runner | |||
| :param tenant_id: tenant id | |||
| :param app_orchestration_config: app orchestration config | |||
| :param model_config: model config | |||
| :param config: dataset config | |||
| :param queue_manager: queue manager | |||
| :param message: message | |||
| :param user_id: user id | |||
| :param agent_llm_callback: agent llm callback | |||
| :param callback: callback | |||
| :param memory: memory | |||
| """ | |||
| self.tenant_id = tenant_id | |||
| self.app_orchestration_config = app_orchestration_config | |||
| self.model_config = model_config | |||
| self.config = config | |||
| self.queue_manager = queue_manager | |||
| self.message = message | |||
| self.user_id = user_id | |||
| self.agent_llm_callback = agent_llm_callback | |||
| self.callback = callback | |||
| self.memory = memory | |||
| def run(self, query: str, | |||
| invoke_from: InvokeFrom) -> Optional[str]: | |||
| """ | |||
| Retrieve agent loop result. | |||
| :param query: query | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| provider = self.config.provider | |||
| model = self.config.model | |||
| tool_configs = self.config.tools | |||
| # check model is support tool calling | |||
| provider_instance = model_provider_factory.get_provider_instance(provider=provider) | |||
| model_type_instance = provider_instance.get_model_instance(ModelType.LLM) | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # get model schema | |||
| model_schema = model_type_instance.get_model_schema( | |||
| model=model, | |||
| credentials=self.model_config.credentials | |||
| ) | |||
| if not model_schema: | |||
| return None | |||
| planning_strategy = PlanningStrategy.REACT | |||
| features = model_schema.features | |||
| if features: | |||
| if ModelFeature.TOOL_CALL in features \ | |||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||
| planning_strategy = PlanningStrategy.FUNCTION_CALL | |||
| tools = self.to_tools( | |||
| tool_configs=tool_configs, | |||
| invoke_from=invoke_from, | |||
| callbacks=[self.callback, DifyStdOutCallbackHandler()], | |||
| ) | |||
| if len(tools) == 0: | |||
| return None | |||
| agent_configuration = AgentConfiguration( | |||
| strategy=planning_strategy, | |||
| model_config=self.model_config, | |||
| tools=tools, | |||
| memory=self.memory, | |||
| max_iterations=10, | |||
| max_execution_time=400.0, | |||
| early_stopping_method="generate", | |||
| agent_llm_callback=self.agent_llm_callback, | |||
| callbacks=[self.callback, DifyStdOutCallbackHandler()] | |||
| ) | |||
| agent_executor = AgentExecutor(agent_configuration) | |||
| try: | |||
| # check if should use agent | |||
| should_use_agent = agent_executor.should_use_agent(query) | |||
| if not should_use_agent: | |||
| return None | |||
| result = agent_executor.run(query) | |||
| return result.output | |||
| except Exception as ex: | |||
| logger.exception("agent_executor run failed") | |||
| return None | |||
| def to_dataset_retriever_tool(self, tool_config: dict, | |||
| invoke_from: InvokeFrom) \ | |||
| -> Optional[BaseTool]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param tool_config: tool config | |||
| :param invoke_from: invoke from | |||
| """ | |||
| show_retrieve_source = self.app_orchestration_config.show_retrieve_source | |||
| hit_callback = DatasetIndexToolCallbackHandler( | |||
| queue_manager=self.queue_manager, | |||
| app_id=self.message.app_id, | |||
| message_id=self.message.id, | |||
| user_id=self.user_id, | |||
| invoke_from=invoke_from | |||
| ) | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == self.tenant_id, | |||
| Dataset.id == tool_config.get("id") | |||
| ).first() | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| return None | |||
| # pass if dataset is not available | |||
| if (dataset and dataset.available_document_count == 0 | |||
| and dataset.available_document_count == 0): | |||
| return None | |||
| # get retrieval model config | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| } | |||
| retrieval_model_config = dataset.retrieval_model \ | |||
| if dataset.retrieval_model else default_retrieval_model | |||
| # get top k | |||
| top_k = retrieval_model_config['top_k'] | |||
| # get score threshold | |||
| score_threshold = None | |||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | |||
| if score_threshold_enabled: | |||
| score_threshold = retrieval_model_config.get("score_threshold") | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| hit_callbacks=[hit_callback], | |||
| return_resource=show_retrieve_source, | |||
| retriever_from=invoke_from.to_source() | |||
| ) | |||
| return tool | |||
| @@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.schema import Generation, LLMResult | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.features.dataset_retrieval.agent.fake_llm import FakeLLM | |||
| from core.model_manager import ModelInstance | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class LLMChain(LCLLMChain): | |||
| @@ -12,9 +12,9 @@ from pydantic import root_validator | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.features.dataset_retrieval.agent.fake_llm import FakeLLM | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| @@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy | |||
| from langchain.schema import AgentAction, AgentFinish, OutputParserException | |||
| from langchain.tools import BaseTool | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.features.dataset_retrieval.agent.llm_chain import LLMChain | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||
| @@ -1,4 +1,3 @@ | |||
| import enum | |||
| import logging | |||
| from typing import Optional, Union | |||
| @@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks | |||
| from langchain.tools import BaseTool | |||
| from pydantic import BaseModel, Extra | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||
| from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | |||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import prompt_messages_to_lc_messages | |||
| from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||
| from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||
| from core.helper import moderation | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| @@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas | |||
| from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool | |||
| class PlanningStrategy(str, enum.Enum): | |||
| ROUTER = 'router' | |||
| REACT_ROUTER = 'react_router' | |||
| REACT = 'react' | |||
| FUNCTION_CALL = 'function_call' | |||
| class AgentConfiguration(BaseModel): | |||
| strategy: PlanningStrategy | |||
| model_config: ModelConfigEntity | |||
| @@ -62,28 +53,7 @@ class AgentExecutor: | |||
| self.agent = self._init_agent() | |||
| def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: | |||
| if self.configuration.strategy == PlanningStrategy.REACT: | |||
| agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| summary_model_config=self.configuration.summary_model_config | |||
| if self.configuration.summary_model_config else None, | |||
| agent_llm_callback=self.configuration.agent_llm_callback, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | |||
| agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) | |||
| if self.configuration.memory else None, # used for read chat histories memory | |||
| summary_model_config=self.configuration.summary_model_config | |||
| if self.configuration.summary_model_config else None, | |||
| agent_llm_callback=self.configuration.agent_llm_callback, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.ROUTER: | |||
| if self.configuration.strategy == PlanningStrategy.ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools | |||
| if isinstance(t, DatasetRetrieverTool) | |||
| or isinstance(t, DatasetMultiRetrieverTool)] | |||
| @@ -2,9 +2,10 @@ from typing import Optional, cast | |||
| from langchain.tools import BaseTool | |||
| from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity | |||
| from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| @@ -1,189 +0,0 @@ | |||
| import base64 | |||
| import hashlib | |||
| import hmac | |||
| import json | |||
| import queue | |||
| import ssl | |||
| from datetime import datetime | |||
| from time import mktime | |||
| from typing import Optional | |||
| from urllib.parse import urlencode, urlparse | |||
| from wsgiref.handlers import format_date_time | |||
| import websocket | |||
| class SparkLLMClient: | |||
| def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): | |||
| domain = 'spark-api.xf-yun.com' | |||
| endpoint = 'chat' | |||
| if api_domain: | |||
| domain = api_domain | |||
| if model_name == 'spark-v3': | |||
| endpoint = 'multimodal' | |||
| model_api_configs = { | |||
| 'spark': { | |||
| 'version': 'v1.1', | |||
| 'chat_domain': 'general' | |||
| }, | |||
| 'spark-v2': { | |||
| 'version': 'v2.1', | |||
| 'chat_domain': 'generalv2' | |||
| }, | |||
| 'spark-v3': { | |||
| 'version': 'v3.1', | |||
| 'chat_domain': 'generalv3' | |||
| }, | |||
| 'spark-v3.5': { | |||
| 'version': 'v3.5', | |||
| 'chat_domain': 'generalv3.5' | |||
| } | |||
| } | |||
| api_version = model_api_configs[model_name]['version'] | |||
| self.chat_domain = model_api_configs[model_name]['chat_domain'] | |||
| self.api_base = f"wss://{domain}/{api_version}/{endpoint}" | |||
| self.app_id = app_id | |||
| self.ws_url = self.create_url( | |||
| urlparse(self.api_base).netloc, | |||
| urlparse(self.api_base).path, | |||
| self.api_base, | |||
| api_key, | |||
| api_secret | |||
| ) | |||
| self.queue = queue.Queue() | |||
| self.blocking_message = '' | |||
| def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: | |||
| # generate timestamp by RFC1123 | |||
| now = datetime.now() | |||
| date = format_date_time(mktime(now.timetuple())) | |||
| signature_origin = "host: " + host + "\n" | |||
| signature_origin += "date: " + date + "\n" | |||
| signature_origin += "GET " + path + " HTTP/1.1" | |||
| # encrypt using hmac-sha256 | |||
| signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), | |||
| digestmod=hashlib.sha256).digest() | |||
| signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') | |||
| authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' | |||
| authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') | |||
| v = { | |||
| "authorization": authorization, | |||
| "date": date, | |||
| "host": host | |||
| } | |||
| # generate url | |||
| url = api_base + '?' + urlencode(v) | |||
| return url | |||
| def run(self, messages: list, user_id: str, | |||
| model_kwargs: Optional[dict] = None, streaming: bool = False): | |||
| websocket.enableTrace(False) | |||
| ws = websocket.WebSocketApp( | |||
| self.ws_url, | |||
| on_message=self.on_message, | |||
| on_error=self.on_error, | |||
| on_close=self.on_close, | |||
| on_open=self.on_open | |||
| ) | |||
| ws.messages = messages | |||
| ws.user_id = user_id | |||
| ws.model_kwargs = model_kwargs | |||
| ws.streaming = streaming | |||
| ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | |||
| def on_error(self, ws, error): | |||
| self.queue.put({ | |||
| 'status_code': error.status_code, | |||
| 'error': error.resp_body.decode('utf-8') | |||
| }) | |||
| ws.close() | |||
| def on_close(self, ws, close_status_code, close_reason): | |||
| self.queue.put({'done': True}) | |||
| def on_open(self, ws): | |||
| self.blocking_message = '' | |||
| data = json.dumps(self.gen_params( | |||
| messages=ws.messages, | |||
| user_id=ws.user_id, | |||
| model_kwargs=ws.model_kwargs | |||
| )) | |||
| ws.send(data) | |||
| def on_message(self, ws, message): | |||
| data = json.loads(message) | |||
| code = data['header']['code'] | |||
| if code != 0: | |||
| self.queue.put({ | |||
| 'status_code': 400, | |||
| 'error': f"Code: {code}, Error: {data['header']['message']}" | |||
| }) | |||
| ws.close() | |||
| else: | |||
| choices = data["payload"]["choices"] | |||
| status = choices["status"] | |||
| content = choices["text"][0]["content"] | |||
| if ws.streaming: | |||
| self.queue.put({'data': content}) | |||
| else: | |||
| self.blocking_message += content | |||
| if status == 2: | |||
| if not ws.streaming: | |||
| self.queue.put({'data': self.blocking_message}) | |||
| ws.close() | |||
| def gen_params(self, messages: list, user_id: str, | |||
| model_kwargs: Optional[dict] = None) -> dict: | |||
| data = { | |||
| "header": { | |||
| "app_id": self.app_id, | |||
| "uid": user_id | |||
| }, | |||
| "parameter": { | |||
| "chat": { | |||
| "domain": self.chat_domain | |||
| } | |||
| }, | |||
| "payload": { | |||
| "message": { | |||
| "text": messages | |||
| } | |||
| } | |||
| } | |||
| if model_kwargs: | |||
| data['parameter']['chat'].update(model_kwargs) | |||
| return data | |||
| def subscribe(self): | |||
| while True: | |||
| content = self.queue.get() | |||
| if 'error' in content: | |||
| if content['status_code'] == 401: | |||
| raise SparkError('[Spark] The credentials you provided are incorrect. ' | |||
| 'Please double-check and fill them in again.') | |||
| elif content['status_code'] == 403: | |||
| raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " | |||
| "Please try again after obtaining the necessary permissions.") | |||
| else: | |||
| raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") | |||
| if 'data' not in content: | |||
| break | |||
| yield content | |||
| class SparkError(Exception): | |||
| pass | |||
| @@ -1,24 +0,0 @@ | |||
| from datetime import datetime | |||
| from langchain.tools import BaseTool | |||
| from pydantic import BaseModel, Field | |||
| class DatetimeToolInput(BaseModel): | |||
| type: str = Field(..., description="Type for current time, must be: datetime.") | |||
| class DatetimeTool(BaseTool): | |||
| """Tool for querying current datetime.""" | |||
| name: str = "current_datetime" | |||
| args_schema: type[BaseModel] = DatetimeToolInput | |||
| description: str = "A tool when you want to get the current date, time, week, month or year, " \ | |||
| "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"." | |||
| def _run(self, type: str) -> str: | |||
| # get current time | |||
| current_time = datetime.utcnow() | |||
| return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A") | |||
| async def _arun(self, tool_input: str) -> str: | |||
| raise NotImplementedError() | |||
| @@ -1,63 +0,0 @@ | |||
| import base64 | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional | |||
| from extensions.ext_database import db | |||
| from libs import rsa | |||
| from models.account import Tenant | |||
| from models.tool import ToolProvider, ToolProviderName | |||
| class BaseToolProvider(ABC): | |||
| def __init__(self, tenant_id: str): | |||
| self.tenant_id = tenant_id | |||
| @abstractmethod | |||
| def get_provider_name(self) -> ToolProviderName: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def encrypt_credentials(self, credentials: dict) -> Optional[dict]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def credentials_to_func_kwargs(self) -> Optional[dict]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def credentials_validate(self, credentials: dict): | |||
| raise NotImplementedError | |||
| def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]: | |||
| """ | |||
| Returns the Provider instance for the given tenant_id and tool_name. | |||
| """ | |||
| query = db.session.query(ToolProvider).filter( | |||
| ToolProvider.tenant_id == self.tenant_id, | |||
| ToolProvider.tool_name == self.get_provider_name().value | |||
| ) | |||
| if must_enabled: | |||
| query = query.filter(ToolProvider.is_enabled == True) | |||
| return query.first() | |||
| def encrypt_token(self, token) -> str: | |||
| tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() | |||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||
| return base64.b64encode(encrypted_token).decode() | |||
| def decrypt_token(self, token: str, obfuscated: bool = False) -> str: | |||
| token = rsa.decrypt(base64.b64decode(token), self.tenant_id) | |||
| if obfuscated: | |||
| return self._obfuscated_token(token) | |||
| return token | |||
| def _obfuscated_token(self, token: str) -> str: | |||
| return token[:6] + '*' * (len(token) - 8) + token[-2:] | |||
| @@ -1,2 +0,0 @@ | |||
| class ToolValidateFailedError(Exception): | |||
| description = "Tool Provider Validate failed" | |||
| @@ -1,77 +0,0 @@ | |||
| from typing import Optional | |||
| from core.tool.provider.base import BaseToolProvider | |||
| from core.tool.provider.errors import ToolValidateFailedError | |||
| from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper | |||
| from models.tool import ToolProviderName | |||
| class SerpAPIToolProvider(BaseToolProvider): | |||
| def get_provider_name(self) -> ToolProviderName: | |||
| """ | |||
| Returns the name of the provider. | |||
| :return: | |||
| """ | |||
| return ToolProviderName.SERPAPI | |||
| def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |||
| """ | |||
| Returns the credentials for SerpAPI as a dictionary. | |||
| :param obfuscated: obfuscate credentials if True | |||
| :return: | |||
| """ | |||
| tool_provider = self.get_provider(must_enabled=True) | |||
| if not tool_provider: | |||
| return None | |||
| credentials = tool_provider.credentials | |||
| if not credentials: | |||
| return None | |||
| if credentials.get('api_key'): | |||
| credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated) | |||
| return credentials | |||
| def credentials_to_func_kwargs(self) -> Optional[dict]: | |||
| """ | |||
| Returns the credentials function kwargs as a dictionary. | |||
| :return: | |||
| """ | |||
| credentials = self.get_credentials() | |||
| if not credentials: | |||
| return None | |||
| return { | |||
| 'serpapi_api_key': credentials.get('api_key') | |||
| } | |||
| def credentials_validate(self, credentials: dict): | |||
| """ | |||
| Validates the given credentials. | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| if 'api_key' not in credentials or not credentials.get('api_key'): | |||
| raise ToolValidateFailedError("SerpAPI api_key is required.") | |||
| api_key = credentials.get('api_key') | |||
| try: | |||
| OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test') | |||
| except Exception as e: | |||
| raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e)) | |||
| def encrypt_credentials(self, credentials: dict) -> Optional[dict]: | |||
| """ | |||
| Encrypts the given credentials. | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| credentials['api_key'] = self.encrypt_token(credentials.get('api_key')) | |||
| return credentials | |||
| @@ -1,43 +0,0 @@ | |||
| from typing import Optional | |||
| from core.tool.provider.base import BaseToolProvider | |||
| from core.tool.provider.serpapi_provider import SerpAPIToolProvider | |||
| class ToolProviderService: | |||
| def __init__(self, tenant_id: str, provider_name: str): | |||
| self.provider = self._init_provider(tenant_id, provider_name) | |||
| def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider: | |||
| if provider_name == 'serpapi': | |||
| return SerpAPIToolProvider(tenant_id) | |||
| else: | |||
| raise Exception('tool provider {} not found'.format(provider_name)) | |||
| def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |||
| """ | |||
| Returns the credentials for Tool as a dictionary. | |||
| :param obfuscated: | |||
| :return: | |||
| """ | |||
| return self.provider.get_credentials(obfuscated) | |||
| def credentials_validate(self, credentials: dict): | |||
| """ | |||
| Validates the given credentials. | |||
| :param credentials: | |||
| :raises: ValidateFailedError | |||
| """ | |||
| return self.provider.credentials_validate(credentials) | |||
| def encrypt_credentials(self, credentials: dict): | |||
| """ | |||
| Encrypts the given credentials. | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| return self.provider.encrypt_credentials(credentials) | |||
| @@ -1,51 +0,0 @@ | |||
| from langchain import SerpAPIWrapper | |||
| from pydantic import BaseModel, Field | |||
| class OptimizedSerpAPIInput(BaseModel): | |||
| query: str = Field(..., description="search query.") | |||
| class OptimizedSerpAPIWrapper(SerpAPIWrapper): | |||
| @staticmethod | |||
| def _process_response(res: dict, num_results: int = 5) -> str: | |||
| """Process response from SerpAPI.""" | |||
| if "error" in res.keys(): | |||
| raise ValueError(f"Got error from SerpAPI: {res['error']}") | |||
| if "answer_box" in res.keys() and type(res["answer_box"]) == list: | |||
| res["answer_box"] = res["answer_box"][0] | |||
| if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): | |||
| toret = res["answer_box"]["answer"] | |||
| elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): | |||
| toret = res["answer_box"]["snippet"] | |||
| elif ( | |||
| "answer_box" in res.keys() | |||
| and "snippet_highlighted_words" in res["answer_box"].keys() | |||
| ): | |||
| toret = res["answer_box"]["snippet_highlighted_words"][0] | |||
| elif ( | |||
| "sports_results" in res.keys() | |||
| and "game_spotlight" in res["sports_results"].keys() | |||
| ): | |||
| toret = res["sports_results"]["game_spotlight"] | |||
| elif ( | |||
| "shopping_results" in res.keys() | |||
| and "title" in res["shopping_results"][0].keys() | |||
| ): | |||
| toret = res["shopping_results"][:3] | |||
| elif ( | |||
| "knowledge_graph" in res.keys() | |||
| and "description" in res["knowledge_graph"].keys() | |||
| ): | |||
| toret = res["knowledge_graph"]["description"] | |||
| elif 'organic_results' in res.keys() and len(res['organic_results']) > 0: | |||
| toret = "" | |||
| for result in res["organic_results"][:num_results]: | |||
| if "link" in result: | |||
| toret += "----------------\nlink: " + result["link"] + "\n" | |||
| if "snippet" in result: | |||
| toret += "snippet: " + result["snippet"] + "\n" | |||
| else: | |||
| toret = "No good search result found" | |||
| return "search result:\n" + toret | |||
| @@ -1,443 +0,0 @@ | |||
| import hashlib | |||
| import json | |||
| import os | |||
| import re | |||
| import site | |||
| import subprocess | |||
| import tempfile | |||
| import unicodedata | |||
| from contextlib import contextmanager | |||
| from typing import Any | |||
| import requests | |||
| from bs4 import BeautifulSoup, CData, Comment, NavigableString | |||
| from langchain.chains import RefineDocumentsChain | |||
| from langchain.chains.summarize import refine_prompts | |||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||
| from langchain.tools.base import BaseTool | |||
| from newspaper import Article | |||
| from pydantic import BaseModel, Field | |||
| from regex import regex | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.rag.extractor import extract_processor | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| from core.rag.models.document import Document | |||
| FULL_TEMPLATE = """ | |||
| TITLE: {title} | |||
| AUTHORS: {authors} | |||
| PUBLISH DATE: {publish_date} | |||
| TOP_IMAGE_URL: {top_image} | |||
| TEXT: | |||
| {text} | |||
| """ | |||
| class WebReaderToolInput(BaseModel): | |||
| url: str = Field(..., description="URL of the website to read") | |||
| summary: bool = Field( | |||
| default=False, | |||
| description="When the user's question requires extracting the summarizing content of the webpage, " | |||
| "set it to true." | |||
| ) | |||
| cursor: int = Field( | |||
| default=0, | |||
| description="Start reading from this character." | |||
| "Use when the first response was truncated" | |||
| "and you want to continue reading the page." | |||
| "The value cannot exceed 24000.", | |||
| ) | |||
| class WebReaderTool(BaseTool): | |||
| """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" | |||
| name: str = "web_reader" | |||
| args_schema: type[BaseModel] = WebReaderToolInput | |||
| description: str = "use this to read a website. " \ | |||
| "If you can answer the question based on the information provided, " \ | |||
| "there is no need to use." | |||
| page_contents: str = None | |||
| url: str = None | |||
| max_chunk_length: int = 4000 | |||
| summary_chunk_tokens: int = 4000 | |||
| summary_chunk_overlap: int = 0 | |||
| summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] | |||
| continue_reading: bool = True | |||
| model_config: ModelConfigEntity | |||
| model_parameters: dict[str, Any] | |||
| def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: | |||
| try: | |||
| if not self.page_contents or self.url != url: | |||
| page_contents = get_url(url) | |||
| self.page_contents = page_contents | |||
| self.url = url | |||
| else: | |||
| page_contents = self.page_contents | |||
| except Exception as e: | |||
| return f'Read this website failed, caused by: {str(e)}.' | |||
| if summary: | |||
| character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |||
| chunk_size=self.summary_chunk_tokens, | |||
| chunk_overlap=self.summary_chunk_overlap, | |||
| separators=self.summary_separators | |||
| ) | |||
| texts = character_splitter.split_text(page_contents) | |||
| docs = [Document(page_content=t) for t in texts] | |||
| if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'): | |||
| return "No content found." | |||
| # only use first 5 docs | |||
| if len(docs) > 5: | |||
| docs = docs[:5] | |||
| chain = self.get_summary_chain() | |||
| try: | |||
| page_contents = chain.run(docs) | |||
| except Exception as e: | |||
| return f'Read this website failed, caused by: {str(e)}.' | |||
| else: | |||
| page_contents = page_result(page_contents, cursor, self.max_chunk_length) | |||
| if self.continue_reading and len(page_contents) >= self.max_chunk_length: | |||
| page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ | |||
| f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ | |||
| f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." | |||
| return page_contents | |||
| async def _arun(self, url: str) -> str: | |||
| raise NotImplementedError | |||
| def get_summary_chain(self) -> RefineDocumentsChain: | |||
| initial_chain = LLMChain( | |||
| model_config=self.model_config, | |||
| prompt=refine_prompts.PROMPT, | |||
| parameters=self.model_parameters | |||
| ) | |||
| refine_chain = LLMChain( | |||
| model_config=self.model_config, | |||
| prompt=refine_prompts.REFINE_PROMPT, | |||
| parameters=self.model_parameters | |||
| ) | |||
| return RefineDocumentsChain( | |||
| initial_llm_chain=initial_chain, | |||
| refine_llm_chain=refine_chain, | |||
| document_variable_name="text", | |||
| initial_response_name="existing_answer", | |||
| callbacks=self.callbacks | |||
| ) | |||
| def page_result(text: str, cursor: int, max_length: int) -> str: | |||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | |||
| return text[cursor: cursor + max_length] | |||
| def get_url(url: str) -> str: | |||
| """Fetch URL and return the contents as a string.""" | |||
| headers = { | |||
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | |||
| } | |||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||
| head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | |||
| if head_response.status_code != 200: | |||
| return "URL returned status code {}.".format(head_response.status_code) | |||
| # check content-type | |||
| main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip() | |||
| if main_content_type not in supported_content_types: | |||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | |||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||
| return ExtractProcessor.load_from_url(url, return_text=True) | |||
| response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | |||
| a = extract_using_readabilipy(response.text) | |||
| if not a['plain_text'] or not a['plain_text'].strip(): | |||
| return get_url_from_newspaper3k(url) | |||
| res = FULL_TEMPLATE.format( | |||
| title=a['title'], | |||
| authors=a['byline'], | |||
| publish_date=a['date'], | |||
| top_image="", | |||
| text=a['plain_text'] if a['plain_text'] else "", | |||
| ) | |||
| return res | |||
| def get_url_from_newspaper3k(url: str) -> str: | |||
| a = Article(url) | |||
| a.download() | |||
| a.parse() | |||
| res = FULL_TEMPLATE.format( | |||
| title=a.title, | |||
| authors=a.authors, | |||
| publish_date=a.publish_date, | |||
| top_image=a.top_image, | |||
| text=a.text, | |||
| ) | |||
| return res | |||
| def extract_using_readabilipy(html): | |||
| with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: | |||
| f_html.write(html) | |||
| f_html.close() | |||
| html_path = f_html.name | |||
| # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file | |||
| article_json_path = html_path + ".json" | |||
| jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') | |||
| with chdir(jsdir): | |||
| subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) | |||
| # Read output of call to Readability.parse() from JSON file and return as Python dictionary | |||
| with open(article_json_path, encoding="utf-8") as json_file: | |||
| input_json = json.loads(json_file.read()) | |||
| # Deleting files after processing | |||
| os.unlink(article_json_path) | |||
| os.unlink(html_path) | |||
| article_json = { | |||
| "title": None, | |||
| "byline": None, | |||
| "date": None, | |||
| "content": None, | |||
| "plain_content": None, | |||
| "plain_text": None | |||
| } | |||
| # Populate article fields from readability fields where present | |||
| if input_json: | |||
| if "title" in input_json and input_json["title"]: | |||
| article_json["title"] = input_json["title"] | |||
| if "byline" in input_json and input_json["byline"]: | |||
| article_json["byline"] = input_json["byline"] | |||
| if "date" in input_json and input_json["date"]: | |||
| article_json["date"] = input_json["date"] | |||
| if "content" in input_json and input_json["content"]: | |||
| article_json["content"] = input_json["content"] | |||
| article_json["plain_content"] = plain_content(article_json["content"], False, False) | |||
| article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) | |||
| if "textContent" in input_json and input_json["textContent"]: | |||
| article_json["plain_text"] = input_json["textContent"] | |||
| article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) | |||
| return article_json | |||
| def find_module_path(module_name): | |||
| for package_path in site.getsitepackages(): | |||
| potential_path = os.path.join(package_path, module_name) | |||
| if os.path.exists(potential_path): | |||
| return potential_path | |||
| return None | |||
| @contextmanager | |||
| def chdir(path): | |||
| """Change directory in context and return to original on exit""" | |||
| # From https://stackoverflow.com/a/37996581, couldn't find a built-in | |||
| original_path = os.getcwd() | |||
| os.chdir(path) | |||
| try: | |||
| yield | |||
| finally: | |||
| os.chdir(original_path) | |||
| def extract_text_blocks_as_plain_text(paragraph_html): | |||
| # Load article as DOM | |||
| soup = BeautifulSoup(paragraph_html, 'html.parser') | |||
| # Select all lists | |||
| list_elements = soup.find_all(['ul', 'ol']) | |||
| # Prefix text in all list items with "* " and make lists paragraphs | |||
| for list_element in list_elements: | |||
| plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) | |||
| list_element.string = plain_items | |||
| list_element.name = "p" | |||
| # Select all text blocks | |||
| text_blocks = [s.parent for s in soup.find_all(string=True)] | |||
| text_blocks = [plain_text_leaf_node(block) for block in text_blocks] | |||
| # Drop empty paragraphs | |||
| text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) | |||
| return text_blocks | |||
| def plain_text_leaf_node(element): | |||
| # Extract all text, stripped of any child HTML elements and normalise it | |||
| plain_text = normalise_text(element.get_text()) | |||
| if plain_text != "" and element.name == "li": | |||
| plain_text = "* {}, ".format(plain_text) | |||
| if plain_text == "": | |||
| plain_text = None | |||
| if "data-node-index" in element.attrs: | |||
| plain = {"node_index": element["data-node-index"], "text": plain_text} | |||
| else: | |||
| plain = {"text": plain_text} | |||
| return plain | |||
| def plain_content(readability_content, content_digests, node_indexes): | |||
| # Load article as DOM | |||
| soup = BeautifulSoup(readability_content, 'html.parser') | |||
| # Make all elements plain | |||
| elements = plain_elements(soup.contents, content_digests, node_indexes) | |||
| if node_indexes: | |||
| # Add node index attributes to nodes | |||
| elements = [add_node_indexes(element) for element in elements] | |||
| # Replace article contents with plain elements | |||
| soup.contents = elements | |||
| return str(soup) | |||
| def plain_elements(elements, content_digests, node_indexes): | |||
| # Get plain content versions of all elements | |||
| elements = [plain_element(element, content_digests, node_indexes) | |||
| for element in elements] | |||
| if content_digests: | |||
| # Add content digest attribute to nodes | |||
| elements = [add_content_digest(element) for element in elements] | |||
| return elements | |||
| def plain_element(element, content_digests, node_indexes): | |||
| # For lists, we make each item plain text | |||
| if is_leaf(element): | |||
| # For leaf node elements, extract the text content, discarding any HTML tags | |||
| # 1. Get element contents as text | |||
| plain_text = element.get_text() | |||
| # 2. Normalise the extracted text string to a canonical representation | |||
| plain_text = normalise_text(plain_text) | |||
| # 3. Update element content to be plain text | |||
| element.string = plain_text | |||
| elif is_text(element): | |||
| if is_non_printing(element): | |||
| # The simplified HTML may have come from Readability.js so might | |||
| # have non-printing text (e.g. Comment or CData). In this case, we | |||
| # keep the structure, but ensure that the string is empty. | |||
| element = type(element)("") | |||
| else: | |||
| plain_text = element.string | |||
| plain_text = normalise_text(plain_text) | |||
| element = type(element)(plain_text) | |||
| else: | |||
| # If not a leaf node or leaf type call recursively on child nodes, replacing | |||
| element.contents = plain_elements(element.contents, content_digests, node_indexes) | |||
| return element | |||
| def add_node_indexes(element, node_index="0"): | |||
| # Can't add attributes to string types | |||
| if is_text(element): | |||
| return element | |||
| # Add index to current element | |||
| element["data-node-index"] = node_index | |||
| # Add index to child elements | |||
| for local_idx, child in enumerate( | |||
| [c for c in element.contents if not is_text(c)], start=1): | |||
| # Can't add attributes to leaf string types | |||
| child_index = "{stem}.{local}".format( | |||
| stem=node_index, local=local_idx) | |||
| add_node_indexes(child, node_index=child_index) | |||
| return element | |||
| def normalise_text(text): | |||
| """Normalise unicode and whitespace.""" | |||
| # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them | |||
| text = strip_control_characters(text) | |||
| text = normalise_unicode(text) | |||
| text = normalise_whitespace(text) | |||
| return text | |||
| def strip_control_characters(text): | |||
| """Strip out unicode control characters which might break the parsing.""" | |||
| # Unicode control characters | |||
| # [Cc]: Other, Control [includes new lines] | |||
| # [Cf]: Other, Format | |||
| # [Cn]: Other, Not Assigned | |||
| # [Co]: Other, Private Use | |||
| # [Cs]: Other, Surrogate | |||
| control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs']) | |||
| retained_chars = ['\t', '\n', '\r', '\f'] | |||
| # Remove non-printing control characters | |||
| return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) | |||
| def normalise_unicode(text): | |||
| """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible.""" | |||
| normal_form = "NFKC" | |||
| text = unicodedata.normalize(normal_form, text) | |||
| return text | |||
| def normalise_whitespace(text): | |||
| """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" | |||
| text = regex.sub(r"\s+", " ", text) | |||
| # Remove leading and trailing whitespace | |||
| text = text.strip() | |||
| return text | |||
| def is_leaf(element): | |||
| return (element.name in ['p', 'li']) | |||
| def is_text(element): | |||
| return isinstance(element, NavigableString) | |||
| def is_non_printing(element): | |||
| return any(isinstance(element, _e) for _e in [Comment, CData]) | |||
| def add_content_digest(element): | |||
| if not is_text(element): | |||
| element["data-content-digest"] = content_digest(element) | |||
| return element | |||
| def content_digest(element): | |||
| if is_text(element): | |||
| # Hash | |||
| trimmed_string = element.string.strip() | |||
| if trimmed_string == "": | |||
| digest = "" | |||
| else: | |||
| digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() | |||
| else: | |||
| contents = element.contents | |||
| num_contents = len(contents) | |||
| if num_contents == 0: | |||
| # No hash when no child elements exist | |||
| digest = "" | |||
| elif num_contents == 1: | |||
| # If single child, use digest of child | |||
| digest = content_digest(contents[0]) | |||
| else: | |||
| # Build content digest from the "non-empty" digests of child nodes | |||
| digest = hashlib.sha256() | |||
| child_digests = list( | |||
| filter(lambda x: x != "", [content_digest(content) for content in contents])) | |||
| for child in child_digests: | |||
| digest.update(child.encode('utf-8')) | |||
| digest = digest.hexdigest() | |||
| return digest | |||
| @@ -4,7 +4,7 @@ from langchain.tools import BaseTool | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom | |||
| from core.features.dataset_retrieval import DatasetRetrievalFeature | |||
| from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter | |||
| from core.tools.tool.tool import Tool | |||
| @@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool): | |||
| @staticmethod | |||
| def get_dataset_tools(tenant_id: str, | |||
| dataset_ids: list[str], | |||
| retrieve_config: DatasetRetrieveConfigEntity, | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler | |||
| ) -> list['DatasetRetrieverTool']: | |||
| dataset_ids: list[str], | |||
| retrieve_config: DatasetRetrieveConfigEntity, | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler | |||
| ) -> list['DatasetRetrieverTool']: | |||
| """ | |||
| get dataset tool | |||
| """ | |||
| @@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool): | |||
| ) | |||
| # restore retrieve strategy | |||
| retrieve_config.retrieve_strategy = original_retriever_mode | |||
| # convert langchain tools to Tools | |||
| tools = [] | |||
| for langchain_tool in langchain_tools: | |||
| @@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool): | |||
| llm=langchain_tool.description), | |||
| runtime=DatasetRetrieverTool.Runtime() | |||
| ) | |||
| tools.append(tool) | |||
| return tools | |||
| @@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool): | |||
| def get_runtime_parameters(self) -> list[ToolParameter]: | |||
| return [ | |||
| ToolParameter(name='query', | |||
| label=I18nObject(en_US='', zh_Hans=''), | |||
| human_description=I18nObject(en_US='', zh_Hans=''), | |||
| type=ToolParameter.ToolParameterType.STRING, | |||
| form=ToolParameter.ToolParameterForm.LLM, | |||
| llm_description='Query for the dataset to be used to retrieve the dataset.', | |||
| required=True, | |||
| default=''), | |||
| label=I18nObject(en_US='', zh_Hans=''), | |||
| human_description=I18nObject(en_US='', zh_Hans=''), | |||
| type=ToolParameter.ToolParameterType.STRING, | |||
| form=ToolParameter.ToolParameterForm.LLM, | |||
| llm_description='Query for the dataset to be used to retrieve the dataset.', | |||
| required=True, | |||
| default=''), | |||
| ] | |||
| def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: | |||
| @@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool): | |||
| query = tool_parameters.get('query', None) | |||
| if not query: | |||
| return self.create_text_message(text='please input query') | |||
| # invoke dataset retriever tool | |||
| result = self.langchain_tool._run(query=query) | |||
| @@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool): | |||
| """ | |||
| validate the credentials for dataset retriever tool | |||
| """ | |||
| pass | |||
| pass | |||
| @@ -7,23 +7,14 @@ import subprocess | |||
| import tempfile | |||
| import unicodedata | |||
| from contextlib import contextmanager | |||
| from typing import Any | |||
| import requests | |||
| from bs4 import BeautifulSoup, CData, Comment, NavigableString | |||
| from langchain.chains import RefineDocumentsChain | |||
| from langchain.chains.summarize import refine_prompts | |||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||
| from langchain.tools.base import BaseTool | |||
| from newspaper import Article | |||
| from pydantic import BaseModel, Field | |||
| from regex import regex | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.rag.extractor import extract_processor | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| from core.rag.models.document import Document | |||
| FULL_TEMPLATE = """ | |||
| TITLE: {title} | |||
| @@ -36,106 +27,6 @@ TEXT: | |||
| """ | |||
| class WebReaderToolInput(BaseModel): | |||
| url: str = Field(..., description="URL of the website to read") | |||
| summary: bool = Field( | |||
| default=False, | |||
| description="When the user's question requires extracting the summarizing content of the webpage, " | |||
| "set it to true." | |||
| ) | |||
| cursor: int = Field( | |||
| default=0, | |||
| description="Start reading from this character." | |||
| "Use when the first response was truncated" | |||
| "and you want to continue reading the page." | |||
| "The value cannot exceed 24000.", | |||
| ) | |||
| class WebReaderTool(BaseTool): | |||
| """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" | |||
| name: str = "web_reader" | |||
| args_schema: type[BaseModel] = WebReaderToolInput | |||
| description: str = "use this to read a website. " \ | |||
| "If you can answer the question based on the information provided, " \ | |||
| "there is no need to use." | |||
| page_contents: str = None | |||
| url: str = None | |||
| max_chunk_length: int = 4000 | |||
| summary_chunk_tokens: int = 4000 | |||
| summary_chunk_overlap: int = 0 | |||
| summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] | |||
| continue_reading: bool = True | |||
| model_config: ModelConfigEntity | |||
| model_parameters: dict[str, Any] | |||
| def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: | |||
| try: | |||
| if not self.page_contents or self.url != url: | |||
| page_contents = get_url(url) | |||
| self.page_contents = page_contents | |||
| self.url = url | |||
| else: | |||
| page_contents = self.page_contents | |||
| except Exception as e: | |||
| return f'Read this website failed, caused by: {str(e)}.' | |||
| if summary: | |||
| character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |||
| chunk_size=self.summary_chunk_tokens, | |||
| chunk_overlap=self.summary_chunk_overlap, | |||
| separators=self.summary_separators | |||
| ) | |||
| texts = character_splitter.split_text(page_contents) | |||
| docs = [Document(page_content=t) for t in texts] | |||
| if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'): | |||
| return "No content found." | |||
| # only use first 5 docs | |||
| if len(docs) > 5: | |||
| docs = docs[:5] | |||
| chain = self.get_summary_chain() | |||
| try: | |||
| page_contents = chain.run(docs) | |||
| except Exception as e: | |||
| return f'Read this website failed, caused by: {str(e)}.' | |||
| else: | |||
| page_contents = page_result(page_contents, cursor, self.max_chunk_length) | |||
| if self.continue_reading and len(page_contents) >= self.max_chunk_length: | |||
| page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ | |||
| f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ | |||
| f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." | |||
| return page_contents | |||
| async def _arun(self, url: str) -> str: | |||
| raise NotImplementedError | |||
| def get_summary_chain(self) -> RefineDocumentsChain: | |||
| initial_chain = LLMChain( | |||
| model_config=self.model_config, | |||
| prompt=refine_prompts.PROMPT, | |||
| parameters=self.model_parameters | |||
| ) | |||
| refine_chain = LLMChain( | |||
| model_config=self.model_config, | |||
| prompt=refine_prompts.REFINE_PROMPT, | |||
| parameters=self.model_parameters | |||
| ) | |||
| return RefineDocumentsChain( | |||
| initial_llm_chain=initial_chain, | |||
| refine_llm_chain=refine_chain, | |||
| document_variable_name="text", | |||
| initial_response_name="existing_answer", | |||
| callbacks=self.callbacks | |||
| ) | |||
| def page_result(text: str, cursor: int, max_length: int) -> str: | |||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | |||
| return text[cursor: cursor + max_length] | |||
| @@ -1,7 +1,7 @@ | |||
| import re | |||
| import uuid | |||
| from core.agent.agent_executor import PlanningStrategy | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.external_data_tool.factory import ExternalDataToolFactory | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||
| from core.model_runtime.model_providers import model_provider_factory | |||