| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- import enum
- import uuid
-
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from werkzeug.exceptions import BadRequest
-
- from extensions.ext_database import db
- from extensions.ext_redis import redis_client
- from models.account import Account
- from models.model import OAuthProviderApp
- from services.account_service import AccountService
-
-
- class OAuthGrantType(enum.StrEnum):
- AUTHORIZATION_CODE = "authorization_code"
- REFRESH_TOKEN = "refresh_token"
-
-
- OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
- OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
- OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
- OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
- OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
-
-
- class OAuthServerService:
- @staticmethod
- def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
- query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
-
- with Session(db.engine) as session:
- return session.execute(query).scalar_one_or_none()
-
- @staticmethod
- def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
- code = str(uuid.uuid4())
- redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
- redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
- return code
-
- @staticmethod
- def sign_oauth_access_token(
- grant_type: OAuthGrantType,
- code: str = "",
- client_id: str = "",
- refresh_token: str = "",
- ) -> tuple[str, str]:
- match grant_type:
- case OAuthGrantType.AUTHORIZATION_CODE:
- redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
- user_account_id = redis_client.get(redis_key)
- if not user_account_id:
- raise BadRequest("invalid code")
-
- # delete code
- redis_client.delete(redis_key)
-
- access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
- refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
- return access_token, refresh_token
- case OAuthGrantType.REFRESH_TOKEN:
- redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
- user_account_id = redis_client.get(redis_key)
- if not user_account_id:
- raise BadRequest("invalid refresh token")
-
- access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
- return access_token, refresh_token
-
- @staticmethod
- def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
- token = str(uuid.uuid4())
- redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
- redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
- return token
-
- @staticmethod
- def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
- token = str(uuid.uuid4())
- redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
- redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
- return token
-
- @staticmethod
- def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
- redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
- user_account_id = redis_client.get(redis_key)
- if not user_account_id:
- return None
-
- user_id_str = user_account_id.decode("utf-8")
-
- return AccountService.load_user(user_id_str)
|