Explorar el Código

Refa: chat with tools. (#8210)

### What problem does this PR solve?


### Type of change
- [x] Refactoring
tags/v0.19.1
Kevin Hu hace 4 meses
padre
commit
56ee69e9d9
No account linked to committer's email address
Se han modificado 2 ficheros con 131 adiciones y 172 borrados
  1. 130
    171
      rag/llm/chat_model.py
  2. 1
    1
      rag/prompts.py

+ 130
- 171
rag/llm/chat_model.py Ver fichero

@@ -26,6 +26,7 @@ from http import HTTPStatus
from typing import Any, Protocol
from urllib.parse import urljoin

import json_repair
import openai
import requests
from dashscope import Generation
@@ -67,11 +68,12 @@ class Base(ABC):
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False

def _get_delay(self, attempt):
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * (2**attempt) + random.uniform(0, 0.5)
return self.base_delay + random.uniform(0, 0.5)

def _classify_error(self, error):
"""Classify error based on error message content"""
@@ -116,6 +118,29 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)

def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
return ans + LENGTH_NOTIFICATION_EN

def _exceptions(self, e, attempt):
logging.exception("OpenAI cat_with_tools")
# Classify the error
error_code = self._classify_error(e)

# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries

if should_retry:
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries:
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"

def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
@@ -124,76 +149,48 @@ class Base(ABC):
self.tools = tools

def chat_with_tools(self, system: str, history: list, gen_conf: dict):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]

tools = self.tools

gen_conf = self._clean_conf()
if system:
history.insert(0, {"role": "system", "content": system})

ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries):
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)

assistant_output = response.choices[0].message
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += response.choices[0].message.content

if not response.choices[0].message.tool_calls:
for attempt in range(self.max_retries+1):
history = hist
for _ in range(self.max_rounds*2):
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response)
if response.choices[0].finish_reason == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count

tk_count += self.total_token_count(response)
history.append(assistant_output)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
raise Exception("500 response structure error.")

for tool_call in response.choices[0].message.tool_calls:
name = tool_call.function.name
args = json.loads(tool_call.function.arguments)
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"

tool_response = self.toolcall_session.tool_call(name, args)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)

final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
assistant_output = final_response.choices[0].message
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += final_response.choices[0].message.content
if final_response.choices[0].finish_reason == "length":
tk_count += self.total_token_count(response)
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return ans, tk_count
return ans, tk_count

except Exception as e:
logging.exception("OpenAI cat_with_tools")
# Classify the error
error_code = self._classify_error(e)
for tool_call in response.choices[0].message.tool_calls:
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
except Exception as e:
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})

# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1

if should_retry:
delay = self._get_delay(attempt)
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries - 1:
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."

def chat(self, system, history, gen_conf):
if system:
@@ -201,26 +198,14 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf)

# Implement exponential backoff retry strategy
for attempt in range(self.max_retries):
for attempt in range(self.max_retries+1):
try:
return self._chat(history, gen_conf)
except Exception as e:
logging.exception("chat_model.Base.chat got exception")
# Classify the error
error_code = self._classify_error(e)

# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1

if should_retry:
delay = self._get_delay(attempt)
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries - 1:
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."

def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
@@ -241,41 +226,48 @@ class Base(ABC):
del gen_conf["max_tokens"]

tools = self.tools

if system:
history.insert(0, {"role": "system", "content": system})

ans = ""
total_tokens = 0
reasoning_start = False
finish_completion = False
final_tool_calls = {}
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
while not finish_completion:
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index

if index not in final_tool_calls:
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments
else:
if not resp.choices:
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries+1):
history = hist
for _ in range(self.max_rounds*2):
reasoning_start = False
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index

if index not in final_tool_calls:
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments
continue

if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")

if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""

if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
ans = resp.choices[0].delta.content
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content

tol = self.total_token_count(resp)
if not tol:
@@ -283,18 +275,18 @@ class Base(ABC):
else:
total_tokens += tol

finish_reason = resp.choices[0].finish_reason
if finish_reason == "tool_calls" and final_tool_calls:
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json.loads(tool_call.function.arguments)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
yield self._length_stop("")

if answer:
yield total_tokens
return

for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history.append(
{
@@ -313,26 +305,16 @@ class Base(ABC):
}
)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_tool_calls = {}
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
continue
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_tokens
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield total_tokens
return

except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)

yield total_tokens
assert False, "Shouldn't be here."

def chat_streamly(self, system, history, gen_conf):
if system:
@@ -636,49 +618,21 @@ class QWenChat(Base):
return "".join(result_list[:-1]), result_list[-1]

def _chat(self, history, gen_conf):
tk_count = 0
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
try:
response = super()._chat(history, gen_conf)
return response
except Exception as e:
error_msg = str(e).lower()
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
return self._simulate_one_shot_from_stream(history, gen_conf)
return super()._chat(history, gen_conf)
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
return "**ERROR**: " + str(e), tk_count

try:
ans = ""
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return "**ERROR**: " + response.message, tk_count
except Exception as e:
error_msg = str(e).lower()
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
return self._simulate_one_shot_from_stream(history, gen_conf)
else:
return "**ERROR**: " + str(e), tk_count

def _simulate_one_shot_from_stream(self, history, gen_conf):
"""
Handles models that require streaming output but need one-shot response.
"""
g = self._chat_streamly("", history, gen_conf, incremental_output=True)
result_list = list(g)
error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0]
if len(error_msg_list) > 0:
return "**ERROR**: " + "".join(error_msg_list), 0
else:
return "".join(result_list[:-1]), result_list[-1]
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return "**ERROR**: " + response.message, tk_count

def _wrap_toolcall_message(self, old_message, message):
if not old_message:
@@ -971,10 +925,10 @@ class LocalAIChat(Base):


class LocalLLM(Base):

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
from jina import Client

self.client = Client(port=12345, protocol="grpc", asyncio=True)

def _prepare_prompt(self, system, history, gen_conf):
@@ -1031,7 +985,13 @@ class VolcEngineChat(Base):


class MiniMaxChat(Base):
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
def __init__(
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
**kwargs
):
super().__init__(key, model_name, base_url=base_url, **kwargs)

if not base_url:
@@ -1263,7 +1223,6 @@ class GeminiChat(Base):

def _chat(self, history, gen_conf):
from google.generativeai.types import content_types

system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = []
for item in history:
@@ -1921,4 +1880,4 @@ class GPUStackChat(Base):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
super().__init__(key, model_name, base_url, **kwargs)

+ 1
- 1
rag/prompts.py Ver fichero

@@ -119,7 +119,7 @@ def kb_prompt(kbinfos, max_tokens):
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
cnt += ck["content_with_weight"]
cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL|re.IGNORECASE)
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})


Cargando…
Cancelar
Guardar