|
|
|
@@ -12,39 +12,45 @@ class CotAgentOutputParser: |
|
|
|
def handle_react_stream_output( |
|
|
|
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict |
|
|
|
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: |
|
|
|
def parse_action(json_str): |
|
|
|
try: |
|
|
|
action = json.loads(json_str, strict=False) |
|
|
|
action_name = None |
|
|
|
action_input = None |
|
|
|
|
|
|
|
# cohere always returns a list |
|
|
|
if isinstance(action, list) and len(action) == 1: |
|
|
|
action = action[0] |
|
|
|
|
|
|
|
for key, value in action.items(): |
|
|
|
if "input" in key.lower(): |
|
|
|
action_input = value |
|
|
|
else: |
|
|
|
action_name = value |
|
|
|
|
|
|
|
if action_name is not None and action_input is not None: |
|
|
|
return AgentScratchpadUnit.Action( |
|
|
|
action_name=action_name, |
|
|
|
action_input=action_input, |
|
|
|
) |
|
|
|
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]: |
|
|
|
action_name = None |
|
|
|
action_input = None |
|
|
|
if isinstance(action, str): |
|
|
|
try: |
|
|
|
action = json.loads(action, strict=False) |
|
|
|
except json.JSONDecodeError: |
|
|
|
return action or "" |
|
|
|
|
|
|
|
# cohere always returns a list |
|
|
|
if isinstance(action, list) and len(action) == 1: |
|
|
|
action = action[0] |
|
|
|
|
|
|
|
for key, value in action.items(): |
|
|
|
if "input" in key.lower(): |
|
|
|
action_input = value |
|
|
|
else: |
|
|
|
return json_str or "" |
|
|
|
action_name = value |
|
|
|
|
|
|
|
if action_name is not None and action_input is not None: |
|
|
|
return AgentScratchpadUnit.Action( |
|
|
|
action_name=action_name, |
|
|
|
action_input=action_input, |
|
|
|
) |
|
|
|
else: |
|
|
|
return json.dumps(action) |
|
|
|
|
|
|
|
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]: |
|
|
|
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE) |
|
|
|
if not blocks: |
|
|
|
return [] |
|
|
|
try: |
|
|
|
json_blocks = [] |
|
|
|
for block in blocks: |
|
|
|
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) |
|
|
|
json_blocks.append(json.loads(json_text, strict=False)) |
|
|
|
return json_blocks |
|
|
|
except: |
|
|
|
return json_str or "" |
|
|
|
|
|
|
|
def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: |
|
|
|
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) |
|
|
|
if not code_blocks: |
|
|
|
return |
|
|
|
for block in code_blocks: |
|
|
|
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) |
|
|
|
yield parse_action(json_text) |
|
|
|
return [] |
|
|
|
|
|
|
|
code_block_cache = "" |
|
|
|
code_block_delimiter_count = 0 |
|
|
|
@@ -78,7 +84,7 @@ class CotAgentOutputParser: |
|
|
|
delta = response_content[index : index + steps] |
|
|
|
yield_delta = False |
|
|
|
|
|
|
|
if delta == "`": |
|
|
|
if not in_json and delta == "`": |
|
|
|
last_character = delta |
|
|
|
code_block_cache += delta |
|
|
|
code_block_delimiter_count += 1 |
|
|
|
@@ -159,8 +165,14 @@ class CotAgentOutputParser: |
|
|
|
if code_block_delimiter_count == 3: |
|
|
|
if in_code_block: |
|
|
|
last_character = delta |
|
|
|
yield from extra_json_from_code_block(code_block_cache) |
|
|
|
code_block_cache = "" |
|
|
|
action_json_list = extra_json_from_code_block(code_block_cache) |
|
|
|
if action_json_list: |
|
|
|
for action_json in action_json_list: |
|
|
|
yield parse_action(action_json) |
|
|
|
code_block_cache = "" |
|
|
|
else: |
|
|
|
index += steps |
|
|
|
continue |
|
|
|
|
|
|
|
in_code_block = not in_code_block |
|
|
|
code_block_delimiter_count = 0 |