|
|
|
@@ -18,11 +18,9 @@ import json |
|
|
|
import logging |
|
|
|
import os |
|
|
|
import random |
|
|
|
import re |
|
|
|
import time |
|
|
|
from abc import ABC |
|
|
|
from copy import deepcopy |
|
|
|
from http import HTTPStatus |
|
|
|
from typing import Any, Protocol |
|
|
|
from urllib.parse import urljoin |
|
|
|
|
|
|
|
@@ -61,9 +59,6 @@ class ToolCallSession(Protocol): |
|
|
|
|
|
|
|
|
|
|
|
class Base(ABC): |
|
|
|
tools: list[Any] |
|
|
|
toolcall_sessions: dict[str, ToolCallSession] |
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url, **kwargs): |
|
|
|
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) |
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) |
|
|
|
@@ -146,6 +141,37 @@ class Base(ABC): |
|
|
|
error_code = ERROR_MAX_RETRIES |
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" |
|
|
|
|
|
|
|
def _verbose_tool_use(self, name, args, res): |
|
|
|
return "<tool_call>" + json.dumps({ |
|
|
|
"name": name, |
|
|
|
"args": args, |
|
|
|
"result": res |
|
|
|
}, ensure_ascii=False, indent=2) + "</tool_call>" |
|
|
|
|
|
|
|
def _append_history(self, hist, tool_call, tool_res): |
|
|
|
hist.append( |
|
|
|
{ |
|
|
|
"role": "assistant", |
|
|
|
"tool_calls": [ |
|
|
|
{ |
|
|
|
"index": tool_call.index, |
|
|
|
"id": tool_call.id, |
|
|
|
"function": { |
|
|
|
"name": tool_call.function.name, |
|
|
|
"arguments": tool_call.function.arguments, |
|
|
|
}, |
|
|
|
"type": "function", |
|
|
|
}, |
|
|
|
], |
|
|
|
} |
|
|
|
) |
|
|
|
try: |
|
|
|
if isinstance(tool_res, dict): |
|
|
|
tool_res = json.dumps(tool_res, ensure_ascii=False) |
|
|
|
finally: |
|
|
|
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)}) |
|
|
|
return hist |
|
|
|
|
|
|
|
def bind_tools(self, toolcall_session, tools): |
|
|
|
if not (toolcall_session and tools): |
|
|
|
return |
|
|
|
@@ -160,18 +186,19 @@ class Base(ABC): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
|
|
|
|
gen_conf = self._clean_conf(gen_conf) |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
hist = deepcopy(history) |
|
|
|
# Implement exponential backoff retry strategy |
|
|
|
for attempt in range(self.max_retries + 1): |
|
|
|
for attempt in range(self.max_retries+1): |
|
|
|
history = hist |
|
|
|
for _ in range(self.max_rounds * 2): |
|
|
|
try: |
|
|
|
try: |
|
|
|
for _ in range(self.max_rounds*2): |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): |
|
|
|
raise Exception("500 response structure error.") |
|
|
|
if any([not response.choices, not response.choices[0].message]): |
|
|
|
raise Exception(f"500 response structure error. Response: {response}") |
|
|
|
|
|
|
|
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: |
|
|
|
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: |
|
|
|
@@ -188,14 +215,17 @@ class Base(ABC): |
|
|
|
try: |
|
|
|
args = json_repair.loads(tool_call.function.arguments) |
|
|
|
tool_response = self.toolcall_sessions[name].tool_call(name, args) |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) |
|
|
|
history = self._append_history(history, tool_call, tool_response) |
|
|
|
ans += self._verbose_tool_use(name, args, tool_response) |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) |
|
|
|
ans += self._verbose_tool_use(name, {}, str(e)) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
e = self._exceptions(e, attempt) |
|
|
|
if e: |
|
|
|
return e, tk_count |
|
|
|
except Exception as e: |
|
|
|
e = self._exceptions(e, attempt) |
|
|
|
if e: |
|
|
|
return e, tk_count |
|
|
|
assert False, "Shouldn't be here." |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
@@ -228,9 +258,7 @@ class Base(ABC): |
|
|
|
return final_tool_calls |
|
|
|
|
|
|
|
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
|
|
|
|
gen_conf = self._clean_conf(gen_conf) |
|
|
|
tools = self.tools |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
@@ -240,9 +268,9 @@ class Base(ABC): |
|
|
|
# Implement exponential backoff retry strategy |
|
|
|
for attempt in range(self.max_retries + 1): |
|
|
|
history = hist |
|
|
|
for _ in range(self.max_rounds * 2): |
|
|
|
reasoning_start = False |
|
|
|
try: |
|
|
|
try: |
|
|
|
for _ in range(self.max_rounds*2): |
|
|
|
reasoning_start = False |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) |
|
|
|
final_tool_calls = {} |
|
|
|
answer = "" |
|
|
|
@@ -252,9 +280,11 @@ class Base(ABC): |
|
|
|
index = tool_call.index |
|
|
|
|
|
|
|
if index not in final_tool_calls: |
|
|
|
if not tool_call.function.arguments: |
|
|
|
tool_call.function.arguments = "" |
|
|
|
final_tool_calls[index] = tool_call |
|
|
|
else: |
|
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments |
|
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else "" |
|
|
|
continue |
|
|
|
|
|
|
|
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): |
|
|
|
@@ -293,40 +323,26 @@ class Base(ABC): |
|
|
|
name = tool_call.function.name |
|
|
|
try: |
|
|
|
args = json_repair.loads(tool_call.function.arguments) |
|
|
|
tool_response = self.toolcall_sessions[name].tool_call(name, args) |
|
|
|
history.append( |
|
|
|
{ |
|
|
|
"role": "assistant", |
|
|
|
"tool_calls": [ |
|
|
|
{ |
|
|
|
"index": tool_call.index, |
|
|
|
"id": tool_call.id, |
|
|
|
"function": { |
|
|
|
"name": tool_call.function.name, |
|
|
|
"arguments": tool_call.function.arguments, |
|
|
|
}, |
|
|
|
"type": "function", |
|
|
|
}, |
|
|
|
], |
|
|
|
} |
|
|
|
) |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) |
|
|
|
tool_response = self.toolcall_session[name].tool_call(name, args) |
|
|
|
history = self._append_history(history, tool_call, tool_response) |
|
|
|
yield self._verbose_tool_use(name, args, tool_response) |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) |
|
|
|
except Exception as e: |
|
|
|
e = self._exceptions(e, attempt) |
|
|
|
if e: |
|
|
|
yield total_tokens |
|
|
|
return |
|
|
|
yield self._verbose_tool_use(name, {}, str(e)) |
|
|
|
|
|
|
|
assert False, "Shouldn't be here." |
|
|
|
except Exception as e: |
|
|
|
e = self._exceptions(e, attempt) |
|
|
|
if e: |
|
|
|
yield total_tokens |
|
|
|
return |
|
|
|
|
|
|
|
yield total_tokens |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
gen_conf = self._clean_conf(gen_conf) |
|
|
|
ans = "" |
|
|
|
total_tokens = 0 |
|
|
|
reasoning_start = False |
|
|
|
@@ -542,252 +558,8 @@ class BaiChuanChat(Base): |
|
|
|
|
|
|
|
class QWenChat(Base): |
|
|
|
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): |
|
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs) |
|
|
|
|
|
|
|
import dashscope |
|
|
|
|
|
|
|
dashscope.api_key = key |
|
|
|
self.model_name = model_name |
|
|
|
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: |
|
|
|
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs) |
|
|
|
|
|
|
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]: |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
# if self.is_reasoning_model(self.model_name): |
|
|
|
# return super().chat(system, history, gen_conf) |
|
|
|
|
|
|
|
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true" |
|
|
|
if not stream_flag: |
|
|
|
from http import HTTPStatus |
|
|
|
|
|
|
|
tools = self.tools |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
|
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf) |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
if response.status_code == HTTPStatus.OK: |
|
|
|
assistant_output = response.output.choices[0].message |
|
|
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans += "<think>" + ans + "</think>" |
|
|
|
ans += response.output.choices[0].message.content |
|
|
|
|
|
|
|
if "tool_calls" not in assistant_output: |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if response.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
history.append(assistant_output) |
|
|
|
|
|
|
|
while "tool_calls" in assistant_output: |
|
|
|
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]} |
|
|
|
tool_name = assistant_output.tool_calls[0]["function"]["name"] |
|
|
|
if tool_name: |
|
|
|
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"]) |
|
|
|
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments) |
|
|
|
history.append(tool_info) |
|
|
|
|
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf) |
|
|
|
if response.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
assistant_output = response.output.choices[0].message |
|
|
|
if assistant_output.content is None: |
|
|
|
assistant_output.content = "" |
|
|
|
history.append(response) |
|
|
|
ans += assistant_output["content"] |
|
|
|
return ans, tk_count |
|
|
|
else: |
|
|
|
return "**ERROR**: " + response.message, tk_count |
|
|
|
else: |
|
|
|
result_list = [] |
|
|
|
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True): |
|
|
|
result_list.append(result) |
|
|
|
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0] |
|
|
|
if len(error_msg_list) > 0: |
|
|
|
return "**ERROR**: " + "".join(error_msg_list), 0 |
|
|
|
else: |
|
|
|
return "".join(result_list[:-1]), result_list[-1] |
|
|
|
|
|
|
|
def _chat(self, history, gen_conf): |
|
|
|
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: |
|
|
|
return super()._chat(history, gen_conf) |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
if response.status_code == HTTPStatus.OK: |
|
|
|
ans += response.output.choices[0]["message"]["content"] |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if response.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
return "**ERROR**: " + response.message, tk_count |
|
|
|
|
|
|
|
def _wrap_toolcall_message(self, old_message, message): |
|
|
|
if not old_message: |
|
|
|
return message |
|
|
|
tool_call_id = message["tool_calls"][0].get("id") |
|
|
|
if tool_call_id: |
|
|
|
old_message.tool_calls[0]["id"] = tool_call_id |
|
|
|
function = message.tool_calls[0]["function"] |
|
|
|
if function: |
|
|
|
if function.get("name"): |
|
|
|
old_message.tool_calls[0]["function"]["name"] = function["name"] |
|
|
|
if function.get("arguments"): |
|
|
|
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"] |
|
|
|
return old_message |
|
|
|
|
|
|
|
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): |
|
|
|
from http import HTTPStatus |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
try: |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) |
|
|
|
tool_info = {"content": "", "role": "tool"} |
|
|
|
toolcall_message = None |
|
|
|
tool_name = "" |
|
|
|
tool_arguments = "" |
|
|
|
finish_completion = False |
|
|
|
reasoning_start = False |
|
|
|
while not finish_completion: |
|
|
|
for resp in response: |
|
|
|
if resp.status_code == HTTPStatus.OK: |
|
|
|
assistant_output = resp.output.choices[0].message |
|
|
|
ans = resp.output.choices[0].message.content |
|
|
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans = resp.output.choices[0].message.reasoning_content |
|
|
|
if not reasoning_start: |
|
|
|
reasoning_start = True |
|
|
|
ans = "<think>" + ans |
|
|
|
else: |
|
|
|
ans = ans + "</think>" |
|
|
|
|
|
|
|
if "tool_calls" not in assistant_output: |
|
|
|
reasoning_start = False |
|
|
|
tk_count += self.total_token_count(resp) |
|
|
|
if resp.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
finish_reason = resp.output.choices[0]["finish_reason"] |
|
|
|
if finish_reason == "stop": |
|
|
|
finish_completion = True |
|
|
|
yield ans |
|
|
|
break |
|
|
|
yield ans |
|
|
|
continue |
|
|
|
|
|
|
|
tk_count += self.total_token_count(resp) |
|
|
|
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output) |
|
|
|
if "tool_calls" in assistant_output: |
|
|
|
tool_call_finish_reason = resp.output.choices[0]["finish_reason"] |
|
|
|
if tool_call_finish_reason == "tool_calls": |
|
|
|
try: |
|
|
|
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"]) |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg="_chat_streamly_with_tool tool call error") |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
finish_completion = True |
|
|
|
break |
|
|
|
|
|
|
|
tool_name = toolcall_message.tool_calls[0]["function"]["name"] |
|
|
|
history.append(toolcall_message) |
|
|
|
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments) |
|
|
|
history.append(tool_info) |
|
|
|
tool_info = {"content": "", "role": "tool"} |
|
|
|
tool_name = "" |
|
|
|
tool_arguments = "" |
|
|
|
toolcall_message = None |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) |
|
|
|
else: |
|
|
|
yield ( |
|
|
|
ans + "\n**ERROR**: " + resp.output.choices[0].message |
|
|
|
if not re.search(r" (key|quota)", str(resp.message).lower()) |
|
|
|
else "Out of credit. Please set the API key in **settings > Model providers.**" |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg="_chat_streamly_with_tool") |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
yield tk_count |
|
|
|
|
|
|
|
def _chat_streamly(self, system, history, gen_conf, incremental_output=True): |
|
|
|
from http import HTTPStatus |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
try: |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf) |
|
|
|
for resp in response: |
|
|
|
if resp.status_code == HTTPStatus.OK: |
|
|
|
ans = resp.output.choices[0]["message"]["content"] |
|
|
|
tk_count = self.total_token_count(resp) |
|
|
|
if resp.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese(ans): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
yield ans |
|
|
|
else: |
|
|
|
yield ( |
|
|
|
ans + "\n**ERROR**: " + resp.message |
|
|
|
if not re.search(r" (key|quota)", str(resp.message).lower()) |
|
|
|
else "Out of credit. Please set the API key in **settings > Model providers.**" |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
|
|
|
|
yield tk_count |
|
|
|
|
|
|
|
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
|
|
|
|
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output): |
|
|
|
yield txt |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: |
|
|
|
return super().chat_streamly(system, history, gen_conf) |
|
|
|
|
|
|
|
return self._chat_streamly(system, history, gen_conf) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def is_reasoning_model(model_name: str) -> bool: |
|
|
|
return any( |
|
|
|
[ |
|
|
|
model_name.lower().find("deepseek") >= 0, |
|
|
|
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview", |
|
|
|
] |
|
|
|
) |
|
|
|
super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs) |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
class ZhipuChat(Base): |
|
|
|
@@ -1877,4 +1649,4 @@ class GPUStackChat(Base): |
|
|
|
if not base_url: |
|
|
|
raise ValueError("Local llm url cannot be None") |
|
|
|
base_url = urljoin(base_url, "v1") |
|
|
|
super().__init__(key, model_name, base_url, **kwargs) |
|
|
|
super().__init__(key, model_name, base_url, **kwargs) |