Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

tool.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. import json
  2. from collections.abc import Generator
  3. from dataclasses import dataclass
  4. from os import getenv
  5. from typing import Any, Optional, Union
  6. from urllib.parse import urlencode
  7. import httpx
  8. from core.file.file_manager import download
  9. from core.helper import ssrf_proxy
  10. from core.tools.__base.tool import Tool
  11. from core.tools.__base.tool_runtime import ToolRuntime
  12. from core.tools.entities.tool_bundle import ApiToolBundle
  13. from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
  14. from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
  15. API_TOOL_DEFAULT_TIMEOUT = (
  16. int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
  17. int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
  18. )
  19. @dataclass
  20. class ParsedResponse:
  21. """Represents a parsed HTTP response with type information"""
  22. content: Union[str, dict]
  23. is_json: bool
  24. def to_string(self) -> str:
  25. """Convert response to string format for credential validation"""
  26. if isinstance(self.content, dict):
  27. return json.dumps(self.content, ensure_ascii=False)
  28. return str(self.content)
  29. class ApiTool(Tool):
  30. """
  31. Api tool
  32. """
  33. def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str):
  34. super().__init__(entity, runtime)
  35. self.api_bundle = api_bundle
  36. self.provider_id = provider_id
  37. def fork_tool_runtime(self, runtime: ToolRuntime):
  38. """
  39. fork a new tool with metadata
  40. :return: the new tool
  41. """
  42. if self.api_bundle is None:
  43. raise ValueError("api_bundle is required")
  44. return self.__class__(
  45. entity=self.entity,
  46. api_bundle=self.api_bundle.model_copy(),
  47. runtime=runtime,
  48. provider_id=self.provider_id,
  49. )
  50. def validate_credentials(
  51. self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
  52. ) -> str:
  53. """
  54. validate the credentials for Api tool
  55. """
  56. # assemble validate request and request parameters
  57. headers = self.assembling_request(parameters)
  58. if format_only:
  59. return ""
  60. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
  61. # validate response
  62. parsed_response = self.validate_and_parse_response(response)
  63. # For credential validation, always return as string
  64. return parsed_response.to_string()
  65. def tool_provider_type(self) -> ToolProviderType:
  66. return ToolProviderType.API
  67. def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
  68. if self.runtime is None:
  69. raise ToolProviderCredentialValidationError("runtime not initialized")
  70. headers = {}
  71. if self.runtime is None:
  72. raise ValueError("runtime is required")
  73. credentials = self.runtime.credentials or {}
  74. if "auth_type" not in credentials:
  75. raise ToolProviderCredentialValidationError("Missing auth_type")
  76. if credentials["auth_type"] in ("api_key_header", "api_key"): # backward compatibility:
  77. api_key_header = "Authorization"
  78. if "api_key_header" in credentials:
  79. api_key_header = credentials["api_key_header"]
  80. if "api_key_value" not in credentials:
  81. raise ToolProviderCredentialValidationError("Missing api_key_value")
  82. elif not isinstance(credentials["api_key_value"], str):
  83. raise ToolProviderCredentialValidationError("api_key_value must be a string")
  84. if "api_key_header_prefix" in credentials:
  85. api_key_header_prefix = credentials["api_key_header_prefix"]
  86. if api_key_header_prefix == "basic" and credentials["api_key_value"]:
  87. credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
  88. elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
  89. credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
  90. elif api_key_header_prefix == "custom":
  91. pass
  92. headers[api_key_header] = credentials["api_key_value"]
  93. elif credentials["auth_type"] == "api_key_query":
  94. # For query parameter authentication, we don't add anything to headers
  95. # The query parameter will be added in do_http_request method
  96. pass
  97. needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
  98. for parameter in needed_parameters:
  99. if parameter.required and parameter.name not in parameters:
  100. if parameter.default is not None:
  101. parameters[parameter.name] = parameter.default
  102. else:
  103. raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
  104. return headers
  105. def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
  106. """
  107. validate the response and return parsed content with type information
  108. :return: ParsedResponse with content and is_json flag
  109. """
  110. if isinstance(response, httpx.Response):
  111. if response.status_code >= 400:
  112. raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
  113. if not response.content:
  114. return ParsedResponse(
  115. "Empty response from the tool, please check your parameters and try again.", False
  116. )
  117. # Check content type
  118. content_type = response.headers.get("content-type", "").lower()
  119. is_json_content_type = "application/json" in content_type
  120. # Try to parse as JSON
  121. try:
  122. json_response = response.json()
  123. # If content-type indicates JSON, return as JSON object
  124. if is_json_content_type:
  125. return ParsedResponse(json_response, True)
  126. else:
  127. # If content-type doesn't indicate JSON, treat as text regardless of content
  128. return ParsedResponse(response.text, False)
  129. except Exception:
  130. # Not valid JSON, return as text
  131. return ParsedResponse(response.text, False)
  132. else:
  133. raise ValueError(f"Invalid response type {type(response)}")
  134. @staticmethod
  135. def get_parameter_value(parameter, parameters):
  136. if parameter["name"] in parameters:
  137. return parameters[parameter["name"]]
  138. elif parameter.get("required", False):
  139. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  140. else:
  141. return (parameter.get("schema", {}) or {}).get("default", "")
  142. def do_http_request(
  143. self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
  144. ) -> httpx.Response:
  145. """
  146. do http request depending on api bundle
  147. """
  148. method = method.lower()
  149. params = {}
  150. path_params = {}
  151. # FIXME: body should be a dict[str, Any] but it changed a lot in this function
  152. body: Any = {}
  153. cookies = {}
  154. files = []
  155. # Add API key to query parameters if auth_type is api_key_query
  156. if self.runtime and self.runtime.credentials:
  157. credentials = self.runtime.credentials
  158. if credentials.get("auth_type") == "api_key_query":
  159. api_key_query_param = credentials.get("api_key_query_param", "key")
  160. api_key_value = credentials.get("api_key_value")
  161. if api_key_value:
  162. params[api_key_query_param] = api_key_value
  163. # check parameters
  164. for parameter in self.api_bundle.openapi.get("parameters", []):
  165. value = self.get_parameter_value(parameter, parameters)
  166. if parameter["in"] == "path":
  167. path_params[parameter["name"]] = value
  168. elif parameter["in"] == "query":
  169. if value != "":
  170. params[parameter["name"]] = value
  171. elif parameter["in"] == "cookie":
  172. cookies[parameter["name"]] = value
  173. elif parameter["in"] == "header":
  174. headers[parameter["name"]] = str(value)
  175. # check if there is a request body and handle it
  176. if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
  177. # handle json request body
  178. if "content" in self.api_bundle.openapi["requestBody"]:
  179. for content_type in self.api_bundle.openapi["requestBody"]["content"]:
  180. headers["Content-Type"] = content_type
  181. body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
  182. # handle ref schema
  183. if "$ref" in body_schema:
  184. ref_path = body_schema["$ref"].split("/")
  185. ref_name = ref_path[-1]
  186. if (
  187. "components" in self.api_bundle.openapi
  188. and "schemas" in self.api_bundle.openapi["components"]
  189. ):
  190. if ref_name in self.api_bundle.openapi["components"]["schemas"]:
  191. body_schema = self.api_bundle.openapi["components"]["schemas"][ref_name]
  192. required = body_schema.get("required", [])
  193. properties = body_schema.get("properties", {})
  194. for name, property in properties.items():
  195. if name in parameters:
  196. # multiple file upload: if the type is array and the items have format as binary
  197. if property.get("type") == "array" and property.get("items", {}).get("format") == "binary":
  198. # parameters[name] should be a list of file objects.
  199. for f in parameters[name]:
  200. files.append((name, (f.filename, download(f), f.mime_type)))
  201. elif property.get("format") == "binary":
  202. f = parameters[name]
  203. files.append((name, (f.filename, download(f), f.mime_type)))
  204. elif "$ref" in property:
  205. body[name] = parameters[name]
  206. else:
  207. # convert type
  208. body[name] = self._convert_body_property_type(property, parameters[name])
  209. elif name in required:
  210. raise ToolParameterValidationError(
  211. f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
  212. )
  213. elif "default" in property:
  214. body[name] = property["default"]
  215. else:
  216. # omit optional parameters that weren't provided, instead of setting them to None
  217. pass
  218. break
  219. # replace path parameters
  220. for name, value in path_params.items():
  221. url = url.replace(f"{{{name}}}", f"{value}")
  222. # parse http body data if needed
  223. if "Content-Type" in headers:
  224. if headers["Content-Type"] == "application/json":
  225. body = json.dumps(body)
  226. elif headers["Content-Type"] == "application/x-www-form-urlencoded":
  227. body = urlencode(body)
  228. else:
  229. body = body
  230. # if there is a file upload, remove the Content-Type header
  231. # so that httpx can automatically generate the boundary header required for multipart/form-data.
  232. # issue: https://github.com/langgenius/dify/issues/13684
  233. # reference: https://stackoverflow.com/questions/39280438/fetch-missing-boundary-in-multipart-form-data-post
  234. if files:
  235. headers.pop("Content-Type", None)
  236. if method in {
  237. "get",
  238. "head",
  239. "post",
  240. "put",
  241. "delete",
  242. "patch",
  243. "options",
  244. "GET",
  245. "POST",
  246. "PUT",
  247. "PATCH",
  248. "DELETE",
  249. "HEAD",
  250. "OPTIONS",
  251. }:
  252. response: httpx.Response = getattr(ssrf_proxy, method.lower())(
  253. url,
  254. params=params,
  255. headers=headers,
  256. cookies=cookies,
  257. data=body,
  258. files=files,
  259. timeout=API_TOOL_DEFAULT_TIMEOUT,
  260. follow_redirects=True,
  261. )
  262. return response
  263. else:
  264. raise ValueError(f"Invalid http method {method}")
  265. def _convert_body_property_any_of(
  266. self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
  267. ) -> Any:
  268. if max_recursive <= 0:
  269. raise Exception("Max recursion depth reached")
  270. for option in any_of or []:
  271. try:
  272. if "type" in option:
  273. # Attempt to convert the value based on the type.
  274. if option["type"] == "integer" or option["type"] == "int":
  275. return int(value)
  276. elif option["type"] == "number":
  277. if "." in str(value):
  278. return float(value)
  279. else:
  280. return int(value)
  281. elif option["type"] == "string":
  282. return str(value)
  283. elif option["type"] == "boolean":
  284. if str(value).lower() in {"true", "1"}:
  285. return True
  286. elif str(value).lower() in {"false", "0"}:
  287. return False
  288. else:
  289. continue # Not a boolean, try next option
  290. elif option["type"] == "null" and not value:
  291. return None
  292. else:
  293. continue # Unsupported type, try next option
  294. elif "anyOf" in option and isinstance(option["anyOf"], list):
  295. # Recursive call to handle nested anyOf
  296. return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1)
  297. except ValueError:
  298. continue # Conversion failed, try next option
  299. # If no option succeeded, you might want to return the value as is or raise an error
  300. return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
  301. def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
  302. try:
  303. if "type" in property:
  304. if property["type"] == "integer" or property["type"] == "int":
  305. return int(value)
  306. elif property["type"] == "number":
  307. # check if it is a float
  308. if "." in str(value):
  309. return float(value)
  310. else:
  311. return int(value)
  312. elif property["type"] == "string":
  313. return str(value)
  314. elif property["type"] == "boolean":
  315. return bool(value)
  316. elif property["type"] == "null":
  317. if value is None:
  318. return None
  319. elif property["type"] == "object" or property["type"] == "array":
  320. if isinstance(value, str):
  321. try:
  322. return json.loads(value)
  323. except ValueError:
  324. return value
  325. elif isinstance(value, dict):
  326. return value
  327. else:
  328. return value
  329. else:
  330. raise ValueError(f"Invalid type {property['type']} for property {property}")
  331. elif "anyOf" in property and isinstance(property["anyOf"], list):
  332. return self._convert_body_property_any_of(property, value, property["anyOf"])
  333. except ValueError:
  334. return value
  335. def _invoke(
  336. self,
  337. user_id: str,
  338. tool_parameters: dict[str, Any],
  339. conversation_id: Optional[str] = None,
  340. app_id: Optional[str] = None,
  341. message_id: Optional[str] = None,
  342. ) -> Generator[ToolInvokeMessage, None, None]:
  343. """
  344. invoke http request
  345. """
  346. response: httpx.Response | str = ""
  347. # assemble request
  348. headers = self.assembling_request(tool_parameters)
  349. # do http request
  350. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
  351. # validate response
  352. parsed_response = self.validate_and_parse_response(response)
  353. # assemble invoke message based on response type
  354. if parsed_response.is_json and isinstance(parsed_response.content, dict):
  355. yield self.create_json_message(parsed_response.content)
  356. else:
  357. # Convert to string if needed and create text message
  358. text_response = (
  359. parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
  360. )
  361. yield self.create_text_message(text_response)