| import binascii | |||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||
| from typing import Any | from typing import Any | ||||
| provider: str, | provider: str, | ||||
| system_credentials: Mapping[str, Any], | system_credentials: Mapping[str, Any], | ||||
| ) -> PluginOAuthAuthorizationUrlResponse: | ) -> PluginOAuthAuthorizationUrlResponse: | ||||
| return self._request_with_plugin_daemon_response( | |||||
| response = self._request_with_plugin_daemon_response_stream( | |||||
| "POST", | "POST", | ||||
| f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", | f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", | ||||
| PluginOAuthAuthorizationUrlResponse, | PluginOAuthAuthorizationUrlResponse, | ||||
| "Content-Type": "application/json", | "Content-Type": "application/json", | ||||
| }, | }, | ||||
| ) | ) | ||||
| for resp in response: | |||||
| return resp | |||||
| raise ValueError("No response received from plugin daemon for authorization URL request.") | |||||
| def get_credentials( | def get_credentials( | ||||
| self, | self, | ||||
| # encode request to raw http request | # encode request to raw http request | ||||
| raw_request_bytes = self._convert_request_to_raw_data(request) | raw_request_bytes = self._convert_request_to_raw_data(request) | ||||
| return self._request_with_plugin_daemon_response( | |||||
| response = self._request_with_plugin_daemon_response_stream( | |||||
| "POST", | "POST", | ||||
| f"plugin/{tenant_id}/dispatch/oauth/get_credentials", | f"plugin/{tenant_id}/dispatch/oauth/get_credentials", | ||||
| PluginOAuthCredentialsResponse, | PluginOAuthCredentialsResponse, | ||||
| "data": { | "data": { | ||||
| "provider": provider, | "provider": provider, | ||||
| "system_credentials": system_credentials, | "system_credentials": system_credentials, | ||||
| "raw_request_bytes": raw_request_bytes, | |||||
| # for json serialization | |||||
| "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), | |||||
| }, | }, | ||||
| }, | }, | ||||
| headers={ | headers={ | ||||
| "Content-Type": "application/json", | "Content-Type": "application/json", | ||||
| }, | }, | ||||
| ) | ) | ||||
| for resp in response: | |||||
| return resp | |||||
| raise ValueError("No response received from plugin daemon for authorization URL request.") | |||||
| def _convert_request_to_raw_data(self, request: Request) -> bytes: | def _convert_request_to_raw_data(self, request: Request) -> bytes: | ||||
| """ | """ | ||||
| """ | """ | ||||
| # Start with the request line | # Start with the request line | ||||
| method = request.method | method = request.method | ||||
| path = request.path | |||||
| path = request.full_path | |||||
| protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") | protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") | ||||
| raw_data = f"{method} {path} {protocol}\r\n".encode() | raw_data = f"{method} {path} {protocol}\r\n".encode() | ||||
| import json | |||||
| import uuid | |||||
| from core.plugin.impl.base import BasePluginClient | from core.plugin.impl.base import BasePluginClient | ||||
| from extensions.ext_redis import redis_client | |||||
| class OAuthProxyService(BasePluginClient): | |||||
| # Default max age for proxy context parameter in seconds | |||||
| __MAX_AGE__ = 5 * 60 # 5 minutes | |||||
| @staticmethod | |||||
| def create_proxy_context(user_id, tenant_id, plugin_id, provider): | |||||
| """ | |||||
| Create a proxy context for an OAuth 2.0 authorization request. | |||||
| This parameter is a crucial security measure to prevent Cross-Site Request | |||||
| Forgery (CSRF) attacks. It works by generating a unique nonce and storing it | |||||
| in a distributed cache (Redis) along with the user's session context. | |||||
| The returned nonce should be included as the 'proxy_context' parameter in the | |||||
| authorization URL. Upon callback, the `use_proxy_context` method | |||||
| is used to verify the state, ensuring the request's integrity and authenticity, | |||||
| and mitigating replay attacks. | |||||
| """ | |||||
| seconds, _ = redis_client.time() | |||||
| context_id = str(uuid.uuid4()) | |||||
| data = { | |||||
| "user_id": user_id, | |||||
| "plugin_id": plugin_id, | |||||
| "tenant_id": tenant_id, | |||||
| "provider": provider, | |||||
| # encode redis time to avoid distribution time skew | |||||
| "timestamp": seconds, | |||||
| } | |||||
| # ignore nonce collision | |||||
| redis_client.setex( | |||||
| f"oauth_proxy_context:{context_id}", | |||||
| OAuthProxyService.__MAX_AGE__, | |||||
| json.dumps(data), | |||||
| ) | |||||
| return context_id | |||||
| class OAuthService(BasePluginClient): | |||||
| @classmethod | |||||
| def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: | |||||
| return "1234567890" | |||||
| @staticmethod | |||||
| def use_proxy_context(context_id, max_age=__MAX_AGE__): | |||||
| """ | |||||
| Validate the proxy context parameter. | |||||
| This checks if the context_id is valid and not expired. | |||||
| """ | |||||
| if not context_id: | |||||
| raise ValueError("context_id is required") | |||||
| # get data from redis | |||||
| data = redis_client.getdel(f"oauth_proxy_context:{context_id}") | |||||
| if not data: | |||||
| raise ValueError("context_id is invalid") | |||||
| # check if data is expired | |||||
| seconds, _ = redis_client.time() | |||||
| state = json.loads(data) | |||||
| if state.get("timestamp") < seconds - max_age: | |||||
| raise ValueError("context_id is expired") | |||||
| return state |
| import json | |||||
| from werkzeug import Request | from werkzeug import Request | ||||
| from werkzeug.datastructures import Headers | from werkzeug.datastructures import Headers | ||||
| from werkzeug.test import EnvironBuilder | from werkzeug.test import EnvironBuilder | ||||
| request = Request(builder.get_environ()) | request = Request(builder.get_environ()) | ||||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | ||||
| assert b"GET /test HTTP/1.1" in raw_request_bytes | |||||
| assert b"GET /test? HTTP/1.1" in raw_request_bytes | |||||
| assert b"Content-Type: application/json" in raw_request_bytes | |||||
| assert b"\r\n\r\n" in raw_request_bytes | |||||
| def test_oauth_convert_request_to_raw_data_with_query_params(): | |||||
| oauth_handler = OAuthHandler() | |||||
| builder = EnvironBuilder( | |||||
| method="GET", | |||||
| path="/test", | |||||
| query_string="code=abc123&state=xyz789", | |||||
| headers=Headers({"Content-Type": "application/json"}), | |||||
| ) | |||||
| request = Request(builder.get_environ()) | |||||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | |||||
| assert b"GET /test?code=abc123&state=xyz789 HTTP/1.1" in raw_request_bytes | |||||
| assert b"Content-Type: application/json" in raw_request_bytes | |||||
| assert b"\r\n\r\n" in raw_request_bytes | |||||
| def test_oauth_convert_request_to_raw_data_with_post_body(): | |||||
| oauth_handler = OAuthHandler() | |||||
| builder = EnvironBuilder( | |||||
| method="POST", | |||||
| path="/test", | |||||
| data="param1=value1¶m2=value2", | |||||
| headers=Headers({"Content-Type": "application/x-www-form-urlencoded"}), | |||||
| ) | |||||
| request = Request(builder.get_environ()) | |||||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | |||||
| assert b"POST /test? HTTP/1.1" in raw_request_bytes | |||||
| assert b"Content-Type: application/x-www-form-urlencoded" in raw_request_bytes | |||||
| assert b"\r\n\r\n" in raw_request_bytes | |||||
| assert b"param1=value1¶m2=value2" in raw_request_bytes | |||||
| def test_oauth_convert_request_to_raw_data_with_json_body(): | |||||
| oauth_handler = OAuthHandler() | |||||
| json_data = {"code": "abc123", "state": "xyz789", "grant_type": "authorization_code"} | |||||
| builder = EnvironBuilder( | |||||
| method="POST", | |||||
| path="/test", | |||||
| data=json.dumps(json_data), | |||||
| headers=Headers({"Content-Type": "application/json"}), | |||||
| ) | |||||
| request = Request(builder.get_environ()) | |||||
| raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) | |||||
| assert b"POST /test? HTTP/1.1" in raw_request_bytes | |||||
| assert b"Content-Type: application/json" in raw_request_bytes | assert b"Content-Type: application/json" in raw_request_bytes | ||||
| assert b"\r\n\r\n" in raw_request_bytes | assert b"\r\n\r\n" in raw_request_bytes | ||||
| assert b'"code": "abc123"' in raw_request_bytes | |||||
| assert b'"state": "xyz789"' in raw_request_bytes | |||||
| assert b'"grant_type": "authorization_code"' in raw_request_bytes |