You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_entities.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import enum
  2. from enum import Enum
  3. from typing import Any, Optional
  4. from pydantic import BaseModel, Field, ValidationInfo, field_validator
  5. from yarl import URL
  6. from configs import dify_config
  7. from core.entities.provider_entities import ProviderConfig
  8. from core.plugin.entities.oauth import OAuthSchema
  9. from core.plugin.entities.parameters import (
  10. PluginParameter,
  11. PluginParameterOption,
  12. PluginParameterType,
  13. as_normal_type,
  14. cast_parameter_value,
  15. init_frontend_parameter,
  16. )
  17. from core.tools.entities.common_entities import I18nObject
  18. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
  19. class DatasourceProviderType(enum.StrEnum):
  20. """
  21. Enum class for datasource provider
  22. """
  23. ONLINE_DOCUMENT = "online_document"
  24. LOCAL_FILE = "local_file"
  25. WEBSITE_CRAWL = "website_crawl"
  26. ONLINE_DRIVE = "online_drive"
  27. @classmethod
  28. def value_of(cls, value: str) -> "DatasourceProviderType":
  29. """
  30. Get value of given mode.
  31. :param value: mode value
  32. :return: mode
  33. """
  34. for mode in cls:
  35. if mode.value == value:
  36. return mode
  37. raise ValueError(f"invalid mode value {value}")
  38. class DatasourceParameter(PluginParameter):
  39. """
  40. Overrides type
  41. """
  42. class DatasourceParameterType(enum.StrEnum):
  43. """
  44. removes TOOLS_SELECTOR from PluginParameterType
  45. """
  46. STRING = PluginParameterType.STRING.value
  47. NUMBER = PluginParameterType.NUMBER.value
  48. BOOLEAN = PluginParameterType.BOOLEAN.value
  49. SELECT = PluginParameterType.SELECT.value
  50. SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
  51. FILE = PluginParameterType.FILE.value
  52. FILES = PluginParameterType.FILES.value
  53. # deprecated, should not use.
  54. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
  55. def as_normal_type(self):
  56. return as_normal_type(self)
  57. def cast_value(self, value: Any):
  58. return cast_parameter_value(self, value)
  59. type: DatasourceParameterType = Field(..., description="The type of the parameter")
  60. description: I18nObject = Field(..., description="The description of the parameter")
  61. @classmethod
  62. def get_simple_instance(
  63. cls,
  64. name: str,
  65. typ: DatasourceParameterType,
  66. required: bool,
  67. options: Optional[list[str]] = None,
  68. ) -> "DatasourceParameter":
  69. """
  70. get a simple datasource parameter
  71. :param name: the name of the parameter
  72. :param llm_description: the description presented to the LLM
  73. :param typ: the type of the parameter
  74. :param required: if the parameter is required
  75. :param options: the options of the parameter
  76. """
  77. # convert options to ToolParameterOption
  78. # FIXME fix the type error
  79. if options:
  80. option_objs = [
  81. PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
  82. for option in options
  83. ]
  84. else:
  85. option_objs = []
  86. return cls(
  87. name=name,
  88. label=I18nObject(en_US="", zh_Hans=""),
  89. placeholder=None,
  90. type=typ,
  91. required=required,
  92. options=option_objs,
  93. description=I18nObject(en_US="", zh_Hans=""),
  94. )
  95. def init_frontend_parameter(self, value: Any):
  96. return init_frontend_parameter(self, self.type, value)
  97. class DatasourceIdentity(BaseModel):
  98. author: str = Field(..., description="The author of the datasource")
  99. name: str = Field(..., description="The name of the datasource")
  100. label: I18nObject = Field(..., description="The label of the datasource")
  101. provider: str = Field(..., description="The provider of the datasource")
  102. icon: Optional[str] = None
  103. class DatasourceEntity(BaseModel):
  104. identity: DatasourceIdentity
  105. parameters: list[DatasourceParameter] = Field(default_factory=list)
  106. description: I18nObject = Field(..., description="The label of the datasource")
  107. output_schema: Optional[dict] = None
  108. @field_validator("parameters", mode="before")
  109. @classmethod
  110. def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
  111. return v or []
  112. class DatasourceProviderIdentity(BaseModel):
  113. author: str = Field(..., description="The author of the tool")
  114. name: str = Field(..., description="The name of the tool")
  115. description: I18nObject = Field(..., description="The description of the tool")
  116. icon: str = Field(..., description="The icon of the tool")
  117. label: I18nObject = Field(..., description="The label of the tool")
  118. tags: Optional[list[ToolLabelEnum]] = Field(
  119. default=[],
  120. description="The tags of the tool",
  121. )
  122. def generate_datasource_icon_url(self, tenant_id: str) -> str:
  123. HARD_CODED_DATASOURCE_ICONS = ["https://assets.dify.ai/images/File%20Upload.svg"]
  124. if self.icon in HARD_CODED_DATASOURCE_ICONS:
  125. return self.icon
  126. return str(
  127. URL(dify_config.CONSOLE_API_URL or "/")
  128. / "console"
  129. / "api"
  130. / "workspaces"
  131. / "current"
  132. / "plugin"
  133. / "icon"
  134. % {"tenant_id": tenant_id, "filename": self.icon}
  135. )
  136. class DatasourceProviderEntity(BaseModel):
  137. """
  138. Datasource provider entity
  139. """
  140. identity: DatasourceProviderIdentity
  141. credentials_schema: list[ProviderConfig] = Field(default_factory=list)
  142. oauth_schema: Optional[OAuthSchema] = None
  143. provider_type: DatasourceProviderType
  144. class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
  145. datasources: list[DatasourceEntity] = Field(default_factory=list)
  146. class DatasourceInvokeMeta(BaseModel):
  147. """
  148. Datasource invoke meta
  149. """
  150. time_cost: float = Field(..., description="The time cost of the tool invoke")
  151. error: Optional[str] = None
  152. tool_config: Optional[dict] = None
  153. @classmethod
  154. def empty(cls) -> "DatasourceInvokeMeta":
  155. """
  156. Get an empty instance of DatasourceInvokeMeta
  157. """
  158. return cls(time_cost=0.0, error=None, tool_config={})
  159. @classmethod
  160. def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
  161. """
  162. Get an instance of DatasourceInvokeMeta with error
  163. """
  164. return cls(time_cost=0.0, error=error, tool_config={})
  165. def to_dict(self) -> dict:
  166. return {
  167. "time_cost": self.time_cost,
  168. "error": self.error,
  169. "tool_config": self.tool_config,
  170. }
  171. class DatasourceLabel(BaseModel):
  172. """
  173. Datasource label
  174. """
  175. name: str = Field(..., description="The name of the tool")
  176. label: I18nObject = Field(..., description="The label of the tool")
  177. icon: str = Field(..., description="The icon of the tool")
  178. class DatasourceInvokeFrom(Enum):
  179. """
  180. Enum class for datasource invoke
  181. """
  182. RAG_PIPELINE = "rag_pipeline"
  183. class OnlineDocumentPage(BaseModel):
  184. """
  185. Online document page
  186. """
  187. page_id: str = Field(..., description="The page id")
  188. page_name: str = Field(..., description="The page title")
  189. page_icon: Optional[dict] = Field(None, description="The page icon")
  190. type: str = Field(..., description="The type of the page")
  191. last_edited_time: str = Field(..., description="The last edited time")
  192. parent_id: Optional[str] = Field(None, description="The parent page id")
  193. class OnlineDocumentInfo(BaseModel):
  194. """
  195. Online document info
  196. """
  197. workspace_id: Optional[str] = Field(None, description="The workspace id")
  198. workspace_name: Optional[str] = Field(None, description="The workspace name")
  199. workspace_icon: Optional[str] = Field(None, description="The workspace icon")
  200. total: int = Field(..., description="The total number of documents")
  201. pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
  202. class OnlineDocumentPagesMessage(BaseModel):
  203. """
  204. Get online document pages response
  205. """
  206. result: list[OnlineDocumentInfo]
  207. class GetOnlineDocumentPageContentRequest(BaseModel):
  208. """
  209. Get online document page content request
  210. """
  211. workspace_id: str = Field(..., description="The workspace id")
  212. page_id: str = Field(..., description="The page id")
  213. type: str = Field(..., description="The type of the page")
  214. class OnlineDocumentPageContent(BaseModel):
  215. """
  216. Online document page content
  217. """
  218. workspace_id: str = Field(..., description="The workspace id")
  219. page_id: str = Field(..., description="The page id")
  220. content: str = Field(..., description="The content of the page")
  221. class GetOnlineDocumentPageContentResponse(BaseModel):
  222. """
  223. Get online document page content response
  224. """
  225. result: OnlineDocumentPageContent
  226. class GetWebsiteCrawlRequest(BaseModel):
  227. """
  228. Get website crawl request
  229. """
  230. crawl_parameters: dict = Field(..., description="The crawl parameters")
  231. class WebSiteInfoDetail(BaseModel):
  232. source_url: str = Field(..., description="The url of the website")
  233. content: str = Field(..., description="The content of the website")
  234. title: str = Field(..., description="The title of the website")
  235. description: str = Field(..., description="The description of the website")
  236. class WebSiteInfo(BaseModel):
  237. """
  238. Website info
  239. """
  240. status: Optional[str] = Field(..., description="crawl job status")
  241. web_info_list: Optional[list[WebSiteInfoDetail]] = []
  242. total: Optional[int] = Field(default=0, description="The total number of websites")
  243. completed: Optional[int] = Field(default=0, description="The number of completed websites")
  244. class WebsiteCrawlMessage(BaseModel):
  245. """
  246. Get website crawl response
  247. """
  248. result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
  249. class DatasourceMessage(ToolInvokeMessage):
  250. pass
  251. #########################
  252. # Online drive file
  253. #########################
  254. class OnlineDriveFile(BaseModel):
  255. """
  256. Online drive file
  257. """
  258. id: str = Field(..., description="The file ID")
  259. name: str = Field(..., description="The file name")
  260. size: int = Field(..., description="The file size")
  261. type: str = Field(..., description="The file type: folder or file")
  262. class OnlineDriveFileBucket(BaseModel):
  263. """
  264. Online drive file bucket
  265. """
  266. bucket: Optional[str] = Field(None, description="The file bucket")
  267. files: list[OnlineDriveFile] = Field(..., description="The file list")
  268. is_truncated: bool = Field(False, description="Whether the result is truncated")
  269. next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
  270. class OnlineDriveBrowseFilesRequest(BaseModel):
  271. """
  272. Get online drive file list request
  273. """
  274. bucket: Optional[str] = Field(None, description="The file bucket")
  275. prefix: str = Field(..., description="The parent folder ID")
  276. max_keys: int = Field(20, description="Page size for pagination")
  277. next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
  278. class OnlineDriveBrowseFilesResponse(BaseModel):
  279. """
  280. Get online drive file list response
  281. """
  282. result: list[OnlineDriveFileBucket] = Field(..., description="The list of file buckets")
  283. class OnlineDriveDownloadFileRequest(BaseModel):
  284. """
  285. Get online drive file
  286. """
  287. id: str = Field(..., description="The id of the file")
  288. bucket: Optional[str] = Field(None, description="The name of the bucket")