您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import tempfile
  2. from binascii import hexlify, unhexlify
  3. from collections.abc import Generator
  4. from core.model_manager import ModelManager
  5. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  6. from core.model_runtime.entities.message_entities import (
  7. PromptMessage,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
  12. from core.plugin.entities.request import (
  13. RequestInvokeLLM,
  14. RequestInvokeModeration,
  15. RequestInvokeRerank,
  16. RequestInvokeSpeech2Text,
  17. RequestInvokeSummary,
  18. RequestInvokeTextEmbedding,
  19. RequestInvokeTTS,
  20. )
  21. from core.tools.entities.tool_entities import ToolProviderType
  22. from core.tools.utils.model_invocation_utils import ModelInvocationUtils
  23. from core.workflow.nodes.llm.node import LLMNode
  24. from models.account import Tenant
  25. class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
  26. @classmethod
  27. def invoke_llm(
  28. cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
  29. ) -> Generator[LLMResultChunk, None, None] | LLMResult:
  30. """
  31. invoke llm
  32. """
  33. model_instance = ModelManager().get_model_instance(
  34. tenant_id=tenant.id,
  35. provider=payload.provider,
  36. model_type=payload.model_type,
  37. model=payload.model,
  38. )
  39. # invoke model
  40. response = model_instance.invoke_llm(
  41. prompt_messages=payload.prompt_messages,
  42. model_parameters=payload.completion_params,
  43. tools=payload.tools,
  44. stop=payload.stop,
  45. stream=True if payload.stream is None else payload.stream,
  46. user=user_id,
  47. )
  48. if isinstance(response, Generator):
  49. def handle() -> Generator[LLMResultChunk, None, None]:
  50. for chunk in response:
  51. if chunk.delta.usage:
  52. LLMNode.deduct_llm_quota(
  53. tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
  54. )
  55. chunk.prompt_messages = []
  56. yield chunk
  57. return handle()
  58. else:
  59. if response.usage:
  60. LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
  61. def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
  62. yield LLMResultChunk(
  63. model=response.model,
  64. prompt_messages=[],
  65. system_fingerprint=response.system_fingerprint,
  66. delta=LLMResultChunkDelta(
  67. index=0,
  68. message=response.message,
  69. usage=response.usage,
  70. finish_reason="",
  71. ),
  72. )
  73. return handle_non_streaming(response)
  74. @classmethod
  75. def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
  76. """
  77. invoke text embedding
  78. """
  79. model_instance = ModelManager().get_model_instance(
  80. tenant_id=tenant.id,
  81. provider=payload.provider,
  82. model_type=payload.model_type,
  83. model=payload.model,
  84. )
  85. # invoke model
  86. response = model_instance.invoke_text_embedding(
  87. texts=payload.texts,
  88. user=user_id,
  89. )
  90. return response
  91. @classmethod
  92. def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
  93. """
  94. invoke rerank
  95. """
  96. model_instance = ModelManager().get_model_instance(
  97. tenant_id=tenant.id,
  98. provider=payload.provider,
  99. model_type=payload.model_type,
  100. model=payload.model,
  101. )
  102. # invoke model
  103. response = model_instance.invoke_rerank(
  104. query=payload.query,
  105. docs=payload.docs,
  106. score_threshold=payload.score_threshold,
  107. top_n=payload.top_n,
  108. user=user_id,
  109. )
  110. return response
  111. @classmethod
  112. def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
  113. """
  114. invoke tts
  115. """
  116. model_instance = ModelManager().get_model_instance(
  117. tenant_id=tenant.id,
  118. provider=payload.provider,
  119. model_type=payload.model_type,
  120. model=payload.model,
  121. )
  122. # invoke model
  123. response = model_instance.invoke_tts(
  124. content_text=payload.content_text,
  125. tenant_id=tenant.id,
  126. voice=payload.voice,
  127. user=user_id,
  128. )
  129. def handle() -> Generator[dict, None, None]:
  130. for chunk in response:
  131. yield {"result": hexlify(chunk).decode("utf-8")}
  132. return handle()
  133. @classmethod
  134. def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
  135. """
  136. invoke speech2text
  137. """
  138. model_instance = ModelManager().get_model_instance(
  139. tenant_id=tenant.id,
  140. provider=payload.provider,
  141. model_type=payload.model_type,
  142. model=payload.model,
  143. )
  144. # invoke model
  145. with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
  146. temp.write(unhexlify(payload.file))
  147. temp.flush()
  148. temp.seek(0)
  149. response = model_instance.invoke_speech2text(
  150. file=temp,
  151. user=user_id,
  152. )
  153. return {
  154. "result": response,
  155. }
  156. @classmethod
  157. def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
  158. """
  159. invoke moderation
  160. """
  161. model_instance = ModelManager().get_model_instance(
  162. tenant_id=tenant.id,
  163. provider=payload.provider,
  164. model_type=payload.model_type,
  165. model=payload.model,
  166. )
  167. # invoke model
  168. response = model_instance.invoke_moderation(
  169. text=payload.text,
  170. user=user_id,
  171. )
  172. return {
  173. "result": response,
  174. }
  175. @classmethod
  176. def get_system_model_max_tokens(cls, tenant_id: str) -> int:
  177. """
  178. get system model max tokens
  179. """
  180. return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
  181. @classmethod
  182. def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
  183. """
  184. get prompt tokens
  185. """
  186. return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
  187. @classmethod
  188. def invoke_system_model(
  189. cls,
  190. user_id: str,
  191. tenant: Tenant,
  192. prompt_messages: list[PromptMessage],
  193. ) -> LLMResult:
  194. """
  195. invoke system model
  196. """
  197. return ModelInvocationUtils.invoke(
  198. user_id=user_id,
  199. tenant_id=tenant.id,
  200. tool_type=ToolProviderType.PLUGIN,
  201. tool_name="plugin",
  202. prompt_messages=prompt_messages,
  203. )
  204. @classmethod
  205. def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
  206. """
  207. invoke summary
  208. """
  209. max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
  210. content = payload.text
  211. SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
  212. and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
  213. retain the original meaning and keep the key points.
  214. however, the text you got is too long, what you got is possible a part of the text.
  215. Please summarize the text you got.
  216. Here is the extra instruction you need to follow:
  217. <extra_instruction>
  218. {payload.instruction}
  219. </extra_instruction>
  220. """
  221. if (
  222. cls.get_prompt_tokens(
  223. tenant_id=tenant.id,
  224. prompt_messages=[UserPromptMessage(content=content)],
  225. )
  226. < max_tokens * 0.6
  227. ):
  228. return content
  229. def get_prompt_tokens(content: str) -> int:
  230. return cls.get_prompt_tokens(
  231. tenant_id=tenant.id,
  232. prompt_messages=[
  233. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  234. UserPromptMessage(content=content),
  235. ],
  236. )
  237. def summarize(content: str) -> str:
  238. summary = cls.invoke_system_model(
  239. user_id=user_id,
  240. tenant=tenant,
  241. prompt_messages=[
  242. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  243. UserPromptMessage(content=content),
  244. ],
  245. )
  246. assert isinstance(summary.message.content, str)
  247. return summary.message.content
  248. lines = content.split("\n")
  249. new_lines: list[str] = []
  250. # split long line into multiple lines
  251. for i in range(len(lines)):
  252. line = lines[i]
  253. if not line.strip():
  254. continue
  255. if len(line) < max_tokens * 0.5:
  256. new_lines.append(line)
  257. elif get_prompt_tokens(line) > max_tokens * 0.7:
  258. while get_prompt_tokens(line) > max_tokens * 0.7:
  259. new_lines.append(line[: int(max_tokens * 0.5)])
  260. line = line[int(max_tokens * 0.5) :]
  261. new_lines.append(line)
  262. else:
  263. new_lines.append(line)
  264. # merge lines into messages with max tokens
  265. messages: list[str] = []
  266. for i in new_lines: # type: ignore
  267. if len(messages) == 0:
  268. messages.append(i) # type: ignore
  269. else:
  270. if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
  271. messages[-1] += i # type: ignore
  272. if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
  273. messages.append(i) # type: ignore
  274. else:
  275. messages[-1] += i # type: ignore
  276. summaries = []
  277. for i in range(len(messages)):
  278. message = messages[i]
  279. summary = summarize(message)
  280. summaries.append(summary)
  281. result = "\n".join(summaries)
  282. if (
  283. cls.get_prompt_tokens(
  284. tenant_id=tenant.id,
  285. prompt_messages=[UserPromptMessage(content=result)],
  286. )
  287. > max_tokens * 0.7
  288. ):
  289. return cls.invoke_summary(
  290. user_id=user_id,
  291. tenant=tenant,
  292. payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
  293. )
  294. return result