瀏覽代碼

chore: refactor the http executor node (#5212)

tags/0.6.12
非法操作 1 年之前
父節點
當前提交
f7900f298f
沒有連結到貢獻者的電子郵件帳戶。

+ 35
- 52
api/core/helper/ssrf_proxy.py 查看文件

""" """
Proxy requests to avoid SSRF Proxy requests to avoid SSRF
""" """

import os import os


from httpx import get as _get
from httpx import head as _head
from httpx import options as _options
from httpx import patch as _patch
from httpx import post as _post
from httpx import put as _put
from requests import delete as _delete
import httpx


SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')


requests_proxies = {
'http': SSRF_PROXY_HTTP_URL,
'https': SSRF_PROXY_HTTPS_URL
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None

httpx_proxies = {
proxies = {
'http://': SSRF_PROXY_HTTP_URL, 'http://': SSRF_PROXY_HTTP_URL,
'https://': SSRF_PROXY_HTTPS_URL 'https://': SSRF_PROXY_HTTPS_URL
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None


def get(url, *args, **kwargs):
return _get(url=url, *args, proxies=httpx_proxies, **kwargs)

def post(url, *args, **kwargs):
return _post(url=url, *args, proxies=httpx_proxies, **kwargs)

def put(url, *args, **kwargs):
return _put(url=url, *args, proxies=httpx_proxies, **kwargs)

def patch(url, *args, **kwargs):
return _patch(url=url, *args, proxies=httpx_proxies, **kwargs)

def delete(url, *args, **kwargs):
if 'follow_redirects' in kwargs:
if kwargs['follow_redirects']:
kwargs['allow_redirects'] = kwargs['follow_redirects']
kwargs.pop('follow_redirects')
if 'timeout' in kwargs:
timeout = kwargs['timeout']
if timeout is None:
kwargs.pop('timeout')
elif isinstance(timeout, tuple):
# check length of tuple
if len(timeout) == 2:
kwargs['timeout'] = timeout
elif len(timeout) == 1:
kwargs['timeout'] = timeout[0]
elif len(timeout) > 2:
kwargs['timeout'] = (timeout[0], timeout[1])
else:
kwargs['timeout'] = (timeout, timeout)
return _delete(url=url, *args, proxies=requests_proxies, **kwargs)

def head(url, *args, **kwargs):
return _head(url=url, *args, proxies=httpx_proxies, **kwargs)

def options(url, *args, **kwargs):
return _options(url=url, *args, proxies=httpx_proxies, **kwargs)

def make_request(method, url, **kwargs):
if SSRF_PROXY_ALL_URL:
return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
elif proxies:
return httpx.request(method=method, url=url, proxies=proxies, **kwargs)
else:
return httpx.request(method=method, url=url, **kwargs)


def get(url, **kwargs):
return make_request('GET', url, **kwargs)


def post(url, **kwargs):
return make_request('POST', url, **kwargs)


def put(url, **kwargs):
return make_request('PUT', url, **kwargs)


def patch(url, **kwargs):
return make_request('PATCH', url, **kwargs)


def delete(url, **kwargs):
return make_request('DELETE', url, **kwargs)


def head(url, **kwargs):
return make_request('HEAD', url, **kwargs)

+ 37
- 79
api/core/tools/tool/api_tool.py 查看文件

import json import json
from json import dumps
from os import getenv from os import getenv
from typing import Any, Union
from typing import Any
from urllib.parse import urlencode from urllib.parse import urlencode


import httpx import httpx
import requests


import core.helper.ssrf_proxy as ssrf_proxy import core.helper.ssrf_proxy as ssrf_proxy
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60'))
) )



class ApiTool(Tool): class ApiTool(Tool):
api_bundle: ApiToolBundle api_bundle: ApiToolBundle
""" """
Api tool Api tool
""" """

def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
""" """
fork a new tool with meta data fork a new tool with meta data
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
runtime=Tool.Runtime(**runtime) runtime=Tool.Runtime(**runtime)
) )
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:

def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any],
format_only: bool = False) -> str:
""" """
validate the credentials for Api tool validate the credentials for Api tool
""" """
headers = self.assembling_request(parameters) headers = self.assembling_request(parameters)


if format_only: if format_only:
return
return ''


response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
# validate response # validate response


if 'api_key_header' in credentials: if 'api_key_header' in credentials:
api_key_header = credentials['api_key_header'] api_key_header = credentials['api_key_header']
if 'api_key_value' not in credentials: if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value') raise ToolProviderCredentialValidationError('Missing api_key_value')
elif not isinstance(credentials['api_key_value'], str): elif not isinstance(credentials['api_key_value'], str):
raise ToolProviderCredentialValidationError('api_key_value must be a string') raise ToolProviderCredentialValidationError('api_key_value must be a string')
if 'api_key_header_prefix' in credentials: if 'api_key_header_prefix' in credentials:
api_key_header_prefix = credentials['api_key_header_prefix'] api_key_header_prefix = credentials['api_key_header_prefix']
if api_key_header_prefix == 'basic' and credentials['api_key_value']: if api_key_header_prefix == 'basic' and credentials['api_key_value']:
credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
elif api_key_header_prefix == 'custom': elif api_key_header_prefix == 'custom':
pass pass
headers[api_key_header] = credentials['api_key_value'] headers[api_key_header] = credentials['api_key_value']


needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
for parameter in needed_parameters: for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters: if parameter.required and parameter.name not in parameters:
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
if parameter.default is not None and parameter.name not in parameters: if parameter.default is not None and parameter.name not in parameters:
parameters[parameter.name] = parameter.default parameters[parameter.name] = parameter.default


return headers return headers


def validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> str:
def validate_and_parse_response(self, response: httpx.Response) -> str:
""" """
validate the response validate the response
""" """
return json.dumps(response) return json.dumps(response)
except Exception as e: except Exception as e:
return response.text return response.text
elif isinstance(response, requests.Response):
if not response.ok:
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
if not response.content:
return 'Empty response from the tool, please check your parameters and try again.'
try:
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception as e:
return json.dumps(response)
except Exception as e:
return response.text
else: else:
raise ValueError(f'Invalid response type {type(response)}') raise ValueError(f'Invalid response type {type(response)}')
def do_http_request(self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]) -> httpx.Response:

@staticmethod
def get_parameter_value(parameter, parameters):
if parameter['name'] in parameters:
return parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
return (parameter.get('schema', {}) or {}).get('default', '')

def do_http_request(self, url: str, method: str, headers: dict[str, Any],
parameters: dict[str, Any]) -> httpx.Response:
""" """
do http request depending on api bundle do http request depending on api bundle
""" """


# check parameters # check parameters
for parameter in self.api_bundle.openapi.get('parameters', []): for parameter in self.api_bundle.openapi.get('parameters', []):
value = self.get_parameter_value(parameter, parameters)
if parameter['in'] == 'path': if parameter['in'] == 'path':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter['required']:
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
path_params[parameter['name']] = value path_params[parameter['name']] = value


elif parameter['in'] == 'query': elif parameter['in'] == 'query':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
params[parameter['name']] = value params[parameter['name']] = value


elif parameter['in'] == 'cookie': elif parameter['in'] == 'cookie':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
cookies[parameter['name']] = value cookies[parameter['name']] = value


elif parameter['in'] == 'header': elif parameter['in'] == 'header':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
headers[parameter['name']] = value headers[parameter['name']] = value


# check if there is a request body and handle it # check if there is a request body and handle it
else: else:
body[name] = None body[name] = None
break break
# replace path parameters # replace path parameters
for name, value in path_params.items(): for name, value in path_params.items():
url = url.replace(f'{{{name}}}', f'{value}') url = url.replace(f'{{{name}}}', f'{value}')
# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
if 'Content-Type' in headers: if 'Content-Type' in headers:
if headers['Content-Type'] == 'application/json': if headers['Content-Type'] == 'application/json':
body = dumps(body)
body = json.dumps(body)
elif headers['Content-Type'] == 'application/x-www-form-urlencoded': elif headers['Content-Type'] == 'application/x-www-form-urlencoded':
body = urlencode(body) body = urlencode(body)
else: else:
body = body body = body
# do http request
if method == 'get':
response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'post':
response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'put':
response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'delete':
response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, allow_redirects=True)
elif method == 'patch':
response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'head':
response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'options':
response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)

if method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body,
timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
return response
else: else:
raise ValueError(f'Invalid http method {method}')
return response
def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10) -> Any:
raise ValueError(f'Invalid http method {self.method}')

def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]],
max_recursive=10) -> Any:
if max_recursive <= 0: if max_recursive <= 0:
raise Exception("Max recursion depth reached") raise Exception("Max recursion depth reached")
for option in any_of or []: for option in any_of or []:


# assemble invoke message # assemble invoke message
return self.create_text_message(response) return self.create_text_message(response)

+ 18
- 49
api/core/workflow/nodes/http_request/http_executor.py 查看文件

from urllib.parse import urlencode from urllib.parse import urlencode


import httpx import httpx
import requests


import core.helper.ssrf_proxy as ssrf_proxy import core.helper.ssrf_proxy as ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector


class HttpExecutorResponse: class HttpExecutorResponse:
headers: dict[str, str] headers: dict[str, str]
response: Union[httpx.Response, requests.Response]
response: httpx.Response


def __init__(self, response: Union[httpx.Response, requests.Response] = None):
self.headers = {}
if isinstance(response, httpx.Response | requests.Response):
for k, v in response.headers.items():
self.headers[k] = v
def __init__(self, response: httpx.Response = None):
self.response = response self.response = response
self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}


@property @property
def is_file(self) -> bool: def is_file(self) -> bool:
return any(v in content_type for v in file_content_types) return any(v in content_type for v in file_content_types)


def get_content_type(self) -> str: def get_content_type(self) -> str:
if 'content-type' in self.headers:
return self.headers.get('content-type')
else:
return self.headers.get('Content-Type') or ""
return self.headers.get('content-type', '')



def extract_file(self) -> tuple[str, bytes]: def extract_file(self) -> tuple[str, bytes]:
""" """


@property @property
def content(self) -> str: def content(self) -> str:
"""
get content
"""
if isinstance(self.response, httpx.Response | requests.Response):
if isinstance(self.response, httpx.Response):
return self.response.text return self.response.text
else: else:
raise ValueError(f'Invalid response type {type(self.response)}') raise ValueError(f'Invalid response type {type(self.response)}')


@property @property
def body(self) -> bytes: def body(self) -> bytes:
"""
get body
"""
if isinstance(self.response, httpx.Response | requests.Response):
if isinstance(self.response, httpx.Response):
return self.response.content return self.response.content
else: else:
raise ValueError(f'Invalid response type {type(self.response)}') raise ValueError(f'Invalid response type {type(self.response)}')


@property @property
def status_code(self) -> int: def status_code(self) -> int:
"""
get status code
"""
if isinstance(self.response, httpx.Response | requests.Response):
if isinstance(self.response, httpx.Response):
return self.response.status_code return self.response.status_code
else: else:
raise ValueError(f'Invalid response type {type(self.response)}') raise ValueError(f'Invalid response type {type(self.response)}')


@property @property
def size(self) -> int: def size(self) -> int:
"""
get size
"""
return len(self.body) return len(self.body)


@property @property
def readable_size(self) -> str: def readable_size(self) -> str:
"""
get readable size
"""
if self.size < 1024: if self.size < 1024:
return f'{self.size} bytes' return f'{self.size} bytes'
elif self.size < 1024 * 1024: elif self.size < 1024 * 1024:
return False return False


@staticmethod @staticmethod
def _to_dict(convert_item: str, convert_text: str, maxsplit: int = -1):
def _to_dict(convert_text: str):
""" """
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
:param convert_item: A label for what item to be converted, params, headers or body.
:param convert_text: The string containing key-value pairs separated by '\n'.
:param maxsplit: The maximum number of splits allowed for the ':' character in each key-value pair. Default is -1 (no limit).
:return: A dictionary containing the key-value pairs from the input string.
""" """
kv_paris = convert_text.split('\n') kv_paris = convert_text.split('\n')
result = {} result = {}
if not kv.strip(): if not kv.strip():
continue continue


kv = kv.split(':', maxsplit=maxsplit)
if len(kv) >= 3:
k, v = kv[0], ":".join(kv[1:])
elif len(kv) == 2:
k, v = kv
elif len(kv) == 1:
kv = kv.split(':', maxsplit=1)
if len(kv) == 1:
k, v = kv[0], '' k, v = kv[0], ''
else: else:
raise ValueError(f'Invalid {convert_item} {kv}')
k, v = kv
result[k.strip()] = v result[k.strip()] = v
return result return result




# extract all template in params # extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool) params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
self.params = self._to_dict("params", params)
self.params = self._to_dict(params)


# extract all template in headers # extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
self.headers = self._to_dict("headers", headers)
self.headers = self._to_dict(headers)


# extract all template in body # extract all template in body
body_data_variable_selectors = [] body_data_variable_selectors = []
self.headers['Content-Type'] = 'application/x-www-form-urlencoded' self.headers['Content-Type'] = 'application/x-www-form-urlencoded'


if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
body = self._to_dict("body", body_data, 1)
body = self._to_dict(body_data)


if node_data.body.type == 'form-data': if node_data.body.type == 'form-data':
self.files = { self.files = {


return headers return headers


def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse:
def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse:
""" """
validate the response validate the response
""" """
if isinstance(response, httpx.Response | requests.Response):
if isinstance(response, httpx.Response):
executor_response = HttpExecutorResponse(response) executor_response = HttpExecutorResponse(response)
else: else:
raise ValueError(f'Invalid response type {type(response)}') raise ValueError(f'Invalid response type {type(response)}')
'follow_redirects': True 'follow_redirects': True
} }


if self.method in ('get', 'head', 'options'):
response = getattr(ssrf_proxy, self.method)(**kwargs)
elif self.method in ('post', 'put', 'delete', 'patch'):
if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else: else:
raise ValueError(f'Invalid http method {self.method}') raise ValueError(f'Invalid http method {self.method}')

+ 36
- 0
api/tests/integration_tests/tools/__mock/http.py 查看文件

import json
from typing import Literal

import httpx
import pytest
from _pytest.monkeypatch import MonkeyPatch


class MockedHttp:
def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
url: str, **kwargs) -> httpx.Response:
"""
Mocked httpx.request
"""
request = httpx.Request(
method,
url,
params=kwargs.get('params'),
headers=kwargs.get('headers'),
cookies=kwargs.get('cookies')
)
data = kwargs.get('data', None)
resp = json.dumps(data).encode('utf-8') if data else b'OK'
response = httpx.Response(
status_code=200,
request=request,
content=resp,
)
return response


@pytest.fixture
def setup_http_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request)
yield
monkeypatch.undo()

+ 0
- 0
api/tests/integration_tests/tools/api_tool/__init__.py 查看文件


+ 39
- 0
api/tests/integration_tests/tools/api_tool/test_api_tool.py 查看文件

from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from tests.integration_tests.tools.__mock.http import setup_http_mock

tool_bundle = {
'server_url': 'http://www.example.com/{path_param}',
'method': 'post',
'author': '',
'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'},
{'in': 'query', 'name': 'query_param'},
{'in': 'cookie', 'name': 'cookie_param'},
{'in': 'header', 'name': 'header_param'},
],
'requestBody': {
'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}}
},
'parameters': []
}
parameters = {
'path_param': 'p_param',
'query_param': 'q_param',
'cookie_param': 'c_param',
'header_param': 'h_param',
'body_param': 'b_param',
}


def test_api_tool(setup_http_mock):
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'}))
headers = tool.assembling_request(parameters)
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)

assert response.status_code == 200
assert '/p_param' == response.request.url.path
assert b'query_param=q_param' == response.request.url.query
assert 'h_param' == response.request.headers.get('header_param')
assert 'application/json' == response.request.headers.get('content-type')
assert 'cookie_param=c_param' == response.request.headers.get('cookie')
assert 'b_param' in response.content.decode()

+ 17
- 49
api/tests/integration_tests/workflow/nodes/__mock/http.py 查看文件

from json import dumps from json import dumps
from typing import Literal from typing import Literal


import httpx._api as httpx
import httpx
import pytest import pytest
import requests.api as requests
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from httpx import Request as HttpxRequest
from requests import Response as RequestsResponse
from yarl import URL


MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'



class MockedHttp: class MockedHttp:
def requests_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], url: str,
**kwargs) -> RequestsResponse:
def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
url: str, **kwargs) -> httpx.Response:
""" """
Mocked requests.request
Mocked httpx.request
""" """
response = RequestsResponse()
response.url = str(URL(url) % kwargs.get('params', {}))
response.headers = kwargs.get('headers', {})

if url == 'http://404.com': if url == 'http://404.com':
response.status_code = 404
response._content = b'Not Found'
response = httpx.Response(
status_code=404,
request=httpx.Request(method, url),
content=b'Not Found'
)
return response return response
# get data, files # get data, files
data = kwargs.get('data', None) data = kwargs.get('data', None)
files = kwargs.get('files', None) files = kwargs.get('files', None)

if data is not None: if data is not None:
resp = dumps(data).encode('utf-8') resp = dumps(data).encode('utf-8')
if files is not None:
elif files is not None:
resp = dumps(files).encode('utf-8') resp = dumps(files).encode('utf-8')
else: else:
resp = b'OK' resp = b'OK'


response.status_code = 200
response._content = resp
return response

def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'],
url: str, **kwargs) -> httpx.Response:
"""
Mocked httpx.request
"""
response = httpx.Response( response = httpx.Response(
status_code=200, status_code=200,
request=HttpxRequest(method, url)
request=httpx.Request(method, url),
headers=kwargs.get('headers', {}),
content=resp
) )
response.headers = kwargs.get('headers', {})

if url == 'http://404.com':
response.status_code = 404
response.content = b'Not Found'
return response
# get data, files
data = kwargs.get('data', None)
files = kwargs.get('files', None)

if data is not None:
resp = dumps(data).encode('utf-8')
if files is not None:
resp = dumps(files).encode('utf-8')
else:
resp = b'OK'

response.status_code = 200
response._content = resp
return response return response



@pytest.fixture @pytest.fixture
def setup_http_mock(request, monkeypatch: MonkeyPatch): def setup_http_mock(request, monkeypatch: MonkeyPatch):
if not MOCK: if not MOCK:
yield yield
return return


monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request) monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request)
yield yield
monkeypatch.undo()
monkeypatch.undo()

+ 72
- 6
api/tests/integration_tests/workflow/nodes/test_http.py 查看文件

from urllib.parse import urlencode

import pytest import pytest


from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
pool.append_variable(node_id='a', variable_key_list=['b123', 'args1'], value=1) pool.append_variable(node_id='a', variable_key_list=['b123', 'args1'], value=1)
pool.append_variable(node_id='a', variable_key_list=['b123', 'args2'], value=2) pool.append_variable(node_id='a', variable_key_list=['b123', 'args2'], value=2)



@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_get(setup_http_mock): def test_get(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx',
'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data



@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_no_auth(setup_http_mock): def test_no_auth(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
assert '?A=b' in data assert '?A=b' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data



@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_custom_authorization_header(setup_http_mock): def test_custom_authorization_header(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
assert 'X-Auth: Auth' in data assert 'X-Auth: Auth' in data



@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_template(setup_http_mock): def test_template(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx',
'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
assert 'X-Header2: 2' in data assert 'X-Header2: 2' in data



@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_json(setup_http_mock): def test_json(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx',
'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data



def test_x_www_form_urlencoded(setup_http_mock): def test_x_www_form_urlencoded(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'id': '1', 'id': '1',
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx',
'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data



def test_form_data(setup_http_mock): def test_form_data(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'id': '1', 'id': '1',
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx',
'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data



def test_none_data(setup_http_mock): def test_none_data(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'id': '1', 'id': '1',
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx',
'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
assert '123123123' not in data assert '123123123' not in data


def test_mock_404(setup_http_mock):
node = HttpRequestNode(config={
'id': '1',
'data': {
'title': 'http',
'desc': '',
'method': 'get',
'url': 'http://404.com',
'authorization': {
'type': 'no-auth',
'config': None,
},
'body': None,
'params': '',
'headers': 'X-Header:123',
'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)

result = node.run(pool)
resp = result.outputs

assert 404 == resp.get('status_code')
assert 'Not Found' in resp.get('body')


def test_multi_colons_parse(setup_http_mock):
node = HttpRequestNode(config={
'id': '1',
'data': {
'title': 'http',
'desc': '',
'method': 'get',
'url': 'http://example.com',
'authorization': {
'type': 'no-auth',
'config': None,
},
'params': 'Referer:http://example1.com\nRedirect:http://example2.com',
'headers': 'Referer:http://example3.com\nRedirect:http://example4.com',
'body': {
'type': 'form-data',
'data': 'Referer:http://example5.com\nRedirect:http://example6.com'
},
'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)

result = node.run(pool)
resp = result.outputs

assert urlencode({'Redirect': 'http://example2.com'}) in result.process_data.get('request')
assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get('request')
assert 'http://example3.com' == resp.get('headers').get('referer')

Loading…
取消
儲存