|
|
|
@@ -24,16 +24,7 @@ from rag.utils import num_tokens_from_string |
|
|
|
|
|
|
|
|
|
|
|
class Base(ABC): |
|
|
|
def __init__(self, key, model_name): |
|
|
|
pass |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
raise NotImplementedError("Please implement encode method!") |
|
|
|
|
|
|
|
|
|
|
|
class GptTurbo(Base): |
|
|
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): |
|
|
|
if not base_url: base_url="https://api.openai.com/v1" |
|
|
|
def __init__(self, key, model_name, base_url): |
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url) |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
@@ -54,28 +45,28 @@ class GptTurbo(Base): |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|
|
|
|
|
|
|
|
class MoonshotChat(GptTurbo): |
|
|
|
class GptTurbo(Base): |
|
|
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): |
|
|
|
if not base_url: base_url="https://api.openai.com/v1" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
|
|
|
|
class MoonshotChat(Base): |
|
|
|
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): |
|
|
|
if not base_url: base_url="https://api.moonshot.cn/v1" |
|
|
|
self.client = OpenAI( |
|
|
|
api_key=key, base_url=base_url) |
|
|
|
self.model_name = model_name |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create( |
|
|
|
model=self.model_name, |
|
|
|
messages=history, |
|
|
|
**gen_conf) |
|
|
|
ans = response.choices[0].message.content.strip() |
|
|
|
if response.choices[0].finish_reason == "length": |
|
|
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( |
|
|
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" |
|
|
|
return ans, response.usage.total_tokens |
|
|
|
except openai.APIError as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|
|
|
|
class XinferenceChat(Base): |
|
|
|
def __init__(self, key=None, model_name="", base_url=""): |
|
|
|
key = "xxx" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
|
|
|
|
class DeepSeekChat(Base): |
|
|
|
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): |
|
|
|
if not base_url: base_url="https://api.deepseek.com/v1" |
|
|
|
super().__init__(key, model_name, base_url) |
|
|
|
|
|
|
|
|
|
|
|
class QWenChat(Base): |
|
|
|
@@ -157,25 +148,3 @@ class OllamaChat(Base): |
|
|
|
except Exception as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|
|
|
|
|
|
|
|
class XinferenceChat(Base): |
|
|
|
def __init__(self, key=None, model_name="", base_url=""): |
|
|
|
self.client = OpenAI(api_key="xxx", base_url=base_url) |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
if system: |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create( |
|
|
|
model=self.model_name, |
|
|
|
messages=history, |
|
|
|
**gen_conf) |
|
|
|
ans = response.choices[0].message.content.strip() |
|
|
|
if response.choices[0].finish_reason == "length": |
|
|
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( |
|
|
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" |
|
|
|
return ans, response.usage.total_tokens |
|
|
|
except openai.APIError as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|