您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

chat_model.py 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. from dashscope import Generation
  19. from abc import ABC
  20. from openai import OpenAI
  21. import openai
  22. from ollama import Client
  23. from volcengine.maas.v2 import MaasService
  24. from rag.nlp import is_english
  25. from rag.utils import num_tokens_from_string
  26. from groq import Groq
  27. import json
  28. import requests
  29. class Base(ABC):
  30. def __init__(self, key, model_name, base_url):
  31. self.client = OpenAI(api_key=key, base_url=base_url)
  32. self.model_name = model_name
  33. def chat(self, system, history, gen_conf):
  34. if system:
  35. history.insert(0, {"role": "system", "content": system})
  36. try:
  37. response = self.client.chat.completions.create(
  38. model=self.model_name,
  39. messages=history,
  40. **gen_conf)
  41. ans = response.choices[0].message.content.strip()
  42. if response.choices[0].finish_reason == "length":
  43. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  44. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  45. return ans, response.usage.total_tokens
  46. except openai.APIError as e:
  47. return "**ERROR**: " + str(e), 0
  48. def chat_streamly(self, system, history, gen_conf):
  49. if system:
  50. history.insert(0, {"role": "system", "content": system})
  51. ans = ""
  52. total_tokens = 0
  53. try:
  54. response = self.client.chat.completions.create(
  55. model=self.model_name,
  56. messages=history,
  57. stream=True,
  58. **gen_conf)
  59. for resp in response:
  60. if not resp.choices or not resp.choices[0].delta.content:continue
  61. ans += resp.choices[0].delta.content
  62. total_tokens += 1
  63. if resp.choices[0].finish_reason == "length":
  64. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  65. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  66. yield ans
  67. except openai.APIError as e:
  68. yield ans + "\n**ERROR**: " + str(e)
  69. yield total_tokens
  70. class GptTurbo(Base):
  71. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  72. if not base_url: base_url="https://api.openai.com/v1"
  73. super().__init__(key, model_name, base_url)
  74. class MoonshotChat(Base):
  75. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  76. if not base_url: base_url="https://api.moonshot.cn/v1"
  77. super().__init__(key, model_name, base_url)
  78. class XinferenceChat(Base):
  79. def __init__(self, key=None, model_name="", base_url=""):
  80. key = "xxx"
  81. super().__init__(key, model_name, base_url)
  82. class DeepSeekChat(Base):
  83. def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
  84. if not base_url: base_url="https://api.deepseek.com/v1"
  85. super().__init__(key, model_name, base_url)
  86. class AzureChat(Base):
  87. def __init__(self, key, model_name, **kwargs):
  88. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  89. self.model_name = model_name
  90. class BaiChuanChat(Base):
  91. def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
  92. if not base_url:
  93. base_url = "https://api.baichuan-ai.com/v1"
  94. super().__init__(key, model_name, base_url)
  95. @staticmethod
  96. def _format_params(params):
  97. return {
  98. "temperature": params.get("temperature", 0.3),
  99. "max_tokens": params.get("max_tokens", 2048),
  100. "top_p": params.get("top_p", 0.85),
  101. }
  102. def chat(self, system, history, gen_conf):
  103. if system:
  104. history.insert(0, {"role": "system", "content": system})
  105. try:
  106. response = self.client.chat.completions.create(
  107. model=self.model_name,
  108. messages=history,
  109. extra_body={
  110. "tools": [{
  111. "type": "web_search",
  112. "web_search": {
  113. "enable": True,
  114. "search_mode": "performance_first"
  115. }
  116. }]
  117. },
  118. **self._format_params(gen_conf))
  119. ans = response.choices[0].message.content.strip()
  120. if response.choices[0].finish_reason == "length":
  121. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  122. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  123. return ans, response.usage.total_tokens
  124. except openai.APIError as e:
  125. return "**ERROR**: " + str(e), 0
  126. def chat_streamly(self, system, history, gen_conf):
  127. if system:
  128. history.insert(0, {"role": "system", "content": system})
  129. ans = ""
  130. total_tokens = 0
  131. try:
  132. response = self.client.chat.completions.create(
  133. model=self.model_name,
  134. messages=history,
  135. extra_body={
  136. "tools": [{
  137. "type": "web_search",
  138. "web_search": {
  139. "enable": True,
  140. "search_mode": "performance_first"
  141. }
  142. }]
  143. },
  144. stream=True,
  145. **self._format_params(gen_conf))
  146. for resp in response:
  147. if resp.choices[0].finish_reason == "stop":
  148. if not resp.choices[0].delta.content:
  149. continue
  150. total_tokens = resp.usage.get('total_tokens', 0)
  151. if not resp.choices[0].delta.content:
  152. continue
  153. ans += resp.choices[0].delta.content
  154. if resp.choices[0].finish_reason == "length":
  155. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  156. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  157. yield ans
  158. except Exception as e:
  159. yield ans + "\n**ERROR**: " + str(e)
  160. yield total_tokens
  161. class QWenChat(Base):
  162. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  163. import dashscope
  164. dashscope.api_key = key
  165. self.model_name = model_name
  166. def chat(self, system, history, gen_conf):
  167. from http import HTTPStatus
  168. if system:
  169. history.insert(0, {"role": "system", "content": system})
  170. response = Generation.call(
  171. self.model_name,
  172. messages=history,
  173. result_format='message',
  174. **gen_conf
  175. )
  176. ans = ""
  177. tk_count = 0
  178. if response.status_code == HTTPStatus.OK:
  179. ans += response.output.choices[0]['message']['content']
  180. tk_count += response.usage.total_tokens
  181. if response.output.choices[0].get("finish_reason", "") == "length":
  182. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  183. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  184. return ans, tk_count
  185. return "**ERROR**: " + response.message, tk_count
  186. def chat_streamly(self, system, history, gen_conf):
  187. from http import HTTPStatus
  188. if system:
  189. history.insert(0, {"role": "system", "content": system})
  190. ans = ""
  191. tk_count = 0
  192. try:
  193. response = Generation.call(
  194. self.model_name,
  195. messages=history,
  196. result_format='message',
  197. stream=True,
  198. **gen_conf
  199. )
  200. for resp in response:
  201. if resp.status_code == HTTPStatus.OK:
  202. ans = resp.output.choices[0]['message']['content']
  203. tk_count = resp.usage.total_tokens
  204. if resp.output.choices[0].get("finish_reason", "") == "length":
  205. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  206. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  207. yield ans
  208. else:
  209. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  210. except Exception as e:
  211. yield ans + "\n**ERROR**: " + str(e)
  212. yield tk_count
  213. class ZhipuChat(Base):
  214. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  215. self.client = ZhipuAI(api_key=key)
  216. self.model_name = model_name
  217. def chat(self, system, history, gen_conf):
  218. if system:
  219. history.insert(0, {"role": "system", "content": system})
  220. try:
  221. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  222. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  223. response = self.client.chat.completions.create(
  224. model=self.model_name,
  225. messages=history,
  226. **gen_conf
  227. )
  228. ans = response.choices[0].message.content.strip()
  229. if response.choices[0].finish_reason == "length":
  230. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  231. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  232. return ans, response.usage.total_tokens
  233. except Exception as e:
  234. return "**ERROR**: " + str(e), 0
  235. def chat_streamly(self, system, history, gen_conf):
  236. if system:
  237. history.insert(0, {"role": "system", "content": system})
  238. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  239. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  240. ans = ""
  241. tk_count = 0
  242. try:
  243. response = self.client.chat.completions.create(
  244. model=self.model_name,
  245. messages=history,
  246. stream=True,
  247. **gen_conf
  248. )
  249. for resp in response:
  250. if not resp.choices[0].delta.content:continue
  251. delta = resp.choices[0].delta.content
  252. ans += delta
  253. if resp.choices[0].finish_reason == "length":
  254. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  255. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  256. tk_count = resp.usage.total_tokens
  257. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  258. yield ans
  259. except Exception as e:
  260. yield ans + "\n**ERROR**: " + str(e)
  261. yield tk_count
  262. class OllamaChat(Base):
  263. def __init__(self, key, model_name, **kwargs):
  264. self.client = Client(host=kwargs["base_url"])
  265. self.model_name = model_name
  266. def chat(self, system, history, gen_conf):
  267. if system:
  268. history.insert(0, {"role": "system", "content": system})
  269. try:
  270. options = {}
  271. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  272. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  273. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  274. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  275. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  276. response = self.client.chat(
  277. model=self.model_name,
  278. messages=history,
  279. options=options,
  280. keep_alive=-1
  281. )
  282. ans = response["message"]["content"].strip()
  283. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  284. except Exception as e:
  285. return "**ERROR**: " + str(e), 0
  286. def chat_streamly(self, system, history, gen_conf):
  287. if system:
  288. history.insert(0, {"role": "system", "content": system})
  289. options = {}
  290. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  291. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  292. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  293. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  294. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  295. ans = ""
  296. try:
  297. response = self.client.chat(
  298. model=self.model_name,
  299. messages=history,
  300. stream=True,
  301. options=options,
  302. keep_alive=-1
  303. )
  304. for resp in response:
  305. if resp["done"]:
  306. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  307. ans += resp["message"]["content"]
  308. yield ans
  309. except Exception as e:
  310. yield ans + "\n**ERROR**: " + str(e)
  311. yield 0
  312. class LocalLLM(Base):
  313. class RPCProxy:
  314. def __init__(self, host, port):
  315. self.host = host
  316. self.port = int(port)
  317. self.__conn()
  318. def __conn(self):
  319. from multiprocessing.connection import Client
  320. self._connection = Client(
  321. (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  322. def __getattr__(self, name):
  323. import pickle
  324. def do_rpc(*args, **kwargs):
  325. for _ in range(3):
  326. try:
  327. self._connection.send(
  328. pickle.dumps((name, args, kwargs)))
  329. return pickle.loads(self._connection.recv())
  330. except Exception as e:
  331. self.__conn()
  332. raise Exception("RPC connection lost!")
  333. return do_rpc
  334. def __init__(self, key, model_name="glm-3-turbo"):
  335. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  336. def chat(self, system, history, gen_conf):
  337. if system:
  338. history.insert(0, {"role": "system", "content": system})
  339. try:
  340. ans = self.client.chat(
  341. history,
  342. gen_conf
  343. )
  344. return ans, num_tokens_from_string(ans)
  345. except Exception as e:
  346. return "**ERROR**: " + str(e), 0
  347. def chat_streamly(self, system, history, gen_conf):
  348. if system:
  349. history.insert(0, {"role": "system", "content": system})
  350. token_count = 0
  351. answer = ""
  352. try:
  353. for ans in self.client.chat_streamly(history, gen_conf):
  354. answer += ans
  355. token_count += 1
  356. yield answer
  357. except Exception as e:
  358. yield answer + "\n**ERROR**: " + str(e)
  359. yield token_count
  360. class VolcEngineChat(Base):
  361. def __init__(self, key, model_name, base_url):
  362. """
  363. Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
  364. Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
  365. model_name is for display only
  366. """
  367. self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
  368. self.volc_ak = eval(key).get('volc_ak', '')
  369. self.volc_sk = eval(key).get('volc_sk', '')
  370. self.client.set_ak(self.volc_ak)
  371. self.client.set_sk(self.volc_sk)
  372. self.model_name = eval(key).get('ep_id', '')
  373. def chat(self, system, history, gen_conf):
  374. if system:
  375. history.insert(0, {"role": "system", "content": system})
  376. try:
  377. req = {
  378. "parameters": {
  379. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  380. "top_k": gen_conf.get("top_k", 0),
  381. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  382. "temperature": gen_conf.get("temperature", 0.1),
  383. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  384. "top_p": gen_conf.get("top_p", 0.3),
  385. },
  386. "messages": history
  387. }
  388. response = self.client.chat(self.model_name, req)
  389. ans = response.choices[0].message.content.strip()
  390. if response.choices[0].finish_reason == "length":
  391. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  392. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  393. return ans, response.usage.total_tokens
  394. except Exception as e:
  395. return "**ERROR**: " + str(e), 0
  396. def chat_streamly(self, system, history, gen_conf):
  397. if system:
  398. history.insert(0, {"role": "system", "content": system})
  399. ans = ""
  400. tk_count = 0
  401. try:
  402. req = {
  403. "parameters": {
  404. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  405. "top_k": gen_conf.get("top_k", 0),
  406. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  407. "temperature": gen_conf.get("temperature", 0.1),
  408. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  409. "top_p": gen_conf.get("top_p", 0.3),
  410. },
  411. "messages": history
  412. }
  413. stream = self.client.stream_chat(self.model_name, req)
  414. for resp in stream:
  415. if not resp.choices[0].message.content:
  416. continue
  417. ans += resp.choices[0].message.content
  418. if resp.choices[0].finish_reason == "stop":
  419. tk_count = resp.usage.total_tokens
  420. yield ans
  421. except Exception as e:
  422. yield ans + "\n**ERROR**: " + str(e)
  423. yield tk_count
  424. class MiniMaxChat(Base):
  425. def __init__(
  426. self,
  427. key,
  428. model_name,
  429. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
  430. ):
  431. if not base_url:
  432. base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
  433. self.base_url = base_url
  434. self.model_name = model_name
  435. self.api_key = key
  436. def chat(self, system, history, gen_conf):
  437. if system:
  438. history.insert(0, {"role": "system", "content": system})
  439. for k in list(gen_conf.keys()):
  440. if k not in ["temperature", "top_p", "max_tokens"]:
  441. del gen_conf[k]
  442. headers = {
  443. "Authorization": f"Bearer {self.api_key}",
  444. "Content-Type": "application/json",
  445. }
  446. payload = json.dumps(
  447. {"model": self.model_name, "messages": history, **gen_conf}
  448. )
  449. try:
  450. response = requests.request(
  451. "POST", url=self.base_url, headers=headers, data=payload
  452. )
  453. print(response, flush=True)
  454. response = response.json()
  455. ans = response["choices"][0]["message"]["content"].strip()
  456. if response["choices"][0]["finish_reason"] == "length":
  457. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  458. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  459. return ans, response["usage"]["total_tokens"]
  460. except Exception as e:
  461. return "**ERROR**: " + str(e), 0
  462. def chat_streamly(self, system, history, gen_conf):
  463. if system:
  464. history.insert(0, {"role": "system", "content": system})
  465. ans = ""
  466. total_tokens = 0
  467. try:
  468. headers = {
  469. "Authorization": f"Bearer {self.api_key}",
  470. "Content-Type": "application/json",
  471. }
  472. payload = json.dumps(
  473. {
  474. "model": self.model_name,
  475. "messages": history,
  476. "stream": True,
  477. **gen_conf,
  478. }
  479. )
  480. response = requests.request(
  481. "POST",
  482. url=self.base_url,
  483. headers=headers,
  484. data=payload,
  485. )
  486. for resp in response.text.split("\n\n")[:-1]:
  487. resp = json.loads(resp[6:])
  488. if "delta" in resp["choices"][0]:
  489. text = resp["choices"][0]["delta"]["content"]
  490. else:
  491. continue
  492. ans += text
  493. total_tokens += num_tokens_from_string(text)
  494. yield ans
  495. except Exception as e:
  496. yield ans + "\n**ERROR**: " + str(e)
  497. yield total_tokens
  498. class MistralChat(Base):
  499. def __init__(self, key, model_name, base_url=None):
  500. from mistralai.client import MistralClient
  501. self.client = MistralClient(api_key=key)
  502. self.model_name = model_name
  503. def chat(self, system, history, gen_conf):
  504. if system:
  505. history.insert(0, {"role": "system", "content": system})
  506. for k in list(gen_conf.keys()):
  507. if k not in ["temperature", "top_p", "max_tokens"]:
  508. del gen_conf[k]
  509. try:
  510. response = self.client.chat(
  511. model=self.model_name,
  512. messages=history,
  513. **gen_conf)
  514. ans = response.choices[0].message.content
  515. if response.choices[0].finish_reason == "length":
  516. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  517. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  518. return ans, response.usage.total_tokens
  519. except openai.APIError as e:
  520. return "**ERROR**: " + str(e), 0
  521. def chat_streamly(self, system, history, gen_conf):
  522. if system:
  523. history.insert(0, {"role": "system", "content": system})
  524. for k in list(gen_conf.keys()):
  525. if k not in ["temperature", "top_p", "max_tokens"]:
  526. del gen_conf[k]
  527. ans = ""
  528. total_tokens = 0
  529. try:
  530. response = self.client.chat_stream(
  531. model=self.model_name,
  532. messages=history,
  533. **gen_conf)
  534. for resp in response:
  535. if not resp.choices or not resp.choices[0].delta.content:continue
  536. ans += resp.choices[0].delta.content
  537. total_tokens += 1
  538. if resp.choices[0].finish_reason == "length":
  539. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  540. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  541. yield ans
  542. except openai.APIError as e:
  543. yield ans + "\n**ERROR**: " + str(e)
  544. yield total_tokens
  545. class BedrockChat(Base):
  546. def __init__(self, key, model_name, **kwargs):
  547. import boto3
  548. self.bedrock_ak = eval(key).get('bedrock_ak', '')
  549. self.bedrock_sk = eval(key).get('bedrock_sk', '')
  550. self.bedrock_region = eval(key).get('bedrock_region', '')
  551. self.model_name = model_name
  552. self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
  553. aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
  554. def chat(self, system, history, gen_conf):
  555. from botocore.exceptions import ClientError
  556. if system:
  557. history.insert(0, {"role": "system", "content": system})
  558. for k in list(gen_conf.keys()):
  559. if k not in ["temperature", "top_p", "max_tokens"]:
  560. del gen_conf[k]
  561. if "max_tokens" in gen_conf:
  562. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  563. _ = gen_conf.pop("max_tokens")
  564. if "top_p" in gen_conf:
  565. gen_conf["topP"] = gen_conf["top_p"]
  566. _ = gen_conf.pop("top_p")
  567. try:
  568. # Send the message to the model, using a basic inference configuration.
  569. response = self.client.converse(
  570. modelId=self.model_name,
  571. messages=history,
  572. inferenceConfig=gen_conf
  573. )
  574. # Extract and print the response text.
  575. ans = response["output"]["message"]["content"][0]["text"]
  576. return ans, num_tokens_from_string(ans)
  577. except (ClientError, Exception) as e:
  578. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  579. def chat_streamly(self, system, history, gen_conf):
  580. from botocore.exceptions import ClientError
  581. if system:
  582. history.insert(0, {"role": "system", "content": system})
  583. for k in list(gen_conf.keys()):
  584. if k not in ["temperature", "top_p", "max_tokens"]:
  585. del gen_conf[k]
  586. if "max_tokens" in gen_conf:
  587. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  588. _ = gen_conf.pop("max_tokens")
  589. if "top_p" in gen_conf:
  590. gen_conf["topP"] = gen_conf["top_p"]
  591. _ = gen_conf.pop("top_p")
  592. if self.model_name.split('.')[0] == 'ai21':
  593. try:
  594. response = self.client.converse(
  595. modelId=self.model_name,
  596. messages=history,
  597. inferenceConfig=gen_conf
  598. )
  599. ans = response["output"]["message"]["content"][0]["text"]
  600. return ans, num_tokens_from_string(ans)
  601. except (ClientError, Exception) as e:
  602. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  603. ans = ""
  604. try:
  605. # Send the message to the model, using a basic inference configuration.
  606. streaming_response = self.client.converse_stream(
  607. modelId=self.model_name,
  608. messages=history,
  609. inferenceConfig=gen_conf
  610. )
  611. # Extract and print the streamed response text in real-time.
  612. for resp in streaming_response["stream"]:
  613. if "contentBlockDelta" in resp:
  614. ans += resp["contentBlockDelta"]["delta"]["text"]
  615. yield ans
  616. except (ClientError, Exception) as e:
  617. yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
  618. yield num_tokens_from_string(ans)
  619. class GeminiChat(Base):
  620. def __init__(self, key, model_name,base_url=None):
  621. from google.generativeai import client,GenerativeModel
  622. client.configure(api_key=key)
  623. _client = client.get_default_generative_client()
  624. self.model_name = 'models/' + model_name
  625. self.model = GenerativeModel(model_name=self.model_name)
  626. self.model._client = _client
  627. def chat(self,system,history,gen_conf):
  628. if system:
  629. history.insert(0, {"role": "user", "parts": system})
  630. if 'max_tokens' in gen_conf:
  631. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  632. for k in list(gen_conf.keys()):
  633. if k not in ["temperature", "top_p", "max_output_tokens"]:
  634. del gen_conf[k]
  635. for item in history:
  636. if 'role' in item and item['role'] == 'assistant':
  637. item['role'] = 'model'
  638. if 'content' in item :
  639. item['parts'] = item.pop('content')
  640. try:
  641. response = self.model.generate_content(
  642. history,
  643. generation_config=gen_conf)
  644. ans = response.text
  645. return ans, response.usage_metadata.total_token_count
  646. except Exception as e:
  647. return "**ERROR**: " + str(e), 0
  648. def chat_streamly(self, system, history, gen_conf):
  649. if system:
  650. history.insert(0, {"role": "user", "parts": system})
  651. if 'max_tokens' in gen_conf:
  652. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  653. for k in list(gen_conf.keys()):
  654. if k not in ["temperature", "top_p", "max_output_tokens"]:
  655. del gen_conf[k]
  656. for item in history:
  657. if 'role' in item and item['role'] == 'assistant':
  658. item['role'] = 'model'
  659. if 'content' in item :
  660. item['parts'] = item.pop('content')
  661. ans = ""
  662. try:
  663. response = self.model.generate_content(
  664. history,
  665. generation_config=gen_conf,stream=True)
  666. for resp in response:
  667. ans += resp.text
  668. yield ans
  669. except Exception as e:
  670. yield ans + "\n**ERROR**: " + str(e)
  671. yield response._chunks[-1].usage_metadata.total_token_count
  672. class GroqChat:
  673. def __init__(self, key, model_name,base_url=''):
  674. self.client = Groq(api_key=key)
  675. self.model_name = model_name
  676. def chat(self, system, history, gen_conf):
  677. if system:
  678. history.insert(0, {"role": "system", "content": system})
  679. for k in list(gen_conf.keys()):
  680. if k not in ["temperature", "top_p", "max_tokens"]:
  681. del gen_conf[k]
  682. ans = ""
  683. try:
  684. response = self.client.chat.completions.create(
  685. model=self.model_name,
  686. messages=history,
  687. **gen_conf
  688. )
  689. ans = response.choices[0].message.content
  690. if response.choices[0].finish_reason == "length":
  691. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  692. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  693. return ans, response.usage.total_tokens
  694. except Exception as e:
  695. return ans + "\n**ERROR**: " + str(e), 0
  696. def chat_streamly(self, system, history, gen_conf):
  697. if system:
  698. history.insert(0, {"role": "system", "content": system})
  699. for k in list(gen_conf.keys()):
  700. if k not in ["temperature", "top_p", "max_tokens"]:
  701. del gen_conf[k]
  702. ans = ""
  703. total_tokens = 0
  704. try:
  705. response = self.client.chat.completions.create(
  706. model=self.model_name,
  707. messages=history,
  708. stream=True,
  709. **gen_conf
  710. )
  711. for resp in response:
  712. if not resp.choices or not resp.choices[0].delta.content:
  713. continue
  714. ans += resp.choices[0].delta.content
  715. total_tokens += 1
  716. if resp.choices[0].finish_reason == "length":
  717. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  718. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  719. yield ans
  720. except Exception as e:
  721. yield ans + "\n**ERROR**: " + str(e)
  722. yield total_tokens
  723. ## openrouter
  724. class OpenRouterChat(Base):
  725. def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
  726. self.base_url = "https://openrouter.ai/api/v1"
  727. self.client = OpenAI(base_url=self.base_url, api_key=key)
  728. self.model_name = model_name