|
|
|
@@ -59,6 +59,7 @@ class Base(ABC): |
|
|
|
# Configure retry parameters |
|
|
|
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5)) |
|
|
|
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0)) |
|
|
|
self.is_tools = False |
|
|
|
|
|
|
|
def _get_delay(self, attempt): |
|
|
|
"""Calculate retry delay time""" |
|
|
|
@@ -89,6 +90,91 @@ class Base(ABC): |
|
|
|
else: |
|
|
|
return ERROR_GENERIC |
|
|
|
|
|
|
|
def bind_tools(self, toolcall_session, tools): |
|
|
|
if not (toolcall_session and tools): |
|
|
|
return |
|
|
|
self.is_tools = True |
|
|
|
self.toolcall_session = toolcall_session |
|
|
|
self.tools = tools |
|
|
|
|
|
|
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
|
|
|
|
tools = self.tools |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
|
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
# Implement exponential backoff retry strategy |
|
|
|
for attempt in range(self.max_retries): |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) |
|
|
|
|
|
|
|
assistant_output = response.choices[0].message |
|
|
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans += "<think>" + ans + "</think>" |
|
|
|
ans += response.choices[0].message.content |
|
|
|
|
|
|
|
if not response.choices[0].message.tool_calls: |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if response.choices[0].finish_reason == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
history.append(assistant_output) |
|
|
|
|
|
|
|
for tool_call in response.choices[0].message.tool_calls: |
|
|
|
name = tool_call.function.name |
|
|
|
args = json.loads(tool_call.function.arguments) |
|
|
|
|
|
|
|
tool_response = self.toolcall_session.tool_call(name, args) |
|
|
|
# if tool_response.choices[0].finish_reason == "length": |
|
|
|
# if is_chinese(ans): |
|
|
|
# ans += LENGTH_NOTIFICATION_CN |
|
|
|
# else: |
|
|
|
# ans += LENGTH_NOTIFICATION_EN |
|
|
|
# return ans, tk_count + self.total_token_count(tool_response) |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) |
|
|
|
|
|
|
|
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) |
|
|
|
assistant_output = final_response.choices[0].message |
|
|
|
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans += "<think>" + ans + "</think>" |
|
|
|
ans += final_response.choices[0].message.content |
|
|
|
if final_response.choices[0].finish_reason == "length": |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logging.exception("OpenAI cat_with_tools") |
|
|
|
# Classify the error |
|
|
|
error_code = self._classify_error(e) |
|
|
|
|
|
|
|
# Check if it's a rate limit error or server error and not the last attempt |
|
|
|
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 |
|
|
|
|
|
|
|
if should_retry: |
|
|
|
delay = self._get_delay(attempt) |
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") |
|
|
|
time.sleep(delay) |
|
|
|
else: |
|
|
|
# For non-rate limit errors or the last attempt, return an error message |
|
|
|
if attempt == self.max_retries - 1: |
|
|
|
error_code = ERROR_MAX_RETRIES |
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
@@ -127,6 +213,127 @@ class Base(ABC): |
|
|
|
error_code = ERROR_MAX_RETRIES |
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0 |
|
|
|
|
|
|
|
def _wrap_toolcall_message(self, stream): |
|
|
|
final_tool_calls = {} |
|
|
|
|
|
|
|
for chunk in stream: |
|
|
|
for tool_call in chunk.choices[0].delta.tool_calls or []: |
|
|
|
index = tool_call.index |
|
|
|
|
|
|
|
if index not in final_tool_calls: |
|
|
|
final_tool_calls[index] = tool_call |
|
|
|
|
|
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments |
|
|
|
|
|
|
|
return final_tool_calls |
|
|
|
|
|
|
|
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
|
|
|
|
tools = self.tools |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
|
|
|
|
ans = "" |
|
|
|
total_tokens = 0 |
|
|
|
reasoning_start = False |
|
|
|
finish_completion = False |
|
|
|
final_tool_calls = {} |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) |
|
|
|
while not finish_completion: |
|
|
|
for resp in response: |
|
|
|
if resp.choices[0].delta.tool_calls: |
|
|
|
for tool_call in resp.choices[0].delta.tool_calls or []: |
|
|
|
index = tool_call.index |
|
|
|
|
|
|
|
if index not in final_tool_calls: |
|
|
|
final_tool_calls[index] = tool_call |
|
|
|
|
|
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments |
|
|
|
if resp.choices[0].finish_reason != "stop": |
|
|
|
continue |
|
|
|
else: |
|
|
|
if not resp.choices: |
|
|
|
continue |
|
|
|
if not resp.choices[0].delta.content: |
|
|
|
resp.choices[0].delta.content = "" |
|
|
|
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: |
|
|
|
ans = "" |
|
|
|
if not reasoning_start: |
|
|
|
reasoning_start = True |
|
|
|
ans = "<think>" |
|
|
|
ans += resp.choices[0].delta.reasoning_content + "</think>" |
|
|
|
else: |
|
|
|
reasoning_start = False |
|
|
|
ans = resp.choices[0].delta.content |
|
|
|
|
|
|
|
tol = self.total_token_count(resp) |
|
|
|
if not tol: |
|
|
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content) |
|
|
|
else: |
|
|
|
total_tokens += tol |
|
|
|
|
|
|
|
finish_reason = resp.choices[0].finish_reason |
|
|
|
if finish_reason == "tool_calls" and final_tool_calls: |
|
|
|
for tool_call in final_tool_calls.values(): |
|
|
|
name = tool_call.function.name |
|
|
|
try: |
|
|
|
if name == "get_current_weather": |
|
|
|
args = json.loads('{"location":"Shanghai"}') |
|
|
|
else: |
|
|
|
args = json.loads(tool_call.function.arguments) |
|
|
|
except Exception: |
|
|
|
continue |
|
|
|
# args = json.loads(tool_call.function.arguments) |
|
|
|
tool_response = self.toolcall_session.tool_call(name, args) |
|
|
|
history.append( |
|
|
|
{ |
|
|
|
"role": "assistant", |
|
|
|
"refusal": "", |
|
|
|
"content": "", |
|
|
|
"audio": "", |
|
|
|
"function_call": "", |
|
|
|
"tool_calls": [ |
|
|
|
{ |
|
|
|
"index": tool_call.index, |
|
|
|
"id": tool_call.id, |
|
|
|
"function": tool_call.function, |
|
|
|
"type": "function", |
|
|
|
}, |
|
|
|
], |
|
|
|
} |
|
|
|
) |
|
|
|
# if tool_response.choices[0].finish_reason == "length": |
|
|
|
# if is_chinese(ans): |
|
|
|
# ans += LENGTH_NOTIFICATION_CN |
|
|
|
# else: |
|
|
|
# ans += LENGTH_NOTIFICATION_EN |
|
|
|
# return ans, total_tokens + self.total_token_count(tool_response) |
|
|
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) |
|
|
|
final_tool_calls = {} |
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) |
|
|
|
continue |
|
|
|
if finish_reason == "length": |
|
|
|
if is_chinese(ans): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, total_tokens + self.total_token_count(resp) |
|
|
|
if finish_reason == "stop": |
|
|
|
finish_completion = True |
|
|
|
yield ans |
|
|
|
break |
|
|
|
yield ans |
|
|
|
continue |
|
|
|
|
|
|
|
except openai.APIError as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
|
|
|
|
yield total_tokens |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
@@ -156,7 +363,7 @@ class Base(ABC): |
|
|
|
if not tol: |
|
|
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content) |
|
|
|
else: |
|
|
|
total_tokens = tol |
|
|
|
total_tokens += tol |
|
|
|
|
|
|
|
if resp.choices[0].finish_reason == "length": |
|
|
|
if is_chinese(ans): |
|
|
|
@@ -180,9 +387,10 @@ class Base(ABC): |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
return 0 |
|
|
|
|
|
|
|
|
|
|
|
def _calculate_dynamic_ctx(self, history): |
|
|
|
"""Calculate dynamic context window size""" |
|
|
|
|
|
|
|
def count_tokens(text): |
|
|
|
"""Calculate token count for text""" |
|
|
|
# Simple calculation: 1 token per ASCII character |
|
|
|
@@ -207,15 +415,16 @@ class Base(ABC): |
|
|
|
|
|
|
|
# Apply 1.2x buffer ratio |
|
|
|
total_tokens_with_buffer = int(total_tokens * 1.2) |
|
|
|
|
|
|
|
|
|
|
|
if total_tokens_with_buffer <= 8192: |
|
|
|
ctx_size = 8192 |
|
|
|
else: |
|
|
|
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 |
|
|
|
ctx_size = ctx_multiplier * 8192 |
|
|
|
|
|
|
|
|
|
|
|
return ctx_size |
|
|
|
|
|
|
|
|
|
|
|
class GptTurbo(Base): |
|
|
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): |
|
|
|
if not base_url: |
|
|
|
@@ -350,6 +559,8 @@ class BaiChuanChat(Base): |
|
|
|
|
|
|
|
class QWenChat(Base): |
|
|
|
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
import dashscope |
|
|
|
|
|
|
|
dashscope.api_key = key |
|
|
|
@@ -357,6 +568,78 @@ class QWenChat(Base): |
|
|
|
if self.is_reasoning_model(self.model_name): |
|
|
|
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1") |
|
|
|
|
|
|
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]: |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
# if self.is_reasoning_model(self.model_name): |
|
|
|
# return super().chat(system, history, gen_conf) |
|
|
|
|
|
|
|
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true" |
|
|
|
if not stream_flag: |
|
|
|
from http import HTTPStatus |
|
|
|
|
|
|
|
tools = self.tools |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
|
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf) |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
if response.status_code == HTTPStatus.OK: |
|
|
|
assistant_output = response.output.choices[0].message |
|
|
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans += "<think>" + ans + "</think>" |
|
|
|
ans += response.output.choices[0].message.content |
|
|
|
|
|
|
|
if "tool_calls" not in assistant_output: |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if response.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
history.append(assistant_output) |
|
|
|
|
|
|
|
while "tool_calls" in assistant_output: |
|
|
|
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]} |
|
|
|
tool_name = assistant_output.tool_calls[0]["function"]["name"] |
|
|
|
if tool_name: |
|
|
|
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"]) |
|
|
|
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments) |
|
|
|
history.append(tool_info) |
|
|
|
|
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf) |
|
|
|
if response.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, tk_count |
|
|
|
|
|
|
|
tk_count += self.total_token_count(response) |
|
|
|
assistant_output = response.output.choices[0].message |
|
|
|
if assistant_output.content is None: |
|
|
|
assistant_output.content = "" |
|
|
|
history.append(response) |
|
|
|
ans += assistant_output["content"] |
|
|
|
return ans, tk_count |
|
|
|
else: |
|
|
|
return "**ERROR**: " + response.message, tk_count |
|
|
|
else: |
|
|
|
result_list = [] |
|
|
|
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True): |
|
|
|
result_list.append(result) |
|
|
|
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0] |
|
|
|
if len(error_msg_list) > 0: |
|
|
|
return "**ERROR**: " + "".join(error_msg_list), 0 |
|
|
|
else: |
|
|
|
return "".join(result_list[:-1]), result_list[-1] |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
@@ -393,6 +676,99 @@ class QWenChat(Base): |
|
|
|
else: |
|
|
|
return "".join(result_list[:-1]), result_list[-1] |
|
|
|
|
|
|
|
def _wrap_toolcall_message(self, old_message, message): |
|
|
|
if not old_message: |
|
|
|
return message |
|
|
|
tool_call_id = message["tool_calls"][0].get("id") |
|
|
|
if tool_call_id: |
|
|
|
old_message.tool_calls[0]["id"] = tool_call_id |
|
|
|
function = message.tool_calls[0]["function"] |
|
|
|
if function: |
|
|
|
if function.get("name"): |
|
|
|
old_message.tool_calls[0]["function"]["name"] = function["name"] |
|
|
|
if function.get("arguments"): |
|
|
|
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"] |
|
|
|
return old_message |
|
|
|
|
|
|
|
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): |
|
|
|
from http import HTTPStatus |
|
|
|
|
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
ans = "" |
|
|
|
tk_count = 0 |
|
|
|
try: |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) |
|
|
|
tool_info = {"content": "", "role": "tool"} |
|
|
|
toolcall_message = None |
|
|
|
tool_name = "" |
|
|
|
tool_arguments = "" |
|
|
|
finish_completion = False |
|
|
|
reasoning_start = False |
|
|
|
while not finish_completion: |
|
|
|
for resp in response: |
|
|
|
if resp.status_code == HTTPStatus.OK: |
|
|
|
assistant_output = resp.output.choices[0].message |
|
|
|
ans = resp.output.choices[0].message.content |
|
|
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: |
|
|
|
ans = resp.output.choices[0].message.reasoning_content |
|
|
|
if not reasoning_start: |
|
|
|
reasoning_start = True |
|
|
|
ans = "<think>" + ans |
|
|
|
else: |
|
|
|
ans = ans + "</think>" |
|
|
|
|
|
|
|
if "tool_calls" not in assistant_output: |
|
|
|
reasoning_start = False |
|
|
|
tk_count += self.total_token_count(resp) |
|
|
|
if resp.output.choices[0].get("finish_reason", "") == "length": |
|
|
|
if is_chinese([ans]): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
finish_reason = resp.output.choices[0]["finish_reason"] |
|
|
|
if finish_reason == "stop": |
|
|
|
finish_completion = True |
|
|
|
yield ans |
|
|
|
break |
|
|
|
yield ans |
|
|
|
continue |
|
|
|
|
|
|
|
tk_count += self.total_token_count(resp) |
|
|
|
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output) |
|
|
|
if "tool_calls" in assistant_output: |
|
|
|
tool_call_finish_reason = resp.output.choices[0]["finish_reason"] |
|
|
|
if tool_call_finish_reason == "tool_calls": |
|
|
|
try: |
|
|
|
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"]) |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg="_chat_streamly_with_tool tool call error") |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
finish_completion = True |
|
|
|
break |
|
|
|
|
|
|
|
tool_name = toolcall_message.tool_calls[0]["function"]["name"] |
|
|
|
history.append(toolcall_message) |
|
|
|
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments) |
|
|
|
history.append(tool_info) |
|
|
|
tool_info = {"content": "", "role": "tool"} |
|
|
|
tool_name = "" |
|
|
|
tool_arguments = "" |
|
|
|
toolcall_message = None |
|
|
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) |
|
|
|
else: |
|
|
|
yield ( |
|
|
|
ans + "\n**ERROR**: " + resp.output.choices[0].message |
|
|
|
if not re.search(r" (key|quota)", str(resp.message).lower()) |
|
|
|
else "Out of credit. Please set the API key in **settings > Model providers.**" |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
logging.exception(msg="_chat_streamly_with_tool") |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
yield tk_count |
|
|
|
|
|
|
|
def _chat_streamly(self, system, history, gen_conf, incremental_output=True): |
|
|
|
from http import HTTPStatus |
|
|
|
|
|
|
|
@@ -425,6 +801,13 @@ class QWenChat(Base): |
|
|
|
|
|
|
|
yield tk_count |
|
|
|
|
|
|
|
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
|
|
|
|
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output): |
|
|
|
yield txt |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
@@ -445,6 +828,8 @@ class QWenChat(Base): |
|
|
|
|
|
|
|
class ZhipuChat(Base): |
|
|
|
def __init__(self, key, model_name="glm-3-turbo", **kwargs): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
self.client = ZhipuAI(api_key=key) |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
@@ -504,6 +889,8 @@ class ZhipuChat(Base): |
|
|
|
|
|
|
|
class OllamaChat(Base): |
|
|
|
def __init__(self, key, model_name, **kwargs): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
@@ -515,10 +902,8 @@ class OllamaChat(Base): |
|
|
|
try: |
|
|
|
# Calculate context size |
|
|
|
ctx_size = self._calculate_dynamic_ctx(history) |
|
|
|
|
|
|
|
options = { |
|
|
|
"num_ctx": ctx_size |
|
|
|
} |
|
|
|
|
|
|
|
options = {"num_ctx": ctx_size} |
|
|
|
if "temperature" in gen_conf: |
|
|
|
options["temperature"] = gen_conf["temperature"] |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
@@ -545,9 +930,7 @@ class OllamaChat(Base): |
|
|
|
try: |
|
|
|
# Calculate context size |
|
|
|
ctx_size = self._calculate_dynamic_ctx(history) |
|
|
|
options = { |
|
|
|
"num_ctx": ctx_size |
|
|
|
} |
|
|
|
options = {"num_ctx": ctx_size} |
|
|
|
if "temperature" in gen_conf: |
|
|
|
options["temperature"] = gen_conf["temperature"] |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
@@ -561,7 +944,7 @@ class OllamaChat(Base): |
|
|
|
|
|
|
|
ans = "" |
|
|
|
try: |
|
|
|
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 ) |
|
|
|
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10) |
|
|
|
for resp in response: |
|
|
|
if resp["done"]: |
|
|
|
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) |
|
|
|
@@ -578,6 +961,8 @@ class OllamaChat(Base): |
|
|
|
|
|
|
|
class LocalAIChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
if not base_url: |
|
|
|
raise ValueError("Local llm url cannot be None") |
|
|
|
if base_url.split("/")[-1] != "v1": |
|
|
|
@@ -613,6 +998,8 @@ class LocalLLM(Base): |
|
|
|
return do_rpc |
|
|
|
|
|
|
|
def __init__(self, key, model_name): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
from jina import Client |
|
|
|
|
|
|
|
self.client = Client(port=12345, protocol="grpc", asyncio=True) |
|
|
|
@@ -659,6 +1046,8 @@ class LocalLLM(Base): |
|
|
|
|
|
|
|
class VolcEngineChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
""" |
|
|
|
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, |
|
|
|
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use |
|
|
|
@@ -677,6 +1066,8 @@ class MiniMaxChat(Base): |
|
|
|
model_name, |
|
|
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", |
|
|
|
): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
if not base_url: |
|
|
|
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" |
|
|
|
self.base_url = base_url |
|
|
|
@@ -755,6 +1146,8 @@ class MiniMaxChat(Base): |
|
|
|
|
|
|
|
class MistralChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=None): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
from mistralai.client import MistralClient |
|
|
|
|
|
|
|
self.client = MistralClient(api_key=key) |
|
|
|
@@ -808,6 +1201,8 @@ class MistralChat(Base): |
|
|
|
|
|
|
|
class BedrockChat(Base): |
|
|
|
def __init__(self, key, model_name, **kwargs): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
import boto3 |
|
|
|
|
|
|
|
self.bedrock_ak = json.loads(key).get("bedrock_ak", "") |
|
|
|
@@ -887,6 +1282,8 @@ class BedrockChat(Base): |
|
|
|
|
|
|
|
class GeminiChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=None): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
from google.generativeai import GenerativeModel, client |
|
|
|
|
|
|
|
client.configure(api_key=key) |
|
|
|
@@ -947,6 +1344,8 @@ class GeminiChat(Base): |
|
|
|
|
|
|
|
class GroqChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=""): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
from groq import Groq |
|
|
|
|
|
|
|
self.client = Groq(api_key=key) |
|
|
|
@@ -1049,6 +1448,8 @@ class PPIOChat(Base): |
|
|
|
|
|
|
|
class CoHereChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=""): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
from cohere import Client |
|
|
|
|
|
|
|
self.client = Client(api_key=key) |
|
|
|
@@ -1171,6 +1572,8 @@ class YiChat(Base): |
|
|
|
|
|
|
|
class ReplicateChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=None): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
from replicate.client import Client |
|
|
|
|
|
|
|
self.model_name = model_name |
|
|
|
@@ -1218,6 +1621,8 @@ class ReplicateChat(Base): |
|
|
|
|
|
|
|
class HunyuanChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=None): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
from tencentcloud.common import credential |
|
|
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client |
|
|
|
|
|
|
|
@@ -1321,6 +1726,8 @@ class SparkChat(Base): |
|
|
|
|
|
|
|
class BaiduYiyanChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=None): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
import qianfan |
|
|
|
|
|
|
|
key = json.loads(key) |
|
|
|
@@ -1372,6 +1779,8 @@ class BaiduYiyanChat(Base): |
|
|
|
|
|
|
|
class AnthropicChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=None): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
import anthropic |
|
|
|
|
|
|
|
self.client = anthropic.Anthropic(api_key=key) |
|
|
|
@@ -1452,6 +1861,8 @@ class AnthropicChat(Base): |
|
|
|
|
|
|
|
class GoogleChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url=None): |
|
|
|
super().__init__(key, model_name, base_url=None) |
|
|
|
|
|
|
|
import base64 |
|
|
|
|
|
|
|
from google.oauth2 import service_account |