Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.


  1. import io
  2. from flask import send_file
  3. from flask_login import current_user
  4. from flask_restx import Resource, reqparse
  5. from werkzeug.exceptions import Forbidden
  6. from controllers.console import api
  7. from controllers.console.wraps import account_initialization_required, setup_required
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from libs.helper import StrLen, uuid_value
  12. from libs.login import login_required
  13. from services.billing_service import BillingService
  14. from services.model_provider_service import ModelProviderService
  15. class ModelProviderListApi(Resource):
  16. @setup_required
  17. @login_required
  18. @account_initialization_required
  19. def get(self):
  20. tenant_id = current_user.current_tenant_id
  21. parser = reqparse.RequestParser()
  22. parser.add_argument(
  23. "model_type",
  24. type=str,
  25. required=False,
  26. nullable=True,
  27. choices=[mt.value for mt in ModelType],
  28. location="args",
  29. )
  30. args = parser.parse_args()
  31. model_provider_service = ModelProviderService()
  32. provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
  33. return jsonable_encoder({"data": provider_list})
  34. class ModelProviderCredentialApi(Resource):
  35. @setup_required
  36. @login_required
  37. @account_initialization_required
  38. def get(self, provider: str):
  39. tenant_id = current_user.current_tenant_id
  40. # if credential_id is not provided, return current used credential
  41. parser = reqparse.RequestParser()
  42. parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
  43. args = parser.parse_args()
  44. model_provider_service = ModelProviderService()
  45. credentials = model_provider_service.get_provider_credential(
  46. tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
  47. )
  48. return {"credentials": credentials}
  49. @setup_required
  50. @login_required
  51. @account_initialization_required
  52. def post(self, provider: str):
  53. if not current_user.is_admin_or_owner:
  54. raise Forbidden()
  55. parser = reqparse.RequestParser()
  56. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  57. parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
  58. args = parser.parse_args()
  59. model_provider_service = ModelProviderService()
  60. try:
  61. model_provider_service.create_provider_credential(
  62. tenant_id=current_user.current_tenant_id,
  63. provider=provider,
  64. credentials=args["credentials"],
  65. credential_name=args["name"],
  66. )
  67. except CredentialsValidateFailedError as ex:
  68. raise ValueError(str(ex))
  69. return {"result": "success"}, 201
  70. @setup_required
  71. @login_required
  72. @account_initialization_required
  73. def put(self, provider: str):
  74. if not current_user.is_admin_or_owner:
  75. raise Forbidden()
  76. parser = reqparse.RequestParser()
  77. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  78. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  79. parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
  80. args = parser.parse_args()
  81. model_provider_service = ModelProviderService()
  82. try:
  83. model_provider_service.update_provider_credential(
  84. tenant_id=current_user.current_tenant_id,
  85. provider=provider,
  86. credentials=args["credentials"],
  87. credential_id=args["credential_id"],
  88. credential_name=args["name"],
  89. )
  90. except CredentialsValidateFailedError as ex:
  91. raise ValueError(str(ex))
  92. return {"result": "success"}
  93. @setup_required
  94. @login_required
  95. @account_initialization_required
  96. def delete(self, provider: str):
  97. if not current_user.is_admin_or_owner:
  98. raise Forbidden()
  99. parser = reqparse.RequestParser()
  100. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  101. args = parser.parse_args()
  102. model_provider_service = ModelProviderService()
  103. model_provider_service.remove_provider_credential(
  104. tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
  105. )
  106. return {"result": "success"}, 204
  107. class ModelProviderCredentialSwitchApi(Resource):
  108. @setup_required
  109. @login_required
  110. @account_initialization_required
  111. def post(self, provider: str):
  112. if not current_user.is_admin_or_owner:
  113. raise Forbidden()
  114. parser = reqparse.RequestParser()
  115. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  116. args = parser.parse_args()
  117. service = ModelProviderService()
  118. service.switch_active_provider_credential(
  119. tenant_id=current_user.current_tenant_id,
  120. provider=provider,
  121. credential_id=args["credential_id"],
  122. )
  123. return {"result": "success"}
  124. class ModelProviderValidateApi(Resource):
  125. @setup_required
  126. @login_required
  127. @account_initialization_required
  128. def post(self, provider: str):
  129. parser = reqparse.RequestParser()
  130. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  131. args = parser.parse_args()
  132. tenant_id = current_user.current_tenant_id
  133. model_provider_service = ModelProviderService()
  134. result = True
  135. error = ""
  136. try:
  137. model_provider_service.validate_provider_credentials(
  138. tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
  139. )
  140. except CredentialsValidateFailedError as ex:
  141. result = False
  142. error = str(ex)
  143. response = {"result": "success" if result else "error"}
  144. if not result:
  145. response["error"] = error or "Unknown error"
  146. return response
  147. class ModelProviderIconApi(Resource):
  148. """
  149. Get model provider icon
  150. """
  151. def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
  152. model_provider_service = ModelProviderService()
  153. icon, mimetype = model_provider_service.get_model_provider_icon(
  154. tenant_id=tenant_id,
  155. provider=provider,
  156. icon_type=icon_type,
  157. lang=lang,
  158. )
  159. if icon is None:
  160. raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}")
  161. return send_file(io.BytesIO(icon), mimetype=mimetype)
  162. class PreferredProviderTypeUpdateApi(Resource):
  163. @setup_required
  164. @login_required
  165. @account_initialization_required
  166. def post(self, provider: str):
  167. if not current_user.is_admin_or_owner:
  168. raise Forbidden()
  169. tenant_id = current_user.current_tenant_id
  170. parser = reqparse.RequestParser()
  171. parser.add_argument(
  172. "preferred_provider_type",
  173. type=str,
  174. required=True,
  175. nullable=False,
  176. choices=["system", "custom"],
  177. location="json",
  178. )
  179. args = parser.parse_args()
  180. model_provider_service = ModelProviderService()
  181. model_provider_service.switch_preferred_provider(
  182. tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
  183. )
  184. return {"result": "success"}
  185. class ModelProviderPaymentCheckoutUrlApi(Resource):
  186. @setup_required
  187. @login_required
  188. @account_initialization_required
  189. def get(self, provider: str):
  190. if provider != "anthropic":
  191. raise ValueError(f"provider name {provider} is invalid")
  192. BillingService.is_tenant_owner_or_admin(current_user)
  193. data = BillingService.get_model_provider_payment_link(
  194. provider_name=provider,
  195. tenant_id=current_user.current_tenant_id,
  196. account_id=current_user.id,
  197. prefilled_email=current_user.email,
  198. )
  199. return data
  200. api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
  201. api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
  202. api.add_resource(
  203. ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
  204. )
  205. api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
  206. api.add_resource(
  207. PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
  208. )
  209. api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
  210. api.add_resource(
  211. ModelProviderIconApi,
  212. "/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
  213. )