Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. from dashscope import Generation
  19. from abc import ABC
  20. from openai import OpenAI
  21. import openai
  22. from ollama import Client
  23. from volcengine.maas.v2 import MaasService
  24. from rag.nlp import is_english
  25. from rag.utils import num_tokens_from_string
  26. from groq import Groq
  27. class Base(ABC):
  28. def __init__(self, key, model_name, base_url):
  29. self.client = OpenAI(api_key=key, base_url=base_url)
  30. self.model_name = model_name
  31. def chat(self, system, history, gen_conf):
  32. if system:
  33. history.insert(0, {"role": "system", "content": system})
  34. try:
  35. response = self.client.chat.completions.create(
  36. model=self.model_name,
  37. messages=history,
  38. **gen_conf)
  39. ans = response.choices[0].message.content.strip()
  40. if response.choices[0].finish_reason == "length":
  41. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  42. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  43. return ans, response.usage.total_tokens
  44. except openai.APIError as e:
  45. return "**ERROR**: " + str(e), 0
  46. def chat_streamly(self, system, history, gen_conf):
  47. if system:
  48. history.insert(0, {"role": "system", "content": system})
  49. ans = ""
  50. total_tokens = 0
  51. try:
  52. response = self.client.chat.completions.create(
  53. model=self.model_name,
  54. messages=history,
  55. stream=True,
  56. **gen_conf)
  57. for resp in response:
  58. if not resp.choices or not resp.choices[0].delta.content:continue
  59. ans += resp.choices[0].delta.content
  60. total_tokens += 1
  61. if resp.choices[0].finish_reason == "length":
  62. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  63. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  64. yield ans
  65. except openai.APIError as e:
  66. yield ans + "\n**ERROR**: " + str(e)
  67. yield total_tokens
  68. class GptTurbo(Base):
  69. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  70. if not base_url: base_url="https://api.openai.com/v1"
  71. super().__init__(key, model_name, base_url)
  72. class MoonshotChat(Base):
  73. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  74. if not base_url: base_url="https://api.moonshot.cn/v1"
  75. super().__init__(key, model_name, base_url)
  76. class XinferenceChat(Base):
  77. def __init__(self, key=None, model_name="", base_url=""):
  78. key = "xxx"
  79. super().__init__(key, model_name, base_url)
  80. class DeepSeekChat(Base):
  81. def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
  82. if not base_url: base_url="https://api.deepseek.com/v1"
  83. super().__init__(key, model_name, base_url)
  84. class AzureChat(Base):
  85. def __init__(self, key, model_name, **kwargs):
  86. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  87. self.model_name = model_name
  88. class BaiChuanChat(Base):
  89. def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
  90. if not base_url:
  91. base_url = "https://api.baichuan-ai.com/v1"
  92. super().__init__(key, model_name, base_url)
  93. @staticmethod
  94. def _format_params(params):
  95. return {
  96. "temperature": params.get("temperature", 0.3),
  97. "max_tokens": params.get("max_tokens", 2048),
  98. "top_p": params.get("top_p", 0.85),
  99. }
  100. def chat(self, system, history, gen_conf):
  101. if system:
  102. history.insert(0, {"role": "system", "content": system})
  103. try:
  104. response = self.client.chat.completions.create(
  105. model=self.model_name,
  106. messages=history,
  107. extra_body={
  108. "tools": [{
  109. "type": "web_search",
  110. "web_search": {
  111. "enable": True,
  112. "search_mode": "performance_first"
  113. }
  114. }]
  115. },
  116. **self._format_params(gen_conf))
  117. ans = response.choices[0].message.content.strip()
  118. if response.choices[0].finish_reason == "length":
  119. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  120. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  121. return ans, response.usage.total_tokens
  122. except openai.APIError as e:
  123. return "**ERROR**: " + str(e), 0
  124. def chat_streamly(self, system, history, gen_conf):
  125. if system:
  126. history.insert(0, {"role": "system", "content": system})
  127. ans = ""
  128. total_tokens = 0
  129. try:
  130. response = self.client.chat.completions.create(
  131. model=self.model_name,
  132. messages=history,
  133. extra_body={
  134. "tools": [{
  135. "type": "web_search",
  136. "web_search": {
  137. "enable": True,
  138. "search_mode": "performance_first"
  139. }
  140. }]
  141. },
  142. stream=True,
  143. **self._format_params(gen_conf))
  144. for resp in response:
  145. if resp.choices[0].finish_reason == "stop":
  146. if not resp.choices[0].delta.content:
  147. continue
  148. total_tokens = resp.usage.get('total_tokens', 0)
  149. if not resp.choices[0].delta.content:
  150. continue
  151. ans += resp.choices[0].delta.content
  152. if resp.choices[0].finish_reason == "length":
  153. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  154. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  155. yield ans
  156. except Exception as e:
  157. yield ans + "\n**ERROR**: " + str(e)
  158. yield total_tokens
  159. class QWenChat(Base):
  160. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  161. import dashscope
  162. dashscope.api_key = key
  163. self.model_name = model_name
  164. def chat(self, system, history, gen_conf):
  165. from http import HTTPStatus
  166. if system:
  167. history.insert(0, {"role": "system", "content": system})
  168. response = Generation.call(
  169. self.model_name,
  170. messages=history,
  171. result_format='message',
  172. **gen_conf
  173. )
  174. ans = ""
  175. tk_count = 0
  176. if response.status_code == HTTPStatus.OK:
  177. ans += response.output.choices[0]['message']['content']
  178. tk_count += response.usage.total_tokens
  179. if response.output.choices[0].get("finish_reason", "") == "length":
  180. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  181. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  182. return ans, tk_count
  183. return "**ERROR**: " + response.message, tk_count
  184. def chat_streamly(self, system, history, gen_conf):
  185. from http import HTTPStatus
  186. if system:
  187. history.insert(0, {"role": "system", "content": system})
  188. ans = ""
  189. tk_count = 0
  190. try:
  191. response = Generation.call(
  192. self.model_name,
  193. messages=history,
  194. result_format='message',
  195. stream=True,
  196. **gen_conf
  197. )
  198. for resp in response:
  199. if resp.status_code == HTTPStatus.OK:
  200. ans = resp.output.choices[0]['message']['content']
  201. tk_count = resp.usage.total_tokens
  202. if resp.output.choices[0].get("finish_reason", "") == "length":
  203. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  204. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  205. yield ans
  206. else:
  207. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  208. except Exception as e:
  209. yield ans + "\n**ERROR**: " + str(e)
  210. yield tk_count
  211. class ZhipuChat(Base):
  212. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  213. self.client = ZhipuAI(api_key=key)
  214. self.model_name = model_name
  215. def chat(self, system, history, gen_conf):
  216. if system:
  217. history.insert(0, {"role": "system", "content": system})
  218. try:
  219. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  220. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  221. response = self.client.chat.completions.create(
  222. model=self.model_name,
  223. messages=history,
  224. **gen_conf
  225. )
  226. ans = response.choices[0].message.content.strip()
  227. if response.choices[0].finish_reason == "length":
  228. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  229. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  230. return ans, response.usage.total_tokens
  231. except Exception as e:
  232. return "**ERROR**: " + str(e), 0
  233. def chat_streamly(self, system, history, gen_conf):
  234. if system:
  235. history.insert(0, {"role": "system", "content": system})
  236. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  237. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  238. ans = ""
  239. tk_count = 0
  240. try:
  241. response = self.client.chat.completions.create(
  242. model=self.model_name,
  243. messages=history,
  244. stream=True,
  245. **gen_conf
  246. )
  247. for resp in response:
  248. if not resp.choices[0].delta.content:continue
  249. delta = resp.choices[0].delta.content
  250. ans += delta
  251. if resp.choices[0].finish_reason == "length":
  252. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  253. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  254. tk_count = resp.usage.total_tokens
  255. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  256. yield ans
  257. except Exception as e:
  258. yield ans + "\n**ERROR**: " + str(e)
  259. yield tk_count
  260. class OllamaChat(Base):
  261. def __init__(self, key, model_name, **kwargs):
  262. self.client = Client(host=kwargs["base_url"])
  263. self.model_name = model_name
  264. def chat(self, system, history, gen_conf):
  265. if system:
  266. history.insert(0, {"role": "system", "content": system})
  267. try:
  268. options = {}
  269. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  270. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  271. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  272. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  273. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  274. response = self.client.chat(
  275. model=self.model_name,
  276. messages=history,
  277. options=options,
  278. keep_alive=-1
  279. )
  280. ans = response["message"]["content"].strip()
  281. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  282. except Exception as e:
  283. return "**ERROR**: " + str(e), 0
  284. def chat_streamly(self, system, history, gen_conf):
  285. if system:
  286. history.insert(0, {"role": "system", "content": system})
  287. options = {}
  288. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  289. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  290. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  291. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  292. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  293. ans = ""
  294. try:
  295. response = self.client.chat(
  296. model=self.model_name,
  297. messages=history,
  298. stream=True,
  299. options=options,
  300. keep_alive=-1
  301. )
  302. for resp in response:
  303. if resp["done"]:
  304. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  305. ans += resp["message"]["content"]
  306. yield ans
  307. except Exception as e:
  308. yield ans + "\n**ERROR**: " + str(e)
  309. yield 0
  310. class LocalLLM(Base):
  311. class RPCProxy:
  312. def __init__(self, host, port):
  313. self.host = host
  314. self.port = int(port)
  315. self.__conn()
  316. def __conn(self):
  317. from multiprocessing.connection import Client
  318. self._connection = Client(
  319. (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  320. def __getattr__(self, name):
  321. import pickle
  322. def do_rpc(*args, **kwargs):
  323. for _ in range(3):
  324. try:
  325. self._connection.send(
  326. pickle.dumps((name, args, kwargs)))
  327. return pickle.loads(self._connection.recv())
  328. except Exception as e:
  329. self.__conn()
  330. raise Exception("RPC connection lost!")
  331. return do_rpc
  332. def __init__(self, key, model_name="glm-3-turbo"):
  333. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  334. def chat(self, system, history, gen_conf):
  335. if system:
  336. history.insert(0, {"role": "system", "content": system})
  337. try:
  338. ans = self.client.chat(
  339. history,
  340. gen_conf
  341. )
  342. return ans, num_tokens_from_string(ans)
  343. except Exception as e:
  344. return "**ERROR**: " + str(e), 0
  345. def chat_streamly(self, system, history, gen_conf):
  346. if system:
  347. history.insert(0, {"role": "system", "content": system})
  348. token_count = 0
  349. answer = ""
  350. try:
  351. for ans in self.client.chat_streamly(history, gen_conf):
  352. answer += ans
  353. token_count += 1
  354. yield answer
  355. except Exception as e:
  356. yield answer + "\n**ERROR**: " + str(e)
  357. yield token_count
  358. class VolcEngineChat(Base):
  359. def __init__(self, key, model_name, base_url):
  360. """
  361. Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
  362. Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
  363. model_name is for display only
  364. """
  365. self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
  366. self.volc_ak = eval(key).get('volc_ak', '')
  367. self.volc_sk = eval(key).get('volc_sk', '')
  368. self.client.set_ak(self.volc_ak)
  369. self.client.set_sk(self.volc_sk)
  370. self.model_name = eval(key).get('ep_id', '')
  371. def chat(self, system, history, gen_conf):
  372. if system:
  373. history.insert(0, {"role": "system", "content": system})
  374. try:
  375. req = {
  376. "parameters": {
  377. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  378. "top_k": gen_conf.get("top_k", 0),
  379. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  380. "temperature": gen_conf.get("temperature", 0.1),
  381. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  382. "top_p": gen_conf.get("top_p", 0.3),
  383. },
  384. "messages": history
  385. }
  386. response = self.client.chat(self.model_name, req)
  387. ans = response.choices[0].message.content.strip()
  388. if response.choices[0].finish_reason == "length":
  389. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  390. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  391. return ans, response.usage.total_tokens
  392. except Exception as e:
  393. return "**ERROR**: " + str(e), 0
  394. def chat_streamly(self, system, history, gen_conf):
  395. if system:
  396. history.insert(0, {"role": "system", "content": system})
  397. ans = ""
  398. tk_count = 0
  399. try:
  400. req = {
  401. "parameters": {
  402. "min_new_tokens": gen_conf.get("min_new_tokens", 1),
  403. "top_k": gen_conf.get("top_k", 0),
  404. "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
  405. "temperature": gen_conf.get("temperature", 0.1),
  406. "max_new_tokens": gen_conf.get("max_tokens", 1000),
  407. "top_p": gen_conf.get("top_p", 0.3),
  408. },
  409. "messages": history
  410. }
  411. stream = self.client.stream_chat(self.model_name, req)
  412. for resp in stream:
  413. if not resp.choices[0].message.content:
  414. continue
  415. ans += resp.choices[0].message.content
  416. if resp.choices[0].finish_reason == "stop":
  417. tk_count = resp.usage.total_tokens
  418. yield ans
  419. except Exception as e:
  420. yield ans + "\n**ERROR**: " + str(e)
  421. yield tk_count
  422. class MiniMaxChat(Base):
  423. def __init__(self, key, model_name="abab6.5s-chat",
  424. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"):
  425. if not base_url:
  426. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
  427. super().__init__(key, model_name, base_url)
  428. class MistralChat(Base):
  429. def __init__(self, key, model_name, base_url=None):
  430. from mistralai.client import MistralClient
  431. self.client = MistralClient(api_key=key)
  432. self.model_name = model_name
  433. def chat(self, system, history, gen_conf):
  434. if system:
  435. history.insert(0, {"role": "system", "content": system})
  436. for k in list(gen_conf.keys()):
  437. if k not in ["temperature", "top_p", "max_tokens"]:
  438. del gen_conf[k]
  439. try:
  440. response = self.client.chat(
  441. model=self.model_name,
  442. messages=history,
  443. **gen_conf)
  444. ans = response.choices[0].message.content
  445. if response.choices[0].finish_reason == "length":
  446. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  447. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  448. return ans, response.usage.total_tokens
  449. except openai.APIError as e:
  450. return "**ERROR**: " + str(e), 0
  451. def chat_streamly(self, system, history, gen_conf):
  452. if system:
  453. history.insert(0, {"role": "system", "content": system})
  454. for k in list(gen_conf.keys()):
  455. if k not in ["temperature", "top_p", "max_tokens"]:
  456. del gen_conf[k]
  457. ans = ""
  458. total_tokens = 0
  459. try:
  460. response = self.client.chat_stream(
  461. model=self.model_name,
  462. messages=history,
  463. **gen_conf)
  464. for resp in response:
  465. if not resp.choices or not resp.choices[0].delta.content:continue
  466. ans += resp.choices[0].delta.content
  467. total_tokens += 1
  468. if resp.choices[0].finish_reason == "length":
  469. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  470. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  471. yield ans
  472. except openai.APIError as e:
  473. yield ans + "\n**ERROR**: " + str(e)
  474. yield total_tokens
  475. class BedrockChat(Base):
  476. def __init__(self, key, model_name, **kwargs):
  477. import boto3
  478. self.bedrock_ak = eval(key).get('bedrock_ak', '')
  479. self.bedrock_sk = eval(key).get('bedrock_sk', '')
  480. self.bedrock_region = eval(key).get('bedrock_region', '')
  481. self.model_name = model_name
  482. self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
  483. aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
  484. def chat(self, system, history, gen_conf):
  485. from botocore.exceptions import ClientError
  486. if system:
  487. history.insert(0, {"role": "system", "content": system})
  488. for k in list(gen_conf.keys()):
  489. if k not in ["temperature", "top_p", "max_tokens"]:
  490. del gen_conf[k]
  491. if "max_tokens" in gen_conf:
  492. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  493. _ = gen_conf.pop("max_tokens")
  494. if "top_p" in gen_conf:
  495. gen_conf["topP"] = gen_conf["top_p"]
  496. _ = gen_conf.pop("top_p")
  497. try:
  498. # Send the message to the model, using a basic inference configuration.
  499. response = self.client.converse(
  500. modelId=self.model_name,
  501. messages=history,
  502. inferenceConfig=gen_conf
  503. )
  504. # Extract and print the response text.
  505. ans = response["output"]["message"]["content"][0]["text"]
  506. return ans, num_tokens_from_string(ans)
  507. except (ClientError, Exception) as e:
  508. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  509. def chat_streamly(self, system, history, gen_conf):
  510. from botocore.exceptions import ClientError
  511. if system:
  512. history.insert(0, {"role": "system", "content": system})
  513. for k in list(gen_conf.keys()):
  514. if k not in ["temperature", "top_p", "max_tokens"]:
  515. del gen_conf[k]
  516. if "max_tokens" in gen_conf:
  517. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  518. _ = gen_conf.pop("max_tokens")
  519. if "top_p" in gen_conf:
  520. gen_conf["topP"] = gen_conf["top_p"]
  521. _ = gen_conf.pop("top_p")
  522. if self.model_name.split('.')[0] == 'ai21':
  523. try:
  524. response = self.client.converse(
  525. modelId=self.model_name,
  526. messages=history,
  527. inferenceConfig=gen_conf
  528. )
  529. ans = response["output"]["message"]["content"][0]["text"]
  530. return ans, num_tokens_from_string(ans)
  531. except (ClientError, Exception) as e:
  532. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  533. ans = ""
  534. try:
  535. # Send the message to the model, using a basic inference configuration.
  536. streaming_response = self.client.converse_stream(
  537. modelId=self.model_name,
  538. messages=history,
  539. inferenceConfig=gen_conf
  540. )
  541. # Extract and print the streamed response text in real-time.
  542. for resp in streaming_response["stream"]:
  543. if "contentBlockDelta" in resp:
  544. ans += resp["contentBlockDelta"]["delta"]["text"]
  545. yield ans
  546. except (ClientError, Exception) as e:
  547. yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
  548. yield num_tokens_from_string(ans)
  549. class GeminiChat(Base):
  550. def __init__(self, key, model_name,base_url=None):
  551. from google.generativeai import client,GenerativeModel
  552. client.configure(api_key=key)
  553. _client = client.get_default_generative_client()
  554. self.model_name = 'models/' + model_name
  555. self.model = GenerativeModel(model_name=self.model_name)
  556. self.model._client = _client
  557. def chat(self,system,history,gen_conf):
  558. if system:
  559. history.insert(0, {"role": "user", "parts": system})
  560. if 'max_tokens' in gen_conf:
  561. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  562. for k in list(gen_conf.keys()):
  563. if k not in ["temperature", "top_p", "max_output_tokens"]:
  564. del gen_conf[k]
  565. for item in history:
  566. if 'role' in item and item['role'] == 'assistant':
  567. item['role'] = 'model'
  568. if 'content' in item :
  569. item['parts'] = item.pop('content')
  570. try:
  571. response = self.model.generate_content(
  572. history,
  573. generation_config=gen_conf)
  574. ans = response.text
  575. return ans, response.usage_metadata.total_token_count
  576. except Exception as e:
  577. return "**ERROR**: " + str(e), 0
  578. def chat_streamly(self, system, history, gen_conf):
  579. if system:
  580. history.insert(0, {"role": "user", "parts": system})
  581. if 'max_tokens' in gen_conf:
  582. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  583. for k in list(gen_conf.keys()):
  584. if k not in ["temperature", "top_p", "max_output_tokens"]:
  585. del gen_conf[k]
  586. for item in history:
  587. if 'role' in item and item['role'] == 'assistant':
  588. item['role'] = 'model'
  589. if 'content' in item :
  590. item['parts'] = item.pop('content')
  591. ans = ""
  592. try:
  593. response = self.model.generate_content(
  594. history,
  595. generation_config=gen_conf,stream=True)
  596. for resp in response:
  597. ans += resp.text
  598. yield ans
  599. except Exception as e:
  600. yield ans + "\n**ERROR**: " + str(e)
  601. yield response._chunks[-1].usage_metadata.total_token_count
  602. class GroqChat:
  603. def __init__(self, key, model_name,base_url=''):
  604. self.client = Groq(api_key=key)
  605. self.model_name = model_name
  606. def chat(self, system, history, gen_conf):
  607. if system:
  608. history.insert(0, {"role": "system", "content": system})
  609. for k in list(gen_conf.keys()):
  610. if k not in ["temperature", "top_p", "max_tokens"]:
  611. del gen_conf[k]
  612. ans = ""
  613. try:
  614. response = self.client.chat.completions.create(
  615. model=self.model_name,
  616. messages=history,
  617. **gen_conf
  618. )
  619. ans = response.choices[0].message.content
  620. if response.choices[0].finish_reason == "length":
  621. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  622. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  623. return ans, response.usage.total_tokens
  624. except Exception as e:
  625. return ans + "\n**ERROR**: " + str(e), 0
  626. def chat_streamly(self, system, history, gen_conf):
  627. if system:
  628. history.insert(0, {"role": "system", "content": system})
  629. for k in list(gen_conf.keys()):
  630. if k not in ["temperature", "top_p", "max_tokens"]:
  631. del gen_conf[k]
  632. ans = ""
  633. total_tokens = 0
  634. try:
  635. response = self.client.chat.completions.create(
  636. model=self.model_name,
  637. messages=history,
  638. stream=True,
  639. **gen_conf
  640. )
  641. for resp in response:
  642. if not resp.choices or not resp.choices[0].delta.content:
  643. continue
  644. ans += resp.choices[0].delta.content
  645. total_tokens += 1
  646. if resp.choices[0].finish_reason == "length":
  647. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  648. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  649. yield ans
  650. except Exception as e:
  651. yield ans + "\n**ERROR**: " + str(e)
  652. yield total_tokens
  653. ## openrouter
  654. class OpenRouterChat(Base):
  655. def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
  656. self.base_url = "https://openrouter.ai/api/v1"
  657. self.client = OpenAI(base_url=self.base_url, api_key=key)
  658. self.model_name = model_name