瀏覽代碼

Fix/incorrect parameter extractor memory (#6038)

tags/0.6.13
Yeuoly 1 年之前
父節點
當前提交
a877d4831d
沒有連結到貢獻者的電子郵件帳戶。

+ 1
- 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py 查看文件

@@ -365,7 +365,7 @@ class ParameterExtractorNode(LLMNode):
files=[],
context='',
memory_config=node_data.memory,
memory=memory,
memory=None,
model_config=model_config
)


+ 92
- 1
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py 查看文件

@@ -1,5 +1,6 @@
import json
import os
from typing import Optional
from unittest.mock import MagicMock

import pytest
@@ -7,6 +8,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
@@ -61,6 +63,16 @@ def get_mocked_fetch_model_config(

return MagicMock(return_value=(model_instance, model_config))

def get_mocked_fetch_memory(memory_text: str):
class MemoryMock:
def get_history_prompt_text(self, human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: Optional[int] = None):
return memory_text

return MagicMock(return_value=MemoryMock())

@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_function_calling_parameter_extractor(setup_openai_mock):
"""
@@ -354,4 +366,83 @@ def test_extract_json_response():
hello world.
""")

assert result['location'] == 'kawaii'
assert result['location'] == 'kawaii'

@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
"""
Test chat parameter extractor with memory.
"""
node = ParameterExtractorNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
config={
'id': 'llm',
'data': {
'title': '123',
'type': 'parameter-extractor',
'model': {
'provider': 'anthropic',
'name': 'claude-2',
'mode': 'chat',
'completion_params': {}
},
'query': ['sys', 'query'],
'parameters': [{
'name': 'location',
'type': 'string',
'description': 'location',
'required': True
}],
'reasoning_mode': 'prompt',
'instruction': '',
'memory': {
'window': {
'enabled': True,
'size': 50
}
},
}
}
)

node._fetch_model_config = get_mocked_fetch_model_config(
provider='anthropic', model='claude-2', mode='chat', credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}
)
node._fetch_memory = get_mocked_fetch_memory('customized memory')
db.session.close = MagicMock()

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})

result = node.run(pool)

assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs.get('location') == ''
assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.'
prompts = result.process_data.get('prompts')

latest_role = None
for prompt in prompts:
if prompt.get('role') == 'user':
if '<structure>' in prompt.get('text'):
assert '<structure>\n{"type": "object"' in prompt.get('text')
elif prompt.get('role') == 'system':
assert 'customized memory' in prompt.get('text')

if latest_role is not None:
assert latest_role != prompt.get('role')

if prompt.get('role') in ['user', 'assistant']:
latest_role = prompt.get('role')

Loading…
取消
儲存