Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

oauth_server.py 3.8KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import enum
  2. import uuid
  3. from sqlalchemy import select
  4. from sqlalchemy.orm import Session
  5. from werkzeug.exceptions import BadRequest
  6. from extensions.ext_database import db
  7. from extensions.ext_redis import redis_client
  8. from models.account import Account
  9. from models.model import OAuthProviderApp
  10. from services.account_service import AccountService
  11. class OAuthGrantType(enum.StrEnum):
  12. AUTHORIZATION_CODE = "authorization_code"
  13. REFRESH_TOKEN = "refresh_token"
  14. OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
  15. OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
  16. OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
  17. OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
  18. OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
  19. class OAuthServerService:
  20. @staticmethod
  21. def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
  22. query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
  23. with Session(db.engine) as session:
  24. return session.execute(query).scalar_one_or_none()
  25. @staticmethod
  26. def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
  27. code = str(uuid.uuid4())
  28. redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
  29. redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
  30. return code
  31. @staticmethod
  32. def sign_oauth_access_token(
  33. grant_type: OAuthGrantType,
  34. code: str = "",
  35. client_id: str = "",
  36. refresh_token: str = "",
  37. ) -> tuple[str, str]:
  38. match grant_type:
  39. case OAuthGrantType.AUTHORIZATION_CODE:
  40. redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
  41. user_account_id = redis_client.get(redis_key)
  42. if not user_account_id:
  43. raise BadRequest("invalid code")
  44. # delete code
  45. redis_client.delete(redis_key)
  46. access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
  47. refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
  48. return access_token, refresh_token
  49. case OAuthGrantType.REFRESH_TOKEN:
  50. redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
  51. user_account_id = redis_client.get(redis_key)
  52. if not user_account_id:
  53. raise BadRequest("invalid refresh token")
  54. access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
  55. return access_token, refresh_token
  56. @staticmethod
  57. def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
  58. token = str(uuid.uuid4())
  59. redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
  60. redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
  61. return token
  62. @staticmethod
  63. def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
  64. token = str(uuid.uuid4())
  65. redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
  66. redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
  67. return token
  68. @staticmethod
  69. def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
  70. redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
  71. user_account_id = redis_client.get(redis_key)
  72. if not user_account_id:
  73. return None
  74. user_id_str = user_account_id.decode("utf-8")
  75. return AccountService.load_user(user_id_str)