You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Generator
  3. from copy import deepcopy
  4. from typing import TYPE_CHECKING, Any, Optional
  5. if TYPE_CHECKING:
  6. from models.model import File
  7. from core.datasource.__base.datasource_runtime import DatasourceRuntime
  8. from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
  9. from core.tools.entities.tool_entities import (
  10. ToolInvokeMessage,
  11. ToolParameter,
  12. )
  13. class Datasource(ABC):
  14. """
  15. The base class of a datasource
  16. """
  17. entity: DatasourceEntity
  18. runtime: DatasourceRuntime
  19. def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None:
  20. self.entity = entity
  21. self.runtime = runtime
  22. def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource":
  23. """
  24. fork a new datasource with metadata
  25. :return: the new datasource
  26. """
  27. return self.__class__(
  28. entity=self.entity.model_copy(),
  29. runtime=runtime,
  30. )
  31. @abstractmethod
  32. def datasource_provider_type(self) -> DatasourceProviderType:
  33. """
  34. get the datasource provider type
  35. :return: the tool provider type
  36. """
  37. def invoke(
  38. self,
  39. user_id: str,
  40. tool_parameters: dict[str, Any],
  41. conversation_id: Optional[str] = None,
  42. app_id: Optional[str] = None,
  43. message_id: Optional[str] = None,
  44. ) -> Generator[ToolInvokeMessage]:
  45. if self.runtime and self.runtime.runtime_parameters:
  46. tool_parameters.update(self.runtime.runtime_parameters)
  47. # try parse tool parameters into the correct type
  48. tool_parameters = self._transform_tool_parameters_type(tool_parameters)
  49. result = self._invoke(
  50. user_id=user_id,
  51. tool_parameters=tool_parameters,
  52. conversation_id=conversation_id,
  53. app_id=app_id,
  54. message_id=message_id,
  55. )
  56. if isinstance(result, ToolInvokeMessage):
  57. def single_generator() -> Generator[ToolInvokeMessage, None, None]:
  58. yield result
  59. return single_generator()
  60. elif isinstance(result, list):
  61. def generator() -> Generator[ToolInvokeMessage, None, None]:
  62. yield from result
  63. return generator()
  64. else:
  65. return result
  66. def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
  67. """
  68. Transform tool parameters type
  69. """
  70. # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
  71. result = deepcopy(tool_parameters)
  72. for parameter in self.entity.parameters or []:
  73. if parameter.name in tool_parameters:
  74. result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
  75. return result
  76. @abstractmethod
  77. def _invoke(
  78. self,
  79. user_id: str,
  80. tool_parameters: dict[str, Any],
  81. conversation_id: Optional[str] = None,
  82. app_id: Optional[str] = None,
  83. message_id: Optional[str] = None,
  84. ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
  85. pass
  86. def get_runtime_parameters(
  87. self,
  88. conversation_id: Optional[str] = None,
  89. app_id: Optional[str] = None,
  90. message_id: Optional[str] = None,
  91. ) -> list[ToolParameter]:
  92. """
  93. get the runtime parameters
  94. interface for developer to dynamic change the parameters of a tool depends on the variables pool
  95. :return: the runtime parameters
  96. """
  97. return self.entity.parameters
  98. def get_merged_runtime_parameters(
  99. self,
  100. conversation_id: Optional[str] = None,
  101. app_id: Optional[str] = None,
  102. message_id: Optional[str] = None,
  103. ) -> list[ToolParameter]:
  104. """
  105. get merged runtime parameters
  106. :return: merged runtime parameters
  107. """
  108. parameters = self.entity.parameters
  109. parameters = parameters.copy()
  110. user_parameters = self.get_runtime_parameters() or []
  111. user_parameters = user_parameters.copy()
  112. # override parameters
  113. for parameter in user_parameters:
  114. # check if parameter in tool parameters
  115. for tool_parameter in parameters:
  116. if tool_parameter.name == parameter.name:
  117. # override parameter
  118. tool_parameter.type = parameter.type
  119. tool_parameter.form = parameter.form
  120. tool_parameter.required = parameter.required
  121. tool_parameter.default = parameter.default
  122. tool_parameter.options = parameter.options
  123. tool_parameter.llm_description = parameter.llm_description
  124. break
  125. else:
  126. # add new parameter
  127. parameters.append(parameter)
  128. return parameters
  129. def create_image_message(
  130. self,
  131. image: str,
  132. ) -> ToolInvokeMessage:
  133. """
  134. create an image message
  135. :param image: the url of the image
  136. :return: the image message
  137. """
  138. return ToolInvokeMessage(
  139. type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
  140. )
  141. def create_file_message(self, file: "File") -> ToolInvokeMessage:
  142. return ToolInvokeMessage(
  143. type=ToolInvokeMessage.MessageType.FILE,
  144. message=ToolInvokeMessage.FileMessage(),
  145. meta={"file": file},
  146. )
  147. def create_link_message(self, link: str) -> ToolInvokeMessage:
  148. """
  149. create a link message
  150. :param link: the url of the link
  151. :return: the link message
  152. """
  153. return ToolInvokeMessage(
  154. type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
  155. )
  156. def create_text_message(self, text: str) -> ToolInvokeMessage:
  157. """
  158. create a text message
  159. :param text: the text
  160. :return: the text message
  161. """
  162. return ToolInvokeMessage(
  163. type=ToolInvokeMessage.MessageType.TEXT,
  164. message=ToolInvokeMessage.TextMessage(text=text),
  165. )
  166. def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
  167. """
  168. create a blob message
  169. :param blob: the blob
  170. :param meta: the meta info of blob object
  171. :return: the blob message
  172. """
  173. return ToolInvokeMessage(
  174. type=ToolInvokeMessage.MessageType.BLOB,
  175. message=ToolInvokeMessage.BlobMessage(blob=blob),
  176. meta=meta,
  177. )
  178. def create_json_message(self, object: dict) -> ToolInvokeMessage:
  179. """
  180. create a json message
  181. """
  182. return ToolInvokeMessage(
  183. type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
  184. )