|
|
|
@@ -17,12 +17,13 @@ from core.conversation_message_task import ConversationMessageTask |
|
|
|
from core.model_providers.error import ProviderTokenNotInitError |
|
|
|
from core.model_providers.model_factory import ModelFactory |
|
|
|
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode |
|
|
|
from core.model_providers.models.llm.base import BaseLLM |
|
|
|
from core.tool.current_datetime_tool import DatetimeTool |
|
|
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool |
|
|
|
from core.tool.provider.serpapi_provider import SerpAPIToolProvider |
|
|
|
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput |
|
|
|
from core.tool.web_reader_tool import WebReaderTool |
|
|
|
from extensions.ext_database import db |
|
|
|
from libs import helper |
|
|
|
from models.dataset import Dataset, DatasetProcessRule |
|
|
|
from models.model import AppModelConfig |
|
|
|
|
|
|
|
@@ -82,15 +83,19 @@ class OrchestratorRuleParser: |
|
|
|
try: |
|
|
|
summary_model_instance = ModelFactory.get_text_generation_model( |
|
|
|
tenant_id=self.tenant_id, |
|
|
|
model_provider_name=agent_provider_name, |
|
|
|
model_name=agent_model_name, |
|
|
|
model_kwargs=ModelKwargs( |
|
|
|
temperature=0, |
|
|
|
max_tokens=500 |
|
|
|
) |
|
|
|
), |
|
|
|
deduct_quota=False |
|
|
|
) |
|
|
|
except ProviderTokenNotInitError as e: |
|
|
|
summary_model_instance = None |
|
|
|
|
|
|
|
tools = self.to_tools( |
|
|
|
agent_model_instance=agent_model_instance, |
|
|
|
tool_configs=tool_configs, |
|
|
|
conversation_message_task=conversation_message_task, |
|
|
|
rest_tokens=rest_tokens, |
|
|
|
@@ -140,11 +145,12 @@ class OrchestratorRuleParser: |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask, |
|
|
|
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask, |
|
|
|
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: |
|
|
|
""" |
|
|
|
Convert app agent tool configs to tools |
|
|
|
|
|
|
|
:param agent_model_instance: |
|
|
|
:param rest_tokens: |
|
|
|
:param tool_configs: app agent tool configs |
|
|
|
:param conversation_message_task: |
|
|
|
@@ -162,7 +168,7 @@ class OrchestratorRuleParser: |
|
|
|
if tool_type == "dataset": |
|
|
|
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens) |
|
|
|
elif tool_type == "web_reader": |
|
|
|
tool = self.to_web_reader_tool() |
|
|
|
tool = self.to_web_reader_tool(agent_model_instance) |
|
|
|
elif tool_type == "google_search": |
|
|
|
tool = self.to_google_search_tool() |
|
|
|
elif tool_type == "wikipedia": |
|
|
|
@@ -207,24 +213,28 @@ class OrchestratorRuleParser: |
|
|
|
|
|
|
|
return tool |
|
|
|
|
|
|
|
def to_web_reader_tool(self) -> Optional[BaseTool]: |
|
|
|
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]: |
|
|
|
""" |
|
|
|
A tool for reading web pages |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
summary_model_instance = ModelFactory.get_text_generation_model( |
|
|
|
tenant_id=self.tenant_id, |
|
|
|
model_kwargs=ModelKwargs( |
|
|
|
temperature=0, |
|
|
|
max_tokens=500 |
|
|
|
try: |
|
|
|
summary_model_instance = ModelFactory.get_text_generation_model( |
|
|
|
tenant_id=self.tenant_id, |
|
|
|
model_provider_name=agent_model_instance.model_provider.provider_name, |
|
|
|
model_name=agent_model_instance.name, |
|
|
|
model_kwargs=ModelKwargs( |
|
|
|
temperature=0, |
|
|
|
max_tokens=500 |
|
|
|
), |
|
|
|
deduct_quota=False |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
summary_llm = summary_model_instance.client |
|
|
|
except ProviderTokenNotInitError: |
|
|
|
summary_model_instance = None |
|
|
|
|
|
|
|
tool = WebReaderTool( |
|
|
|
llm=summary_llm, |
|
|
|
llm=summary_model_instance.client if summary_model_instance else None, |
|
|
|
max_chunk_length=4000, |
|
|
|
continue_reading=True, |
|
|
|
callbacks=[DifyStdOutCallbackHandler()] |
|
|
|
@@ -252,11 +262,7 @@ class OrchestratorRuleParser: |
|
|
|
return tool |
|
|
|
|
|
|
|
def to_current_datetime_tool(self) -> Optional[BaseTool]: |
|
|
|
tool = Tool( |
|
|
|
name="current_datetime", |
|
|
|
description="A tool when you want to get the current date, time, week, month or year, " |
|
|
|
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".", |
|
|
|
func=helper.get_current_datetime, |
|
|
|
tool = DatetimeTool( |
|
|
|
callbacks=[DifyStdOutCallbackHandler()] |
|
|
|
) |
|
|
|
|