You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import tempfile
  2. from binascii import hexlify, unhexlify
  3. from collections.abc import Generator
  4. from core.model_manager import ModelManager
  5. from core.model_runtime.entities.llm_entities import (
  6. LLMResult,
  7. LLMResultChunk,
  8. LLMResultChunkDelta,
  9. )
  10. from core.model_runtime.entities.message_entities import (
  11. PromptMessage,
  12. SystemPromptMessage,
  13. UserPromptMessage,
  14. )
  15. from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
  16. from core.plugin.entities.request import (
  17. RequestInvokeLLM,
  18. RequestInvokeModeration,
  19. RequestInvokeRerank,
  20. RequestInvokeSpeech2Text,
  21. RequestInvokeSummary,
  22. RequestInvokeTextEmbedding,
  23. RequestInvokeTTS,
  24. )
  25. from core.tools.entities.tool_entities import ToolProviderType
  26. from core.tools.utils.model_invocation_utils import ModelInvocationUtils
  27. from core.workflow.nodes.llm import llm_utils
  28. from models.account import Tenant
  29. class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
  30. @classmethod
  31. def invoke_llm(
  32. cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
  33. ) -> Generator[LLMResultChunk, None, None] | LLMResult:
  34. """
  35. invoke llm
  36. """
  37. model_instance = ModelManager().get_model_instance(
  38. tenant_id=tenant.id,
  39. provider=payload.provider,
  40. model_type=payload.model_type,
  41. model=payload.model,
  42. )
  43. # invoke model
  44. response = model_instance.invoke_llm(
  45. prompt_messages=payload.prompt_messages,
  46. model_parameters=payload.completion_params,
  47. tools=payload.tools,
  48. stop=payload.stop,
  49. stream=True if payload.stream is None else payload.stream,
  50. user=user_id,
  51. )
  52. if isinstance(response, Generator):
  53. def handle() -> Generator[LLMResultChunk, None, None]:
  54. for chunk in response:
  55. if chunk.delta.usage:
  56. llm_utils.deduct_llm_quota(
  57. tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
  58. )
  59. chunk.prompt_messages = []
  60. yield chunk
  61. return handle()
  62. else:
  63. if response.usage:
  64. llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
  65. def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
  66. yield LLMResultChunk(
  67. model=response.model,
  68. prompt_messages=[],
  69. system_fingerprint=response.system_fingerprint,
  70. delta=LLMResultChunkDelta(
  71. index=0,
  72. message=response.message,
  73. usage=response.usage,
  74. finish_reason="",
  75. ),
  76. )
  77. return handle_non_streaming(response)
  78. @classmethod
  79. def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
  80. """
  81. invoke text embedding
  82. """
  83. model_instance = ModelManager().get_model_instance(
  84. tenant_id=tenant.id,
  85. provider=payload.provider,
  86. model_type=payload.model_type,
  87. model=payload.model,
  88. )
  89. # invoke model
  90. response = model_instance.invoke_text_embedding(
  91. texts=payload.texts,
  92. user=user_id,
  93. )
  94. return response
  95. @classmethod
  96. def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
  97. """
  98. invoke rerank
  99. """
  100. model_instance = ModelManager().get_model_instance(
  101. tenant_id=tenant.id,
  102. provider=payload.provider,
  103. model_type=payload.model_type,
  104. model=payload.model,
  105. )
  106. # invoke model
  107. response = model_instance.invoke_rerank(
  108. query=payload.query,
  109. docs=payload.docs,
  110. score_threshold=payload.score_threshold,
  111. top_n=payload.top_n,
  112. user=user_id,
  113. )
  114. return response
  115. @classmethod
  116. def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
  117. """
  118. invoke tts
  119. """
  120. model_instance = ModelManager().get_model_instance(
  121. tenant_id=tenant.id,
  122. provider=payload.provider,
  123. model_type=payload.model_type,
  124. model=payload.model,
  125. )
  126. # invoke model
  127. response = model_instance.invoke_tts(
  128. content_text=payload.content_text,
  129. tenant_id=tenant.id,
  130. voice=payload.voice,
  131. user=user_id,
  132. )
  133. def handle() -> Generator[dict, None, None]:
  134. for chunk in response:
  135. yield {"result": hexlify(chunk).decode("utf-8")}
  136. return handle()
  137. @classmethod
  138. def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
  139. """
  140. invoke speech2text
  141. """
  142. model_instance = ModelManager().get_model_instance(
  143. tenant_id=tenant.id,
  144. provider=payload.provider,
  145. model_type=payload.model_type,
  146. model=payload.model,
  147. )
  148. # invoke model
  149. with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
  150. temp.write(unhexlify(payload.file))
  151. temp.flush()
  152. temp.seek(0)
  153. response = model_instance.invoke_speech2text(
  154. file=temp,
  155. user=user_id,
  156. )
  157. return {
  158. "result": response,
  159. }
  160. @classmethod
  161. def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
  162. """
  163. invoke moderation
  164. """
  165. model_instance = ModelManager().get_model_instance(
  166. tenant_id=tenant.id,
  167. provider=payload.provider,
  168. model_type=payload.model_type,
  169. model=payload.model,
  170. )
  171. # invoke model
  172. response = model_instance.invoke_moderation(
  173. text=payload.text,
  174. user=user_id,
  175. )
  176. return {
  177. "result": response,
  178. }
  179. @classmethod
  180. def get_system_model_max_tokens(cls, tenant_id: str) -> int:
  181. """
  182. get system model max tokens
  183. """
  184. return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
  185. @classmethod
  186. def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
  187. """
  188. get prompt tokens
  189. """
  190. return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
  191. @classmethod
  192. def invoke_system_model(
  193. cls,
  194. user_id: str,
  195. tenant: Tenant,
  196. prompt_messages: list[PromptMessage],
  197. ) -> LLMResult:
  198. """
  199. invoke system model
  200. """
  201. return ModelInvocationUtils.invoke(
  202. user_id=user_id,
  203. tenant_id=tenant.id,
  204. tool_type=ToolProviderType.PLUGIN,
  205. tool_name="plugin",
  206. prompt_messages=prompt_messages,
  207. )
  208. @classmethod
  209. def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
  210. """
  211. invoke summary
  212. """
  213. max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
  214. content = payload.text
  215. SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
  216. and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
  217. retain the original meaning and keep the key points.
  218. however, the text you got is too long, what you got is possible a part of the text.
  219. Please summarize the text you got.
  220. Here is the extra instruction you need to follow:
  221. <extra_instruction>
  222. {payload.instruction}
  223. </extra_instruction>
  224. """
  225. if (
  226. cls.get_prompt_tokens(
  227. tenant_id=tenant.id,
  228. prompt_messages=[UserPromptMessage(content=content)],
  229. )
  230. < max_tokens * 0.6
  231. ):
  232. return content
  233. def get_prompt_tokens(content: str) -> int:
  234. return cls.get_prompt_tokens(
  235. tenant_id=tenant.id,
  236. prompt_messages=[
  237. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  238. UserPromptMessage(content=content),
  239. ],
  240. )
  241. def summarize(content: str) -> str:
  242. summary = cls.invoke_system_model(
  243. user_id=user_id,
  244. tenant=tenant,
  245. prompt_messages=[
  246. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  247. UserPromptMessage(content=content),
  248. ],
  249. )
  250. assert isinstance(summary.message.content, str)
  251. return summary.message.content
  252. lines = content.split("\n")
  253. new_lines: list[str] = []
  254. # split long line into multiple lines
  255. for i in range(len(lines)):
  256. line = lines[i]
  257. if not line.strip():
  258. continue
  259. if len(line) < max_tokens * 0.5:
  260. new_lines.append(line)
  261. elif get_prompt_tokens(line) > max_tokens * 0.7:
  262. while get_prompt_tokens(line) > max_tokens * 0.7:
  263. new_lines.append(line[: int(max_tokens * 0.5)])
  264. line = line[int(max_tokens * 0.5) :]
  265. new_lines.append(line)
  266. else:
  267. new_lines.append(line)
  268. # merge lines into messages with max tokens
  269. messages: list[str] = []
  270. for i in new_lines: # type: ignore
  271. if len(messages) == 0:
  272. messages.append(i) # type: ignore
  273. else:
  274. if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
  275. messages[-1] += i # type: ignore
  276. if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
  277. messages.append(i) # type: ignore
  278. else:
  279. messages[-1] += i # type: ignore
  280. summaries = []
  281. for i in range(len(messages)):
  282. message = messages[i]
  283. summary = summarize(message)
  284. summaries.append(summary)
  285. result = "\n".join(summaries)
  286. if (
  287. cls.get_prompt_tokens(
  288. tenant_id=tenant.id,
  289. prompt_messages=[UserPromptMessage(content=result)],
  290. )
  291. > max_tokens * 0.7
  292. ):
  293. return cls.invoke_summary(
  294. user_id=user_id,
  295. tenant=tenant,
  296. payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
  297. )
  298. return result