|
|
|
@@ -24,7 +24,8 @@ from volcengine.maas.v2 import MaasService |
|
|
|
from rag.nlp import is_english |
|
|
|
from rag.utils import num_tokens_from_string |
|
|
|
from groq import Groq |
|
|
|
|
|
|
|
import json |
|
|
|
import requests |
|
|
|
|
|
|
|
class Base(ABC): |
|
|
|
def __init__(self, key, model_name, base_url): |
|
|
|
@@ -475,11 +476,83 @@ class VolcEngineChat(Base): |
|
|
|
|
|
|
|
|
|
|
|
class MiniMaxChat(Base): |
|
|
|
def __init__(self, key, model_name="abab6.5s-chat", |
|
|
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
key, |
|
|
|
model_name, |
|
|
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", |
|
|
|
): |
|
|
|
if not base_url: |
|
|
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" |
|
|
|
self.base_url = base_url |
|
|
|
self.model_name = model_name |
|
|
|
self.api_key = key |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
for k in list(gen_conf.keys()): |
|
|
|
if k not in ["temperature", "top_p", "max_tokens"]: |
|
|
|
del gen_conf[k] |
|
|
|
headers = { |
|
|
|
"Authorization": f"Bearer {self.api_key}", |
|
|
|
"Content-Type": "application/json", |
|
|
|
} |
|
|
|
payload = json.dumps( |
|
|
|
{"model": self.model_name, "messages": history, **gen_conf} |
|
|
|
) |
|
|
|
try: |
|
|
|
response = requests.request( |
|
|
|
"POST", url=self.base_url, headers=headers, data=payload |
|
|
|
) |
|
|
|
print(response, flush=True) |
|
|
|
response = response.json() |
|
|
|
ans = response["choices"][0]["message"]["content"].strip() |
|
|
|
if response["choices"][0]["finish_reason"] == "length": |
|
|
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( |
|
|
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" |
|
|
|
return ans, response["usage"]["total_tokens"] |
|
|
|
except Exception as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
ans = "" |
|
|
|
total_tokens = 0 |
|
|
|
try: |
|
|
|
headers = { |
|
|
|
"Authorization": f"Bearer {self.api_key}", |
|
|
|
"Content-Type": "application/json", |
|
|
|
} |
|
|
|
payload = json.dumps( |
|
|
|
{ |
|
|
|
"model": self.model_name, |
|
|
|
"messages": history, |
|
|
|
"stream": True, |
|
|
|
**gen_conf, |
|
|
|
} |
|
|
|
) |
|
|
|
response = requests.request( |
|
|
|
"POST", |
|
|
|
url=self.base_url, |
|
|
|
headers=headers, |
|
|
|
data=payload, |
|
|
|
) |
|
|
|
for resp in response.text.split("\n\n")[:-1]: |
|
|
|
resp = json.loads(resp[6:]) |
|
|
|
if "delta" in resp["choices"][0]: |
|
|
|
text = resp["choices"][0]["delta"]["content"] |
|
|
|
else: |
|
|
|
continue |
|
|
|
ans += text |
|
|
|
total_tokens += num_tokens_from_string(text) |
|
|
|
yield ans |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
|
|
|
|
yield total_tokens |
|
|
|
|
|
|
|
|
|
|
|
class MistralChat(Base): |
|
|
|
@@ -748,4 +821,3 @@ class OpenRouterChat(Base): |
|
|
|
self.base_url = "https://openrouter.ai/api/v1" |
|
|
|
self.client = OpenAI(base_url=self.base_url, api_key=key) |
|
|
|
self.model_name = model_name |
|
|
|
|