Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

api_tools_manage_service.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import json
  2. import logging
  3. from collections.abc import Mapping
  4. from typing import Any, cast
  5. from httpx import get
  6. from sqlalchemy import select
  7. from core.entities.provider_entities import ProviderConfig
  8. from core.model_runtime.utils.encoders import jsonable_encoder
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.custom_tool.provider import ApiToolProviderController
  11. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  12. from core.tools.entities.common_entities import I18nObject
  13. from core.tools.entities.tool_bundle import ApiToolBundle
  14. from core.tools.entities.tool_entities import (
  15. ApiProviderAuthType,
  16. ApiProviderSchemaType,
  17. )
  18. from core.tools.tool_label_manager import ToolLabelManager
  19. from core.tools.tool_manager import ToolManager
  20. from core.tools.utils.encryption import create_tool_provider_encrypter
  21. from core.tools.utils.parser import ApiBasedToolSchemaParser
  22. from extensions.ext_database import db
  23. from models.tools import ApiToolProvider
  24. from services.tools.tools_transform_service import ToolTransformService
  25. logger = logging.getLogger(__name__)
  26. class ApiToolManageService:
  27. @staticmethod
  28. def parser_api_schema(schema: str) -> Mapping[str, Any]:
  29. """
  30. parse api schema to tool bundle
  31. """
  32. try:
  33. warnings: dict[str, str] = {}
  34. try:
  35. tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
  36. except Exception as e:
  37. raise ValueError(f"invalid schema: {str(e)}")
  38. credentials_schema = [
  39. ProviderConfig(
  40. name="auth_type",
  41. type=ProviderConfig.Type.SELECT,
  42. required=True,
  43. default="none",
  44. options=[
  45. ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  46. ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
  47. ],
  48. placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
  49. ),
  50. ProviderConfig(
  51. name="api_key_header",
  52. type=ProviderConfig.Type.TEXT_INPUT,
  53. required=False,
  54. placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
  55. default="api_key",
  56. help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
  57. ),
  58. ProviderConfig(
  59. name="api_key_value",
  60. type=ProviderConfig.Type.TEXT_INPUT,
  61. required=False,
  62. placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
  63. default="",
  64. ),
  65. ]
  66. return cast(
  67. Mapping,
  68. jsonable_encoder(
  69. {
  70. "schema_type": schema_type,
  71. "parameters_schema": tool_bundles,
  72. "credentials_schema": credentials_schema,
  73. "warning": warnings,
  74. }
  75. ),
  76. )
  77. except Exception as e:
  78. raise ValueError(f"invalid schema: {str(e)}")
  79. @staticmethod
  80. def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
  81. """
  82. convert schema to tool bundles
  83. :return: the list of tool bundles, description
  84. """
  85. try:
  86. return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
  87. except Exception as e:
  88. raise ValueError(f"invalid schema: {str(e)}")
  89. @staticmethod
  90. def create_api_tool_provider(
  91. user_id: str,
  92. tenant_id: str,
  93. provider_name: str,
  94. icon: dict,
  95. credentials: dict,
  96. schema_type: str,
  97. schema: str,
  98. privacy_policy: str,
  99. custom_disclaimer: str,
  100. labels: list[str],
  101. ):
  102. """
  103. create api tool provider
  104. """
  105. if schema_type not in [member.value for member in ApiProviderSchemaType]:
  106. raise ValueError(f"invalid schema type {schema}")
  107. provider_name = provider_name.strip()
  108. # check if the provider exists
  109. provider = (
  110. db.session.query(ApiToolProvider)
  111. .where(
  112. ApiToolProvider.tenant_id == tenant_id,
  113. ApiToolProvider.name == provider_name,
  114. )
  115. .first()
  116. )
  117. if provider is not None:
  118. raise ValueError(f"provider {provider_name} already exists")
  119. # parse openapi to tool bundle
  120. extra_info: dict[str, str] = {}
  121. # extra info like description will be set here
  122. tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
  123. if len(tool_bundles) > 100:
  124. raise ValueError("the number of apis should be less than 100")
  125. # create db provider
  126. db_provider = ApiToolProvider(
  127. tenant_id=tenant_id,
  128. user_id=user_id,
  129. name=provider_name,
  130. icon=json.dumps(icon),
  131. schema=schema,
  132. description=extra_info.get("description", ""),
  133. schema_type_str=schema_type,
  134. tools_str=json.dumps(jsonable_encoder(tool_bundles)),
  135. credentials_str={},
  136. privacy_policy=privacy_policy,
  137. custom_disclaimer=custom_disclaimer,
  138. )
  139. if "auth_type" not in credentials:
  140. raise ValueError("auth_type is required")
  141. # get auth type, none or api key
  142. auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
  143. # create provider entity
  144. provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
  145. # load tools into provider entity
  146. provider_controller.load_bundled_tools(tool_bundles)
  147. # encrypt credentials
  148. encrypter, _ = create_tool_provider_encrypter(
  149. tenant_id=tenant_id,
  150. controller=provider_controller,
  151. )
  152. db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
  153. db.session.add(db_provider)
  154. db.session.commit()
  155. # update labels
  156. ToolLabelManager.update_tool_labels(provider_controller, labels)
  157. return {"result": "success"}
  158. @staticmethod
  159. def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
  160. """
  161. get api tool provider remote schema
  162. """
  163. headers = {
  164. "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"
  165. " Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
  166. "Accept": "*/*",
  167. }
  168. try:
  169. response = get(url, headers=headers, timeout=10)
  170. if response.status_code != 200:
  171. raise ValueError(f"Got status code {response.status_code}")
  172. schema = response.text
  173. # try to parse schema, avoid SSRF attack
  174. ApiToolManageService.parser_api_schema(schema)
  175. except Exception:
  176. logger.exception("parse api schema error")
  177. raise ValueError("invalid schema, please check the url you provided")
  178. return {"schema": schema}
  179. @staticmethod
  180. def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
  181. """
  182. list api tool provider tools
  183. """
  184. provider: ApiToolProvider | None = (
  185. db.session.query(ApiToolProvider)
  186. .where(
  187. ApiToolProvider.tenant_id == tenant_id,
  188. ApiToolProvider.name == provider_name,
  189. )
  190. .first()
  191. )
  192. if provider is None:
  193. raise ValueError(f"you have not added provider {provider_name}")
  194. controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
  195. labels = ToolLabelManager.get_tool_labels(controller)
  196. return [
  197. ToolTransformService.convert_tool_entity_to_api_entity(
  198. tool_bundle,
  199. tenant_id=tenant_id,
  200. labels=labels,
  201. )
  202. for tool_bundle in provider.tools
  203. ]
  204. @staticmethod
  205. def update_api_tool_provider(
  206. user_id: str,
  207. tenant_id: str,
  208. provider_name: str,
  209. original_provider: str,
  210. icon: dict,
  211. credentials: dict,
  212. schema_type: str,
  213. schema: str,
  214. privacy_policy: str,
  215. custom_disclaimer: str,
  216. labels: list[str],
  217. ):
  218. """
  219. update api tool provider
  220. """
  221. if schema_type not in [member.value for member in ApiProviderSchemaType]:
  222. raise ValueError(f"invalid schema type {schema}")
  223. provider_name = provider_name.strip()
  224. # check if the provider exists
  225. provider = (
  226. db.session.query(ApiToolProvider)
  227. .where(
  228. ApiToolProvider.tenant_id == tenant_id,
  229. ApiToolProvider.name == original_provider,
  230. )
  231. .first()
  232. )
  233. if provider is None:
  234. raise ValueError(f"api provider {provider_name} does not exists")
  235. # parse openapi to tool bundle
  236. extra_info: dict[str, str] = {}
  237. # extra info like description will be set here
  238. tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
  239. # update db provider
  240. provider.name = provider_name
  241. provider.icon = json.dumps(icon)
  242. provider.schema = schema
  243. provider.description = extra_info.get("description", "")
  244. provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
  245. provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
  246. provider.privacy_policy = privacy_policy
  247. provider.custom_disclaimer = custom_disclaimer
  248. if "auth_type" not in credentials:
  249. raise ValueError("auth_type is required")
  250. # get auth type, none or api key
  251. auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
  252. # create provider entity
  253. provider_controller = ApiToolProviderController.from_db(provider, auth_type)
  254. # load tools into provider entity
  255. provider_controller.load_bundled_tools(tool_bundles)
  256. # get original credentials if exists
  257. encrypter, cache = create_tool_provider_encrypter(
  258. tenant_id=tenant_id,
  259. controller=provider_controller,
  260. )
  261. original_credentials = encrypter.decrypt(provider.credentials)
  262. masked_credentials = encrypter.mask_tool_credentials(original_credentials)
  263. # check if the credential has changed, save the original credential
  264. for name, value in credentials.items():
  265. if name in masked_credentials and value == masked_credentials[name]:
  266. credentials[name] = original_credentials[name]
  267. credentials = encrypter.encrypt(credentials)
  268. provider.credentials_str = json.dumps(credentials)
  269. db.session.add(provider)
  270. db.session.commit()
  271. # delete cache
  272. cache.delete()
  273. # update labels
  274. ToolLabelManager.update_tool_labels(provider_controller, labels)
  275. return {"result": "success"}
  276. @staticmethod
  277. def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  278. """
  279. delete tool provider
  280. """
  281. provider = (
  282. db.session.query(ApiToolProvider)
  283. .where(
  284. ApiToolProvider.tenant_id == tenant_id,
  285. ApiToolProvider.name == provider_name,
  286. )
  287. .first()
  288. )
  289. if provider is None:
  290. raise ValueError(f"you have not added provider {provider_name}")
  291. db.session.delete(provider)
  292. db.session.commit()
  293. return {"result": "success"}
  294. @staticmethod
  295. def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
  296. """
  297. get api tool provider
  298. """
  299. return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
  300. @staticmethod
  301. def test_api_tool_preview(
  302. tenant_id: str,
  303. provider_name: str,
  304. tool_name: str,
  305. credentials: dict,
  306. parameters: dict,
  307. schema_type: str,
  308. schema: str,
  309. ):
  310. """
  311. test api tool before adding api tool provider
  312. """
  313. if schema_type not in [member.value for member in ApiProviderSchemaType]:
  314. raise ValueError(f"invalid schema type {schema_type}")
  315. try:
  316. tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
  317. except Exception:
  318. raise ValueError("invalid schema")
  319. # get tool bundle
  320. tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
  321. if tool_bundle is None:
  322. raise ValueError(f"invalid tool name {tool_name}")
  323. db_provider = (
  324. db.session.query(ApiToolProvider)
  325. .where(
  326. ApiToolProvider.tenant_id == tenant_id,
  327. ApiToolProvider.name == provider_name,
  328. )
  329. .first()
  330. )
  331. if not db_provider:
  332. # create a fake db provider
  333. db_provider = ApiToolProvider(
  334. tenant_id="",
  335. user_id="",
  336. name="",
  337. icon="",
  338. schema=schema,
  339. description="",
  340. schema_type_str=ApiProviderSchemaType.OPENAPI.value,
  341. tools_str=json.dumps(jsonable_encoder(tool_bundles)),
  342. credentials_str=json.dumps(credentials),
  343. )
  344. if "auth_type" not in credentials:
  345. raise ValueError("auth_type is required")
  346. # get auth type, none or api key
  347. auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
  348. # create provider entity
  349. provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
  350. # load tools into provider entity
  351. provider_controller.load_bundled_tools(tool_bundles)
  352. # decrypt credentials
  353. if db_provider.id:
  354. encrypter, _ = create_tool_provider_encrypter(
  355. tenant_id=tenant_id,
  356. controller=provider_controller,
  357. )
  358. decrypted_credentials = encrypter.decrypt(credentials)
  359. # check if the credential has changed, save the original credential
  360. masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
  361. for name, value in credentials.items():
  362. if name in masked_credentials and value == masked_credentials[name]:
  363. credentials[name] = decrypted_credentials[name]
  364. try:
  365. provider_controller.validate_credentials_format(credentials)
  366. # get tool
  367. tool = provider_controller.get_tool(tool_name)
  368. tool = tool.fork_tool_runtime(
  369. runtime=ToolRuntime(
  370. credentials=credentials,
  371. tenant_id=tenant_id,
  372. )
  373. )
  374. result = tool.validate_credentials(credentials, parameters)
  375. except Exception as e:
  376. return {"error": str(e)}
  377. return {"result": result or "empty response"}
  378. @staticmethod
  379. def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
  380. """
  381. list api tools
  382. """
  383. # get all api providers
  384. db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
  385. result: list[ToolProviderApiEntity] = []
  386. for provider in db_providers:
  387. # convert provider controller to user provider
  388. provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
  389. labels = ToolLabelManager.get_tool_labels(provider_controller)
  390. user_provider = ToolTransformService.api_provider_to_user_provider(
  391. provider_controller, db_provider=provider, decrypt_credentials=True
  392. )
  393. user_provider.labels = labels
  394. # add icon
  395. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_provider)
  396. tools = provider_controller.get_tools(tenant_id=tenant_id)
  397. for tool in tools or []:
  398. user_provider.tools.append(
  399. ToolTransformService.convert_tool_entity_to_api_entity(
  400. tenant_id=tenant_id, tool=tool, labels=labels
  401. )
  402. )
  403. result.append(user_provider)
  404. return result