您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import json
  2. from collections.abc import Generator
  3. from os import getenv
  4. from typing import Any, Optional
  5. from urllib.parse import urlencode
  6. import httpx
  7. from core.file.file_manager import download
  8. from core.helper import ssrf_proxy
  9. from core.tools.__base.tool import Tool
  10. from core.tools.__base.tool_runtime import ToolRuntime
  11. from core.tools.entities.tool_bundle import ApiToolBundle
  12. from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
  13. from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
  14. API_TOOL_DEFAULT_TIMEOUT = (
  15. int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
  16. int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
  17. )
  18. class ApiTool(Tool):
  19. api_bundle: ApiToolBundle
  20. provider_id: str
  21. """
  22. Api tool
  23. """
  24. def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str):
  25. super().__init__(entity, runtime)
  26. self.api_bundle = api_bundle
  27. self.provider_id = provider_id
  28. def fork_tool_runtime(self, runtime: ToolRuntime):
  29. """
  30. fork a new tool with metadata
  31. :return: the new tool
  32. """
  33. if self.api_bundle is None:
  34. raise ValueError("api_bundle is required")
  35. return self.__class__(
  36. entity=self.entity,
  37. api_bundle=self.api_bundle.model_copy(),
  38. runtime=runtime,
  39. provider_id=self.provider_id,
  40. )
  41. def validate_credentials(
  42. self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
  43. ) -> str:
  44. """
  45. validate the credentials for Api tool
  46. """
  47. # assemble validate request and request parameters
  48. headers = self.assembling_request(parameters)
  49. if format_only:
  50. return ""
  51. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
  52. # validate response
  53. return self.validate_and_parse_response(response)
  54. def tool_provider_type(self) -> ToolProviderType:
  55. return ToolProviderType.API
  56. def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
  57. if self.runtime is None:
  58. raise ToolProviderCredentialValidationError("runtime not initialized")
  59. headers = {}
  60. if self.runtime is None:
  61. raise ValueError("runtime is required")
  62. credentials = self.runtime.credentials or {}
  63. if "auth_type" not in credentials:
  64. raise ToolProviderCredentialValidationError("Missing auth_type")
  65. if credentials["auth_type"] in ("api_key_header", "api_key"): # backward compatibility:
  66. api_key_header = "Authorization"
  67. if "api_key_header" in credentials:
  68. api_key_header = credentials["api_key_header"]
  69. if "api_key_value" not in credentials:
  70. raise ToolProviderCredentialValidationError("Missing api_key_value")
  71. elif not isinstance(credentials["api_key_value"], str):
  72. raise ToolProviderCredentialValidationError("api_key_value must be a string")
  73. if "api_key_header_prefix" in credentials:
  74. api_key_header_prefix = credentials["api_key_header_prefix"]
  75. if api_key_header_prefix == "basic" and credentials["api_key_value"]:
  76. credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
  77. elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
  78. credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
  79. elif api_key_header_prefix == "custom":
  80. pass
  81. headers[api_key_header] = credentials["api_key_value"]
  82. elif credentials["auth_type"] == "api_key_query":
  83. # For query parameter authentication, we don't add anything to headers
  84. # The query parameter will be added in do_http_request method
  85. pass
  86. needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
  87. for parameter in needed_parameters:
  88. if parameter.required and parameter.name not in parameters:
  89. if parameter.default is not None:
  90. parameters[parameter.name] = parameter.default
  91. else:
  92. raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
  93. return headers
  94. def validate_and_parse_response(self, response: httpx.Response) -> str:
  95. """
  96. validate the response
  97. """
  98. if isinstance(response, httpx.Response):
  99. if response.status_code >= 400:
  100. raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
  101. if not response.content:
  102. return "Empty response from the tool, please check your parameters and try again."
  103. try:
  104. response = response.json()
  105. try:
  106. return json.dumps(response, ensure_ascii=False)
  107. except Exception:
  108. return json.dumps(response)
  109. except Exception:
  110. return response.text
  111. else:
  112. raise ValueError(f"Invalid response type {type(response)}")
  113. @staticmethod
  114. def get_parameter_value(parameter, parameters):
  115. if parameter["name"] in parameters:
  116. return parameters[parameter["name"]]
  117. elif parameter.get("required", False):
  118. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  119. else:
  120. return (parameter.get("schema", {}) or {}).get("default", "")
  121. def do_http_request(
  122. self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
  123. ) -> httpx.Response:
  124. """
  125. do http request depending on api bundle
  126. """
  127. method = method.lower()
  128. params = {}
  129. path_params = {}
  130. # FIXME: body should be a dict[str, Any] but it changed a lot in this function
  131. body: Any = {}
  132. cookies = {}
  133. files = []
  134. # Add API key to query parameters if auth_type is api_key_query
  135. if self.runtime and self.runtime.credentials:
  136. credentials = self.runtime.credentials
  137. if credentials.get("auth_type") == "api_key_query":
  138. api_key_query_param = credentials.get("api_key_query_param", "key")
  139. api_key_value = credentials.get("api_key_value")
  140. if api_key_value:
  141. params[api_key_query_param] = api_key_value
  142. # check parameters
  143. for parameter in self.api_bundle.openapi.get("parameters", []):
  144. value = self.get_parameter_value(parameter, parameters)
  145. if parameter["in"] == "path":
  146. path_params[parameter["name"]] = value
  147. elif parameter["in"] == "query":
  148. if value != "":
  149. params[parameter["name"]] = value
  150. elif parameter["in"] == "cookie":
  151. cookies[parameter["name"]] = value
  152. elif parameter["in"] == "header":
  153. headers[parameter["name"]] = str(value)
  154. # check if there is a request body and handle it
  155. if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
  156. # handle json request body
  157. if "content" in self.api_bundle.openapi["requestBody"]:
  158. for content_type in self.api_bundle.openapi["requestBody"]["content"]:
  159. headers["Content-Type"] = content_type
  160. body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
  161. # handle ref schema
  162. if "$ref" in body_schema:
  163. ref_path = body_schema["$ref"].split("/")
  164. ref_name = ref_path[-1]
  165. if (
  166. "components" in self.api_bundle.openapi
  167. and "schemas" in self.api_bundle.openapi["components"]
  168. ):
  169. if ref_name in self.api_bundle.openapi["components"]["schemas"]:
  170. body_schema = self.api_bundle.openapi["components"]["schemas"][ref_name]
  171. required = body_schema.get("required", [])
  172. properties = body_schema.get("properties", {})
  173. for name, property in properties.items():
  174. if name in parameters:
  175. # multiple file upload: if the type is array and the items have format as binary
  176. if property.get("type") == "array" and property.get("items", {}).get("format") == "binary":
  177. # parameters[name] should be a list of file objects.
  178. for f in parameters[name]:
  179. files.append((name, (f.filename, download(f), f.mime_type)))
  180. elif property.get("format") == "binary":
  181. f = parameters[name]
  182. files.append((name, (f.filename, download(f), f.mime_type)))
  183. elif "$ref" in property:
  184. body[name] = parameters[name]
  185. else:
  186. # convert type
  187. body[name] = self._convert_body_property_type(property, parameters[name])
  188. elif name in required:
  189. raise ToolParameterValidationError(
  190. f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
  191. )
  192. elif "default" in property:
  193. body[name] = property["default"]
  194. else:
  195. # omit optional parameters that weren't provided, instead of setting them to None
  196. pass
  197. break
  198. # replace path parameters
  199. for name, value in path_params.items():
  200. url = url.replace(f"{{{name}}}", f"{value}")
  201. # parse http body data if needed
  202. if "Content-Type" in headers:
  203. if headers["Content-Type"] == "application/json":
  204. body = json.dumps(body)
  205. elif headers["Content-Type"] == "application/x-www-form-urlencoded":
  206. body = urlencode(body)
  207. else:
  208. body = body
  209. # if there is a file upload, remove the Content-Type header
  210. # so that httpx can automatically generate the boundary header required for multipart/form-data.
  211. # issue: https://github.com/langgenius/dify/issues/13684
  212. # reference: https://stackoverflow.com/questions/39280438/fetch-missing-boundary-in-multipart-form-data-post
  213. if files:
  214. headers.pop("Content-Type", None)
  215. if method in {
  216. "get",
  217. "head",
  218. "post",
  219. "put",
  220. "delete",
  221. "patch",
  222. "options",
  223. "GET",
  224. "POST",
  225. "PUT",
  226. "PATCH",
  227. "DELETE",
  228. "HEAD",
  229. "OPTIONS",
  230. }:
  231. response: httpx.Response = getattr(ssrf_proxy, method.lower())(
  232. url,
  233. params=params,
  234. headers=headers,
  235. cookies=cookies,
  236. data=body,
  237. files=files,
  238. timeout=API_TOOL_DEFAULT_TIMEOUT,
  239. follow_redirects=True,
  240. )
  241. return response
  242. else:
  243. raise ValueError(f"Invalid http method {method}")
  244. def _convert_body_property_any_of(
  245. self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
  246. ) -> Any:
  247. if max_recursive <= 0:
  248. raise Exception("Max recursion depth reached")
  249. for option in any_of or []:
  250. try:
  251. if "type" in option:
  252. # Attempt to convert the value based on the type.
  253. if option["type"] == "integer" or option["type"] == "int":
  254. return int(value)
  255. elif option["type"] == "number":
  256. if "." in str(value):
  257. return float(value)
  258. else:
  259. return int(value)
  260. elif option["type"] == "string":
  261. return str(value)
  262. elif option["type"] == "boolean":
  263. if str(value).lower() in {"true", "1"}:
  264. return True
  265. elif str(value).lower() in {"false", "0"}:
  266. return False
  267. else:
  268. continue # Not a boolean, try next option
  269. elif option["type"] == "null" and not value:
  270. return None
  271. else:
  272. continue # Unsupported type, try next option
  273. elif "anyOf" in option and isinstance(option["anyOf"], list):
  274. # Recursive call to handle nested anyOf
  275. return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1)
  276. except ValueError:
  277. continue # Conversion failed, try next option
  278. # If no option succeeded, you might want to return the value as is or raise an error
  279. return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
  280. def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
  281. try:
  282. if "type" in property:
  283. if property["type"] == "integer" or property["type"] == "int":
  284. return int(value)
  285. elif property["type"] == "number":
  286. # check if it is a float
  287. if "." in str(value):
  288. return float(value)
  289. else:
  290. return int(value)
  291. elif property["type"] == "string":
  292. return str(value)
  293. elif property["type"] == "boolean":
  294. return bool(value)
  295. elif property["type"] == "null":
  296. if value is None:
  297. return None
  298. elif property["type"] == "object" or property["type"] == "array":
  299. if isinstance(value, str):
  300. try:
  301. return json.loads(value)
  302. except ValueError:
  303. return value
  304. elif isinstance(value, dict):
  305. return value
  306. else:
  307. return value
  308. else:
  309. raise ValueError(f"Invalid type {property['type']} for property {property}")
  310. elif "anyOf" in property and isinstance(property["anyOf"], list):
  311. return self._convert_body_property_any_of(property, value, property["anyOf"])
  312. except ValueError:
  313. return value
  314. def _invoke(
  315. self,
  316. user_id: str,
  317. tool_parameters: dict[str, Any],
  318. conversation_id: Optional[str] = None,
  319. app_id: Optional[str] = None,
  320. message_id: Optional[str] = None,
  321. ) -> Generator[ToolInvokeMessage, None, None]:
  322. """
  323. invoke http request
  324. """
  325. response: httpx.Response | str = ""
  326. # assemble request
  327. headers = self.assembling_request(tool_parameters)
  328. # do http request
  329. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
  330. # validate response
  331. response = self.validate_and_parse_response(response)
  332. # assemble invoke message
  333. yield self.create_text_message(response)