Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

chat_model.py 38KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. from dashscope import Generation
  19. from abc import ABC
  20. from openai import OpenAI
  21. import openai
  22. from ollama import Client
  23. from volcengine.maas.v2 import MaasService
  24. from rag.nlp import is_english
  25. from rag.utils import num_tokens_from_string
  26. from groq import Groq
  27. import json
  28. import requests
  29. class Base(ABC):
  30. def __init__(self, key, model_name, base_url):
  31. self.client = OpenAI(api_key=key, base_url=base_url)
  32. self.model_name = model_name
  33. def chat(self, system, history, gen_conf):
  34. if system:
  35. history.insert(0, {"role": "system", "content": system})
  36. try:
  37. response = self.client.chat.completions.create(
  38. model=self.model_name,
  39. messages=history,
  40. **gen_conf)
  41. ans = response.choices[0].message.content.strip()
  42. if response.choices[0].finish_reason == "length":
  43. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  44. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  45. return ans, response.usage.total_tokens
  46. except openai.APIError as e:
  47. return "**ERROR**: " + str(e), 0
  48. def chat_streamly(self, system, history, gen_conf):
  49. if system:
  50. history.insert(0, {"role": "system", "content": system})
  51. ans = ""
  52. total_tokens = 0
  53. try:
  54. response = self.client.chat.completions.create(
  55. model=self.model_name,
  56. messages=history,
  57. stream=True,
  58. **gen_conf)
  59. for resp in response:
  60. if not resp.choices or not resp.choices[0].delta.content:continue
  61. ans += resp.choices[0].delta.content
  62. total_tokens += 1
  63. if resp.choices[0].finish_reason == "length":
  64. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  65. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  66. yield ans
  67. except openai.APIError as e:
  68. yield ans + "\n**ERROR**: " + str(e)
  69. yield total_tokens
  70. class GptTurbo(Base):
  71. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  72. if not base_url: base_url="https://api.openai.com/v1"
  73. super().__init__(key, model_name, base_url)
  74. class MoonshotChat(Base):
  75. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  76. if not base_url: base_url="https://api.moonshot.cn/v1"
  77. super().__init__(key, model_name, base_url)
  78. class XinferenceChat(Base):
  79. def __init__(self, key=None, model_name="", base_url=""):
  80. key = "xxx"
  81. super().__init__(key, model_name, base_url)
  82. class DeepSeekChat(Base):
  83. def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
  84. if not base_url: base_url="https://api.deepseek.com/v1"
  85. super().__init__(key, model_name, base_url)
  86. class AzureChat(Base):
  87. def __init__(self, key, model_name, **kwargs):
  88. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  89. self.model_name = model_name
  90. class BaiChuanChat(Base):
  91. def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
  92. if not base_url:
  93. base_url = "https://api.baichuan-ai.com/v1"
  94. super().__init__(key, model_name, base_url)
  95. @staticmethod
  96. def _format_params(params):
  97. return {
  98. "temperature": params.get("temperature", 0.3),
  99. "max_tokens": params.get("max_tokens", 2048),
  100. "top_p": params.get("top_p", 0.85),
  101. }
  102. def chat(self, system, history, gen_conf):
  103. if system:
  104. history.insert(0, {"role": "system", "content": system})
  105. try:
  106. response = self.client.chat.completions.create(
  107. model=self.model_name,
  108. messages=history,
  109. extra_body={
  110. "tools": [{
  111. "type": "web_search",
  112. "web_search": {
  113. "enable": True,
  114. "search_mode": "performance_first"
  115. }
  116. }]
  117. },
  118. **self._format_params(gen_conf))
  119. ans = response.choices[0].message.content.strip()
  120. if response.choices[0].finish_reason == "length":
  121. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  122. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  123. return ans, response.usage.total_tokens
  124. except openai.APIError as e:
  125. return "**ERROR**: " + str(e), 0
  126. def chat_streamly(self, system, history, gen_conf):
  127. if system:
  128. history.insert(0, {"role": "system", "content": system})
  129. ans = ""
  130. total_tokens = 0
  131. try:
  132. response = self.client.chat.completions.create(
  133. model=self.model_name,
  134. messages=history,
  135. extra_body={
  136. "tools": [{
  137. "type": "web_search",
  138. "web_search": {
  139. "enable": True,
  140. "search_mode": "performance_first"
  141. }
  142. }]
  143. },
  144. stream=True,
  145. **self._format_params(gen_conf))
  146. for resp in response:
  147. if resp.choices[0].finish_reason == "stop":
  148. if not resp.choices[0].delta.content:
  149. continue
  150. total_tokens = resp.usage.get('total_tokens', 0)
  151. if not resp.choices[0].delta.content:
  152. continue
  153. ans += resp.choices[0].delta.content
  154. if resp.choices[0].finish_reason == "length":
  155. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  156. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  157. yield ans
  158. except Exception as e:
  159. yield ans + "\n**ERROR**: " + str(e)
  160. yield total_tokens
  161. class QWenChat(Base):
  162. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  163. import dashscope
  164. dashscope.api_key = key
  165. self.model_name = model_name
  166. def chat(self, system, history, gen_conf):
  167. from http import HTTPStatus
  168. if system:
  169. history.insert(0, {"role": "system", "content": system})
  170. response = Generation.call(
  171. self.model_name,
  172. messages=history,
  173. result_format='message',
  174. **gen_conf
  175. )
  176. ans = ""
  177. tk_count = 0
  178. if response.status_code == HTTPStatus.OK:
  179. ans += response.output.choices[0]['message']['content']
  180. tk_count += response.usage.total_tokens
  181. if response.output.choices[0].get("finish_reason", "") == "length":
  182. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  183. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  184. return ans, tk_count
  185. return "**ERROR**: " + response.message, tk_count
  186. def chat_streamly(self, system, history, gen_conf):
  187. from http import HTTPStatus
  188. if system:
  189. history.insert(0, {"role": "system", "content": system})
  190. ans = ""
  191. tk_count = 0
  192. try:
  193. response = Generation.call(
  194. self.model_name,
  195. messages=history,
  196. result_format='message',
  197. stream=True,
  198. **gen_conf
  199. )
  200. for resp in response:
  201. if resp.status_code == HTTPStatus.OK:
  202. ans = resp.output.choices[0]['message']['content']
  203. tk_count = resp.usage.total_tokens
  204. if resp.output.choices[0].get("finish_reason", "") == "length":
  205. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  206. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  207. yield ans
  208. else:
  209. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  210. except Exception as e:
  211. yield ans + "\n**ERROR**: " + str(e)
  212. yield tk_count
  213. class ZhipuChat(Base):
  214. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  215. self.client = ZhipuAI(api_key=key)
  216. self.model_name = model_name
  217. def chat(self, system, history, gen_conf):
  218. if system:
  219. history.insert(0, {"role": "system", "content": system})
  220. try:
  221. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  222. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  223. response = self.client.chat.completions.create(
  224. model=self.model_name,
  225. messages=history,
  226. **gen_conf
  227. )
  228. ans = response.choices[0].message.content.strip()
  229. if response.choices[0].finish_reason == "length":
  230. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  231. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  232. return ans, response.usage.total_tokens
  233. except Exception as e:
  234. return "**ERROR**: " + str(e), 0
  235. def chat_streamly(self, system, history, gen_conf):
  236. if system:
  237. history.insert(0, {"role": "system", "content": system})
  238. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  239. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  240. ans = ""
  241. tk_count = 0
  242. try:
  243. response = self.client.chat.completions.create(
  244. model=self.model_name,
  245. messages=history,
  246. stream=True,
  247. **gen_conf
  248. )
  249. for resp in response:
  250. if not resp.choices[0].delta.content:continue
  251. delta = resp.choices[0].delta.content
  252. ans += delta
  253. if resp.choices[0].finish_reason == "length":
  254. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  255. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  256. tk_count = resp.usage.total_tokens
  257. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  258. yield ans
  259. except Exception as e:
  260. yield ans + "\n**ERROR**: " + str(e)
  261. yield tk_count
  262. class OllamaChat(Base):
  263. def __init__(self, key, model_name, **kwargs):
  264. self.client = Client(host=kwargs["base_url"])
  265. self.model_name = model_name
  266. def chat(self, system, history, gen_conf):
  267. if system:
  268. history.insert(0, {"role": "system", "content": system})
  269. try:
  270. options = {}
  271. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  272. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  273. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  274. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  275. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  276. response = self.client.chat(
  277. model=self.model_name,
  278. messages=history,
  279. options=options,
  280. keep_alive=-1
  281. )
  282. ans = response["message"]["content"].strip()
  283. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  284. except Exception as e:
  285. return "**ERROR**: " + str(e), 0
  286. def chat_streamly(self, system, history, gen_conf):
  287. if system:
  288. history.insert(0, {"role": "system", "content": system})
  289. options = {}
  290. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  291. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  292. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  293. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  294. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  295. ans = ""
  296. try:
  297. response = self.client.chat(
  298. model=self.model_name,
  299. messages=history,
  300. stream=True,
  301. options=options,
  302. keep_alive=-1
  303. )
  304. for resp in response:
  305. if resp["done"]:
  306. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  307. ans += resp["message"]["content"]
  308. yield ans
  309. except Exception as e:
  310. yield ans + "\n**ERROR**: " + str(e)
  311. yield 0
  312. class LocalAIChat(Base):
  313. def __init__(self, key, model_name, base_url):
  314. if base_url[-1] == "/":
  315. base_url = base_url[:-1]
  316. self.base_url = base_url + "/v1/chat/completions"
  317. self.model_name = model_name.split("___")[0]
  318. def chat(self, system, history, gen_conf):
  319. if system:
  320. history.insert(0, {"role": "system", "content": system})
  321. for k in list(gen_conf.keys()):
  322. if k not in ["temperature", "top_p", "max_tokens"]:
  323. del gen_conf[k]
  324. headers = {
  325. "Content-Type": "application/json",
  326. }
  327. payload = json.dumps(
  328. {"model": self.model_name, "messages": history, **gen_conf}
  329. )
  330. try:
  331. response = requests.request(
  332. "POST", url=self.base_url, headers=headers, data=payload
  333. )
  334. response = response.json()
  335. ans = response["choices"][0]["message"]["content"].strip()
  336. if response["choices"][0]["finish_reason"] == "length":
  337. ans += (
  338. "...\nFor the content length reason, it stopped, continue?"
  339. if is_english([ans])
  340. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  341. )
  342. return ans, response["usage"]["total_tokens"]
  343. except Exception as e:
  344. return "**ERROR**: " + str(e), 0
  345. def chat_streamly(self, system, history, gen_conf):
  346. if system:
  347. history.insert(0, {"role": "system", "content": system})
  348. ans = ""
  349. total_tokens = 0
  350. try:
  351. headers = {
  352. "Content-Type": "application/json",
  353. }
  354. payload = json.dumps(
  355. {
  356. "model": self.model_name,
  357. "messages": history,
  358. "stream": True,
  359. **gen_conf,
  360. }
  361. )
  362. response = requests.request(
  363. "POST",
  364. url=self.base_url,
  365. headers=headers,
  366. data=payload,
  367. )
  368. for resp in response.content.decode("utf-8").split("\n\n"):
  369. if "choices" not in resp:
  370. continue
  371. resp = json.loads(resp[6:])
  372. if "delta" in resp["choices"][0]:
  373. text = resp["choices"][0]["delta"]["content"]
  374. else:
  375. continue
  376. ans += text
  377. total_tokens += 1
  378. yield ans
  379. except Exception as e:
  380. yield ans + "\n**ERROR**: " + str(e)
  381. yield total_tokens
  382. class LocalLLM(Base):
  383. class RPCProxy:
  384. def __init__(self, host, port):
  385. self.host = host
  386. self.port = int(port)
  387. self.__conn()
  388. def __conn(self):
  389. from multiprocessing.connection import Client
  390. self._connection = Client(
  391. (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  392. def __getattr__(self, name):
  393. import pickle
  394. def do_rpc(*args, **kwargs):
  395. for _ in range(3):
  396. try:
  397. self._connection.send(
  398. pickle.dumps((name, args, kwargs)))
  399. return pickle.loads(self._connection.recv())
  400. except Exception as e:
  401. self.__conn()
  402. raise Exception("RPC connection lost!")
  403. return do_rpc
  404. def __init__(self, key, model_name="glm-3-turbo"):
  405. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  406. def chat(self, system, history, gen_conf):
  407. if system:
  408. history.insert(0, {"role": "system", "content": system})
  409. try:
  410. ans = self.client.chat(
  411. history,
  412. gen_conf
  413. )
  414. return ans, num_tokens_from_string(ans)
  415. except Exception as e:
  416. return "**ERROR**: " + str(e), 0
  417. def chat_streamly(self, system, history, gen_conf):
  418. if system:
  419. history.insert(0, {"role": "system", "content": system})
  420. token_count = 0
  421. answer = ""
  422. try:
  423. for ans in self.client.chat_streamly(history, gen_conf):
  424. answer += ans
  425. token_count += 1
  426. yield answer
  427. except Exception as e:
  428. yield answer + "\n**ERROR**: " + str(e)
  429. yield token_count
  430. class VolcEngineChat(Base):
  431. def __init__(self, key, model_name, base_url):
  432. """
  433. Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
  434. Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
  435. model_name is for display only
  436. """
  437. self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
  438. self.volc_ak = eval(key).get('volc_ak', '')
  439. self.volc_sk = eval(key).get('volc_sk', '')
  440. self.client.set_ak(self.volc_ak)
  441. self.client.set_sk(self.volc_sk)
  442. self.model_name = eval(key).get('ep_id', '')
  443. def chat(self, system, history, gen_conf):
  444. if system:
  445. history.insert(0, {"role": "system", "content": system})
  446. try:
  447. req = {
  448. "parameters": {
  449. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  450. "top_k": gen_conf.get("top_k", 0),
  451. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  452. "temperature": gen_conf.get("temperature", 0.1),
  453. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  454. "top_p": gen_conf.get("top_p", 0.3),
  455. },
  456. "messages": history
  457. }
  458. response = self.client.chat(self.model_name, req)
  459. ans = response.choices[0].message.content.strip()
  460. if response.choices[0].finish_reason == "length":
  461. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  462. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  463. return ans, response.usage.total_tokens
  464. except Exception as e:
  465. return "**ERROR**: " + str(e), 0
  466. def chat_streamly(self, system, history, gen_conf):
  467. if system:
  468. history.insert(0, {"role": "system", "content": system})
  469. ans = ""
  470. tk_count = 0
  471. try:
  472. req = {
  473. "parameters": {
  474. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  475. "top_k": gen_conf.get("top_k", 0),
  476. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  477. "temperature": gen_conf.get("temperature", 0.1),
  478. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  479. "top_p": gen_conf.get("top_p", 0.3),
  480. },
  481. "messages": history
  482. }
  483. stream = self.client.stream_chat(self.model_name, req)
  484. for resp in stream:
  485. if not resp.choices[0].message.content:
  486. continue
  487. ans += resp.choices[0].message.content
  488. if resp.choices[0].finish_reason == "stop":
  489. tk_count = resp.usage.total_tokens
  490. yield ans
  491. except Exception as e:
  492. yield ans + "\n**ERROR**: " + str(e)
  493. yield tk_count
  494. class MiniMaxChat(Base):
  495. def __init__(
  496. self,
  497. key,
  498. model_name,
  499. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
  500. ):
  501. if not base_url:
  502. base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
  503. self.base_url = base_url
  504. self.model_name = model_name
  505. self.api_key = key
  506. def chat(self, system, history, gen_conf):
  507. if system:
  508. history.insert(0, {"role": "system", "content": system})
  509. for k in list(gen_conf.keys()):
  510. if k not in ["temperature", "top_p", "max_tokens"]:
  511. del gen_conf[k]
  512. headers = {
  513. "Authorization": f"Bearer {self.api_key}",
  514. "Content-Type": "application/json",
  515. }
  516. payload = json.dumps(
  517. {"model": self.model_name, "messages": history, **gen_conf}
  518. )
  519. try:
  520. response = requests.request(
  521. "POST", url=self.base_url, headers=headers, data=payload
  522. )
  523. response = response.json()
  524. ans = response["choices"][0]["message"]["content"].strip()
  525. if response["choices"][0]["finish_reason"] == "length":
  526. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  527. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  528. return ans, response["usage"]["total_tokens"]
  529. except Exception as e:
  530. return "**ERROR**: " + str(e), 0
  531. def chat_streamly(self, system, history, gen_conf):
  532. if system:
  533. history.insert(0, {"role": "system", "content": system})
  534. ans = ""
  535. total_tokens = 0
  536. try:
  537. headers = {
  538. "Authorization": f"Bearer {self.api_key}",
  539. "Content-Type": "application/json",
  540. }
  541. payload = json.dumps(
  542. {
  543. "model": self.model_name,
  544. "messages": history,
  545. "stream": True,
  546. **gen_conf,
  547. }
  548. )
  549. response = requests.request(
  550. "POST",
  551. url=self.base_url,
  552. headers=headers,
  553. data=payload,
  554. )
  555. for resp in response.text.split("\n\n")[:-1]:
  556. resp = json.loads(resp[6:])
  557. if "delta" in resp["choices"][0]:
  558. text = resp["choices"][0]["delta"]["content"]
  559. else:
  560. continue
  561. ans += text
  562. total_tokens += num_tokens_from_string(text)
  563. yield ans
  564. except Exception as e:
  565. yield ans + "\n**ERROR**: " + str(e)
  566. yield total_tokens
  567. class MistralChat(Base):
  568. def __init__(self, key, model_name, base_url=None):
  569. from mistralai.client import MistralClient
  570. self.client = MistralClient(api_key=key)
  571. self.model_name = model_name
  572. def chat(self, system, history, gen_conf):
  573. if system:
  574. history.insert(0, {"role": "system", "content": system})
  575. for k in list(gen_conf.keys()):
  576. if k not in ["temperature", "top_p", "max_tokens"]:
  577. del gen_conf[k]
  578. try:
  579. response = self.client.chat(
  580. model=self.model_name,
  581. messages=history,
  582. **gen_conf)
  583. ans = response.choices[0].message.content
  584. if response.choices[0].finish_reason == "length":
  585. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  586. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  587. return ans, response.usage.total_tokens
  588. except openai.APIError as e:
  589. return "**ERROR**: " + str(e), 0
  590. def chat_streamly(self, system, history, gen_conf):
  591. if system:
  592. history.insert(0, {"role": "system", "content": system})
  593. for k in list(gen_conf.keys()):
  594. if k not in ["temperature", "top_p", "max_tokens"]:
  595. del gen_conf[k]
  596. ans = ""
  597. total_tokens = 0
  598. try:
  599. response = self.client.chat_stream(
  600. model=self.model_name,
  601. messages=history,
  602. **gen_conf)
  603. for resp in response:
  604. if not resp.choices or not resp.choices[0].delta.content:continue
  605. ans += resp.choices[0].delta.content
  606. total_tokens += 1
  607. if resp.choices[0].finish_reason == "length":
  608. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  609. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  610. yield ans
  611. except openai.APIError as e:
  612. yield ans + "\n**ERROR**: " + str(e)
  613. yield total_tokens
  614. class BedrockChat(Base):
  615. def __init__(self, key, model_name, **kwargs):
  616. import boto3
  617. self.bedrock_ak = eval(key).get('bedrock_ak', '')
  618. self.bedrock_sk = eval(key).get('bedrock_sk', '')
  619. self.bedrock_region = eval(key).get('bedrock_region', '')
  620. self.model_name = model_name
  621. self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
  622. aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
  623. def chat(self, system, history, gen_conf):
  624. from botocore.exceptions import ClientError
  625. if system:
  626. history.insert(0, {"role": "system", "content": system})
  627. for k in list(gen_conf.keys()):
  628. if k not in ["temperature", "top_p", "max_tokens"]:
  629. del gen_conf[k]
  630. if "max_tokens" in gen_conf:
  631. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  632. _ = gen_conf.pop("max_tokens")
  633. if "top_p" in gen_conf:
  634. gen_conf["topP"] = gen_conf["top_p"]
  635. _ = gen_conf.pop("top_p")
  636. try:
  637. # Send the message to the model, using a basic inference configuration.
  638. response = self.client.converse(
  639. modelId=self.model_name,
  640. messages=history,
  641. inferenceConfig=gen_conf
  642. )
  643. # Extract and print the response text.
  644. ans = response["output"]["message"]["content"][0]["text"]
  645. return ans, num_tokens_from_string(ans)
  646. except (ClientError, Exception) as e:
  647. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  648. def chat_streamly(self, system, history, gen_conf):
  649. from botocore.exceptions import ClientError
  650. if system:
  651. history.insert(0, {"role": "system", "content": system})
  652. for k in list(gen_conf.keys()):
  653. if k not in ["temperature", "top_p", "max_tokens"]:
  654. del gen_conf[k]
  655. if "max_tokens" in gen_conf:
  656. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  657. _ = gen_conf.pop("max_tokens")
  658. if "top_p" in gen_conf:
  659. gen_conf["topP"] = gen_conf["top_p"]
  660. _ = gen_conf.pop("top_p")
  661. if self.model_name.split('.')[0] == 'ai21':
  662. try:
  663. response = self.client.converse(
  664. modelId=self.model_name,
  665. messages=history,
  666. inferenceConfig=gen_conf
  667. )
  668. ans = response["output"]["message"]["content"][0]["text"]
  669. return ans, num_tokens_from_string(ans)
  670. except (ClientError, Exception) as e:
  671. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  672. ans = ""
  673. try:
  674. # Send the message to the model, using a basic inference configuration.
  675. streaming_response = self.client.converse_stream(
  676. modelId=self.model_name,
  677. messages=history,
  678. inferenceConfig=gen_conf
  679. )
  680. # Extract and print the streamed response text in real-time.
  681. for resp in streaming_response["stream"]:
  682. if "contentBlockDelta" in resp:
  683. ans += resp["contentBlockDelta"]["delta"]["text"]
  684. yield ans
  685. except (ClientError, Exception) as e:
  686. yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
  687. yield num_tokens_from_string(ans)
  688. class GeminiChat(Base):
  689. def __init__(self, key, model_name,base_url=None):
  690. from google.generativeai import client,GenerativeModel
  691. client.configure(api_key=key)
  692. _client = client.get_default_generative_client()
  693. self.model_name = 'models/' + model_name
  694. self.model = GenerativeModel(model_name=self.model_name)
  695. self.model._client = _client
  696. def chat(self,system,history,gen_conf):
  697. if system:
  698. history.insert(0, {"role": "user", "parts": system})
  699. if 'max_tokens' in gen_conf:
  700. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  701. for k in list(gen_conf.keys()):
  702. if k not in ["temperature", "top_p", "max_output_tokens"]:
  703. del gen_conf[k]
  704. for item in history:
  705. if 'role' in item and item['role'] == 'assistant':
  706. item['role'] = 'model'
  707. if 'content' in item :
  708. item['parts'] = item.pop('content')
  709. try:
  710. response = self.model.generate_content(
  711. history,
  712. generation_config=gen_conf)
  713. ans = response.text
  714. return ans, response.usage_metadata.total_token_count
  715. except Exception as e:
  716. return "**ERROR**: " + str(e), 0
  717. def chat_streamly(self, system, history, gen_conf):
  718. if system:
  719. history.insert(0, {"role": "user", "parts": system})
  720. if 'max_tokens' in gen_conf:
  721. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  722. for k in list(gen_conf.keys()):
  723. if k not in ["temperature", "top_p", "max_output_tokens"]:
  724. del gen_conf[k]
  725. for item in history:
  726. if 'role' in item and item['role'] == 'assistant':
  727. item['role'] = 'model'
  728. if 'content' in item :
  729. item['parts'] = item.pop('content')
  730. ans = ""
  731. try:
  732. response = self.model.generate_content(
  733. history,
  734. generation_config=gen_conf,stream=True)
  735. for resp in response:
  736. ans += resp.text
  737. yield ans
  738. except Exception as e:
  739. yield ans + "\n**ERROR**: " + str(e)
  740. yield response._chunks[-1].usage_metadata.total_token_count
  741. class GroqChat:
  742. def __init__(self, key, model_name,base_url=''):
  743. self.client = Groq(api_key=key)
  744. self.model_name = model_name
  745. def chat(self, system, history, gen_conf):
  746. if system:
  747. history.insert(0, {"role": "system", "content": system})
  748. for k in list(gen_conf.keys()):
  749. if k not in ["temperature", "top_p", "max_tokens"]:
  750. del gen_conf[k]
  751. ans = ""
  752. try:
  753. response = self.client.chat.completions.create(
  754. model=self.model_name,
  755. messages=history,
  756. **gen_conf
  757. )
  758. ans = response.choices[0].message.content
  759. if response.choices[0].finish_reason == "length":
  760. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  761. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  762. return ans, response.usage.total_tokens
  763. except Exception as e:
  764. return ans + "\n**ERROR**: " + str(e), 0
  765. def chat_streamly(self, system, history, gen_conf):
  766. if system:
  767. history.insert(0, {"role": "system", "content": system})
  768. for k in list(gen_conf.keys()):
  769. if k not in ["temperature", "top_p", "max_tokens"]:
  770. del gen_conf[k]
  771. ans = ""
  772. total_tokens = 0
  773. try:
  774. response = self.client.chat.completions.create(
  775. model=self.model_name,
  776. messages=history,
  777. stream=True,
  778. **gen_conf
  779. )
  780. for resp in response:
  781. if not resp.choices or not resp.choices[0].delta.content:
  782. continue
  783. ans += resp.choices[0].delta.content
  784. total_tokens += 1
  785. if resp.choices[0].finish_reason == "length":
  786. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  787. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  788. yield ans
  789. except Exception as e:
  790. yield ans + "\n**ERROR**: " + str(e)
  791. yield total_tokens
  792. ## openrouter
  793. class OpenRouterChat(Base):
  794. def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
  795. self.base_url = "https://openrouter.ai/api/v1"
  796. self.client = OpenAI(base_url=self.base_url, api_key=key)
  797. self.model_name = model_name
  798. class StepFunChat(Base):
  799. def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
  800. if not base_url:
  801. base_url = "https://api.stepfun.com/v1"
  802. super().__init__(key, model_name, base_url)
  803. class NvidiaChat(Base):
  804. def __init__(
  805. self,
  806. key,
  807. model_name,
  808. base_url="https://integrate.api.nvidia.com/v1/chat/completions",
  809. ):
  810. if not base_url:
  811. base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
  812. self.base_url = base_url
  813. self.model_name = model_name
  814. self.api_key = key
  815. self.headers = {
  816. "accept": "application/json",
  817. "Authorization": f"Bearer {self.api_key}",
  818. "Content-Type": "application/json",
  819. }
  820. def chat(self, system, history, gen_conf):
  821. if system:
  822. history.insert(0, {"role": "system", "content": system})
  823. for k in list(gen_conf.keys()):
  824. if k not in ["temperature", "top_p", "max_tokens"]:
  825. del gen_conf[k]
  826. payload = {"model": self.model_name, "messages": history, **gen_conf}
  827. try:
  828. response = requests.post(
  829. url=self.base_url, headers=self.headers, json=payload
  830. )
  831. response = response.json()
  832. ans = response["choices"][0]["message"]["content"].strip()
  833. return ans, response["usage"]["total_tokens"]
  834. except Exception as e:
  835. return "**ERROR**: " + str(e), 0
  836. def chat_streamly(self, system, history, gen_conf):
  837. if system:
  838. history.insert(0, {"role": "system", "content": system})
  839. for k in list(gen_conf.keys()):
  840. if k not in ["temperature", "top_p", "max_tokens"]:
  841. del gen_conf[k]
  842. ans = ""
  843. total_tokens = 0
  844. payload = {
  845. "model": self.model_name,
  846. "messages": history,
  847. "stream": True,
  848. **gen_conf,
  849. }
  850. try:
  851. response = requests.post(
  852. url=self.base_url,
  853. headers=self.headers,
  854. json=payload,
  855. )
  856. for resp in response.text.split("\n\n"):
  857. if "choices" not in resp:
  858. continue
  859. resp = json.loads(resp[6:])
  860. if "content" in resp["choices"][0]["delta"]:
  861. text = resp["choices"][0]["delta"]["content"]
  862. else:
  863. continue
  864. ans += text
  865. if "usage" in resp:
  866. total_tokens = resp["usage"]["total_tokens"]
  867. yield ans
  868. except Exception as e:
  869. yield ans + "\n**ERROR**: " + str(e)
  870. yield total_tokens
  871. class LmStudioChat(Base):
  872. def __init__(self, key, model_name, base_url):
  873. from os.path import join
  874. if not base_url:
  875. raise ValueError("Local llm url cannot be None")
  876. if base_url.split("/")[-1] != "v1":
  877. self.base_url = join(base_url, "v1")
  878. self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
  879. self.model_name = model_name