| @@ -1,3 +1,4 @@ | |||
| import binascii | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): | |||
| provider: str, | |||
| system_credentials: Mapping[str, Any], | |||
| ) -> PluginOAuthAuthorizationUrlResponse: | |||
| return self._request_with_plugin_daemon_response( | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", | |||
| PluginOAuthAuthorizationUrlResponse, | |||
| @@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient): | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| return resp | |||
| raise ValueError("No response received from plugin daemon for authorization URL request.") | |||
| def get_credentials( | |||
| self, | |||
| @@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient): | |||
| # encode request to raw http request | |||
| raw_request_bytes = self._convert_request_to_raw_data(request) | |||
| return self._request_with_plugin_daemon_response( | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/oauth/get_credentials", | |||
| PluginOAuthCredentialsResponse, | |||
| @@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient): | |||
| "data": { | |||
| "provider": provider, | |||
| "system_credentials": system_credentials, | |||
| "raw_request_bytes": raw_request_bytes, | |||
| # for json serialization | |||
| "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), | |||
| }, | |||
| }, | |||
| headers={ | |||
| @@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient): | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| return resp | |||
| raise ValueError("No response received from plugin daemon for authorization URL request.") | |||
| def _convert_request_to_raw_data(self, request: Request) -> bytes: | |||
| """ | |||
| @@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient): | |||
| """ | |||
| # Start with the request line | |||
| method = request.method | |||
| path = request.path | |||
| path = request.full_path | |||
| protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") | |||
| raw_data = f"{method} {path} {protocol}\r\n".encode() | |||
| @@ -1,7 +1,61 @@ | |||
| import json | |||
| import uuid | |||
| from core.plugin.impl.base import BasePluginClient | |||
| from extensions.ext_redis import redis_client | |||
| class OAuthProxyService(BasePluginClient): | |||
| # Default max age for proxy context parameter in seconds | |||
| __MAX_AGE__ = 5 * 60 # 5 minutes | |||
| @staticmethod | |||
| def create_proxy_context(user_id, tenant_id, plugin_id, provider): | |||
| """ | |||
| Create a proxy context for an OAuth 2.0 authorization request. | |||
| This parameter is a crucial security measure to prevent Cross-Site Request | |||
| Forgery (CSRF) attacks. It works by generating a unique nonce and storing it | |||
| in a distributed cache (Redis) along with the user's session context. | |||
| The returned nonce should be included as the 'proxy_context' parameter in the | |||
| authorization URL. Upon callback, the `use_proxy_context` method | |||
| is used to verify the state, ensuring the request's integrity and authenticity, | |||
| and mitigating replay attacks. | |||
| """ | |||
| seconds, _ = redis_client.time() | |||
| context_id = str(uuid.uuid4()) | |||
| data = { | |||
| "user_id": user_id, | |||
| "plugin_id": plugin_id, | |||
| "tenant_id": tenant_id, | |||
| "provider": provider, | |||
| # encode redis time to avoid distribution time skew | |||
| "timestamp": seconds, | |||
| } | |||
| # ignore nonce collision | |||
| redis_client.setex( | |||
| f"oauth_proxy_context:{context_id}", | |||
| OAuthProxyService.__MAX_AGE__, | |||
| json.dumps(data), | |||
| ) | |||
| return context_id | |||
| class OAuthService(BasePluginClient): | |||
| @classmethod | |||
| def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: | |||
| return "1234567890" | |||
| @staticmethod | |||
| def use_proxy_context(context_id, max_age=__MAX_AGE__): | |||
| """ | |||
| Validate the proxy context parameter. | |||
| This checks if the context_id is valid and not expired. | |||
| """ | |||
| if not context_id: | |||
| raise ValueError("context_id is required") | |||
| # get data from redis | |||
| data = redis_client.getdel(f"oauth_proxy_context:{context_id}") | |||
| if not data: | |||
| raise ValueError("context_id is invalid") | |||
| # check if data is expired | |||
| seconds, _ = redis_client.time() | |||
| state = json.loads(data) | |||
| if state.get("timestamp") < seconds - max_age: | |||
| raise ValueError("context_id is expired") | |||
| return state | |||
| @@ -1,3 +1,5 @@ | |||
| import json | |||
| from werkzeug import Request | |||
| from werkzeug.datastructures import Headers | |||
| from werkzeug.test import EnvironBuilder | |||
| @@ -15,6 +17,59 @@ def test_oauth_convert_request_to_raw_data(): | |||
| request = Request(builder.get_environ()) | |||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | |||
| assert b"GET /test HTTP/1.1" in raw_request_bytes | |||
| assert b"GET /test? HTTP/1.1" in raw_request_bytes | |||
| assert b"Content-Type: application/json" in raw_request_bytes | |||
| assert b"\r\n\r\n" in raw_request_bytes | |||
| def test_oauth_convert_request_to_raw_data_with_query_params(): | |||
| oauth_handler = OAuthHandler() | |||
| builder = EnvironBuilder( | |||
| method="GET", | |||
| path="/test", | |||
| query_string="code=abc123&state=xyz789", | |||
| headers=Headers({"Content-Type": "application/json"}), | |||
| ) | |||
| request = Request(builder.get_environ()) | |||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | |||
| assert b"GET /test?code=abc123&state=xyz789 HTTP/1.1" in raw_request_bytes | |||
| assert b"Content-Type: application/json" in raw_request_bytes | |||
| assert b"\r\n\r\n" in raw_request_bytes | |||
| def test_oauth_convert_request_to_raw_data_with_post_body(): | |||
| oauth_handler = OAuthHandler() | |||
| builder = EnvironBuilder( | |||
| method="POST", | |||
| path="/test", | |||
| data="param1=value1¶m2=value2", | |||
| headers=Headers({"Content-Type": "application/x-www-form-urlencoded"}), | |||
| ) | |||
| request = Request(builder.get_environ()) | |||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | |||
| assert b"POST /test? HTTP/1.1" in raw_request_bytes | |||
| assert b"Content-Type: application/x-www-form-urlencoded" in raw_request_bytes | |||
| assert b"\r\n\r\n" in raw_request_bytes | |||
| assert b"param1=value1¶m2=value2" in raw_request_bytes | |||
| def test_oauth_convert_request_to_raw_data_with_json_body(): | |||
| oauth_handler = OAuthHandler() | |||
| json_data = {"code": "abc123", "state": "xyz789", "grant_type": "authorization_code"} | |||
| builder = EnvironBuilder( | |||
| method="POST", | |||
| path="/test", | |||
| data=json.dumps(json_data), | |||
| headers=Headers({"Content-Type": "application/json"}), | |||
| ) | |||
| request = Request(builder.get_environ()) | |||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | |||
| assert b"POST /test? HTTP/1.1" in raw_request_bytes | |||
| assert b"Content-Type: application/json" in raw_request_bytes | |||
| assert b"\r\n\r\n" in raw_request_bytes | |||
| assert b'"code": "abc123"' in raw_request_bytes | |||
| assert b'"state": "xyz789"' in raw_request_bytes | |||
| assert b'"grant_type": "authorization_code"' in raw_request_bytes | |||