Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

chat_model.py 7.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from zhipuai import ZhipuAI
  17. from dashscope import Generation
  18. from abc import ABC
  19. from openai import OpenAI
  20. import openai
  21. from ollama import Client
  22. from rag.nlp import is_english
  23. from rag.utils import num_tokens_from_string
  24. class Base(ABC):
  25. def __init__(self, key, model_name):
  26. pass
  27. def chat(self, system, history, gen_conf):
  28. raise NotImplementedError("Please implement encode method!")
  29. class GptTurbo(Base):
  30. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  31. if not base_url: base_url="https://api.openai.com/v1"
  32. self.client = OpenAI(api_key=key, base_url=base_url)
  33. self.model_name = model_name
  34. def chat(self, system, history, gen_conf):
  35. if system:
  36. history.insert(0, {"role": "system", "content": system})
  37. try:
  38. response = self.client.chat.completions.create(
  39. model=self.model_name,
  40. messages=history,
  41. **gen_conf)
  42. ans = response.choices[0].message.content.strip()
  43. if response.choices[0].finish_reason == "length":
  44. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  45. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  46. return ans, response.usage.completion_tokens
  47. except openai.APIError as e:
  48. return "**ERROR**: " + str(e), 0
  49. class MoonshotChat(GptTurbo):
  50. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  51. if not base_url: base_url="https://api.moonshot.cn/v1"
  52. self.client = OpenAI(
  53. api_key=key, base_url=base_url)
  54. self.model_name = model_name
  55. def chat(self, system, history, gen_conf):
  56. if system:
  57. history.insert(0, {"role": "system", "content": system})
  58. try:
  59. response = self.client.chat.completions.create(
  60. model=self.model_name,
  61. messages=history,
  62. **gen_conf)
  63. ans = response.choices[0].message.content.strip()
  64. if response.choices[0].finish_reason == "length":
  65. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  66. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  67. return ans, response.usage.completion_tokens
  68. except openai.APIError as e:
  69. return "**ERROR**: " + str(e), 0
  70. class QWenChat(Base):
  71. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  72. import dashscope
  73. dashscope.api_key = key
  74. self.model_name = model_name
  75. def chat(self, system, history, gen_conf):
  76. from http import HTTPStatus
  77. if system:
  78. history.insert(0, {"role": "system", "content": system})
  79. response = Generation.call(
  80. self.model_name,
  81. messages=history,
  82. result_format='message',
  83. **gen_conf
  84. )
  85. ans = ""
  86. tk_count = 0
  87. if response.status_code == HTTPStatus.OK:
  88. ans += response.output.choices[0]['message']['content']
  89. tk_count += response.usage.output_tokens
  90. if response.output.choices[0].get("finish_reason", "") == "length":
  91. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  92. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  93. return ans, tk_count
  94. return "**ERROR**: " + response.message, tk_count
  95. class ZhipuChat(Base):
  96. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  97. self.client = ZhipuAI(api_key=key)
  98. self.model_name = model_name
  99. def chat(self, system, history, gen_conf):
  100. if system:
  101. history.insert(0, {"role": "system", "content": system})
  102. try:
  103. response = self.client.chat.completions.create(
  104. model=self.model_name,
  105. messages=history,
  106. **gen_conf
  107. )
  108. ans = response.choices[0].message.content.strip()
  109. if response.choices[0].finish_reason == "length":
  110. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  111. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  112. return ans, response.usage.completion_tokens
  113. except Exception as e:
  114. return "**ERROR**: " + str(e), 0
  115. class OllamaChat(Base):
  116. def __init__(self, key, model_name, **kwargs):
  117. self.client = Client(host=kwargs["base_url"])
  118. self.model_name = model_name
  119. def chat(self, system, history, gen_conf):
  120. if system:
  121. history.insert(0, {"role": "system", "content": system})
  122. try:
  123. options = {"temperature": gen_conf.get("temperature", 0.1),
  124. "num_predict": gen_conf.get("max_tokens", 128),
  125. "top_k": gen_conf.get("top_p", 0.3),
  126. "presence_penalty": gen_conf.get("presence_penalty", 0.4),
  127. "frequency_penalty": gen_conf.get("frequency_penalty", 0.7),
  128. }
  129. response = self.client.chat(
  130. model=self.model_name,
  131. messages=history,
  132. options=options
  133. )
  134. ans = response["message"]["content"].strip()
  135. return ans, response["eval_count"]
  136. except Exception as e:
  137. return "**ERROR**: " + str(e), 0
  138. class LocalLLM(Base):
  139. class RPCProxy:
  140. def __init__(self, host, port):
  141. self.host = host
  142. self.port = int(port)
  143. self.__conn()
  144. def __conn(self):
  145. from multiprocessing.connection import Client
  146. self._connection = Client(
  147. (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  148. def __getattr__(self, name):
  149. import pickle
  150. def do_rpc(*args, **kwargs):
  151. for _ in range(3):
  152. try:
  153. self._connection.send(
  154. pickle.dumps((name, args, kwargs)))
  155. return pickle.loads(self._connection.recv())
  156. except Exception as e:
  157. self.__conn()
  158. raise Exception("RPC connection lost!")
  159. return do_rpc
  160. def __init__(self, *args, **kwargs):
  161. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  162. def chat(self, system, history, gen_conf):
  163. if system:
  164. history.insert(0, {"role": "system", "content": system})
  165. try:
  166. ans = self.client.chat(
  167. history,
  168. gen_conf
  169. )
  170. return ans, num_tokens_from_string(ans)
  171. except Exception as e:
  172. return "**ERROR**: " + str(e), 0