You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_entities.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import base64
  2. import enum
  3. from collections.abc import Mapping
  4. from enum import Enum
  5. from typing import Any, Optional, Union
  6. from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
  7. from core.entities.provider_entities import ProviderConfig
  8. from core.plugin.entities.parameters import (
  9. PluginParameter,
  10. PluginParameterOption,
  11. PluginParameterType,
  12. as_normal_type,
  13. cast_parameter_value,
  14. init_frontend_parameter,
  15. )
  16. from core.tools.entities.common_entities import I18nObject
  17. from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
  18. class ToolLabelEnum(Enum):
  19. SEARCH = "search"
  20. IMAGE = "image"
  21. VIDEOS = "videos"
  22. WEATHER = "weather"
  23. FINANCE = "finance"
  24. DESIGN = "design"
  25. TRAVEL = "travel"
  26. SOCIAL = "social"
  27. NEWS = "news"
  28. MEDICAL = "medical"
  29. PRODUCTIVITY = "productivity"
  30. EDUCATION = "education"
  31. BUSINESS = "business"
  32. ENTERTAINMENT = "entertainment"
  33. UTILITIES = "utilities"
  34. OTHER = "other"
  35. class DatasourceProviderType(enum.StrEnum):
  36. """
  37. Enum class for datasource provider
  38. """
  39. RAG_PIPELINE = "rag_pipeline"
  40. @classmethod
  41. def value_of(cls, value: str) -> "DatasourceProviderType":
  42. """
  43. Get value of given mode.
  44. :param value: mode value
  45. :return: mode
  46. """
  47. for mode in cls:
  48. if mode.value == value:
  49. return mode
  50. raise ValueError(f"invalid mode value {value}")
  51. class ApiProviderSchemaType(Enum):
  52. """
  53. Enum class for api provider schema type.
  54. """
  55. OPENAPI = "openapi"
  56. SWAGGER = "swagger"
  57. OPENAI_PLUGIN = "openai_plugin"
  58. OPENAI_ACTIONS = "openai_actions"
  59. @classmethod
  60. def value_of(cls, value: str) -> "ApiProviderSchemaType":
  61. """
  62. Get value of given mode.
  63. :param value: mode value
  64. :return: mode
  65. """
  66. for mode in cls:
  67. if mode.value == value:
  68. return mode
  69. raise ValueError(f"invalid mode value {value}")
  70. class ApiProviderAuthType(Enum):
  71. """
  72. Enum class for api provider auth type.
  73. """
  74. NONE = "none"
  75. API_KEY = "api_key"
  76. @classmethod
  77. def value_of(cls, value: str) -> "ApiProviderAuthType":
  78. """
  79. Get value of given mode.
  80. :param value: mode value
  81. :return: mode
  82. """
  83. for mode in cls:
  84. if mode.value == value:
  85. return mode
  86. raise ValueError(f"invalid mode value {value}")
  87. class ToolInvokeMessage(BaseModel):
  88. class TextMessage(BaseModel):
  89. text: str
  90. class JsonMessage(BaseModel):
  91. json_object: dict
  92. class BlobMessage(BaseModel):
  93. blob: bytes
  94. class FileMessage(BaseModel):
  95. pass
  96. class VariableMessage(BaseModel):
  97. variable_name: str = Field(..., description="The name of the variable")
  98. variable_value: Any = Field(..., description="The value of the variable")
  99. stream: bool = Field(default=False, description="Whether the variable is streamed")
  100. @model_validator(mode="before")
  101. @classmethod
  102. def transform_variable_value(cls, values) -> Any:
  103. """
  104. Only basic types and lists are allowed.
  105. """
  106. value = values.get("variable_value")
  107. if not isinstance(value, dict | list | str | int | float | bool):
  108. raise ValueError("Only basic types and lists are allowed.")
  109. # if stream is true, the value must be a string
  110. if values.get("stream"):
  111. if not isinstance(value, str):
  112. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  113. return values
  114. @field_validator("variable_name", mode="before")
  115. @classmethod
  116. def transform_variable_name(cls, value: str) -> str:
  117. """
  118. The variable name must be a string.
  119. """
  120. if value in {"json", "text", "files"}:
  121. raise ValueError(f"The variable name '{value}' is reserved.")
  122. return value
  123. class LogMessage(BaseModel):
  124. class LogStatus(Enum):
  125. START = "start"
  126. ERROR = "error"
  127. SUCCESS = "success"
  128. id: str
  129. label: str = Field(..., description="The label of the log")
  130. parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
  131. error: Optional[str] = Field(default=None, description="The error message")
  132. status: LogStatus = Field(..., description="The status of the log")
  133. data: Mapping[str, Any] = Field(..., description="Detailed log data")
  134. metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
  135. class MessageType(Enum):
  136. TEXT = "text"
  137. IMAGE = "image"
  138. LINK = "link"
  139. BLOB = "blob"
  140. JSON = "json"
  141. IMAGE_LINK = "image_link"
  142. BINARY_LINK = "binary_link"
  143. VARIABLE = "variable"
  144. FILE = "file"
  145. LOG = "log"
  146. type: MessageType = MessageType.TEXT
  147. """
  148. plain text, image url or link url
  149. """
  150. message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
  151. meta: dict[str, Any] | None = None
  152. @field_validator("message", mode="before")
  153. @classmethod
  154. def decode_blob_message(cls, v):
  155. if isinstance(v, dict) and "blob" in v:
  156. try:
  157. v["blob"] = base64.b64decode(v["blob"])
  158. except Exception:
  159. pass
  160. return v
  161. @field_serializer("message")
  162. def serialize_message(self, v):
  163. if isinstance(v, self.BlobMessage):
  164. return {"blob": base64.b64encode(v.blob).decode("utf-8")}
  165. return v
  166. class ToolInvokeMessageBinary(BaseModel):
  167. mimetype: str = Field(..., description="The mimetype of the binary")
  168. url: str = Field(..., description="The url of the binary")
  169. file_var: Optional[dict[str, Any]] = None
  170. class DatasourceParameter(PluginParameter):
  171. """
  172. Overrides type
  173. """
  174. class DatasourceParameterType(enum.StrEnum):
  175. """
  176. removes TOOLS_SELECTOR from PluginParameterType
  177. """
  178. STRING = PluginParameterType.STRING.value
  179. NUMBER = PluginParameterType.NUMBER.value
  180. BOOLEAN = PluginParameterType.BOOLEAN.value
  181. SELECT = PluginParameterType.SELECT.value
  182. SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
  183. FILE = PluginParameterType.FILE.value
  184. FILES = PluginParameterType.FILES.value
  185. APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
  186. MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
  187. # deprecated, should not use.
  188. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
  189. def as_normal_type(self):
  190. return as_normal_type(self)
  191. def cast_value(self, value: Any):
  192. return cast_parameter_value(self, value)
  193. class DatasourceParameterForm(Enum):
  194. SCHEMA = "schema" # should be set while adding tool
  195. FORM = "form" # should be set before invoking tool
  196. LLM = "llm" # will be set by LLM
  197. type: DatasourceParameterType = Field(..., description="The type of the parameter")
  198. human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
  199. form: DatasourceParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
  200. llm_description: Optional[str] = None
  201. @classmethod
  202. def get_simple_instance(
  203. cls,
  204. name: str,
  205. llm_description: str,
  206. typ: DatasourceParameterType,
  207. required: bool,
  208. options: Optional[list[str]] = None,
  209. ) -> "DatasourceParameter":
  210. """
  211. get a simple datasource parameter
  212. :param name: the name of the parameter
  213. :param llm_description: the description presented to the LLM
  214. :param typ: the type of the parameter
  215. :param required: if the parameter is required
  216. :param options: the options of the parameter
  217. """
  218. # convert options to ToolParameterOption
  219. # FIXME fix the type error
  220. if options:
  221. option_objs = [
  222. PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
  223. for option in options
  224. ]
  225. else:
  226. option_objs = []
  227. return cls(
  228. name=name,
  229. label=I18nObject(en_US="", zh_Hans=""),
  230. placeholder=None,
  231. human_description=I18nObject(en_US="", zh_Hans=""),
  232. type=typ,
  233. form=cls.ToolParameterForm.LLM,
  234. llm_description=llm_description,
  235. required=required,
  236. options=option_objs,
  237. )
  238. def init_frontend_parameter(self, value: Any):
  239. return init_frontend_parameter(self, self.type, value)
  240. class ToolProviderIdentity(BaseModel):
  241. author: str = Field(..., description="The author of the tool")
  242. name: str = Field(..., description="The name of the tool")
  243. description: I18nObject = Field(..., description="The description of the tool")
  244. icon: str = Field(..., description="The icon of the tool")
  245. label: I18nObject = Field(..., description="The label of the tool")
  246. tags: Optional[list[ToolLabelEnum]] = Field(
  247. default=[],
  248. description="The tags of the tool",
  249. )
  250. class DatasourceIdentity(BaseModel):
  251. author: str = Field(..., description="The author of the tool")
  252. name: str = Field(..., description="The name of the tool")
  253. label: I18nObject = Field(..., description="The label of the tool")
  254. provider: str = Field(..., description="The provider of the tool")
  255. icon: Optional[str] = None
  256. class DatasourceDescription(BaseModel):
  257. human: I18nObject = Field(..., description="The description presented to the user")
  258. llm: str = Field(..., description="The description presented to the LLM")
  259. class DatasourceEntity(BaseModel):
  260. identity: DatasourceIdentity
  261. parameters: list[DatasourceParameter] = Field(default_factory=list)
  262. description: Optional[DatasourceDescription] = None
  263. output_schema: Optional[dict] = None
  264. has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
  265. # pydantic configs
  266. model_config = ConfigDict(protected_namespaces=())
  267. @field_validator("parameters", mode="before")
  268. @classmethod
  269. def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
  270. return v or []
  271. class ToolProviderEntity(BaseModel):
  272. identity: ToolProviderIdentity
  273. plugin_id: Optional[str] = None
  274. credentials_schema: list[ProviderConfig] = Field(default_factory=list)
  275. class DatasourceProviderEntityWithPlugin(ToolProviderEntity):
  276. datasources: list[DatasourceEntity] = Field(default_factory=list)
  277. class WorkflowToolParameterConfiguration(BaseModel):
  278. """
  279. Workflow tool configuration
  280. """
  281. name: str = Field(..., description="The name of the parameter")
  282. description: str = Field(..., description="The description of the parameter")
  283. form: DatasourceParameter.DatasourceParameterForm = Field(..., description="The form of the parameter")
  284. class DatasourceInvokeMeta(BaseModel):
  285. """
  286. Datasource invoke meta
  287. """
  288. time_cost: float = Field(..., description="The time cost of the tool invoke")
  289. error: Optional[str] = None
  290. tool_config: Optional[dict] = None
  291. @classmethod
  292. def empty(cls) -> "DatasourceInvokeMeta":
  293. """
  294. Get an empty instance of DatasourceInvokeMeta
  295. """
  296. return cls(time_cost=0.0, error=None, tool_config={})
  297. @classmethod
  298. def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
  299. """
  300. Get an instance of DatasourceInvokeMeta with error
  301. """
  302. return cls(time_cost=0.0, error=error, tool_config={})
  303. def to_dict(self) -> dict:
  304. return {
  305. "time_cost": self.time_cost,
  306. "error": self.error,
  307. "tool_config": self.tool_config,
  308. }
  309. class DatasourceLabel(BaseModel):
  310. """
  311. Datasource label
  312. """
  313. name: str = Field(..., description="The name of the tool")
  314. label: I18nObject = Field(..., description="The label of the tool")
  315. icon: str = Field(..., description="The icon of the tool")
  316. class DatasourceInvokeFrom(Enum):
  317. """
  318. Enum class for datasource invoke
  319. """
  320. RAG_PIPELINE = "rag_pipeline"
  321. class DatasourceSelector(BaseModel):
  322. dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
  323. class Parameter(BaseModel):
  324. name: str = Field(..., description="The name of the parameter")
  325. type: DatasourceParameter.DatasourceParameterType = Field(..., description="The type of the parameter")
  326. required: bool = Field(..., description="Whether the parameter is required")
  327. description: str = Field(..., description="The description of the parameter")
  328. default: Optional[Union[int, float, str]] = None
  329. options: Optional[list[PluginParameterOption]] = None
  330. provider_id: str = Field(..., description="The id of the provider")
  331. datasource_name: str = Field(..., description="The name of the datasource")
  332. datasource_description: str = Field(..., description="The description of the datasource")
  333. datasource_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
  334. datasource_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
  335. def to_plugin_parameter(self) -> dict[str, Any]:
  336. return self.model_dump()