You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import asyncio
  17. import json
  18. import logging
  19. import os
  20. import random
  21. import re
  22. import time
  23. from abc import ABC
  24. from copy import deepcopy
  25. from typing import Any, Protocol
  26. from urllib.parse import urljoin
  27. import json_repair
  28. import litellm
  29. import openai
  30. import requests
  31. from ollama import Client
  32. from openai import OpenAI
  33. from openai.lib.azure import AzureOpenAI
  34. from strenum import StrEnum
  35. from zhipuai import ZhipuAI
  36. from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
  37. from rag.nlp import is_chinese, is_english
  38. from rag.utils import num_tokens_from_string
  39. # Error message constants
  40. class LLMErrorCode(StrEnum):
  41. ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
  42. ERROR_AUTHENTICATION = "AUTH_ERROR"
  43. ERROR_INVALID_REQUEST = "INVALID_REQUEST"
  44. ERROR_SERVER = "SERVER_ERROR"
  45. ERROR_TIMEOUT = "TIMEOUT"
  46. ERROR_CONNECTION = "CONNECTION_ERROR"
  47. ERROR_MODEL = "MODEL_ERROR"
  48. ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
  49. ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
  50. ERROR_QUOTA = "QUOTA_EXCEEDED"
  51. ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
  52. ERROR_GENERIC = "GENERIC_ERROR"
  53. class ReActMode(StrEnum):
  54. FUNCTION_CALL = "function_call"
  55. REACT = "react"
  56. ERROR_PREFIX = "**ERROR**"
  57. LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
  58. LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
  59. class ToolCallSession(Protocol):
  60. def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
  61. class Base(ABC):
  62. def __init__(self, key, model_name, base_url, **kwargs):
  63. timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
  64. self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
  65. self.model_name = model_name
  66. # Configure retry parameters
  67. self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
  68. self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
  69. self.max_rounds = kwargs.get("max_rounds", 5)
  70. self.is_tools = False
  71. self.tools = []
  72. self.toolcall_sessions = {}
  73. def _get_delay(self):
  74. """Calculate retry delay time"""
  75. return self.base_delay * random.uniform(10, 150)
  76. def _classify_error(self, error):
  77. """Classify error based on error message content"""
  78. error_str = str(error).lower()
  79. keywords_mapping = [
  80. (["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
  81. (["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
  82. (["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
  83. (["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
  84. (["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
  85. (["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
  86. (["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
  87. (["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
  88. (["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
  89. (["max rounds"], LLMErrorCode.ERROR_MODEL),
  90. ]
  91. for words, code in keywords_mapping:
  92. if re.search("({})".format("|".join(words)), error_str):
  93. return code
  94. return LLMErrorCode.ERROR_GENERIC
  95. def _clean_conf(self, gen_conf):
  96. if "max_tokens" in gen_conf:
  97. del gen_conf["max_tokens"]
  98. return gen_conf
  99. def _chat(self, history, gen_conf, **kwargs):
  100. logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
  101. if self.model_name.lower().find("qwen3") >= 0:
  102. kwargs["extra_body"] = {"enable_thinking": False}
  103. response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
  104. if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
  105. return "", 0
  106. ans = response.choices[0].message.content.strip()
  107. if response.choices[0].finish_reason == "length":
  108. ans = self._length_stop(ans)
  109. return ans, self.total_token_count(response)
  110. def _chat_streamly(self, history, gen_conf, **kwargs):
  111. logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
  112. reasoning_start = False
  113. response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
  114. for resp in response:
  115. if not resp.choices:
  116. continue
  117. if not resp.choices[0].delta.content:
  118. resp.choices[0].delta.content = ""
  119. if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
  120. ans = ""
  121. if not reasoning_start:
  122. reasoning_start = True
  123. ans = "<think>"
  124. ans += resp.choices[0].delta.reasoning_content + "</think>"
  125. else:
  126. reasoning_start = False
  127. ans = resp.choices[0].delta.content
  128. tol = self.total_token_count(resp)
  129. if not tol:
  130. tol = num_tokens_from_string(resp.choices[0].delta.content)
  131. if resp.choices[0].finish_reason == "length":
  132. if is_chinese(ans):
  133. ans += LENGTH_NOTIFICATION_CN
  134. else:
  135. ans += LENGTH_NOTIFICATION_EN
  136. yield ans, tol
  137. def _length_stop(self, ans):
  138. if is_chinese([ans]):
  139. return ans + LENGTH_NOTIFICATION_CN
  140. return ans + LENGTH_NOTIFICATION_EN
  141. def _exceptions(self, e, attempt):
  142. logging.exception("OpenAI chat_with_tools")
  143. # Classify the error
  144. error_code = self._classify_error(e)
  145. if attempt == self.max_retries:
  146. error_code = LLMErrorCode.ERROR_MAX_RETRIES
  147. # Check if it's a rate limit error or server error and not the last attempt
  148. should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
  149. if not should_retry:
  150. return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
  151. delay = self._get_delay()
  152. logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
  153. time.sleep(delay)
  154. def _verbose_tool_use(self, name, args, res):
  155. return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
  156. def _append_history(self, hist, tool_call, tool_res):
  157. hist.append(
  158. {
  159. "role": "assistant",
  160. "tool_calls": [
  161. {
  162. "index": tool_call.index,
  163. "id": tool_call.id,
  164. "function": {
  165. "name": tool_call.function.name,
  166. "arguments": tool_call.function.arguments,
  167. },
  168. "type": "function",
  169. },
  170. ],
  171. }
  172. )
  173. try:
  174. if isinstance(tool_res, dict):
  175. tool_res = json.dumps(tool_res, ensure_ascii=False)
  176. finally:
  177. hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
  178. return hist
  179. def bind_tools(self, toolcall_session, tools):
  180. if not (toolcall_session and tools):
  181. return
  182. self.is_tools = True
  183. self.toolcall_session = toolcall_session
  184. self.tools = tools
  185. def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
  186. gen_conf = self._clean_conf(gen_conf)
  187. if system:
  188. history.insert(0, {"role": "system", "content": system})
  189. ans = ""
  190. tk_count = 0
  191. hist = deepcopy(history)
  192. # Implement exponential backoff retry strategy
  193. for attempt in range(self.max_retries + 1):
  194. history = hist
  195. try:
  196. for _ in range(self.max_rounds + 1):
  197. logging.info(f"{self.tools=}")
  198. response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
  199. tk_count += self.total_token_count(response)
  200. if any([not response.choices, not response.choices[0].message]):
  201. raise Exception(f"500 response structure error. Response: {response}")
  202. if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
  203. if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
  204. ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
  205. ans += response.choices[0].message.content
  206. if response.choices[0].finish_reason == "length":
  207. ans = self._length_stop(ans)
  208. return ans, tk_count
  209. for tool_call in response.choices[0].message.tool_calls:
  210. logging.info(f"Response {tool_call=}")
  211. name = tool_call.function.name
  212. try:
  213. args = json_repair.loads(tool_call.function.arguments)
  214. tool_response = self.toolcall_session.tool_call(name, args)
  215. history = self._append_history(history, tool_call, tool_response)
  216. ans += self._verbose_tool_use(name, args, tool_response)
  217. except Exception as e:
  218. logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
  219. history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
  220. ans += self._verbose_tool_use(name, {}, str(e))
  221. logging.warning(f"Exceed max rounds: {self.max_rounds}")
  222. history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
  223. response, token_count = self._chat(history, gen_conf)
  224. ans += response
  225. tk_count += token_count
  226. return ans, tk_count
  227. except Exception as e:
  228. e = self._exceptions(e, attempt)
  229. if e:
  230. return e, tk_count
  231. assert False, "Shouldn't be here."
  232. def chat(self, system, history, gen_conf={}, **kwargs):
  233. if system:
  234. history.insert(0, {"role": "system", "content": system})
  235. gen_conf = self._clean_conf(gen_conf)
  236. # Implement exponential backoff retry strategy
  237. for attempt in range(self.max_retries + 1):
  238. try:
  239. return self._chat(history, gen_conf, **kwargs)
  240. except Exception as e:
  241. e = self._exceptions(e, attempt)
  242. if e:
  243. return e, 0
  244. assert False, "Shouldn't be here."
  245. def _wrap_toolcall_message(self, stream):
  246. final_tool_calls = {}
  247. for chunk in stream:
  248. for tool_call in chunk.choices[0].delta.tool_calls or []:
  249. index = tool_call.index
  250. if index not in final_tool_calls:
  251. final_tool_calls[index] = tool_call
  252. final_tool_calls[index].function.arguments += tool_call.function.arguments
  253. return final_tool_calls
  254. def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
  255. gen_conf = self._clean_conf(gen_conf)
  256. tools = self.tools
  257. if system:
  258. history.insert(0, {"role": "system", "content": system})
  259. total_tokens = 0
  260. hist = deepcopy(history)
  261. # Implement exponential backoff retry strategy
  262. for attempt in range(self.max_retries + 1):
  263. history = hist
  264. try:
  265. for _ in range(self.max_rounds + 1):
  266. reasoning_start = False
  267. logging.info(f"{tools=}")
  268. response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
  269. final_tool_calls = {}
  270. answer = ""
  271. for resp in response:
  272. if resp.choices[0].delta.tool_calls:
  273. for tool_call in resp.choices[0].delta.tool_calls or []:
  274. index = tool_call.index
  275. if index not in final_tool_calls:
  276. if not tool_call.function.arguments:
  277. tool_call.function.arguments = ""
  278. final_tool_calls[index] = tool_call
  279. else:
  280. final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
  281. continue
  282. if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
  283. raise Exception("500 response structure error.")
  284. if not resp.choices[0].delta.content:
  285. resp.choices[0].delta.content = ""
  286. if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
  287. ans = ""
  288. if not reasoning_start:
  289. reasoning_start = True
  290. ans = "<think>"
  291. ans += resp.choices[0].delta.reasoning_content + "</think>"
  292. yield ans
  293. else:
  294. reasoning_start = False
  295. answer += resp.choices[0].delta.content
  296. yield resp.choices[0].delta.content
  297. tol = self.total_token_count(resp)
  298. if not tol:
  299. total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
  300. else:
  301. total_tokens += tol
  302. finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
  303. if finish_reason == "length":
  304. yield self._length_stop("")
  305. if answer:
  306. yield total_tokens
  307. return
  308. for tool_call in final_tool_calls.values():
  309. name = tool_call.function.name
  310. try:
  311. args = json_repair.loads(tool_call.function.arguments)
  312. yield self._verbose_tool_use(name, args, "Begin to call...")
  313. tool_response = self.toolcall_session.tool_call(name, args)
  314. history = self._append_history(history, tool_call, tool_response)
  315. yield self._verbose_tool_use(name, args, tool_response)
  316. except Exception as e:
  317. logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
  318. history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
  319. yield self._verbose_tool_use(name, {}, str(e))
  320. logging.warning(f"Exceed max rounds: {self.max_rounds}")
  321. history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
  322. response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
  323. for resp in response:
  324. if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
  325. raise Exception("500 response structure error.")
  326. if not resp.choices[0].delta.content:
  327. resp.choices[0].delta.content = ""
  328. continue
  329. tol = self.total_token_count(resp)
  330. if not tol:
  331. total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
  332. else:
  333. total_tokens += tol
  334. answer += resp.choices[0].delta.content
  335. yield resp.choices[0].delta.content
  336. yield total_tokens
  337. return
  338. except Exception as e:
  339. e = self._exceptions(e, attempt)
  340. if e:
  341. yield e
  342. yield total_tokens
  343. return
  344. assert False, "Shouldn't be here."
  345. def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
  346. if system:
  347. history.insert(0, {"role": "system", "content": system})
  348. gen_conf = self._clean_conf(gen_conf)
  349. ans = ""
  350. total_tokens = 0
  351. try:
  352. for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
  353. yield delta_ans
  354. total_tokens += tol
  355. except openai.APIError as e:
  356. yield ans + "\n**ERROR**: " + str(e)
  357. yield total_tokens
  358. def total_token_count(self, resp):
  359. try:
  360. return resp.usage.total_tokens
  361. except Exception:
  362. pass
  363. try:
  364. return resp["usage"]["total_tokens"]
  365. except Exception:
  366. pass
  367. return 0
  368. def _calculate_dynamic_ctx(self, history):
  369. """Calculate dynamic context window size"""
  370. def count_tokens(text):
  371. """Calculate token count for text"""
  372. # Simple calculation: 1 token per ASCII character
  373. # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
  374. total = 0
  375. for char in text:
  376. if ord(char) < 128: # ASCII characters
  377. total += 1
  378. else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
  379. total += 2
  380. return total
  381. # Calculate total tokens for all messages
  382. total_tokens = 0
  383. for message in history:
  384. content = message.get("content", "")
  385. # Calculate content tokens
  386. content_tokens = count_tokens(content)
  387. # Add role marker token overhead
  388. role_tokens = 4
  389. total_tokens += content_tokens + role_tokens
  390. # Apply 1.2x buffer ratio
  391. total_tokens_with_buffer = int(total_tokens * 1.2)
  392. if total_tokens_with_buffer <= 8192:
  393. ctx_size = 8192
  394. else:
  395. ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
  396. ctx_size = ctx_multiplier * 8192
  397. return ctx_size
  398. class GptTurbo(Base):
  399. _FACTORY_NAME = "OpenAI"
  400. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
  401. if not base_url:
  402. base_url = "https://api.openai.com/v1"
  403. super().__init__(key, model_name, base_url, **kwargs)
  404. class XinferenceChat(Base):
  405. _FACTORY_NAME = "Xinference"
  406. def __init__(self, key=None, model_name="", base_url="", **kwargs):
  407. if not base_url:
  408. raise ValueError("Local llm url cannot be None")
  409. base_url = urljoin(base_url, "v1")
  410. super().__init__(key, model_name, base_url, **kwargs)
  411. class HuggingFaceChat(Base):
  412. _FACTORY_NAME = "HuggingFace"
  413. def __init__(self, key=None, model_name="", base_url="", **kwargs):
  414. if not base_url:
  415. raise ValueError("Local llm url cannot be None")
  416. base_url = urljoin(base_url, "v1")
  417. super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
  418. class ModelScopeChat(Base):
  419. _FACTORY_NAME = "ModelScope"
  420. def __init__(self, key=None, model_name="", base_url="", **kwargs):
  421. if not base_url:
  422. raise ValueError("Local llm url cannot be None")
  423. base_url = urljoin(base_url, "v1")
  424. super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
  425. class AzureChat(Base):
  426. _FACTORY_NAME = "Azure-OpenAI"
  427. def __init__(self, key, model_name, base_url, **kwargs):
  428. api_key = json.loads(key).get("api_key", "")
  429. api_version = json.loads(key).get("api_version", "2024-02-01")
  430. super().__init__(key, model_name, base_url, **kwargs)
  431. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
  432. self.model_name = model_name
  433. class BaiChuanChat(Base):
  434. _FACTORY_NAME = "BaiChuan"
  435. def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
  436. if not base_url:
  437. base_url = "https://api.baichuan-ai.com/v1"
  438. super().__init__(key, model_name, base_url, **kwargs)
  439. @staticmethod
  440. def _format_params(params):
  441. return {
  442. "temperature": params.get("temperature", 0.3),
  443. "top_p": params.get("top_p", 0.85),
  444. }
  445. def _clean_conf(self, gen_conf):
  446. return {
  447. "temperature": gen_conf.get("temperature", 0.3),
  448. "top_p": gen_conf.get("top_p", 0.85),
  449. }
  450. def _chat(self, history, gen_conf={}, **kwargs):
  451. response = self.client.chat.completions.create(
  452. model=self.model_name,
  453. messages=history,
  454. extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
  455. **gen_conf,
  456. )
  457. ans = response.choices[0].message.content.strip()
  458. if response.choices[0].finish_reason == "length":
  459. if is_chinese([ans]):
  460. ans += LENGTH_NOTIFICATION_CN
  461. else:
  462. ans += LENGTH_NOTIFICATION_EN
  463. return ans, self.total_token_count(response)
  464. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  465. if system:
  466. history.insert(0, {"role": "system", "content": system})
  467. if "max_tokens" in gen_conf:
  468. del gen_conf["max_tokens"]
  469. ans = ""
  470. total_tokens = 0
  471. try:
  472. response = self.client.chat.completions.create(
  473. model=self.model_name,
  474. messages=history,
  475. extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
  476. stream=True,
  477. **self._format_params(gen_conf),
  478. )
  479. for resp in response:
  480. if not resp.choices:
  481. continue
  482. if not resp.choices[0].delta.content:
  483. resp.choices[0].delta.content = ""
  484. ans = resp.choices[0].delta.content
  485. tol = self.total_token_count(resp)
  486. if not tol:
  487. total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
  488. else:
  489. total_tokens = tol
  490. if resp.choices[0].finish_reason == "length":
  491. if is_chinese([ans]):
  492. ans += LENGTH_NOTIFICATION_CN
  493. else:
  494. ans += LENGTH_NOTIFICATION_EN
  495. yield ans
  496. except Exception as e:
  497. yield ans + "\n**ERROR**: " + str(e)
  498. yield total_tokens
  499. class ZhipuChat(Base):
  500. _FACTORY_NAME = "ZHIPU-AI"
  501. def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
  502. super().__init__(key, model_name, base_url=base_url, **kwargs)
  503. self.client = ZhipuAI(api_key=key)
  504. self.model_name = model_name
  505. def _clean_conf(self, gen_conf):
  506. if "max_tokens" in gen_conf:
  507. del gen_conf["max_tokens"]
  508. if "presence_penalty" in gen_conf:
  509. del gen_conf["presence_penalty"]
  510. if "frequency_penalty" in gen_conf:
  511. del gen_conf["frequency_penalty"]
  512. return gen_conf
  513. def chat_with_tools(self, system: str, history: list, gen_conf: dict):
  514. if "presence_penalty" in gen_conf:
  515. del gen_conf["presence_penalty"]
  516. if "frequency_penalty" in gen_conf:
  517. del gen_conf["frequency_penalty"]
  518. return super().chat_with_tools(system, history, gen_conf)
  519. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  520. if system:
  521. history.insert(0, {"role": "system", "content": system})
  522. if "max_tokens" in gen_conf:
  523. del gen_conf["max_tokens"]
  524. if "presence_penalty" in gen_conf:
  525. del gen_conf["presence_penalty"]
  526. if "frequency_penalty" in gen_conf:
  527. del gen_conf["frequency_penalty"]
  528. ans = ""
  529. tk_count = 0
  530. try:
  531. logging.info(json.dumps(history, ensure_ascii=False, indent=2))
  532. response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
  533. for resp in response:
  534. if not resp.choices[0].delta.content:
  535. continue
  536. delta = resp.choices[0].delta.content
  537. ans = delta
  538. if resp.choices[0].finish_reason == "length":
  539. if is_chinese(ans):
  540. ans += LENGTH_NOTIFICATION_CN
  541. else:
  542. ans += LENGTH_NOTIFICATION_EN
  543. tk_count = self.total_token_count(resp)
  544. if resp.choices[0].finish_reason == "stop":
  545. tk_count = self.total_token_count(resp)
  546. yield ans
  547. except Exception as e:
  548. yield ans + "\n**ERROR**: " + str(e)
  549. yield tk_count
  550. def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
  551. if "presence_penalty" in gen_conf:
  552. del gen_conf["presence_penalty"]
  553. if "frequency_penalty" in gen_conf:
  554. del gen_conf["frequency_penalty"]
  555. return super().chat_streamly_with_tools(system, history, gen_conf)
  556. class OllamaChat(Base):
  557. _FACTORY_NAME = "Ollama"
  558. def __init__(self, key, model_name, base_url=None, **kwargs):
  559. super().__init__(key, model_name, base_url=base_url, **kwargs)
  560. self.client = Client(host=base_url) if not key or key == "x" else Client(host=base_url, headers={"Authorization": f"Bearer {key}"})
  561. self.model_name = model_name
  562. self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
  563. def _clean_conf(self, gen_conf):
  564. options = {}
  565. if "max_tokens" in gen_conf:
  566. options["num_predict"] = gen_conf["max_tokens"]
  567. for k in ["temperature", "top_p", "presence_penalty", "frequency_penalty"]:
  568. if k not in gen_conf:
  569. continue
  570. options[k] = gen_conf[k]
  571. return options
  572. def _chat(self, history, gen_conf={}, **kwargs):
  573. # Calculate context size
  574. ctx_size = self._calculate_dynamic_ctx(history)
  575. gen_conf["num_ctx"] = ctx_size
  576. response = self.client.chat(model=self.model_name, messages=history, options=gen_conf, keep_alive=self.keep_alive)
  577. ans = response["message"]["content"].strip()
  578. token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
  579. return ans, token_count
  580. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  581. if system:
  582. history.insert(0, {"role": "system", "content": system})
  583. if "max_tokens" in gen_conf:
  584. del gen_conf["max_tokens"]
  585. try:
  586. # Calculate context size
  587. ctx_size = self._calculate_dynamic_ctx(history)
  588. options = {"num_ctx": ctx_size}
  589. if "temperature" in gen_conf:
  590. options["temperature"] = gen_conf["temperature"]
  591. if "max_tokens" in gen_conf:
  592. options["num_predict"] = gen_conf["max_tokens"]
  593. if "top_p" in gen_conf:
  594. options["top_p"] = gen_conf["top_p"]
  595. if "presence_penalty" in gen_conf:
  596. options["presence_penalty"] = gen_conf["presence_penalty"]
  597. if "frequency_penalty" in gen_conf:
  598. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  599. ans = ""
  600. try:
  601. response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=self.keep_alive)
  602. for resp in response:
  603. if resp["done"]:
  604. token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  605. yield token_count
  606. ans = resp["message"]["content"]
  607. yield ans
  608. except Exception as e:
  609. yield ans + "\n**ERROR**: " + str(e)
  610. yield 0
  611. except Exception as e:
  612. yield "**ERROR**: " + str(e)
  613. yield 0
  614. class LocalAIChat(Base):
  615. _FACTORY_NAME = "LocalAI"
  616. def __init__(self, key, model_name, base_url=None, **kwargs):
  617. super().__init__(key, model_name, base_url=base_url, **kwargs)
  618. if not base_url:
  619. raise ValueError("Local llm url cannot be None")
  620. base_url = urljoin(base_url, "v1")
  621. self.client = OpenAI(api_key="empty", base_url=base_url)
  622. self.model_name = model_name.split("___")[0]
  623. class LocalLLM(Base):
  624. def __init__(self, key, model_name, base_url=None, **kwargs):
  625. super().__init__(key, model_name, base_url=base_url, **kwargs)
  626. from jina import Client
  627. self.client = Client(port=12345, protocol="grpc", asyncio=True)
  628. def _prepare_prompt(self, system, history, gen_conf):
  629. from rag.svr.jina_server import Prompt
  630. if system:
  631. history.insert(0, {"role": "system", "content": system})
  632. return Prompt(message=history, gen_conf=gen_conf)
  633. def _stream_response(self, endpoint, prompt):
  634. from rag.svr.jina_server import Generation
  635. answer = ""
  636. try:
  637. res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation)
  638. loop = asyncio.get_event_loop()
  639. try:
  640. while True:
  641. answer = loop.run_until_complete(res.__anext__()).text
  642. yield answer
  643. except StopAsyncIteration:
  644. pass
  645. except Exception as e:
  646. yield answer + "\n**ERROR**: " + str(e)
  647. yield num_tokens_from_string(answer)
  648. def chat(self, system, history, gen_conf={}, **kwargs):
  649. if "max_tokens" in gen_conf:
  650. del gen_conf["max_tokens"]
  651. prompt = self._prepare_prompt(system, history, gen_conf)
  652. chat_gen = self._stream_response("/chat", prompt)
  653. ans = next(chat_gen)
  654. total_tokens = next(chat_gen)
  655. return ans, total_tokens
  656. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  657. if "max_tokens" in gen_conf:
  658. del gen_conf["max_tokens"]
  659. prompt = self._prepare_prompt(system, history, gen_conf)
  660. return self._stream_response("/stream", prompt)
  661. class VolcEngineChat(Base):
  662. _FACTORY_NAME = "VolcEngine"
  663. def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
  664. """
  665. Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
  666. Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
  667. model_name is for display only
  668. """
  669. base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
  670. ark_api_key = json.loads(key).get("ark_api_key", "")
  671. model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
  672. super().__init__(ark_api_key, model_name, base_url, **kwargs)
  673. class MiniMaxChat(Base):
  674. _FACTORY_NAME = "MiniMax"
  675. def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
  676. super().__init__(key, model_name, base_url=base_url, **kwargs)
  677. if not base_url:
  678. base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
  679. self.base_url = base_url
  680. self.model_name = model_name
  681. self.api_key = key
  682. def _clean_conf(self, gen_conf):
  683. for k in list(gen_conf.keys()):
  684. if k not in ["temperature", "top_p", "max_tokens"]:
  685. del gen_conf[k]
  686. return gen_conf
  687. def _chat(self, history, gen_conf):
  688. headers = {
  689. "Authorization": f"Bearer {self.api_key}",
  690. "Content-Type": "application/json",
  691. }
  692. payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf})
  693. response = requests.request("POST", url=self.base_url, headers=headers, data=payload)
  694. response = response.json()
  695. ans = response["choices"][0]["message"]["content"].strip()
  696. if response["choices"][0]["finish_reason"] == "length":
  697. if is_chinese(ans):
  698. ans += LENGTH_NOTIFICATION_CN
  699. else:
  700. ans += LENGTH_NOTIFICATION_EN
  701. return ans, self.total_token_count(response)
  702. def chat_streamly(self, system, history, gen_conf):
  703. if system:
  704. history.insert(0, {"role": "system", "content": system})
  705. for k in list(gen_conf.keys()):
  706. if k not in ["temperature", "top_p", "max_tokens"]:
  707. del gen_conf[k]
  708. ans = ""
  709. total_tokens = 0
  710. try:
  711. headers = {
  712. "Authorization": f"Bearer {self.api_key}",
  713. "Content-Type": "application/json",
  714. }
  715. payload = json.dumps(
  716. {
  717. "model": self.model_name,
  718. "messages": history,
  719. "stream": True,
  720. **gen_conf,
  721. }
  722. )
  723. response = requests.request(
  724. "POST",
  725. url=self.base_url,
  726. headers=headers,
  727. data=payload,
  728. )
  729. for resp in response.text.split("\n\n")[:-1]:
  730. resp = json.loads(resp[6:])
  731. text = ""
  732. if "choices" in resp and "delta" in resp["choices"][0]:
  733. text = resp["choices"][0]["delta"]["content"]
  734. ans = text
  735. tol = self.total_token_count(resp)
  736. if not tol:
  737. total_tokens += num_tokens_from_string(text)
  738. else:
  739. total_tokens = tol
  740. yield ans
  741. except Exception as e:
  742. yield ans + "\n**ERROR**: " + str(e)
  743. yield total_tokens
  744. class MistralChat(Base):
  745. _FACTORY_NAME = "Mistral"
  746. def __init__(self, key, model_name, base_url=None, **kwargs):
  747. super().__init__(key, model_name, base_url=base_url, **kwargs)
  748. from mistralai.client import MistralClient
  749. self.client = MistralClient(api_key=key)
  750. self.model_name = model_name
  751. def _clean_conf(self, gen_conf):
  752. for k in list(gen_conf.keys()):
  753. if k not in ["temperature", "top_p", "max_tokens"]:
  754. del gen_conf[k]
  755. return gen_conf
  756. def _chat(self, history, gen_conf={}, **kwargs):
  757. response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
  758. ans = response.choices[0].message.content
  759. if response.choices[0].finish_reason == "length":
  760. if is_chinese(ans):
  761. ans += LENGTH_NOTIFICATION_CN
  762. else:
  763. ans += LENGTH_NOTIFICATION_EN
  764. return ans, self.total_token_count(response)
  765. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  766. if system:
  767. history.insert(0, {"role": "system", "content": system})
  768. for k in list(gen_conf.keys()):
  769. if k not in ["temperature", "top_p", "max_tokens"]:
  770. del gen_conf[k]
  771. ans = ""
  772. total_tokens = 0
  773. try:
  774. response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf, **kwargs)
  775. for resp in response:
  776. if not resp.choices or not resp.choices[0].delta.content:
  777. continue
  778. ans = resp.choices[0].delta.content
  779. total_tokens += 1
  780. if resp.choices[0].finish_reason == "length":
  781. if is_chinese(ans):
  782. ans += LENGTH_NOTIFICATION_CN
  783. else:
  784. ans += LENGTH_NOTIFICATION_EN
  785. yield ans
  786. except openai.APIError as e:
  787. yield ans + "\n**ERROR**: " + str(e)
  788. yield total_tokens
  789. ## openrouter
  790. class OpenRouterChat(Base):
  791. _FACTORY_NAME = "OpenRouter"
  792. def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
  793. if not base_url:
  794. base_url = "https://openrouter.ai/api/v1"
  795. super().__init__(key, model_name, base_url, **kwargs)
  796. class StepFunChat(Base):
  797. _FACTORY_NAME = "StepFun"
  798. def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
  799. if not base_url:
  800. base_url = "https://api.stepfun.com/v1"
  801. super().__init__(key, model_name, base_url, **kwargs)
  802. class LmStudioChat(Base):
  803. _FACTORY_NAME = "LM-Studio"
  804. def __init__(self, key, model_name, base_url, **kwargs):
  805. if not base_url:
  806. raise ValueError("Local llm url cannot be None")
  807. base_url = urljoin(base_url, "v1")
  808. super().__init__(key, model_name, base_url, **kwargs)
  809. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  810. self.model_name = model_name
  811. class OpenAI_APIChat(Base):
  812. _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
  813. def __init__(self, key, model_name, base_url, **kwargs):
  814. if not base_url:
  815. raise ValueError("url cannot be None")
  816. model_name = model_name.split("___")[0]
  817. super().__init__(key, model_name, base_url, **kwargs)
  818. class PPIOChat(Base):
  819. _FACTORY_NAME = "PPIO"
  820. def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
  821. if not base_url:
  822. base_url = "https://api.ppinfra.com/v3/openai"
  823. super().__init__(key, model_name, base_url, **kwargs)
  824. class LeptonAIChat(Base):
  825. _FACTORY_NAME = "LeptonAI"
  826. def __init__(self, key, model_name, base_url=None, **kwargs):
  827. if not base_url:
  828. base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
  829. super().__init__(key, model_name, base_url, **kwargs)
  830. class PerfXCloudChat(Base):
  831. _FACTORY_NAME = "PerfXCloud"
  832. def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
  833. if not base_url:
  834. base_url = "https://cloud.perfxlab.cn/v1"
  835. super().__init__(key, model_name, base_url, **kwargs)
  836. class UpstageChat(Base):
  837. _FACTORY_NAME = "Upstage"
  838. def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
  839. if not base_url:
  840. base_url = "https://api.upstage.ai/v1/solar"
  841. super().__init__(key, model_name, base_url, **kwargs)
  842. class NovitaAIChat(Base):
  843. _FACTORY_NAME = "NovitaAI"
  844. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
  845. if not base_url:
  846. base_url = "https://api.novita.ai/v3/openai"
  847. super().__init__(key, model_name, base_url, **kwargs)
  848. class SILICONFLOWChat(Base):
  849. _FACTORY_NAME = "SILICONFLOW"
  850. def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
  851. if not base_url:
  852. base_url = "https://api.siliconflow.cn/v1"
  853. super().__init__(key, model_name, base_url, **kwargs)
  854. class YiChat(Base):
  855. _FACTORY_NAME = "01.AI"
  856. def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
  857. if not base_url:
  858. base_url = "https://api.lingyiwanwu.com/v1"
  859. super().__init__(key, model_name, base_url, **kwargs)
  860. class GiteeChat(Base):
  861. _FACTORY_NAME = "GiteeAI"
  862. def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
  863. if not base_url:
  864. base_url = "https://ai.gitee.com/v1/"
  865. super().__init__(key, model_name, base_url, **kwargs)
  866. class ReplicateChat(Base):
  867. _FACTORY_NAME = "Replicate"
  868. def __init__(self, key, model_name, base_url=None, **kwargs):
  869. super().__init__(key, model_name, base_url=base_url, **kwargs)
  870. from replicate.client import Client
  871. self.model_name = model_name
  872. self.client = Client(api_token=key)
  873. def _chat(self, history, gen_conf={}, **kwargs):
  874. system = history[0]["content"] if history and history[0]["role"] == "system" else ""
  875. prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
  876. response = self.client.run(
  877. self.model_name,
  878. input={"system_prompt": system, "prompt": prompt, **gen_conf},
  879. )
  880. ans = "".join(response)
  881. return ans, num_tokens_from_string(ans)
  882. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  883. if "max_tokens" in gen_conf:
  884. del gen_conf["max_tokens"]
  885. prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
  886. ans = ""
  887. try:
  888. response = self.client.run(
  889. self.model_name,
  890. input={"system_prompt": system, "prompt": prompt, **gen_conf},
  891. )
  892. for resp in response:
  893. ans = resp
  894. yield ans
  895. except Exception as e:
  896. yield ans + "\n**ERROR**: " + str(e)
  897. yield num_tokens_from_string(ans)
  898. class HunyuanChat(Base):
  899. _FACTORY_NAME = "Tencent Hunyuan"
  900. def __init__(self, key, model_name, base_url=None, **kwargs):
  901. super().__init__(key, model_name, base_url=base_url, **kwargs)
  902. from tencentcloud.common import credential
  903. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  904. key = json.loads(key)
  905. sid = key.get("hunyuan_sid", "")
  906. sk = key.get("hunyuan_sk", "")
  907. cred = credential.Credential(sid, sk)
  908. self.model_name = model_name
  909. self.client = hunyuan_client.HunyuanClient(cred, "")
  910. def _clean_conf(self, gen_conf):
  911. _gen_conf = {}
  912. if "temperature" in gen_conf:
  913. _gen_conf["Temperature"] = gen_conf["temperature"]
  914. if "top_p" in gen_conf:
  915. _gen_conf["TopP"] = gen_conf["top_p"]
  916. return _gen_conf
  917. def _chat(self, history, gen_conf={}, **kwargs):
  918. from tencentcloud.hunyuan.v20230901 import models
  919. hist = [{k.capitalize(): v for k, v in item.items()} for item in history]
  920. req = models.ChatCompletionsRequest()
  921. params = {"Model": self.model_name, "Messages": hist, **gen_conf}
  922. req.from_json_string(json.dumps(params))
  923. response = self.client.ChatCompletions(req)
  924. ans = response.Choices[0].Message.Content
  925. return ans, response.Usage.TotalTokens
  926. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  927. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  928. TencentCloudSDKException,
  929. )
  930. from tencentcloud.hunyuan.v20230901 import models
  931. _gen_conf = {}
  932. _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
  933. if system:
  934. _history.insert(0, {"Role": "system", "Content": system})
  935. if "max_tokens" in gen_conf:
  936. del gen_conf["max_tokens"]
  937. if "temperature" in gen_conf:
  938. _gen_conf["Temperature"] = gen_conf["temperature"]
  939. if "top_p" in gen_conf:
  940. _gen_conf["TopP"] = gen_conf["top_p"]
  941. req = models.ChatCompletionsRequest()
  942. params = {
  943. "Model": self.model_name,
  944. "Messages": _history,
  945. "Stream": True,
  946. **_gen_conf,
  947. }
  948. req.from_json_string(json.dumps(params))
  949. ans = ""
  950. total_tokens = 0
  951. try:
  952. response = self.client.ChatCompletions(req)
  953. for resp in response:
  954. resp = json.loads(resp["data"])
  955. if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
  956. continue
  957. ans = resp["Choices"][0]["Delta"]["Content"]
  958. total_tokens += 1
  959. yield ans
  960. except TencentCloudSDKException as e:
  961. yield ans + "\n**ERROR**: " + str(e)
  962. yield total_tokens
  963. class SparkChat(Base):
  964. _FACTORY_NAME = "XunFei Spark"
  965. def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
  966. if not base_url:
  967. base_url = "https://spark-api-open.xf-yun.com/v1"
  968. model2version = {
  969. "Spark-Max": "generalv3.5",
  970. "Spark-Lite": "general",
  971. "Spark-Pro": "generalv3",
  972. "Spark-Pro-128K": "pro-128k",
  973. "Spark-4.0-Ultra": "4.0Ultra",
  974. }
  975. version2model = {v: k for k, v in model2version.items()}
  976. assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
  977. if model_name in model2version:
  978. model_version = model2version[model_name]
  979. else:
  980. model_version = model_name
  981. super().__init__(key, model_version, base_url, **kwargs)
  982. class BaiduYiyanChat(Base):
  983. _FACTORY_NAME = "BaiduYiyan"
  984. def __init__(self, key, model_name, base_url=None, **kwargs):
  985. super().__init__(key, model_name, base_url=base_url, **kwargs)
  986. import qianfan
  987. key = json.loads(key)
  988. ak = key.get("yiyan_ak", "")
  989. sk = key.get("yiyan_sk", "")
  990. self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
  991. self.model_name = model_name.lower()
  992. def _clean_conf(self, gen_conf):
  993. gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
  994. if "max_tokens" in gen_conf:
  995. del gen_conf["max_tokens"]
  996. return gen_conf
  997. def _chat(self, history, gen_conf):
  998. system = history[0]["content"] if history and history[0]["role"] == "system" else ""
  999. response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
  1000. ans = response["result"]
  1001. return ans, self.total_token_count(response)
  1002. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  1003. gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
  1004. if "max_tokens" in gen_conf:
  1005. del gen_conf["max_tokens"]
  1006. ans = ""
  1007. total_tokens = 0
  1008. try:
  1009. response = self.client.do(model=self.model_name, messages=history, system=system, stream=True, **gen_conf)
  1010. for resp in response:
  1011. resp = resp.body
  1012. ans = resp["result"]
  1013. total_tokens = self.total_token_count(resp)
  1014. yield ans
  1015. except Exception as e:
  1016. return ans + "\n**ERROR**: " + str(e), 0
  1017. yield total_tokens
  1018. class GoogleChat(Base):
  1019. _FACTORY_NAME = "Google Cloud"
  1020. def __init__(self, key, model_name, base_url=None, **kwargs):
  1021. super().__init__(key, model_name, base_url=base_url, **kwargs)
  1022. import base64
  1023. from google.oauth2 import service_account
  1024. key = json.loads(key)
  1025. access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
  1026. project_id = key.get("google_project_id", "")
  1027. region = key.get("google_region", "")
  1028. scopes = ["https://www.googleapis.com/auth/cloud-platform"]
  1029. self.model_name = model_name
  1030. if "claude" in self.model_name:
  1031. from anthropic import AnthropicVertex
  1032. from google.auth.transport.requests import Request
  1033. if access_token:
  1034. credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
  1035. request = Request()
  1036. credits.refresh(request)
  1037. token = credits.token
  1038. self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
  1039. else:
  1040. self.client = AnthropicVertex(region=region, project_id=project_id)
  1041. else:
  1042. import vertexai.generative_models as glm
  1043. from google.cloud import aiplatform
  1044. if access_token:
  1045. credits = service_account.Credentials.from_service_account_info(access_token)
  1046. aiplatform.init(credentials=credits, project=project_id, location=region)
  1047. else:
  1048. aiplatform.init(project=project_id, location=region)
  1049. self.client = glm.GenerativeModel(model_name=self.model_name)
  1050. def _clean_conf(self, gen_conf):
  1051. if "claude" in self.model_name:
  1052. if "max_tokens" in gen_conf:
  1053. del gen_conf["max_tokens"]
  1054. else:
  1055. if "max_tokens" in gen_conf:
  1056. gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
  1057. for k in list(gen_conf.keys()):
  1058. if k not in ["temperature", "top_p", "max_output_tokens"]:
  1059. del gen_conf[k]
  1060. return gen_conf
  1061. def _chat(self, history, gen_conf={}, **kwargs):
  1062. system = history[0]["content"] if history and history[0]["role"] == "system" else ""
  1063. if "claude" in self.model_name:
  1064. response = self.client.messages.create(
  1065. model=self.model_name,
  1066. messages=[h for h in history if h["role"] != "system"],
  1067. system=system,
  1068. stream=False,
  1069. **gen_conf,
  1070. ).json()
  1071. ans = response["content"][0]["text"]
  1072. if response["stop_reason"] == "max_tokens":
  1073. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  1074. return (
  1075. ans,
  1076. response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
  1077. )
  1078. self.client._system_instruction = system
  1079. hist = []
  1080. for item in history:
  1081. if item["role"] == "system":
  1082. continue
  1083. hist.append(deepcopy(item))
  1084. item = hist[-1]
  1085. if "role" in item and item["role"] == "assistant":
  1086. item["role"] = "model"
  1087. if "content" in item:
  1088. item["parts"] = [
  1089. {
  1090. "text": item.pop("content"),
  1091. }
  1092. ]
  1093. response = self.client.generate_content(hist, generation_config=gen_conf)
  1094. ans = response.text
  1095. return ans, response.usage_metadata.total_token_count
  1096. def chat_streamly(self, system, history, gen_conf={}, **kwargs):
  1097. if "claude" in self.model_name:
  1098. if "max_tokens" in gen_conf:
  1099. del gen_conf["max_tokens"]
  1100. ans = ""
  1101. total_tokens = 0
  1102. try:
  1103. response = self.client.messages.create(
  1104. model=self.model_name,
  1105. messages=history,
  1106. system=system,
  1107. stream=True,
  1108. **gen_conf,
  1109. )
  1110. for res in response.iter_lines():
  1111. res = res.decode("utf-8")
  1112. if "content_block_delta" in res and "data" in res:
  1113. text = json.loads(res[6:])["delta"]["text"]
  1114. ans = text
  1115. total_tokens += num_tokens_from_string(text)
  1116. except Exception as e:
  1117. yield ans + "\n**ERROR**: " + str(e)
  1118. yield total_tokens
  1119. else:
  1120. self.client._system_instruction = system
  1121. if "max_tokens" in gen_conf:
  1122. gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
  1123. for k in list(gen_conf.keys()):
  1124. if k not in ["temperature", "top_p", "max_output_tokens"]:
  1125. del gen_conf[k]
  1126. for item in history:
  1127. if "role" in item and item["role"] == "assistant":
  1128. item["role"] = "model"
  1129. if "content" in item:
  1130. item["parts"] = item.pop("content")
  1131. ans = ""
  1132. try:
  1133. response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
  1134. for resp in response:
  1135. ans = resp.text
  1136. yield ans
  1137. except Exception as e:
  1138. yield ans + "\n**ERROR**: " + str(e)
  1139. yield response._chunks[-1].usage_metadata.total_token_count
  1140. class GPUStackChat(Base):
  1141. _FACTORY_NAME = "GPUStack"
  1142. def __init__(self, key=None, model_name="", base_url="", **kwargs):
  1143. if not base_url:
  1144. raise ValueError("Local llm url cannot be None")
  1145. base_url = urljoin(base_url, "v1")
  1146. super().__init__(key, model_name, base_url, **kwargs)
  1147. class Ai302Chat(Base):
  1148. _FACTORY_NAME = "302.AI"
  1149. def __init__(self, key, model_name, base_url="https://api.302.ai/v1", **kwargs):
  1150. if not base_url:
  1151. base_url = "https://api.302.ai/v1"
  1152. super().__init__(key, model_name, base_url, **kwargs)
  1153. class LiteLLMBase(ABC):
  1154. _FACTORY_NAME = ["Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic"]
  1155. def __init__(self, key, model_name, base_url=None, **kwargs):
  1156. self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
  1157. self.provider = kwargs.get("provider", "")
  1158. self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "")
  1159. self.model_name = f"{self.prefix}{model_name}"
  1160. self.api_key = key
  1161. self.base_url = base_url or FACTORY_DEFAULT_BASE_URL.get(self.provider, "")
  1162. # Configure retry parameters
  1163. self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
  1164. self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
  1165. self.max_rounds = kwargs.get("max_rounds", 5)
  1166. self.is_tools = False
  1167. self.tools = []
  1168. self.toolcall_sessions = {}
  1169. # Factory specific fields
  1170. if self.provider == SupportedLiteLLMProvider.Bedrock:
  1171. self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
  1172. self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
  1173. self.bedrock_region = json.loads(key).get("bedrock_region", "")
  1174. def _get_delay(self):
  1175. """Calculate retry delay time"""
  1176. return self.base_delay * random.uniform(10, 150)
  1177. def _classify_error(self, error):
  1178. """Classify error based on error message content"""
  1179. error_str = str(error).lower()
  1180. keywords_mapping = [
  1181. (["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
  1182. (["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
  1183. (["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
  1184. (["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
  1185. (["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
  1186. (["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
  1187. (["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
  1188. (["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
  1189. (["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
  1190. (["max rounds"], LLMErrorCode.ERROR_MODEL),
  1191. ]
  1192. for words, code in keywords_mapping:
  1193. if re.search("({})".format("|".join(words)), error_str):
  1194. return code
  1195. return LLMErrorCode.ERROR_GENERIC
  1196. def _clean_conf(self, gen_conf):
  1197. if "max_tokens" in gen_conf:
  1198. del gen_conf["max_tokens"]
  1199. return gen_conf
  1200. def _chat(self, history, gen_conf, **kwargs):
  1201. logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
  1202. if self.model_name.lower().find("qwen3") >= 0:
  1203. kwargs["extra_body"] = {"enable_thinking": False}
  1204. completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
  1205. response = litellm.completion(
  1206. **completion_args,
  1207. drop_params=True,
  1208. timeout=self.timeout,
  1209. )
  1210. # response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
  1211. if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
  1212. return "", 0
  1213. ans = response.choices[0].message.content.strip()
  1214. if response.choices[0].finish_reason == "length":
  1215. ans = self._length_stop(ans)
  1216. return ans, self.total_token_count(response)
  1217. def _chat_streamly(self, history, gen_conf, **kwargs):
  1218. logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
  1219. reasoning_start = False
  1220. completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
  1221. stop = kwargs.get("stop")
  1222. if stop:
  1223. completion_args["stop"] = stop
  1224. response = litellm.completion(
  1225. **completion_args,
  1226. drop_params=True,
  1227. timeout=self.timeout,
  1228. )
  1229. for resp in response:
  1230. if not hasattr(resp, "choices") or not resp.choices:
  1231. continue
  1232. delta = resp.choices[0].delta
  1233. if not hasattr(delta, "content") or delta.content is None:
  1234. delta.content = ""
  1235. if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
  1236. ans = ""
  1237. if not reasoning_start:
  1238. reasoning_start = True
  1239. ans = "<think>"
  1240. ans += delta.reasoning_content + "</think>"
  1241. else:
  1242. reasoning_start = False
  1243. ans = delta.content
  1244. tol = self.total_token_count(resp)
  1245. if not tol:
  1246. tol = num_tokens_from_string(delta.content)
  1247. finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
  1248. if finish_reason == "length":
  1249. if is_chinese(ans):
  1250. ans += LENGTH_NOTIFICATION_CN
  1251. else:
  1252. ans += LENGTH_NOTIFICATION_EN
  1253. yield ans, tol
  1254. def _length_stop(self, ans):
  1255. if is_chinese([ans]):
  1256. return ans + LENGTH_NOTIFICATION_CN
  1257. return ans + LENGTH_NOTIFICATION_EN
  1258. def _exceptions(self, e, attempt):
  1259. logging.exception("OpenAI chat_with_tools")
  1260. # Classify the error
  1261. error_code = self._classify_error(e)
  1262. if attempt == self.max_retries:
  1263. error_code = LLMErrorCode.ERROR_MAX_RETRIES
  1264. # Check if it's a rate limit error or server error and not the last attempt
  1265. should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
  1266. if not should_retry:
  1267. return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
  1268. delay = self._get_delay()
  1269. logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
  1270. time.sleep(delay)
  1271. def _verbose_tool_use(self, name, args, res):
  1272. return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
  1273. def _append_history(self, hist, tool_call, tool_res):
  1274. hist.append(
  1275. {
  1276. "role": "assistant",
  1277. "tool_calls": [
  1278. {
  1279. "index": tool_call.index,
  1280. "id": tool_call.id,
  1281. "function": {
  1282. "name": tool_call.function.name,
  1283. "arguments": tool_call.function.arguments,
  1284. },
  1285. "type": "function",
  1286. },
  1287. ],
  1288. }
  1289. )
  1290. try:
  1291. if isinstance(tool_res, dict):
  1292. tool_res = json.dumps(tool_res, ensure_ascii=False)
  1293. finally:
  1294. hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
  1295. return hist
  1296. def bind_tools(self, toolcall_session, tools):
  1297. if not (toolcall_session and tools):
  1298. return
  1299. self.is_tools = True
  1300. self.toolcall_session = toolcall_session
  1301. self.tools = tools
  1302. def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
  1303. completion_args = {
  1304. "model": self.model_name,
  1305. "messages": history,
  1306. "api_key": self.api_key,
  1307. **kwargs,
  1308. }
  1309. if stream:
  1310. completion_args.update(
  1311. {
  1312. "stream": stream,
  1313. }
  1314. )
  1315. if tools and self.tools:
  1316. completion_args.update(
  1317. {
  1318. "tools": self.tools,
  1319. "tool_choice": "auto",
  1320. }
  1321. )
  1322. if self.provider in FACTORY_DEFAULT_BASE_URL:
  1323. completion_args.update({"api_base": self.base_url})
  1324. elif self.provider == SupportedLiteLLMProvider.Bedrock:
  1325. completion_args.pop("api_key", None)
  1326. completion_args.pop("api_base", None)
  1327. completion_args.update(
  1328. {
  1329. "aws_access_key_id": self.bedrock_ak,
  1330. "aws_secret_access_key": self.bedrock_sk,
  1331. "aws_region_name": self.bedrock_region,
  1332. }
  1333. )
  1334. return completion_args
  1335. def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
  1336. gen_conf = self._clean_conf(gen_conf)
  1337. if system:
  1338. history.insert(0, {"role": "system", "content": system})
  1339. ans = ""
  1340. tk_count = 0
  1341. hist = deepcopy(history)
  1342. # Implement exponential backoff retry strategy
  1343. for attempt in range(self.max_retries + 1):
  1344. history = deepcopy(hist) # deepcopy is required here
  1345. try:
  1346. for _ in range(self.max_rounds + 1):
  1347. logging.info(f"{self.tools=}")
  1348. completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
  1349. response = litellm.completion(
  1350. **completion_args,
  1351. drop_params=True,
  1352. timeout=self.timeout,
  1353. )
  1354. tk_count += self.total_token_count(response)
  1355. if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
  1356. raise Exception(f"500 response structure error. Response: {response}")
  1357. message = response.choices[0].message
  1358. if not hasattr(message, "tool_calls") or not message.tool_calls:
  1359. if hasattr(message, "reasoning_content") and message.reasoning_content:
  1360. ans += f"<think>{message.reasoning_content}</think>"
  1361. ans += message.content or ""
  1362. if response.choices[0].finish_reason == "length":
  1363. ans = self._length_stop(ans)
  1364. return ans, tk_count
  1365. for tool_call in message.tool_calls:
  1366. logging.info(f"Response {tool_call=}")
  1367. name = tool_call.function.name
  1368. try:
  1369. args = json_repair.loads(tool_call.function.arguments)
  1370. tool_response = self.toolcall_session.tool_call(name, args)
  1371. history = self._append_history(history, tool_call, tool_response)
  1372. ans += self._verbose_tool_use(name, args, tool_response)
  1373. except Exception as e:
  1374. logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
  1375. history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
  1376. ans += self._verbose_tool_use(name, {}, str(e))
  1377. logging.warning(f"Exceed max rounds: {self.max_rounds}")
  1378. history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
  1379. response, token_count = self._chat(history, gen_conf)
  1380. ans += response
  1381. tk_count += token_count
  1382. return ans, tk_count
  1383. except Exception as e:
  1384. e = self._exceptions(e, attempt)
  1385. if e:
  1386. return e, tk_count
  1387. assert False, "Shouldn't be here."
  1388. def chat(self, system, history, gen_conf={}, **kwargs):
  1389. if system:
  1390. history.insert(0, {"role": "system", "content": system})
  1391. gen_conf = self._clean_conf(gen_conf)
  1392. # Implement exponential backoff retry strategy
  1393. for attempt in range(self.max_retries + 1):
  1394. try:
  1395. response = self._chat(history, gen_conf, **kwargs)
  1396. return response
  1397. except Exception as e:
  1398. e = self._exceptions(e, attempt)
  1399. if e:
  1400. return e, 0
  1401. assert False, "Shouldn't be here."
  1402. def _wrap_toolcall_message(self, stream):
  1403. final_tool_calls = {}
  1404. for chunk in stream:
  1405. for tool_call in chunk.choices[0].delta.tool_calls or []:
  1406. index = tool_call.index
  1407. if index not in final_tool_calls:
  1408. final_tool_calls[index] = tool_call
  1409. final_tool_calls[index].function.arguments += tool_call.function.arguments
  1410. return final_tool_calls
  1411. def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
  1412. gen_conf = self._clean_conf(gen_conf)
  1413. tools = self.tools
  1414. if system:
  1415. history.insert(0, {"role": "system", "content": system})
  1416. total_tokens = 0
  1417. hist = deepcopy(history)
  1418. # Implement exponential backoff retry strategy
  1419. for attempt in range(self.max_retries + 1):
  1420. history = deepcopy(hist) # deepcopy is required here
  1421. try:
  1422. for _ in range(self.max_rounds + 1):
  1423. reasoning_start = False
  1424. logging.info(f"{tools=}")
  1425. completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
  1426. response = litellm.completion(
  1427. **completion_args,
  1428. drop_params=True,
  1429. timeout=self.timeout,
  1430. )
  1431. final_tool_calls = {}
  1432. answer = ""
  1433. for resp in response:
  1434. if not hasattr(resp, "choices") or not resp.choices:
  1435. continue
  1436. delta = resp.choices[0].delta
  1437. if hasattr(delta, "tool_calls") and delta.tool_calls:
  1438. for tool_call in delta.tool_calls:
  1439. index = tool_call.index
  1440. if index not in final_tool_calls:
  1441. if not tool_call.function.arguments:
  1442. tool_call.function.arguments = ""
  1443. final_tool_calls[index] = tool_call
  1444. else:
  1445. final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
  1446. continue
  1447. if not hasattr(delta, "content") or delta.content is None:
  1448. delta.content = ""
  1449. if hasattr(delta, "reasoning_content") and delta.reasoning_content:
  1450. ans = ""
  1451. if not reasoning_start:
  1452. reasoning_start = True
  1453. ans = "<think>"
  1454. ans += delta.reasoning_content + "</think>"
  1455. yield ans
  1456. else:
  1457. reasoning_start = False
  1458. answer += delta.content
  1459. yield delta.content
  1460. tol = self.total_token_count(resp)
  1461. if not tol:
  1462. total_tokens += num_tokens_from_string(delta.content)
  1463. else:
  1464. total_tokens += tol
  1465. finish_reason = getattr(resp.choices[0], "finish_reason", "")
  1466. if finish_reason == "length":
  1467. yield self._length_stop("")
  1468. if answer:
  1469. yield total_tokens
  1470. return
  1471. for tool_call in final_tool_calls.values():
  1472. name = tool_call.function.name
  1473. try:
  1474. args = json_repair.loads(tool_call.function.arguments)
  1475. yield self._verbose_tool_use(name, args, "Begin to call...")
  1476. tool_response = self.toolcall_session.tool_call(name, args)
  1477. history = self._append_history(history, tool_call, tool_response)
  1478. yield self._verbose_tool_use(name, args, tool_response)
  1479. except Exception as e:
  1480. logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
  1481. history.append(
  1482. {
  1483. "role": "tool",
  1484. "tool_call_id": tool_call.id,
  1485. "content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
  1486. }
  1487. )
  1488. yield self._verbose_tool_use(name, {}, str(e))
  1489. logging.warning(f"Exceed max rounds: {self.max_rounds}")
  1490. history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
  1491. completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
  1492. response = litellm.completion(
  1493. **completion_args,
  1494. drop_params=True,
  1495. timeout=self.timeout,
  1496. )
  1497. for resp in response:
  1498. if not hasattr(resp, "choices") or not resp.choices:
  1499. continue
  1500. delta = resp.choices[0].delta
  1501. if not hasattr(delta, "content") or delta.content is None:
  1502. continue
  1503. tol = self.total_token_count(resp)
  1504. if not tol:
  1505. total_tokens += num_tokens_from_string(delta.content)
  1506. else:
  1507. total_tokens += tol
  1508. yield delta.content
  1509. yield total_tokens
  1510. return
  1511. except Exception as e:
  1512. e = self._exceptions(e, attempt)
  1513. if e:
  1514. yield e
  1515. yield total_tokens
  1516. return
  1517. assert False, "Shouldn't be here."
  1518. def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
  1519. if system:
  1520. history.insert(0, {"role": "system", "content": system})
  1521. gen_conf = self._clean_conf(gen_conf)
  1522. ans = ""
  1523. total_tokens = 0
  1524. try:
  1525. for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
  1526. yield delta_ans
  1527. total_tokens += tol
  1528. except openai.APIError as e:
  1529. yield ans + "\n**ERROR**: " + str(e)
  1530. yield total_tokens
  1531. def total_token_count(self, resp):
  1532. try:
  1533. return resp.usage.total_tokens
  1534. except Exception:
  1535. pass
  1536. try:
  1537. return resp["usage"]["total_tokens"]
  1538. except Exception:
  1539. pass
  1540. return 0
  1541. def _calculate_dynamic_ctx(self, history):
  1542. """Calculate dynamic context window size"""
  1543. def count_tokens(text):
  1544. """Calculate token count for text"""
  1545. # Simple calculation: 1 token per ASCII character
  1546. # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
  1547. total = 0
  1548. for char in text:
  1549. if ord(char) < 128: # ASCII characters
  1550. total += 1
  1551. else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
  1552. total += 2
  1553. return total
  1554. # Calculate total tokens for all messages
  1555. total_tokens = 0
  1556. for message in history:
  1557. content = message.get("content", "")
  1558. # Calculate content tokens
  1559. content_tokens = count_tokens(content)
  1560. # Add role marker token overhead
  1561. role_tokens = 4
  1562. total_tokens += content_tokens + role_tokens
  1563. # Apply 1.2x buffer ratio
  1564. total_tokens_with_buffer = int(total_tokens * 1.2)
  1565. if total_tokens_with_buffer <= 8192:
  1566. ctx_size = 8192
  1567. else:
  1568. ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
  1569. ctx_size = ctx_multiplier * 8192
  1570. return ctx_size