| @@ -8,9 +8,10 @@ from extensions.ext_redis import redis_client | |||
| class OAuthProxyService(BasePluginClient): | |||
| # Default max age for proxy context parameter in seconds | |||
| __MAX_AGE__ = 5 * 60 # 5 minutes | |||
| __KEY_PREFIX__ = "oauth_proxy_context:" | |||
| @staticmethod | |||
| def create_proxy_context(user_id, tenant_id, plugin_id, provider): | |||
| def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str): | |||
| """ | |||
| Create a proxy context for an OAuth 2.0 authorization request. | |||
| @@ -23,26 +24,22 @@ class OAuthProxyService(BasePluginClient): | |||
| is used to verify the state, ensuring the request's integrity and authenticity, | |||
| and mitigating replay attacks. | |||
| """ | |||
| seconds, _ = redis_client.time() | |||
| context_id = str(uuid.uuid4()) | |||
| data = { | |||
| "user_id": user_id, | |||
| "plugin_id": plugin_id, | |||
| "tenant_id": tenant_id, | |||
| "provider": provider, | |||
| # encode redis time to avoid distribution time skew | |||
| "timestamp": seconds, | |||
| } | |||
| # ignore nonce collision | |||
| redis_client.setex( | |||
| f"oauth_proxy_context:{context_id}", | |||
| f"{OAuthProxyService.__KEY_PREFIX__}{context_id}", | |||
| OAuthProxyService.__MAX_AGE__, | |||
| json.dumps(data), | |||
| ) | |||
| return context_id | |||
| @staticmethod | |||
| def use_proxy_context(context_id, max_age=__MAX_AGE__): | |||
| def use_proxy_context(context_id: str): | |||
| """ | |||
| Validate the proxy context parameter. | |||
| This checks if the context_id is valid and not expired. | |||
| @@ -50,12 +47,7 @@ class OAuthProxyService(BasePluginClient): | |||
| if not context_id: | |||
| raise ValueError("context_id is required") | |||
| # get data from redis | |||
| data = redis_client.getdel(f"oauth_proxy_context:{context_id}") | |||
| data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}") | |||
| if not data: | |||
| raise ValueError("context_id is invalid") | |||
| # check if data is expired | |||
| seconds, _ = redis_client.time() | |||
| state = json.loads(data) | |||
| if state.get("timestamp") < seconds - max_age: | |||
| raise ValueError("context_id is expired") | |||
| return state | |||
| return json.loads(data) | |||