| @@ -2,14 +2,18 @@ import json | |||
| from typing import Tuple, List, Any, Union, Sequence, Optional, cast | |||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| @@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| arbitrary_types_allowed = True | |||
| @root_validator | |||
| def validate_llm(cls, values: dict) -> dict: | |||
| return values | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| @@ -65,7 +73,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||
| try: | |||
| agent_decision = super().plan(intermediate_steps, callbacks, **kwargs) | |||
| agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) | |||
| if isinstance(agent_decision, AgentAction): | |||
| tool_inputs = agent_decision.tool_input | |||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs: | |||
| @@ -76,6 +84,44 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| def real_plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||
| selected_inputs = { | |||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||
| } | |||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.content, | |||
| additional_kwargs={ | |||
| 'function_call': result.function_call | |||
| } | |||
| ) | |||
| agent_decision = _parse_ai_message(ai_message) | |||
| return agent_decision | |||
| async def aplan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| @@ -87,7 +133,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| llm: BaseLanguageModel, | |||
| model_instance: BaseLLM, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| @@ -96,11 +142,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| ), | |||
| **kwargs: Any, | |||
| ) -> BaseSingleActionAgent: | |||
| return super().from_llm_and_tools( | |||
| llm=llm, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| prompt = cls.create_prompt( | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_instance=model_instance, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| **kwargs, | |||
| ) | |||
| @@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ | |||
| _format_intermediate_steps | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken | |||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \ | |||
| get_buffer_string | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | |||
| from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin): | |||
| class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_model_instance: BaseLLM = None | |||
| model_instance: BaseLLM | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| @root_validator | |||
| def validate_llm(cls, values: dict) -> dict: | |||
| return values | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| llm: BaseLanguageModel, | |||
| model_instance: BaseLLM, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| @@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio | |||
| ), | |||
| **kwargs: Any, | |||
| ) -> BaseSingleActionAgent: | |||
| return super().from_llm_and_tools( | |||
| llm=llm, | |||
| prompt = cls.create_prompt( | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_instance=model_instance, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=cls.get_system_message(), | |||
| **kwargs, | |||
| ) | |||
| @@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio | |||
| :param query: | |||
| :return: | |||
| """ | |||
| original_max_tokens = self.llm.max_tokens | |||
| self.llm.max_tokens = 40 | |||
| original_max_tokens = self.model_instance.model_kwargs.max_tokens | |||
| self.model_instance.model_kwargs.max_tokens = 40 | |||
| prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | |||
| messages = prompt.to_messages() | |||
| try: | |||
| predicted_message = self.llm.predict_messages( | |||
| messages, functions=self.functions, callbacks=None | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| callbacks=None | |||
| ) | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| function_call = predicted_message.additional_kwargs.get("function_call", {}) | |||
| function_call = result.function_call | |||
| self.llm.max_tokens = original_max_tokens | |||
| self.model_instance.model_kwargs.max_tokens = original_max_tokens | |||
| return True if function_call else False | |||
| @@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio | |||
| except ExceededLLMTokensLimitError as e: | |||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||
| predicted_message = self.llm.predict_messages( | |||
| messages, functions=self.functions, callbacks=callbacks | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.content, | |||
| additional_kwargs={ | |||
| 'function_call': result.function_call | |||
| } | |||
| ) | |||
| agent_decision = _parse_ai_message(predicted_message) | |||
| agent_decision = _parse_ai_message(ai_message) | |||
| if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': | |||
| tool_inputs = agent_decision.tool_input | |||
| @@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio | |||
| return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) | |||
| except ValueError: | |||
| return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") | |||
| def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) | |||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||
| if rest_tokens >= 0: | |||
| return messages | |||
| system_message = None | |||
| human_message = None | |||
| should_summary_messages = [] | |||
| for message in messages: | |||
| if isinstance(message, SystemMessage): | |||
| system_message = message | |||
| elif isinstance(message, HumanMessage): | |||
| human_message = message | |||
| else: | |||
| should_summary_messages.append(message) | |||
| if len(should_summary_messages) > 2: | |||
| ai_message = should_summary_messages[-2] | |||
| function_message = should_summary_messages[-1] | |||
| should_summary_messages = should_summary_messages[self.moving_summary_index:-2] | |||
| self.moving_summary_index = len(should_summary_messages) | |||
| else: | |||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||
| raise ExceededLLMTokensLimitError(error_msg) | |||
| new_messages = [system_message, human_message] | |||
| if self.moving_summary_index == 0: | |||
| should_summary_messages.insert(0, human_message) | |||
| self.moving_summary_buffer = self.predict_new_summary( | |||
| messages=should_summary_messages, | |||
| existing_summary=self.moving_summary_buffer | |||
| ) | |||
| new_messages.append(AIMessage(content=self.moving_summary_buffer)) | |||
| new_messages.append(ai_message) | |||
| new_messages.append(function_message) | |||
| return new_messages | |||
| def predict_new_summary( | |||
| self, messages: List[BaseMessage], existing_summary: str | |||
| ) -> str: | |||
| new_lines = get_buffer_string( | |||
| messages, | |||
| human_prefix="Human", | |||
| ai_prefix="AI", | |||
| ) | |||
| chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) | |||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||
| if model_instance.model_provider.provider_name == 'azure_openai': | |||
| model = model_instance.base_model_name | |||
| model = model.replace("gpt-35", "gpt-3.5") | |||
| else: | |||
| model = model_instance.base_model_name | |||
| tiktoken_ = _import_tiktoken() | |||
| try: | |||
| encoding = tiktoken_.encoding_for_model(model) | |||
| except KeyError: | |||
| model = "cl100k_base" | |||
| encoding = tiktoken_.get_encoding(model) | |||
| if model.startswith("gpt-3.5-turbo"): | |||
| # every message follows <im_start>{role/name}\n{content}<im_end>\n | |||
| tokens_per_message = 4 | |||
| # if there's a name, the role is omitted | |||
| tokens_per_name = -1 | |||
| elif model.startswith("gpt-4"): | |||
| tokens_per_message = 3 | |||
| tokens_per_name = 1 | |||
| else: | |||
| raise NotImplementedError( | |||
| f"get_num_tokens_from_messages() is not presently implemented " | |||
| f"for model {model}." | |||
| "See https://github.com/openai/openai-python/blob/main/chatml.md for " | |||
| "information on how messages are converted to tokens." | |||
| ) | |||
| num_tokens = 0 | |||
| for m in messages: | |||
| message = _convert_message_to_dict(m) | |||
| num_tokens += tokens_per_message | |||
| for key, value in message.items(): | |||
| if key == "function_call": | |||
| for f_key, f_value in value.items(): | |||
| num_tokens += len(encoding.encode(f_key)) | |||
| num_tokens += len(encoding.encode(f_value)) | |||
| else: | |||
| num_tokens += len(encoding.encode(value)) | |||
| if key == "name": | |||
| num_tokens += tokens_per_name | |||
| # every reply is primed with <im_start>assistant | |||
| num_tokens += 3 | |||
| if kwargs.get('functions'): | |||
| for function in kwargs.get('functions'): | |||
| num_tokens += len(encoding.encode('name')) | |||
| num_tokens += len(encoding.encode(function.get("name"))) | |||
| num_tokens += len(encoding.encode('description')) | |||
| num_tokens += len(encoding.encode(function.get("description"))) | |||
| parameters = function.get("parameters") | |||
| num_tokens += len(encoding.encode('parameters')) | |||
| if 'title' in parameters: | |||
| num_tokens += len(encoding.encode('title')) | |||
| num_tokens += len(encoding.encode(parameters.get("title"))) | |||
| num_tokens += len(encoding.encode('type')) | |||
| num_tokens += len(encoding.encode(parameters.get("type"))) | |||
| if 'properties' in parameters: | |||
| num_tokens += len(encoding.encode('properties')) | |||
| for key, value in parameters.get('properties').items(): | |||
| num_tokens += len(encoding.encode(key)) | |||
| for field_key, field_value in value.items(): | |||
| num_tokens += len(encoding.encode(field_key)) | |||
| if field_key == 'enum': | |||
| for enum_field in field_value: | |||
| num_tokens += 3 | |||
| num_tokens += len(encoding.encode(enum_field)) | |||
| else: | |||
| num_tokens += len(encoding.encode(field_key)) | |||
| num_tokens += len(encoding.encode(str(field_value))) | |||
| if 'required' in parameters: | |||
| num_tokens += len(encoding.encode('required')) | |||
| for required_field in parameters['required']: | |||
| num_tokens += 3 | |||
| num_tokens += len(encoding.encode(required_field)) | |||
| return num_tokens | |||
| @@ -1,140 +0,0 @@ | |||
| from typing import cast, List | |||
| from langchain.chat_models import ChatOpenAI | |||
| from langchain.chat_models.openai import _convert_message_to_dict | |||
| from langchain.memory.summary import SummarizerMixin | |||
| from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from pydantic import BaseModel | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_llm: BaseLanguageModel = None | |||
| model_instance: BaseLLM | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) | |||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||
| if rest_tokens >= 0: | |||
| return messages | |||
| system_message = None | |||
| human_message = None | |||
| should_summary_messages = [] | |||
| for message in messages: | |||
| if isinstance(message, SystemMessage): | |||
| system_message = message | |||
| elif isinstance(message, HumanMessage): | |||
| human_message = message | |||
| else: | |||
| should_summary_messages.append(message) | |||
| if len(should_summary_messages) > 2: | |||
| ai_message = should_summary_messages[-2] | |||
| function_message = should_summary_messages[-1] | |||
| should_summary_messages = should_summary_messages[self.moving_summary_index:-2] | |||
| self.moving_summary_index = len(should_summary_messages) | |||
| else: | |||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||
| raise ExceededLLMTokensLimitError(error_msg) | |||
| new_messages = [system_message, human_message] | |||
| if self.moving_summary_index == 0: | |||
| should_summary_messages.insert(0, human_message) | |||
| summary_handler = SummarizerMixin(llm=self.summary_llm) | |||
| self.moving_summary_buffer = summary_handler.predict_new_summary( | |||
| messages=should_summary_messages, | |||
| existing_summary=self.moving_summary_buffer | |||
| ) | |||
| new_messages.append(AIMessage(content=self.moving_summary_buffer)) | |||
| new_messages.append(ai_message) | |||
| new_messages.append(function_message) | |||
| return new_messages | |||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||
| llm = cast(ChatOpenAI, model_instance.client) | |||
| model, encoding = llm._get_encoding_model() | |||
| if model.startswith("gpt-3.5-turbo"): | |||
| # every message follows <im_start>{role/name}\n{content}<im_end>\n | |||
| tokens_per_message = 4 | |||
| # if there's a name, the role is omitted | |||
| tokens_per_name = -1 | |||
| elif model.startswith("gpt-4"): | |||
| tokens_per_message = 3 | |||
| tokens_per_name = 1 | |||
| else: | |||
| raise NotImplementedError( | |||
| f"get_num_tokens_from_messages() is not presently implemented " | |||
| f"for model {model}." | |||
| "See https://github.com/openai/openai-python/blob/main/chatml.md for " | |||
| "information on how messages are converted to tokens." | |||
| ) | |||
| num_tokens = 0 | |||
| for m in messages: | |||
| message = _convert_message_to_dict(m) | |||
| num_tokens += tokens_per_message | |||
| for key, value in message.items(): | |||
| if key == "function_call": | |||
| for f_key, f_value in value.items(): | |||
| num_tokens += len(encoding.encode(f_key)) | |||
| num_tokens += len(encoding.encode(f_value)) | |||
| else: | |||
| num_tokens += len(encoding.encode(value)) | |||
| if key == "name": | |||
| num_tokens += tokens_per_name | |||
| # every reply is primed with <im_start>assistant | |||
| num_tokens += 3 | |||
| if kwargs.get('functions'): | |||
| for function in kwargs.get('functions'): | |||
| num_tokens += len(encoding.encode('name')) | |||
| num_tokens += len(encoding.encode(function.get("name"))) | |||
| num_tokens += len(encoding.encode('description')) | |||
| num_tokens += len(encoding.encode(function.get("description"))) | |||
| parameters = function.get("parameters") | |||
| num_tokens += len(encoding.encode('parameters')) | |||
| if 'title' in parameters: | |||
| num_tokens += len(encoding.encode('title')) | |||
| num_tokens += len(encoding.encode(parameters.get("title"))) | |||
| num_tokens += len(encoding.encode('type')) | |||
| num_tokens += len(encoding.encode(parameters.get("type"))) | |||
| if 'properties' in parameters: | |||
| num_tokens += len(encoding.encode('properties')) | |||
| for key, value in parameters.get('properties').items(): | |||
| num_tokens += len(encoding.encode(key)) | |||
| for field_key, field_value in value.items(): | |||
| num_tokens += len(encoding.encode(field_key)) | |||
| if field_key == 'enum': | |||
| for enum_field in field_value: | |||
| num_tokens += 3 | |||
| num_tokens += len(encoding.encode(enum_field)) | |||
| else: | |||
| num_tokens += len(encoding.encode(field_key)) | |||
| num_tokens += len(encoding.encode(str(field_value))) | |||
| if 'required' in parameters: | |||
| num_tokens += len(encoding.encode('required')) | |||
| for required_field in parameters['required']: | |||
| num_tokens += 3 | |||
| num_tokens += len(encoding.encode(required_field)) | |||
| return num_tokens | |||
| @@ -1,107 +0,0 @@ | |||
| from typing import List, Tuple, Any, Union, Sequence, Optional | |||
| from langchain.agents import BaseMultiActionAgent | |||
| from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \ | |||
| _parse_ai_message | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.tools import BaseTool | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | |||
| from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin | |||
| class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin): | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| llm: BaseLanguageModel, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| **kwargs: Any, | |||
| ) -> BaseMultiActionAgent: | |||
| return super().from_llm_and_tools( | |||
| llm=llm, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=cls.get_system_message(), | |||
| **kwargs, | |||
| ) | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| :param query: | |||
| :return: | |||
| """ | |||
| original_max_tokens = self.llm.max_tokens | |||
| self.llm.max_tokens = 15 | |||
| prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | |||
| messages = prompt.to_messages() | |||
| try: | |||
| predicted_message = self.llm.predict_messages( | |||
| messages, functions=self.functions, callbacks=None | |||
| ) | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| function_call = predicted_message.additional_kwargs.get("function_call", {}) | |||
| self.llm.max_tokens = original_max_tokens | |||
| return True if function_call else False | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||
| selected_inputs = { | |||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||
| } | |||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| # summarize messages if rest_tokens < 0 | |||
| try: | |||
| messages = self.summarize_messages_if_needed(messages, functions=self.functions) | |||
| except ExceededLLMTokensLimitError as e: | |||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||
| predicted_message = self.llm.predict_messages( | |||
| messages, functions=self.functions, callbacks=callbacks | |||
| ) | |||
| agent_decision = _parse_ai_message(predicted_message) | |||
| return agent_decision | |||
| @classmethod | |||
| def get_system_message(cls): | |||
| # get current time | |||
| return SystemMessage(content="You are a helpful AI assistant.\n" | |||
| "The current date or current time you know is wrong.\n" | |||
| "Respond directly if appropriate.") | |||
| @@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast | |||
| from langchain import BasePromptTemplate | |||
| from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | |||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |||
| @@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException | |||
| from langchain.tools import BaseTool | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| @@ -49,7 +49,6 @@ Action: | |||
| class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| model_instance: BaseLLM | |||
| dataset_tools: Sequence[BaseTool] | |||
| class Config: | |||
| @@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| try: | |||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| new_exception = self.llm_chain.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| try: | |||
| @@ -145,7 +144,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| llm: BaseLanguageModel, | |||
| model_instance: BaseLLM, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| output_parser: Optional[AgentOutputParser] = None, | |||
| @@ -157,17 +156,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| return super().from_llm_and_tools( | |||
| llm=llm, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| output_parser=output_parser, | |||
| """Construct an agent from an LLM and tools.""" | |||
| cls._validate_tools(tools) | |||
| prompt = cls.create_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| suffix=suffix, | |||
| human_message_template=human_message_template, | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables, | |||
| memory_prompts=memory_prompts, | |||
| ) | |||
| llm_chain = LLMChain( | |||
| model_instance=model_instance, | |||
| prompt=prompt, | |||
| callback_manager=callback_manager, | |||
| ) | |||
| tool_names = [tool.name for tool in tools] | |||
| _output_parser = output_parser | |||
| return cls( | |||
| llm_chain=llm_chain, | |||
| allowed_tools=tool_names, | |||
| output_parser=_output_parser, | |||
| dataset_tools=tools, | |||
| **kwargs, | |||
| ) | |||
| @@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional | |||
| from langchain import BasePromptTemplate | |||
| from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | |||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.memory.summary import SummarizerMixin | |||
| from langchain.memory.prompt import SUMMARY_PROMPT | |||
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException | |||
| from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \ | |||
| get_buffer_string | |||
| from langchain.tools import BaseTool | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| @@ -52,8 +53,7 @@ Action: | |||
| class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_llm: BaseLanguageModel = None | |||
| model_instance: BaseLLM | |||
| summary_model_instance: BaseLLM = None | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| @@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| if prompts: | |||
| messages = prompts[0].to_messages() | |||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages) | |||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages) | |||
| if rest_tokens < 0: | |||
| full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | |||
| try: | |||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| new_exception = self.llm_chain.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| try: | |||
| @@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| "I don't know how to respond to that."}, "") | |||
| def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): | |||
| if len(intermediate_steps) >= 2 and self.summary_llm: | |||
| if len(intermediate_steps) >= 2 and self.summary_model_instance: | |||
| should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] | |||
| should_summary_messages = [AIMessage(content=observation) | |||
| for _, observation in should_summary_intermediate_steps] | |||
| @@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| error_msg = "Exceeded LLM tokens limit, stopped." | |||
| raise ExceededLLMTokensLimitError(error_msg) | |||
| summary_handler = SummarizerMixin(llm=self.summary_llm) | |||
| if self.moving_summary_buffer and 'chat_history' in kwargs: | |||
| kwargs["chat_history"].pop() | |||
| self.moving_summary_buffer = summary_handler.predict_new_summary( | |||
| self.moving_summary_buffer = self.predict_new_summary( | |||
| messages=should_summary_messages, | |||
| existing_summary=self.moving_summary_buffer | |||
| ) | |||
| @@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| return self.get_full_inputs([intermediate_steps[-1]], **kwargs) | |||
| def predict_new_summary( | |||
| self, messages: List[BaseMessage], existing_summary: str | |||
| ) -> str: | |||
| new_lines = get_buffer_string( | |||
| messages, | |||
| human_prefix="Human", | |||
| ai_prefix="AI", | |||
| ) | |||
| chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) | |||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||
| @classmethod | |||
| def create_prompt( | |||
| cls, | |||
| @@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| llm: BaseLanguageModel, | |||
| model_instance: BaseLLM, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| output_parser: Optional[AgentOutputParser] = None, | |||
| @@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| return super().from_llm_and_tools( | |||
| llm=llm, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| output_parser=output_parser, | |||
| """Construct an agent from an LLM and tools.""" | |||
| cls._validate_tools(tools) | |||
| prompt = cls.create_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| suffix=suffix, | |||
| human_message_template=human_message_template, | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables, | |||
| memory_prompts=memory_prompts, | |||
| ) | |||
| llm_chain = LLMChain( | |||
| model_instance=model_instance, | |||
| prompt=prompt, | |||
| callback_manager=callback_manager, | |||
| ) | |||
| tool_names = [tool.name for tool in tools] | |||
| _output_parser = output_parser | |||
| return cls( | |||
| llm_chain=llm_chain, | |||
| allowed_tools=tool_names, | |||
| output_parser=_output_parser, | |||
| **kwargs, | |||
| ) | |||
| @@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra | |||
| from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||
| from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | |||
| from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent | |||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | |||
| @@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum): | |||
| REACT_ROUTER = 'react_router' | |||
| REACT = 'react' | |||
| FUNCTION_CALL = 'function_call' | |||
| MULTI_FUNCTION_CALL = 'multi_function_call' | |||
| class AgentConfiguration(BaseModel): | |||
| @@ -64,30 +62,18 @@ class AgentExecutor: | |||
| if self.configuration.strategy == PlanningStrategy.REACT: | |||
| agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| summary_llm=self.configuration.summary_model_instance.client | |||
| summary_model_instance=self.configuration.summary_model_instance | |||
| if self.configuration.summary_model_instance else None, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | |||
| agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | |||
| summary_llm=self.configuration.summary_model_instance.client | |||
| if self.configuration.summary_model_instance else None, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: | |||
| agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | |||
| summary_llm=self.configuration.summary_model_instance.client | |||
| summary_model_instance=self.configuration.summary_model_instance | |||
| if self.configuration.summary_model_instance else None, | |||
| verbose=True | |||
| ) | |||
| @@ -95,7 +81,6 @@ class AgentExecutor: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, | |||
| verbose=True | |||
| @@ -104,7 +89,6 @@ class AgentExecutor: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| verbose=True | |||
| @@ -0,0 +1,36 @@ | |||
| from typing import List, Dict, Any, Optional | |||
| from langchain import LLMChain as LCLLMChain | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.schema import LLMResult, Generation | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class LLMChain(LCLLMChain): | |||
| model_instance: BaseLLM | |||
| """The language model instance to use.""" | |||
| llm: BaseLanguageModel = FakeLLM(response="") | |||
| def generate( | |||
| self, | |||
| input_list: List[Dict[str, Any]], | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> LLMResult: | |||
| """Generate LLM result from inputs.""" | |||
| prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) | |||
| messages = prompts[0].to_messages() | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| stop=stop | |||
| ) | |||
| generations = [ | |||
| [Generation(text=result.content)] | |||
| ] | |||
| return LLMResult(generations=generations) | |||
| @@ -1,6 +1,6 @@ | |||
| import enum | |||
| from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage | |||
| from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage | |||
| from pydantic import BaseModel | |||
| @@ -9,6 +9,7 @@ class LLMRunResult(BaseModel): | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| source: list = None | |||
| function_call: dict = None | |||
| class MessageType(enum.Enum): | |||
| @@ -20,6 +21,7 @@ class MessageType(enum.Enum): | |||
| class PromptMessage(BaseModel): | |||
| type: MessageType = MessageType.HUMAN | |||
| content: str = '' | |||
| function_call: dict = None | |||
| def to_lc_messages(messages: list[PromptMessage]): | |||
| @@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]): | |||
| if message.type == MessageType.HUMAN: | |||
| lc_messages.append(HumanMessage(content=message.content)) | |||
| elif message.type == MessageType.ASSISTANT: | |||
| lc_messages.append(AIMessage(content=message.content)) | |||
| additional_kwargs = {} | |||
| if message.function_call: | |||
| additional_kwargs['function_call'] = message.function_call | |||
| lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs)) | |||
| elif message.type == MessageType.SYSTEM: | |||
| lc_messages.append(SystemMessage(content=message.content)) | |||
| @@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]): | |||
| if isinstance(message, HumanMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | |||
| elif isinstance(message, AIMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT)) | |||
| message_kwargs = { | |||
| 'content': message.content, | |||
| 'type': MessageType.ASSISTANT | |||
| } | |||
| if 'function_call' in message.additional_kwargs: | |||
| message_kwargs['function_call'] = message.additional_kwargs['function_call'] | |||
| prompt_messages.append(PromptMessage(**message_kwargs)) | |||
| elif isinstance(message, SystemMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) | |||
| elif isinstance(message, FunctionMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | |||
| return prompt_messages | |||
| @@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM): | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| generate_kwargs = { | |||
| 'stop': stop, | |||
| 'callbacks': callbacks | |||
| } | |||
| if isinstance(prompts, str): | |||
| generate_kwargs['prompts'] = [prompts] | |||
| else: | |||
| generate_kwargs['messages'] = [prompts] | |||
| if 'functions' in kwargs: | |||
| generate_kwargs['functions'] = kwargs['functions'] | |||
| return self._client.generate(**generate_kwargs) | |||
| @property | |||
| def base_model_name(self) -> str: | |||
| @@ -13,7 +13,8 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, | |||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler | |||
| from core.helper import moderation | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \ | |||
| to_lc_messages | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| @@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel): | |||
| except Exception as ex: | |||
| raise self.handle_exceptions(ex) | |||
| function_call = None | |||
| if isinstance(result.generations[0][0], ChatGeneration): | |||
| completion_content = result.generations[0][0].message.content | |||
| if 'function_call' in result.generations[0][0].message.additional_kwargs: | |||
| function_call = result.generations[0][0].message.additional_kwargs.get('function_call') | |||
| else: | |||
| completion_content = result.generations[0][0].text | |||
| @@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel): | |||
| return LLMRunResult( | |||
| content=completion_content, | |||
| prompt_tokens=prompt_tokens, | |||
| completion_tokens=completion_tokens | |||
| completion_tokens=completion_tokens, | |||
| function_call=function_call | |||
| ) | |||
| @abstractmethod | |||
| @@ -442,16 +447,7 @@ class BaseLLM(BaseProviderModel): | |||
| if len(messages) == 0: | |||
| return [] | |||
| chat_messages = [] | |||
| for message in messages: | |||
| if message.type == MessageType.HUMAN: | |||
| chat_messages.append(HumanMessage(content=message.content)) | |||
| elif message.type == MessageType.ASSISTANT: | |||
| chat_messages.append(AIMessage(content=message.content)) | |||
| elif message.type == MessageType.SYSTEM: | |||
| chat_messages.append(SystemMessage(content=message.content)) | |||
| return chat_messages | |||
| return to_lc_messages(messages) | |||
| def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: | |||
| """ | |||
| @@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM): | |||
| raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| generate_kwargs = { | |||
| 'stop': stop, | |||
| 'callbacks': callbacks | |||
| } | |||
| if isinstance(prompts, str): | |||
| generate_kwargs['prompts'] = [prompts] | |||
| else: | |||
| generate_kwargs['messages'] = [prompts] | |||
| if 'functions' in kwargs: | |||
| generate_kwargs['functions'] = kwargs['functions'] | |||
| return self._client.generate(**generate_kwargs) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| @@ -1,7 +1,6 @@ | |||
| import math | |||
| from typing import Optional | |||
| from flask import current_app | |||
| from langchain import WikipediaAPIWrapper | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| @@ -27,7 +26,6 @@ from core.tool.web_reader_tool import WebReaderTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetProcessRule | |||
| from models.model import AppModelConfig | |||
| from models.provider import ProviderType | |||
| class OrchestratorRuleParser: | |||
| @@ -77,7 +75,7 @@ class OrchestratorRuleParser: | |||
| # only OpenAI chat model (include Azure) support function call, use ReACT instead | |||
| if agent_model_instance.model_mode != ModelMode.CHAT \ | |||
| or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: | |||
| if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: | |||
| if planning_strategy == PlanningStrategy.FUNCTION_CALL: | |||
| planning_strategy = PlanningStrategy.REACT | |||
| elif planning_strategy == PlanningStrategy.ROUTER: | |||
| planning_strategy = PlanningStrategy.REACT_ROUTER | |||
| @@ -207,7 +205,10 @@ class OrchestratorRuleParser: | |||
| tool = self.to_current_datetime_tool() | |||
| if tool: | |||
| tool.callbacks.extend(callbacks) | |||
| if tool.callbacks is not None: | |||
| tool.callbacks.extend(callbacks) | |||
| else: | |||
| tool.callbacks = callbacks | |||
| tools.append(tool) | |||
| return tools | |||
| @@ -269,10 +270,9 @@ class OrchestratorRuleParser: | |||
| summary_model_instance = None | |||
| tool = WebReaderTool( | |||
| llm=summary_model_instance.client if summary_model_instance else None, | |||
| model_instance=summary_model_instance if summary_model_instance else None, | |||
| max_chunk_length=4000, | |||
| continue_reading=True, | |||
| callbacks=[DifyStdOutCallbackHandler()] | |||
| continue_reading=True | |||
| ) | |||
| return tool | |||
| @@ -290,16 +290,13 @@ class OrchestratorRuleParser: | |||
| "is not up to date. " | |||
| "Input should be a search query.", | |||
| func=OptimizedSerpAPIWrapper(**func_kwargs).run, | |||
| args_schema=OptimizedSerpAPIInput, | |||
| callbacks=[DifyStdOutCallbackHandler()] | |||
| args_schema=OptimizedSerpAPIInput | |||
| ) | |||
| return tool | |||
| def to_current_datetime_tool(self) -> Optional[BaseTool]: | |||
| tool = DatetimeTool( | |||
| callbacks=[DifyStdOutCallbackHandler()] | |||
| ) | |||
| tool = DatetimeTool() | |||
| return tool | |||
| @@ -310,8 +307,7 @@ class OrchestratorRuleParser: | |||
| return WikipediaQueryRun( | |||
| name="wikipedia", | |||
| api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), | |||
| args_schema=WikipediaInput, | |||
| callbacks=[DifyStdOutCallbackHandler()] | |||
| args_schema=WikipediaInput | |||
| ) | |||
| @classmethod | |||
| @@ -11,8 +11,8 @@ from typing import Type | |||
| import requests | |||
| from bs4 import BeautifulSoup, NavigableString, Comment, CData | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.chains.summarize import load_summarize_chain | |||
| from langchain.chains import RefineDocumentsChain | |||
| from langchain.chains.summarize import refine_prompts | |||
| from langchain.schema import Document | |||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||
| from langchain.tools.base import BaseTool | |||
| @@ -20,8 +20,10 @@ from newspaper import Article | |||
| from pydantic import BaseModel, Field | |||
| from regex import regex | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.data_loader import file_extractor | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| FULL_TEMPLATE = """ | |||
| TITLE: {title} | |||
| @@ -65,7 +67,7 @@ class WebReaderTool(BaseTool): | |||
| summary_chunk_overlap: int = 0 | |||
| summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] | |||
| continue_reading: bool = True | |||
| llm: BaseLanguageModel = None | |||
| model_instance: BaseLLM = None | |||
| def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: | |||
| try: | |||
| @@ -78,7 +80,7 @@ class WebReaderTool(BaseTool): | |||
| except Exception as e: | |||
| return f'Read this website failed, caused by: {str(e)}.' | |||
| if summary and self.llm: | |||
| if summary and self.model_instance: | |||
| character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |||
| chunk_size=self.summary_chunk_tokens, | |||
| chunk_overlap=self.summary_chunk_overlap, | |||
| @@ -95,10 +97,9 @@ class WebReaderTool(BaseTool): | |||
| if len(docs) > 5: | |||
| docs = docs[:5] | |||
| chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks) | |||
| chain = self.get_summary_chain() | |||
| try: | |||
| page_contents = chain.run(docs) | |||
| # todo use cache | |||
| except Exception as e: | |||
| return f'Read this website failed, caused by: {str(e)}.' | |||
| else: | |||
| @@ -114,6 +115,23 @@ class WebReaderTool(BaseTool): | |||
| async def _arun(self, url: str) -> str: | |||
| raise NotImplementedError | |||
| def get_summary_chain(self) -> RefineDocumentsChain: | |||
| initial_chain = LLMChain( | |||
| model_instance=self.model_instance, | |||
| prompt=refine_prompts.PROMPT | |||
| ) | |||
| refine_chain = LLMChain( | |||
| model_instance=self.model_instance, | |||
| prompt=refine_prompts.REFINE_PROMPT | |||
| ) | |||
| return RefineDocumentsChain( | |||
| initial_llm_chain=initial_chain, | |||
| refine_llm_chain=refine_chain, | |||
| document_variable_name="text", | |||
| initial_response_name="existing_answer", | |||
| callbacks=self.callbacks | |||
| ) | |||
| def page_result(text: str, cursor: int, max_length: int) -> str: | |||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | |||