您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

provider.py 7.8KB


  1. from pydantic import Field
  2. from sqlalchemy import select
  3. from core.entities.provider_entities import ProviderConfig
  4. from core.tools.__base.tool_provider import ToolProviderController
  5. from core.tools.__base.tool_runtime import ToolRuntime
  6. from core.tools.custom_tool.tool import ApiTool
  7. from core.tools.entities.common_entities import I18nObject
  8. from core.tools.entities.tool_bundle import ApiToolBundle
  9. from core.tools.entities.tool_entities import (
  10. ApiProviderAuthType,
  11. ToolDescription,
  12. ToolEntity,
  13. ToolIdentity,
  14. ToolProviderEntity,
  15. ToolProviderIdentity,
  16. ToolProviderType,
  17. )
  18. from extensions.ext_database import db
  19. from models.tools import ApiToolProvider
  20. class ApiToolProviderController(ToolProviderController):
  21. provider_id: str
  22. tenant_id: str
  23. tools: list[ApiTool] = Field(default_factory=list)
  24. def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str):
  25. super().__init__(entity)
  26. self.provider_id = provider_id
  27. self.tenant_id = tenant_id
  28. self.tools = []
  29. @classmethod
  30. def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
  31. credentials_schema = [
  32. ProviderConfig(
  33. name="auth_type",
  34. required=True,
  35. type=ProviderConfig.Type.SELECT,
  36. options=[
  37. ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  38. ProviderConfig.Option(value="api_key_header", label=I18nObject(en_US="Header", zh_Hans="请求头")),
  39. ProviderConfig.Option(
  40. value="api_key_query", label=I18nObject(en_US="Query Param", zh_Hans="查询参数")
  41. ),
  42. ],
  43. default="none",
  44. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  45. )
  46. ]
  47. if auth_type == ApiProviderAuthType.API_KEY_HEADER:
  48. credentials_schema = [
  49. *credentials_schema,
  50. ProviderConfig(
  51. name="api_key_header",
  52. required=False,
  53. default="Authorization",
  54. type=ProviderConfig.Type.TEXT_INPUT,
  55. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  56. ),
  57. ProviderConfig(
  58. name="api_key_value",
  59. required=True,
  60. type=ProviderConfig.Type.SECRET_INPUT,
  61. help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
  62. ),
  63. ProviderConfig(
  64. name="api_key_header_prefix",
  65. required=False,
  66. default="basic",
  67. type=ProviderConfig.Type.SELECT,
  68. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  69. options=[
  70. ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  71. ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  72. ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  73. ],
  74. ),
  75. ]
  76. elif auth_type == ApiProviderAuthType.API_KEY_QUERY:
  77. credentials_schema = [
  78. *credentials_schema,
  79. ProviderConfig(
  80. name="api_key_query_param",
  81. required=False,
  82. default="key",
  83. type=ProviderConfig.Type.TEXT_INPUT,
  84. help=I18nObject(
  85. en_US="The query parameter name of the api key", zh_Hans="携带 api key 的查询参数名称"
  86. ),
  87. ),
  88. ProviderConfig(
  89. name="api_key_value",
  90. required=True,
  91. type=ProviderConfig.Type.SECRET_INPUT,
  92. help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
  93. ),
  94. ]
  95. elif auth_type == ApiProviderAuthType.NONE:
  96. pass
  97. user = db_provider.user
  98. user_name = user.name if user else ""
  99. return ApiToolProviderController(
  100. entity=ToolProviderEntity(
  101. identity=ToolProviderIdentity(
  102. author=user_name,
  103. name=db_provider.name,
  104. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  105. description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
  106. icon=db_provider.icon,
  107. ),
  108. credentials_schema=credentials_schema,
  109. plugin_id=None,
  110. ),
  111. provider_id=db_provider.id or "",
  112. tenant_id=db_provider.tenant_id or "",
  113. )
  114. @property
  115. def provider_type(self) -> ToolProviderType:
  116. return ToolProviderType.API
  117. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  118. """
  119. parse tool bundle to tool
  120. :param tool_bundle: the tool bundle
  121. :return: the tool
  122. """
  123. return ApiTool(
  124. api_bundle=tool_bundle,
  125. provider_id=self.provider_id,
  126. entity=ToolEntity(
  127. identity=ToolIdentity(
  128. author=tool_bundle.author,
  129. name=tool_bundle.operation_id or "default_tool",
  130. label=I18nObject(
  131. en_US=tool_bundle.operation_id or "default_tool",
  132. zh_Hans=tool_bundle.operation_id or "default_tool",
  133. ),
  134. icon=self.entity.identity.icon,
  135. provider=self.provider_id,
  136. ),
  137. description=ToolDescription(
  138. human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
  139. llm=tool_bundle.summary or "",
  140. ),
  141. parameters=tool_bundle.parameters or [],
  142. ),
  143. runtime=ToolRuntime(tenant_id=self.tenant_id),
  144. )
  145. def load_bundled_tools(self, tools: list[ApiToolBundle]):
  146. """
  147. load bundled tools
  148. :param tools: the bundled tools
  149. :return: the tools
  150. """
  151. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  152. return self.tools
  153. def get_tools(self, tenant_id: str) -> list[ApiTool]:
  154. """
  155. fetch tools from database
  156. :param tenant_id: the tenant id
  157. :return: the tools
  158. """
  159. if len(self.tools) > 0:
  160. return self.tools
  161. tools: list[ApiTool] = []
  162. # get tenant api providers
  163. db_providers = db.session.scalars(
  164. select(ApiToolProvider).where(
  165. ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
  166. )
  167. ).all()
  168. if db_providers and len(db_providers) != 0:
  169. for db_provider in db_providers:
  170. for tool in db_provider.tools:
  171. assistant_tool = self._parse_tool_bundle(tool)
  172. tools.append(assistant_tool)
  173. self.tools = tools
  174. return tools
  175. def get_tool(self, tool_name: str) -> ApiTool:
  176. """
  177. get tool by name
  178. :param tool_name: the name of the tool
  179. :return: the tool
  180. """
  181. if self.tools is None:
  182. self.get_tools(self.tenant_id)
  183. for tool in self.tools:
  184. if tool.entity.identity.name == tool_name:
  185. return tool
  186. raise ValueError(f"tool {tool_name} not found")