Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import inspect
  17. import logging
  18. import re
  19. from functools import partial
  20. from typing import Generator
  21. from api.db.db_models import LLM
  22. from api.db.services.common_service import CommonService
  23. from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
  24. class LLMService(CommonService):
  25. model = LLM
  26. def get_init_tenant_llm(user_id):
  27. from api import settings
  28. tenant_llm = []
  29. seen = set()
  30. factory_configs = []
  31. for factory_config in [
  32. settings.CHAT_CFG,
  33. settings.EMBEDDING_CFG,
  34. settings.ASR_CFG,
  35. settings.IMAGE2TEXT_CFG,
  36. settings.RERANK_CFG,
  37. ]:
  38. factory_name = factory_config["factory"]
  39. if factory_name not in seen:
  40. seen.add(factory_name)
  41. factory_configs.append(factory_config)
  42. for factory_config in factory_configs:
  43. for llm in LLMService.query(fid=factory_config["factory"]):
  44. tenant_llm.append(
  45. {
  46. "tenant_id": user_id,
  47. "llm_factory": factory_config["factory"],
  48. "llm_name": llm.llm_name,
  49. "model_type": llm.model_type,
  50. "api_key": factory_config["api_key"],
  51. "api_base": factory_config["base_url"],
  52. "max_tokens": llm.max_tokens if llm.max_tokens else 8192,
  53. }
  54. )
  55. if settings.LIGHTEN != 1:
  56. for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
  57. mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
  58. tenant_llm.append(
  59. {
  60. "tenant_id": user_id,
  61. "llm_factory": fid,
  62. "llm_name": mdlnm,
  63. "model_type": "embedding",
  64. "api_key": "",
  65. "api_base": "",
  66. "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
  67. }
  68. )
  69. unique = {}
  70. for item in tenant_llm:
  71. key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
  72. if key not in unique:
  73. unique[key] = item
  74. return list(unique.values())
  75. class LLMBundle(LLM4Tenant):
  76. def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
  77. super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
  78. def bind_tools(self, toolcall_session, tools):
  79. if not self.is_tools:
  80. logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
  81. return
  82. self.mdl.bind_tools(toolcall_session, tools)
  83. def encode(self, texts: list):
  84. if self.langfuse:
  85. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
  86. embeddings, used_tokens = self.mdl.encode(texts)
  87. llm_name = getattr(self, "llm_name", None)
  88. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
  89. logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
  90. if self.langfuse:
  91. generation.update(usage_details={"total_tokens": used_tokens})
  92. generation.end()
  93. return embeddings, used_tokens
  94. def encode_queries(self, query: str):
  95. if self.langfuse:
  96. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
  97. emd, used_tokens = self.mdl.encode_queries(query)
  98. llm_name = getattr(self, "llm_name", None)
  99. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
  100. logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
  101. if self.langfuse:
  102. generation.update(usage_details={"total_tokens": used_tokens})
  103. generation.end()
  104. return emd, used_tokens
  105. def similarity(self, query: str, texts: list):
  106. if self.langfuse:
  107. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
  108. sim, used_tokens = self.mdl.similarity(query, texts)
  109. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  110. logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
  111. if self.langfuse:
  112. generation.update(usage_details={"total_tokens": used_tokens})
  113. generation.end()
  114. return sim, used_tokens
  115. def describe(self, image, max_tokens=300):
  116. if self.langfuse:
  117. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
  118. txt, used_tokens = self.mdl.describe(image)
  119. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  120. logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
  121. if self.langfuse:
  122. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  123. generation.end()
  124. return txt
  125. def describe_with_prompt(self, image, prompt):
  126. if self.langfuse:
  127. generation = self.language.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
  128. txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
  129. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  130. logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
  131. if self.langfuse:
  132. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  133. generation.end()
  134. return txt
  135. def transcription(self, audio):
  136. if self.langfuse:
  137. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
  138. txt, used_tokens = self.mdl.transcription(audio)
  139. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  140. logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
  141. if self.langfuse:
  142. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  143. generation.end()
  144. return txt
  145. def tts(self, text: str) -> Generator[bytes, None, None]:
  146. if self.langfuse:
  147. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
  148. for chunk in self.mdl.tts(text):
  149. if isinstance(chunk, int):
  150. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
  151. logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
  152. return
  153. yield chunk
  154. if self.langfuse:
  155. generation.end()
  156. def _remove_reasoning_content(self, txt: str) -> str:
  157. first_think_start = txt.find("<think>")
  158. if first_think_start == -1:
  159. return txt
  160. last_think_end = txt.rfind("</think>")
  161. if last_think_end == -1:
  162. return txt
  163. if last_think_end < first_think_start:
  164. return txt
  165. return txt[last_think_end + len("</think>") :]
  166. @staticmethod
  167. def _clean_param(chat_partial, **kwargs):
  168. func = chat_partial.func
  169. sig = inspect.signature(func)
  170. keyword_args = []
  171. support_var_args = False
  172. for param in sig.parameters.values():
  173. if param.kind == inspect.Parameter.VAR_KEYWORD or param.kind == inspect.Parameter.VAR_POSITIONAL:
  174. support_var_args = True
  175. elif param.kind == inspect.Parameter.KEYWORD_ONLY:
  176. keyword_args.append(param.name)
  177. use_kwargs = kwargs
  178. if not support_var_args:
  179. use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args}
  180. return use_kwargs
  181. def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
  182. if self.langfuse:
  183. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
  184. chat_partial = partial(self.mdl.chat, system, history, gen_conf)
  185. if self.is_tools and self.mdl.is_tools:
  186. chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
  187. use_kwargs = self._clean_param(chat_partial, **kwargs)
  188. txt, used_tokens = chat_partial(**use_kwargs)
  189. txt = self._remove_reasoning_content(txt)
  190. if not self.verbose_tool_use:
  191. txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
  192. if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
  193. logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
  194. if self.langfuse:
  195. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  196. generation.end()
  197. return txt
  198. def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
  199. if self.langfuse:
  200. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
  201. ans = ""
  202. chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
  203. total_tokens = 0
  204. if self.is_tools and self.mdl.is_tools:
  205. chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
  206. use_kwargs = self._clean_param(chat_partial, **kwargs)
  207. for txt in chat_partial(**use_kwargs):
  208. if isinstance(txt, int):
  209. total_tokens = txt
  210. if self.langfuse:
  211. generation.update(output={"output": ans})
  212. generation.end()
  213. break
  214. if txt.endswith("</think>"):
  215. ans = ans.rstrip("</think>")
  216. if not self.verbose_tool_use:
  217. txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
  218. ans += txt
  219. yield ans
  220. if total_tokens > 0:
  221. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
  222. logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))