You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import json
  2. from datetime import datetime
  3. from typing import Any, cast
  4. from urllib.parse import urlparse
  5. import sqlalchemy as sa
  6. from deprecated import deprecated
  7. from sqlalchemy import ForeignKey, func
  8. from sqlalchemy.orm import Mapped, mapped_column
  9. from core.file import helpers as file_helpers
  10. from core.helper import encrypter
  11. from core.mcp.types import Tool
  12. from core.tools.entities.common_entities import I18nObject
  13. from core.tools.entities.tool_bundle import ApiToolBundle
  14. from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
  15. from models.base import Base
  16. from .engine import db
  17. from .model import Account, App, Tenant
  18. from .types import StringUUID
  19. # system level tool oauth client params (client_id, client_secret, etc.)
  20. class ToolOAuthSystemClient(Base):
  21. __tablename__ = "tool_oauth_system_clients"
  22. __table_args__ = (
  23. db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
  24. db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
  25. )
  26. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  27. plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
  28. provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
  29. # oauth params of the tool provider
  30. encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
  31. # tenant level tool oauth client params (client_id, client_secret, etc.)
  32. class ToolOAuthTenantClient(Base):
  33. __tablename__ = "tool_oauth_tenant_clients"
  34. __table_args__ = (
  35. db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
  36. db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
  37. )
  38. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  39. # tenant id
  40. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  41. plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
  42. provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
  43. enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
  44. # oauth params of the tool provider
  45. encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
  46. @property
  47. def oauth_params(self) -> dict:
  48. return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
  49. class BuiltinToolProvider(Base):
  50. """
  51. This table stores the tool provider information for built-in tools for each tenant.
  52. """
  53. __tablename__ = "tool_builtin_providers"
  54. __table_args__ = (
  55. db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
  56. db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
  57. )
  58. # id of the tool provider
  59. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  60. name: Mapped[str] = mapped_column(
  61. db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
  62. )
  63. # id of the tenant
  64. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
  65. # who created this tool provider
  66. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  67. # name of the tool provider
  68. provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
  69. # credential of the tool provider
  70. encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
  71. created_at: Mapped[datetime] = mapped_column(
  72. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  73. )
  74. updated_at: Mapped[datetime] = mapped_column(
  75. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  76. )
  77. is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
  78. # credential type, e.g., "api-key", "oauth2"
  79. credential_type: Mapped[str] = mapped_column(
  80. db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
  81. )
  82. expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
  83. @property
  84. def credentials(self) -> dict:
  85. return cast(dict, json.loads(self.encrypted_credentials))
  86. class ApiToolProvider(Base):
  87. """
  88. The table stores the api providers.
  89. """
  90. __tablename__ = "tool_api_providers"
  91. __table_args__ = (
  92. db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
  93. db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
  94. )
  95. id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  96. # name of the api provider
  97. name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
  98. # icon
  99. icon = mapped_column(db.String(255), nullable=False)
  100. # original schema
  101. schema = mapped_column(db.Text, nullable=False)
  102. schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False)
  103. # who created this tool
  104. user_id = mapped_column(StringUUID, nullable=False)
  105. # tenant id
  106. tenant_id = mapped_column(StringUUID, nullable=False)
  107. # description of the provider
  108. description = mapped_column(db.Text, nullable=False)
  109. # json format tools
  110. tools_str = mapped_column(db.Text, nullable=False)
  111. # json format credentials
  112. credentials_str = mapped_column(db.Text, nullable=False)
  113. # privacy policy
  114. privacy_policy = mapped_column(db.String(255), nullable=True)
  115. # custom_disclaimer
  116. custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
  117. created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  118. updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  119. @property
  120. def schema_type(self) -> ApiProviderSchemaType:
  121. return ApiProviderSchemaType.value_of(self.schema_type_str)
  122. @property
  123. def tools(self) -> list[ApiToolBundle]:
  124. return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
  125. @property
  126. def credentials(self) -> dict:
  127. return dict(json.loads(self.credentials_str))
  128. @property
  129. def user(self) -> Account | None:
  130. if not self.user_id:
  131. return None
  132. return db.session.query(Account).filter(Account.id == self.user_id).first()
  133. @property
  134. def tenant(self) -> Tenant | None:
  135. return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  136. class ToolLabelBinding(Base):
  137. """
  138. The table stores the labels for tools.
  139. """
  140. __tablename__ = "tool_label_bindings"
  141. __table_args__ = (
  142. db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
  143. db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
  144. )
  145. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  146. # tool id
  147. tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False)
  148. # tool type
  149. tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
  150. # label name
  151. label_name: Mapped[str] = mapped_column(db.String(40), nullable=False)
  152. class WorkflowToolProvider(Base):
  153. """
  154. The table stores the workflow providers.
  155. """
  156. __tablename__ = "tool_workflow_providers"
  157. __table_args__ = (
  158. db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
  159. db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
  160. db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
  161. )
  162. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  163. # name of the workflow provider
  164. name: Mapped[str] = mapped_column(db.String(255), nullable=False)
  165. # label of the workflow provider
  166. label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
  167. # icon
  168. icon: Mapped[str] = mapped_column(db.String(255), nullable=False)
  169. # app id of the workflow provider
  170. app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  171. # version of the workflow provider
  172. version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
  173. # who created this tool
  174. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  175. # tenant id
  176. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  177. # description of the provider
  178. description: Mapped[str] = mapped_column(db.Text, nullable=False)
  179. # parameter configuration
  180. parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]")
  181. # privacy policy
  182. privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="")
  183. created_at: Mapped[datetime] = mapped_column(
  184. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  185. )
  186. updated_at: Mapped[datetime] = mapped_column(
  187. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  188. )
  189. @property
  190. def user(self) -> Account | None:
  191. return db.session.query(Account).filter(Account.id == self.user_id).first()
  192. @property
  193. def tenant(self) -> Tenant | None:
  194. return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  195. @property
  196. def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
  197. return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
  198. @property
  199. def app(self) -> App | None:
  200. return db.session.query(App).filter(App.id == self.app_id).first()
  201. class MCPToolProvider(Base):
  202. """
  203. The table stores the mcp providers.
  204. """
  205. __tablename__ = "tool_mcp_providers"
  206. __table_args__ = (
  207. db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
  208. db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
  209. db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
  210. db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
  211. )
  212. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  213. # name of the mcp provider
  214. name: Mapped[str] = mapped_column(db.String(40), nullable=False)
  215. # server identifier of the mcp provider
  216. server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
  217. # encrypted url of the mcp provider
  218. server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
  219. # hash of server_url for uniqueness check
  220. server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
  221. # icon of the mcp provider
  222. icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
  223. # tenant id
  224. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  225. # who created this tool
  226. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  227. # encrypted credentials
  228. encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
  229. # authed
  230. authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
  231. # tools
  232. tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
  233. created_at: Mapped[datetime] = mapped_column(
  234. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  235. )
  236. updated_at: Mapped[datetime] = mapped_column(
  237. db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
  238. )
  239. def load_user(self) -> Account | None:
  240. return db.session.query(Account).filter(Account.id == self.user_id).first()
  241. @property
  242. def tenant(self) -> Tenant | None:
  243. return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  244. @property
  245. def credentials(self) -> dict:
  246. try:
  247. return cast(dict, json.loads(self.encrypted_credentials)) or {}
  248. except Exception:
  249. return {}
  250. @property
  251. def mcp_tools(self) -> list[Tool]:
  252. return [Tool(**tool) for tool in json.loads(self.tools)]
  253. @property
  254. def provider_icon(self) -> dict[str, str] | str:
  255. try:
  256. return cast(dict[str, str], json.loads(self.icon))
  257. except json.JSONDecodeError:
  258. return file_helpers.get_signed_file_url(self.icon)
  259. @property
  260. def decrypted_server_url(self) -> str:
  261. return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
  262. @property
  263. def masked_server_url(self) -> str:
  264. def mask_url(url: str, mask_char: str = "*") -> str:
  265. """
  266. mask the url to a simple string
  267. """
  268. parsed = urlparse(url)
  269. base_url = f"{parsed.scheme}://{parsed.netloc}"
  270. if parsed.path and parsed.path != "/":
  271. return f"{base_url}/{mask_char * 6}"
  272. else:
  273. return base_url
  274. return mask_url(self.decrypted_server_url)
  275. @property
  276. def decrypted_credentials(self) -> dict:
  277. from core.helper.provider_cache import NoOpProviderCredentialCache
  278. from core.tools.mcp_tool.provider import MCPToolProviderController
  279. from core.tools.utils.encryption import create_provider_encrypter
  280. provider_controller = MCPToolProviderController._from_db(self)
  281. encrypter, _ = create_provider_encrypter(
  282. tenant_id=self.tenant_id,
  283. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  284. cache=NoOpProviderCredentialCache(),
  285. )
  286. return encrypter.decrypt(self.credentials) # type: ignore
  287. class ToolModelInvoke(Base):
  288. """
  289. store the invoke logs from tool invoke
  290. """
  291. __tablename__ = "tool_model_invokes"
  292. __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
  293. id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  294. # who invoke this tool
  295. user_id = mapped_column(StringUUID, nullable=False)
  296. # tenant id
  297. tenant_id = mapped_column(StringUUID, nullable=False)
  298. # provider
  299. provider = mapped_column(db.String(255), nullable=False)
  300. # type
  301. tool_type = mapped_column(db.String(40), nullable=False)
  302. # tool name
  303. tool_name = mapped_column(db.String(128), nullable=False)
  304. # invoke parameters
  305. model_parameters = mapped_column(db.Text, nullable=False)
  306. # prompt messages
  307. prompt_messages = mapped_column(db.Text, nullable=False)
  308. # invoke response
  309. model_response = mapped_column(db.Text, nullable=False)
  310. prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
  311. answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
  312. answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
  313. answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
  314. provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
  315. total_price = mapped_column(db.Numeric(10, 7))
  316. currency = mapped_column(db.String(255), nullable=False)
  317. created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  318. updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  319. @deprecated
  320. class ToolConversationVariables(Base):
  321. """
  322. store the conversation variables from tool invoke
  323. """
  324. __tablename__ = "tool_conversation_variables"
  325. __table_args__ = (
  326. db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
  327. # add index for user_id and conversation_id
  328. db.Index("user_id_idx", "user_id"),
  329. db.Index("conversation_id_idx", "conversation_id"),
  330. )
  331. id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  332. # conversation user id
  333. user_id = mapped_column(StringUUID, nullable=False)
  334. # tenant id
  335. tenant_id = mapped_column(StringUUID, nullable=False)
  336. # conversation id
  337. conversation_id = mapped_column(StringUUID, nullable=False)
  338. # variables pool
  339. variables_str = mapped_column(db.Text, nullable=False)
  340. created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  341. updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  342. @property
  343. def variables(self) -> Any:
  344. return json.loads(self.variables_str)
  345. class ToolFile(Base):
  346. """This table stores file metadata generated in workflows,
  347. not only files created by agent.
  348. """
  349. __tablename__ = "tool_files"
  350. __table_args__ = (
  351. db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
  352. db.Index("tool_file_conversation_id_idx", "conversation_id"),
  353. )
  354. id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  355. # conversation user id
  356. user_id: Mapped[str] = mapped_column(StringUUID)
  357. # tenant id
  358. tenant_id: Mapped[str] = mapped_column(StringUUID)
  359. # conversation id
  360. conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
  361. # file key
  362. file_key: Mapped[str] = mapped_column(db.String(255), nullable=False)
  363. # mime type
  364. mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
  365. # original url
  366. original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
  367. # name
  368. name: Mapped[str] = mapped_column(default="")
  369. # size
  370. size: Mapped[int] = mapped_column(default=-1)
  371. @deprecated
  372. class DeprecatedPublishedAppTool(Base):
  373. """
  374. The table stores the apps published as a tool for each person.
  375. """
  376. __tablename__ = "tool_published_apps"
  377. __table_args__ = (
  378. db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
  379. db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
  380. )
  381. id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  382. # id of the app
  383. app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
  384. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  385. # who published this tool
  386. description = mapped_column(db.Text, nullable=False)
  387. # llm_description of the tool, for LLM
  388. llm_description = mapped_column(db.Text, nullable=False)
  389. # query description, query will be seem as a parameter of the tool,
  390. # to describe this parameter to llm, we need this field
  391. query_description = mapped_column(db.Text, nullable=False)
  392. # query name, the name of the query parameter
  393. query_name = mapped_column(db.String(40), nullable=False)
  394. # name of the tool provider
  395. tool_name = mapped_column(db.String(40), nullable=False)
  396. # author
  397. author = mapped_column(db.String(40), nullable=False)
  398. created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  399. updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  400. @property
  401. def description_i18n(self) -> I18nObject:
  402. return I18nObject(**json.loads(self.description))