You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_oauth_clients.py 8.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import urllib.parse
  2. from unittest.mock import MagicMock, patch
  3. import httpx
  4. import pytest
  5. from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
  6. class BaseOAuthTest:
  7. """Base class for OAuth provider tests with common fixtures"""
  8. @pytest.fixture
  9. def oauth_config(self):
  10. return {
  11. "client_id": "test_client_id",
  12. "client_secret": "test_client_secret",
  13. "redirect_uri": "http://localhost/callback",
  14. }
  15. @pytest.fixture
  16. def mock_response(self):
  17. response = MagicMock()
  18. response.json.return_value = {}
  19. return response
  20. def parse_auth_url(self, url):
  21. """Helper to parse authorization URL"""
  22. parsed = urllib.parse.urlparse(url)
  23. params = urllib.parse.parse_qs(parsed.query)
  24. return parsed, params
  25. class TestGitHubOAuth(BaseOAuthTest):
  26. @pytest.fixture
  27. def oauth(self, oauth_config):
  28. return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
  29. @pytest.mark.parametrize(
  30. ("invite_token", "expected_state"),
  31. [
  32. (None, None),
  33. ("test_invite_token", "test_invite_token"),
  34. ("", None),
  35. ],
  36. )
  37. def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
  38. url = oauth.get_authorization_url(invite_token)
  39. parsed, params = self.parse_auth_url(url)
  40. assert parsed.scheme == "https"
  41. assert parsed.netloc == "github.com"
  42. assert parsed.path == "/login/oauth/authorize"
  43. assert params["client_id"][0] == oauth_config["client_id"]
  44. assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
  45. assert params["scope"][0] == "user:email"
  46. if expected_state:
  47. assert params["state"][0] == expected_state
  48. else:
  49. assert "state" not in params
  50. @pytest.mark.parametrize(
  51. ("response_data", "expected_token", "should_raise"),
  52. [
  53. ({"access_token": "test_token"}, "test_token", False),
  54. ({"error": "invalid_grant"}, None, True),
  55. ({}, None, True),
  56. ],
  57. )
  58. @patch("httpx.post")
  59. def test_should_retrieve_access_token(
  60. self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
  61. ):
  62. mock_response.json.return_value = response_data
  63. mock_post.return_value = mock_response
  64. if should_raise:
  65. with pytest.raises(ValueError) as exc_info:
  66. oauth.get_access_token("test_code")
  67. assert "Error in GitHub OAuth" in str(exc_info.value)
  68. else:
  69. token = oauth.get_access_token("test_code")
  70. assert token == expected_token
  71. @pytest.mark.parametrize(
  72. ("user_data", "email_data", "expected_email"),
  73. [
  74. # User with primary email
  75. (
  76. {"id": 12345, "login": "testuser", "name": "Test User"},
  77. [
  78. {"email": "secondary@example.com", "primary": False},
  79. {"email": "primary@example.com", "primary": True},
  80. ],
  81. "primary@example.com",
  82. ),
  83. # User with no emails - fallback to noreply
  84. ({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
  85. # User with only secondary email - fallback to noreply
  86. (
  87. {"id": 12345, "login": "testuser", "name": "Test User"},
  88. [{"email": "secondary@example.com", "primary": False}],
  89. "12345+testuser@users.noreply.github.com",
  90. ),
  91. ],
  92. )
  93. @patch("httpx.get")
  94. def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
  95. user_response = MagicMock()
  96. user_response.json.return_value = user_data
  97. email_response = MagicMock()
  98. email_response.json.return_value = email_data
  99. mock_get.side_effect = [user_response, email_response]
  100. user_info = oauth.get_user_info("test_token")
  101. assert user_info.id == str(user_data["id"])
  102. assert user_info.name == user_data["name"]
  103. assert user_info.email == expected_email
  104. @patch("httpx.get")
  105. def test_should_handle_network_errors(self, mock_get, oauth):
  106. mock_get.side_effect = httpx.RequestError("Network error")
  107. with pytest.raises(httpx.RequestError):
  108. oauth.get_raw_user_info("test_token")
  109. class TestGoogleOAuth(BaseOAuthTest):
  110. @pytest.fixture
  111. def oauth(self, oauth_config):
  112. return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
  113. @pytest.mark.parametrize(
  114. ("invite_token", "expected_state"),
  115. [
  116. (None, None),
  117. ("test_invite_token", "test_invite_token"),
  118. ("", None),
  119. ],
  120. )
  121. def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
  122. url = oauth.get_authorization_url(invite_token)
  123. parsed, params = self.parse_auth_url(url)
  124. assert parsed.scheme == "https"
  125. assert parsed.netloc == "accounts.google.com"
  126. assert parsed.path == "/o/oauth2/v2/auth"
  127. assert params["client_id"][0] == oauth_config["client_id"]
  128. assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
  129. assert params["response_type"][0] == "code"
  130. assert params["scope"][0] == "openid email"
  131. if expected_state:
  132. assert params["state"][0] == expected_state
  133. else:
  134. assert "state" not in params
  135. @pytest.mark.parametrize(
  136. ("response_data", "expected_token", "should_raise"),
  137. [
  138. ({"access_token": "test_token"}, "test_token", False),
  139. ({"error": "invalid_grant"}, None, True),
  140. ({}, None, True),
  141. ],
  142. )
  143. @patch("httpx.post")
  144. def test_should_retrieve_access_token(
  145. self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
  146. ):
  147. mock_response.json.return_value = response_data
  148. mock_post.return_value = mock_response
  149. if should_raise:
  150. with pytest.raises(ValueError) as exc_info:
  151. oauth.get_access_token("test_code")
  152. assert "Error in Google OAuth" in str(exc_info.value)
  153. else:
  154. token = oauth.get_access_token("test_code")
  155. assert token == expected_token
  156. mock_post.assert_called_once_with(
  157. oauth._TOKEN_URL,
  158. data={
  159. "client_id": oauth_config["client_id"],
  160. "client_secret": oauth_config["client_secret"],
  161. "code": "test_code",
  162. "grant_type": "authorization_code",
  163. "redirect_uri": oauth_config["redirect_uri"],
  164. },
  165. headers={"Accept": "application/json"},
  166. )
  167. @pytest.mark.parametrize(
  168. ("user_data", "expected_name"),
  169. [
  170. ({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
  171. ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
  172. ],
  173. )
  174. @patch("httpx.get")
  175. def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
  176. mock_response.json.return_value = user_data
  177. mock_get.return_value = mock_response
  178. user_info = oauth.get_user_info("test_token")
  179. assert user_info.id == user_data["sub"]
  180. assert user_info.name == expected_name
  181. assert user_info.email == user_data["email"]
  182. mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
  183. @pytest.mark.parametrize(
  184. "exception_type",
  185. [
  186. httpx.HTTPError,
  187. httpx.ConnectError,
  188. httpx.TimeoutException,
  189. ],
  190. )
  191. @patch("httpx.get")
  192. def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
  193. mock_response = MagicMock()
  194. mock_response.raise_for_status.side_effect = exception_type("Error")
  195. mock_get.return_value = mock_response
  196. with pytest.raises(exception_type):
  197. oauth.get_raw_user_info("invalid_token")
  198. class TestOAuthUserInfo:
  199. @pytest.mark.parametrize(
  200. "user_data",
  201. [
  202. {"id": "123", "name": "Test User", "email": "test@example.com"},
  203. {"id": "456", "name": "", "email": "user@domain.com"},
  204. {"id": "789", "name": "Another User", "email": "another@test.org"},
  205. ],
  206. )
  207. def test_should_create_user_info_dataclass(self, user_data):
  208. user_info = OAuthUserInfo(**user_data)
  209. assert user_info.id == user_data["id"]
  210. assert user_info.name == user_data["name"]
  211. assert user_info.email == user_data["email"]