|
|
|
@@ -24,6 +24,7 @@ from volcengine.maas.v2 import MaasService |
|
|
|
from rag.nlp import is_english |
|
|
|
from rag.utils import num_tokens_from_string |
|
|
|
from groq import Groq |
|
|
|
import os |
|
|
|
import json |
|
|
|
import requests |
|
|
|
|
|
|
|
@@ -60,9 +61,16 @@ class Base(ABC): |
|
|
|
stream=True, |
|
|
|
**gen_conf) |
|
|
|
for resp in response: |
|
|
|
if not resp.choices or not resp.choices[0].delta.content:continue |
|
|
|
if not resp.choices:continue |
|
|
|
ans += resp.choices[0].delta.content |
|
|
|
total_tokens += 1 |
|
|
|
total_tokens = ( |
|
|
|
( |
|
|
|
total_tokens |
|
|
|
+ num_tokens_from_string(resp.choices[0].delta.content) |
|
|
|
) |
|
|
|
if not hasattr(resp, "usage") |
|
|
|
else resp.usage["total_tokens"] |
|
|
|
) |
|
|
|
if resp.choices[0].finish_reason == "length": |
|
|
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( |
|
|
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" |
|
|
|
@@ -85,8 +93,13 @@ class MoonshotChat(Base): |
|
|
|
if not base_url: base_url="https://api.moonshot.cn/v1" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
|
|
|
|
class XinferenceChat(Base): |
|
|
|
def __init__(self, key=None, model_name="", base_url=""): |
|
|
|
if not base_url: |
|
|
|
raise ValueError("Local llm url cannot be None") |
|
|
|
if base_url.split("/")[-1] != "v1": |
|
|
|
self.base_url = os.path.join(base_url, "v1") |
|
|
|
key = "xxx" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
@@ -349,79 +362,13 @@ class OllamaChat(Base): |
|
|
|
|
|
|
|
class LocalAIChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url): |
|
|
|
if base_url[-1] == "/": |
|
|
|
base_url = base_url[:-1] |
|
|
|
self.base_url = base_url + "/v1/chat/completions" |
|
|
|
if not base_url: |
|
|
|
raise ValueError("Local llm url cannot be None") |
|
|
|
if base_url.split("/")[-1] != "v1": |
|
|
|
self.base_url = os.path.join(base_url, "v1") |
|
|
|
self.client = OpenAI(api_key="empty", base_url=self.base_url) |
|
|
|
self.model_name = model_name.split("___")[0] |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
for k in list(gen_conf.keys()): |
|
|
|
if k not in ["temperature", "top_p", "max_tokens"]: |
|
|
|
del gen_conf[k] |
|
|
|
headers = { |
|
|
|
"Content-Type": "application/json", |
|
|
|
} |
|
|
|
payload = json.dumps( |
|
|
|
{"model": self.model_name, "messages": history, **gen_conf} |
|
|
|
) |
|
|
|
try: |
|
|
|
response = requests.request( |
|
|
|
"POST", url=self.base_url, headers=headers, data=payload |
|
|
|
) |
|
|
|
response = response.json() |
|
|
|
ans = response["choices"][0]["message"]["content"].strip() |
|
|
|
if response["choices"][0]["finish_reason"] == "length": |
|
|
|
ans += ( |
|
|
|
"...\nFor the content length reason, it stopped, continue?" |
|
|
|
if is_english([ans]) |
|
|
|
else "······\n由于长度的原因,回答被截断了,要继续吗?" |
|
|
|
) |
|
|
|
return ans, response["usage"]["total_tokens"] |
|
|
|
except Exception as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
ans = "" |
|
|
|
total_tokens = 0 |
|
|
|
try: |
|
|
|
headers = { |
|
|
|
"Content-Type": "application/json", |
|
|
|
} |
|
|
|
payload = json.dumps( |
|
|
|
{ |
|
|
|
"model": self.model_name, |
|
|
|
"messages": history, |
|
|
|
"stream": True, |
|
|
|
**gen_conf, |
|
|
|
} |
|
|
|
) |
|
|
|
response = requests.request( |
|
|
|
"POST", |
|
|
|
url=self.base_url, |
|
|
|
headers=headers, |
|
|
|
data=payload, |
|
|
|
) |
|
|
|
for resp in response.content.decode("utf-8").split("\n\n"): |
|
|
|
if "choices" not in resp: |
|
|
|
continue |
|
|
|
resp = json.loads(resp[6:]) |
|
|
|
if "delta" in resp["choices"][0]: |
|
|
|
text = resp["choices"][0]["delta"]["content"] |
|
|
|
else: |
|
|
|
continue |
|
|
|
ans += text |
|
|
|
total_tokens += 1 |
|
|
|
yield ans |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
|
|
|
|
yield total_tokens |
|
|
|
|
|
|
|
|
|
|
|
class LocalLLM(Base): |
|
|
|
class RPCProxy: |
|
|
|
@@ -892,9 +839,10 @@ class GroqChat: |
|
|
|
## openrouter |
|
|
|
class OpenRouterChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"): |
|
|
|
self.base_url = "https://openrouter.ai/api/v1" |
|
|
|
self.client = OpenAI(base_url=self.base_url, api_key=key) |
|
|
|
self.model_name = model_name |
|
|
|
if not base_url: |
|
|
|
base_url = "https://openrouter.ai/api/v1" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
|
|
|
|
class StepFunChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"): |
|
|
|
@@ -904,87 +852,17 @@ class StepFunChat(Base): |
|
|
|
|
|
|
|
|
|
|
|
class NvidiaChat(Base): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
key, |
|
|
|
model_name, |
|
|
|
base_url="https://integrate.api.nvidia.com/v1/chat/completions", |
|
|
|
): |
|
|
|
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"): |
|
|
|
if not base_url: |
|
|
|
base_url = "https://integrate.api.nvidia.com/v1/chat/completions" |
|
|
|
self.base_url = base_url |
|
|
|
self.model_name = model_name |
|
|
|
self.api_key = key |
|
|
|
self.headers = { |
|
|
|
"accept": "application/json", |
|
|
|
"Authorization": f"Bearer {self.api_key}", |
|
|
|
"Content-Type": "application/json", |
|
|
|
} |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
for k in list(gen_conf.keys()): |
|
|
|
if k not in ["temperature", "top_p", "max_tokens"]: |
|
|
|
del gen_conf[k] |
|
|
|
payload = {"model": self.model_name, "messages": history, **gen_conf} |
|
|
|
try: |
|
|
|
response = requests.post( |
|
|
|
url=self.base_url, headers=self.headers, json=payload |
|
|
|
) |
|
|
|
response = response.json() |
|
|
|
ans = response["choices"][0]["message"]["content"].strip() |
|
|
|
return ans, response["usage"]["total_tokens"] |
|
|
|
except Exception as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|
|
|
|
def chat_streamly(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
for k in list(gen_conf.keys()): |
|
|
|
if k not in ["temperature", "top_p", "max_tokens"]: |
|
|
|
del gen_conf[k] |
|
|
|
ans = "" |
|
|
|
total_tokens = 0 |
|
|
|
payload = { |
|
|
|
"model": self.model_name, |
|
|
|
"messages": history, |
|
|
|
"stream": True, |
|
|
|
**gen_conf, |
|
|
|
} |
|
|
|
|
|
|
|
try: |
|
|
|
response = requests.post( |
|
|
|
url=self.base_url, |
|
|
|
headers=self.headers, |
|
|
|
json=payload, |
|
|
|
) |
|
|
|
for resp in response.text.split("\n\n"): |
|
|
|
if "choices" not in resp: |
|
|
|
continue |
|
|
|
resp = json.loads(resp[6:]) |
|
|
|
if "content" in resp["choices"][0]["delta"]: |
|
|
|
text = resp["choices"][0]["delta"]["content"] |
|
|
|
else: |
|
|
|
continue |
|
|
|
ans += text |
|
|
|
if "usage" in resp: |
|
|
|
total_tokens = resp["usage"]["total_tokens"] |
|
|
|
yield ans |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
|
|
|
|
yield total_tokens |
|
|
|
base_url = "https://integrate.api.nvidia.com/v1" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
|
|
|
|
class LmStudioChat(Base): |
|
|
|
def __init__(self, key, model_name, base_url): |
|
|
|
from os.path import join |
|
|
|
|
|
|
|
if not base_url: |
|
|
|
raise ValueError("Local llm url cannot be None") |
|
|
|
if base_url.split("/")[-1] != "v1": |
|
|
|
self.base_url = join(base_url, "v1") |
|
|
|
self.base_url = os.path.join(base_url, "v1") |
|
|
|
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url) |
|
|
|
self.model_name = model_name |