Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

chat_model.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from zhipuai import ZhipuAI
  17. from dashscope import Generation
  18. from abc import ABC
  19. from openai import OpenAI
  20. import openai
  21. from ollama import Client
  22. from rag.nlp import is_english
  23. from rag.utils import num_tokens_from_string
  24. class Base(ABC):
  25. def __init__(self, key, model_name, base_url):
  26. self.client = OpenAI(api_key=key, base_url=base_url)
  27. self.model_name = model_name
  28. def chat(self, system, history, gen_conf):
  29. if system:
  30. history.insert(0, {"role": "system", "content": system})
  31. try:
  32. response = self.client.chat.completions.create(
  33. model=self.model_name,
  34. messages=history,
  35. **gen_conf)
  36. ans = response.choices[0].message.content.strip()
  37. if response.choices[0].finish_reason == "length":
  38. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  39. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  40. return ans, response.usage.total_tokens
  41. except openai.APIError as e:
  42. return "**ERROR**: " + str(e), 0
  43. def chat_streamly(self, system, history, gen_conf):
  44. if system:
  45. history.insert(0, {"role": "system", "content": system})
  46. ans = ""
  47. total_tokens = 0
  48. try:
  49. response = self.client.chat.completions.create(
  50. model=self.model_name,
  51. messages=history,
  52. stream=True,
  53. **gen_conf)
  54. for resp in response:
  55. if not resp.choices[0].delta.content:continue
  56. ans += resp.choices[0].delta.content
  57. total_tokens += 1
  58. if resp.choices[0].finish_reason == "length":
  59. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  60. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  61. yield ans
  62. except openai.APIError as e:
  63. yield ans + "\n**ERROR**: " + str(e)
  64. yield total_tokens
  65. class GptTurbo(Base):
  66. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  67. if not base_url: base_url="https://api.openai.com/v1"
  68. super().__init__(key, model_name, base_url)
  69. class MoonshotChat(Base):
  70. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  71. if not base_url: base_url="https://api.moonshot.cn/v1"
  72. super().__init__(key, model_name, base_url)
  73. class XinferenceChat(Base):
  74. def __init__(self, key=None, model_name="", base_url=""):
  75. key = "xxx"
  76. super().__init__(key, model_name, base_url)
  77. class DeepSeekChat(Base):
  78. def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
  79. if not base_url: base_url="https://api.deepseek.com/v1"
  80. super().__init__(key, model_name, base_url)
  81. class QWenChat(Base):
  82. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  83. import dashscope
  84. dashscope.api_key = key
  85. self.model_name = model_name
  86. def chat(self, system, history, gen_conf):
  87. from http import HTTPStatus
  88. if system:
  89. history.insert(0, {"role": "system", "content": system})
  90. response = Generation.call(
  91. self.model_name,
  92. messages=history,
  93. result_format='message',
  94. **gen_conf
  95. )
  96. ans = ""
  97. tk_count = 0
  98. if response.status_code == HTTPStatus.OK:
  99. ans += response.output.choices[0]['message']['content']
  100. tk_count += response.usage.total_tokens
  101. if response.output.choices[0].get("finish_reason", "") == "length":
  102. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  103. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  104. return ans, tk_count
  105. return "**ERROR**: " + response.message, tk_count
  106. def chat_streamly(self, system, history, gen_conf):
  107. from http import HTTPStatus
  108. if system:
  109. history.insert(0, {"role": "system", "content": system})
  110. ans = ""
  111. try:
  112. response = Generation.call(
  113. self.model_name,
  114. messages=history,
  115. result_format='message',
  116. stream=True,
  117. **gen_conf
  118. )
  119. tk_count = 0
  120. for resp in response:
  121. if resp.status_code == HTTPStatus.OK:
  122. ans = resp.output.choices[0]['message']['content']
  123. tk_count = resp.usage.total_tokens
  124. if resp.output.choices[0].get("finish_reason", "") == "length":
  125. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  126. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  127. yield ans
  128. else:
  129. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  130. except Exception as e:
  131. yield ans + "\n**ERROR**: " + str(e)
  132. yield tk_count
  133. class ZhipuChat(Base):
  134. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  135. self.client = ZhipuAI(api_key=key)
  136. self.model_name = model_name
  137. def chat(self, system, history, gen_conf):
  138. if system:
  139. history.insert(0, {"role": "system", "content": system})
  140. try:
  141. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  142. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  143. response = self.client.chat.completions.create(
  144. model=self.model_name,
  145. messages=history,
  146. **gen_conf
  147. )
  148. ans = response.choices[0].message.content.strip()
  149. if response.choices[0].finish_reason == "length":
  150. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  151. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  152. return ans, response.usage.total_tokens
  153. except Exception as e:
  154. return "**ERROR**: " + str(e), 0
  155. def chat_streamly(self, system, history, gen_conf):
  156. if system:
  157. history.insert(0, {"role": "system", "content": system})
  158. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  159. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  160. ans = ""
  161. try:
  162. response = self.client.chat.completions.create(
  163. model=self.model_name,
  164. messages=history,
  165. stream=True,
  166. **gen_conf
  167. )
  168. tk_count = 0
  169. for resp in response:
  170. if not resp.choices[0].delta.content:continue
  171. delta = resp.choices[0].delta.content
  172. ans += delta
  173. if resp.choices[0].finish_reason == "length":
  174. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  175. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  176. tk_count = resp.usage.total_tokens
  177. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  178. yield ans
  179. except Exception as e:
  180. yield ans + "\n**ERROR**: " + str(e)
  181. yield tk_count
  182. class OllamaChat(Base):
  183. def __init__(self, key, model_name, **kwargs):
  184. self.client = Client(host=kwargs["base_url"])
  185. self.model_name = model_name
  186. def chat(self, system, history, gen_conf):
  187. if system:
  188. history.insert(0, {"role": "system", "content": system})
  189. try:
  190. options = {}
  191. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  192. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  193. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  194. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  195. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  196. response = self.client.chat(
  197. model=self.model_name,
  198. messages=history,
  199. options=options
  200. )
  201. ans = response["message"]["content"].strip()
  202. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  203. except Exception as e:
  204. return "**ERROR**: " + str(e), 0
  205. def chat_streamly(self, system, history, gen_conf):
  206. if system:
  207. history.insert(0, {"role": "system", "content": system})
  208. options = {}
  209. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  210. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  211. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  212. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  213. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  214. ans = ""
  215. try:
  216. response = self.client.chat(
  217. model=self.model_name,
  218. messages=history,
  219. stream=True,
  220. options=options
  221. )
  222. for resp in response:
  223. if resp["done"]:
  224. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  225. ans += resp["message"]["content"]
  226. yield ans
  227. except Exception as e:
  228. yield ans + "\n**ERROR**: " + str(e)
  229. yield 0
  230. class LocalLLM(Base):
  231. class RPCProxy:
  232. def __init__(self, host, port):
  233. self.host = host
  234. self.port = int(port)
  235. self.__conn()
  236. def __conn(self):
  237. from multiprocessing.connection import Client
  238. self._connection = Client(
  239. (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  240. def __getattr__(self, name):
  241. import pickle
  242. def do_rpc(*args, **kwargs):
  243. for _ in range(3):
  244. try:
  245. self._connection.send(
  246. pickle.dumps((name, args, kwargs)))
  247. return pickle.loads(self._connection.recv())
  248. except Exception as e:
  249. self.__conn()
  250. raise Exception("RPC connection lost!")
  251. return do_rpc
  252. def __init__(self, key, model_name="glm-3-turbo"):
  253. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  254. def chat(self, system, history, gen_conf):
  255. if system:
  256. history.insert(0, {"role": "system", "content": system})
  257. try:
  258. ans = self.client.chat(
  259. history,
  260. gen_conf
  261. )
  262. return ans, num_tokens_from_string(ans)
  263. except Exception as e:
  264. return "**ERROR**: " + str(e), 0
  265. def chat_streamly(self, system, history, gen_conf):
  266. if system:
  267. history.insert(0, {"role": "system", "content": system})
  268. token_count = 0
  269. answer = ""
  270. try:
  271. for ans in self.client.chat_streamly(history, gen_conf):
  272. answer += ans
  273. token_count += 1
  274. yield answer
  275. except Exception as e:
  276. yield answer + "\n**ERROR**: " + str(e)
  277. yield token_count