| @@ -32,7 +32,6 @@ from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.tools.entities.tool_entities import ( | |||
| ToolInvokeMessage, | |||
| ToolParameter, | |||
| ToolRuntimeVariablePool, | |||
| ) | |||
| @@ -141,24 +140,6 @@ class BaseAgentRunner(AppRunner): | |||
| app_generate_entity.app_config.prompt_template.simple_prompt_template = '' | |||
| return app_generate_entity | |||
| def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: | |||
| """ | |||
| Handle tool response | |||
| """ | |||
| result = '' | |||
| for response in tool_response: | |||
| if response.type == ToolInvokeMessage.MessageType.TEXT: | |||
| result += response.message | |||
| elif response.type == ToolInvokeMessage.MessageType.LINK: | |||
| result += f"result link: {response.message}. please tell user to check it." | |||
| elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ | |||
| response.type == ToolInvokeMessage.MessageType.IMAGE: | |||
| result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." | |||
| else: | |||
| result += f"tool response: {response.message}." | |||
| return result | |||
| def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: | |||
| """ | |||
| @@ -95,6 +95,7 @@ class ToolInvokeMessage(BaseModel): | |||
| IMAGE = "image" | |||
| LINK = "link" | |||
| BLOB = "blob" | |||
| JSON = "json" | |||
| IMAGE_LINK = "image_link" | |||
| FILE_VAR = "file_var" | |||
| @@ -102,7 +103,7 @@ class ToolInvokeMessage(BaseModel): | |||
| """ | |||
| plain text, image url or link url | |||
| """ | |||
| message: Union[str, bytes] = None | |||
| message: Union[str, bytes, dict] = None | |||
| meta: dict[str, Any] = None | |||
| save_as: str = '' | |||
| @@ -8,99 +8,36 @@ from core.tools.tool.builtin_tool import BuiltinTool | |||
| SERP_API_URL = "https://serpapi.com/search" | |||
| class SerpAPI: | |||
| """ | |||
| SerpAPI tool provider. | |||
| """ | |||
| def __init__(self, api_key: str) -> None: | |||
| """Initialize SerpAPI tool provider.""" | |||
| self.serpapi_api_key = api_key | |||
| def run(self, query: str, **kwargs: Any) -> str: | |||
| """Run query through SerpAPI and parse result.""" | |||
| typ = kwargs.get("result_type", "text") | |||
| return self._process_response(self.results(query), typ=typ) | |||
| def results(self, query: str) -> dict: | |||
| """Run query through SerpAPI and return the raw result.""" | |||
| params = self.get_params(query) | |||
| response = requests.get(url=SERP_API_URL, params=params) | |||
| response.raise_for_status() | |||
| return response.json() | |||
| class GoogleSearchTool(BuiltinTool): | |||
| def get_params(self, query: str) -> dict[str, str]: | |||
| """Get parameters for SerpAPI.""" | |||
| def _parse_response(self, response: dict) -> dict: | |||
| result = {} | |||
| if "knowledge_graph" in response: | |||
| result["title"] = response["knowledge_graph"].get("title", "") | |||
| result["description"] = response["knowledge_graph"].get("description", "") | |||
| if "organic_results" in response: | |||
| result["organic_results"] = [ | |||
| { | |||
| "title": item.get("title", ""), | |||
| "link": item.get("link", ""), | |||
| "snippet": item.get("snippet", "") | |||
| } | |||
| for item in response["organic_results"] | |||
| ] | |||
| return result | |||
| def _invoke(self, | |||
| user_id: str, | |||
| tool_parameters: dict[str, Any], | |||
| ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||
| params = { | |||
| "api_key": self.serpapi_api_key, | |||
| "q": query, | |||
| "api_key": self.runtime.credentials['serpapi_api_key'], | |||
| "q": tool_parameters['query'], | |||
| "engine": "google", | |||
| "google_domain": "google.com", | |||
| "gl": "us", | |||
| "hl": "en" | |||
| } | |||
| return params | |||
| @staticmethod | |||
| def _process_response(res: dict, typ: str) -> str: | |||
| """ | |||
| Process response from SerpAPI. | |||
| SerpAPI doc: https://serpapi.com/search-api | |||
| Google search main results are called organic results | |||
| """ | |||
| if "error" in res: | |||
| raise ValueError(f"Got error from SerpAPI: {res['error']}") | |||
| toret = "" | |||
| if typ == "text": | |||
| if "knowledge_graph" in res and "description" in res["knowledge_graph"]: | |||
| toret += res["knowledge_graph"]["description"] + "\n" | |||
| if "organic_results" in res: | |||
| snippets = [ | |||
| f"content: {item.get('snippet')}\nlink: {item.get('link')}" | |||
| for item in res["organic_results"] | |||
| if "snippet" in item | |||
| ] | |||
| toret += "\n".join(snippets) | |||
| elif typ == "link": | |||
| if "knowledge_graph" in res and "source" in res["knowledge_graph"]: | |||
| toret += res["knowledge_graph"]["source"]["link"] | |||
| elif "organic_results" in res: | |||
| links = [ | |||
| f"[{item['title']}]({item['link']})\n" | |||
| for item in res["organic_results"] | |||
| if "title" in item and "link" in item | |||
| ] | |||
| toret += "\n".join(links) | |||
| elif "related_questions" in res: | |||
| questions = [ | |||
| f"[{item['question']}]({item['link']})\n" | |||
| for item in res["related_questions"] | |||
| if "question" in item and "link" in item | |||
| ] | |||
| toret += "\n".join(questions) | |||
| elif "related_searches" in res: | |||
| searches = [ | |||
| f"[{item['query']}]({item['link']})\n" | |||
| for item in res["related_searches"] | |||
| if "query" in item and "link" in item | |||
| ] | |||
| toret += "\n".join(searches) | |||
| if not toret: | |||
| toret = "No good search result found" | |||
| return toret | |||
| class GoogleSearchTool(BuiltinTool): | |||
| def _invoke(self, | |||
| user_id: str, | |||
| tool_parameters: dict[str, Any], | |||
| ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||
| """ | |||
| invoke tools | |||
| """ | |||
| query = tool_parameters['query'] | |||
| result_type = tool_parameters['result_type'] | |||
| api_key = self.runtime.credentials['serpapi_api_key'] | |||
| result = SerpAPI(api_key).run(query, result_type=result_type) | |||
| if result_type == 'text': | |||
| return self.create_text_message(text=result) | |||
| return self.create_link_message(link=result) | |||
| response = requests.get(url=SERP_API_URL, params=params) | |||
| response.raise_for_status() | |||
| valuable_res = self._parse_response(response.json()) | |||
| return self.create_json_message(valuable_res) | |||
| @@ -25,27 +25,3 @@ parameters: | |||
| pt_BR: used for searching | |||
| llm_description: key words for searching | |||
| form: llm | |||
| - name: result_type | |||
| type: select | |||
| required: true | |||
| options: | |||
| - value: text | |||
| label: | |||
| en_US: text | |||
| zh_Hans: 文本 | |||
| pt_BR: texto | |||
| - value: link | |||
| label: | |||
| en_US: link | |||
| zh_Hans: 链接 | |||
| pt_BR: link | |||
| default: link | |||
| label: | |||
| en_US: Result type | |||
| zh_Hans: 结果类型 | |||
| pt_BR: Result type | |||
| human_description: | |||
| en_US: used for selecting the result type, text or link | |||
| zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 | |||
| pt_BR: used for selecting the result type, text or link | |||
| form: form | |||
| @@ -207,30 +207,7 @@ class Tool(BaseModel, ABC): | |||
| result = [result] | |||
| return result | |||
| def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: | |||
| """ | |||
| Handle tool response | |||
| """ | |||
| result = '' | |||
| for response in tool_response: | |||
| if response.type == ToolInvokeMessage.MessageType.TEXT: | |||
| result += response.message | |||
| elif response.type == ToolInvokeMessage.MessageType.LINK: | |||
| result += f"result link: {response.message}. please tell user to check it. \n" | |||
| elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ | |||
| response.type == ToolInvokeMessage.MessageType.IMAGE: | |||
| result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n" | |||
| elif response.type == ToolInvokeMessage.MessageType.BLOB: | |||
| if len(response.message) > 114: | |||
| result += str(response.message[:114]) + '...' | |||
| else: | |||
| result += str(response.message) | |||
| else: | |||
| result += f"tool response: {response.message}. \n" | |||
| return result | |||
| def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| Transform tool parameters type | |||
| @@ -355,3 +332,12 @@ class Tool(BaseModel, ABC): | |||
| message=blob, meta=meta, | |||
| save_as=save_as | |||
| ) | |||
| def create_json_message(self, object: dict) -> ToolInvokeMessage: | |||
| """ | |||
| create a json message | |||
| """ | |||
| return ToolInvokeMessage( | |||
| type=ToolInvokeMessage.MessageType.JSON, | |||
| message=object | |||
| ) | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| from copy import deepcopy | |||
| from datetime import datetime, timezone | |||
| from mimetypes import guess_type | |||
| @@ -188,6 +189,8 @@ class ToolEngine: | |||
| elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ | |||
| response.type == ToolInvokeMessage.MessageType.IMAGE: | |||
| result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." | |||
| elif response.type == ToolInvokeMessage.MessageType.JSON: | |||
| result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." | |||
| else: | |||
| result += f"tool response: {response.message}." | |||
| @@ -74,13 +74,14 @@ class ToolNode(BaseNode): | |||
| ) | |||
| # convert tool messages | |||
| plain_text, files = self._convert_tool_messages(messages) | |||
| plain_text, files, json = self._convert_tool_messages(messages) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={ | |||
| 'text': plain_text, | |||
| 'files': files | |||
| 'files': files, | |||
| 'json': json | |||
| }, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOOL_INFO: tool_info | |||
| @@ -149,8 +150,9 @@ class ToolNode(BaseNode): | |||
| # extract plain text and files | |||
| files = self._extract_tool_response_binary(messages) | |||
| plain_text = self._extract_tool_response_text(messages) | |||
| json = self._extract_tool_response_json(messages) | |||
| return plain_text, files | |||
| return plain_text, files, json | |||
| def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]: | |||
| """ | |||
| @@ -203,7 +205,9 @@ class ToolNode(BaseNode): | |||
| f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' | |||
| for message in tool_response | |||
| ]) | |||
| def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: | |||
| return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: | |||