選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

chat_model.py 55KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. from dashscope import Generation
  19. from abc import ABC
  20. from openai import OpenAI
  21. import openai
  22. from ollama import Client
  23. from rag.nlp import is_english
  24. from rag.utils import num_tokens_from_string
  25. from groq import Groq
  26. import os
  27. import json
  28. import requests
  29. import asyncio
  30. class Base(ABC):
  31. def __init__(self, key, model_name, base_url):
  32. timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
  33. self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
  34. self.model_name = model_name
  35. def chat(self, system, history, gen_conf):
  36. if system:
  37. history.insert(0, {"role": "system", "content": system})
  38. try:
  39. response = self.client.chat.completions.create(
  40. model=self.model_name,
  41. messages=history,
  42. **gen_conf)
  43. ans = response.choices[0].message.content.strip()
  44. if response.choices[0].finish_reason == "length":
  45. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  46. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  47. return ans, response.usage.total_tokens
  48. except openai.APIError as e:
  49. return "**ERROR**: " + str(e), 0
  50. def chat_streamly(self, system, history, gen_conf):
  51. if system:
  52. history.insert(0, {"role": "system", "content": system})
  53. ans = ""
  54. total_tokens = 0
  55. try:
  56. response = self.client.chat.completions.create(
  57. model=self.model_name,
  58. messages=history,
  59. stream=True,
  60. **gen_conf)
  61. for resp in response:
  62. if not resp.choices: continue
  63. if not resp.choices[0].delta.content:
  64. resp.choices[0].delta.content = ""
  65. ans += resp.choices[0].delta.content
  66. total_tokens += 1
  67. if not hasattr(resp, "usage") or not resp.usage:
  68. total_tokens = (
  69. total_tokens
  70. + num_tokens_from_string(resp.choices[0].delta.content)
  71. )
  72. elif isinstance(resp.usage, dict):
  73. total_tokens = resp.usage.get("total_tokens", total_tokens)
  74. else: total_tokens = resp.usage.total_tokens
  75. if resp.choices[0].finish_reason == "length":
  76. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  77. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  78. yield ans
  79. except openai.APIError as e:
  80. yield ans + "\n**ERROR**: " + str(e)
  81. yield total_tokens
  82. class GptTurbo(Base):
  83. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  84. if not base_url: base_url = "https://api.openai.com/v1"
  85. super().__init__(key, model_name, base_url)
  86. class MoonshotChat(Base):
  87. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  88. if not base_url: base_url = "https://api.moonshot.cn/v1"
  89. super().__init__(key, model_name, base_url)
  90. class XinferenceChat(Base):
  91. def __init__(self, key=None, model_name="", base_url=""):
  92. if not base_url:
  93. raise ValueError("Local llm url cannot be None")
  94. if base_url.split("/")[-1] != "v1":
  95. base_url = os.path.join(base_url, "v1")
  96. super().__init__(key, model_name, base_url)
  97. class HuggingFaceChat(Base):
  98. def __init__(self, key=None, model_name="", base_url=""):
  99. if not base_url:
  100. raise ValueError("Local llm url cannot be None")
  101. if base_url.split("/")[-1] != "v1":
  102. base_url = os.path.join(base_url, "v1")
  103. super().__init__(key, model_name, base_url)
  104. class DeepSeekChat(Base):
  105. def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
  106. if not base_url: base_url = "https://api.deepseek.com/v1"
  107. super().__init__(key, model_name, base_url)
  108. class AzureChat(Base):
  109. def __init__(self, key, model_name, **kwargs):
  110. api_key = json.loads(key).get('api_key', '')
  111. api_version = json.loads(key).get('api_version', '2024-02-01')
  112. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  113. self.model_name = model_name
  114. class BaiChuanChat(Base):
  115. def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
  116. if not base_url:
  117. base_url = "https://api.baichuan-ai.com/v1"
  118. super().__init__(key, model_name, base_url)
  119. @staticmethod
  120. def _format_params(params):
  121. return {
  122. "temperature": params.get("temperature", 0.3),
  123. "max_tokens": params.get("max_tokens", 2048),
  124. "top_p": params.get("top_p", 0.85),
  125. }
  126. def chat(self, system, history, gen_conf):
  127. if system:
  128. history.insert(0, {"role": "system", "content": system})
  129. try:
  130. response = self.client.chat.completions.create(
  131. model=self.model_name,
  132. messages=history,
  133. extra_body={
  134. "tools": [{
  135. "type": "web_search",
  136. "web_search": {
  137. "enable": True,
  138. "search_mode": "performance_first"
  139. }
  140. }]
  141. },
  142. **self._format_params(gen_conf))
  143. ans = response.choices[0].message.content.strip()
  144. if response.choices[0].finish_reason == "length":
  145. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  146. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  147. return ans, response.usage.total_tokens
  148. except openai.APIError as e:
  149. return "**ERROR**: " + str(e), 0
  150. def chat_streamly(self, system, history, gen_conf):
  151. if system:
  152. history.insert(0, {"role": "system", "content": system})
  153. ans = ""
  154. total_tokens = 0
  155. try:
  156. response = self.client.chat.completions.create(
  157. model=self.model_name,
  158. messages=history,
  159. extra_body={
  160. "tools": [{
  161. "type": "web_search",
  162. "web_search": {
  163. "enable": True,
  164. "search_mode": "performance_first"
  165. }
  166. }]
  167. },
  168. stream=True,
  169. **self._format_params(gen_conf))
  170. for resp in response:
  171. if not resp.choices: continue
  172. if not resp.choices[0].delta.content:
  173. resp.choices[0].delta.content = ""
  174. ans += resp.choices[0].delta.content
  175. total_tokens = (
  176. (
  177. total_tokens
  178. + num_tokens_from_string(resp.choices[0].delta.content)
  179. )
  180. if not hasattr(resp, "usage")
  181. else resp.usage["total_tokens"]
  182. )
  183. if resp.choices[0].finish_reason == "length":
  184. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  185. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  186. yield ans
  187. except Exception as e:
  188. yield ans + "\n**ERROR**: " + str(e)
  189. yield total_tokens
  190. class QWenChat(Base):
  191. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  192. import dashscope
  193. dashscope.api_key = key
  194. self.model_name = model_name
  195. def chat(self, system, history, gen_conf):
  196. stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true'
  197. if not stream_flag:
  198. from http import HTTPStatus
  199. if system:
  200. history.insert(0, {"role": "system", "content": system})
  201. response = Generation.call(
  202. self.model_name,
  203. messages=history,
  204. result_format='message',
  205. **gen_conf
  206. )
  207. ans = ""
  208. tk_count = 0
  209. if response.status_code == HTTPStatus.OK:
  210. ans += response.output.choices[0]['message']['content']
  211. tk_count += response.usage.total_tokens
  212. if response.output.choices[0].get("finish_reason", "") == "length":
  213. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  214. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  215. return ans, tk_count
  216. return "**ERROR**: " + response.message, tk_count
  217. else:
  218. g = self._chat_streamly(system, history, gen_conf, incremental_output=True)
  219. result_list = list(g)
  220. error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0]
  221. if len(error_msg_list) > 0:
  222. return "**ERROR**: " + "".join(error_msg_list) , 0
  223. else:
  224. return "".join(result_list[:-1]), result_list[-1]
  225. def _chat_streamly(self, system, history, gen_conf, incremental_output=False):
  226. from http import HTTPStatus
  227. if system:
  228. history.insert(0, {"role": "system", "content": system})
  229. ans = ""
  230. tk_count = 0
  231. try:
  232. response = Generation.call(
  233. self.model_name,
  234. messages=history,
  235. result_format='message',
  236. stream=True,
  237. incremental_output=incremental_output,
  238. **gen_conf
  239. )
  240. for resp in response:
  241. if resp.status_code == HTTPStatus.OK:
  242. ans = resp.output.choices[0]['message']['content']
  243. tk_count = resp.usage.total_tokens
  244. if resp.output.choices[0].get("finish_reason", "") == "length":
  245. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  246. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  247. yield ans
  248. else:
  249. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
  250. "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  251. except Exception as e:
  252. yield ans + "\n**ERROR**: " + str(e)
  253. yield tk_count
  254. def chat_streamly(self, system, history, gen_conf):
  255. return self._chat_streamly(system, history, gen_conf)
  256. class ZhipuChat(Base):
  257. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  258. self.client = ZhipuAI(api_key=key)
  259. self.model_name = model_name
  260. def chat(self, system, history, gen_conf):
  261. if system:
  262. history.insert(0, {"role": "system", "content": system})
  263. try:
  264. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  265. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  266. response = self.client.chat.completions.create(
  267. model=self.model_name,
  268. messages=history,
  269. **gen_conf
  270. )
  271. ans = response.choices[0].message.content.strip()
  272. if response.choices[0].finish_reason == "length":
  273. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  274. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  275. return ans, response.usage.total_tokens
  276. except Exception as e:
  277. return "**ERROR**: " + str(e), 0
  278. def chat_streamly(self, system, history, gen_conf):
  279. if system:
  280. history.insert(0, {"role": "system", "content": system})
  281. if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
  282. if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
  283. ans = ""
  284. tk_count = 0
  285. try:
  286. response = self.client.chat.completions.create(
  287. model=self.model_name,
  288. messages=history,
  289. stream=True,
  290. **gen_conf
  291. )
  292. for resp in response:
  293. if not resp.choices[0].delta.content: continue
  294. delta = resp.choices[0].delta.content
  295. ans += delta
  296. if resp.choices[0].finish_reason == "length":
  297. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  298. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  299. tk_count = resp.usage.total_tokens
  300. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  301. yield ans
  302. except Exception as e:
  303. yield ans + "\n**ERROR**: " + str(e)
  304. yield tk_count
  305. class OllamaChat(Base):
  306. def __init__(self, key, model_name, **kwargs):
  307. self.client = Client(host=kwargs["base_url"])
  308. self.model_name = model_name
  309. def chat(self, system, history, gen_conf):
  310. if system:
  311. history.insert(0, {"role": "system", "content": system})
  312. try:
  313. options = {}
  314. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  315. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  316. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  317. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  318. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  319. response = self.client.chat(
  320. model=self.model_name,
  321. messages=history,
  322. options=options,
  323. keep_alive=-1
  324. )
  325. ans = response["message"]["content"].strip()
  326. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  327. except Exception as e:
  328. return "**ERROR**: " + str(e), 0
  329. def chat_streamly(self, system, history, gen_conf):
  330. if system:
  331. history.insert(0, {"role": "system", "content": system})
  332. options = {}
  333. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  334. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  335. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  336. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  337. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  338. ans = ""
  339. try:
  340. response = self.client.chat(
  341. model=self.model_name,
  342. messages=history,
  343. stream=True,
  344. options=options,
  345. keep_alive=-1
  346. )
  347. for resp in response:
  348. if resp["done"]:
  349. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  350. ans += resp["message"]["content"]
  351. yield ans
  352. except Exception as e:
  353. yield ans + "\n**ERROR**: " + str(e)
  354. yield 0
  355. class LocalAIChat(Base):
  356. def __init__(self, key, model_name, base_url):
  357. if not base_url:
  358. raise ValueError("Local llm url cannot be None")
  359. if base_url.split("/")[-1] != "v1":
  360. base_url = os.path.join(base_url, "v1")
  361. self.client = OpenAI(api_key="empty", base_url=base_url)
  362. self.model_name = model_name.split("___")[0]
  363. class LocalLLM(Base):
  364. class RPCProxy:
  365. def __init__(self, host, port):
  366. self.host = host
  367. self.port = int(port)
  368. self.__conn()
  369. def __conn(self):
  370. from multiprocessing.connection import Client
  371. self._connection = Client(
  372. (self.host, self.port), authkey=b"infiniflow-token4kevinhu"
  373. )
  374. def __getattr__(self, name):
  375. import pickle
  376. def do_rpc(*args, **kwargs):
  377. for _ in range(3):
  378. try:
  379. self._connection.send(pickle.dumps((name, args, kwargs)))
  380. return pickle.loads(self._connection.recv())
  381. except Exception as e:
  382. self.__conn()
  383. raise Exception("RPC connection lost!")
  384. return do_rpc
  385. def __init__(self, key, model_name):
  386. from jina import Client
  387. self.client = Client(port=12345, protocol="grpc", asyncio=True)
  388. def _prepare_prompt(self, system, history, gen_conf):
  389. from rag.svr.jina_server import Prompt, Generation
  390. if system:
  391. history.insert(0, {"role": "system", "content": system})
  392. if "max_tokens" in gen_conf:
  393. gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
  394. return Prompt(message=history, gen_conf=gen_conf)
  395. def _stream_response(self, endpoint, prompt):
  396. from rag.svr.jina_server import Prompt, Generation
  397. answer = ""
  398. try:
  399. res = self.client.stream_doc(
  400. on=endpoint, inputs=prompt, return_type=Generation
  401. )
  402. loop = asyncio.get_event_loop()
  403. try:
  404. while True:
  405. answer = loop.run_until_complete(res.__anext__()).text
  406. yield answer
  407. except StopAsyncIteration:
  408. pass
  409. except Exception as e:
  410. yield answer + "\n**ERROR**: " + str(e)
  411. yield num_tokens_from_string(answer)
  412. def chat(self, system, history, gen_conf):
  413. prompt = self._prepare_prompt(system, history, gen_conf)
  414. chat_gen = self._stream_response("/chat", prompt)
  415. ans = next(chat_gen)
  416. total_tokens = next(chat_gen)
  417. return ans, total_tokens
  418. def chat_streamly(self, system, history, gen_conf):
  419. prompt = self._prepare_prompt(system, history, gen_conf)
  420. return self._stream_response("/stream", prompt)
  421. class VolcEngineChat(Base):
  422. def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'):
  423. """
  424. Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
  425. Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
  426. model_name is for display only
  427. """
  428. base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
  429. ark_api_key = json.loads(key).get('ark_api_key', '')
  430. model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
  431. super().__init__(ark_api_key, model_name, base_url)
  432. class MiniMaxChat(Base):
  433. def __init__(
  434. self,
  435. key,
  436. model_name,
  437. base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
  438. ):
  439. if not base_url:
  440. base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
  441. self.base_url = base_url
  442. self.model_name = model_name
  443. self.api_key = key
  444. def chat(self, system, history, gen_conf):
  445. if system:
  446. history.insert(0, {"role": "system", "content": system})
  447. for k in list(gen_conf.keys()):
  448. if k not in ["temperature", "top_p", "max_tokens"]:
  449. del gen_conf[k]
  450. headers = {
  451. "Authorization": f"Bearer {self.api_key}",
  452. "Content-Type": "application/json",
  453. }
  454. payload = json.dumps(
  455. {"model": self.model_name, "messages": history, **gen_conf}
  456. )
  457. try:
  458. response = requests.request(
  459. "POST", url=self.base_url, headers=headers, data=payload
  460. )
  461. response = response.json()
  462. ans = response["choices"][0]["message"]["content"].strip()
  463. if response["choices"][0]["finish_reason"] == "length":
  464. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  465. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  466. return ans, response["usage"]["total_tokens"]
  467. except Exception as e:
  468. return "**ERROR**: " + str(e), 0
  469. def chat_streamly(self, system, history, gen_conf):
  470. if system:
  471. history.insert(0, {"role": "system", "content": system})
  472. ans = ""
  473. total_tokens = 0
  474. try:
  475. headers = {
  476. "Authorization": f"Bearer {self.api_key}",
  477. "Content-Type": "application/json",
  478. }
  479. payload = json.dumps(
  480. {
  481. "model": self.model_name,
  482. "messages": history,
  483. "stream": True,
  484. **gen_conf,
  485. }
  486. )
  487. response = requests.request(
  488. "POST",
  489. url=self.base_url,
  490. headers=headers,
  491. data=payload,
  492. )
  493. for resp in response.text.split("\n\n")[:-1]:
  494. resp = json.loads(resp[6:])
  495. text = ""
  496. if "choices" in resp and "delta" in resp["choices"][0]:
  497. text = resp["choices"][0]["delta"]["content"]
  498. ans += text
  499. total_tokens = (
  500. total_tokens + num_tokens_from_string(text)
  501. if "usage" not in resp
  502. else resp["usage"]["total_tokens"]
  503. )
  504. yield ans
  505. except Exception as e:
  506. yield ans + "\n**ERROR**: " + str(e)
  507. yield total_tokens
  508. class MistralChat(Base):
  509. def __init__(self, key, model_name, base_url=None):
  510. from mistralai.client import MistralClient
  511. self.client = MistralClient(api_key=key)
  512. self.model_name = model_name
  513. def chat(self, system, history, gen_conf):
  514. if system:
  515. history.insert(0, {"role": "system", "content": system})
  516. for k in list(gen_conf.keys()):
  517. if k not in ["temperature", "top_p", "max_tokens"]:
  518. del gen_conf[k]
  519. try:
  520. response = self.client.chat(
  521. model=self.model_name,
  522. messages=history,
  523. **gen_conf)
  524. ans = response.choices[0].message.content
  525. if response.choices[0].finish_reason == "length":
  526. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  527. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  528. return ans, response.usage.total_tokens
  529. except openai.APIError as e:
  530. return "**ERROR**: " + str(e), 0
  531. def chat_streamly(self, system, history, gen_conf):
  532. if system:
  533. history.insert(0, {"role": "system", "content": system})
  534. for k in list(gen_conf.keys()):
  535. if k not in ["temperature", "top_p", "max_tokens"]:
  536. del gen_conf[k]
  537. ans = ""
  538. total_tokens = 0
  539. try:
  540. response = self.client.chat_stream(
  541. model=self.model_name,
  542. messages=history,
  543. **gen_conf)
  544. for resp in response:
  545. if not resp.choices or not resp.choices[0].delta.content: continue
  546. ans += resp.choices[0].delta.content
  547. total_tokens += 1
  548. if resp.choices[0].finish_reason == "length":
  549. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  550. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  551. yield ans
  552. except openai.APIError as e:
  553. yield ans + "\n**ERROR**: " + str(e)
  554. yield total_tokens
  555. class BedrockChat(Base):
  556. def __init__(self, key, model_name, **kwargs):
  557. import boto3
  558. self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
  559. self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
  560. self.bedrock_region = json.loads(key).get('bedrock_region', '')
  561. self.model_name = model_name
  562. self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
  563. aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
  564. def chat(self, system, history, gen_conf):
  565. from botocore.exceptions import ClientError
  566. for k in list(gen_conf.keys()):
  567. if k not in ["temperature", "top_p", "max_tokens"]:
  568. del gen_conf[k]
  569. if "max_tokens" in gen_conf:
  570. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  571. _ = gen_conf.pop("max_tokens")
  572. if "top_p" in gen_conf:
  573. gen_conf["topP"] = gen_conf["top_p"]
  574. _ = gen_conf.pop("top_p")
  575. for item in history:
  576. if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
  577. item["content"] = [{"text": item["content"]}]
  578. try:
  579. # Send the message to the model, using a basic inference configuration.
  580. response = self.client.converse(
  581. modelId=self.model_name,
  582. messages=history,
  583. inferenceConfig=gen_conf,
  584. system=[{"text": (system if system else "Answer the user's message.")}],
  585. )
  586. # Extract and print the response text.
  587. ans = response["output"]["message"]["content"][0]["text"]
  588. return ans, num_tokens_from_string(ans)
  589. except (ClientError, Exception) as e:
  590. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  591. def chat_streamly(self, system, history, gen_conf):
  592. from botocore.exceptions import ClientError
  593. for k in list(gen_conf.keys()):
  594. if k not in ["temperature", "top_p", "max_tokens"]:
  595. del gen_conf[k]
  596. if "max_tokens" in gen_conf:
  597. gen_conf["maxTokens"] = gen_conf["max_tokens"]
  598. _ = gen_conf.pop("max_tokens")
  599. if "top_p" in gen_conf:
  600. gen_conf["topP"] = gen_conf["top_p"]
  601. _ = gen_conf.pop("top_p")
  602. for item in history:
  603. if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
  604. item["content"] = [{"text": item["content"]}]
  605. if self.model_name.split('.')[0] == 'ai21':
  606. try:
  607. response = self.client.converse(
  608. modelId=self.model_name,
  609. messages=history,
  610. inferenceConfig=gen_conf,
  611. system=[{"text": (system if system else "Answer the user's message.")}]
  612. )
  613. ans = response["output"]["message"]["content"][0]["text"]
  614. return ans, num_tokens_from_string(ans)
  615. except (ClientError, Exception) as e:
  616. return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
  617. ans = ""
  618. try:
  619. # Send the message to the model, using a basic inference configuration.
  620. streaming_response = self.client.converse_stream(
  621. modelId=self.model_name,
  622. messages=history,
  623. inferenceConfig=gen_conf,
  624. system=[{"text": (system if system else "Answer the user's message.")}]
  625. )
  626. # Extract and print the streamed response text in real-time.
  627. for resp in streaming_response["stream"]:
  628. if "contentBlockDelta" in resp:
  629. ans += resp["contentBlockDelta"]["delta"]["text"]
  630. yield ans
  631. except (ClientError, Exception) as e:
  632. yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
  633. yield num_tokens_from_string(ans)
  634. class GeminiChat(Base):
  635. def __init__(self, key, model_name, base_url=None):
  636. from google.generativeai import client, GenerativeModel
  637. client.configure(api_key=key)
  638. _client = client.get_default_generative_client()
  639. self.model_name = 'models/' + model_name
  640. self.model = GenerativeModel(model_name=self.model_name)
  641. self.model._client = _client
  642. def chat(self, system, history, gen_conf):
  643. from google.generativeai.types import content_types
  644. if system:
  645. self.model._system_instruction = content_types.to_content(system)
  646. if 'max_tokens' in gen_conf:
  647. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  648. for k in list(gen_conf.keys()):
  649. if k not in ["temperature", "top_p", "max_output_tokens"]:
  650. del gen_conf[k]
  651. for item in history:
  652. if 'role' in item and item['role'] == 'assistant':
  653. item['role'] = 'model'
  654. if 'role' in item and item['role'] == 'system':
  655. item['role'] = 'user'
  656. if 'content' in item:
  657. item['parts'] = item.pop('content')
  658. try:
  659. response = self.model.generate_content(
  660. history,
  661. generation_config=gen_conf)
  662. ans = response.text
  663. return ans, response.usage_metadata.total_token_count
  664. except Exception as e:
  665. return "**ERROR**: " + str(e), 0
  666. def chat_streamly(self, system, history, gen_conf):
  667. from google.generativeai.types import content_types
  668. if system:
  669. self.model._system_instruction = content_types.to_content(system)
  670. if 'max_tokens' in gen_conf:
  671. gen_conf['max_output_tokens'] = gen_conf['max_tokens']
  672. for k in list(gen_conf.keys()):
  673. if k not in ["temperature", "top_p", "max_output_tokens"]:
  674. del gen_conf[k]
  675. for item in history:
  676. if 'role' in item and item['role'] == 'assistant':
  677. item['role'] = 'model'
  678. if 'content' in item:
  679. item['parts'] = item.pop('content')
  680. ans = ""
  681. try:
  682. response = self.model.generate_content(
  683. history,
  684. generation_config=gen_conf, stream=True)
  685. for resp in response:
  686. ans += resp.text
  687. yield ans
  688. except Exception as e:
  689. yield ans + "\n**ERROR**: " + str(e)
  690. yield response._chunks[-1].usage_metadata.total_token_count
  691. class GroqChat:
  692. def __init__(self, key, model_name, base_url=''):
  693. self.client = Groq(api_key=key)
  694. self.model_name = model_name
  695. def chat(self, system, history, gen_conf):
  696. if system:
  697. history.insert(0, {"role": "system", "content": system})
  698. for k in list(gen_conf.keys()):
  699. if k not in ["temperature", "top_p", "max_tokens"]:
  700. del gen_conf[k]
  701. ans = ""
  702. try:
  703. response = self.client.chat.completions.create(
  704. model=self.model_name,
  705. messages=history,
  706. **gen_conf
  707. )
  708. ans = response.choices[0].message.content
  709. if response.choices[0].finish_reason == "length":
  710. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  711. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  712. return ans, response.usage.total_tokens
  713. except Exception as e:
  714. return ans + "\n**ERROR**: " + str(e), 0
  715. def chat_streamly(self, system, history, gen_conf):
  716. if system:
  717. history.insert(0, {"role": "system", "content": system})
  718. for k in list(gen_conf.keys()):
  719. if k not in ["temperature", "top_p", "max_tokens"]:
  720. del gen_conf[k]
  721. ans = ""
  722. total_tokens = 0
  723. try:
  724. response = self.client.chat.completions.create(
  725. model=self.model_name,
  726. messages=history,
  727. stream=True,
  728. **gen_conf
  729. )
  730. for resp in response:
  731. if not resp.choices or not resp.choices[0].delta.content:
  732. continue
  733. ans += resp.choices[0].delta.content
  734. total_tokens += 1
  735. if resp.choices[0].finish_reason == "length":
  736. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  737. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  738. yield ans
  739. except Exception as e:
  740. yield ans + "\n**ERROR**: " + str(e)
  741. yield total_tokens
  742. ## openrouter
  743. class OpenRouterChat(Base):
  744. def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
  745. if not base_url:
  746. base_url = "https://openrouter.ai/api/v1"
  747. super().__init__(key, model_name, base_url)
  748. class StepFunChat(Base):
  749. def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
  750. if not base_url:
  751. base_url = "https://api.stepfun.com/v1"
  752. super().__init__(key, model_name, base_url)
  753. class NvidiaChat(Base):
  754. def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"):
  755. if not base_url:
  756. base_url = "https://integrate.api.nvidia.com/v1"
  757. super().__init__(key, model_name, base_url)
  758. class LmStudioChat(Base):
  759. def __init__(self, key, model_name, base_url):
  760. if not base_url:
  761. raise ValueError("Local llm url cannot be None")
  762. if base_url.split("/")[-1] != "v1":
  763. base_url = os.path.join(base_url, "v1")
  764. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  765. self.model_name = model_name
  766. class OpenAI_APIChat(Base):
  767. def __init__(self, key, model_name, base_url):
  768. if not base_url:
  769. raise ValueError("url cannot be None")
  770. if base_url.split("/")[-1] != "v1":
  771. base_url = os.path.join(base_url, "v1")
  772. model_name = model_name.split("___")[0]
  773. super().__init__(key, model_name, base_url)
  774. class CoHereChat(Base):
  775. def __init__(self, key, model_name, base_url=""):
  776. from cohere import Client
  777. self.client = Client(api_key=key)
  778. self.model_name = model_name
  779. def chat(self, system, history, gen_conf):
  780. if system:
  781. history.insert(0, {"role": "system", "content": system})
  782. if "top_p" in gen_conf:
  783. gen_conf["p"] = gen_conf.pop("top_p")
  784. if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
  785. gen_conf.pop("presence_penalty")
  786. for item in history:
  787. if "role" in item and item["role"] == "user":
  788. item["role"] = "USER"
  789. if "role" in item and item["role"] == "assistant":
  790. item["role"] = "CHATBOT"
  791. if "content" in item:
  792. item["message"] = item.pop("content")
  793. mes = history.pop()["message"]
  794. ans = ""
  795. try:
  796. response = self.client.chat(
  797. model=self.model_name, chat_history=history, message=mes, **gen_conf
  798. )
  799. ans = response.text
  800. if response.finish_reason == "MAX_TOKENS":
  801. ans += (
  802. "...\nFor the content length reason, it stopped, continue?"
  803. if is_english([ans])
  804. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  805. )
  806. return (
  807. ans,
  808. response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
  809. )
  810. except Exception as e:
  811. return ans + "\n**ERROR**: " + str(e), 0
  812. def chat_streamly(self, system, history, gen_conf):
  813. if system:
  814. history.insert(0, {"role": "system", "content": system})
  815. if "top_p" in gen_conf:
  816. gen_conf["p"] = gen_conf.pop("top_p")
  817. if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
  818. gen_conf.pop("presence_penalty")
  819. for item in history:
  820. if "role" in item and item["role"] == "user":
  821. item["role"] = "USER"
  822. if "role" in item and item["role"] == "assistant":
  823. item["role"] = "CHATBOT"
  824. if "content" in item:
  825. item["message"] = item.pop("content")
  826. mes = history.pop()["message"]
  827. ans = ""
  828. total_tokens = 0
  829. try:
  830. response = self.client.chat_stream(
  831. model=self.model_name, chat_history=history, message=mes, **gen_conf
  832. )
  833. for resp in response:
  834. if resp.event_type == "text-generation":
  835. ans += resp.text
  836. total_tokens += num_tokens_from_string(resp.text)
  837. elif resp.event_type == "stream-end":
  838. if resp.finish_reason == "MAX_TOKENS":
  839. ans += (
  840. "...\nFor the content length reason, it stopped, continue?"
  841. if is_english([ans])
  842. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  843. )
  844. yield ans
  845. except Exception as e:
  846. yield ans + "\n**ERROR**: " + str(e)
  847. yield total_tokens
  848. class LeptonAIChat(Base):
  849. def __init__(self, key, model_name, base_url=None):
  850. if not base_url:
  851. base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
  852. super().__init__(key, model_name, base_url)
  853. class TogetherAIChat(Base):
  854. def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
  855. if not base_url:
  856. base_url = "https://api.together.xyz/v1"
  857. super().__init__(key, model_name, base_url)
  858. class PerfXCloudChat(Base):
  859. def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
  860. if not base_url:
  861. base_url = "https://cloud.perfxlab.cn/v1"
  862. super().__init__(key, model_name, base_url)
  863. class UpstageChat(Base):
  864. def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
  865. if not base_url:
  866. base_url = "https://api.upstage.ai/v1/solar"
  867. super().__init__(key, model_name, base_url)
  868. class NovitaAIChat(Base):
  869. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai"):
  870. if not base_url:
  871. base_url = "https://api.novita.ai/v3/openai"
  872. super().__init__(key, model_name, base_url)
  873. class SILICONFLOWChat(Base):
  874. def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
  875. if not base_url:
  876. base_url = "https://api.siliconflow.cn/v1"
  877. super().__init__(key, model_name, base_url)
  878. class YiChat(Base):
  879. def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1"):
  880. if not base_url:
  881. base_url = "https://api.lingyiwanwu.com/v1"
  882. super().__init__(key, model_name, base_url)
  883. class ReplicateChat(Base):
  884. def __init__(self, key, model_name, base_url=None):
  885. from replicate.client import Client
  886. self.model_name = model_name
  887. self.client = Client(api_token=key)
  888. self.system = ""
  889. def chat(self, system, history, gen_conf):
  890. if "max_tokens" in gen_conf:
  891. gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
  892. if system:
  893. self.system = system
  894. prompt = "\n".join(
  895. [item["role"] + ":" + item["content"] for item in history[-5:]]
  896. )
  897. ans = ""
  898. try:
  899. response = self.client.run(
  900. self.model_name,
  901. input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
  902. )
  903. ans = "".join(response)
  904. return ans, num_tokens_from_string(ans)
  905. except Exception as e:
  906. return ans + "\n**ERROR**: " + str(e), 0
  907. def chat_streamly(self, system, history, gen_conf):
  908. if "max_tokens" in gen_conf:
  909. gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
  910. if system:
  911. self.system = system
  912. prompt = "\n".join(
  913. [item["role"] + ":" + item["content"] for item in history[-5:]]
  914. )
  915. ans = ""
  916. try:
  917. response = self.client.run(
  918. self.model_name,
  919. input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
  920. )
  921. for resp in response:
  922. ans += resp
  923. yield ans
  924. except Exception as e:
  925. yield ans + "\n**ERROR**: " + str(e)
  926. yield num_tokens_from_string(ans)
  927. class HunyuanChat(Base):
  928. def __init__(self, key, model_name, base_url=None):
  929. from tencentcloud.common import credential
  930. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  931. key = json.loads(key)
  932. sid = key.get("hunyuan_sid", "")
  933. sk = key.get("hunyuan_sk", "")
  934. cred = credential.Credential(sid, sk)
  935. self.model_name = model_name
  936. self.client = hunyuan_client.HunyuanClient(cred, "")
  937. def chat(self, system, history, gen_conf):
  938. from tencentcloud.hunyuan.v20230901 import models
  939. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  940. TencentCloudSDKException,
  941. )
  942. _gen_conf = {}
  943. _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
  944. if system:
  945. _history.insert(0, {"Role": "system", "Content": system})
  946. if "temperature" in gen_conf:
  947. _gen_conf["Temperature"] = gen_conf["temperature"]
  948. if "top_p" in gen_conf:
  949. _gen_conf["TopP"] = gen_conf["top_p"]
  950. req = models.ChatCompletionsRequest()
  951. params = {"Model": self.model_name, "Messages": _history, **_gen_conf}
  952. req.from_json_string(json.dumps(params))
  953. ans = ""
  954. try:
  955. response = self.client.ChatCompletions(req)
  956. ans = response.Choices[0].Message.Content
  957. return ans, response.Usage.TotalTokens
  958. except TencentCloudSDKException as e:
  959. return ans + "\n**ERROR**: " + str(e), 0
  960. def chat_streamly(self, system, history, gen_conf):
  961. from tencentcloud.hunyuan.v20230901 import models
  962. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  963. TencentCloudSDKException,
  964. )
  965. _gen_conf = {}
  966. _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
  967. if system:
  968. _history.insert(0, {"Role": "system", "Content": system})
  969. if "temperature" in gen_conf:
  970. _gen_conf["Temperature"] = gen_conf["temperature"]
  971. if "top_p" in gen_conf:
  972. _gen_conf["TopP"] = gen_conf["top_p"]
  973. req = models.ChatCompletionsRequest()
  974. params = {
  975. "Model": self.model_name,
  976. "Messages": _history,
  977. "Stream": True,
  978. **_gen_conf,
  979. }
  980. req.from_json_string(json.dumps(params))
  981. ans = ""
  982. total_tokens = 0
  983. try:
  984. response = self.client.ChatCompletions(req)
  985. for resp in response:
  986. resp = json.loads(resp["data"])
  987. if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
  988. continue
  989. ans += resp["Choices"][0]["Delta"]["Content"]
  990. total_tokens += 1
  991. yield ans
  992. except TencentCloudSDKException as e:
  993. yield ans + "\n**ERROR**: " + str(e)
  994. yield total_tokens
  995. class SparkChat(Base):
  996. def __init__(
  997. self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
  998. ):
  999. if not base_url:
  1000. base_url = "https://spark-api-open.xf-yun.com/v1"
  1001. model2version = {
  1002. "Spark-Max": "generalv3.5",
  1003. "Spark-Lite": "general",
  1004. "Spark-Pro": "generalv3",
  1005. "Spark-Pro-128K": "pro-128k",
  1006. "Spark-4.0-Ultra": "4.0Ultra",
  1007. }
  1008. model_version = model2version[model_name]
  1009. super().__init__(key, model_version, base_url)
  1010. class BaiduYiyanChat(Base):
  1011. def __init__(self, key, model_name, base_url=None):
  1012. import qianfan
  1013. key = json.loads(key)
  1014. ak = key.get("yiyan_ak", "")
  1015. sk = key.get("yiyan_sk", "")
  1016. self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
  1017. self.model_name = model_name.lower()
  1018. self.system = ""
  1019. def chat(self, system, history, gen_conf):
  1020. if system:
  1021. self.system = system
  1022. gen_conf["penalty_score"] = (
  1023. (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
  1024. 0)) / 2
  1025. ) + 1
  1026. if "max_tokens" in gen_conf:
  1027. gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
  1028. ans = ""
  1029. try:
  1030. response = self.client.do(
  1031. model=self.model_name,
  1032. messages=history,
  1033. system=self.system,
  1034. **gen_conf
  1035. ).body
  1036. ans = response['result']
  1037. return ans, response["usage"]["total_tokens"]
  1038. except Exception as e:
  1039. return ans + "\n**ERROR**: " + str(e), 0
  1040. def chat_streamly(self, system, history, gen_conf):
  1041. if system:
  1042. self.system = system
  1043. gen_conf["penalty_score"] = (
  1044. (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
  1045. 0)) / 2
  1046. ) + 1
  1047. if "max_tokens" in gen_conf:
  1048. gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
  1049. ans = ""
  1050. total_tokens = 0
  1051. try:
  1052. response = self.client.do(
  1053. model=self.model_name,
  1054. messages=history,
  1055. system=self.system,
  1056. stream=True,
  1057. **gen_conf
  1058. )
  1059. for resp in response:
  1060. resp = resp.body
  1061. ans += resp['result']
  1062. total_tokens = resp["usage"]["total_tokens"]
  1063. yield ans
  1064. except Exception as e:
  1065. return ans + "\n**ERROR**: " + str(e), 0
  1066. yield total_tokens
  1067. class AnthropicChat(Base):
  1068. def __init__(self, key, model_name, base_url=None):
  1069. import anthropic
  1070. self.client = anthropic.Anthropic(api_key=key)
  1071. self.model_name = model_name
  1072. self.system = ""
  1073. def chat(self, system, history, gen_conf):
  1074. if system:
  1075. self.system = system
  1076. if "max_tokens" not in gen_conf:
  1077. gen_conf["max_tokens"] = 4096
  1078. try:
  1079. response = self.client.messages.create(
  1080. model=self.model_name,
  1081. messages=history,
  1082. system=self.system,
  1083. stream=False,
  1084. **gen_conf,
  1085. ).json()
  1086. ans = response["content"][0]["text"]
  1087. if response["stop_reason"] == "max_tokens":
  1088. ans += (
  1089. "...\nFor the content length reason, it stopped, continue?"
  1090. if is_english([ans])
  1091. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  1092. )
  1093. return (
  1094. ans,
  1095. response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
  1096. )
  1097. except Exception as e:
  1098. return ans + "\n**ERROR**: " + str(e), 0
  1099. def chat_streamly(self, system, history, gen_conf):
  1100. if system:
  1101. self.system = system
  1102. if "max_tokens" not in gen_conf:
  1103. gen_conf["max_tokens"] = 4096
  1104. ans = ""
  1105. total_tokens = 0
  1106. try:
  1107. response = self.client.messages.create(
  1108. model=self.model_name,
  1109. messages=history,
  1110. system=self.system,
  1111. stream=True,
  1112. **gen_conf,
  1113. )
  1114. for res in response.iter_lines():
  1115. res = res.decode("utf-8")
  1116. if "content_block_delta" in res and "data" in res:
  1117. text = json.loads(res[6:])["delta"]["text"]
  1118. ans += text
  1119. total_tokens += num_tokens_from_string(text)
  1120. except Exception as e:
  1121. yield ans + "\n**ERROR**: " + str(e)
  1122. yield total_tokens
  1123. class GoogleChat(Base):
  1124. def __init__(self, key, model_name, base_url=None):
  1125. from google.oauth2 import service_account
  1126. import base64
  1127. key = json.load(key)
  1128. access_token = json.loads(
  1129. base64.b64decode(key.get("google_service_account_key", ""))
  1130. )
  1131. project_id = key.get("google_project_id", "")
  1132. region = key.get("google_region", "")
  1133. scopes = ["https://www.googleapis.com/auth/cloud-platform"]
  1134. self.model_name = model_name
  1135. self.system = ""
  1136. if "claude" in self.model_name:
  1137. from anthropic import AnthropicVertex
  1138. from google.auth.transport.requests import Request
  1139. if access_token:
  1140. credits = service_account.Credentials.from_service_account_info(
  1141. access_token, scopes=scopes
  1142. )
  1143. request = Request()
  1144. credits.refresh(request)
  1145. token = credits.token
  1146. self.client = AnthropicVertex(
  1147. region=region, project_id=project_id, access_token=token
  1148. )
  1149. else:
  1150. self.client = AnthropicVertex(region=region, project_id=project_id)
  1151. else:
  1152. from google.cloud import aiplatform
  1153. import vertexai.generative_models as glm
  1154. if access_token:
  1155. credits = service_account.Credentials.from_service_account_info(
  1156. access_token
  1157. )
  1158. aiplatform.init(
  1159. credentials=credits, project=project_id, location=region
  1160. )
  1161. else:
  1162. aiplatform.init(project=project_id, location=region)
  1163. self.client = glm.GenerativeModel(model_name=self.model_name)
  1164. def chat(self, system, history, gen_conf):
  1165. if system:
  1166. self.system = system
  1167. if "claude" in self.model_name:
  1168. if "max_tokens" not in gen_conf:
  1169. gen_conf["max_tokens"] = 4096
  1170. try:
  1171. response = self.client.messages.create(
  1172. model=self.model_name,
  1173. messages=history,
  1174. system=self.system,
  1175. stream=False,
  1176. **gen_conf,
  1177. ).json()
  1178. ans = response["content"][0]["text"]
  1179. if response["stop_reason"] == "max_tokens":
  1180. ans += (
  1181. "...\nFor the content length reason, it stopped, continue?"
  1182. if is_english([ans])
  1183. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  1184. )
  1185. return (
  1186. ans,
  1187. response["usage"]["input_tokens"]
  1188. + response["usage"]["output_tokens"],
  1189. )
  1190. except Exception as e:
  1191. return "\n**ERROR**: " + str(e), 0
  1192. else:
  1193. self.client._system_instruction = self.system
  1194. if "max_tokens" in gen_conf:
  1195. gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
  1196. for k in list(gen_conf.keys()):
  1197. if k not in ["temperature", "top_p", "max_output_tokens"]:
  1198. del gen_conf[k]
  1199. for item in history:
  1200. if "role" in item and item["role"] == "assistant":
  1201. item["role"] = "model"
  1202. if "content" in item:
  1203. item["parts"] = item.pop("content")
  1204. try:
  1205. response = self.client.generate_content(
  1206. history, generation_config=gen_conf
  1207. )
  1208. ans = response.text
  1209. return ans, response.usage_metadata.total_token_count
  1210. except Exception as e:
  1211. return "**ERROR**: " + str(e), 0
  1212. def chat_streamly(self, system, history, gen_conf):
  1213. if system:
  1214. self.system = system
  1215. if "claude" in self.model_name:
  1216. if "max_tokens" not in gen_conf:
  1217. gen_conf["max_tokens"] = 4096
  1218. ans = ""
  1219. total_tokens = 0
  1220. try:
  1221. response = self.client.messages.create(
  1222. model=self.model_name,
  1223. messages=history,
  1224. system=self.system,
  1225. stream=True,
  1226. **gen_conf,
  1227. )
  1228. for res in response.iter_lines():
  1229. res = res.decode("utf-8")
  1230. if "content_block_delta" in res and "data" in res:
  1231. text = json.loads(res[6:])["delta"]["text"]
  1232. ans += text
  1233. total_tokens += num_tokens_from_string(text)
  1234. except Exception as e:
  1235. yield ans + "\n**ERROR**: " + str(e)
  1236. yield total_tokens
  1237. else:
  1238. self.client._system_instruction = self.system
  1239. if "max_tokens" in gen_conf:
  1240. gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
  1241. for k in list(gen_conf.keys()):
  1242. if k not in ["temperature", "top_p", "max_output_tokens"]:
  1243. del gen_conf[k]
  1244. for item in history:
  1245. if "role" in item and item["role"] == "assistant":
  1246. item["role"] = "model"
  1247. if "content" in item:
  1248. item["parts"] = item.pop("content")
  1249. ans = ""
  1250. try:
  1251. response = self.model.generate_content(
  1252. history, generation_config=gen_conf, stream=True
  1253. )
  1254. for resp in response:
  1255. ans += resp.text
  1256. yield ans
  1257. except Exception as e:
  1258. yield ans + "\n**ERROR**: " + str(e)
  1259. yield response._chunks[-1].usage_metadata.total_token_count