|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
# limitations under the License. |
|
|
|
# |
|
|
|
import re |
|
|
|
import random |
|
|
|
|
|
|
|
from openai.lib.azure import AzureOpenAI |
|
|
|
from zhipuai import ZhipuAI |
|
|
|
@@ -28,6 +29,23 @@ import os |
|
|
|
import json |
|
|
|
import requests |
|
|
|
import asyncio |
|
|
|
import logging |
|
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
# Error message constants |
|
|
|
ERROR_PREFIX = "**ERROR**" |
|
|
|
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED" |
|
|
|
ERROR_AUTHENTICATION = "AUTH_ERROR" |
|
|
|
ERROR_INVALID_REQUEST = "INVALID_REQUEST" |
|
|
|
ERROR_SERVER = "SERVER_ERROR" |
|
|
|
ERROR_TIMEOUT = "TIMEOUT" |
|
|
|
ERROR_CONNECTION = "CONNECTION_ERROR" |
|
|
|
ERROR_MODEL = "MODEL_ERROR" |
|
|
|
ERROR_CONTENT_FILTER = "CONTENT_FILTERED" |
|
|
|
ERROR_QUOTA = "QUOTA_EXCEEDED" |
|
|
|
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED" |
|
|
|
ERROR_GENERIC = "GENERIC_ERROR" |
|
|
|
|
|
|
|
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。" |
|
|
|
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." |
|
|
|
@@ -38,28 +56,78 @@ class Base(ABC): |
|
|
|
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600)) |
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) |
|
|
|
self.model_name = model_name |
|
|
|
# Configure retry parameters |
|
|
|
self.max_retries = int(os.environ.get('LLM_MAX_RETRIES', 5)) |
|
|
|
self.base_delay = float(os.environ.get('LLM_BASE_DELAY', 2.0)) |
|
|
|
|
|
|
|
def _get_delay(self, attempt): |
|
|
|
"""Calculate retry delay time""" |
|
|
|
return self.base_delay * (2 ** attempt) + random.uniform(0, 0.5) |
|
|
|
|
|
|
|
def _classify_error(self, error): |
|
|
|
"""Classify error based on error message content""" |
|
|
|
error_str = str(error).lower() |
|
|
|
|
|
|
|
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str: |
|
|
|
return ERROR_RATE_LIMIT |
|
|
|
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str: |
|
|
|
return ERROR_AUTHENTICATION |
|
|
|
elif "invalid" in error_str or "bad request" in error_str or "400" in error_str or "format" in error_str or "malformed" in error_str or "parameter" in error_str: |
|
|
|
return ERROR_INVALID_REQUEST |
|
|
|
elif "server" in error_str or "502" in error_str or "503" in error_str or "504" in error_str or "500" in error_str or "unavailable" in error_str: |
|
|
|
return ERROR_SERVER |
|
|
|
elif "timeout" in error_str or "timed out" in error_str: |
|
|
|
return ERROR_TIMEOUT |
|
|
|
elif "connect" in error_str or "network" in error_str or "unreachable" in error_str or "dns" in error_str: |
|
|
|
return ERROR_CONNECTION |
|
|
|
elif "quota" in error_str or "capacity" in error_str or "credit" in error_str or "billing" in error_str or "limit" in error_str and "rate" not in error_str: |
|
|
|
return ERROR_QUOTA |
|
|
|
elif "filter" in error_str or "content" in error_str or "policy" in error_str or "blocked" in error_str or "safety" in error_str: |
|
|
|
return ERROR_CONTENT_FILTER |
|
|
|
elif "model" in error_str or "not found" in error_str or "does not exist" in error_str or "not available" in error_str: |
|
|
|
return ERROR_MODEL |
|
|
|
else: |
|
|
|
return ERROR_GENERIC |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create( |
|
|
|
model=self.model_name, |
|
|
|
messages=history, |
|
|
|
**gen_conf) |
|
|
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): |
|
|
|
return "", 0 |
|
|
|
ans = response.choices[0].message.content.strip() |
|
|
|
if response.choices[0].finish_reason == "length": |
|
|
|
if is_chinese(ans): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
|
|
|
|
# Implement exponential backoff retry strategy |
|
|
|
for attempt in range(self.max_retries): |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create( |
|
|
|
model=self.model_name, |
|
|
|
messages=history, |
|
|
|
**gen_conf) |
|
|
|
|
|
|
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): |
|
|
|
return "", 0 |
|
|
|
ans = response.choices[0].message.content.strip() |
|
|
|
if response.choices[0].finish_reason == "length": |
|
|
|
if is_chinese(ans): |
|
|
|
ans += LENGTH_NOTIFICATION_CN |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, self.total_token_count(response) |
|
|
|
except Exception as e: |
|
|
|
# Classify the error |
|
|
|
error_code = self._classify_error(e) |
|
|
|
|
|
|
|
# Check if it's a rate limit error or server error and not the last attempt |
|
|
|
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 |
|
|
|
|
|
|
|
if should_retry: |
|
|
|
delay = self._get_delay(attempt) |
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt+1}/{self.max_retries})") |
|
|
|
time.sleep(delay) |
|
|
|
else: |
|
|
|
ans += LENGTH_NOTIFICATION_EN |
|
|
|
return ans, self.total_token_count(response) |
|
|
|
except openai.APIError as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
# For non-rate limit errors or the last attempt, return an error message |
|
|
|
if attempt == self.max_retries - 1: |
|
|
|
error_code = ERROR_MAX_RETRIES |
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if system: |