|
|
|
@@ -26,6 +26,7 @@ from http import HTTPStatus |
|
|
|
from typing import Any, Protocol |
|
|
|
from urllib.parse import urljoin |
|
|
|
|
|
|
|
import json_repair |
|
|
|
import openai |
|
|
|
import requests |
|
|
|
from dashscope import Generation |
|
|
|
@@ -67,11 +68,12 @@ class Base(ABC): |
|
|
|
# Configure retry parameters |
|
|
|
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) |
|
|
|
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) |
|
|
|
self.max_rounds = kwargs.get("max_rounds", 5) |
|
|
|
self.is_tools = False |
|
|
|
|
|
|
|
def _get_delay(self, attempt): |
|
|
|
def _get_delay(self): |
|
|
|
"""Calculate retry delay time""" |
|
|
|
return self.base_delay * (2**attempt) + random.uniform(0, 0.5) |
|
|
|
return self.base_delay + random.uniform(0, 0.5) |
|
|
|
|
|
|
|
def _classify_error(self, error): |
|
|
|
"""Classify error based on error message content""" |
|
|
|
@@ -116,6 +118,29 @@ class Base(ABC): |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, self.total_token_count(response) |
|
|
|
|
|
|
|
def _length_stop(self, ans): |
|
|
|
if is_chinese([ans]): |
|
|
|
return ans + LENGTH_NOTIFICATION_CN |
|
|
|
return ans + LENGTH_NOTIFICATION_EN |
|
|
|
|
|
|
|
def _exceptions(self, e, attempt): |
|
|
|
logging.exception("OpenAI cat_with_tools") |
|
|
|
# Classify the error |
|
|
|
error_code = self._classify_error(e) |
|
|
|
|
|
|
|
# Check if it's a rate limit error or server error and not the last attempt |
|
|
|
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries |
|
|
|
|
|
|
|
if should_retry: |
|
|
|
delay = self._get_delay() |
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") |
|
|
|
time.sleep(delay) |
|
|
|
else: |
|
|
|
# For non-rate limit errors or the last attempt, return an error message |
|
|
|
if attempt == self.max_retries: |
|
|
|
error_code = ERROR_MAX_RETRIES |
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" |
|
|
|
|
|
|
|
def bind_tools(self, toolcall_session, tools): |
|
|
|
if not (toolcall_session and tools): |
|
|
|
return |
|
|
|
@@ -124,76 +149,48 @@ class Base(ABC): |
|
|
|
self.tools = tools |
|
|
|
|
|
|
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
|
|
|
|
tools = self.tools |
|
|
|
|
|
|
|
gen_conf = self._clean_conf() |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
|
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
hist = deepcopy(history) |
|
|
|
# Implement exponential backoff retry strategy |
|
|
|
for attempt in range(self.max_retries): |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) |
|
|
|
|
|
|
|
assistant_output = response.choices[0].message |
|
|
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans += "<think>" + ans + "</think>" |
|
|
|
ans += response.choices[0].message.content |
|
|
|
|
|
|
|
if not response.choices[0].message.tool_calls: |
|
|
|
for attempt in range(self.max_retries+1): |
|
|
|
history = hist |
|
|
|
for _ in range(self.max_rounds*2): |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if response.choices[0].finish_reason == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
history.append(assistant_output) |
|
|
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): |
|
|
|
raise Exception("500 response structure error.") |
|
|
|
|
|
|
|
for tool_call in response.choices[0].message.tool_calls: |
|
|
|
name = tool_call.function.name |
|
|
|
args = json.loads(tool_call.function.arguments) |
|
|
|
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: |
|
|
|
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: |
|
|
|
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>" |
|
|
|
|
|
|
|
tool_response = self.toolcall_session.tool_call(name, args) |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) |
|
|
|
ans += response.choices[0].message.content |
|
|
|
if response.choices[0].finish_reason == "length": |
|
|
|
ans = self._length_stop(ans) |
|
|
|
|
|
|
|
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) |
|
|
|
assistant_output = final_response.choices[0].message |
|
|
|
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans += "<think>" + ans + "</think>" |
|
|
|
ans += final_response.choices[0].message.content |
|
|
|
if final_response.choices[0].finish_reason == "length": |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
return ans, tk_count |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logging.exception("OpenAI cat_with_tools") |
|
|
|
# Classify the error |
|
|
|
error_code = self._classify_error(e) |
|
|
|
for tool_call in response.choices[0].message.tool_calls: |
|
|
|
name = tool_call.function.name |
|
|
|
try: |
|
|
|
args = json_repair.loads(tool_call.function.arguments) |
|
|
|
tool_response = self.toolcall_session.tool_call(name, args) |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) |
|
|
|
except Exception as e: |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) |
|
|
|
|
|
|
|
# Check if it's a rate limit error or server error and not the last attempt |
|
|
|
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 |
|
|
|
|
|
|
|
if should_retry: |
|
|
|
delay = self._get_delay(attempt) |
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") |
|
|
|
time.sleep(delay) |
|
|
|
else: |
|
|
|
# For non-rate limit errors or the last attempt, return an error message |
|
|
|
if attempt == self.max_retries - 1: |
|
|
|
error_code = ERROR_MAX_RETRIES |
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 |
|
|
|
except Exception as e: |
|
|
|
e = self._exceptions(e, attempt) |
|
|
|
if e: |
|
|
|
return e, tk_count |
|
|
|
assert False, "Shouldn't be here." |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
@@ -201,26 +198,14 @@ class Base(ABC): |
|
|
|
gen_conf = self._clean_conf(gen_conf) |
|
|
|
|
|
|
|
# Implement exponential backoff retry strategy |
|
|
|
for attempt in range(self.max_retries): |
|
|
|
for attempt in range(self.max_retries+1): |
|
|
|
try: |
|
|
|
return self._chat(history, gen_conf) |
|
|
|
except Exception as e: |
|
|
|
logging.exception("chat_model.Base.chat got exception") |
|
|
|
# Classify the error |
|
|
|
error_code = self._classify_error(e) |
|
|
|
|
|
|
|
# Check if it's a rate limit error or server error and not the last attempt |
|
|
|
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 |
|
|
|
|
|
|
|
if should_retry: |
|
|
|
delay = self._get_delay(attempt) |
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") |
|
|
|
time.sleep(delay) |
|
|
|
else: |
|
|
|
# For non-rate limit errors or the last attempt, return an error message |
|
|
|
if attempt == self.max_retries - 1: |
|
|
|
error_code = ERROR_MAX_RETRIES |
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 |
|
|
|
e = self._exceptions(e, attempt) |
|
|
|
if e: |
|
|
|
return e, 0 |
|
|
|
assert False, "Shouldn't be here." |
|
|
|
|
|
|
|
def _wrap_toolcall_message(self, stream): |
|
|
|
final_tool_calls = {} |
|
|
|
@@ -241,41 +226,48 @@ class Base(ABC): |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
|
|
|
|
tools = self.tools |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
|
|
|
|
ans = "" |
|
|
|
total_tokens = 0 |
|
|
|
reasoning_start = False |
|
|
|
finish_completion = False |
|
|
|
final_tool_calls = {} |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) |
|
|
|
while not finish_completion: |
|
|
|
for resp in response: |
|
|
|
if resp.choices[0].delta.tool_calls: |
|
|
|
for tool_call in resp.choices[0].delta.tool_calls or []: |
|
|
|
index = tool_call.index |
|
|
|
|
|
|
|
if index not in final_tool_calls: |
|
|
|
final_tool_calls[index] = tool_call |
|
|
|
else: |
|
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments |
|
|
|
else: |
|
|
|
if not resp.choices: |
|
|
|
hist = deepcopy(history) |
|
|
|
# Implement exponential backoff retry strategy |
|
|
|
for attempt in range(self.max_retries+1): |
|
|
|
history = hist |
|
|
|
for _ in range(self.max_rounds*2): |
|
|
|
reasoning_start = False |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) |
|
|
|
final_tool_calls = {} |
|
|
|
answer = "" |
|
|
|
for resp in response: |
|
|
|
if resp.choices[0].delta.tool_calls: |
|
|
|
for tool_call in resp.choices[0].delta.tool_calls or []: |
|
|
|
index = tool_call.index |
|
|
|
|
|
|
|
if index not in final_tool_calls: |
|
|
|
final_tool_calls[index] = tool_call |
|
|
|
else: |
|
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments |
|
|
|
continue |
|
|
|
|
|
|
|
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): |
|
|
|
raise Exception("500 response structure error.") |
|
|
|
|
|
|
|
if not resp.choices[0].delta.content: |
|
|
|
resp.choices[0].delta.content = "" |
|
|
|
|
|
|
|
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: |
|
|
|
ans = "" |
|
|
|
if not reasoning_start: |
|
|
|
reasoning_start = True |
|
|
|
ans = "<think>" |
|
|
|
ans += resp.choices[0].delta.reasoning_content + "</think>" |
|
|
|
yield ans |
|
|
|
else: |
|
|
|
reasoning_start = False |
|
|
|
ans = resp.choices[0].delta.content |
|
|
|
answer += resp.choices[0].delta.content |
|
|
|
yield resp.choices[0].delta.content |
|
|
|
|
|
|
|
tol = self.total_token_count(resp) |
|
|
|
if not tol: |
|
|
|
@@ -283,18 +275,18 @@ class Base(ABC): |
|
|
|
else: |
|
|
|
total_tokens += tol |
|
|
|
|
|
|
|
finish_reason = resp.choices[0].finish_reason |
|
|
|
if finish_reason == "tool_calls" and final_tool_calls: |
|
|
|
for tool_call in final_tool_calls.values(): |
|
|
|
name = tool_call.function.name |
|
|
|
try: |
|
|
|
args = json.loads(tool_call.function.arguments) |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
finish_completion = True |
|
|
|
break |
|
|
|
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" |
|
|
|
if finish_reason == "length": |
|
|
|
yield self._length_stop("") |
|
|
|
|
|
|
|
if answer: |
|
|
|
yield total_tokens |
|
|
|
return |
|
|
|
|
|
|
|
for tool_call in final_tool_calls.values(): |
|
|
|
name = tool_call.function.name |
|
|
|
try: |
|
|
|
args = json_repair.loads(tool_call.function.arguments) |
|
|
|
tool_response = self.toolcall_session.tool_call(name, args) |
|
|
|
history.append( |
|
|
|
{ |
|
|
|
@@ -313,26 +305,16 @@ class Base(ABC): |
|
|
|
} |
|
|
|
) |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) |
|
|
|
final_tool_calls = {} |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) |
|
|
|
continue |
|
|
|
if finish_reason == "length": |
|
|
|
if is_chinese(ans): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, total_tokens |
|
|
|
if finish_reason == "stop": |
|
|
|
finish_completion = True |
|
|
|
yield ans |
|
|
|
break |
|
|
|
yield ans |
|
|
|
continue |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) |
|
|
|
except Exception as e: |
|
|
|
e = self._exceptions(e, attempt) |
|
|
|
if e: |
|
|
|
yield total_tokens |
|
|
|
return |
|
|
|
|
|
|
|
except openai.APIError as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
|
|
|
|
yield total_tokens |
|
|
|
assert False, "Shouldn't be here." |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
@@ -636,49 +618,21 @@ class QWenChat(Base): |
|
|
|
return "".join(result_list[:-1]), result_list[-1] |
|
|
|
|
|
|
|
def _chat(self, history, gen_conf): |
|
|
|
tk_count = 0 |
|
|
|
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: |
|
|
|
try: |
|
|
|
response = super()._chat(history, gen_conf) |
|
|
|
return response |
|
|
|
except Exception as e: |
|
|
|
error_msg = str(e).lower() |
|
|
|
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg: |
|
|
|
return self._simulate_one_shot_from_stream(history, gen_conf) |
|
|
|
return super()._chat(history, gen_conf) |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
if response.status_code == HTTPStatus.OK: |
|
|
|
ans += response.output.choices[0]["message"]["content"] |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if response.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
return "**ERROR**: " + str(e), tk_count |
|
|
|
|
|
|
|
try: |
|
|
|
ans = "" |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) |
|
|
|
if response.status_code == HTTPStatus.OK: |
|
|
|
ans += response.output.choices[0]["message"]["content"] |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if response.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
return "**ERROR**: " + response.message, tk_count |
|
|
|
except Exception as e: |
|
|
|
error_msg = str(e).lower() |
|
|
|
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg: |
|
|
|
return self._simulate_one_shot_from_stream(history, gen_conf) |
|
|
|
else: |
|
|
|
return "**ERROR**: " + str(e), tk_count |
|
|
|
|
|
|
|
def _simulate_one_shot_from_stream(self, history, gen_conf): |
|
|
|
""" |
|
|
|
Handles models that require streaming output but need one-shot response. |
|
|
|
""" |
|
|
|
g = self._chat_streamly("", history, gen_conf, incremental_output=True) |
|
|
|
result_list = list(g) |
|
|
|
error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0] |
|
|
|
if len(error_msg_list) > 0: |
|
|
|
return "**ERROR**: " + "".join(error_msg_list), 0 |
|
|
|
else: |
|
|
|
return "".join(result_list[:-1]), result_list[-1] |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
return "**ERROR**: " + response.message, tk_count |
|
|
|
|
|
|
|
def _wrap_toolcall_message(self, old_message, message): |
|
|
|
if not old_message: |
|
|
|
@@ -971,10 +925,10 @@ class LocalAIChat(Base): |
|
|
|
|
|
|
|
|
|
|
|
class LocalLLM(Base): |
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs): |
|
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs) |
|
|
|
from jina import Client |
|
|
|
|
|
|
|
self.client = Client(port=12345, protocol="grpc", asyncio=True) |
|
|
|
|
|
|
|
def _prepare_prompt(self, system, history, gen_conf): |
|
|
|
@@ -1031,7 +985,13 @@ class VolcEngineChat(Base): |
|
|
|
|
|
|
|
|
|
|
|
class MiniMaxChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
key, |
|
|
|
model_name, |
|
|
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs) |
|
|
|
|
|
|
|
if not base_url: |
|
|
|
@@ -1263,7 +1223,6 @@ class GeminiChat(Base): |
|
|
|
|
|
|
|
def _chat(self, history, gen_conf): |
|
|
|
from google.generativeai.types import content_types |
|
|
|
|
|
|
|
system = history[0]["content"] if history and history[0]["role"] == "system" else "" |
|
|
|
hist = [] |
|
|
|
for item in history: |
|
|
|
@@ -1921,4 +1880,4 @@ class GPUStackChat(Base): |
|
|
|
if not base_url: |
|
|
|
raise ValueError("Local llm url cannot be None") |
|
|
|
base_url = urljoin(base_url, "v1") |
|
|
|
super().__init__(key, model_name, base_url, **kwargs) |
|
|
|
super().__init__(key, model_name, base_url, **kwargs) |