You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools.py 19KB


  1. import json
  2. from datetime import datetime
  3. from typing import Any, cast
  4. from urllib.parse import urlparse
  5. import sqlalchemy as sa
  6. from deprecated import deprecated
  7. from sqlalchemy import ForeignKey, func
  8. from sqlalchemy.orm import Mapped, mapped_column
  9. from core.file import helpers as file_helpers
  10. from core.helper import encrypter
  11. from core.mcp.types import Tool
  12. from core.tools.entities.common_entities import I18nObject
  13. from core.tools.entities.tool_bundle import ApiToolBundle
  14. from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
  15. from models.base import Base
  16. from .engine import db
  17. from .model import Account, App, Tenant
  18. from .types import StringUUID
  19. # system level tool oauth client params (client_id, client_secret, etc.)
  20. class ToolOAuthSystemClient(Base):
  21. __tablename__ = "tool_oauth_system_clients"
  22. __table_args__ = (
  23. db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
  24. db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
  25. )
  26. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  27. plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
  28. provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
  29. # oauth params of the tool provider
  30. encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
  31. # tenant level tool oauth client params (client_id, client_secret, etc.)
  32. class ToolOAuthTenantClient(Base):
  33. __tablename__ = "tool_oauth_tenant_clients"
  34. __table_args__ = (
  35. db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
  36. db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
  37. )
  38. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  39. # tenant id
  40. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  41. plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
  42. provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
  43. enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
  44. # oauth params of the tool provider
  45. encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
  46. @property
  47. def oauth_params(self) -> dict:
  48. return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
  49. class BuiltinToolProvider(Base):
  50. """
  51. This table stores the tool provider information for built-in tools for each tenant.
  52. """
  53. __tablename__ = "tool_builtin_providers"
  54. __table_args__ = (
  55. db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
  56. db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
  57. )
  58. # id of the tool provider
  59. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  60. name: Mapped[str] = mapped_column(
  61. db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
  62. )
  63. # id of the tenant
  64. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
  65. # who created this tool provider
  66. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  67. # name of the tool provider
  68. provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
  69. # credential of the tool provider
  70. encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
  71. created_at: Mapped[datetime] = mapped_column(
  72. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  73. )
  74. updated_at: Mapped[datetime] = mapped_column(
  75. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  76. )
  77. is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
  78. # credential type, e.g., "api-key", "oauth2"
  79. credential_type: Mapped[str] = mapped_column(
  80. db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
  81. )
  82. @property
  83. def credentials(self) -> dict:
  84. return cast(dict, json.loads(self.encrypted_credentials))
  85. class ApiToolProvider(Base):
  86. """
  87. The table stores the api providers.
  88. """
  89. __tablename__ = "tool_api_providers"
  90. __table_args__ = (
  91. db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
  92. db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
  93. )
  94. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  95. # name of the api provider
  96. name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
  97. # icon
  98. icon = db.Column(db.String(255), nullable=False)
  99. # original schema
  100. schema = db.Column(db.Text, nullable=False)
  101. schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False)
  102. # who created this tool
  103. user_id = db.Column(StringUUID, nullable=False)
  104. # tenant id
  105. tenant_id = db.Column(StringUUID, nullable=False)
  106. # description of the provider
  107. description = db.Column(db.Text, nullable=False)
  108. # json format tools
  109. tools_str = db.Column(db.Text, nullable=False)
  110. # json format credentials
  111. credentials_str = db.Column(db.Text, nullable=False)
  112. # privacy policy
  113. privacy_policy = db.Column(db.String(255), nullable=True)
  114. # custom_disclaimer
  115. custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
  116. created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  117. updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  118. @property
  119. def schema_type(self) -> ApiProviderSchemaType:
  120. return ApiProviderSchemaType.value_of(self.schema_type_str)
  121. @property
  122. def tools(self) -> list[ApiToolBundle]:
  123. return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
  124. @property
  125. def credentials(self) -> dict:
  126. return dict(json.loads(self.credentials_str))
  127. @property
  128. def user(self) -> Account | None:
  129. if not self.user_id:
  130. return None
  131. return db.session.query(Account).filter(Account.id == self.user_id).first()
  132. @property
  133. def tenant(self) -> Tenant | None:
  134. return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  135. class ToolLabelBinding(Base):
  136. """
  137. The table stores the labels for tools.
  138. """
  139. __tablename__ = "tool_label_bindings"
  140. __table_args__ = (
  141. db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
  142. db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
  143. )
  144. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  145. # tool id
  146. tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False)
  147. # tool type
  148. tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
  149. # label name
  150. label_name: Mapped[str] = mapped_column(db.String(40), nullable=False)
  151. class WorkflowToolProvider(Base):
  152. """
  153. The table stores the workflow providers.
  154. """
  155. __tablename__ = "tool_workflow_providers"
  156. __table_args__ = (
  157. db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
  158. db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
  159. db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
  160. )
  161. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  162. # name of the workflow provider
  163. name: Mapped[str] = mapped_column(db.String(255), nullable=False)
  164. # label of the workflow provider
  165. label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
  166. # icon
  167. icon: Mapped[str] = mapped_column(db.String(255), nullable=False)
  168. # app id of the workflow provider
  169. app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  170. # version of the workflow provider
  171. version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
  172. # who created this tool
  173. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  174. # tenant id
  175. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  176. # description of the provider
  177. description: Mapped[str] = mapped_column(db.Text, nullable=False)
  178. # parameter configuration
  179. parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]")
  180. # privacy policy
  181. privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="")
  182. created_at: Mapped[datetime] = mapped_column(
  183. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  184. )
  185. updated_at: Mapped[datetime] = mapped_column(
  186. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  187. )
  188. @property
  189. def user(self) -> Account | None:
  190. return db.session.query(Account).filter(Account.id == self.user_id).first()
  191. @property
  192. def tenant(self) -> Tenant | None:
  193. return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  194. @property
  195. def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
  196. return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
  197. @property
  198. def app(self) -> App | None:
  199. return db.session.query(App).filter(App.id == self.app_id).first()
  200. class MCPToolProvider(Base):
  201. """
  202. The table stores the mcp providers.
  203. """
  204. __tablename__ = "tool_mcp_providers"
  205. __table_args__ = (
  206. db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
  207. db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
  208. db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
  209. db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
  210. )
  211. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  212. # name of the mcp provider
  213. name: Mapped[str] = mapped_column(db.String(40), nullable=False)
  214. # server identifier of the mcp provider
  215. server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
  216. # encrypted url of the mcp provider
  217. server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
  218. # hash of server_url for uniqueness check
  219. server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
  220. # icon of the mcp provider
  221. icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
  222. # tenant id
  223. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  224. # who created this tool
  225. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  226. # encrypted credentials
  227. encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
  228. # authed
  229. authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
  230. # tools
  231. tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
  232. created_at: Mapped[datetime] = mapped_column(
  233. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  234. )
  235. updated_at: Mapped[datetime] = mapped_column(
  236. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  237. )
  238. def load_user(self) -> Account | None:
  239. return db.session.query(Account).filter(Account.id == self.user_id).first()
  240. @property
  241. def tenant(self) -> Tenant | None:
  242. return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  243. @property
  244. def credentials(self) -> dict:
  245. try:
  246. return cast(dict, json.loads(self.encrypted_credentials)) or {}
  247. except Exception:
  248. return {}
  249. @property
  250. def mcp_tools(self) -> list[Tool]:
  251. return [Tool(**tool) for tool in json.loads(self.tools)]
  252. @property
  253. def provider_icon(self) -> dict[str, str] | str:
  254. try:
  255. return cast(dict[str, str], json.loads(self.icon))
  256. except json.JSONDecodeError:
  257. return file_helpers.get_signed_file_url(self.icon)
  258. @property
  259. def decrypted_server_url(self) -> str:
  260. return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
  261. @property
  262. def masked_server_url(self) -> str:
  263. def mask_url(url: str, mask_char: str = "*") -> str:
  264. """
  265. mask the url to a simple string
  266. """
  267. parsed = urlparse(url)
  268. base_url = f"{parsed.scheme}://{parsed.netloc}"
  269. if parsed.path and parsed.path != "/":
  270. return f"{base_url}/{mask_char * 6}"
  271. else:
  272. return base_url
  273. return mask_url(self.decrypted_server_url)
  274. @property
  275. def decrypted_credentials(self) -> dict:
  276. from core.helper.provider_cache import NoOpProviderCredentialCache
  277. from core.tools.mcp_tool.provider import MCPToolProviderController
  278. from core.tools.utils.encryption import create_provider_encrypter
  279. provider_controller = MCPToolProviderController._from_db(self)
  280. encrypter, _ = create_provider_encrypter(
  281. tenant_id=self.tenant_id,
  282. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  283. cache=NoOpProviderCredentialCache(),
  284. )
  285. return encrypter.decrypt(self.credentials) # type: ignore
  286. class ToolModelInvoke(Base):
  287. """
  288. store the invoke logs from tool invoke
  289. """
  290. __tablename__ = "tool_model_invokes"
  291. __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
  292. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  293. # who invoke this tool
  294. user_id = db.Column(StringUUID, nullable=False)
  295. # tenant id
  296. tenant_id = db.Column(StringUUID, nullable=False)
  297. # provider
  298. provider = db.Column(db.String(255), nullable=False)
  299. # type
  300. tool_type = db.Column(db.String(40), nullable=False)
  301. # tool name
  302. tool_name = db.Column(db.String(128), nullable=False)
  303. # invoke parameters
  304. model_parameters = db.Column(db.Text, nullable=False)
  305. # prompt messages
  306. prompt_messages = db.Column(db.Text, nullable=False)
  307. # invoke response
  308. model_response = db.Column(db.Text, nullable=False)
  309. prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
  310. answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
  311. answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
  312. answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
  313. provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
  314. total_price = db.Column(db.Numeric(10, 7))
  315. currency = db.Column(db.String(255), nullable=False)
  316. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  317. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  318. @deprecated
  319. class ToolConversationVariables(Base):
  320. """
  321. store the conversation variables from tool invoke
  322. """
  323. __tablename__ = "tool_conversation_variables"
  324. __table_args__ = (
  325. db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
  326. # add index for user_id and conversation_id
  327. db.Index("user_id_idx", "user_id"),
  328. db.Index("conversation_id_idx", "conversation_id"),
  329. )
  330. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  331. # conversation user id
  332. user_id = db.Column(StringUUID, nullable=False)
  333. # tenant id
  334. tenant_id = db.Column(StringUUID, nullable=False)
  335. # conversation id
  336. conversation_id = db.Column(StringUUID, nullable=False)
  337. # variables pool
  338. variables_str = db.Column(db.Text, nullable=False)
  339. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  340. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  341. @property
  342. def variables(self) -> Any:
  343. return json.loads(self.variables_str)
  344. class ToolFile(Base):
  345. """This table stores file metadata generated in workflows,
  346. not only files created by agent.
  347. """
  348. __tablename__ = "tool_files"
  349. __table_args__ = (
  350. db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
  351. db.Index("tool_file_conversation_id_idx", "conversation_id"),
  352. )
  353. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  354. # conversation user id
  355. user_id: Mapped[str] = mapped_column(StringUUID)
  356. # tenant id
  357. tenant_id: Mapped[str] = mapped_column(StringUUID)
  358. # conversation id
  359. conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
  360. # file key
  361. file_key: Mapped[str] = mapped_column(db.String(255), nullable=False)
  362. # mime type
  363. mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
  364. # original url
  365. original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
  366. # name
  367. name: Mapped[str] = mapped_column(default="")
  368. # size
  369. size: Mapped[int] = mapped_column(default=-1)
  370. @deprecated
  371. class DeprecatedPublishedAppTool(Base):
  372. """
  373. The table stores the apps published as a tool for each person.
  374. """
  375. __tablename__ = "tool_published_apps"
  376. __table_args__ = (
  377. db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
  378. db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
  379. )
  380. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  381. # id of the app
  382. app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
  383. user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
  384. # who published this tool
  385. description = db.Column(db.Text, nullable=False)
  386. # llm_description of the tool, for LLM
  387. llm_description = db.Column(db.Text, nullable=False)
  388. # query description, query will be seem as a parameter of the tool,
  389. # to describe this parameter to llm, we need this field
  390. query_description = db.Column(db.Text, nullable=False)
  391. # query name, the name of the query parameter
  392. query_name = db.Column(db.String(40), nullable=False)
  393. # name of the tool provider
  394. tool_name = db.Column(db.String(40), nullable=False)
  395. # author
  396. author = db.Column(db.String(40), nullable=False)
  397. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  398. updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  399. @property
  400. def description_i18n(self) -> I18nObject:
  401. return I18nObject(**json.loads(self.description))