You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. from dashscope import Generation
  19. from abc import ABC
  20. from openai import OpenAI
  21. import openai
  22. from ollama import Client
  23. from volcengine.maas.v2 import MaasService
  24. from rag.nlp import is_english
  25. from rag.utils import num_tokens_from_string
  26. class Base(ABC):
  27. def __init__(self, key, model_name, base_url):
  28. self.client = OpenAI(api_key=key, base_url=base_url)
  29. self.model_name = model_name
  30. def chat(self, system, history, gen_conf):
  31. if system:
  32. history.insert(0, {"role": "system", "content": system})
  33. try:
  34. response = self.client.chat.completions.create(
  35. model=self.model_name,
  36. messages=history,
  37. **gen_conf)
  38. ans = response.choices[0].message.content.strip()
  39. if response.choices[0].finish_reason == "length":
  40. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  41. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  42. return ans, response.usage.total_tokens
  43. except openai.APIError as e:
  44. return "**ERROR**: " + str(e), 0
  45. def chat_streamly(self, system, history, gen_conf):
  46. if system:
  47. history.insert(0, {"role": "system", "content": system})
  48. ans = ""
  49. total_tokens = 0
  50. try:
  51. response = self.client.chat.completions.create(
  52. model=self.model_name,
  53. messages=history,
  54. stream=True,
  55. **gen_conf)
  56. for resp in response:
  57. if not resp.choices or not resp.choices[0].delta.content:continue
  58. ans += resp.choices[0].delta.content
  59. total_tokens += 1
  60. if resp.choices[0].finish_reason == "length":
  61. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  62. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  63. yield ans
  64. except openai.APIError as e:
  65. yield ans + "\n**ERROR**: " + str(e)
  66. yield total_tokens
  67. class GptTurbo(Base):
  68. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  69. if not base_url: base_url="https://api.openai.com/v1"
  70. super().__init__(key, model_name, base_url)
  71. class MoonshotChat(Base):
  72. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  73. if not base_url: base_url="https://api.moonshot.cn/v1"
  74. super().__init__(key, model_name, base_url)
  75. class XinferenceChat(Base):
  76. def __init__(self, key=None, model_name="", base_url=""):
  77. key = "xxx"
  78. super().__init__(key, model_name, base_url)
  79. class DeepSeekChat(Base):
  80. def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
  81. if not base_url: base_url="https://api.deepseek.com/v1"
  82. super().__init__(key, model_name, base_url)
  83. class AzureChat(Base):
  84. def __init__(self, key, model_name, **kwargs):
  85. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  86. self.model_name = model_name
  87. class BaiChuanChat(Base):
  88. def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
  89. if not base_url:
  90. base_url = "https://api.baichuan-ai.com/v1"
  91. super().__init__(key, model_name, base_url)
  92. @staticmethod
  93. def _format_params(params):
  94. return {
  95. "temperature": params.get("temperature", 0.3),
  96. "max_tokens": params.get("max_tokens", 2048),
  97. "top_p": params.get("top_p", 0.85),
  98. }
  99. def chat(self, system, history, gen_conf):
  100. if system:
  101. history.insert(0, {"role": "system", "content": system})
  102. try:
  103. response = self.client.chat.completions.create(
  104. model=self.model_name,
  105. messages=history,
  106. extra_body={
  107. "tools": [{
  108. "type": "web_search",
  109. "web_search": {
  110. "enable": True,
  111. "search_mode": "performance_first"
  112. }
  113. }]
  114. },
  115. **self._format_params(gen_conf))
  116. ans = response.choices[0].message.content.strip()
  117. if response.choices[0].finish_reason == "length":
  118. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  119. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  120. return ans, response.usage.total_tokens
  121. except openai.APIError as e:
  122. return "**ERROR**: " + str(e), 0
  123. def chat_streamly(self, system, history, gen_conf):
  124. if system:
  125. history.insert(0, {"role": "system", "content": system})
  126. ans = ""
  127. total_tokens = 0
  128. try:
  129. response = self.client.chat.completions.create(
  130. model=self.model_name,
  131. messages=history,
  132. extra_body={
  133. "tools": [{
  134. "type": "web_search",
  135. "web_search": {
  136. "enable": True,
  137. "search_mode": "performance_first"
  138. }
  139. }]
  140. },
  141. stream=True,
  142. **self._format_params(gen_conf))
  143. for resp in response:
  144. if resp.choices[0].finish_reason == "stop":
  145. if not resp.choices[0].delta.content:
  146. continue
  147. total_tokens = resp.usage.get('total_tokens', 0)
  148. if not resp.choices[0].delta.content:
  149. continue
  150. ans += resp.choices[0].delta.content
  151. if resp.choices[0].finish_reason == "length":
  152. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  153. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  154. yield ans
  155. except Exception as e:
  156. yield ans + "\n**ERROR**: " + str(e)
  157. yield total_tokens
  158. class QWenChat(Base):
  159. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  160. import dashscope
  161. dashscope.api_key = key
  162. self.model_name = model_name
  163. def chat(self, system, history, gen_conf):
  164. from http import HTTPStatus
  165. if system:
  166. history.insert(0, {"role": "system", "content": system})
  167. response = Generation.call(
  168. self.model_name,
  169. messages=history,
  170. result_format='message',
  171. **gen_conf
  172. )
  173. ans = ""
  174. tk_count = 0
  175. if response.status_code == HTTPStatus.OK:
  176. ans += response.output.choices[0]['message']['content']
  177. tk_count += response.usage.total_tokens
  178. if response.output.choices[0].get("finish_reason", "") == "length":
  179. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  180. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  181. return ans, tk_count
  182. return "**ERROR**: " + response.message, tk_count
  183. def chat_streamly(self, system, history, gen_conf):
  184. from http import HTTPStatus
  185. if system:
  186. history.insert(0, {"role": "system", "content": system})
  187. ans = ""
  188. tk_count = 0
  189. try:
  190. response = Generation.call(
  191. self.model_name,
  192. messages=history,
  193. result_format='message',
  194. stream=True,
  195. **gen_conf
  196. )
  197. for resp in response:
  198. if resp.status_code == HTTPStatus.OK:
  199. ans = resp.output.choices[0]['message']['content']
  200. tk_count = resp.usage.total_tokens
  201. if resp.output.choices[0].get("finish_reason", "") == "length":
  202. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  203. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  204. yield ans
  205. else:
  206. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  207. except Exception as e:
  208. yield ans + "\n**ERROR**: " + str(e)
  209. yield tk_count
  210. class ZhipuChat(Base):
  211. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  212. self.client = ZhipuAI(api_key=key)
  213. self.model_name = model_name
  214. def chat(self, system, history, gen_conf):
  215. if system:
  216. history.insert(0, {"role": "system", "content": system})
  217. try:
  218. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  219. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  220. response = self.client.chat.completions.create(
  221. model=self.model_name,
  222. messages=history,
  223. **gen_conf
  224. )
  225. ans = response.choices[0].message.content.strip()
  226. if response.choices[0].finish_reason == "length":
  227. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  228. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  229. return ans, response.usage.total_tokens
  230. except Exception as e:
  231. return "**ERROR**: " + str(e), 0
  232. def chat_streamly(self, system, history, gen_conf):
  233. if system:
  234. history.insert(0, {"role": "system", "content": system})
  235. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  236. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  237. ans = ""
  238. tk_count = 0
  239. try:
  240. response = self.client.chat.completions.create(
  241. model=self.model_name,
  242. messages=history,
  243. stream=True,
  244. **gen_conf
  245. )
  246. for resp in response:
  247. if not resp.choices[0].delta.content:continue
  248. delta = resp.choices[0].delta.content
  249. ans += delta
  250. if resp.choices[0].finish_reason == "length":
  251. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  252. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  253. tk_count = resp.usage.total_tokens
  254. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  255. yield ans
  256. except Exception as e:
  257. yield ans + "\n**ERROR**: " + str(e)
  258. yield tk_count
  259. class OllamaChat(Base):
  260. def __init__(self, key, model_name, **kwargs):
  261. self.client = Client(host=kwargs["base_url"])
  262. self.model_name = model_name
  263. def chat(self, system, history, gen_conf):
  264. if system:
  265. history.insert(0, {"role": "system", "content": system})
  266. try:
  267. options = {}
  268. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  269. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  270. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  271. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  272. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  273. response = self.client.chat(
  274. model=self.model_name,
  275. messages=history,
  276. options=options,
  277. keep_alive=-1
  278. )
  279. ans = response["message"]["content"].strip()
  280. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  281. except Exception as e:
  282. return "**ERROR**: " + str(e), 0
  283. def chat_streamly(self, system, history, gen_conf):
  284. if system:
  285. history.insert(0, {"role": "system", "content": system})
  286. options = {}
  287. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  288. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  289. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  290. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  291. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  292. ans = ""
  293. try:
  294. response = self.client.chat(
  295. model=self.model_name,
  296. messages=history,
  297. stream=True,
  298. options=options,
  299. keep_alive=-1
  300. )
  301. for resp in response:
  302. if resp["done"]:
  303. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  304. ans += resp["message"]["content"]
  305. yield ans
  306. except Exception as e:
  307. yield ans + "\n**ERROR**: " + str(e)
  308. yield 0
  309. class LocalLLM(Base):
  310. class RPCProxy:
  311. def __init__(self, host, port):
  312. self.host = host
  313. self.port = int(port)
  314. self.__conn()
  315. def __conn(self):
  316. from multiprocessing.connection import Client
  317. self._connection = Client(
  318. (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  319. def __getattr__(self, name):
  320. import pickle
  321. def do_rpc(*args, **kwargs):
  322. for _ in range(3):
  323. try:
  324. self._connection.send(
  325. pickle.dumps((name, args, kwargs)))
  326. return pickle.loads(self._connection.recv())
  327. except Exception as e:
  328. self.__conn()
  329. raise Exception("RPC connection lost!")
  330. return do_rpc
  331. def __init__(self, key, model_name="glm-3-turbo"):
  332. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  333. def chat(self, system, history, gen_conf):
  334. if system:
  335. history.insert(0, {"role": "system", "content": system})
  336. try:
  337. ans = self.client.chat(
  338. history,
  339. gen_conf
  340. )
  341. return ans, num_tokens_from_string(ans)
  342. except Exception as e:
  343. return "**ERROR**: " + str(e), 0
  344. def chat_streamly(self, system, history, gen_conf):
  345. if system:
  346. history.insert(0, {"role": "system", "content": system})
  347. token_count = 0
  348. answer = ""
  349. try:
  350. for ans in self.client.chat_streamly(history, gen_conf):
  351. answer += ans
  352. token_count += 1
  353. yield answer
  354. except Exception as e:
  355. yield answer + "\n**ERROR**: " + str(e)
  356. yield token_count
  357. class VolcEngineChat(Base):
  358. def __init__(self, key, model_name, base_url):
  359. """
  360. Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
  361. Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
  362. model_name is for display only
  363. """
  364. self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
  365. self.volc_ak = eval(key).get('volc_ak', '')
  366. self.volc_sk = eval(key).get('volc_sk', '')
  367. self.client.set_ak(self.volc_ak)
  368. self.client.set_sk(self.volc_sk)
  369. self.model_name = eval(key).get('ep_id', '')
  370. def chat(self, system, history, gen_conf):
  371. if system:
  372. history.insert(0, {"role": "system", "content": system})
  373. try:
  374. req = {
  375. "parameters": {
  376. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  377. "top_k": gen_conf.get("top_k", 0),
  378. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  379. "temperature": gen_conf.get("temperature", 0.1),
  380. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  381. "top_p": gen_conf.get("top_p", 0.3),
  382. },
  383. "messages": history
  384. }
  385. response = self.client.chat(self.model_name, req)
  386. ans = response.choices[0].message.content.strip()
  387. if response.choices[0].finish_reason == "length":
  388. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  389. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  390. return ans, response.usage.total_tokens
  391. except Exception as e:
  392. return "**ERROR**: " + str(e), 0
  393. def chat_streamly(self, system, history, gen_conf):
  394. if system:
  395. history.insert(0, {"role": "system", "content": system})
  396. ans = ""
  397. tk_count = 0
  398. try:
  399. req = {
  400. "parameters": {
  401. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  402. "top_k": gen_conf.get("top_k", 0),
  403. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  404. "temperature": gen_conf.get("temperature", 0.1),
  405. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  406. "top_p": gen_conf.get("top_p", 0.3),
  407. },
  408. "messages": history
  409. }
  410. stream = self.client.stream_chat(self.model_name, req)
  411. for resp in stream:
  412. if not resp.choices[0].message.content:
  413. continue
  414. ans += resp.choices[0].message.content
  415. if resp.choices[0].finish_reason == "stop":
  416. tk_count = resp.usage.total_tokens
  417. yield ans
  418. except Exception as e:
  419. yield ans + "\n**ERROR**: " + str(e)
  420. yield tk_count
  421. class MiniMaxChat(Base):
  422. def __init__(self, key, model_name="abab6.5s-chat",
  423. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"):
  424. if not base_url:
  425. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
  426. super().__init__(key, model_name, base_url)
  427. class MistralChat(Base):
  428. def __init__(self, key, model_name, base_url=None):
  429. from mistralai.client import MistralClient
  430. self.client = MistralClient(api_key=key)
  431. self.model_name = model_name
  432. def chat(self, system, history, gen_conf):
  433. if system:
  434. history.insert(0, {"role": "system", "content": system})
  435. for k in list(gen_conf.keys()):
  436. if k not in ["temperature", "top_p", "max_tokens"]:
  437. del gen_conf[k]
  438. try:
  439. response = self.client.chat(
  440. model=self.model_name,
  441. messages=history,
  442. **gen_conf)
  443. ans = response.choices[0].message.content
  444. if response.choices[0].finish_reason == "length":
  445. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  446. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  447. return ans, response.usage.total_tokens
  448. except openai.APIError as e:
  449. return "**ERROR**: " + str(e), 0
  450. def chat_streamly(self, system, history, gen_conf):
  451. if system:
  452. history.insert(0, {"role": "system", "content": system})
  453. for k in list(gen_conf.keys()):
  454. if k not in ["temperature", "top_p", "max_tokens"]:
  455. del gen_conf[k]
  456. ans = ""
  457. total_tokens = 0
  458. try:
  459. response = self.client.chat_stream(
  460. model=self.model_name,
  461. messages=history,
  462. **gen_conf)
  463. for resp in response:
  464. if not resp.choices or not resp.choices[0].delta.content:continue
  465. ans += resp.choices[0].delta.content
  466. total_tokens += 1
  467. if resp.choices[0].finish_reason == "length":
  468. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  469. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  470. yield ans
  471. except openai.APIError as e:
  472. yield ans + "\n**ERROR**: " + str(e)
  473. yield total_tokens