|
|
|
@@ -1,4 +1,5 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import uuid |
|
|
|
from collections.abc import Mapping, Sequence |
|
|
|
from typing import Any, Optional, cast |
|
|
|
@@ -58,6 +59,30 @@ from .prompts import ( |
|
|
|
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, |
|
|
|
) |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def extract_json(text): |
|
|
|
""" |
|
|
|
From a given JSON started from '{' or '[' extract the complete JSON object. |
|
|
|
""" |
|
|
|
stack = [] |
|
|
|
for i, c in enumerate(text): |
|
|
|
if c in {"{", "["}: |
|
|
|
stack.append(c) |
|
|
|
elif c in {"}", "]"}: |
|
|
|
# check if stack is empty |
|
|
|
if not stack: |
|
|
|
return text[:i] |
|
|
|
# check if the last element in stack is matching |
|
|
|
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): |
|
|
|
stack.pop() |
|
|
|
if not stack: |
|
|
|
return text[: i + 1] |
|
|
|
else: |
|
|
|
return text[:i] |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
class ParameterExtractorNode(LLMNode): |
|
|
|
""" |
|
|
|
@@ -594,27 +619,6 @@ class ParameterExtractorNode(LLMNode): |
|
|
|
Extract complete json response. |
|
|
|
""" |
|
|
|
|
|
|
|
def extract_json(text): |
|
|
|
""" |
|
|
|
From a given JSON started from '{' or '[' extract the complete JSON object. |
|
|
|
""" |
|
|
|
stack = [] |
|
|
|
for i, c in enumerate(text): |
|
|
|
if c in {"{", "["}: |
|
|
|
stack.append(c) |
|
|
|
elif c in {"}", "]"}: |
|
|
|
# check if stack is empty |
|
|
|
if not stack: |
|
|
|
return text[:i] |
|
|
|
# check if the last element in stack is matching |
|
|
|
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): |
|
|
|
stack.pop() |
|
|
|
if not stack: |
|
|
|
return text[: i + 1] |
|
|
|
else: |
|
|
|
return text[:i] |
|
|
|
return None |
|
|
|
|
|
|
|
# extract json from the text |
|
|
|
for idx in range(len(result)): |
|
|
|
if result[idx] == "{" or result[idx] == "[": |
|
|
|
@@ -624,6 +628,7 @@ class ParameterExtractorNode(LLMNode): |
|
|
|
return cast(dict, json.loads(json_str)) |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
logger.info(f"extra error: {result}") |
|
|
|
return None |
|
|
|
|
|
|
|
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: |
|
|
|
@@ -633,7 +638,18 @@ class ParameterExtractorNode(LLMNode): |
|
|
|
if not tool_call or not tool_call.function.arguments: |
|
|
|
return None |
|
|
|
|
|
|
|
return cast(dict, json.loads(tool_call.function.arguments)) |
|
|
|
result = tool_call.function.arguments |
|
|
|
# extract json from the arguments |
|
|
|
for idx in range(len(result)): |
|
|
|
if result[idx] == "{" or result[idx] == "[": |
|
|
|
json_str = extract_json(result[idx:]) |
|
|
|
if json_str: |
|
|
|
try: |
|
|
|
return cast(dict, json.loads(json_str)) |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
logger.info(f"extra error: {result}") |
|
|
|
return None |
|
|
|
|
|
|
|
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: |
|
|
|
""" |