浏览代码

robust for json parser (#17687)

tags/1.3.0
zxfishhack 6 个月前
父节点
当前提交
5541a1f80e
没有帐户链接到提交者的电子邮件

+ 38
- 22
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py 查看文件

@@ -1,4 +1,5 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
@@ -58,6 +59,30 @@ from .prompts import (
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
)

logger = logging.getLogger(__name__)


def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c in {"{", "["}:
stack.append(c)
elif c in {"}", "]"}:
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
stack.pop()
if not stack:
return text[: i + 1]
else:
return text[:i]
return None


class ParameterExtractorNode(LLMNode):
"""
@@ -594,27 +619,6 @@ class ParameterExtractorNode(LLMNode):
Extract complete json response.
"""

def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c in {"{", "["}:
stack.append(c)
elif c in {"}", "]"}:
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
stack.pop()
if not stack:
return text[: i + 1]
else:
return text[:i]
return None

# extract json from the text
for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[":
@@ -624,6 +628,7 @@ class ParameterExtractorNode(LLMNode):
return cast(dict, json.loads(json_str))
except Exception:
pass
logger.info(f"extra error: {result}")
return None

def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
@@ -633,7 +638,18 @@ class ParameterExtractorNode(LLMNode):
if not tool_call or not tool_call.function.arguments:
return None

return cast(dict, json.loads(tool_call.function.arguments))
result = tool_call.function.arguments
# extract json from the arguments
for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
try:
return cast(dict, json.loads(json_str))
except Exception:
pass
logger.info(f"extra error: {result}")
return None

def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""

+ 41
- 0
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py 查看文件

@@ -5,6 +5,7 @@ from typing import Optional
from unittest.mock import MagicMock

from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
@@ -311,6 +312,46 @@ def test_extract_json_response():
assert result["location"] == "kawaii"


def test_extract_json_from_tool_call():
"""
Test extract json response.
"""

node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "parameter-extractor",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {},
},
"query": ["sys", "query"],
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
"reasoning_mode": "prompt",
"instruction": "{{#sys.query#}}",
"memory": None,
},
},
)

result = node._extract_json_from_tool_call(
AssistantPromptMessage.ToolCall(
id="llm",
type="parameter-extractor",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name="foo", arguments="""{"location":"kawaii"}{"location": 1}"""
),
)
)

assert result is not None
assert result["location"] == "kawaii"


def test_chat_parameter_extractor_with_memory(setup_model_mock):
"""
Test chat parameter extractor with memory.

正在加载...
取消
保存