| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- import urllib.parse
- from unittest.mock import MagicMock, patch
-
- import httpx
- import pytest
-
- from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
-
-
- class BaseOAuthTest:
- """Base class for OAuth provider tests with common fixtures"""
-
- @pytest.fixture
- def oauth_config(self):
- return {
- "client_id": "test_client_id",
- "client_secret": "test_client_secret",
- "redirect_uri": "http://localhost/callback",
- }
-
- @pytest.fixture
- def mock_response(self):
- response = MagicMock()
- response.json.return_value = {}
- return response
-
- def parse_auth_url(self, url):
- """Helper to parse authorization URL"""
- parsed = urllib.parse.urlparse(url)
- params = urllib.parse.parse_qs(parsed.query)
- return parsed, params
-
-
- class TestGitHubOAuth(BaseOAuthTest):
- @pytest.fixture
- def oauth(self, oauth_config):
- return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
-
- @pytest.mark.parametrize(
- ("invite_token", "expected_state"),
- [
- (None, None),
- ("test_invite_token", "test_invite_token"),
- ("", None),
- ],
- )
- def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
- url = oauth.get_authorization_url(invite_token)
- parsed, params = self.parse_auth_url(url)
-
- assert parsed.scheme == "https"
- assert parsed.netloc == "github.com"
- assert parsed.path == "/login/oauth/authorize"
- assert params["client_id"][0] == oauth_config["client_id"]
- assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
- assert params["scope"][0] == "user:email"
-
- if expected_state:
- assert params["state"][0] == expected_state
- else:
- assert "state" not in params
-
- @pytest.mark.parametrize(
- ("response_data", "expected_token", "should_raise"),
- [
- ({"access_token": "test_token"}, "test_token", False),
- ({"error": "invalid_grant"}, None, True),
- ({}, None, True),
- ],
- )
- @patch("httpx.post")
- def test_should_retrieve_access_token(
- self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
- ):
- mock_response.json.return_value = response_data
- mock_post.return_value = mock_response
-
- if should_raise:
- with pytest.raises(ValueError) as exc_info:
- oauth.get_access_token("test_code")
- assert "Error in GitHub OAuth" in str(exc_info.value)
- else:
- token = oauth.get_access_token("test_code")
- assert token == expected_token
-
- @pytest.mark.parametrize(
- ("user_data", "email_data", "expected_email"),
- [
- # User with primary email
- (
- {"id": 12345, "login": "testuser", "name": "Test User"},
- [
- {"email": "secondary@example.com", "primary": False},
- {"email": "primary@example.com", "primary": True},
- ],
- "primary@example.com",
- ),
- # User with no emails - fallback to noreply
- ({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
- # User with only secondary email - fallback to noreply
- (
- {"id": 12345, "login": "testuser", "name": "Test User"},
- [{"email": "secondary@example.com", "primary": False}],
- "12345+testuser@users.noreply.github.com",
- ),
- ],
- )
- @patch("httpx.get")
- def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
- user_response = MagicMock()
- user_response.json.return_value = user_data
-
- email_response = MagicMock()
- email_response.json.return_value = email_data
-
- mock_get.side_effect = [user_response, email_response]
-
- user_info = oauth.get_user_info("test_token")
-
- assert user_info.id == str(user_data["id"])
- assert user_info.name == user_data["name"]
- assert user_info.email == expected_email
-
- @patch("httpx.get")
- def test_should_handle_network_errors(self, mock_get, oauth):
- mock_get.side_effect = httpx.RequestError("Network error")
-
- with pytest.raises(httpx.RequestError):
- oauth.get_raw_user_info("test_token")
-
-
- class TestGoogleOAuth(BaseOAuthTest):
- @pytest.fixture
- def oauth(self, oauth_config):
- return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
-
- @pytest.mark.parametrize(
- ("invite_token", "expected_state"),
- [
- (None, None),
- ("test_invite_token", "test_invite_token"),
- ("", None),
- ],
- )
- def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
- url = oauth.get_authorization_url(invite_token)
- parsed, params = self.parse_auth_url(url)
-
- assert parsed.scheme == "https"
- assert parsed.netloc == "accounts.google.com"
- assert parsed.path == "/o/oauth2/v2/auth"
- assert params["client_id"][0] == oauth_config["client_id"]
- assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
- assert params["response_type"][0] == "code"
- assert params["scope"][0] == "openid email"
-
- if expected_state:
- assert params["state"][0] == expected_state
- else:
- assert "state" not in params
-
- @pytest.mark.parametrize(
- ("response_data", "expected_token", "should_raise"),
- [
- ({"access_token": "test_token"}, "test_token", False),
- ({"error": "invalid_grant"}, None, True),
- ({}, None, True),
- ],
- )
- @patch("httpx.post")
- def test_should_retrieve_access_token(
- self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
- ):
- mock_response.json.return_value = response_data
- mock_post.return_value = mock_response
-
- if should_raise:
- with pytest.raises(ValueError) as exc_info:
- oauth.get_access_token("test_code")
- assert "Error in Google OAuth" in str(exc_info.value)
- else:
- token = oauth.get_access_token("test_code")
- assert token == expected_token
-
- mock_post.assert_called_once_with(
- oauth._TOKEN_URL,
- data={
- "client_id": oauth_config["client_id"],
- "client_secret": oauth_config["client_secret"],
- "code": "test_code",
- "grant_type": "authorization_code",
- "redirect_uri": oauth_config["redirect_uri"],
- },
- headers={"Accept": "application/json"},
- )
-
- @pytest.mark.parametrize(
- ("user_data", "expected_name"),
- [
- ({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
- ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
- ],
- )
- @patch("httpx.get")
- def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
- mock_response.json.return_value = user_data
- mock_get.return_value = mock_response
-
- user_info = oauth.get_user_info("test_token")
-
- assert user_info.id == user_data["sub"]
- assert user_info.name == expected_name
- assert user_info.email == user_data["email"]
-
- mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
-
- @pytest.mark.parametrize(
- "exception_type",
- [
- httpx.HTTPError,
- httpx.ConnectError,
- httpx.TimeoutException,
- ],
- )
- @patch("httpx.get")
- def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
- mock_response = MagicMock()
- mock_response.raise_for_status.side_effect = exception_type("Error")
- mock_get.return_value = mock_response
-
- with pytest.raises(exception_type):
- oauth.get_raw_user_info("invalid_token")
-
-
- class TestOAuthUserInfo:
- @pytest.mark.parametrize(
- "user_data",
- [
- {"id": "123", "name": "Test User", "email": "test@example.com"},
- {"id": "456", "name": "", "email": "user@domain.com"},
- {"id": "789", "name": "Another User", "email": "another@test.org"},
- ],
- )
- def test_should_create_user_info_dataclass(self, user_data):
- user_info = OAuthUserInfo(**user_data)
-
- assert user_info.id == user_data["id"]
- assert user_info.name == user_data["name"]
- assert user_info.email == user_data["email"]
|