|
|
|
@@ -1,5 +1,6 @@ |
|
|
|
import json |
|
|
|
import os |
|
|
|
from typing import Optional |
|
|
|
from unittest.mock import MagicMock |
|
|
|
|
|
|
|
import pytest |
|
|
|
@@ -7,6 +8,7 @@ import pytest |
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity |
|
|
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle |
|
|
|
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration |
|
|
|
from core.memory.token_buffer_memory import TokenBufferMemory |
|
|
|
from core.model_manager import ModelInstance |
|
|
|
from core.model_runtime.entities.model_entities import ModelType |
|
|
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory |
|
|
|
@@ -61,6 +63,16 @@ def get_mocked_fetch_model_config( |
|
|
|
|
|
|
|
return MagicMock(return_value=(model_instance, model_config)) |
|
|
|
|
|
|
|
def get_mocked_fetch_memory(memory_text: str): |
|
|
|
class MemoryMock: |
|
|
|
def get_history_prompt_text(self, human_prefix: str = "Human", |
|
|
|
ai_prefix: str = "Assistant", |
|
|
|
max_token_limit: int = 2000, |
|
|
|
message_limit: Optional[int] = None): |
|
|
|
return memory_text |
|
|
|
|
|
|
|
return MagicMock(return_value=MemoryMock()) |
|
|
|
|
|
|
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) |
|
|
|
def test_function_calling_parameter_extractor(setup_openai_mock): |
|
|
|
""" |
|
|
|
@@ -354,4 +366,83 @@ def test_extract_json_response(): |
|
|
|
hello world. |
|
|
|
""") |
|
|
|
|
|
|
|
assert result['location'] == 'kawaii' |
|
|
|
assert result['location'] == 'kawaii' |
|
|
|
|
|
|
|
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) |
|
|
|
def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): |
|
|
|
""" |
|
|
|
Test chat parameter extractor with memory. |
|
|
|
""" |
|
|
|
node = ParameterExtractorNode( |
|
|
|
tenant_id='1', |
|
|
|
app_id='1', |
|
|
|
workflow_id='1', |
|
|
|
user_id='1', |
|
|
|
invoke_from=InvokeFrom.WEB_APP, |
|
|
|
user_from=UserFrom.ACCOUNT, |
|
|
|
config={ |
|
|
|
'id': 'llm', |
|
|
|
'data': { |
|
|
|
'title': '123', |
|
|
|
'type': 'parameter-extractor', |
|
|
|
'model': { |
|
|
|
'provider': 'anthropic', |
|
|
|
'name': 'claude-2', |
|
|
|
'mode': 'chat', |
|
|
|
'completion_params': {} |
|
|
|
}, |
|
|
|
'query': ['sys', 'query'], |
|
|
|
'parameters': [{ |
|
|
|
'name': 'location', |
|
|
|
'type': 'string', |
|
|
|
'description': 'location', |
|
|
|
'required': True |
|
|
|
}], |
|
|
|
'reasoning_mode': 'prompt', |
|
|
|
'instruction': '', |
|
|
|
'memory': { |
|
|
|
'window': { |
|
|
|
'enabled': True, |
|
|
|
'size': 50 |
|
|
|
} |
|
|
|
}, |
|
|
|
} |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
node._fetch_model_config = get_mocked_fetch_model_config( |
|
|
|
provider='anthropic', model='claude-2', mode='chat', credentials={ |
|
|
|
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') |
|
|
|
} |
|
|
|
) |
|
|
|
node._fetch_memory = get_mocked_fetch_memory('customized memory') |
|
|
|
db.session.close = MagicMock() |
|
|
|
|
|
|
|
# construct variable pool |
|
|
|
pool = VariablePool(system_variables={ |
|
|
|
SystemVariable.QUERY: 'what\'s the weather in SF', |
|
|
|
SystemVariable.FILES: [], |
|
|
|
SystemVariable.CONVERSATION_ID: 'abababa', |
|
|
|
SystemVariable.USER_ID: 'aaa' |
|
|
|
}, user_inputs={}) |
|
|
|
|
|
|
|
result = node.run(pool) |
|
|
|
|
|
|
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED |
|
|
|
assert result.outputs.get('location') == '' |
|
|
|
assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' |
|
|
|
prompts = result.process_data.get('prompts') |
|
|
|
|
|
|
|
latest_role = None |
|
|
|
for prompt in prompts: |
|
|
|
if prompt.get('role') == 'user': |
|
|
|
if '<structure>' in prompt.get('text'): |
|
|
|
assert '<structure>\n{"type": "object"' in prompt.get('text') |
|
|
|
elif prompt.get('role') == 'system': |
|
|
|
assert 'customized memory' in prompt.get('text') |
|
|
|
|
|
|
|
if latest_role is not None: |
|
|
|
assert latest_role != prompt.get('role') |
|
|
|
|
|
|
|
if prompt.get('role') in ['user', 'assistant']: |
|
|
|
latest_role = prompt.get('role') |