瀏覽代碼

Refa: limit embedding concurrency and fix `chat_with_tool` (#8543)

### What problem does this PR solve?

#8538

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
tags/v0.20.0
Kevin Hu 4 月之前
父節點
當前提交
e441c17c2c
No account linked to committer's email address
共有 2 個文件被更改,包括 75 次插入303 次删除
  1. 67
    295
      rag/llm/chat_model.py
  2. 8
    8
      rag/raptor.py

+ 67
- 295
rag/llm/chat_model.py 查看文件

@@ -18,11 +18,9 @@ import json
import logging
import os
import random
import re
import time
from abc import ABC
from copy import deepcopy
from http import HTTPStatus
from typing import Any, Protocol
from urllib.parse import urljoin

@@ -61,9 +59,6 @@ class ToolCallSession(Protocol):


class Base(ABC):
tools: list[Any]
toolcall_sessions: dict[str, ToolCallSession]
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
@@ -146,6 +141,37 @@ class Base(ABC):
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"

def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({
"name": name,
"args": args,
"result": res
}, ensure_ascii=False, indent=2) + "</tool_call>"

def _append_history(self, hist, tool_call, tool_res):
hist.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
try:
if isinstance(tool_res, dict):
tool_res = json.dumps(tool_res, ensure_ascii=False)
finally:
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
return hist

def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
@@ -160,18 +186,19 @@ class Base(ABC):
if system:
history.insert(0, {"role": "system", "content": system})

gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
for attempt in range(self.max_retries+1):
history = hist
for _ in range(self.max_rounds * 2):
try:
try:
for _ in range(self.max_rounds*2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
raise Exception("500 response structure error.")
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")

if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
@@ -188,14 +215,17 @@ class Base(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))

except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."

def chat(self, system, history, gen_conf):
@@ -228,9 +258,7 @@ class Base(ABC):
return final_tool_calls

def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]

gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
@@ -240,9 +268,9 @@ class Base(ABC):
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
for _ in range(self.max_rounds * 2):
reasoning_start = False
try:
try:
for _ in range(self.max_rounds*2):
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
final_tool_calls = {}
answer = ""
@@ -252,9 +280,11 @@ class Base(ABC):
index = tool_call.index

if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue

if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
@@ -293,40 +323,26 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
history.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
tool_response = self.toolcall_session[name].tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield total_tokens
return
yield self._verbose_tool_use(name, {}, str(e))

assert False, "Shouldn't be here."
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield total_tokens
return

yield total_tokens

def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
reasoning_start = False
@@ -542,252 +558,8 @@ class BaiChuanChat(Base):

class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

import dashscope

dashscope.api_key = key
self.model_name = model_name
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)

def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
# if self.is_reasoning_model(self.model_name):
# return super().chat(system, history, gen_conf)

stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
if not stream_flag:
from http import HTTPStatus

tools = self.tools

if system:
history.insert(0, {"role": "system", "content": system})

response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
assistant_output = response.output.choices[0].message
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += response.output.choices[0].message.content

if "tool_calls" not in assistant_output:
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count

tk_count += self.total_token_count(response)
history.append(assistant_output)

while "tool_calls" in assistant_output:
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
tool_name = assistant_output.tool_calls[0]["function"]["name"]
if tool_name:
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments)
history.append(tool_info)

response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
if response.output.choices[0].get("finish_reason", "") == "length":
tk_count += self.total_token_count(response)
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count

tk_count += self.total_token_count(response)
assistant_output = response.output.choices[0].message
if assistant_output.content is None:
assistant_output.content = ""
history.append(response)
ans += assistant_output["content"]
return ans, tk_count
else:
return "**ERROR**: " + response.message, tk_count
else:
result_list = []
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
result_list.append(result)
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
if len(error_msg_list) > 0:
return "**ERROR**: " + "".join(error_msg_list), 0
else:
return "".join(result_list[:-1]), result_list[-1]

def _chat(self, history, gen_conf):
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
return super()._chat(history, gen_conf)
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return "**ERROR**: " + response.message, tk_count

def _wrap_toolcall_message(self, old_message, message):
if not old_message:
return message
tool_call_id = message["tool_calls"][0].get("id")
if tool_call_id:
old_message.tool_calls[0]["id"] = tool_call_id
function = message.tool_calls[0]["function"]
if function:
if function.get("name"):
old_message.tool_calls[0]["function"]["name"] = function["name"]
if function.get("arguments"):
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
return old_message

def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
from http import HTTPStatus

if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
tk_count = 0
try:
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
tool_info = {"content": "", "role": "tool"}
toolcall_message = None
tool_name = ""
tool_arguments = ""
finish_completion = False
reasoning_start = False
while not finish_completion:
for resp in response:
if resp.status_code == HTTPStatus.OK:
assistant_output = resp.output.choices[0].message
ans = resp.output.choices[0].message.content
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans = resp.output.choices[0].message.reasoning_content
if not reasoning_start:
reasoning_start = True
ans = "<think>" + ans
else:
ans = ans + "</think>"

if "tool_calls" not in assistant_output:
reasoning_start = False
tk_count += self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
finish_reason = resp.output.choices[0]["finish_reason"]
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue

tk_count += self.total_token_count(resp)
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
if "tool_calls" in assistant_output:
tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
if tool_call_finish_reason == "tool_calls":
try:
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool tool call error")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break

tool_name = toolcall_message.tool_calls[0]["function"]["name"]
history.append(toolcall_message)
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments)
history.append(tool_info)
tool_info = {"content": "", "role": "tool"}
tool_name = ""
tool_arguments = ""
toolcall_message = None
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
else:
yield (
ans + "\n**ERROR**: " + resp.output.choices[0].message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool")
yield ans + "\n**ERROR**: " + str(e)
yield tk_count

def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus

if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
tk_count = 0
try:
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
for resp in response:
if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]["message"]["content"]
tk_count = self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
else:
yield (
ans + "\n**ERROR**: " + resp.message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield tk_count

def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]

for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
yield txt

def chat_streamly(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
return super().chat_streamly(system, history, gen_conf)

return self._chat_streamly(system, history, gen_conf)

@staticmethod
def is_reasoning_model(model_name: str) -> bool:
return any(
[
model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
]
)
super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
return


class ZhipuChat(Base):
@@ -1877,4 +1649,4 @@ class GPUStackChat(Base):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
super().__init__(key, model_name, base_url, **kwargs)

+ 8
- 8
rag/raptor.py 查看文件

@@ -105,14 +105,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
],
{"temperature": 0.3, "max_tokens": self._max_token},
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))

labels = []
while end - start > 1:

Loading…
取消
儲存