You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource.py 7.0KB


  1. from abc import ABC, abstractmethod
  2. from collections.abc import Generator
  3. from copy import deepcopy
  4. from typing import TYPE_CHECKING, Any, Optional
  5. if TYPE_CHECKING:
  6. from models.model import File
  7. from core.datasource.__base.datasource_runtime import DatasourceRuntime
  8. from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
  9. from core.tools.entities.tool_entities import (
  10. ToolInvokeMessage,
  11. ToolParameter,
  12. )
  13. class Datasource(ABC):
  14. """
  15. The base class of a datasource
  16. """
  17. entity: DatasourceEntity
  18. runtime: DatasourceRuntime
  19. def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None:
  20. self.entity = entity
  21. self.runtime = runtime
  22. def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource":
  23. """
  24. fork a new datasource with metadata
  25. :return: the new datasource
  26. """
  27. return self.__class__(
  28. entity=self.entity.model_copy(),
  29. runtime=runtime,
  30. )
  31. @abstractmethod
  32. def datasource_provider_type(self) -> DatasourceProviderType:
  33. """
  34. get the datasource provider type
  35. :return: the tool provider type
  36. """
  37. def invoke(
  38. self,
  39. user_id: str,
  40. tool_parameters: dict[str, Any],
  41. conversation_id: Optional[str] = None,
  42. app_id: Optional[str] = None,
  43. message_id: Optional[str] = None,
  44. ) -> Generator[ToolInvokeMessage]:
  45. if self.runtime and self.runtime.runtime_parameters:
  46. tool_parameters.update(self.runtime.runtime_parameters)
  47. # try parse tool parameters into the correct type
  48. tool_parameters = self._transform_tool_parameters_type(tool_parameters)
  49. result = self._invoke(
  50. user_id=user_id,
  51. tool_parameters=tool_parameters,
  52. conversation_id=conversation_id,
  53. app_id=app_id,
  54. message_id=message_id,
  55. )
  56. if isinstance(result, ToolInvokeMessage):
  57. def single_generator() -> Generator[ToolInvokeMessage, None, None]:
  58. yield result
  59. return single_generator()
  60. elif isinstance(result, list):
  61. def generator() -> Generator[ToolInvokeMessage, None, None]:
  62. yield from result
  63. return generator()
  64. else:
  65. return result
  66. def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
  67. """
  68. Transform tool parameters type
  69. """
  70. # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
  71. result = deepcopy(tool_parameters)
  72. for parameter in self.entity.parameters or []:
  73. if parameter.name in tool_parameters:
  74. result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
  75. return result
  76. @abstractmethod
  77. def _invoke(
  78. self,
  79. user_id: str,
  80. tool_parameters: dict[str, Any],
  81. conversation_id: Optional[str] = None,
  82. app_id: Optional[str] = None,
  83. message_id: Optional[str] = None,
  84. ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
  85. pass
  86. def get_runtime_parameters(
  87. self,
  88. conversation_id: Optional[str] = None,
  89. app_id: Optional[str] = None,
  90. message_id: Optional[str] = None,
  91. ) -> list[ToolParameter]:
  92. """
  93. get the runtime parameters
  94. interface for developer to dynamic change the parameters of a tool depends on the variables pool
  95. :return: the runtime parameters
  96. """
  97. return self.entity.parameters
  98. def get_merged_runtime_parameters(
  99. self,
  100. conversation_id: Optional[str] = None,
  101. app_id: Optional[str] = None,
  102. message_id: Optional[str] = None,
  103. ) -> list[ToolParameter]:
  104. """
  105. get merged runtime parameters
  106. :return: merged runtime parameters
  107. """
  108. parameters = self.entity.parameters
  109. parameters = parameters.copy()
  110. user_parameters = self.get_runtime_parameters() or []
  111. user_parameters = user_parameters.copy()
  112. # override parameters
  113. for parameter in user_parameters:
  114. # check if parameter in tool parameters
  115. for tool_parameter in parameters:
  116. if tool_parameter.name == parameter.name:
  117. # override parameter
  118. tool_parameter.type = parameter.type
  119. tool_parameter.form = parameter.form
  120. tool_parameter.required = parameter.required
  121. tool_parameter.default = parameter.default
  122. tool_parameter.options = parameter.options
  123. tool_parameter.llm_description = parameter.llm_description
  124. break
  125. else:
  126. # add new parameter
  127. parameters.append(parameter)
  128. return parameters
  129. def create_image_message(
  130. self,
  131. image: str,
  132. ) -> ToolInvokeMessage:
  133. """
  134. create an image message
  135. :param image: the url of the image
  136. :return: the image message
  137. """
  138. return ToolInvokeMessage(
  139. type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
  140. )
  141. def create_file_message(self, file: "File") -> ToolInvokeMessage:
  142. return ToolInvokeMessage(
  143. type=ToolInvokeMessage.MessageType.FILE,
  144. message=ToolInvokeMessage.FileMessage(),
  145. meta={"file": file},
  146. )
  147. def create_link_message(self, link: str) -> ToolInvokeMessage:
  148. """
  149. create a link message
  150. :param link: the url of the link
  151. :return: the link message
  152. """
  153. return ToolInvokeMessage(
  154. type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
  155. )
  156. def create_text_message(self, text: str) -> ToolInvokeMessage:
  157. """
  158. create a text message
  159. :param text: the text
  160. :return: the text message
  161. """
  162. return ToolInvokeMessage(
  163. type=ToolInvokeMessage.MessageType.TEXT,
  164. message=ToolInvokeMessage.TextMessage(text=text),
  165. )
  166. def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
  167. """
  168. create a blob message
  169. :param blob: the blob
  170. :param meta: the meta info of blob object
  171. :return: the blob message
  172. """
  173. return ToolInvokeMessage(
  174. type=ToolInvokeMessage.MessageType.BLOB,
  175. message=ToolInvokeMessage.BlobMessage(blob=blob),
  176. meta=meta,
  177. )
  178. def create_json_message(self, object: dict) -> ToolInvokeMessage:
  179. """
  180. create a json message
  181. """
  182. return ToolInvokeMessage(
  183. type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
  184. )