You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tool_entities.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import base64
  2. import enum
  3. from collections.abc import Mapping
  4. from enum import Enum
  5. from typing import Any, Optional, Union
  6. from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
  7. from core.entities.provider_entities import ProviderConfig
  8. from core.plugin.entities.parameters import (
  9. MCPServerParameterType,
  10. PluginParameter,
  11. PluginParameterOption,
  12. PluginParameterType,
  13. as_normal_type,
  14. cast_parameter_value,
  15. init_frontend_parameter,
  16. )
  17. from core.tools.entities.common_entities import I18nObject
  18. from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
  19. class ToolLabelEnum(Enum):
  20. SEARCH = "search"
  21. IMAGE = "image"
  22. VIDEOS = "videos"
  23. WEATHER = "weather"
  24. FINANCE = "finance"
  25. DESIGN = "design"
  26. TRAVEL = "travel"
  27. SOCIAL = "social"
  28. NEWS = "news"
  29. MEDICAL = "medical"
  30. PRODUCTIVITY = "productivity"
  31. EDUCATION = "education"
  32. BUSINESS = "business"
  33. ENTERTAINMENT = "entertainment"
  34. UTILITIES = "utilities"
  35. OTHER = "other"
  36. class ToolProviderType(enum.StrEnum):
  37. """
  38. Enum class for tool provider
  39. """
  40. PLUGIN = "plugin"
  41. BUILT_IN = "builtin"
  42. WORKFLOW = "workflow"
  43. API = "api"
  44. APP = "app"
  45. DATASET_RETRIEVAL = "dataset-retrieval"
  46. MCP = "mcp"
  47. @classmethod
  48. def value_of(cls, value: str) -> "ToolProviderType":
  49. """
  50. Get value of given mode.
  51. :param value: mode value
  52. :return: mode
  53. """
  54. for mode in cls:
  55. if mode.value == value:
  56. return mode
  57. raise ValueError(f"invalid mode value {value}")
  58. class ApiProviderSchemaType(Enum):
  59. """
  60. Enum class for api provider schema type.
  61. """
  62. OPENAPI = "openapi"
  63. SWAGGER = "swagger"
  64. OPENAI_PLUGIN = "openai_plugin"
  65. OPENAI_ACTIONS = "openai_actions"
  66. @classmethod
  67. def value_of(cls, value: str) -> "ApiProviderSchemaType":
  68. """
  69. Get value of given mode.
  70. :param value: mode value
  71. :return: mode
  72. """
  73. for mode in cls:
  74. if mode.value == value:
  75. return mode
  76. raise ValueError(f"invalid mode value {value}")
  77. class ApiProviderAuthType(Enum):
  78. """
  79. Enum class for api provider auth type.
  80. """
  81. NONE = "none"
  82. API_KEY_HEADER = "api_key_header"
  83. API_KEY_QUERY = "api_key_query"
  84. @classmethod
  85. def value_of(cls, value: str) -> "ApiProviderAuthType":
  86. """
  87. Get value of given mode.
  88. :param value: mode value
  89. :return: mode
  90. """
  91. for mode in cls:
  92. if mode.value == value:
  93. return mode
  94. raise ValueError(f"invalid mode value {value}")
  95. class ToolInvokeMessage(BaseModel):
  96. class TextMessage(BaseModel):
  97. text: str
  98. class JsonMessage(BaseModel):
  99. json_object: dict
  100. class BlobMessage(BaseModel):
  101. blob: bytes
  102. class BlobChunkMessage(BaseModel):
  103. id: str = Field(..., description="The id of the blob")
  104. sequence: int = Field(..., description="The sequence of the chunk")
  105. total_length: int = Field(..., description="The total length of the blob")
  106. blob: bytes = Field(..., description="The blob data of the chunk")
  107. end: bool = Field(..., description="Whether the chunk is the last chunk")
  108. class FileMessage(BaseModel):
  109. pass
  110. class VariableMessage(BaseModel):
  111. variable_name: str = Field(..., description="The name of the variable")
  112. variable_value: Any = Field(..., description="The value of the variable")
  113. stream: bool = Field(default=False, description="Whether the variable is streamed")
  114. @model_validator(mode="before")
  115. @classmethod
  116. def transform_variable_value(cls, values) -> Any:
  117. """
  118. Only basic types and lists are allowed.
  119. """
  120. value = values.get("variable_value")
  121. if not isinstance(value, dict | list | str | int | float | bool):
  122. raise ValueError("Only basic types and lists are allowed.")
  123. # if stream is true, the value must be a string
  124. if values.get("stream"):
  125. if not isinstance(value, str):
  126. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  127. return values
  128. @field_validator("variable_name", mode="before")
  129. @classmethod
  130. def transform_variable_name(cls, value: str) -> str:
  131. """
  132. The variable name must be a string.
  133. """
  134. if value in {"json", "text", "files"}:
  135. raise ValueError(f"The variable name '{value}' is reserved.")
  136. return value
  137. class LogMessage(BaseModel):
  138. class LogStatus(Enum):
  139. START = "start"
  140. ERROR = "error"
  141. SUCCESS = "success"
  142. id: str
  143. label: str = Field(..., description="The label of the log")
  144. parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
  145. error: Optional[str] = Field(default=None, description="The error message")
  146. status: LogStatus = Field(..., description="The status of the log")
  147. data: Mapping[str, Any] = Field(..., description="Detailed log data")
  148. metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
  149. class MessageType(Enum):
  150. TEXT = "text"
  151. IMAGE = "image"
  152. LINK = "link"
  153. BLOB = "blob"
  154. JSON = "json"
  155. IMAGE_LINK = "image_link"
  156. BINARY_LINK = "binary_link"
  157. VARIABLE = "variable"
  158. FILE = "file"
  159. LOG = "log"
  160. BLOB_CHUNK = "blob_chunk"
  161. type: MessageType = MessageType.TEXT
  162. """
  163. plain text, image url or link url
  164. """
  165. message: (
  166. JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
  167. )
  168. meta: dict[str, Any] | None = None
  169. @field_validator("message", mode="before")
  170. @classmethod
  171. def decode_blob_message(cls, v):
  172. if isinstance(v, dict) and "blob" in v:
  173. try:
  174. v["blob"] = base64.b64decode(v["blob"])
  175. except Exception:
  176. pass
  177. return v
  178. @field_serializer("message")
  179. def serialize_message(self, v):
  180. if isinstance(v, self.BlobMessage):
  181. return {"blob": base64.b64encode(v.blob).decode("utf-8")}
  182. return v
  183. class ToolInvokeMessageBinary(BaseModel):
  184. mimetype: str = Field(..., description="The mimetype of the binary")
  185. url: str = Field(..., description="The url of the binary")
  186. file_var: Optional[dict[str, Any]] = None
  187. class ToolParameter(PluginParameter):
  188. """
  189. Overrides type
  190. """
  191. class ToolParameterType(enum.StrEnum):
  192. """
  193. removes TOOLS_SELECTOR from PluginParameterType
  194. """
  195. STRING = PluginParameterType.STRING.value
  196. NUMBER = PluginParameterType.NUMBER.value
  197. BOOLEAN = PluginParameterType.BOOLEAN.value
  198. SELECT = PluginParameterType.SELECT.value
  199. SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
  200. FILE = PluginParameterType.FILE.value
  201. FILES = PluginParameterType.FILES.value
  202. APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
  203. MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
  204. DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
  205. # MCP object and array type parameters
  206. ARRAY = MCPServerParameterType.ARRAY.value
  207. OBJECT = MCPServerParameterType.OBJECT.value
  208. # deprecated, should not use.
  209. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
  210. def as_normal_type(self):
  211. return as_normal_type(self)
  212. def cast_value(self, value: Any):
  213. return cast_parameter_value(self, value)
  214. class ToolParameterForm(Enum):
  215. SCHEMA = "schema" # should be set while adding tool
  216. FORM = "form" # should be set before invoking tool
  217. LLM = "llm" # will be set by LLM
  218. type: ToolParameterType = Field(..., description="The type of the parameter")
  219. human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
  220. form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
  221. llm_description: Optional[str] = None
  222. # MCP object and array type parameters use this field to store the schema
  223. input_schema: Optional[dict] = None
  224. @classmethod
  225. def get_simple_instance(
  226. cls,
  227. name: str,
  228. llm_description: str,
  229. typ: ToolParameterType,
  230. required: bool,
  231. options: Optional[list[str]] = None,
  232. ) -> "ToolParameter":
  233. """
  234. get a simple tool parameter
  235. :param name: the name of the parameter
  236. :param llm_description: the description presented to the LLM
  237. :param typ: the type of the parameter
  238. :param required: if the parameter is required
  239. :param options: the options of the parameter
  240. """
  241. # convert options to ToolParameterOption
  242. if options:
  243. option_objs = [
  244. PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
  245. for option in options
  246. ]
  247. else:
  248. option_objs = []
  249. return cls(
  250. name=name,
  251. label=I18nObject(en_US="", zh_Hans=""),
  252. placeholder=None,
  253. human_description=I18nObject(en_US="", zh_Hans=""),
  254. type=typ,
  255. form=cls.ToolParameterForm.LLM,
  256. llm_description=llm_description,
  257. required=required,
  258. options=option_objs,
  259. )
  260. def init_frontend_parameter(self, value: Any):
  261. return init_frontend_parameter(self, self.type, value)
  262. class ToolProviderIdentity(BaseModel):
  263. author: str = Field(..., description="The author of the tool")
  264. name: str = Field(..., description="The name of the tool")
  265. description: I18nObject = Field(..., description="The description of the tool")
  266. icon: str = Field(..., description="The icon of the tool")
  267. icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool")
  268. label: I18nObject = Field(..., description="The label of the tool")
  269. tags: Optional[list[ToolLabelEnum]] = Field(
  270. default=[],
  271. description="The tags of the tool",
  272. )
  273. class ToolIdentity(BaseModel):
  274. author: str = Field(..., description="The author of the tool")
  275. name: str = Field(..., description="The name of the tool")
  276. label: I18nObject = Field(..., description="The label of the tool")
  277. provider: str = Field(..., description="The provider of the tool")
  278. icon: Optional[str] = None
  279. class ToolDescription(BaseModel):
  280. human: I18nObject = Field(..., description="The description presented to the user")
  281. llm: str = Field(..., description="The description presented to the LLM")
  282. class ToolEntity(BaseModel):
  283. identity: ToolIdentity
  284. parameters: list[ToolParameter] = Field(default_factory=list)
  285. description: Optional[ToolDescription] = None
  286. output_schema: Optional[dict] = None
  287. has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
  288. # pydantic configs
  289. model_config = ConfigDict(protected_namespaces=())
  290. @field_validator("parameters", mode="before")
  291. @classmethod
  292. def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
  293. return v or []
  294. class ToolProviderEntity(BaseModel):
  295. identity: ToolProviderIdentity
  296. plugin_id: Optional[str] = None
  297. credentials_schema: list[ProviderConfig] = Field(default_factory=list)
  298. class ToolProviderEntityWithPlugin(ToolProviderEntity):
  299. tools: list[ToolEntity] = Field(default_factory=list)
  300. class WorkflowToolParameterConfiguration(BaseModel):
  301. """
  302. Workflow tool configuration
  303. """
  304. name: str = Field(..., description="The name of the parameter")
  305. description: str = Field(..., description="The description of the parameter")
  306. form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
  307. class ToolInvokeMeta(BaseModel):
  308. """
  309. Tool invoke meta
  310. """
  311. time_cost: float = Field(..., description="The time cost of the tool invoke")
  312. error: Optional[str] = None
  313. tool_config: Optional[dict] = None
  314. @classmethod
  315. def empty(cls) -> "ToolInvokeMeta":
  316. """
  317. Get an empty instance of ToolInvokeMeta
  318. """
  319. return cls(time_cost=0.0, error=None, tool_config={})
  320. @classmethod
  321. def error_instance(cls, error: str) -> "ToolInvokeMeta":
  322. """
  323. Get an instance of ToolInvokeMeta with error
  324. """
  325. return cls(time_cost=0.0, error=error, tool_config={})
  326. def to_dict(self) -> dict:
  327. return {
  328. "time_cost": self.time_cost,
  329. "error": self.error,
  330. "tool_config": self.tool_config,
  331. }
  332. class ToolLabel(BaseModel):
  333. """
  334. Tool label
  335. """
  336. name: str = Field(..., description="The name of the tool")
  337. label: I18nObject = Field(..., description="The label of the tool")
  338. icon: str = Field(..., description="The icon of the tool")
  339. class ToolInvokeFrom(Enum):
  340. """
  341. Enum class for tool invoke
  342. """
  343. WORKFLOW = "workflow"
  344. AGENT = "agent"
  345. PLUGIN = "plugin"
  346. class ToolSelector(BaseModel):
  347. dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
  348. class Parameter(BaseModel):
  349. name: str = Field(..., description="The name of the parameter")
  350. type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
  351. required: bool = Field(..., description="Whether the parameter is required")
  352. description: str = Field(..., description="The description of the parameter")
  353. default: Optional[Union[int, float, str]] = None
  354. options: Optional[list[PluginParameterOption]] = None
  355. provider_id: str = Field(..., description="The id of the provider")
  356. tool_name: str = Field(..., description="The name of the tool")
  357. tool_description: str = Field(..., description="The description of the tool")
  358. tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
  359. tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
  360. def to_plugin_parameter(self) -> dict[str, Any]:
  361. return self.model_dump()