| from typing import Tuple, List, Any, Union, Sequence, Optional, cast | from typing import Tuple, List, Any, Union, Sequence, Optional, cast | ||||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | ||||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||||
| from langchain.callbacks.base import BaseCallbackManager | from langchain.callbacks.base import BaseCallbackManager | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | from langchain.prompts.chat import BaseMessagePromptTemplate | ||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | from langchain.schema.language_model import BaseLanguageModel | ||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from pydantic import root_validator | |||||
| from core.model_providers.models.entity.message import to_prompt_messages | |||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | from core.tool.dataset_retriever_tool import DatasetRetrieverTool | ||||
| arbitrary_types_allowed = True | arbitrary_types_allowed = True | ||||
| @root_validator | |||||
| def validate_llm(cls, values: dict) -> dict: | |||||
| return values | |||||
| def should_use_agent(self, query: str): | def should_use_agent(self, query: str): | ||||
| """ | """ | ||||
| return should use agent | return should use agent | ||||
| return AgentFinish(return_values={"output": observation}, log=observation) | return AgentFinish(return_values={"output": observation}, log=observation) | ||||
| try: | try: | ||||
| agent_decision = super().plan(intermediate_steps, callbacks, **kwargs) | |||||
| agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) | |||||
| if isinstance(agent_decision, AgentAction): | if isinstance(agent_decision, AgentAction): | ||||
| tool_inputs = agent_decision.tool_input | tool_inputs = agent_decision.tool_input | ||||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | ||||
| new_exception = self.model_instance.handle_exceptions(e) | new_exception = self.model_instance.handle_exceptions(e) | ||||
| raise new_exception | raise new_exception | ||||
| def real_plan( | |||||
| self, | |||||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||||
| selected_inputs = { | |||||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||||
| } | |||||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||||
| prompt = self.prompt.format_prompt(**full_inputs) | |||||
| messages = prompt.to_messages() | |||||
| prompt_messages = to_prompt_messages(messages) | |||||
| result = self.model_instance.run( | |||||
| messages=prompt_messages, | |||||
| functions=self.functions, | |||||
| ) | |||||
| ai_message = AIMessage( | |||||
| content=result.content, | |||||
| additional_kwargs={ | |||||
| 'function_call': result.function_call | |||||
| } | |||||
| ) | |||||
| agent_decision = _parse_ai_message(ai_message) | |||||
| return agent_decision | |||||
| async def aplan( | async def aplan( | ||||
| self, | self, | ||||
| intermediate_steps: List[Tuple[AgentAction, str]], | intermediate_steps: List[Tuple[AgentAction, str]], | ||||
| @classmethod | @classmethod | ||||
| def from_llm_and_tools( | def from_llm_and_tools( | ||||
| cls, | cls, | ||||
| llm: BaseLanguageModel, | |||||
| model_instance: BaseLLM, | |||||
| tools: Sequence[BaseTool], | tools: Sequence[BaseTool], | ||||
| callback_manager: Optional[BaseCallbackManager] = None, | callback_manager: Optional[BaseCallbackManager] = None, | ||||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | ||||
| ), | ), | ||||
| **kwargs: Any, | **kwargs: Any, | ||||
| ) -> BaseSingleActionAgent: | ) -> BaseSingleActionAgent: | ||||
| return super().from_llm_and_tools( | |||||
| llm=llm, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| prompt = cls.create_prompt( | |||||
| extra_prompt_messages=extra_prompt_messages, | extra_prompt_messages=extra_prompt_messages, | ||||
| system_message=system_message, | system_message=system_message, | ||||
| ) | |||||
| return cls( | |||||
| model_instance=model_instance, | |||||
| llm=FakeLLM(response=''), | |||||
| prompt=prompt, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| **kwargs, | **kwargs, | ||||
| ) | ) |
| _format_intermediate_steps | _format_intermediate_steps | ||||
| from langchain.callbacks.base import BaseCallbackManager | from langchain.callbacks.base import BaseCallbackManager | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken | |||||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | from langchain.prompts.chat import BaseMessagePromptTemplate | ||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \ | |||||
| get_buffer_string | |||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from pydantic import root_validator | |||||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | |||||
| from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin | |||||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | |||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.model_providers.models.entity.message import to_prompt_messages | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin): | |||||
| class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): | |||||
| moving_summary_buffer: str = "" | |||||
| moving_summary_index: int = 0 | |||||
| summary_model_instance: BaseLLM = None | |||||
| model_instance: BaseLLM | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| @root_validator | |||||
| def validate_llm(cls, values: dict) -> dict: | |||||
| return values | |||||
| @classmethod | @classmethod | ||||
| def from_llm_and_tools( | def from_llm_and_tools( | ||||
| cls, | cls, | ||||
| llm: BaseLanguageModel, | |||||
| model_instance: BaseLLM, | |||||
| tools: Sequence[BaseTool], | tools: Sequence[BaseTool], | ||||
| callback_manager: Optional[BaseCallbackManager] = None, | callback_manager: Optional[BaseCallbackManager] = None, | ||||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | ||||
| ), | ), | ||||
| **kwargs: Any, | **kwargs: Any, | ||||
| ) -> BaseSingleActionAgent: | ) -> BaseSingleActionAgent: | ||||
| return super().from_llm_and_tools( | |||||
| llm=llm, | |||||
| prompt = cls.create_prompt( | |||||
| extra_prompt_messages=extra_prompt_messages, | |||||
| system_message=system_message, | |||||
| ) | |||||
| return cls( | |||||
| model_instance=model_instance, | |||||
| llm=FakeLLM(response=''), | |||||
| prompt=prompt, | |||||
| tools=tools, | tools=tools, | ||||
| callback_manager=callback_manager, | callback_manager=callback_manager, | ||||
| extra_prompt_messages=extra_prompt_messages, | |||||
| system_message=cls.get_system_message(), | |||||
| **kwargs, | **kwargs, | ||||
| ) | ) | ||||
| :param query: | :param query: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| original_max_tokens = self.llm.max_tokens | |||||
| self.llm.max_tokens = 40 | |||||
| original_max_tokens = self.model_instance.model_kwargs.max_tokens | |||||
| self.model_instance.model_kwargs.max_tokens = 40 | |||||
| prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | ||||
| messages = prompt.to_messages() | messages = prompt.to_messages() | ||||
| try: | try: | ||||
| predicted_message = self.llm.predict_messages( | |||||
| messages, functions=self.functions, callbacks=None | |||||
| prompt_messages = to_prompt_messages(messages) | |||||
| result = self.model_instance.run( | |||||
| messages=prompt_messages, | |||||
| functions=self.functions, | |||||
| callbacks=None | |||||
| ) | ) | ||||
| except Exception as e: | except Exception as e: | ||||
| new_exception = self.model_instance.handle_exceptions(e) | new_exception = self.model_instance.handle_exceptions(e) | ||||
| raise new_exception | raise new_exception | ||||
| function_call = predicted_message.additional_kwargs.get("function_call", {}) | |||||
| function_call = result.function_call | |||||
| self.llm.max_tokens = original_max_tokens | |||||
| self.model_instance.model_kwargs.max_tokens = original_max_tokens | |||||
| return True if function_call else False | return True if function_call else False | ||||
| except ExceededLLMTokensLimitError as e: | except ExceededLLMTokensLimitError as e: | ||||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | return AgentFinish(return_values={"output": str(e)}, log=str(e)) | ||||
| predicted_message = self.llm.predict_messages( | |||||
| messages, functions=self.functions, callbacks=callbacks | |||||
| prompt_messages = to_prompt_messages(messages) | |||||
| result = self.model_instance.run( | |||||
| messages=prompt_messages, | |||||
| functions=self.functions, | |||||
| ) | |||||
| ai_message = AIMessage( | |||||
| content=result.content, | |||||
| additional_kwargs={ | |||||
| 'function_call': result.function_call | |||||
| } | |||||
| ) | ) | ||||
| agent_decision = _parse_ai_message(predicted_message) | |||||
| agent_decision = _parse_ai_message(ai_message) | |||||
| if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': | if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': | ||||
| tool_inputs = agent_decision.tool_input | tool_inputs = agent_decision.tool_input | ||||
| return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) | return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) | ||||
| except ValueError: | except ValueError: | ||||
| return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") | return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") | ||||
| def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) | |||||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||||
| if rest_tokens >= 0: | |||||
| return messages | |||||
| system_message = None | |||||
| human_message = None | |||||
| should_summary_messages = [] | |||||
| for message in messages: | |||||
| if isinstance(message, SystemMessage): | |||||
| system_message = message | |||||
| elif isinstance(message, HumanMessage): | |||||
| human_message = message | |||||
| else: | |||||
| should_summary_messages.append(message) | |||||
| if len(should_summary_messages) > 2: | |||||
| ai_message = should_summary_messages[-2] | |||||
| function_message = should_summary_messages[-1] | |||||
| should_summary_messages = should_summary_messages[self.moving_summary_index:-2] | |||||
| self.moving_summary_index = len(should_summary_messages) | |||||
| else: | |||||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||||
| raise ExceededLLMTokensLimitError(error_msg) | |||||
| new_messages = [system_message, human_message] | |||||
| if self.moving_summary_index == 0: | |||||
| should_summary_messages.insert(0, human_message) | |||||
| self.moving_summary_buffer = self.predict_new_summary( | |||||
| messages=should_summary_messages, | |||||
| existing_summary=self.moving_summary_buffer | |||||
| ) | |||||
| new_messages.append(AIMessage(content=self.moving_summary_buffer)) | |||||
| new_messages.append(ai_message) | |||||
| new_messages.append(function_message) | |||||
| return new_messages | |||||
| def predict_new_summary( | |||||
| self, messages: List[BaseMessage], existing_summary: str | |||||
| ) -> str: | |||||
| new_lines = get_buffer_string( | |||||
| messages, | |||||
| human_prefix="Human", | |||||
| ai_prefix="AI", | |||||
| ) | |||||
| chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) | |||||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||||
| if model_instance.model_provider.provider_name == 'azure_openai': | |||||
| model = model_instance.base_model_name | |||||
| model = model.replace("gpt-35", "gpt-3.5") | |||||
| else: | |||||
| model = model_instance.base_model_name | |||||
| tiktoken_ = _import_tiktoken() | |||||
| try: | |||||
| encoding = tiktoken_.encoding_for_model(model) | |||||
| except KeyError: | |||||
| model = "cl100k_base" | |||||
| encoding = tiktoken_.get_encoding(model) | |||||
| if model.startswith("gpt-3.5-turbo"): | |||||
| # every message follows <im_start>{role/name}\n{content}<im_end>\n | |||||
| tokens_per_message = 4 | |||||
| # if there's a name, the role is omitted | |||||
| tokens_per_name = -1 | |||||
| elif model.startswith("gpt-4"): | |||||
| tokens_per_message = 3 | |||||
| tokens_per_name = 1 | |||||
| else: | |||||
| raise NotImplementedError( | |||||
| f"get_num_tokens_from_messages() is not presently implemented " | |||||
| f"for model {model}." | |||||
| "See https://github.com/openai/openai-python/blob/main/chatml.md for " | |||||
| "information on how messages are converted to tokens." | |||||
| ) | |||||
| num_tokens = 0 | |||||
| for m in messages: | |||||
| message = _convert_message_to_dict(m) | |||||
| num_tokens += tokens_per_message | |||||
| for key, value in message.items(): | |||||
| if key == "function_call": | |||||
| for f_key, f_value in value.items(): | |||||
| num_tokens += len(encoding.encode(f_key)) | |||||
| num_tokens += len(encoding.encode(f_value)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(value)) | |||||
| if key == "name": | |||||
| num_tokens += tokens_per_name | |||||
| # every reply is primed with <im_start>assistant | |||||
| num_tokens += 3 | |||||
| if kwargs.get('functions'): | |||||
| for function in kwargs.get('functions'): | |||||
| num_tokens += len(encoding.encode('name')) | |||||
| num_tokens += len(encoding.encode(function.get("name"))) | |||||
| num_tokens += len(encoding.encode('description')) | |||||
| num_tokens += len(encoding.encode(function.get("description"))) | |||||
| parameters = function.get("parameters") | |||||
| num_tokens += len(encoding.encode('parameters')) | |||||
| if 'title' in parameters: | |||||
| num_tokens += len(encoding.encode('title')) | |||||
| num_tokens += len(encoding.encode(parameters.get("title"))) | |||||
| num_tokens += len(encoding.encode('type')) | |||||
| num_tokens += len(encoding.encode(parameters.get("type"))) | |||||
| if 'properties' in parameters: | |||||
| num_tokens += len(encoding.encode('properties')) | |||||
| for key, value in parameters.get('properties').items(): | |||||
| num_tokens += len(encoding.encode(key)) | |||||
| for field_key, field_value in value.items(): | |||||
| num_tokens += len(encoding.encode(field_key)) | |||||
| if field_key == 'enum': | |||||
| for enum_field in field_value: | |||||
| num_tokens += 3 | |||||
| num_tokens += len(encoding.encode(enum_field)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(field_key)) | |||||
| num_tokens += len(encoding.encode(str(field_value))) | |||||
| if 'required' in parameters: | |||||
| num_tokens += len(encoding.encode('required')) | |||||
| for required_field in parameters['required']: | |||||
| num_tokens += 3 | |||||
| num_tokens += len(encoding.encode(required_field)) | |||||
| return num_tokens |
| from typing import cast, List | |||||
| from langchain.chat_models import ChatOpenAI | |||||
| from langchain.chat_models.openai import _convert_message_to_dict | |||||
| from langchain.memory.summary import SummarizerMixin | |||||
| from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from pydantic import BaseModel | |||||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): | |||||
| moving_summary_buffer: str = "" | |||||
| moving_summary_index: int = 0 | |||||
| summary_llm: BaseLanguageModel = None | |||||
| model_instance: BaseLLM | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) | |||||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||||
| if rest_tokens >= 0: | |||||
| return messages | |||||
| system_message = None | |||||
| human_message = None | |||||
| should_summary_messages = [] | |||||
| for message in messages: | |||||
| if isinstance(message, SystemMessage): | |||||
| system_message = message | |||||
| elif isinstance(message, HumanMessage): | |||||
| human_message = message | |||||
| else: | |||||
| should_summary_messages.append(message) | |||||
| if len(should_summary_messages) > 2: | |||||
| ai_message = should_summary_messages[-2] | |||||
| function_message = should_summary_messages[-1] | |||||
| should_summary_messages = should_summary_messages[self.moving_summary_index:-2] | |||||
| self.moving_summary_index = len(should_summary_messages) | |||||
| else: | |||||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||||
| raise ExceededLLMTokensLimitError(error_msg) | |||||
| new_messages = [system_message, human_message] | |||||
| if self.moving_summary_index == 0: | |||||
| should_summary_messages.insert(0, human_message) | |||||
| summary_handler = SummarizerMixin(llm=self.summary_llm) | |||||
| self.moving_summary_buffer = summary_handler.predict_new_summary( | |||||
| messages=should_summary_messages, | |||||
| existing_summary=self.moving_summary_buffer | |||||
| ) | |||||
| new_messages.append(AIMessage(content=self.moving_summary_buffer)) | |||||
| new_messages.append(ai_message) | |||||
| new_messages.append(function_message) | |||||
| return new_messages | |||||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||||
| llm = cast(ChatOpenAI, model_instance.client) | |||||
| model, encoding = llm._get_encoding_model() | |||||
| if model.startswith("gpt-3.5-turbo"): | |||||
| # every message follows <im_start>{role/name}\n{content}<im_end>\n | |||||
| tokens_per_message = 4 | |||||
| # if there's a name, the role is omitted | |||||
| tokens_per_name = -1 | |||||
| elif model.startswith("gpt-4"): | |||||
| tokens_per_message = 3 | |||||
| tokens_per_name = 1 | |||||
| else: | |||||
| raise NotImplementedError( | |||||
| f"get_num_tokens_from_messages() is not presently implemented " | |||||
| f"for model {model}." | |||||
| "See https://github.com/openai/openai-python/blob/main/chatml.md for " | |||||
| "information on how messages are converted to tokens." | |||||
| ) | |||||
| num_tokens = 0 | |||||
| for m in messages: | |||||
| message = _convert_message_to_dict(m) | |||||
| num_tokens += tokens_per_message | |||||
| for key, value in message.items(): | |||||
| if key == "function_call": | |||||
| for f_key, f_value in value.items(): | |||||
| num_tokens += len(encoding.encode(f_key)) | |||||
| num_tokens += len(encoding.encode(f_value)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(value)) | |||||
| if key == "name": | |||||
| num_tokens += tokens_per_name | |||||
| # every reply is primed with <im_start>assistant | |||||
| num_tokens += 3 | |||||
| if kwargs.get('functions'): | |||||
| for function in kwargs.get('functions'): | |||||
| num_tokens += len(encoding.encode('name')) | |||||
| num_tokens += len(encoding.encode(function.get("name"))) | |||||
| num_tokens += len(encoding.encode('description')) | |||||
| num_tokens += len(encoding.encode(function.get("description"))) | |||||
| parameters = function.get("parameters") | |||||
| num_tokens += len(encoding.encode('parameters')) | |||||
| if 'title' in parameters: | |||||
| num_tokens += len(encoding.encode('title')) | |||||
| num_tokens += len(encoding.encode(parameters.get("title"))) | |||||
| num_tokens += len(encoding.encode('type')) | |||||
| num_tokens += len(encoding.encode(parameters.get("type"))) | |||||
| if 'properties' in parameters: | |||||
| num_tokens += len(encoding.encode('properties')) | |||||
| for key, value in parameters.get('properties').items(): | |||||
| num_tokens += len(encoding.encode(key)) | |||||
| for field_key, field_value in value.items(): | |||||
| num_tokens += len(encoding.encode(field_key)) | |||||
| if field_key == 'enum': | |||||
| for enum_field in field_value: | |||||
| num_tokens += 3 | |||||
| num_tokens += len(encoding.encode(enum_field)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(field_key)) | |||||
| num_tokens += len(encoding.encode(str(field_value))) | |||||
| if 'required' in parameters: | |||||
| num_tokens += len(encoding.encode('required')) | |||||
| for required_field in parameters['required']: | |||||
| num_tokens += 3 | |||||
| num_tokens += len(encoding.encode(required_field)) | |||||
| return num_tokens |
| from typing import List, Tuple, Any, Union, Sequence, Optional | |||||
| from langchain.agents import BaseMultiActionAgent | |||||
| from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \ | |||||
| _parse_ai_message | |||||
| from langchain.callbacks.base import BaseCallbackManager | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from langchain.tools import BaseTool | |||||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | |||||
| from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin | |||||
| class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin): | |||||
| @classmethod | |||||
| def from_llm_and_tools( | |||||
| cls, | |||||
| llm: BaseLanguageModel, | |||||
| tools: Sequence[BaseTool], | |||||
| callback_manager: Optional[BaseCallbackManager] = None, | |||||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||||
| system_message: Optional[SystemMessage] = SystemMessage( | |||||
| content="You are a helpful AI assistant." | |||||
| ), | |||||
| **kwargs: Any, | |||||
| ) -> BaseMultiActionAgent: | |||||
| return super().from_llm_and_tools( | |||||
| llm=llm, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| extra_prompt_messages=extra_prompt_messages, | |||||
| system_message=cls.get_system_message(), | |||||
| **kwargs, | |||||
| ) | |||||
| def should_use_agent(self, query: str): | |||||
| """ | |||||
| return should use agent | |||||
| :param query: | |||||
| :return: | |||||
| """ | |||||
| original_max_tokens = self.llm.max_tokens | |||||
| self.llm.max_tokens = 15 | |||||
| prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | |||||
| messages = prompt.to_messages() | |||||
| try: | |||||
| predicted_message = self.llm.predict_messages( | |||||
| messages, functions=self.functions, callbacks=None | |||||
| ) | |||||
| except Exception as e: | |||||
| new_exception = self.model_instance.handle_exceptions(e) | |||||
| raise new_exception | |||||
| function_call = predicted_message.additional_kwargs.get("function_call", {}) | |||||
| self.llm.max_tokens = original_max_tokens | |||||
| return True if function_call else False | |||||
| def plan( | |||||
| self, | |||||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||||
| selected_inputs = { | |||||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||||
| } | |||||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||||
| prompt = self.prompt.format_prompt(**full_inputs) | |||||
| messages = prompt.to_messages() | |||||
| # summarize messages if rest_tokens < 0 | |||||
| try: | |||||
| messages = self.summarize_messages_if_needed(messages, functions=self.functions) | |||||
| except ExceededLLMTokensLimitError as e: | |||||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||||
| predicted_message = self.llm.predict_messages( | |||||
| messages, functions=self.functions, callbacks=callbacks | |||||
| ) | |||||
| agent_decision = _parse_ai_message(predicted_message) | |||||
| return agent_decision | |||||
| @classmethod | |||||
| def get_system_message(cls): | |||||
| # get current time | |||||
| return SystemMessage(content="You are a helpful AI assistant.\n" | |||||
| "The current date or current time you know is wrong.\n" | |||||
| "Respond directly if appropriate.") |
| from langchain import BasePromptTemplate | from langchain import BasePromptTemplate | ||||
| from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | ||||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | ||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.callbacks.base import BaseCallbackManager | from langchain.callbacks.base import BaseCallbackManager | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | ||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | ||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | from core.tool.dataset_retriever_tool import DatasetRetrieverTool | ||||
| class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | ||||
| model_instance: BaseLLM | |||||
| dataset_tools: Sequence[BaseTool] | dataset_tools: Sequence[BaseTool] | ||||
| class Config: | class Config: | ||||
| try: | try: | ||||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | ||||
| except Exception as e: | except Exception as e: | ||||
| new_exception = self.model_instance.handle_exceptions(e) | |||||
| new_exception = self.llm_chain.model_instance.handle_exceptions(e) | |||||
| raise new_exception | raise new_exception | ||||
| try: | try: | ||||
| @classmethod | @classmethod | ||||
| def from_llm_and_tools( | def from_llm_and_tools( | ||||
| cls, | cls, | ||||
| llm: BaseLanguageModel, | |||||
| model_instance: BaseLLM, | |||||
| tools: Sequence[BaseTool], | tools: Sequence[BaseTool], | ||||
| callback_manager: Optional[BaseCallbackManager] = None, | callback_manager: Optional[BaseCallbackManager] = None, | ||||
| output_parser: Optional[AgentOutputParser] = None, | output_parser: Optional[AgentOutputParser] = None, | ||||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | memory_prompts: Optional[List[BasePromptTemplate]] = None, | ||||
| **kwargs: Any, | **kwargs: Any, | ||||
| ) -> Agent: | ) -> Agent: | ||||
| return super().from_llm_and_tools( | |||||
| llm=llm, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| output_parser=output_parser, | |||||
| """Construct an agent from an LLM and tools.""" | |||||
| cls._validate_tools(tools) | |||||
| prompt = cls.create_prompt( | |||||
| tools, | |||||
| prefix=prefix, | prefix=prefix, | ||||
| suffix=suffix, | suffix=suffix, | ||||
| human_message_template=human_message_template, | human_message_template=human_message_template, | ||||
| format_instructions=format_instructions, | format_instructions=format_instructions, | ||||
| input_variables=input_variables, | input_variables=input_variables, | ||||
| memory_prompts=memory_prompts, | memory_prompts=memory_prompts, | ||||
| ) | |||||
| llm_chain = LLMChain( | |||||
| model_instance=model_instance, | |||||
| prompt=prompt, | |||||
| callback_manager=callback_manager, | |||||
| ) | |||||
| tool_names = [tool.name for tool in tools] | |||||
| _output_parser = output_parser | |||||
| return cls( | |||||
| llm_chain=llm_chain, | |||||
| allowed_tools=tool_names, | |||||
| output_parser=_output_parser, | |||||
| dataset_tools=tools, | dataset_tools=tools, | ||||
| **kwargs, | **kwargs, | ||||
| ) | ) |
| from langchain import BasePromptTemplate | from langchain import BasePromptTemplate | ||||
| from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | ||||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | ||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.callbacks.base import BaseCallbackManager | from langchain.callbacks.base import BaseCallbackManager | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.memory.summary import SummarizerMixin | |||||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||||
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | ||||
| from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException | |||||
| from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \ | |||||
| get_buffer_string | |||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | ||||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | ||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | ||||
| class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | ||||
| moving_summary_buffer: str = "" | moving_summary_buffer: str = "" | ||||
| moving_summary_index: int = 0 | moving_summary_index: int = 0 | ||||
| summary_llm: BaseLanguageModel = None | |||||
| model_instance: BaseLLM | |||||
| summary_model_instance: BaseLLM = None | |||||
| class Config: | class Config: | ||||
| """Configuration for this pydantic object.""" | """Configuration for this pydantic object.""" | ||||
| if prompts: | if prompts: | ||||
| messages = prompts[0].to_messages() | messages = prompts[0].to_messages() | ||||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages) | |||||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages) | |||||
| if rest_tokens < 0: | if rest_tokens < 0: | ||||
| full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | ||||
| try: | try: | ||||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | ||||
| except Exception as e: | except Exception as e: | ||||
| new_exception = self.model_instance.handle_exceptions(e) | |||||
| new_exception = self.llm_chain.model_instance.handle_exceptions(e) | |||||
| raise new_exception | raise new_exception | ||||
| try: | try: | ||||
| "I don't know how to respond to that."}, "") | "I don't know how to respond to that."}, "") | ||||
| def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): | def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): | ||||
| if len(intermediate_steps) >= 2 and self.summary_llm: | |||||
| if len(intermediate_steps) >= 2 and self.summary_model_instance: | |||||
| should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] | should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] | ||||
| should_summary_messages = [AIMessage(content=observation) | should_summary_messages = [AIMessage(content=observation) | ||||
| for _, observation in should_summary_intermediate_steps] | for _, observation in should_summary_intermediate_steps] | ||||
| error_msg = "Exceeded LLM tokens limit, stopped." | error_msg = "Exceeded LLM tokens limit, stopped." | ||||
| raise ExceededLLMTokensLimitError(error_msg) | raise ExceededLLMTokensLimitError(error_msg) | ||||
| summary_handler = SummarizerMixin(llm=self.summary_llm) | |||||
| if self.moving_summary_buffer and 'chat_history' in kwargs: | if self.moving_summary_buffer and 'chat_history' in kwargs: | ||||
| kwargs["chat_history"].pop() | kwargs["chat_history"].pop() | ||||
| self.moving_summary_buffer = summary_handler.predict_new_summary( | |||||
| self.moving_summary_buffer = self.predict_new_summary( | |||||
| messages=should_summary_messages, | messages=should_summary_messages, | ||||
| existing_summary=self.moving_summary_buffer | existing_summary=self.moving_summary_buffer | ||||
| ) | ) | ||||
| return self.get_full_inputs([intermediate_steps[-1]], **kwargs) | return self.get_full_inputs([intermediate_steps[-1]], **kwargs) | ||||
| def predict_new_summary( | |||||
| self, messages: List[BaseMessage], existing_summary: str | |||||
| ) -> str: | |||||
| new_lines = get_buffer_string( | |||||
| messages, | |||||
| human_prefix="Human", | |||||
| ai_prefix="AI", | |||||
| ) | |||||
| chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) | |||||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||||
| @classmethod | @classmethod | ||||
| def create_prompt( | def create_prompt( | ||||
| cls, | cls, | ||||
| @classmethod | @classmethod | ||||
| def from_llm_and_tools( | def from_llm_and_tools( | ||||
| cls, | cls, | ||||
| llm: BaseLanguageModel, | |||||
| model_instance: BaseLLM, | |||||
| tools: Sequence[BaseTool], | tools: Sequence[BaseTool], | ||||
| callback_manager: Optional[BaseCallbackManager] = None, | callback_manager: Optional[BaseCallbackManager] = None, | ||||
| output_parser: Optional[AgentOutputParser] = None, | output_parser: Optional[AgentOutputParser] = None, | ||||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | memory_prompts: Optional[List[BasePromptTemplate]] = None, | ||||
| **kwargs: Any, | **kwargs: Any, | ||||
| ) -> Agent: | ) -> Agent: | ||||
| return super().from_llm_and_tools( | |||||
| llm=llm, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| output_parser=output_parser, | |||||
| """Construct an agent from an LLM and tools.""" | |||||
| cls._validate_tools(tools) | |||||
| prompt = cls.create_prompt( | |||||
| tools, | |||||
| prefix=prefix, | prefix=prefix, | ||||
| suffix=suffix, | suffix=suffix, | ||||
| human_message_template=human_message_template, | human_message_template=human_message_template, | ||||
| format_instructions=format_instructions, | format_instructions=format_instructions, | ||||
| input_variables=input_variables, | input_variables=input_variables, | ||||
| memory_prompts=memory_prompts, | memory_prompts=memory_prompts, | ||||
| ) | |||||
| llm_chain = LLMChain( | |||||
| model_instance=model_instance, | |||||
| prompt=prompt, | |||||
| callback_manager=callback_manager, | |||||
| ) | |||||
| tool_names = [tool.name for tool in tools] | |||||
| _output_parser = output_parser | |||||
| return cls( | |||||
| llm_chain=llm_chain, | |||||
| allowed_tools=tool_names, | |||||
| output_parser=_output_parser, | |||||
| **kwargs, | **kwargs, | ||||
| ) | ) |
| from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | ||||
| from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | ||||
| from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent | |||||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | ||||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | ||||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | ||||
| REACT_ROUTER = 'react_router' | REACT_ROUTER = 'react_router' | ||||
| REACT = 'react' | REACT = 'react' | ||||
| FUNCTION_CALL = 'function_call' | FUNCTION_CALL = 'function_call' | ||||
| MULTI_FUNCTION_CALL = 'multi_function_call' | |||||
| class AgentConfiguration(BaseModel): | class AgentConfiguration(BaseModel): | ||||
| if self.configuration.strategy == PlanningStrategy.REACT: | if self.configuration.strategy == PlanningStrategy.REACT: | ||||
| agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | ||||
| model_instance=self.configuration.model_instance, | model_instance=self.configuration.model_instance, | ||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| output_parser=StructuredChatOutputParser(), | output_parser=StructuredChatOutputParser(), | ||||
| summary_llm=self.configuration.summary_model_instance.client | |||||
| summary_model_instance=self.configuration.summary_model_instance | |||||
| if self.configuration.summary_model_instance else None, | if self.configuration.summary_model_instance else None, | ||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | ||||
| agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | ||||
| model_instance=self.configuration.model_instance, | model_instance=self.configuration.model_instance, | ||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | ||||
| summary_llm=self.configuration.summary_model_instance.client | |||||
| if self.configuration.summary_model_instance else None, | |||||
| verbose=True | |||||
| ) | |||||
| elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: | |||||
| agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( | |||||
| model_instance=self.configuration.model_instance, | |||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | |||||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | |||||
| summary_llm=self.configuration.summary_model_instance.client | |||||
| summary_model_instance=self.configuration.summary_model_instance | |||||
| if self.configuration.summary_model_instance else None, | if self.configuration.summary_model_instance else None, | ||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | ||||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | agent = MultiDatasetRouterAgent.from_llm_and_tools( | ||||
| model_instance=self.configuration.model_instance, | model_instance=self.configuration.model_instance, | ||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, | extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, | ||||
| verbose=True | verbose=True | ||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | ||||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | ||||
| model_instance=self.configuration.model_instance, | model_instance=self.configuration.model_instance, | ||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| output_parser=StructuredChatOutputParser(), | output_parser=StructuredChatOutputParser(), | ||||
| verbose=True | verbose=True |
| from typing import List, Dict, Any, Optional | |||||
| from langchain import LLMChain as LCLLMChain | |||||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||||
| from langchain.schema import LLMResult, Generation | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from core.model_providers.models.entity.message import to_prompt_messages | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| class LLMChain(LCLLMChain): | |||||
| model_instance: BaseLLM | |||||
| """The language model instance to use.""" | |||||
| llm: BaseLanguageModel = FakeLLM(response="") | |||||
| def generate( | |||||
| self, | |||||
| input_list: List[Dict[str, Any]], | |||||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||||
| ) -> LLMResult: | |||||
| """Generate LLM result from inputs.""" | |||||
| prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) | |||||
| messages = prompts[0].to_messages() | |||||
| prompt_messages = to_prompt_messages(messages) | |||||
| result = self.model_instance.run( | |||||
| messages=prompt_messages, | |||||
| stop=stop | |||||
| ) | |||||
| generations = [ | |||||
| [Generation(text=result.content)] | |||||
| ] | |||||
| return LLMResult(generations=generations) |
| import enum | import enum | ||||
| from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage | |||||
| from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| prompt_tokens: int | prompt_tokens: int | ||||
| completion_tokens: int | completion_tokens: int | ||||
| source: list = None | source: list = None | ||||
| function_call: dict = None | |||||
| class MessageType(enum.Enum): | class MessageType(enum.Enum): | ||||
| class PromptMessage(BaseModel): | class PromptMessage(BaseModel): | ||||
| type: MessageType = MessageType.HUMAN | type: MessageType = MessageType.HUMAN | ||||
| content: str = '' | content: str = '' | ||||
| function_call: dict = None | |||||
| def to_lc_messages(messages: list[PromptMessage]): | def to_lc_messages(messages: list[PromptMessage]): | ||||
| if message.type == MessageType.HUMAN: | if message.type == MessageType.HUMAN: | ||||
| lc_messages.append(HumanMessage(content=message.content)) | lc_messages.append(HumanMessage(content=message.content)) | ||||
| elif message.type == MessageType.ASSISTANT: | elif message.type == MessageType.ASSISTANT: | ||||
| lc_messages.append(AIMessage(content=message.content)) | |||||
| additional_kwargs = {} | |||||
| if message.function_call: | |||||
| additional_kwargs['function_call'] = message.function_call | |||||
| lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs)) | |||||
| elif message.type == MessageType.SYSTEM: | elif message.type == MessageType.SYSTEM: | ||||
| lc_messages.append(SystemMessage(content=message.content)) | lc_messages.append(SystemMessage(content=message.content)) | ||||
| if isinstance(message, HumanMessage): | if isinstance(message, HumanMessage): | ||||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | ||||
| elif isinstance(message, AIMessage): | elif isinstance(message, AIMessage): | ||||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT)) | |||||
| message_kwargs = { | |||||
| 'content': message.content, | |||||
| 'type': MessageType.ASSISTANT | |||||
| } | |||||
| if 'function_call' in message.additional_kwargs: | |||||
| message_kwargs['function_call'] = message.additional_kwargs['function_call'] | |||||
| prompt_messages.append(PromptMessage(**message_kwargs)) | |||||
| elif isinstance(message, SystemMessage): | elif isinstance(message, SystemMessage): | ||||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) | prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) | ||||
| elif isinstance(message, FunctionMessage): | |||||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | |||||
| return prompt_messages | return prompt_messages | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| prompts = self._get_prompt_from_messages(messages) | prompts = self._get_prompt_from_messages(messages) | ||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| generate_kwargs = { | |||||
| 'stop': stop, | |||||
| 'callbacks': callbacks | |||||
| } | |||||
| if isinstance(prompts, str): | |||||
| generate_kwargs['prompts'] = [prompts] | |||||
| else: | |||||
| generate_kwargs['messages'] = [prompts] | |||||
| if 'functions' in kwargs: | |||||
| generate_kwargs['functions'] = kwargs['functions'] | |||||
| return self._client.generate(**generate_kwargs) | |||||
| @property | @property | ||||
| def base_model_name(self) -> str: | def base_model_name(self) -> str: |
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler | from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler | ||||
| from core.helper import moderation | from core.helper import moderation | ||||
| from core.model_providers.models.base import BaseProviderModel | from core.model_providers.models.base import BaseProviderModel | ||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \ | |||||
| to_lc_messages | |||||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | ||||
| from core.model_providers.providers.base import BaseModelProvider | from core.model_providers.providers.base import BaseModelProvider | ||||
| from core.prompt.prompt_builder import PromptBuilder | from core.prompt.prompt_builder import PromptBuilder | ||||
| except Exception as ex: | except Exception as ex: | ||||
| raise self.handle_exceptions(ex) | raise self.handle_exceptions(ex) | ||||
| function_call = None | |||||
| if isinstance(result.generations[0][0], ChatGeneration): | if isinstance(result.generations[0][0], ChatGeneration): | ||||
| completion_content = result.generations[0][0].message.content | completion_content = result.generations[0][0].message.content | ||||
| if 'function_call' in result.generations[0][0].message.additional_kwargs: | |||||
| function_call = result.generations[0][0].message.additional_kwargs.get('function_call') | |||||
| else: | else: | ||||
| completion_content = result.generations[0][0].text | completion_content = result.generations[0][0].text | ||||
| return LLMRunResult( | return LLMRunResult( | ||||
| content=completion_content, | content=completion_content, | ||||
| prompt_tokens=prompt_tokens, | prompt_tokens=prompt_tokens, | ||||
| completion_tokens=completion_tokens | |||||
| completion_tokens=completion_tokens, | |||||
| function_call=function_call | |||||
| ) | ) | ||||
| @abstractmethod | @abstractmethod | ||||
| if len(messages) == 0: | if len(messages) == 0: | ||||
| return [] | return [] | ||||
| chat_messages = [] | |||||
| for message in messages: | |||||
| if message.type == MessageType.HUMAN: | |||||
| chat_messages.append(HumanMessage(content=message.content)) | |||||
| elif message.type == MessageType.ASSISTANT: | |||||
| chat_messages.append(AIMessage(content=message.content)) | |||||
| elif message.type == MessageType.SYSTEM: | |||||
| chat_messages.append(SystemMessage(content=message.content)) | |||||
| return chat_messages | |||||
| return to_lc_messages(messages) | |||||
| def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: | def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: | ||||
| """ | """ |
| raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") | raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") | ||||
| prompts = self._get_prompt_from_messages(messages) | prompts = self._get_prompt_from_messages(messages) | ||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| generate_kwargs = { | |||||
| 'stop': stop, | |||||
| 'callbacks': callbacks | |||||
| } | |||||
| if isinstance(prompts, str): | |||||
| generate_kwargs['prompts'] = [prompts] | |||||
| else: | |||||
| generate_kwargs['messages'] = [prompts] | |||||
| if 'functions' in kwargs: | |||||
| generate_kwargs['functions'] = kwargs['functions'] | |||||
| return self._client.generate(**generate_kwargs) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | def get_num_tokens(self, messages: List[PromptMessage]) -> int: | ||||
| """ | """ |
| import math | import math | ||||
| from typing import Optional | from typing import Optional | ||||
| from flask import current_app | |||||
| from langchain import WikipediaAPIWrapper | from langchain import WikipediaAPIWrapper | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DatasetProcessRule | from models.dataset import Dataset, DatasetProcessRule | ||||
| from models.model import AppModelConfig | from models.model import AppModelConfig | ||||
| from models.provider import ProviderType | |||||
| class OrchestratorRuleParser: | class OrchestratorRuleParser: | ||||
| # only OpenAI chat model (include Azure) support function call, use ReACT instead | # only OpenAI chat model (include Azure) support function call, use ReACT instead | ||||
| if agent_model_instance.model_mode != ModelMode.CHAT \ | if agent_model_instance.model_mode != ModelMode.CHAT \ | ||||
| or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: | or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: | ||||
| if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: | |||||
| if planning_strategy == PlanningStrategy.FUNCTION_CALL: | |||||
| planning_strategy = PlanningStrategy.REACT | planning_strategy = PlanningStrategy.REACT | ||||
| elif planning_strategy == PlanningStrategy.ROUTER: | elif planning_strategy == PlanningStrategy.ROUTER: | ||||
| planning_strategy = PlanningStrategy.REACT_ROUTER | planning_strategy = PlanningStrategy.REACT_ROUTER | ||||
| tool = self.to_current_datetime_tool() | tool = self.to_current_datetime_tool() | ||||
| if tool: | if tool: | ||||
| tool.callbacks.extend(callbacks) | |||||
| if tool.callbacks is not None: | |||||
| tool.callbacks.extend(callbacks) | |||||
| else: | |||||
| tool.callbacks = callbacks | |||||
| tools.append(tool) | tools.append(tool) | ||||
| return tools | return tools | ||||
| summary_model_instance = None | summary_model_instance = None | ||||
| tool = WebReaderTool( | tool = WebReaderTool( | ||||
| llm=summary_model_instance.client if summary_model_instance else None, | |||||
| model_instance=summary_model_instance if summary_model_instance else None, | |||||
| max_chunk_length=4000, | max_chunk_length=4000, | ||||
| continue_reading=True, | |||||
| callbacks=[DifyStdOutCallbackHandler()] | |||||
| continue_reading=True | |||||
| ) | ) | ||||
| return tool | return tool | ||||
| "is not up to date. " | "is not up to date. " | ||||
| "Input should be a search query.", | "Input should be a search query.", | ||||
| func=OptimizedSerpAPIWrapper(**func_kwargs).run, | func=OptimizedSerpAPIWrapper(**func_kwargs).run, | ||||
| args_schema=OptimizedSerpAPIInput, | |||||
| callbacks=[DifyStdOutCallbackHandler()] | |||||
| args_schema=OptimizedSerpAPIInput | |||||
| ) | ) | ||||
| return tool | return tool | ||||
| def to_current_datetime_tool(self) -> Optional[BaseTool]: | def to_current_datetime_tool(self) -> Optional[BaseTool]: | ||||
| tool = DatetimeTool( | |||||
| callbacks=[DifyStdOutCallbackHandler()] | |||||
| ) | |||||
| tool = DatetimeTool() | |||||
| return tool | return tool | ||||
| return WikipediaQueryRun( | return WikipediaQueryRun( | ||||
| name="wikipedia", | name="wikipedia", | ||||
| api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), | api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), | ||||
| args_schema=WikipediaInput, | |||||
| callbacks=[DifyStdOutCallbackHandler()] | |||||
| args_schema=WikipediaInput | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| import requests | import requests | ||||
| from bs4 import BeautifulSoup, NavigableString, Comment, CData | from bs4 import BeautifulSoup, NavigableString, Comment, CData | ||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.chains.summarize import load_summarize_chain | |||||
| from langchain.chains import RefineDocumentsChain | |||||
| from langchain.chains.summarize import refine_prompts | |||||
| from langchain.schema import Document | from langchain.schema import Document | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||
| from langchain.tools.base import BaseTool | from langchain.tools.base import BaseTool | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from regex import regex | from regex import regex | ||||
| from core.chain.llm_chain import LLMChain | |||||
| from core.data_loader import file_extractor | from core.data_loader import file_extractor | ||||
| from core.data_loader.file_extractor import FileExtractor | from core.data_loader.file_extractor import FileExtractor | ||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| FULL_TEMPLATE = """ | FULL_TEMPLATE = """ | ||||
| TITLE: {title} | TITLE: {title} | ||||
| summary_chunk_overlap: int = 0 | summary_chunk_overlap: int = 0 | ||||
| summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] | summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] | ||||
| continue_reading: bool = True | continue_reading: bool = True | ||||
| llm: BaseLanguageModel = None | |||||
| model_instance: BaseLLM = None | |||||
| def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: | def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: | ||||
| try: | try: | ||||
| except Exception as e: | except Exception as e: | ||||
| return f'Read this website failed, caused by: {str(e)}.' | return f'Read this website failed, caused by: {str(e)}.' | ||||
| if summary and self.llm: | |||||
| if summary and self.model_instance: | |||||
| character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | ||||
| chunk_size=self.summary_chunk_tokens, | chunk_size=self.summary_chunk_tokens, | ||||
| chunk_overlap=self.summary_chunk_overlap, | chunk_overlap=self.summary_chunk_overlap, | ||||
| if len(docs) > 5: | if len(docs) > 5: | ||||
| docs = docs[:5] | docs = docs[:5] | ||||
| chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks) | |||||
| chain = self.get_summary_chain() | |||||
| try: | try: | ||||
| page_contents = chain.run(docs) | page_contents = chain.run(docs) | ||||
| # todo use cache | |||||
| except Exception as e: | except Exception as e: | ||||
| return f'Read this website failed, caused by: {str(e)}.' | return f'Read this website failed, caused by: {str(e)}.' | ||||
| else: | else: | ||||
| async def _arun(self, url: str) -> str: | async def _arun(self, url: str) -> str: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def get_summary_chain(self) -> RefineDocumentsChain: | |||||
| initial_chain = LLMChain( | |||||
| model_instance=self.model_instance, | |||||
| prompt=refine_prompts.PROMPT | |||||
| ) | |||||
| refine_chain = LLMChain( | |||||
| model_instance=self.model_instance, | |||||
| prompt=refine_prompts.REFINE_PROMPT | |||||
| ) | |||||
| return RefineDocumentsChain( | |||||
| initial_llm_chain=initial_chain, | |||||
| refine_llm_chain=refine_chain, | |||||
| document_variable_name="text", | |||||
| initial_response_name="existing_answer", | |||||
| callbacks=self.callbacks | |||||
| ) | |||||
| def page_result(text: str, cursor: int, max_length: int) -> str: | def page_result(text: str, cursor: int, max_length: int) -> str: | ||||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" |