您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

plugin.py 7.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import datetime
  2. import re
  3. from collections.abc import Mapping
  4. from enum import StrEnum, auto
  5. from typing import Any, Optional
  6. from packaging.version import InvalidVersion, Version
  7. from pydantic import BaseModel, Field, field_validator, model_validator
  8. from werkzeug.exceptions import NotFound
  9. from core.agent.plugin_entities import AgentStrategyProviderEntity
  10. from core.model_runtime.entities.provider_entities import ProviderEntity
  11. from core.plugin.entities.base import BasePluginEntity
  12. from core.plugin.entities.endpoint import EndpointProviderDeclaration
  13. from core.tools.entities.common_entities import I18nObject
  14. from core.tools.entities.tool_entities import ToolProviderEntity
  15. class PluginInstallationSource(StrEnum):
  16. Github = auto()
  17. Marketplace = auto()
  18. Package = auto()
  19. Remote = auto()
  20. class PluginResourceRequirements(BaseModel):
  21. memory: int
  22. class Permission(BaseModel):
  23. class Tool(BaseModel):
  24. enabled: Optional[bool] = Field(default=False)
  25. class Model(BaseModel):
  26. enabled: Optional[bool] = Field(default=False)
  27. llm: Optional[bool] = Field(default=False)
  28. text_embedding: Optional[bool] = Field(default=False)
  29. rerank: Optional[bool] = Field(default=False)
  30. tts: Optional[bool] = Field(default=False)
  31. speech2text: Optional[bool] = Field(default=False)
  32. moderation: Optional[bool] = Field(default=False)
  33. class Node(BaseModel):
  34. enabled: Optional[bool] = Field(default=False)
  35. class Endpoint(BaseModel):
  36. enabled: Optional[bool] = Field(default=False)
  37. class Storage(BaseModel):
  38. enabled: Optional[bool] = Field(default=False)
  39. size: int = Field(ge=1024, le=1073741824, default=1048576)
  40. tool: Optional[Tool] = Field(default=None)
  41. model: Optional[Model] = Field(default=None)
  42. node: Optional[Node] = Field(default=None)
  43. endpoint: Optional[Endpoint] = Field(default=None)
  44. storage: Optional[Storage] = Field(default=None)
  45. permission: Optional[Permission] = Field(default=None)
  46. class PluginCategory(StrEnum):
  47. Tool = auto()
  48. Model = auto()
  49. Extension = auto()
  50. AgentStrategy = "agent-strategy"
  51. class PluginDeclaration(BaseModel):
  52. class Plugins(BaseModel):
  53. tools: Optional[list[str]] = Field(default_factory=list[str])
  54. models: Optional[list[str]] = Field(default_factory=list[str])
  55. endpoints: Optional[list[str]] = Field(default_factory=list[str])
  56. class Meta(BaseModel):
  57. minimum_dify_version: Optional[str] = Field(default=None)
  58. version: Optional[str] = Field(default=None)
  59. @field_validator("minimum_dify_version")
  60. @classmethod
  61. def validate_minimum_dify_version(cls, v: Optional[str]) -> Optional[str]:
  62. if v is None:
  63. return v
  64. try:
  65. Version(v)
  66. return v
  67. except InvalidVersion as e:
  68. raise ValueError(f"Invalid version format: {v}") from e
  69. version: str = Field(...)
  70. author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
  71. name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
  72. description: I18nObject
  73. icon: str
  74. icon_dark: Optional[str] = Field(default=None)
  75. label: I18nObject
  76. category: PluginCategory
  77. created_at: datetime.datetime
  78. resource: PluginResourceRequirements
  79. plugins: Plugins
  80. tags: list[str] = Field(default_factory=list)
  81. repo: Optional[str] = Field(default=None)
  82. verified: bool = Field(default=False)
  83. tool: Optional[ToolProviderEntity] = None
  84. model: Optional[ProviderEntity] = None
  85. endpoint: Optional[EndpointProviderDeclaration] = None
  86. agent_strategy: Optional[AgentStrategyProviderEntity] = None
  87. meta: Meta
  88. @field_validator("version")
  89. @classmethod
  90. def validate_version(cls, v: str) -> str:
  91. try:
  92. Version(v)
  93. return v
  94. except InvalidVersion as e:
  95. raise ValueError(f"Invalid version format: {v}") from e
  96. @model_validator(mode="before")
  97. @classmethod
  98. def validate_category(cls, values: dict):
  99. # auto detect category
  100. if values.get("tool"):
  101. values["category"] = PluginCategory.Tool
  102. elif values.get("model"):
  103. values["category"] = PluginCategory.Model
  104. elif values.get("agent_strategy"):
  105. values["category"] = PluginCategory.AgentStrategy
  106. else:
  107. values["category"] = PluginCategory.Extension
  108. return values
  109. class PluginInstallation(BasePluginEntity):
  110. tenant_id: str
  111. endpoints_setups: int
  112. endpoints_active: int
  113. runtime_type: str
  114. source: PluginInstallationSource
  115. meta: Mapping[str, Any]
  116. plugin_id: str
  117. plugin_unique_identifier: str
  118. version: str
  119. checksum: str
  120. declaration: PluginDeclaration
  121. class PluginEntity(PluginInstallation):
  122. name: str
  123. installation_id: str
  124. version: str
  125. @model_validator(mode="after")
  126. def set_plugin_id(self):
  127. if self.declaration.tool:
  128. self.declaration.tool.plugin_id = self.plugin_id
  129. return self
  130. class GenericProviderID:
  131. organization: str
  132. plugin_name: str
  133. provider_name: str
  134. is_hardcoded: bool
  135. def to_string(self) -> str:
  136. return str(self)
  137. def __str__(self) -> str:
  138. return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
  139. def __init__(self, value: str, is_hardcoded: bool = False):
  140. if not value:
  141. raise NotFound("plugin not found, please add plugin")
  142. # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
  143. if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
  144. # check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
  145. if re.match(r"^[a-z0-9_-]+$", value):
  146. value = f"langgenius/{value}/{value}"
  147. else:
  148. raise ValueError(f"Invalid plugin id {value}")
  149. self.organization, self.plugin_name, self.provider_name = value.split("/")
  150. self.is_hardcoded = is_hardcoded
  151. def is_langgenius(self) -> bool:
  152. return self.organization == "langgenius"
  153. @property
  154. def plugin_id(self) -> str:
  155. return f"{self.organization}/{self.plugin_name}"
  156. class ModelProviderID(GenericProviderID):
  157. def __init__(self, value: str, is_hardcoded: bool = False):
  158. super().__init__(value, is_hardcoded)
  159. if self.organization == "langgenius" and self.provider_name == "google":
  160. self.plugin_name = "gemini"
  161. class ToolProviderID(GenericProviderID):
  162. def __init__(self, value: str, is_hardcoded: bool = False):
  163. super().__init__(value, is_hardcoded)
  164. if self.organization == "langgenius":
  165. if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
  166. self.plugin_name = f"{self.provider_name}_tool"
  167. class PluginDependency(BaseModel):
  168. class Type(StrEnum):
  169. Github = PluginInstallationSource.Github
  170. Marketplace = PluginInstallationSource.Marketplace
  171. Package = PluginInstallationSource.Package
  172. class Github(BaseModel):
  173. repo: str
  174. version: str
  175. package: str
  176. github_plugin_unique_identifier: str
  177. @property
  178. def plugin_unique_identifier(self) -> str:
  179. return self.github_plugin_unique_identifier
  180. class Marketplace(BaseModel):
  181. marketplace_plugin_unique_identifier: str
  182. @property
  183. def plugin_unique_identifier(self) -> str:
  184. return self.marketplace_plugin_unique_identifier
  185. class Package(BaseModel):
  186. plugin_unique_identifier: str
  187. type: Type
  188. value: Github | Marketplace | Package
  189. current_identifier: Optional[str] = None
  190. class MissingPluginDependency(BaseModel):
  191. plugin_unique_identifier: str
  192. current_identifier: Optional[str] = None