Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from pydantic import Field
  2. from core.entities.provider_entities import ProviderConfig
  3. from core.tools.__base.tool_provider import ToolProviderController
  4. from core.tools.__base.tool_runtime import ToolRuntime
  5. from core.tools.custom_tool.tool import ApiTool
  6. from core.tools.entities.common_entities import I18nObject
  7. from core.tools.entities.tool_bundle import ApiToolBundle
  8. from core.tools.entities.tool_entities import (
  9. ApiProviderAuthType,
  10. ToolDescription,
  11. ToolEntity,
  12. ToolIdentity,
  13. ToolProviderEntity,
  14. ToolProviderIdentity,
  15. ToolProviderType,
  16. )
  17. from extensions.ext_database import db
  18. from models.tools import ApiToolProvider
  19. class ApiToolProviderController(ToolProviderController):
  20. provider_id: str
  21. tenant_id: str
  22. tools: list[ApiTool] = Field(default_factory=list)
  23. def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
  24. super().__init__(entity)
  25. self.provider_id = provider_id
  26. self.tenant_id = tenant_id
  27. self.tools = []
  28. @classmethod
  29. def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
  30. credentials_schema = [
  31. ProviderConfig(
  32. name="auth_type",
  33. required=True,
  34. type=ProviderConfig.Type.SELECT,
  35. options=[
  36. ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  37. ProviderConfig.Option(value="api_key_header", label=I18nObject(en_US="Header", zh_Hans="请求头")),
  38. ProviderConfig.Option(
  39. value="api_key_query", label=I18nObject(en_US="Query Param", zh_Hans="查询参数")
  40. ),
  41. ],
  42. default="none",
  43. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  44. )
  45. ]
  46. if auth_type == ApiProviderAuthType.API_KEY_HEADER:
  47. credentials_schema = [
  48. *credentials_schema,
  49. ProviderConfig(
  50. name="api_key_header",
  51. required=False,
  52. default="Authorization",
  53. type=ProviderConfig.Type.TEXT_INPUT,
  54. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  55. ),
  56. ProviderConfig(
  57. name="api_key_value",
  58. required=True,
  59. type=ProviderConfig.Type.SECRET_INPUT,
  60. help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
  61. ),
  62. ProviderConfig(
  63. name="api_key_header_prefix",
  64. required=False,
  65. default="basic",
  66. type=ProviderConfig.Type.SELECT,
  67. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  68. options=[
  69. ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  70. ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  71. ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  72. ],
  73. ),
  74. ]
  75. elif auth_type == ApiProviderAuthType.API_KEY_QUERY:
  76. credentials_schema = [
  77. *credentials_schema,
  78. ProviderConfig(
  79. name="api_key_query_param",
  80. required=False,
  81. default="key",
  82. type=ProviderConfig.Type.TEXT_INPUT,
  83. help=I18nObject(
  84. en_US="The query parameter name of the api key", zh_Hans="携带 api key 的查询参数名称"
  85. ),
  86. ),
  87. ProviderConfig(
  88. name="api_key_value",
  89. required=True,
  90. type=ProviderConfig.Type.SECRET_INPUT,
  91. help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
  92. ),
  93. ]
  94. elif auth_type == ApiProviderAuthType.NONE:
  95. pass
  96. user = db_provider.user
  97. user_name = user.name if user else ""
  98. return ApiToolProviderController(
  99. entity=ToolProviderEntity(
  100. identity=ToolProviderIdentity(
  101. author=user_name,
  102. name=db_provider.name,
  103. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  104. description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
  105. icon=db_provider.icon,
  106. ),
  107. credentials_schema=credentials_schema,
  108. plugin_id=None,
  109. ),
  110. provider_id=db_provider.id or "",
  111. tenant_id=db_provider.tenant_id or "",
  112. )
  113. @property
  114. def provider_type(self) -> ToolProviderType:
  115. return ToolProviderType.API
  116. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  117. """
  118. parse tool bundle to tool
  119. :param tool_bundle: the tool bundle
  120. :return: the tool
  121. """
  122. return ApiTool(
  123. api_bundle=tool_bundle,
  124. provider_id=self.provider_id,
  125. entity=ToolEntity(
  126. identity=ToolIdentity(
  127. author=tool_bundle.author,
  128. name=tool_bundle.operation_id or "default_tool",
  129. label=I18nObject(
  130. en_US=tool_bundle.operation_id or "default_tool",
  131. zh_Hans=tool_bundle.operation_id or "default_tool",
  132. ),
  133. icon=self.entity.identity.icon,
  134. provider=self.provider_id,
  135. ),
  136. description=ToolDescription(
  137. human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
  138. llm=tool_bundle.summary or "",
  139. ),
  140. parameters=tool_bundle.parameters or [],
  141. ),
  142. runtime=ToolRuntime(tenant_id=self.tenant_id),
  143. )
  144. def load_bundled_tools(self, tools: list[ApiToolBundle]):
  145. """
  146. load bundled tools
  147. :param tools: the bundled tools
  148. :return: the tools
  149. """
  150. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  151. return self.tools
  152. def get_tools(self, tenant_id: str) -> list[ApiTool]:
  153. """
  154. fetch tools from database
  155. :param tenant_id: the tenant id
  156. :return: the tools
  157. """
  158. if len(self.tools) > 0:
  159. return self.tools
  160. tools: list[ApiTool] = []
  161. # get tenant api providers
  162. db_providers: list[ApiToolProvider] = (
  163. db.session.query(ApiToolProvider)
  164. .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
  165. .all()
  166. )
  167. if db_providers and len(db_providers) != 0:
  168. for db_provider in db_providers:
  169. for tool in db_provider.tools:
  170. assistant_tool = self._parse_tool_bundle(tool)
  171. tools.append(assistant_tool)
  172. self.tools = tools
  173. return tools
  174. def get_tool(self, tool_name: str):
  175. """
  176. get tool by name
  177. :param tool_name: the name of the tool
  178. :return: the tool
  179. """
  180. if self.tools is None:
  181. self.get_tools(self.tenant_id)
  182. for tool in self.tools:
  183. if tool.entity.identity.name == tool_name:
  184. return tool
  185. raise ValueError(f"tool {tool_name} not found")