| @@ -216,6 +216,7 @@ UNSTRUCTURED_API_KEY= | |||
| SSRF_PROXY_HTTP_URL= | |||
| SSRF_PROXY_HTTPS_URL= | |||
| SSRF_DEFAULT_MAX_RETRIES=3 | |||
| BATCH_UPLOAD_LIMIT=10 | |||
| KEYWORD_DATA_SOURCE_TYPE=database | |||
| @@ -1,13 +1,16 @@ | |||
| """ | |||
| Proxy requests to avoid SSRF | |||
| """ | |||
| import logging | |||
| import os | |||
| import time | |||
| import httpx | |||
| SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') | |||
| SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') | |||
| SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') | |||
| SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) | |||
| proxies = { | |||
| 'http://': SSRF_PROXY_HTTP_URL, | |||
| @@ -15,34 +18,55 @@ proxies = { | |||
| } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None | |||
| def make_request(method, url, **kwargs): | |||
| if SSRF_PROXY_ALL_URL: | |||
| return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) | |||
| elif proxies: | |||
| return httpx.request(method=method, url=url, proxies=proxies, **kwargs) | |||
| else: | |||
| return httpx.request(method=method, url=url, **kwargs) | |||
| BACKOFF_FACTOR = 0.5 | |||
| STATUS_FORCELIST = [429, 500, 502, 503, 504] | |||
| def get(url, **kwargs): | |||
| return make_request('GET', url, **kwargs) | |||
| def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| retries = 0 | |||
| while retries <= max_retries: | |||
| try: | |||
| if SSRF_PROXY_ALL_URL: | |||
| response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) | |||
| elif proxies: | |||
| response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) | |||
| else: | |||
| response = httpx.request(method=method, url=url, **kwargs) | |||
| if response.status_code not in STATUS_FORCELIST: | |||
| return response | |||
| else: | |||
| logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") | |||
| def post(url, **kwargs): | |||
| return make_request('POST', url, **kwargs) | |||
| except httpx.RequestError as e: | |||
| logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") | |||
| retries += 1 | |||
| if retries <= max_retries: | |||
| time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) | |||
| def put(url, **kwargs): | |||
| return make_request('PUT', url, **kwargs) | |||
| raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") | |||
| def patch(url, **kwargs): | |||
| return make_request('PATCH', url, **kwargs) | |||
| def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| return make_request('GET', url, max_retries=max_retries, **kwargs) | |||
| def delete(url, **kwargs): | |||
| return make_request('DELETE', url, **kwargs) | |||
| def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| return make_request('POST', url, max_retries=max_retries, **kwargs) | |||
| def head(url, **kwargs): | |||
| return make_request('HEAD', url, **kwargs) | |||
| def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| return make_request('PUT', url, max_retries=max_retries, **kwargs) | |||
| def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| return make_request('PATCH', url, max_retries=max_retries, **kwargs) | |||
| def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| return make_request('DELETE', url, max_retries=max_retries, **kwargs) | |||
| def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| return make_request('HEAD', url, max_retries=max_retries, **kwargs) | |||
| @@ -60,11 +60,13 @@ class JinaReaderTool(BuiltinTool): | |||
| if tool_parameters.get('no_cache', False): | |||
| headers['X-No-Cache'] = 'true' | |||
| max_retries = tool_parameters.get('max_retries', 3) | |||
| response = ssrf_proxy.get( | |||
| str(URL(self._jina_reader_endpoint + url)), | |||
| headers=headers, | |||
| params=request_params, | |||
| timeout=(10, 60), | |||
| max_retries=max_retries | |||
| ) | |||
| if tool_parameters.get('summary', False): | |||
| @@ -150,3 +150,17 @@ parameters: | |||
| pt_BR: Habilitar resumo para a saída | |||
| llm_description: enable summary | |||
| form: form | |||
| - name: max_retries | |||
| type: number | |||
| required: false | |||
| default: 3 | |||
| label: | |||
| en_US: Retry | |||
| zh_Hans: 重试 | |||
| pt_BR: Repetir | |||
| human_description: | |||
| en_US: Number of times to retry the request if it fails | |||
| zh_Hans: 请求失败时重试的次数 | |||
| pt_BR: Número de vezes para repetir a solicitação se falhar | |||
| llm_description: Number of times to retry the request if it fails | |||
| form: form | |||
| @@ -40,10 +40,12 @@ class JinaSearchTool(BuiltinTool): | |||
| if tool_parameters.get('no_cache', False): | |||
| headers['X-No-Cache'] = 'true' | |||
| max_retries = tool_parameters.get('max_retries', 3) | |||
| response = ssrf_proxy.get( | |||
| str(URL(self._jina_search_endpoint + query)), | |||
| headers=headers, | |||
| timeout=(10, 60) | |||
| timeout=(10, 60), | |||
| max_retries=max_retries | |||
| ) | |||
| return self.create_text_message(response.text) | |||
| @@ -91,3 +91,17 @@ parameters: | |||
| pt_BR: Ignorar o cache | |||
| llm_description: bypass the cache | |||
| form: form | |||
| - name: max_retries | |||
| type: number | |||
| required: false | |||
| default: 3 | |||
| label: | |||
| en_US: Retry | |||
| zh_Hans: 重试 | |||
| pt_BR: Repetir | |||
| human_description: | |||
| en_US: Number of times to retry the request if it fails | |||
| zh_Hans: 请求失败时重试的次数 | |||
| pt_BR: Número de vezes para repetir a solicitação se falhar | |||
| llm_description: Number of times to retry the request if it fails | |||
| form: form | |||
| @@ -0,0 +1,52 @@ | |||
| import random | |||
| from unittest.mock import MagicMock, patch | |||
| from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request | |||
| @patch('httpx.request') | |||
| def test_successful_request(mock_request): | |||
| mock_response = MagicMock() | |||
| mock_response.status_code = 200 | |||
| mock_request.return_value = mock_response | |||
| response = make_request('GET', 'http://example.com') | |||
| assert response.status_code == 200 | |||
| @patch('httpx.request') | |||
| def test_retry_exceed_max_retries(mock_request): | |||
| mock_response = MagicMock() | |||
| mock_response.status_code = 500 | |||
| side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES | |||
| mock_request.side_effect = side_effects | |||
| try: | |||
| make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) | |||
| raise AssertionError("Expected Exception not raised") | |||
| except Exception as e: | |||
| assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" | |||
| @patch('httpx.request') | |||
| def test_retry_logic_success(mock_request): | |||
| side_effects = [] | |||
| for _ in range(SSRF_DEFAULT_MAX_RETRIES): | |||
| status_code = random.choice(STATUS_FORCELIST) | |||
| mock_response = MagicMock() | |||
| mock_response.status_code = status_code | |||
| side_effects.append(mock_response) | |||
| mock_response_200 = MagicMock() | |||
| mock_response_200.status_code = 200 | |||
| side_effects.append(mock_response_200) | |||
| mock_request.side_effect = side_effects | |||
| response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) | |||
| assert response.status_code == 200 | |||
| assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 | |||
| assert mock_request.call_args_list[0][1].get('method') == 'GET' | |||