Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

llm_service.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import inspect
  17. import logging
  18. import re
  19. from functools import partial
  20. from typing import Generator
  21. from langfuse import Langfuse
  22. from api import settings
  23. from api.db import LLMType
  24. from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
  25. from api.db.services.common_service import CommonService
  26. from api.db.services.langfuse_service import TenantLangfuseService
  27. from api.db.services.user_service import TenantService
  28. from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
  29. class LLMFactoriesService(CommonService):
  30. model = LLMFactories
  31. class LLMService(CommonService):
  32. model = LLM
  33. class TenantLLMService(CommonService):
  34. model = TenantLLM
  35. @classmethod
  36. @DB.connection_context()
  37. def get_api_key(cls, tenant_id, model_name):
  38. mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
  39. if not fid:
  40. objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
  41. else:
  42. objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
  43. if (not objs) and fid:
  44. if fid == "LocalAI":
  45. mdlnm += "___LocalAI"
  46. elif fid == "HuggingFace":
  47. mdlnm += "___HuggingFace"
  48. elif fid == "OpenAI-API-Compatible":
  49. mdlnm += "___OpenAI-API"
  50. elif fid == "VLLM":
  51. mdlnm += "___VLLM"
  52. objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
  53. if not objs:
  54. return
  55. return objs[0]
  56. @classmethod
  57. @DB.connection_context()
  58. def get_my_llms(cls, tenant_id):
  59. fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
  60. objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
  61. return list(objs)
  62. @staticmethod
  63. def split_model_name_and_factory(model_name):
  64. arr = model_name.split("@")
  65. if len(arr) < 2:
  66. return model_name, None
  67. if len(arr) > 2:
  68. return "@".join(arr[0:-1]), arr[-1]
  69. # model name must be xxx@yyy
  70. try:
  71. model_factories = settings.FACTORY_LLM_INFOS
  72. model_providers = set([f["name"] for f in model_factories])
  73. if arr[-1] not in model_providers:
  74. return model_name, None
  75. return arr[0], arr[-1]
  76. except Exception as e:
  77. logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
  78. return model_name, None
  79. @classmethod
  80. @DB.connection_context()
  81. def get_model_config(cls, tenant_id, llm_type, llm_name=None):
  82. e, tenant = TenantService.get_by_id(tenant_id)
  83. if not e:
  84. raise LookupError("Tenant not found")
  85. if llm_type == LLMType.EMBEDDING.value:
  86. mdlnm = tenant.embd_id if not llm_name else llm_name
  87. elif llm_type == LLMType.SPEECH2TEXT.value:
  88. mdlnm = tenant.asr_id
  89. elif llm_type == LLMType.IMAGE2TEXT.value:
  90. mdlnm = tenant.img2txt_id if not llm_name else llm_name
  91. elif llm_type == LLMType.CHAT.value:
  92. mdlnm = tenant.llm_id if not llm_name else llm_name
  93. elif llm_type == LLMType.RERANK:
  94. mdlnm = tenant.rerank_id if not llm_name else llm_name
  95. elif llm_type == LLMType.TTS:
  96. mdlnm = tenant.tts_id if not llm_name else llm_name
  97. else:
  98. assert False, "LLM type error"
  99. model_config = cls.get_api_key(tenant_id, mdlnm)
  100. mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
  101. if not model_config: # for some cases seems fid mismatch
  102. model_config = cls.get_api_key(tenant_id, mdlnm)
  103. if model_config:
  104. model_config = model_config.to_dict()
  105. llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
  106. if not llm and fid: # for some cases seems fid mismatch
  107. llm = LLMService.query(llm_name=mdlnm)
  108. if llm:
  109. model_config["is_tools"] = llm[0].is_tools
  110. if not model_config:
  111. if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
  112. llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
  113. if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
  114. model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
  115. if not model_config:
  116. if mdlnm == "flag-embedding":
  117. model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
  118. else:
  119. if not mdlnm:
  120. raise LookupError(f"Type of {llm_type} model is not set.")
  121. raise LookupError("Model({}) not authorized".format(mdlnm))
  122. return model_config
  123. @classmethod
  124. @DB.connection_context()
  125. def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
  126. model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
  127. kwargs.update({"provider": model_config["llm_factory"]})
  128. if llm_type == LLMType.EMBEDDING.value:
  129. if model_config["llm_factory"] not in EmbeddingModel:
  130. return
  131. return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
  132. if llm_type == LLMType.RERANK:
  133. if model_config["llm_factory"] not in RerankModel:
  134. return
  135. return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
  136. if llm_type == LLMType.IMAGE2TEXT.value:
  137. if model_config["llm_factory"] not in CvModel:
  138. return
  139. return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
  140. if llm_type == LLMType.CHAT.value:
  141. if model_config["llm_factory"] not in ChatModel:
  142. return
  143. return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
  144. if llm_type == LLMType.SPEECH2TEXT:
  145. if model_config["llm_factory"] not in Seq2txtModel:
  146. return
  147. return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
  148. if llm_type == LLMType.TTS:
  149. if model_config["llm_factory"] not in TTSModel:
  150. return
  151. return TTSModel[model_config["llm_factory"]](
  152. model_config["api_key"],
  153. model_config["llm_name"],
  154. base_url=model_config["api_base"],
  155. )
  156. @classmethod
  157. @DB.connection_context()
  158. def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
  159. e, tenant = TenantService.get_by_id(tenant_id)
  160. if not e:
  161. logging.error(f"Tenant not found: {tenant_id}")
  162. return 0
  163. llm_map = {
  164. LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
  165. LLMType.SPEECH2TEXT.value: tenant.asr_id,
  166. LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
  167. LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
  168. LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
  169. LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
  170. }
  171. mdlnm = llm_map.get(llm_type)
  172. if mdlnm is None:
  173. logging.error(f"LLM type error: {llm_type}")
  174. return 0
  175. llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
  176. try:
  177. num = (
  178. cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
  179. .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
  180. .execute()
  181. )
  182. except Exception:
  183. logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
  184. return 0
  185. return num
  186. @classmethod
  187. @DB.connection_context()
  188. def get_openai_models(cls):
  189. objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
  190. return list(objs)
  191. @staticmethod
  192. def llm_id2llm_type(llm_id: str) -> str | None:
  193. llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
  194. llm_factories = settings.FACTORY_LLM_INFOS
  195. for llm_factory in llm_factories:
  196. for llm in llm_factory["llm"]:
  197. if llm_id == llm["llm_name"]:
  198. return llm["model_type"].split(",")[-1]
  199. for llm in LLMService.query(llm_name=llm_id):
  200. return llm.model_type
  201. llm = TenantLLMService.get_or_none(llm_name=llm_id)
  202. if llm:
  203. return llm.model_type
  204. for llm in TenantLLMService.query(llm_name=llm_id):
  205. return llm.model_type
  206. class LLMBundle:
  207. def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
  208. self.tenant_id = tenant_id
  209. self.llm_type = llm_type
  210. self.llm_name = llm_name
  211. self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
  212. assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
  213. model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
  214. self.max_length = model_config.get("max_tokens", 8192)
  215. self.is_tools = model_config.get("is_tools", False)
  216. self.verbose_tool_use = kwargs.get("verbose_tool_use")
  217. langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
  218. self.langfuse = None
  219. if langfuse_keys:
  220. langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
  221. if langfuse.auth_check():
  222. self.langfuse = langfuse
  223. trace_id = self.langfuse.create_trace_id()
  224. self.trace_context = {"trace_id": trace_id}
  225. def bind_tools(self, toolcall_session, tools):
  226. if not self.is_tools:
  227. logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
  228. return
  229. self.mdl.bind_tools(toolcall_session, tools)
  230. def encode(self, texts: list):
  231. if self.langfuse:
  232. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
  233. embeddings, used_tokens = self.mdl.encode(texts)
  234. llm_name = getattr(self, "llm_name", None)
  235. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
  236. logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
  237. if self.langfuse:
  238. generation.update(usage_details={"total_tokens": used_tokens})
  239. generation.end()
  240. return embeddings, used_tokens
  241. def encode_queries(self, query: str):
  242. if self.langfuse:
  243. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
  244. emd, used_tokens = self.mdl.encode_queries(query)
  245. llm_name = getattr(self, "llm_name", None)
  246. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
  247. logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
  248. if self.langfuse:
  249. generation.update(usage_details={"total_tokens": used_tokens})
  250. generation.end()
  251. return emd, used_tokens
  252. def similarity(self, query: str, texts: list):
  253. if self.langfuse:
  254. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
  255. sim, used_tokens = self.mdl.similarity(query, texts)
  256. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  257. logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
  258. if self.langfuse:
  259. generation.update(usage_details={"total_tokens": used_tokens})
  260. generation.end()
  261. return sim, used_tokens
  262. def describe(self, image, max_tokens=300):
  263. if self.langfuse:
  264. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
  265. txt, used_tokens = self.mdl.describe(image)
  266. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  267. logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
  268. if self.langfuse:
  269. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  270. generation.end()
  271. return txt
  272. def describe_with_prompt(self, image, prompt):
  273. if self.langfuse:
  274. generation = self.language.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
  275. txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
  276. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  277. logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
  278. if self.langfuse:
  279. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  280. generation.end()
  281. return txt
  282. def transcription(self, audio):
  283. if self.langfuse:
  284. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
  285. txt, used_tokens = self.mdl.transcription(audio)
  286. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
  287. logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
  288. if self.langfuse:
  289. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  290. generation.end()
  291. return txt
  292. def tts(self, text: str) -> Generator[bytes, None, None]:
  293. if self.langfuse:
  294. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
  295. for chunk in self.mdl.tts(text):
  296. if isinstance(chunk, int):
  297. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
  298. logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
  299. return
  300. yield chunk
  301. if self.langfuse:
  302. generation.end()
  303. def _remove_reasoning_content(self, txt: str) -> str:
  304. first_think_start = txt.find("<think>")
  305. if first_think_start == -1:
  306. return txt
  307. last_think_end = txt.rfind("</think>")
  308. if last_think_end == -1:
  309. return txt
  310. if last_think_end < first_think_start:
  311. return txt
  312. return txt[last_think_end + len("</think>") :]
  313. @staticmethod
  314. def _clean_param(chat_partial, **kwargs):
  315. func = chat_partial.func
  316. sig = inspect.signature(func)
  317. keyword_args = []
  318. support_var_args = False
  319. for param in sig.parameters.values():
  320. if param.kind == inspect.Parameter.VAR_KEYWORD or param.kind == inspect.Parameter.VAR_POSITIONAL:
  321. support_var_args = True
  322. elif param.kind == inspect.Parameter.KEYWORD_ONLY:
  323. keyword_args.append(param.name)
  324. use_kwargs = kwargs
  325. if not support_var_args:
  326. use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args}
  327. return use_kwargs
  328. def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
  329. if self.langfuse:
  330. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
  331. chat_partial = partial(self.mdl.chat, system, history, gen_conf)
  332. if self.is_tools and self.mdl.is_tools:
  333. chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
  334. use_kwargs = self._clean_param(chat_partial, **kwargs)
  335. txt, used_tokens = chat_partial(**use_kwargs)
  336. txt = self._remove_reasoning_content(txt)
  337. if not self.verbose_tool_use:
  338. txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
  339. if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
  340. logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
  341. if self.langfuse:
  342. generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
  343. generation.end()
  344. return txt
  345. def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
  346. if self.langfuse:
  347. generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
  348. ans = ""
  349. chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
  350. total_tokens = 0
  351. if self.is_tools and self.mdl.is_tools:
  352. chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
  353. use_kwargs = self._clean_param(chat_partial, **kwargs)
  354. for txt in chat_partial(**use_kwargs):
  355. if isinstance(txt, int):
  356. total_tokens = txt
  357. if self.langfuse:
  358. generation.update(output={"output": ans})
  359. generation.end()
  360. break
  361. if txt.endswith("</think>"):
  362. ans = ans.rstrip("</think>")
  363. if not self.verbose_tool_use:
  364. txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
  365. ans += txt
  366. yield ans
  367. if total_tokens > 0:
  368. if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
  369. logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))