您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

api_entities.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from collections.abc import Mapping
  2. from datetime import datetime
  3. from typing import Any, Literal
  4. from pydantic import BaseModel, Field, field_validator
  5. from core.model_runtime.utils.encoders import jsonable_encoder
  6. from core.tools.__base.tool import ToolParameter
  7. from core.tools.entities.common_entities import I18nObject
  8. from core.tools.entities.tool_entities import CredentialType, ToolProviderType
  9. class ToolApiEntity(BaseModel):
  10. author: str
  11. name: str # identifier
  12. label: I18nObject # label
  13. description: I18nObject
  14. parameters: list[ToolParameter] | None = None
  15. labels: list[str] = Field(default_factory=list)
  16. output_schema: Mapping[str, object] = Field(default_factory=dict)
  17. ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None
  18. class ToolProviderApiEntity(BaseModel):
  19. id: str
  20. author: str
  21. name: str # identifier
  22. description: I18nObject
  23. icon: str | Mapping[str, str]
  24. icon_dark: str | Mapping[str, str] = ""
  25. label: I18nObject # label
  26. type: ToolProviderType
  27. masked_credentials: Mapping[str, object] = Field(default_factory=dict)
  28. original_credentials: Mapping[str, object] = Field(default_factory=dict)
  29. is_team_authorization: bool = False
  30. allow_delete: bool = True
  31. plugin_id: str | None = Field(default="", description="The plugin id of the tool")
  32. plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
  33. tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity])
  34. labels: list[str] = Field(default_factory=list)
  35. # MCP
  36. server_url: str | None = Field(default="", description="The server url of the tool")
  37. updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
  38. server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
  39. timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
  40. sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
  41. masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
  42. original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
  43. @field_validator("tools", mode="before")
  44. @classmethod
  45. def convert_none_to_empty_list(cls, v):
  46. return v if v is not None else []
  47. def to_dict(self):
  48. # -------------
  49. # overwrite tool parameter types for temp fix
  50. tools = jsonable_encoder(self.tools)
  51. for tool in tools:
  52. if tool.get("parameters"):
  53. for parameter in tool.get("parameters"):
  54. if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
  55. parameter["type"] = "files"
  56. if parameter.get("input_schema") is None:
  57. parameter.pop("input_schema", None)
  58. # -------------
  59. optional_fields = self.optional_field("server_url", self.server_url)
  60. if self.type == ToolProviderType.MCP:
  61. optional_fields.update(self.optional_field("updated_at", self.updated_at))
  62. optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
  63. optional_fields.update(self.optional_field("timeout", self.timeout))
  64. optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
  65. optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
  66. optional_fields.update(self.optional_field("original_headers", self.original_headers))
  67. return {
  68. "id": self.id,
  69. "author": self.author,
  70. "name": self.name,
  71. "plugin_id": self.plugin_id,
  72. "plugin_unique_identifier": self.plugin_unique_identifier,
  73. "description": self.description.to_dict(),
  74. "icon": self.icon,
  75. "icon_dark": self.icon_dark,
  76. "label": self.label.to_dict(),
  77. "type": self.type.value,
  78. "team_credentials": self.masked_credentials,
  79. "is_team_authorization": self.is_team_authorization,
  80. "allow_delete": self.allow_delete,
  81. "tools": tools,
  82. "labels": self.labels,
  83. **optional_fields,
  84. }
  85. def optional_field(self, key: str, value: Any):
  86. """Return dict with key-value if value is truthy, empty dict otherwise."""
  87. return {key: value} if value else {}
  88. class ToolProviderCredentialApiEntity(BaseModel):
  89. id: str = Field(description="The unique id of the credential")
  90. name: str = Field(description="The name of the credential")
  91. provider: str = Field(description="The provider of the credential")
  92. credential_type: CredentialType = Field(description="The type of the credential")
  93. is_default: bool = Field(
  94. default=False, description="Whether the credential is the default credential for the provider in the workspace"
  95. )
  96. credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
  97. class ToolProviderCredentialInfoApiEntity(BaseModel):
  98. supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
  99. is_oauth_custom_client_enabled: bool = Field(
  100. default=False, description="Whether the OAuth custom client is enabled for the provider"
  101. )
  102. credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")