You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import json
  2. from datetime import datetime
  3. from typing import TYPE_CHECKING, Any, cast
  4. from urllib.parse import urlparse
  5. import sqlalchemy as sa
  6. from deprecated import deprecated
  7. from sqlalchemy import ForeignKey, String, func
  8. from sqlalchemy.orm import Mapped, mapped_column
  9. from core.helper import encrypter
  10. from models.base import Base
  11. from .engine import db
  12. from .model import Account, App, Tenant
  13. from .types import StringUUID
  14. if TYPE_CHECKING:
  15. from core.mcp.types import Tool as MCPTool
  16. from core.tools.entities.common_entities import I18nObject
  17. from core.tools.entities.tool_bundle import ApiToolBundle
  18. from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
  19. # system level tool oauth client params (client_id, client_secret, etc.)
  20. class ToolOAuthSystemClient(Base):
  21. __tablename__ = "tool_oauth_system_clients"
  22. __table_args__ = (
  23. sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
  24. sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
  25. )
  26. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  27. plugin_id = mapped_column(String(512), nullable=False)
  28. provider: Mapped[str] = mapped_column(String(255), nullable=False)
  29. # oauth params of the tool provider
  30. encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
  31. # tenant level tool oauth client params (client_id, client_secret, etc.)
  32. class ToolOAuthTenantClient(Base):
  33. __tablename__ = "tool_oauth_tenant_clients"
  34. __table_args__ = (
  35. sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
  36. sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
  37. )
  38. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  39. # tenant id
  40. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  41. plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
  42. provider: Mapped[str] = mapped_column(String(255), nullable=False)
  43. enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
  44. # oauth params of the tool provider
  45. encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
  46. @property
  47. def oauth_params(self) -> dict:
  48. return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
  49. class BuiltinToolProvider(Base):
  50. """
  51. This table stores the tool provider information for built-in tools for each tenant.
  52. """
  53. __tablename__ = "tool_builtin_providers"
  54. __table_args__ = (
  55. sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
  56. sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
  57. )
  58. # id of the tool provider
  59. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  60. name: Mapped[str] = mapped_column(
  61. String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
  62. )
  63. # id of the tenant
  64. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
  65. # who created this tool provider
  66. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  67. # name of the tool provider
  68. provider: Mapped[str] = mapped_column(String(256), nullable=False)
  69. # credential of the tool provider
  70. encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
  71. created_at: Mapped[datetime] = mapped_column(
  72. sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
  73. )
  74. updated_at: Mapped[datetime] = mapped_column(
  75. sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
  76. )
  77. is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
  78. # credential type, e.g., "api-key", "oauth2"
  79. credential_type: Mapped[str] = mapped_column(
  80. String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
  81. )
  82. expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
  83. @property
  84. def credentials(self) -> dict:
  85. return cast(dict, json.loads(self.encrypted_credentials))
  86. class ApiToolProvider(Base):
  87. """
  88. The table stores the api providers.
  89. """
  90. __tablename__ = "tool_api_providers"
  91. __table_args__ = (
  92. sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
  93. sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
  94. )
  95. id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  96. # name of the api provider
  97. name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
  98. # icon
  99. icon: Mapped[str] = mapped_column(String(255), nullable=False)
  100. # original schema
  101. schema = mapped_column(sa.Text, nullable=False)
  102. schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
  103. # who created this tool
  104. user_id = mapped_column(StringUUID, nullable=False)
  105. # tenant id
  106. tenant_id = mapped_column(StringUUID, nullable=False)
  107. # description of the provider
  108. description = mapped_column(sa.Text, nullable=False)
  109. # json format tools
  110. tools_str = mapped_column(sa.Text, nullable=False)
  111. # json format credentials
  112. credentials_str = mapped_column(sa.Text, nullable=False)
  113. # privacy policy
  114. privacy_policy = mapped_column(String(255), nullable=True)
  115. # custom_disclaimer
  116. custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
  117. created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
  118. updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
  119. @property
  120. def schema_type(self) -> "ApiProviderSchemaType":
  121. from core.tools.entities.tool_entities import ApiProviderSchemaType
  122. return ApiProviderSchemaType.value_of(self.schema_type_str)
  123. @property
  124. def tools(self) -> list["ApiToolBundle"]:
  125. from core.tools.entities.tool_bundle import ApiToolBundle
  126. return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
  127. @property
  128. def credentials(self) -> dict:
  129. return dict(json.loads(self.credentials_str))
  130. @property
  131. def user(self) -> Account | None:
  132. if not self.user_id:
  133. return None
  134. return db.session.query(Account).where(Account.id == self.user_id).first()
  135. @property
  136. def tenant(self) -> Tenant | None:
  137. return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
  138. class ToolLabelBinding(Base):
  139. """
  140. The table stores the labels for tools.
  141. """
  142. __tablename__ = "tool_label_bindings"
  143. __table_args__ = (
  144. sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
  145. sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
  146. )
  147. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  148. # tool id
  149. tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
  150. # tool type
  151. tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
  152. # label name
  153. label_name: Mapped[str] = mapped_column(String(40), nullable=False)
  154. class WorkflowToolProvider(Base):
  155. """
  156. The table stores the workflow providers.
  157. """
  158. __tablename__ = "tool_workflow_providers"
  159. __table_args__ = (
  160. sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
  161. sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
  162. sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
  163. )
  164. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  165. # name of the workflow provider
  166. name: Mapped[str] = mapped_column(String(255), nullable=False)
  167. # label of the workflow provider
  168. label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="")
  169. # icon
  170. icon: Mapped[str] = mapped_column(String(255), nullable=False)
  171. # app id of the workflow provider
  172. app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  173. # version of the workflow provider
  174. version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="")
  175. # who created this tool
  176. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  177. # tenant id
  178. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  179. # description of the provider
  180. description: Mapped[str] = mapped_column(sa.Text, nullable=False)
  181. # parameter configuration
  182. parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
  183. # privacy policy
  184. privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
  185. created_at: Mapped[datetime] = mapped_column(
  186. sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
  187. )
  188. updated_at: Mapped[datetime] = mapped_column(
  189. sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
  190. )
  191. @property
  192. def user(self) -> Account | None:
  193. return db.session.query(Account).where(Account.id == self.user_id).first()
  194. @property
  195. def tenant(self) -> Tenant | None:
  196. return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
  197. @property
  198. def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
  199. from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
  200. return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
  201. @property
  202. def app(self) -> App | None:
  203. return db.session.query(App).where(App.id == self.app_id).first()
  204. class MCPToolProvider(Base):
  205. """
  206. The table stores the mcp providers.
  207. """
  208. __tablename__ = "tool_mcp_providers"
  209. __table_args__ = (
  210. sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
  211. sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
  212. sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
  213. sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
  214. )
  215. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  216. # name of the mcp provider
  217. name: Mapped[str] = mapped_column(String(40), nullable=False)
  218. # server identifier of the mcp provider
  219. server_identifier: Mapped[str] = mapped_column(String(64), nullable=False)
  220. # encrypted url of the mcp provider
  221. server_url: Mapped[str] = mapped_column(sa.Text, nullable=False)
  222. # hash of server_url for uniqueness check
  223. server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
  224. # icon of the mcp provider
  225. icon: Mapped[str] = mapped_column(String(255), nullable=True)
  226. # tenant id
  227. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  228. # who created this tool
  229. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  230. # encrypted credentials
  231. encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
  232. # authed
  233. authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
  234. # tools
  235. tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
  236. created_at: Mapped[datetime] = mapped_column(
  237. sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
  238. )
  239. updated_at: Mapped[datetime] = mapped_column(
  240. sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
  241. )
  242. timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
  243. sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
  244. def load_user(self) -> Account | None:
  245. return db.session.query(Account).where(Account.id == self.user_id).first()
  246. @property
  247. def tenant(self) -> Tenant | None:
  248. return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
  249. @property
  250. def credentials(self) -> dict:
  251. try:
  252. return cast(dict, json.loads(self.encrypted_credentials)) or {}
  253. except Exception:
  254. return {}
  255. @property
  256. def mcp_tools(self) -> list["MCPTool"]:
  257. from core.mcp.types import Tool as MCPTool
  258. return [MCPTool(**tool) for tool in json.loads(self.tools)]
  259. @property
  260. def provider_icon(self) -> dict[str, str] | str:
  261. from core.file import helpers as file_helpers
  262. try:
  263. return cast(dict[str, str], json.loads(self.icon))
  264. except json.JSONDecodeError:
  265. return file_helpers.get_signed_file_url(self.icon)
  266. @property
  267. def decrypted_server_url(self) -> str:
  268. return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
  269. @property
  270. def masked_server_url(self) -> str:
  271. def mask_url(url: str, mask_char: str = "*") -> str:
  272. """
  273. mask the url to a simple string
  274. """
  275. parsed = urlparse(url)
  276. base_url = f"{parsed.scheme}://{parsed.netloc}"
  277. if parsed.path and parsed.path != "/":
  278. return f"{base_url}/{mask_char * 6}"
  279. else:
  280. return base_url
  281. return mask_url(self.decrypted_server_url)
  282. @property
  283. def decrypted_credentials(self) -> dict:
  284. from core.helper.provider_cache import NoOpProviderCredentialCache
  285. from core.tools.mcp_tool.provider import MCPToolProviderController
  286. from core.tools.utils.encryption import create_provider_encrypter
  287. provider_controller = MCPToolProviderController._from_db(self)
  288. encrypter, _ = create_provider_encrypter(
  289. tenant_id=self.tenant_id,
  290. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  291. cache=NoOpProviderCredentialCache(),
  292. )
  293. return encrypter.decrypt(self.credentials) # type: ignore
  294. class ToolModelInvoke(Base):
  295. """
  296. store the invoke logs from tool invoke
  297. """
  298. __tablename__ = "tool_model_invokes"
  299. __table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
  300. id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  301. # who invoke this tool
  302. user_id = mapped_column(StringUUID, nullable=False)
  303. # tenant id
  304. tenant_id = mapped_column(StringUUID, nullable=False)
  305. # provider
  306. provider: Mapped[str] = mapped_column(String(255), nullable=False)
  307. # type
  308. tool_type = mapped_column(String(40), nullable=False)
  309. # tool name
  310. tool_name = mapped_column(String(128), nullable=False)
  311. # invoke parameters
  312. model_parameters = mapped_column(sa.Text, nullable=False)
  313. # prompt messages
  314. prompt_messages = mapped_column(sa.Text, nullable=False)
  315. # invoke response
  316. model_response = mapped_column(sa.Text, nullable=False)
  317. prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
  318. answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
  319. answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
  320. answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
  321. provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
  322. total_price = mapped_column(sa.Numeric(10, 7))
  323. currency: Mapped[str] = mapped_column(String(255), nullable=False)
  324. created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
  325. updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
  326. @deprecated
  327. class ToolConversationVariables(Base):
  328. """
  329. store the conversation variables from tool invoke
  330. """
  331. __tablename__ = "tool_conversation_variables"
  332. __table_args__ = (
  333. sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
  334. # add index for user_id and conversation_id
  335. sa.Index("user_id_idx", "user_id"),
  336. sa.Index("conversation_id_idx", "conversation_id"),
  337. )
  338. id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  339. # conversation user id
  340. user_id = mapped_column(StringUUID, nullable=False)
  341. # tenant id
  342. tenant_id = mapped_column(StringUUID, nullable=False)
  343. # conversation id
  344. conversation_id = mapped_column(StringUUID, nullable=False)
  345. # variables pool
  346. variables_str = mapped_column(sa.Text, nullable=False)
  347. created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
  348. updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
  349. @property
  350. def variables(self) -> Any:
  351. return json.loads(self.variables_str)
  352. class ToolFile(Base):
  353. """This table stores file metadata generated in workflows,
  354. not only files created by agent.
  355. """
  356. __tablename__ = "tool_files"
  357. __table_args__ = (
  358. sa.PrimaryKeyConstraint("id", name="tool_file_pkey"),
  359. sa.Index("tool_file_conversation_id_idx", "conversation_id"),
  360. )
  361. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  362. # conversation user id
  363. user_id: Mapped[str] = mapped_column(StringUUID)
  364. # tenant id
  365. tenant_id: Mapped[str] = mapped_column(StringUUID)
  366. # conversation id
  367. conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
  368. # file key
  369. file_key: Mapped[str] = mapped_column(String(255), nullable=False)
  370. # mime type
  371. mimetype: Mapped[str] = mapped_column(String(255), nullable=False)
  372. # original url
  373. original_url: Mapped[str] = mapped_column(String(2048), nullable=True)
  374. # name
  375. name: Mapped[str] = mapped_column(default="")
  376. # size
  377. size: Mapped[int] = mapped_column(default=-1)
  378. @deprecated
  379. class DeprecatedPublishedAppTool(Base):
  380. """
  381. The table stores the apps published as a tool for each person.
  382. """
  383. __tablename__ = "tool_published_apps"
  384. __table_args__ = (
  385. sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
  386. sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
  387. )
  388. id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  389. # id of the app
  390. app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
  391. user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  392. # who published this tool
  393. description = mapped_column(sa.Text, nullable=False)
  394. # llm_description of the tool, for LLM
  395. llm_description = mapped_column(sa.Text, nullable=False)
  396. # query description, query will be seem as a parameter of the tool,
  397. # to describe this parameter to llm, we need this field
  398. query_description = mapped_column(sa.Text, nullable=False)
  399. # query name, the name of the query parameter
  400. query_name = mapped_column(String(40), nullable=False)
  401. # name of the tool provider
  402. tool_name = mapped_column(String(40), nullable=False)
  403. # author
  404. author = mapped_column(String(40), nullable=False)
  405. created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
  406. updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
  407. @property
  408. def description_i18n(self) -> "I18nObject":
  409. from core.tools.entities.common_entities import I18nObject
  410. return I18nObject(**json.loads(self.description))