Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

llm_service.py 19KB

Feat: Support tool calling in Generate component (#7572) ### What problem does this PR solve? Hello, our use case requires LLM agent to invoke some tools, so I made a simple implementation here. This PR does two things: 1. A simple plugin mechanism based on `pluginlib`: This mechanism lives in the `plugin` directory. It will only load plugins from `plugin/embedded_plugins` for now. A sample plugin `bad_calculator.py` is placed in `plugin/embedded_plugins/llm_tools`, it accepts two numbers `a` and `b`, then give a wrong result `a + b + 100`. In the future, it can load plugins from external location with little code change. Plugins are divided into different types. The only plugin type supported in this PR is `llm_tools`, which must implement the `LLMToolPlugin` class in the `plugin/llm_tool_plugin.py`. More plugin types can be added in the future. 2. A tool selector in the `Generate` component: Added a tool selector to select one or more tools for LLM: ![image](https://github.com/user-attachments/assets/74a21fdf-9333-4175-991b-43df6524c5dc) And with the `bad_calculator` tool, it results this with the `qwen-max` model: ![image](https://github.com/user-attachments/assets/93aff9c4-8550-414a-90a2-1a15a5249d94) ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
5 meses atrás
Fix: pymysql.err.InterfaceError: (0, '') during long time streaming chat responses (#6548) (#7057) ### Related Issue: https://github.com/infiniflow/ragflow/issues/6548 ### Related PR: https://github.com/infiniflow/ragflow/pull/6861 ### Environment: Commit version: [[48730e0](https://github.com/infiniflow/ragflow/commit/48730e00a864606606a9d0778620d75411488740)] ### Bug Description: Unexpected `pymysql.err.InterfaceError: (0, '') `when using Peewee + PyMySQL + PooledMySQLDatabase after a long-running `chat streamly` operation. This is a common issue with Peewee + PyMySQL + connection pooling: you end up using a connection that was silently closed by the server, but Peewee doesn't realize it's dead. **I found that the error only occurs during longer streaming outputs** and is unrelated to the database connection context, so it's likely because: - The prolonged streaming response caused the database connection to time out - The original database connection might have been disconnected by the server during the streaming process ### Why This Happens This error happens even when using `@DB.connection_context() `after the stream is done. After investigation, I found this is caused by MySQL connection pools that appear to be open but are actually dead (expired due to` wait_timeout`). 1. `@DB.connection_context()` (as a decorator or context manager) pulls a connection from the pool. 2. If this connection was idle and expired on the MySQL server (e.g., due to `wait_timeout`), but not closed in Python, it will still be considered “open” (`DB.is_closed() == False`). 3. The real error will occur only when I execute a SQL command (such as .`get_or_none()`), and PyMySQL tries to send it to the server via a broken socket. ### Changes Made: 1. I implemented manual connection checks before executing SQL: ``` try: DB.execute_sql("SELECT 1") except Exception: print("Connection dead, reconnecting...") DB.close() DB.connect() ``` 2. Delayed the token count update until after the streaming response is completed to ensure the streaming output isn't interrupted by database operations. ``` total_tokens = 0 for txt in chat_streamly(system, history, gen_conf): if isinstance(txt, int): total_tokens = txt ...... break ...... if total_tokens > 0: if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) ```
6 meses atrás
Fix: pymysql.err.InterfaceError: (0, '') during long time streaming chat responses (#6548) (#7057) ### Related Issue: https://github.com/infiniflow/ragflow/issues/6548 ### Related PR: https://github.com/infiniflow/ragflow/pull/6861 ### Environment: Commit version: [[48730e0](https://github.com/infiniflow/ragflow/commit/48730e00a864606606a9d0778620d75411488740)] ### Bug Description: Unexpected `pymysql.err.InterfaceError: (0, '') `when using Peewee + PyMySQL + PooledMySQLDatabase after a long-running `chat streamly` operation. This is a common issue with Peewee + PyMySQL + connection pooling: you end up using a connection that was silently closed by the server, but Peewee doesn't realize it's dead. **I found that the error only occurs during longer streaming outputs** and is unrelated to the database connection context, so it's likely because: - The prolonged streaming response caused the database connection to time out - The original database connection might have been disconnected by the server during the streaming process ### Why This Happens This error happens even when using `@DB.connection_context() `after the stream is done. After investigation, I found this is caused by MySQL connection pools that appear to be open but are actually dead (expired due to` wait_timeout`). 1. `@DB.connection_context()` (as a decorator or context manager) pulls a connection from the pool. 2. If this connection was idle and expired on the MySQL server (e.g., due to `wait_timeout`), but not closed in Python, it will still be considered “open” (`DB.is_closed() == False`). 3. The real error will occur only when I execute a SQL command (such as .`get_or_none()`), and PyMySQL tries to send it to the server via a broken socket. ### Changes Made: 1. I implemented manual connection checks before executing SQL: ``` try: DB.execute_sql("SELECT 1") except Exception: print("Connection dead, reconnecting...") DB.close() DB.connect() ``` 2. Delayed the token count update until after the streaming response is completed to ensure the streaming output isn't interrupted by database operations. ``` total_tokens = 0 for txt in chat_streamly(system, history, gen_conf): if isinstance(txt, int): total_tokens = txt ...... break ...... if total_tokens > 0: if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) ```
6 meses atrás
Fix: pymysql.err.InterfaceError: (0, '') during long time streaming chat responses (#6548) (#7057) ### Related Issue: https://github.com/infiniflow/ragflow/issues/6548 ### Related PR: https://github.com/infiniflow/ragflow/pull/6861 ### Environment: Commit version: [[48730e0](https://github.com/infiniflow/ragflow/commit/48730e00a864606606a9d0778620d75411488740)] ### Bug Description: Unexpected `pymysql.err.InterfaceError: (0, '') `when using Peewee + PyMySQL + PooledMySQLDatabase after a long-running `chat streamly` operation. This is a common issue with Peewee + PyMySQL + connection pooling: you end up using a connection that was silently closed by the server, but Peewee doesn't realize it's dead. **I found that the error only occurs during longer streaming outputs** and is unrelated to the database connection context, so it's likely because: - The prolonged streaming response caused the database connection to time out - The original database connection might have been disconnected by the server during the streaming process ### Why This Happens This error happens even when using `@DB.connection_context() `after the stream is done. After investigation, I found this is caused by MySQL connection pools that appear to be open but are actually dead (expired due to` wait_timeout`). 1. `@DB.connection_context()` (as a decorator or context manager) pulls a connection from the pool. 2. If this connection was idle and expired on the MySQL server (e.g., due to `wait_timeout`), but not closed in Python, it will still be considered “open” (`DB.is_closed() == False`). 3. The real error will occur only when I execute a SQL command (such as .`get_or_none()`), and PyMySQL tries to send it to the server via a broken socket. ### Changes Made: 1. I implemented manual connection checks before executing SQL: ``` try: DB.execute_sql("SELECT 1") except Exception: print("Connection dead, reconnecting...") DB.close() DB.connect() ``` 2. Delayed the token count update until after the streaming response is completed to ensure the streaming output isn't interrupted by database operations. ``` total_tokens = 0 for txt in chat_streamly(system, history, gen_conf): if isinstance(txt, int): total_tokens = txt ...... break ...... if total_tokens > 0: if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) ```
6 meses atrás
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. from functools import partial
  19. from typing import Generator
  20. from langfuse import Langfuse
  21. from api import settings
  22. from api.db import LLMType
  23. from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
  24. from api.db.services.common_service import CommonService
  25. from api.db.services.langfuse_service import TenantLangfuseService
  26. from api.db.services.user_service import TenantService
  27. from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
  28. class LLMFactoriesService(CommonService):
  29. model = LLMFactories
  30. class LLMService(CommonService):
  31. model = LLM
  32. class TenantLLMService(CommonService):
  33. model = TenantLLM
  34. @classmethod
  35. @DB.connection_context()
  36. def get_api_key(cls, tenant_id, model_name):
  37. mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
  38. if not fid:
  39. objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
  40. else:
  41. objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
  42. if (not objs) and fid:
  43. if fid == "LocalAI":
  44. mdlnm += "___LocalAI"
  45. elif fid == "HuggingFace":
  46. mdlnm += "___HuggingFace"
  47. elif fid == "OpenAI-API-Compatible":
  48. mdlnm += "___OpenAI-API"
  49. elif fid == "VLLM":
  50. mdlnm += "___VLLM"
  51. objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
  52. if not objs:
  53. return
  54. return objs[0]
  55. @classmethod
  56. @DB.connection_context()
  57. def get_my_llms(cls, tenant_id):
  58. fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
  59. objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
  60. return list(objs)
  61. @staticmethod
  62. def split_model_name_and_factory(model_name):
  63. arr = model_name.split("@")
  64. if len(arr) < 2:
  65. return model_name, None
  66. if len(arr) > 2:
  67. return "@".join(arr[0:-1]), arr[-1]
  68. # model name must be xxx@yyy
  69. try:
  70. model_factories = settings.FACTORY_LLM_INFOS
  71. model_providers = set([f["name"] for f in model_factories])
  72. if arr[-1] not in model_providers:
  73. return model_name, None
  74. return arr[0], arr[-1]
  75. except Exception as e:
  76. logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
  77. return model_name, None
  78. @classmethod
  79. @DB.connection_context()
  80. def get_model_config(cls, tenant_id, llm_type, llm_name=None):
  81. e, tenant = TenantService.get_by_id(tenant_id)
  82. if not e:
  83. raise LookupError("Tenant not found")
  84. if llm_type == LLMType.EMBEDDING.value:
  85. mdlnm = tenant.embd_id if not llm_name else llm_name
  86. elif llm_type == LLMType.SPEECH2TEXT.value:
  87. mdlnm = tenant.asr_id
  88. elif llm_type == LLMType.IMAGE2TEXT.value:
  89. mdlnm = tenant.img2txt_id if not llm_name else llm_name
  90. elif llm_type == LLMType.CHAT.value:
  91. mdlnm = tenant.llm_id if not llm_name else llm_name
  92. elif llm_type == LLMType.RERANK:
  93. mdlnm = tenant.rerank_id if not llm_name else llm_name
  94. elif llm_type == LLMType.TTS:
  95. mdlnm = tenant.tts_id if not llm_name else llm_name
  96. else:
  97. assert False, "LLM type error"
  98. model_config = cls.get_api_key(tenant_id, mdlnm)
  99. mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
  100. if not model_config: # for some cases seems fid mismatch
  101. model_config = cls.get_api_key(tenant_id, mdlnm)
  102. if model_config:
  103. model_config = model_config.to_dict()
  104. llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
  105. if not llm and fid: # for some cases seems fid mismatch
  106. llm = LLMService.query(llm_name=mdlnm)
  107. if llm:
  108. model_config["is_tools"] = llm[0].is_tools
  109. if not model_config:
  110. if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
  111. llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
  112. if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
  113. model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
  114. if not model_config:
  115. if mdlnm == "flag-embedding":
  116. model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
  117. else:
  118. if not mdlnm:
  119. raise LookupError(f"Type of {llm_type} model is not set.")
  120. raise LookupError("Model({}) not authorized".format(mdlnm))
  121. return model_config
  122. @classmethod
  123. @DB.connection_context()
  124. def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
  125. model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
  126. if llm_type == LLMType.EMBEDDING.value:
  127. if model_config["llm_factory"] not in EmbeddingModel:
  128. return
  129. return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
  130. if llm_type == LLMType.RERANK:
  131. if model_config["llm_factory"] not in RerankModel:
  132. return
  133. return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
  134. if llm_type == LLMType.IMAGE2TEXT.value:
  135. if model_config["llm_factory"] not in CvModel:
  136. return
  137. return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
  138. if llm_type == LLMType.CHAT.value:
  139. if model_config["llm_factory"] not in ChatModel:
  140. return
  141. return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
  142. if llm_type == LLMType.SPEECH2TEXT:
  143. if model_config["llm_factory"] not in Seq2txtModel:
  144. return
  145. return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
  146. if llm_type == LLMType.TTS:
  147. if model_config["llm_factory"] not in TTSModel:
  148. return
  149. return TTSModel[model_config["llm_factory"]](
  150. model_config["api_key"],
  151. model_config["llm_name"],
  152. base_url=model_config["api_base"],
  153. )
  154. @classmethod
  155. @DB.connection_context()
  156. def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
  157. e, tenant = TenantService.get_by_id(tenant_id)
  158. if not e:
  159. logging.error(f"Tenant not found: {tenant_id}")
  160. return 0
  161. llm_map = {
  162. LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
  163. LLMType.SPEECH2TEXT.value: tenant.asr_id,
  164. LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
  165. LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
  166. LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
  167. LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
  168. }
  169. mdlnm = llm_map.get(llm_type)
  170. if mdlnm is None:
  171. logging.error(f"LLM type error: {llm_type}")
  172. return 0
  173. llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
  174. try:
  175. num = (
  176. cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
  177. .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
  178. .execute()
  179. )
  180. except Exception:
  181. logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
  182. return 0
  183. return num
  184. @classmethod
  185. @DB.connection_context()
  186. def get_openai_models(cls):
  187. objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
  188. return list(objs)
  189. @staticmethod
  190. def llm_id2llm_type(llm_id: str) -> str | None:
  191. llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
  192. llm_factories = settings.FACTORY_LLM_INFOS
  193. for llm_factory in llm_factories:
  194. for llm in llm_factory["llm"]:
  195. if llm_id == llm["llm_name"]:
  196. return llm["model_type"].split(",")[-1]
  197. class LLMBundle:
  198. def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
  199. self.tenant_id = tenant_id
  200. self.llm_type = llm_type
  201. self.llm_name = llm_name
  202. self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
  203. assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
  204. model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
  205. self.max_length = model_config.get("max_tokens", 8192)
  206. self.is_tools = model_config.get("is_tools", False)
  207. self.verbose_tool_use = kwargs.get("verbose_tool_use")
  208. langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
  209. self.langfuse = None
  210. if langfuse_keys:
  211. langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
  212. if langfuse.auth_check():
  213. self.langfuse = langfuse
  214. trace_id = self.langfuse.create_trace_id()
  215. self.trace_context = {"trace_id": trace_id}
  216. def bind_tools(self, toolcall_session, tools):
  217. if not self.is_tools:
  218. logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
  219. return
  220. self.mdl.bind_tools(toolcall_session, tools)
  221. def encode(self, texts: list):
  222. if self.langfuse:
  223. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
  224. embeddings, used_tokens = self.mdl.encode(texts)
  225. llm_name = getattr(self, "llm_name", None)
  226. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
  227. logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
  228. if self.langfuse:
  229. generation.update(usage_details={"total_tokens": used_tokens})
  230. generation.end()
  231. return embeddings, used_tokens
  232. def encode_queries(self, query: str):
  233. if self.langfuse:
  234. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
  235. emd, used_tokens = self.mdl.encode_queries(query)
  236. llm_name = getattr(self, "llm_name", None)
  237. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
  238. logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
  239. if self.langfuse:
  240. generation.update(usage_details={"total_tokens": used_tokens})
  241. generation.end()
  242. return emd, used_tokens
  243. def similarity(self, query: str, texts: list):
  244. if self.langfuse:
  245. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
  246. sim, used_tokens = self.mdl.similarity(query, texts)
  247. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  248. logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
  249. if self.langfuse:
  250. generation.update(usage_details={"total_tokens": used_tokens})
  251. generation.end()
  252. return sim, used_tokens
  253. def describe(self, image, max_tokens=300):
  254. if self.langfuse:
  255. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
  256. txt, used_tokens = self.mdl.describe(image)
  257. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  258. logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
  259. if self.langfuse:
  260. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  261. generation.end()
  262. return txt
  263. def describe_with_prompt(self, image, prompt):
  264. if self.langfuse:
  265. generation = self.language.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
  266. txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
  267. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  268. logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
  269. if self.langfuse:
  270. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  271. generation.end()
  272. return txt
  273. def transcription(self, audio):
  274. if self.langfuse:
  275. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
  276. txt, used_tokens = self.mdl.transcription(audio)
  277. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  278. logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
  279. if self.langfuse:
  280. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  281. generation.end()
  282. return txt
  283. def tts(self, text: str) -> Generator[bytes, None, None]:
  284. if self.langfuse:
  285. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
  286. for chunk in self.mdl.tts(text):
  287. if isinstance(chunk, int):
  288. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
  289. logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
  290. return
  291. yield chunk
  292. if self.langfuse:
  293. generation.end()
  294. def _remove_reasoning_content(self, txt: str) -> str:
  295. first_think_start = txt.find("<think>")
  296. if first_think_start == -1:
  297. return txt
  298. last_think_end = txt.rfind("</think>")
  299. if last_think_end == -1:
  300. return txt
  301. if last_think_end < first_think_start:
  302. return txt
  303. return txt[last_think_end + len("</think>") :]
  304. def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
  305. if self.langfuse:
  306. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
  307. chat_partial = partial(self.mdl.chat, system, history, gen_conf)
  308. if self.is_tools and self.mdl.is_tools:
  309. chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
  310. txt, used_tokens = chat_partial(**kwargs)
  311. txt = self._remove_reasoning_content(txt)
  312. if not self.verbose_tool_use:
  313. txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
  314. if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
  315. logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
  316. if self.langfuse:
  317. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  318. generation.end()
  319. return txt
  320. def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
  321. if self.langfuse:
  322. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
  323. ans = ""
  324. chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
  325. total_tokens = 0
  326. if self.is_tools and self.mdl.is_tools:
  327. chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
  328. for txt in chat_partial(**kwargs):
  329. if isinstance(txt, int):
  330. total_tokens = txt
  331. if self.langfuse:
  332. generation.update(output={"output": ans})
  333. generation.end()
  334. break
  335. if txt.endswith("</think>"):
  336. ans = ans.rstrip("</think>")
  337. if not self.verbose_tool_use:
  338. txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
  339. ans += txt
  340. yield ans
  341. if total_tokens > 0:
  342. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
  343. logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))