You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

chat_model.py 35KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. from dashscope import Generation
  19. from abc import ABC
  20. from openai import OpenAI
  21. import openai
  22. from ollama import Client
  23. from volcengine.maas.v2 import MaasService
  24. from rag.nlp import is_english
  25. from rag.utils import num_tokens_from_string
  26. from groq import Groq
  27. import os
  28. import json
  29. import requests
  30. import asyncio
  31. from rag.svr.jina_server import Prompt,Generation
  32. class Base(ABC):
  33. def __init__(self, key, model_name, base_url):
  34. self.client = OpenAI(api_key=key, base_url=base_url)
  35. self.model_name = model_name
  36. def chat(self, system, history, gen_conf):
  37. if system:
  38. history.insert(0, {"role": "system", "content": system})
  39. try:
  40. response = self.client.chat.completions.create(
  41. model=self.model_name,
  42. messages=history,
  43. **gen_conf)
  44. ans = response.choices[0].message.content.strip()
  45. if response.choices[0].finish_reason == "length":
  46. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  47. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  48. return ans, response.usage.total_tokens
  49. except openai.APIError as e:
  50. return "**ERROR**: " + str(e), 0
  51. def chat_streamly(self, system, history, gen_conf):
  52. if system:
  53. history.insert(0, {"role": "system", "content": system})
  54. ans = ""
  55. total_tokens = 0
  56. try:
  57. response = self.client.chat.completions.create(
  58. model=self.model_name,
  59. messages=history,
  60. stream=True,
  61. **gen_conf)
  62. for resp in response:
  63. if not resp.choices:continue
  64. if not resp.choices[0].delta.content:
  65. resp.choices[0].delta.content = ""
  66. ans += resp.choices[0].delta.content
  67. total_tokens = (
  68. (
  69. total_tokens
  70. + num_tokens_from_string(resp.choices[0].delta.content)
  71. )
  72. if not hasattr(resp, "usage")
  73. else resp.usage["total_tokens"]
  74. )
  75. if resp.choices[0].finish_reason == "length":
  76. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  77. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  78. yield ans
  79. except openai.APIError as e:
  80. yield ans + "\n**ERROR**: " + str(e)
  81. yield total_tokens
  82. class GptTurbo(Base):
  83. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  84. if not base_url: base_url="https://api.openai.com/v1"
  85. super().__init__(key, model_name, base_url)
  86. class MoonshotChat(Base):
  87. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  88. if not base_url: base_url="https://api.moonshot.cn/v1"
  89. super().__init__(key, model_name, base_url)
  90. class XinferenceChat(Base):
  91. def __init__(self, key=None, model_name="", base_url=""):
  92. if not base_url:
  93. raise ValueError("Local llm url cannot be None")
  94. if base_url.split("/")[-1] != "v1":
  95. self.base_url = os.path.join(base_url, "v1")
  96. key = "xxx"
  97. super().__init__(key, model_name, base_url)
  98. class DeepSeekChat(Base):
  99. def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
  100. if not base_url: base_url="https://api.deepseek.com/v1"
  101. super().__init__(key, model_name, base_url)
  102. class AzureChat(Base):
  103. def __init__(self, key, model_name, **kwargs):
  104. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  105. self.model_name = model_name
  106. class BaiChuanChat(Base):
  107. def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
  108. if not base_url:
  109. base_url = "https://api.baichuan-ai.com/v1"
  110. super().__init__(key, model_name, base_url)
  111. @staticmethod
  112. def _format_params(params):
  113. return {
  114. "temperature": params.get("temperature", 0.3),
  115. "max_tokens": params.get("max_tokens", 2048),
  116. "top_p": params.get("top_p", 0.85),
  117. }
  118. def chat(self, system, history, gen_conf):
  119. if system:
  120. history.insert(0, {"role": "system", "content": system})
  121. try:
  122. response = self.client.chat.completions.create(
  123. model=self.model_name,
  124. messages=history,
  125. extra_body={
  126. "tools": [{
  127. "type": "web_search",
  128. "web_search": {
  129. "enable": True,
  130. "search_mode": "performance_first"
  131. }
  132. }]
  133. },
  134. **self._format_params(gen_conf))
  135. ans = response.choices[0].message.content.strip()
  136. if response.choices[0].finish_reason == "length":
  137. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  138. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  139. return ans, response.usage.total_tokens
  140. except openai.APIError as e:
  141. return "**ERROR**: " + str(e), 0
  142. def chat_streamly(self, system, history, gen_conf):
  143. if system:
  144. history.insert(0, {"role": "system", "content": system})
  145. ans = ""
  146. total_tokens = 0
  147. try:
  148. response = self.client.chat.completions.create(
  149. model=self.model_name,
  150. messages=history,
  151. extra_body={
  152. "tools": [{
  153. "type": "web_search",
  154. "web_search": {
  155. "enable": True,
  156. "search_mode": "performance_first"
  157. }
  158. }]
  159. },
  160. stream=True,
  161. **self._format_params(gen_conf))
  162. for resp in response:
  163. if resp.choices[0].finish_reason == "stop":
  164. if not resp.choices[0].delta.content:
  165. continue
  166. total_tokens = resp.usage.get('total_tokens', 0)
  167. if not resp.choices[0].delta.content:
  168. continue
  169. ans += resp.choices[0].delta.content
  170. if resp.choices[0].finish_reason == "length":
  171. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  172. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  173. yield ans
  174. except Exception as e:
  175. yield ans + "\n**ERROR**: " + str(e)
  176. yield total_tokens
  177. class QWenChat(Base):
  178. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  179. import dashscope
  180. dashscope.api_key = key
  181. self.model_name = model_name
  182. def chat(self, system, history, gen_conf):
  183. from http import HTTPStatus
  184. if system:
  185. history.insert(0, {"role": "system", "content": system})
  186. response = Generation.call(
  187. self.model_name,
  188. messages=history,
  189. result_format='message',
  190. **gen_conf
  191. )
  192. ans = ""
  193. tk_count = 0
  194. if response.status_code == HTTPStatus.OK:
  195. ans += response.output.choices[0]['message']['content']
  196. tk_count += response.usage.total_tokens
  197. if response.output.choices[0].get("finish_reason", "") == "length":
  198. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  199. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  200. return ans, tk_count
  201. return "**ERROR**: " + response.message, tk_count
  202. def chat_streamly(self, system, history, gen_conf):
  203. from http import HTTPStatus
  204. if system:
  205. history.insert(0, {"role": "system", "content": system})
  206. ans = ""
  207. tk_count = 0
  208. try:
  209. response = Generation.call(
  210. self.model_name,
  211. messages=history,
  212. result_format='message',
  213. stream=True,
  214. **gen_conf
  215. )
  216. for resp in response:
  217. if resp.status_code == HTTPStatus.OK:
  218. ans = resp.output.choices[0]['message']['content']
  219. tk_count = resp.usage.total_tokens
  220. if resp.output.choices[0].get("finish_reason", "") == "length":
  221. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  222. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  223. yield ans
  224. else:
  225. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  226. except Exception as e:
  227. yield ans + "\n**ERROR**: " + str(e)
  228. yield tk_count
  229. class ZhipuChat(Base):
  230. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  231. self.client = ZhipuAI(api_key=key)
  232. self.model_name = model_name
  233. def chat(self, system, history, gen_conf):
  234. if system:
  235. history.insert(0, {"role": "system", "content": system})
  236. try:
  237. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  238. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  239. response = self.client.chat.completions.create(
  240. model=self.model_name,
  241. messages=history,
  242. **gen_conf
  243. )
  244. ans = response.choices[0].message.content.strip()
  245. if response.choices[0].finish_reason == "length":
  246. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  247. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  248. return ans, response.usage.total_tokens
  249. except Exception as e:
  250. return "**ERROR**: " + str(e), 0
  251. def chat_streamly(self, system, history, gen_conf):
  252. if system:
  253. history.insert(0, {"role": "system", "content": system})
  254. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  255. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  256. ans = ""
  257. tk_count = 0
  258. try:
  259. response = self.client.chat.completions.create(
  260. model=self.model_name,
  261. messages=history,
  262. stream=True,
  263. **gen_conf
  264. )
  265. for resp in response:
  266. if not resp.choices[0].delta.content:continue
  267. delta = resp.choices[0].delta.content
  268. ans += delta
  269. if resp.choices[0].finish_reason == "length":
  270. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  271. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  272. tk_count = resp.usage.total_tokens
  273. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  274. yield ans
  275. except Exception as e:
  276. yield ans + "\n**ERROR**: " + str(e)
  277. yield tk_count
  278. class OllamaChat(Base):
  279. def __init__(self, key, model_name, **kwargs):
  280. self.client = Client(host=kwargs["base_url"])
  281. self.model_name = model_name
  282. def chat(self, system, history, gen_conf):
  283. if system:
  284. history.insert(0, {"role": "system", "content": system})
  285. try:
  286. options = {}
  287. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  288. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  289. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  290. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  291. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  292. response = self.client.chat(
  293. model=self.model_name,
  294. messages=history,
  295. options=options,
  296. keep_alive=-1
  297. )
  298. ans = response["message"]["content"].strip()
  299. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  300. except Exception as e:
  301. return "**ERROR**: " + str(e), 0
  302. def chat_streamly(self, system, history, gen_conf):
  303. if system:
  304. history.insert(0, {"role": "system", "content": system})
  305. options = {}
  306. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  307. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  308. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  309. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  310. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  311. ans = ""
  312. try:
  313. response = self.client.chat(
  314. model=self.model_name,
  315. messages=history,
  316. stream=True,
  317. options=options,
  318. keep_alive=-1
  319. )
  320. for resp in response:
  321. if resp["done"]:
  322. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  323. ans += resp["message"]["content"]
  324. yield ans
  325. except Exception as e:
  326. yield ans + "\n**ERROR**: " + str(e)
  327. yield 0
  328. class LocalAIChat(Base):
  329. def __init__(self, key, model_name, base_url):
  330. if not base_url:
  331. raise ValueError("Local llm url cannot be None")
  332. if base_url.split("/")[-1] != "v1":
  333. self.base_url = os.path.join(base_url, "v1")
  334. self.client = OpenAI(api_key="empty", base_url=self.base_url)
  335. self.model_name = model_name.split("___")[0]
  336. class LocalLLM(Base):
  337. class RPCProxy:
  338. def __init__(self, host, port):
  339. self.host = host
  340. self.port = int(port)
  341. self.__conn()
  342. def __conn(self):
  343. from multiprocessing.connection import Client
  344. self._connection = Client(
  345. (self.host, self.port), authkey=b"infiniflow-token4kevinhu"
  346. )
  347. def __getattr__(self, name):
  348. import pickle
  349. def do_rpc(*args, **kwargs):
  350. for _ in range(3):
  351. try:
  352. self._connection.send(pickle.dumps((name, args, kwargs)))
  353. return pickle.loads(self._connection.recv())
  354. except Exception as e:
  355. self.__conn()
  356. raise Exception("RPC connection lost!")
  357. return do_rpc
  358. def __init__(self, key, model_name):
  359. from jina import Client
  360. self.client = Client(port=12345, protocol="grpc", asyncio=True)
  361. def _prepare_prompt(self, system, history, gen_conf):
  362. if system:
  363. history.insert(0, {"role": "system", "content": system})
  364. if "max_tokens" in gen_conf:
  365. gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
  366. return Prompt(message=history, gen_conf=gen_conf)
  367. def _stream_response(self, endpoint, prompt):
  368. answer = ""
  369. try:
  370. res = self.client.stream_doc(
  371. on=endpoint, inputs=prompt, return_type=Generation
  372. )
  373. loop = asyncio.get_event_loop()
  374. try:
  375. while True:
  376. answer = loop.run_until_complete(res.__anext__()).text
  377. yield answer
  378. except StopAsyncIteration:
  379. pass
  380. except Exception as e:
  381. yield answer + "\n**ERROR**: " + str(e)
  382. yield num_tokens_from_string(answer)
  383. def chat(self, system, history, gen_conf):
  384. prompt = self._prepare_prompt(system, history, gen_conf)
  385. chat_gen = self._stream_response("/chat", prompt)
  386. ans = next(chat_gen)
  387. total_tokens = next(chat_gen)
  388. return ans, total_tokens
  389. def chat_streamly(self, system, history, gen_conf):
  390. prompt = self._prepare_prompt(system, history, gen_conf)
  391. return self._stream_response("/stream", prompt)
  392. class VolcEngineChat(Base):
  393. def __init__(self, key, model_name, base_url):
  394. """
  395. Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
  396. Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
  397. model_name is for display only
  398. """
  399. self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
  400. self.volc_ak = eval(key).get('volc_ak', '')
  401. self.volc_sk = eval(key).get('volc_sk', '')
  402. self.client.set_ak(self.volc_ak)
  403. self.client.set_sk(self.volc_sk)
  404. self.model_name = eval(key).get('ep_id', '')
  405. def chat(self, system, history, gen_conf):
  406. if system:
  407. history.insert(0, {"role": "system", "content": system})
  408. try:
  409. req = {
  410. "parameters": {
  411. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  412. "top_k": gen_conf.get("top_k", 0),
  413. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  414. "temperature": gen_conf.get("temperature", 0.1),
  415. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  416. "top_p": gen_conf.get("top_p", 0.3),
  417. },
  418. "messages": history
  419. }
  420. response = self.client.chat(self.model_name, req)
  421. ans = response.choices[0].message.content.strip()
  422. if response.choices[0].finish_reason == "length":
  423. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  424. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  425. return ans, response.usage.total_tokens
  426. except Exception as e:
  427. return "**ERROR**: " + str(e), 0
  428. def chat_streamly(self, system, history, gen_conf):
  429. if system:
  430. history.insert(0, {"role": "system", "content": system})
  431. ans = ""
  432. tk_count = 0
  433. try:
  434. req = {
  435. "parameters": {
  436. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  437. "top_k": gen_conf.get("top_k", 0),
  438. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  439. "temperature": gen_conf.get("temperature", 0.1),
  440. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  441. "top_p": gen_conf.get("top_p", 0.3),
  442. },
  443. "messages": history
  444. }
  445. stream = self.client.stream_chat(self.model_name, req)
  446. for resp in stream:
  447. if not resp.choices[0].message.content:
  448. continue
  449. ans += resp.choices[0].message.content
  450. if resp.choices[0].finish_reason == "stop":
  451. tk_count = resp.usage.total_tokens
  452. yield ans
  453. except Exception as e:
  454. yield ans + "\n**ERROR**: " + str(e)
  455. yield tk_count
  456. class MiniMaxChat(Base):
  457. def __init__(
  458. self,
  459. key,
  460. model_name,
  461. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
  462. ):
  463. if not base_url:
  464. base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
  465. self.base_url = base_url
  466. self.model_name = model_name
  467. self.api_key = key
  468. def chat(self, system, history, gen_conf):
  469. if system:
  470. history.insert(0, {"role": "system", "content": system})
  471. for k in list(gen_conf.keys()):
  472. if k not in ["temperature", "top_p", "max_tokens"]:
  473. del gen_conf[k]
  474. headers = {
  475. "Authorization": f"Bearer {self.api_key}",
  476. "Content-Type": "application/json",
  477. }
  478. payload = json.dumps(
  479. {"model": self.model_name, "messages": history, **gen_conf}
  480. )
  481. try:
  482. response = requests.request(
  483. "POST", url=self.base_url, headers=headers, data=payload
  484. )
  485. response = response.json()
  486. ans = response["choices"][0]["message"]["content"].strip()
  487. if response["choices"][0]["finish_reason"] == "length":
  488. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  489. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  490. return ans, response["usage"]["total_tokens"]
  491. except Exception as e:
  492. return "**ERROR**: " + str(e), 0
  493. def chat_streamly(self, system, history, gen_conf):
  494. if system:
  495. history.insert(0, {"role": "system", "content": system})
  496. ans = ""
  497. total_tokens = 0
  498. try:
  499. headers = {
  500. "Authorization": f"Bearer {self.api_key}",
  501. "Content-Type": "application/json",
  502. }
  503. payload = json.dumps(
  504. {
  505. "model": self.model_name,
  506. "messages": history,
  507. "stream": True,
  508. **gen_conf,
  509. }
  510. )
  511. response = requests.request(
  512. "POST",
  513. url=self.base_url,
  514. headers=headers,
  515. data=payload,
  516. )
  517. for resp in response.text.split("\n\n")[:-1]:
  518. resp = json.loads(resp[6:])
  519. text = ""
  520. if "choices" in resp and "delta" in resp["choices"][0]:
  521. text = resp["choices"][0]["delta"]["content"]
  522. ans += text
  523. total_tokens = (
  524. total_tokens + num_tokens_from_string(text)
  525. if "usage" not in resp
  526. else resp["usage"]["total_tokens"]
  527. )
  528. yield ans
  529. except Exception as e:
  530. yield ans + "\n**ERROR**: " + str(e)
  531. yield total_tokens
  532. class MistralChat(Base):
  533. def __init__(self, key, model_name, base_url=None):
  534. from mistralai.client import MistralClient
  535. self.client = MistralClient(api_key=key)
  536. self.model_name = model_name
  537. def chat(self, system, history, gen_conf):
  538. if system:
  539. history.insert(0, {"role": "system", "content": system})
  540. for k in list(gen_conf.keys()):
  541. if k not in ["temperature", "top_p", "max_tokens"]:
  542. del gen_conf[k]
  543. try:
  544. response = self.client.chat(
  545. model=self.model_name,
  546. messages=history,
  547. **gen_conf)
  548. ans = response.choices[0].message.content
  549. if response.choices[0].finish_reason == "length":
  550. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  551. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  552. return ans, response.usage.total_tokens
  553. except openai.APIError as e:
  554. return "**ERROR**: " + str(e), 0
  555. def chat_streamly(self, system, history, gen_conf):
  556. if system:
  557. history.insert(0, {"role": "system", "content": system})
  558. for k in list(gen_conf.keys()):
  559. if k not in ["temperature", "top_p", "max_tokens"]:
  560. del gen_conf[k]
  561. ans = ""
  562. total_tokens = 0
  563. try:
  564. response = self.client.chat_stream(
  565. model=self.model_name,
  566. messages=history,
  567. **gen_conf)
  568. for resp in response:
  569. if not resp.choices or not resp.choices[0].delta.content:continue
  570. ans += resp.choices[0].delta.content
  571. total_tokens += 1
  572. if resp.choices[0].finish_reason == "length":
  573. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  574. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  575. yield ans
  576. except openai.APIError as e:
  577. yield ans + "\n**ERROR**: " + str(e)
  578. yield total_tokens
  579. class BedrockChat(Base):
  580. def __init__(self, key, model_name, **kwargs):
  581. import boto3
  582. self.bedrock_ak = eval(key).get('bedrock_ak', '')
  583. self.bedrock_sk = eval(key).get('bedrock_sk', '')
  584. self.bedrock_region = eval(key).get('bedrock_region', '')
  585. self.model_name = model_name
  586. self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
  587. aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
  588. def chat(self, system, history, gen_conf):
  589. from botocore.exceptions import ClientError
  590. if system:
  591. history.insert(0, {"role": "system", "content": system})
  592. for k in list(gen_conf.keys()):
  593. if k not in ["temperature", "top_p", "max_tokens"]:
  594. del gen_conf[k]
  595. if "max_tokens" in gen_conf:
  596. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  597. _ = gen_conf.pop("max_tokens")
  598. if "top_p" in gen_conf:
  599. gen_conf["topP"] = gen_conf["top_p"]
  600. _ = gen_conf.pop("top_p")
  601. try:
  602. # Send the message to the model, using a basic inference configuration.
  603. response = self.client.converse(
  604. modelId=self.model_name,
  605. messages=history,
  606. inferenceConfig=gen_conf
  607. )
  608. # Extract and print the response text.
  609. ans = response["output"]["message"]["content"][0]["text"]
  610. return ans, num_tokens_from_string(ans)
  611. except (ClientError, Exception) as e:
  612. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  613. def chat_streamly(self, system, history, gen_conf):
  614. from botocore.exceptions import ClientError
  615. if system:
  616. history.insert(0, {"role": "system", "content": system})
  617. for k in list(gen_conf.keys()):
  618. if k not in ["temperature", "top_p", "max_tokens"]:
  619. del gen_conf[k]
  620. if "max_tokens" in gen_conf:
  621. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  622. _ = gen_conf.pop("max_tokens")
  623. if "top_p" in gen_conf:
  624. gen_conf["topP"] = gen_conf["top_p"]
  625. _ = gen_conf.pop("top_p")
  626. if self.model_name.split('.')[0] == 'ai21':
  627. try:
  628. response = self.client.converse(
  629. modelId=self.model_name,
  630. messages=history,
  631. inferenceConfig=gen_conf
  632. )
  633. ans = response["output"]["message"]["content"][0]["text"]
  634. return ans, num_tokens_from_string(ans)
  635. except (ClientError, Exception) as e:
  636. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  637. ans = ""
  638. try:
  639. # Send the message to the model, using a basic inference configuration.
  640. streaming_response = self.client.converse_stream(
  641. modelId=self.model_name,
  642. messages=history,
  643. inferenceConfig=gen_conf
  644. )
  645. # Extract and print the streamed response text in real-time.
  646. for resp in streaming_response["stream"]:
  647. if "contentBlockDelta" in resp:
  648. ans += resp["contentBlockDelta"]["delta"]["text"]
  649. yield ans
  650. except (ClientError, Exception) as e:
  651. yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
  652. yield num_tokens_from_string(ans)
  653. class GeminiChat(Base):
  654. def __init__(self, key, model_name,base_url=None):
  655. from google.generativeai import client,GenerativeModel
  656. client.configure(api_key=key)
  657. _client = client.get_default_generative_client()
  658. self.model_name = 'models/' + model_name
  659. self.model = GenerativeModel(model_name=self.model_name)
  660. self.model._client = _client
  661. def chat(self,system,history,gen_conf):
  662. if system:
  663. history.insert(0, {"role": "user", "parts": system})
  664. if 'max_tokens' in gen_conf:
  665. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  666. for k in list(gen_conf.keys()):
  667. if k not in ["temperature", "top_p", "max_output_tokens"]:
  668. del gen_conf[k]
  669. for item in history:
  670. if 'role' in item and item['role'] == 'assistant':
  671. item['role'] = 'model'
  672. if 'content' in item :
  673. item['parts'] = item.pop('content')
  674. try:
  675. response = self.model.generate_content(
  676. history,
  677. generation_config=gen_conf)
  678. ans = response.text
  679. return ans, response.usage_metadata.total_token_count
  680. except Exception as e:
  681. return "**ERROR**: " + str(e), 0
  682. def chat_streamly(self, system, history, gen_conf):
  683. if system:
  684. history.insert(0, {"role": "user", "parts": system})
  685. if 'max_tokens' in gen_conf:
  686. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  687. for k in list(gen_conf.keys()):
  688. if k not in ["temperature", "top_p", "max_output_tokens"]:
  689. del gen_conf[k]
  690. for item in history:
  691. if 'role' in item and item['role'] == 'assistant':
  692. item['role'] = 'model'
  693. if 'content' in item :
  694. item['parts'] = item.pop('content')
  695. ans = ""
  696. try:
  697. response = self.model.generate_content(
  698. history,
  699. generation_config=gen_conf,stream=True)
  700. for resp in response:
  701. ans += resp.text
  702. yield ans
  703. except Exception as e:
  704. yield ans + "\n**ERROR**: " + str(e)
  705. yield response._chunks[-1].usage_metadata.total_token_count
  706. class GroqChat:
  707. def __init__(self, key, model_name,base_url=''):
  708. self.client = Groq(api_key=key)
  709. self.model_name = model_name
  710. def chat(self, system, history, gen_conf):
  711. if system:
  712. history.insert(0, {"role": "system", "content": system})
  713. for k in list(gen_conf.keys()):
  714. if k not in ["temperature", "top_p", "max_tokens"]:
  715. del gen_conf[k]
  716. ans = ""
  717. try:
  718. response = self.client.chat.completions.create(
  719. model=self.model_name,
  720. messages=history,
  721. **gen_conf
  722. )
  723. ans = response.choices[0].message.content
  724. if response.choices[0].finish_reason == "length":
  725. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  726. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  727. return ans, response.usage.total_tokens
  728. except Exception as e:
  729. return ans + "\n**ERROR**: " + str(e), 0
  730. def chat_streamly(self, system, history, gen_conf):
  731. if system:
  732. history.insert(0, {"role": "system", "content": system})
  733. for k in list(gen_conf.keys()):
  734. if k not in ["temperature", "top_p", "max_tokens"]:
  735. del gen_conf[k]
  736. ans = ""
  737. total_tokens = 0
  738. try:
  739. response = self.client.chat.completions.create(
  740. model=self.model_name,
  741. messages=history,
  742. stream=True,
  743. **gen_conf
  744. )
  745. for resp in response:
  746. if not resp.choices or not resp.choices[0].delta.content:
  747. continue
  748. ans += resp.choices[0].delta.content
  749. total_tokens += 1
  750. if resp.choices[0].finish_reason == "length":
  751. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  752. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  753. yield ans
  754. except Exception as e:
  755. yield ans + "\n**ERROR**: " + str(e)
  756. yield total_tokens
  757. ## openrouter
  758. class OpenRouterChat(Base):
  759. def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
  760. if not base_url:
  761. base_url = "https://openrouter.ai/api/v1"
  762. super().__init__(key, model_name, base_url)
  763. class StepFunChat(Base):
  764. def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
  765. if not base_url:
  766. base_url = "https://api.stepfun.com/v1"
  767. super().__init__(key, model_name, base_url)
  768. class NvidiaChat(Base):
  769. def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"):
  770. if not base_url:
  771. base_url = "https://integrate.api.nvidia.com/v1"
  772. super().__init__(key, model_name, base_url)
  773. class LmStudioChat(Base):
  774. def __init__(self, key, model_name, base_url):
  775. if not base_url:
  776. raise ValueError("Local llm url cannot be None")
  777. if base_url.split("/")[-1] != "v1":
  778. self.base_url = os.path.join(base_url, "v1")
  779. self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
  780. self.model_name = model_name