Browse Source

more httpx (#25651)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
tags/1.9.0
Asuka Minato 1 month ago
parent
commit
8940decd1b
No account linked to committer's email address

+ 4
- 4
api/configs/remote_settings_sources/nacos/http_request.py View File

import os import os
import time import time


import requests
import httpx


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


params = {} params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except requests.RequestException as e:
except httpx.RequestError as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"


def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
params = {"username": self.username, "password": self.password} params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login" url = "http://" + self.server + "/nacos/v1/auth/login"
try: try:
resp = requests.request("POST", url, headers=None, params=params)
resp = httpx.request("POST", url, headers=None, params=params)
resp.raise_for_status() resp.raise_for_status()
response_data = resp.json() response_data = resp.json()
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")

+ 3
- 3
api/controllers/console/auth/data_source_oauth.py View File

import logging import logging


import requests
import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
return {"error": "Invalid code"}, 400 return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.HTTPError as e:
except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e:
except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )

+ 5
- 3
api/controllers/console/auth/oauth.py View File

import logging import logging


import requests
import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select from sqlalchemy import select
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
except httpx.RequestError as e:
error_text = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400



+ 6
- 2
api/controllers/console/version.py View File

import json import json
import logging import logging


import requests
import httpx
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from packaging import version from packaging import version


return result return result


try: try:
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10),
)
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args["current_version"]

+ 3
- 3
api/core/ops/aliyun_trace/data_exporter/traceclient.py View File

from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime


import requests
import httpx
from opentelemetry import trace as trace_api from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource


def api_check(self): def api_check(self):
try: try:
response = requests.head(self.endpoint, timeout=5)
response = httpx.head(self.endpoint, timeout=5)
if response.status_code == 405: if response.status_code == 405:
return True return True
else: else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False return False
except requests.RequestException as e:
except httpx.RequestError as e:
logger.debug("AliyunTrace API check failed: %s", str(e)) logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}") raise ValueError(f"AliyunTrace API check failed: {str(e)}")



+ 6
- 6
api/libs/oauth.py View File

import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass


import requests
import httpx




@dataclass @dataclass
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
} }
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)


response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")


def get_raw_user_info(self, token: str): def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"} headers = {"Authorization": f"token {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status() response.raise_for_status()
user_info = response.json() user_info = response.json()


email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json() email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) primary_email: dict = next((email for email in email_info if email["primary"] == True), {})


"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
} }
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)


response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")


def get_raw_user_info(self, token: str): def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()



+ 6
- 6
api/libs/oauth_data_source.py View File

import urllib.parse import urllib.parse
from typing import Any from typing import Any


import requests
import httpx
from flask_login import current_user from flask_login import current_user
from sqlalchemy import select from sqlalchemy import select


data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret) auth = (self.client_id, self.client_secret)
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)


response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }


response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json() response_json = response.json()


results.extend(response_json.get("results", [])) results.extend(response_json.get("results", []))
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
message = response_json.get("message", "unknown error") message = response_json.get("message", "unknown error")
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json() response_json = response.json()
if "object" in response_json and response_json["object"] == "user": if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"] user_type = response_json["type"]
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json() response_json = response.json()


results.extend(response_json.get("results", [])) results.extend(response_json.get("results", []))

+ 2
- 2
api/services/auth/firecrawl/firecrawl.py View File

import json import json


import requests
import httpx


from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase


return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}


def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)


def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

+ 2
- 2
api/services/auth/jina.py View File

import json import json


import requests
import httpx


from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase


return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}


def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)


def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

+ 2
- 2
api/services/auth/jina/jina.py View File

import json import json


import requests
import httpx


from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase


return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}


def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)


def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

+ 2
- 2
api/services/auth/watercrawl/watercrawl.py View File

import json import json
from urllib.parse import urljoin from urllib.parse import urljoin


import requests
import httpx


from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase


return {"Content-Type": "application/json", "X-API-KEY": self.api_key} return {"Content-Type": "application/json", "X-API-KEY": self.api_key}


def _get_request(self, url, headers): def _get_request(self, url, headers):
return requests.get(url, headers=headers)
return httpx.get(url, headers=headers)


def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

+ 2
- 2
api/services/operation_service.py View File

import os import os


import requests
import httpx




class OperationService: class OperationService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}


url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
response = httpx.request(method, url, json=json, params=params, headers=headers)


return response.json() return response.json()



+ 8
- 8
api/services/website_service.py View File

from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any


import requests
import httpx
from flask_login import current_user from flask_login import current_user


from core.helper import encrypter from core.helper import encrypter
@classmethod @classmethod
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
if not request.options.crawl_sub_pages: if not request.options.crawl_sub_pages:
response = requests.get(
response = httpx.get(
f"https://r.jina.ai/{request.url}", f"https://r.jina.ai/{request.url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
) )
raise ValueError("Failed to crawl:") raise ValueError("Failed to crawl:")
return {"status": "active", "data": response.json().get("data")} return {"status": "active", "data": response.json().get("data")}
else: else:
response = requests.post(
response = httpx.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={ json={
"url": request.url, "url": request.url,


@classmethod @classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
response = requests.post(
response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id}, json={"taskId": job_id},
} }


if crawl_status_data["status"] == "completed": if crawl_status_data["status"] == "completed":
response = requests.post(
response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
@classmethod @classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
if not job_id: if not job_id:
response = requests.get(
response = httpx.get(
f"https://r.jina.ai/{url}", f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
) )
return dict(response.json().get("data", {})) return dict(response.json().get("data", {}))
else: else:
# Get crawl status first # Get crawl status first
status_response = requests.post(
status_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id}, json={"taskId": job_id},
raise ValueError("Crawl job is not completed") raise ValueError("Crawl job is not completed")


# Get processed data # Get processed data
data_response = requests.post(
data_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},

+ 6
- 9
api/tests/integration_tests/plugin/__mock/http.py View File

import os import os
from typing import Literal from typing import Literal


import httpx
import pytest import pytest
import requests


from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@classmethod @classmethod
def requests_request( def requests_request(
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> requests.Response:
) -> httpx.Response:
""" """
Mocked requests.request
Mocked httpx.request
""" """
request = requests.PreparedRequest()
request.method = method
request.url = url
request = httpx.Request(method, url)
if url.endswith("/tools"): if url.endswith("/tools"):
content = PluginDaemonBasicResponse[list[ToolProviderEntity]]( content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
code=0, message="success", data=cls.list_tools() code=0, message="success", data=cls.list_tools()
else: else:
raise ValueError("") raise ValueError("")


response = requests.Response()
response.status_code = 200
response = httpx.Response(status_code=200)
response.request = request response.request = request
response._content = content.encode("utf-8") response._content = content.encode("utf-8")
return response return response
@pytest.fixture @pytest.fixture
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch): def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK_SWITCH: if MOCK_SWITCH:
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)


def unpatch(): def unpatch():
monkeypatch.undo() monkeypatch.undo()

+ 2
- 2
api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py View File

import os import os
import time import time


import requests
import httpx
from clickzetta import connect from clickzetta import connect




max_retries = 30 max_retries = 30
for i in range(max_retries): for i in range(max_retries):
try: try:
response = requests.get(f"{base_url}/console/api/health")
response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200: if response.status_code == 200:
print("✓ Dify API is ready") print("✓ Dify API is ready")
break break

+ 2
- 2
api/tests/unit_tests/controllers/console/auth/test_oauth.py View File

mock_db.session.rollback = MagicMock() mock_db.session.rollback = MagicMock()


# Import the real requests module to create a proper exception # Import the real requests module to create a proper exception
import requests
import httpx


request_exception = requests.exceptions.RequestException("OAuth error")
request_exception = httpx.RequestError("OAuth error")
request_exception.response = MagicMock() request_exception.response = MagicMock()
request_exception.response.text = str(exception) request_exception.response.text = str(exception)



+ 12
- 12
api/tests/unit_tests/libs/test_oauth_clients.py View File

import urllib.parse import urllib.parse
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch


import httpx
import pytest import pytest
import requests


from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo


({}, None, True), ({}, None, True),
], ],
) )
@patch("requests.post")
@patch("httpx.post")
def test_should_retrieve_access_token( def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
): ):
), ),
], ],
) )
@patch("requests.get")
@patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email): def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock() user_response = MagicMock()
user_response.json.return_value = user_data user_response.json.return_value = user_data
assert user_info.name == user_data["name"] assert user_info.name == user_data["name"]
assert user_info.email == expected_email assert user_info.email == expected_email


@patch("requests.get")
@patch("httpx.get")
def test_should_handle_network_errors(self, mock_get, oauth): def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = requests.exceptions.RequestException("Network error")
mock_get.side_effect = httpx.RequestError("Network error")


with pytest.raises(requests.exceptions.RequestException):
with pytest.raises(httpx.RequestError):
oauth.get_raw_user_info("test_token") oauth.get_raw_user_info("test_token")




({}, None, True), ({}, None, True),
], ],
) )
@patch("requests.post")
@patch("httpx.post")
def test_should_retrieve_access_token( def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
): ):
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
], ],
) )
@patch("requests.get")
@patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name): def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data mock_response.json.return_value = user_data
mock_get.return_value = mock_response mock_get.return_value = mock_response
@pytest.mark.parametrize( @pytest.mark.parametrize(
"exception_type", "exception_type",
[ [
requests.exceptions.HTTPError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
httpx.HTTPError,
httpx.ConnectError,
httpx.TimeoutException,
], ],
) )
@patch("requests.get")
@patch("httpx.get")
def test_should_handle_http_errors(self, mock_get, oauth, exception_type): def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error") mock_response.raise_for_status.side_effect = exception_type("Error")

+ 10
- 10
api/tests/unit_tests/services/auth/test_auth_integration.py View File

from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch from unittest.mock import Mock, patch


import httpx
import pytest import pytest
import requests


from services.auth.api_key_auth_factory import ApiKeyAuthFactory from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}


@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
"""Test complete authentication flow: request → validation → encryption → storage""" """Test complete authentication flow: request → validation → encryption → storage"""
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()


@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_cross_component_integration(self, mock_http): def test_cross_component_integration(self, mock_http):
"""Test factory → provider → HTTP call integration""" """Test factory → provider → HTTP call integration"""
mock_http.return_value = self._create_success_response() mock_http.return_value = self._create_success_response()
assert "another_secret" not in factory_str assert "another_secret" not in factory_str


@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
"""Test concurrent authentication creation safety""" """Test concurrent authentication creation safety"""
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)


@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_http_error_handling(self, mock_http): def test_http_error_handling(self, mock_http):
"""Test proper HTTP error handling""" """Test proper HTTP error handling"""
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 401 mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}' mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
mock_http.return_value = mock_response mock_http.return_value = mock_response


# PT012: Split into single statement for pytest.raises # PT012: Split into single statement for pytest.raises
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
with pytest.raises((requests.exceptions.HTTPError, Exception)):
with pytest.raises((httpx.HTTPError, Exception)):
factory.validate_credentials() factory.validate_credentials()


@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_network_failure_recovery(self, mock_http, mock_session): def test_network_failure_recovery(self, mock_http, mock_session):
"""Test system recovery from network failures""" """Test system recovery from network failures"""
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
mock_http.side_effect = httpx.RequestError("Network timeout")
mock_session.add = Mock() mock_session.add = Mock()
mock_session.commit = Mock() mock_session.commit = Mock()


args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}


with pytest.raises(requests.exceptions.RequestException):
with pytest.raises(httpx.RequestError):
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)


mock_session.commit.assert_not_called() mock_session.commit.assert_not_called()

+ 13
- 13
api/tests/unit_tests/services/auth/test_firecrawl_auth.py View File

from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch


import httpx
import pytest import pytest
import requests


from services.auth.firecrawl.firecrawl import FirecrawlAuth from services.auth.firecrawl.firecrawl import FirecrawlAuth


FirecrawlAuth(credentials) FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error assert str(exc_info.value) == expected_error


@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
(500, "Internal server error"), (500, "Internal server error"),
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes""" """Test handling of various HTTP error codes"""
mock_response = MagicMock() mock_response = MagicMock()
(401, "Not JSON", True, "Expecting value"), # JSON decode error (401, "Not JSON", True, "Expecting value"), # JSON decode error
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_unexpected_errors( def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
): ):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("exception_type", "exception_message"), ("exception_type", "exception_message"),
[ [
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts""" """Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message) mock_post.side_effect = exception_type(exception_message)
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value) assert "super_secret_key_12345" not in str(exc_info.value)


@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_use_custom_base_url_in_validation(self, mock_post): def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation""" """Test that custom base URL is used in validation"""
mock_response = MagicMock() mock_response = MagicMock()
assert result is True assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"


@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message""" """Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")


with pytest.raises(requests.Timeout) as exc_info:
with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials() auth_instance.validate_credentials()


# Verify the timeout exception is raised with original message # Verify the timeout exception is raised with original message

+ 10
- 10
api/tests/unit_tests/services/auth/test_jina_auth.py View File

from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch


import httpx
import pytest import pytest
import requests


from services.auth.jina.jina import JinaAuth from services.auth.jina.jina import JinaAuth


JinaAuth(credentials) JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided" assert str(exc_info.value) == "No API key provided"


@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post): def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
json={"url": "https://example.com"}, json={"url": "https://example.com"},
) )


@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_402_error(self, mock_post): def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error""" """Test handling of 402 Payment Required error"""
mock_response = MagicMock() mock_response = MagicMock()
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"


@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_409_error(self, mock_post): def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error""" """Test handling of 409 Conflict error"""
mock_response = MagicMock() mock_response = MagicMock()
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"


@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_500_error(self, mock_post): def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error""" """Test handling of 500 Internal Server Error"""
mock_response = MagicMock() mock_response = MagicMock()
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"


@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_with_text_response(self, mock_post): def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response""" """Test handling of unexpected errors with text response"""
mock_response = MagicMock() mock_response = MagicMock()
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"


@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_without_text(self, mock_post): def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response""" """Test handling of unexpected errors without text response"""
mock_response = MagicMock() mock_response = MagicMock()
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"


@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_network_errors(self, mock_post): def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors""" """Test handling of network connection errors"""
mock_post.side_effect = requests.ConnectionError("Network error")
mock_post.side_effect = httpx.ConnectError("Network error")


credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials) auth = JinaAuth(credentials)


with pytest.raises(requests.ConnectionError):
with pytest.raises(httpx.ConnectError):
auth.validate_credentials() auth.validate_credentials()


def test_should_not_expose_api_key_in_error_messages(self): def test_should_not_expose_api_key_in_error_messages(self):

+ 14
- 14
api/tests/unit_tests/services/auth/test_watercrawl_auth.py View File

from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch


import httpx
import pytest import pytest
import requests


from services.auth.watercrawl.watercrawl import WatercrawlAuth from services.auth.watercrawl.watercrawl import WatercrawlAuth


WatercrawlAuth(credentials) WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error assert str(exc_info.value) == expected_error


@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance): def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
(500, "Internal server error"), (500, "Internal server error"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance): def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes""" """Test handling of various HTTP error codes"""
mock_response = MagicMock() mock_response = MagicMock()
(401, "Not JSON", True, "Expecting value"), # JSON decode error (401, "Not JSON", True, "Expecting value"), # JSON decode error
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_unexpected_errors( def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
): ):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("exception_type", "exception_message"), ("exception_type", "exception_message"),
[ [
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance): def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts""" """Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message) mock_get.side_effect = exception_type(exception_message)
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}) WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value) assert "super_secret_key_12345" not in str(exc_info.value)


@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_custom_base_url_in_validation(self, mock_get): def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation""" """Test that custom base URL is used in validation"""
mock_response = MagicMock() mock_response = MagicMock()
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url): def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs""" """Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock() mock_response = MagicMock()
# Verify the correct URL was called # Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url assert mock_get.call_args[0][0] == expected_url


@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance): def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message""" """Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")


with pytest.raises(requests.Timeout) as exc_info:
with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials() auth_instance.validate_credentials()


# Verify the timeout exception is raised with original message # Verify the timeout exception is raised with original message

Loading…
Cancel
Save