| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 | import json
import logging
from unittest import mock
import pytest
from flask import Flask, Response
from configs import dify_config
from extensions import ext_request_logging
from extensions.ext_request_logging import _is_content_type_json, _log_request_finished, init_app
def test_is_content_type_json():
    """
    Test the _is_content_type_json function.
    """
    assert _is_content_type_json("application/json") is True
    # content type header with charset option.
    assert _is_content_type_json("application/json; charset=utf-8") is True
    # content type header with charset option, in uppercase.
    assert _is_content_type_json("APPLICATION/JSON; CHARSET=UTF-8") is True
    assert _is_content_type_json("text/html") is False
    assert _is_content_type_json("") is False
_KEY_NEEDLE = "needle"
_VALUE_NEEDLE = _KEY_NEEDLE[::-1]
_RESPONSE_NEEDLE = "response"
def _get_test_app():
    app = Flask(__name__)
    @app.route("/", methods=["GET", "POST"])
    def handler():
        return _RESPONSE_NEEDLE
    return app
# NOTE(QuantumGhost): Due to the design of Flask, we need to use monkey patch to write tests.
@pytest.fixture
def mock_request_receiver(monkeypatch) -> mock.Mock:
    mock_log_request_started = mock.Mock()
    monkeypatch.setattr(ext_request_logging, "_log_request_started", mock_log_request_started)
    return mock_log_request_started
@pytest.fixture
def mock_response_receiver(monkeypatch) -> mock.Mock:
    mock_log_request_finished = mock.Mock()
    monkeypatch.setattr(ext_request_logging, "_log_request_finished", mock_log_request_finished)
    return mock_log_request_finished
@pytest.fixture
def mock_logger(monkeypatch) -> logging.Logger:
    _logger = mock.MagicMock(spec=logging.Logger)
    monkeypatch.setattr(ext_request_logging, "_logger", _logger)
    return _logger
@pytest.fixture
def enable_request_logging(monkeypatch):
    monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", True)
class TestRequestLoggingExtension:
    def test_receiver_should_not_be_invoked_if_configuration_is_disabled(
        self,
        monkeypatch,
        mock_request_receiver,
        mock_response_receiver,
    ):
        monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", False)
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            client.get("/")
        mock_request_receiver.assert_not_called()
        mock_response_receiver.assert_not_called()
    def test_receiver_should_be_called_if_enabled(
        self,
        enable_request_logging,
        mock_request_receiver,
        mock_response_receiver,
    ):
        """
        Test the request logging extension with JSON data.
        """
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
        mock_request_receiver.assert_called_once()
        mock_response_receiver.assert_called_once()
class TestLoggingLevel:
    @pytest.mark.usefixtures("enable_request_logging")
    def test_logging_should_be_skipped_if_level_is_above_debug(self, enable_request_logging, mock_logger):
        mock_logger.isEnabledFor.return_value = False
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
        mock_logger.debug.assert_not_called()
class TestRequestReceiverLogging:
    @pytest.mark.usefixtures("enable_request_logging")
    def test_non_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
        mock_logger.isEnabledFor.return_value = True
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            client.post("/", data="plain text")
        assert mock_logger.debug.call_count == 1
        call_args = mock_logger.debug.call_args[0]
        assert "Received Request" in call_args[0]
        assert call_args[1] == "POST"
        assert call_args[2] == "/"
        assert "Request Body" not in call_args[0]
    @pytest.mark.usefixtures("enable_request_logging")
    def test_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
        mock_logger.isEnabledFor.return_value = True
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
        assert mock_logger.debug.call_count == 1
        call_args = mock_logger.debug.call_args[0]
        assert "Received Request" in call_args[0]
        assert "Request Body" in call_args[0]
        assert call_args[1] == "POST"
        assert call_args[2] == "/"
        assert _KEY_NEEDLE in call_args[3]
    @pytest.mark.usefixtures("enable_request_logging")
    def test_json_request_with_empty_body(self, enable_request_logging, mock_logger, mock_response_receiver):
        mock_logger.isEnabledFor.return_value = True
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            client.post("/", headers={"Content-Type": "application/json"})
        assert mock_logger.debug.call_count == 1
        call_args = mock_logger.debug.call_args[0]
        assert "Received Request" in call_args[0]
        assert "Request Body" not in call_args[0]
        assert call_args[1] == "POST"
        assert call_args[2] == "/"
    @pytest.mark.usefixtures("enable_request_logging")
    def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
        mock_logger.isEnabledFor.return_value = True
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            client.post(
                "/",
                headers={"Content-Type": "application/json"},
                data="{",
            )
        assert mock_logger.debug.call_count == 0
        assert mock_logger.exception.call_count == 1
        exception_call_args = mock_logger.exception.call_args[0]
        assert exception_call_args[0] == "Failed to parse JSON request"
class TestResponseReceiverLogging:
    @pytest.mark.usefixtures("enable_request_logging")
    def test_non_json_response(self, enable_request_logging, mock_logger):
        mock_logger.isEnabledFor.return_value = True
        app = _get_test_app()
        response = Response(
            "OK",
            headers={"Content-Type": "text/plain"},
        )
        _log_request_finished(app, response)
        assert mock_logger.debug.call_count == 1
        call_args = mock_logger.debug.call_args[0]
        assert "Response" in call_args[0]
        assert "200" in call_args[1]
        assert call_args[2] == "text/plain"
        assert "Response Body" not in call_args[0]
    @pytest.mark.usefixtures("enable_request_logging")
    def test_json_response(self, enable_request_logging, mock_logger, mock_response_receiver):
        mock_logger.isEnabledFor.return_value = True
        app = _get_test_app()
        response = Response(
            json.dumps({_KEY_NEEDLE: _VALUE_NEEDLE}),
            headers={"Content-Type": "application/json"},
        )
        _log_request_finished(app, response)
        assert mock_logger.debug.call_count == 1
        call_args = mock_logger.debug.call_args[0]
        assert "Response" in call_args[0]
        assert "Response Body" in call_args[0]
        assert "200" in call_args[1]
        assert call_args[2] == "application/json"
        assert _KEY_NEEDLE in call_args[3]
    @pytest.mark.usefixtures("enable_request_logging")
    def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
        mock_logger.isEnabledFor.return_value = True
        app = _get_test_app()
        response = Response(
            "{",
            headers={"Content-Type": "application/json"},
        )
        _log_request_finished(app, response)
        assert mock_logger.debug.call_count == 0
        assert mock_logger.exception.call_count == 1
        exception_call_args = mock_logger.exception.call_args[0]
        assert exception_call_args[0] == "Failed to parse JSON response"
class TestResponseUnmodified:
    def test_when_request_logging_disabled(self):
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            response = client.post(
                "/",
                headers={"Content-Type": "application/json"},
                data="{",
            )
        assert response.text == _RESPONSE_NEEDLE
        assert response.status_code == 200
    @pytest.mark.usefixtures("enable_request_logging")
    def test_when_request_logging_enabled(self, enable_request_logging):
        app = _get_test_app()
        init_app(app)
        with app.test_client() as client:
            response = client.post(
                "/",
                headers={"Content-Type": "application/json"},
                data="{",
            )
        assert response.text == _RESPONSE_NEEDLE
        assert response.status_code == 200
 |