您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

api_key_auth_service.py 3.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import json
  2. from sqlalchemy import select
  3. from core.helper import encrypter
  4. from extensions.ext_database import db
  5. from models.source import DataSourceApiKeyAuthBinding
  6. from services.auth.api_key_auth_factory import ApiKeyAuthFactory
  7. class ApiKeyAuthService:
  8. @staticmethod
  9. def get_provider_auth_list(tenant_id: str):
  10. data_source_api_key_bindings = db.session.scalars(
  11. select(DataSourceApiKeyAuthBinding).where(
  12. DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
  13. )
  14. ).all()
  15. return data_source_api_key_bindings
  16. @staticmethod
  17. def create_provider_auth(tenant_id: str, args: dict):
  18. auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
  19. if auth_result:
  20. # Encrypt the api key
  21. api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
  22. args["credentials"]["config"]["api_key"] = api_key
  23. data_source_api_key_binding = DataSourceApiKeyAuthBinding()
  24. data_source_api_key_binding.tenant_id = tenant_id
  25. data_source_api_key_binding.category = args["category"]
  26. data_source_api_key_binding.provider = args["provider"]
  27. data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
  28. db.session.add(data_source_api_key_binding)
  29. db.session.commit()
  30. @staticmethod
  31. def get_auth_credentials(tenant_id: str, category: str, provider: str):
  32. data_source_api_key_bindings = (
  33. db.session.query(DataSourceApiKeyAuthBinding)
  34. .where(
  35. DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
  36. DataSourceApiKeyAuthBinding.category == category,
  37. DataSourceApiKeyAuthBinding.provider == provider,
  38. DataSourceApiKeyAuthBinding.disabled.is_(False),
  39. )
  40. .first()
  41. )
  42. if not data_source_api_key_bindings:
  43. return None
  44. credentials = json.loads(data_source_api_key_bindings.credentials)
  45. return credentials
  46. @staticmethod
  47. def delete_provider_auth(tenant_id: str, binding_id: str):
  48. data_source_api_key_binding = (
  49. db.session.query(DataSourceApiKeyAuthBinding)
  50. .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
  51. .first()
  52. )
  53. if data_source_api_key_binding:
  54. db.session.delete(data_source_api_key_binding)
  55. db.session.commit()
  56. @classmethod
  57. def validate_api_key_auth_args(cls, args):
  58. if "category" not in args or not args["category"]:
  59. raise ValueError("category is required")
  60. if "provider" not in args or not args["provider"]:
  61. raise ValueError("provider is required")
  62. if "credentials" not in args or not args["credentials"]:
  63. raise ValueError("credentials is required")
  64. if not isinstance(args["credentials"], dict):
  65. raise ValueError("credentials must be a dictionary")
  66. if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
  67. raise ValueError("auth_type is required")