You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tool_entities.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. import base64
  2. import enum
  3. from collections.abc import Mapping
  4. from enum import Enum
  5. from typing import Any, Optional, Union
  6. from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
  7. from core.entities.provider_entities import ProviderConfig
  8. from core.plugin.entities.parameters import (
  9. MCPServerParameterType,
  10. PluginParameter,
  11. PluginParameterOption,
  12. PluginParameterType,
  13. as_normal_type,
  14. cast_parameter_value,
  15. init_frontend_parameter,
  16. )
  17. from core.rag.entities.citation_metadata import RetrievalSourceMetadata
  18. from core.tools.entities.common_entities import I18nObject
  19. from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
  20. class ToolLabelEnum(Enum):
  21. SEARCH = "search"
  22. IMAGE = "image"
  23. VIDEOS = "videos"
  24. WEATHER = "weather"
  25. FINANCE = "finance"
  26. DESIGN = "design"
  27. TRAVEL = "travel"
  28. SOCIAL = "social"
  29. NEWS = "news"
  30. MEDICAL = "medical"
  31. PRODUCTIVITY = "productivity"
  32. EDUCATION = "education"
  33. BUSINESS = "business"
  34. ENTERTAINMENT = "entertainment"
  35. UTILITIES = "utilities"
  36. OTHER = "other"
  37. class ToolProviderType(enum.StrEnum):
  38. """
  39. Enum class for tool provider
  40. """
  41. PLUGIN = "plugin"
  42. BUILT_IN = "builtin"
  43. WORKFLOW = "workflow"
  44. API = "api"
  45. APP = "app"
  46. DATASET_RETRIEVAL = "dataset-retrieval"
  47. MCP = "mcp"
  48. @classmethod
  49. def value_of(cls, value: str) -> "ToolProviderType":
  50. """
  51. Get value of given mode.
  52. :param value: mode value
  53. :return: mode
  54. """
  55. for mode in cls:
  56. if mode.value == value:
  57. return mode
  58. raise ValueError(f"invalid mode value {value}")
  59. class ApiProviderSchemaType(Enum):
  60. """
  61. Enum class for api provider schema type.
  62. """
  63. OPENAPI = "openapi"
  64. SWAGGER = "swagger"
  65. OPENAI_PLUGIN = "openai_plugin"
  66. OPENAI_ACTIONS = "openai_actions"
  67. @classmethod
  68. def value_of(cls, value: str) -> "ApiProviderSchemaType":
  69. """
  70. Get value of given mode.
  71. :param value: mode value
  72. :return: mode
  73. """
  74. for mode in cls:
  75. if mode.value == value:
  76. return mode
  77. raise ValueError(f"invalid mode value {value}")
  78. class ApiProviderAuthType(Enum):
  79. """
  80. Enum class for api provider auth type.
  81. """
  82. NONE = "none"
  83. API_KEY_HEADER = "api_key_header"
  84. API_KEY_QUERY = "api_key_query"
  85. @classmethod
  86. def value_of(cls, value: str) -> "ApiProviderAuthType":
  87. """
  88. Get value of given mode.
  89. :param value: mode value
  90. :return: mode
  91. """
  92. # 'api_key' deprecated in PR #21656
  93. # normalize & tiny alias for backward compatibility
  94. v = (value or "").strip().lower()
  95. if v == "api_key":
  96. v = cls.API_KEY_HEADER.value
  97. for mode in cls:
  98. if mode.value == v:
  99. return mode
  100. valid = ", ".join(m.value for m in cls)
  101. raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
  102. class ToolInvokeMessage(BaseModel):
  103. class TextMessage(BaseModel):
  104. text: str
  105. class JsonMessage(BaseModel):
  106. json_object: dict
  107. class BlobMessage(BaseModel):
  108. blob: bytes
  109. class BlobChunkMessage(BaseModel):
  110. id: str = Field(..., description="The id of the blob")
  111. sequence: int = Field(..., description="The sequence of the chunk")
  112. total_length: int = Field(..., description="The total length of the blob")
  113. blob: bytes = Field(..., description="The blob data of the chunk")
  114. end: bool = Field(..., description="Whether the chunk is the last chunk")
  115. class FileMessage(BaseModel):
  116. pass
  117. class VariableMessage(BaseModel):
  118. variable_name: str = Field(..., description="The name of the variable")
  119. variable_value: Any = Field(..., description="The value of the variable")
  120. stream: bool = Field(default=False, description="Whether the variable is streamed")
  121. @model_validator(mode="before")
  122. @classmethod
  123. def transform_variable_value(cls, values) -> Any:
  124. """
  125. Only basic types and lists are allowed.
  126. """
  127. value = values.get("variable_value")
  128. if not isinstance(value, dict | list | str | int | float | bool):
  129. raise ValueError("Only basic types and lists are allowed.")
  130. # if stream is true, the value must be a string
  131. if values.get("stream"):
  132. if not isinstance(value, str):
  133. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  134. return values
  135. @field_validator("variable_name", mode="before")
  136. @classmethod
  137. def transform_variable_name(cls, value: str) -> str:
  138. """
  139. The variable name must be a string.
  140. """
  141. if value in {"json", "text", "files"}:
  142. raise ValueError(f"The variable name '{value}' is reserved.")
  143. return value
  144. class LogMessage(BaseModel):
  145. class LogStatus(Enum):
  146. START = "start"
  147. ERROR = "error"
  148. SUCCESS = "success"
  149. id: str
  150. label: str = Field(..., description="The label of the log")
  151. parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
  152. error: Optional[str] = Field(default=None, description="The error message")
  153. status: LogStatus = Field(..., description="The status of the log")
  154. data: Mapping[str, Any] = Field(..., description="Detailed log data")
  155. metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
  156. class RetrieverResourceMessage(BaseModel):
  157. retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
  158. context: str = Field(..., description="context")
  159. class MessageType(Enum):
  160. TEXT = "text"
  161. IMAGE = "image"
  162. LINK = "link"
  163. BLOB = "blob"
  164. JSON = "json"
  165. IMAGE_LINK = "image_link"
  166. BINARY_LINK = "binary_link"
  167. VARIABLE = "variable"
  168. FILE = "file"
  169. LOG = "log"
  170. BLOB_CHUNK = "blob_chunk"
  171. RETRIEVER_RESOURCES = "retriever_resources"
  172. type: MessageType = MessageType.TEXT
  173. """
  174. plain text, image url or link url
  175. """
  176. message: (
  177. JsonMessage
  178. | TextMessage
  179. | BlobChunkMessage
  180. | BlobMessage
  181. | LogMessage
  182. | FileMessage
  183. | None
  184. | VariableMessage
  185. | RetrieverResourceMessage
  186. )
  187. meta: dict[str, Any] | None = None
  188. @field_validator("message", mode="before")
  189. @classmethod
  190. def decode_blob_message(cls, v):
  191. if isinstance(v, dict) and "blob" in v:
  192. try:
  193. v["blob"] = base64.b64decode(v["blob"])
  194. except Exception:
  195. pass
  196. return v
  197. @field_serializer("message")
  198. def serialize_message(self, v):
  199. if isinstance(v, self.BlobMessage):
  200. return {"blob": base64.b64encode(v.blob).decode("utf-8")}
  201. return v
  202. class ToolInvokeMessageBinary(BaseModel):
  203. mimetype: str = Field(..., description="The mimetype of the binary")
  204. url: str = Field(..., description="The url of the binary")
  205. file_var: Optional[dict[str, Any]] = None
  206. class ToolParameter(PluginParameter):
  207. """
  208. Overrides type
  209. """
  210. class ToolParameterType(enum.StrEnum):
  211. """
  212. removes TOOLS_SELECTOR from PluginParameterType
  213. """
  214. STRING = PluginParameterType.STRING.value
  215. NUMBER = PluginParameterType.NUMBER.value
  216. BOOLEAN = PluginParameterType.BOOLEAN.value
  217. SELECT = PluginParameterType.SELECT.value
  218. SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
  219. FILE = PluginParameterType.FILE.value
  220. FILES = PluginParameterType.FILES.value
  221. APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
  222. MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
  223. ANY = PluginParameterType.ANY.value
  224. DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
  225. # MCP object and array type parameters
  226. ARRAY = MCPServerParameterType.ARRAY.value
  227. OBJECT = MCPServerParameterType.OBJECT.value
  228. # deprecated, should not use.
  229. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
  230. def as_normal_type(self):
  231. return as_normal_type(self)
  232. def cast_value(self, value: Any):
  233. return cast_parameter_value(self, value)
  234. class ToolParameterForm(Enum):
  235. SCHEMA = "schema" # should be set while adding tool
  236. FORM = "form" # should be set before invoking tool
  237. LLM = "llm" # will be set by LLM
  238. type: ToolParameterType = Field(..., description="The type of the parameter")
  239. human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
  240. form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
  241. llm_description: Optional[str] = None
  242. # MCP object and array type parameters use this field to store the schema
  243. input_schema: Optional[dict] = None
  244. @classmethod
  245. def get_simple_instance(
  246. cls,
  247. name: str,
  248. llm_description: str,
  249. typ: ToolParameterType,
  250. required: bool,
  251. options: Optional[list[str]] = None,
  252. ) -> "ToolParameter":
  253. """
  254. get a simple tool parameter
  255. :param name: the name of the parameter
  256. :param llm_description: the description presented to the LLM
  257. :param typ: the type of the parameter
  258. :param required: if the parameter is required
  259. :param options: the options of the parameter
  260. """
  261. # convert options to ToolParameterOption
  262. if options:
  263. option_objs = [
  264. PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
  265. for option in options
  266. ]
  267. else:
  268. option_objs = []
  269. return cls(
  270. name=name,
  271. label=I18nObject(en_US="", zh_Hans=""),
  272. placeholder=None,
  273. human_description=I18nObject(en_US="", zh_Hans=""),
  274. type=typ,
  275. form=cls.ToolParameterForm.LLM,
  276. llm_description=llm_description,
  277. required=required,
  278. options=option_objs,
  279. )
  280. def init_frontend_parameter(self, value: Any):
  281. return init_frontend_parameter(self, self.type, value)
  282. class ToolProviderIdentity(BaseModel):
  283. author: str = Field(..., description="The author of the tool")
  284. name: str = Field(..., description="The name of the tool")
  285. description: I18nObject = Field(..., description="The description of the tool")
  286. icon: str = Field(..., description="The icon of the tool")
  287. icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool")
  288. label: I18nObject = Field(..., description="The label of the tool")
  289. tags: Optional[list[ToolLabelEnum]] = Field(
  290. default=[],
  291. description="The tags of the tool",
  292. )
  293. class ToolIdentity(BaseModel):
  294. author: str = Field(..., description="The author of the tool")
  295. name: str = Field(..., description="The name of the tool")
  296. label: I18nObject = Field(..., description="The label of the tool")
  297. provider: str = Field(..., description="The provider of the tool")
  298. icon: Optional[str] = None
  299. class ToolDescription(BaseModel):
  300. human: I18nObject = Field(..., description="The description presented to the user")
  301. llm: str = Field(..., description="The description presented to the LLM")
  302. class ToolEntity(BaseModel):
  303. identity: ToolIdentity
  304. parameters: list[ToolParameter] = Field(default_factory=list)
  305. description: Optional[ToolDescription] = None
  306. output_schema: Optional[dict] = None
  307. has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
  308. # pydantic configs
  309. model_config = ConfigDict(protected_namespaces=())
  310. @field_validator("parameters", mode="before")
  311. @classmethod
  312. def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
  313. return v or []
  314. class OAuthSchema(BaseModel):
  315. client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
  316. credentials_schema: list[ProviderConfig] = Field(
  317. default_factory=list, description="The schema of the OAuth credentials"
  318. )
  319. class ToolProviderEntity(BaseModel):
  320. identity: ToolProviderIdentity
  321. plugin_id: Optional[str] = None
  322. credentials_schema: list[ProviderConfig] = Field(default_factory=list)
  323. oauth_schema: Optional[OAuthSchema] = None
  324. class ToolProviderEntityWithPlugin(ToolProviderEntity):
  325. tools: list[ToolEntity] = Field(default_factory=list)
  326. class WorkflowToolParameterConfiguration(BaseModel):
  327. """
  328. Workflow tool configuration
  329. """
  330. name: str = Field(..., description="The name of the parameter")
  331. description: str = Field(..., description="The description of the parameter")
  332. form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
  333. class ToolInvokeMeta(BaseModel):
  334. """
  335. Tool invoke meta
  336. """
  337. time_cost: float = Field(..., description="The time cost of the tool invoke")
  338. error: Optional[str] = None
  339. tool_config: Optional[dict] = None
  340. @classmethod
  341. def empty(cls) -> "ToolInvokeMeta":
  342. """
  343. Get an empty instance of ToolInvokeMeta
  344. """
  345. return cls(time_cost=0.0, error=None, tool_config={})
  346. @classmethod
  347. def error_instance(cls, error: str) -> "ToolInvokeMeta":
  348. """
  349. Get an instance of ToolInvokeMeta with error
  350. """
  351. return cls(time_cost=0.0, error=error, tool_config={})
  352. def to_dict(self) -> dict:
  353. return {
  354. "time_cost": self.time_cost,
  355. "error": self.error,
  356. "tool_config": self.tool_config,
  357. }
  358. class ToolLabel(BaseModel):
  359. """
  360. Tool label
  361. """
  362. name: str = Field(..., description="The name of the tool")
  363. label: I18nObject = Field(..., description="The label of the tool")
  364. icon: str = Field(..., description="The icon of the tool")
  365. class ToolInvokeFrom(Enum):
  366. """
  367. Enum class for tool invoke
  368. """
  369. WORKFLOW = "workflow"
  370. AGENT = "agent"
  371. PLUGIN = "plugin"
  372. class ToolSelector(BaseModel):
  373. dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
  374. class Parameter(BaseModel):
  375. name: str = Field(..., description="The name of the parameter")
  376. type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
  377. required: bool = Field(..., description="Whether the parameter is required")
  378. description: str = Field(..., description="The description of the parameter")
  379. default: Optional[Union[int, float, str]] = None
  380. options: Optional[list[PluginParameterOption]] = None
  381. provider_id: str = Field(..., description="The id of the provider")
  382. credential_id: Optional[str] = Field(default=None, description="The id of the credential")
  383. tool_name: str = Field(..., description="The name of the tool")
  384. tool_description: str = Field(..., description="The description of the tool")
  385. tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
  386. tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
  387. def to_plugin_parameter(self) -> dict[str, Any]:
  388. return self.model_dump()
  389. class CredentialType(enum.StrEnum):
  390. API_KEY = "api-key"
  391. OAUTH2 = "oauth2"
  392. def get_name(self):
  393. if self == CredentialType.API_KEY:
  394. return "API KEY"
  395. elif self == CredentialType.OAUTH2:
  396. return "AUTH"
  397. else:
  398. return self.value.replace("-", " ").upper()
  399. def is_editable(self):
  400. return self == CredentialType.API_KEY
  401. def is_validate_allowed(self):
  402. return self == CredentialType.API_KEY
  403. @classmethod
  404. def values(cls):
  405. return [item.value for item in cls]
  406. @classmethod
  407. def of(cls, credential_type: str) -> "CredentialType":
  408. type_name = credential_type.lower()
  409. if type_name == "api-key":
  410. return cls.API_KEY
  411. elif type_name == "oauth2":
  412. return cls.OAUTH2
  413. else:
  414. raise ValueError(f"Invalid credential type: {credential_type}")