You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import tempfile
  2. from binascii import hexlify, unhexlify
  3. from collections.abc import Generator
  4. from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
  5. from core.model_manager import ModelManager
  6. from core.model_runtime.entities.llm_entities import (
  7. LLMResult,
  8. LLMResultChunk,
  9. LLMResultChunkDelta,
  10. LLMResultChunkWithStructuredOutput,
  11. LLMResultWithStructuredOutput,
  12. )
  13. from core.model_runtime.entities.message_entities import (
  14. PromptMessage,
  15. SystemPromptMessage,
  16. UserPromptMessage,
  17. )
  18. from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
  19. from core.plugin.entities.request import (
  20. RequestInvokeLLM,
  21. RequestInvokeLLMWithStructuredOutput,
  22. RequestInvokeModeration,
  23. RequestInvokeRerank,
  24. RequestInvokeSpeech2Text,
  25. RequestInvokeSummary,
  26. RequestInvokeTextEmbedding,
  27. RequestInvokeTTS,
  28. )
  29. from core.tools.entities.tool_entities import ToolProviderType
  30. from core.tools.utils.model_invocation_utils import ModelInvocationUtils
  31. from core.workflow.nodes.llm import llm_utils
  32. from models.account import Tenant
  33. class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
  34. @classmethod
  35. def invoke_llm(
  36. cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
  37. ) -> Generator[LLMResultChunk, None, None] | LLMResult:
  38. """
  39. invoke llm
  40. """
  41. model_instance = ModelManager().get_model_instance(
  42. tenant_id=tenant.id,
  43. provider=payload.provider,
  44. model_type=payload.model_type,
  45. model=payload.model,
  46. )
  47. # invoke model
  48. response = model_instance.invoke_llm(
  49. prompt_messages=payload.prompt_messages,
  50. model_parameters=payload.completion_params,
  51. tools=payload.tools,
  52. stop=payload.stop,
  53. stream=True if payload.stream is None else payload.stream,
  54. user=user_id,
  55. )
  56. if isinstance(response, Generator):
  57. def handle() -> Generator[LLMResultChunk, None, None]:
  58. for chunk in response:
  59. if chunk.delta.usage:
  60. llm_utils.deduct_llm_quota(
  61. tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
  62. )
  63. chunk.prompt_messages = []
  64. yield chunk
  65. return handle()
  66. else:
  67. if response.usage:
  68. llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
  69. def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
  70. yield LLMResultChunk(
  71. model=response.model,
  72. prompt_messages=[],
  73. system_fingerprint=response.system_fingerprint,
  74. delta=LLMResultChunkDelta(
  75. index=0,
  76. message=response.message,
  77. usage=response.usage,
  78. finish_reason="",
  79. ),
  80. )
  81. return handle_non_streaming(response)
  82. @classmethod
  83. def invoke_llm_with_structured_output(
  84. cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
  85. ):
  86. """
  87. invoke llm with structured output
  88. """
  89. model_instance = ModelManager().get_model_instance(
  90. tenant_id=tenant.id,
  91. provider=payload.provider,
  92. model_type=payload.model_type,
  93. model=payload.model,
  94. )
  95. model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
  96. if not model_schema:
  97. raise ValueError(f"Model schema not found for {payload.model}")
  98. response = invoke_llm_with_structured_output(
  99. provider=payload.provider,
  100. model_schema=model_schema,
  101. model_instance=model_instance,
  102. prompt_messages=payload.prompt_messages,
  103. json_schema=payload.structured_output_schema,
  104. tools=payload.tools,
  105. stop=payload.stop,
  106. stream=True if payload.stream is None else payload.stream,
  107. user=user_id,
  108. model_parameters=payload.completion_params,
  109. )
  110. if isinstance(response, Generator):
  111. def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
  112. for chunk in response:
  113. if chunk.delta.usage:
  114. llm_utils.deduct_llm_quota(
  115. tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
  116. )
  117. chunk.prompt_messages = []
  118. yield chunk
  119. return handle()
  120. else:
  121. if response.usage:
  122. llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
  123. def handle_non_streaming(
  124. response: LLMResultWithStructuredOutput,
  125. ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
  126. yield LLMResultChunkWithStructuredOutput(
  127. model=response.model,
  128. prompt_messages=[],
  129. system_fingerprint=response.system_fingerprint,
  130. structured_output=response.structured_output,
  131. delta=LLMResultChunkDelta(
  132. index=0,
  133. message=response.message,
  134. usage=response.usage,
  135. finish_reason="",
  136. ),
  137. )
  138. return handle_non_streaming(response)
  139. @classmethod
  140. def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
  141. """
  142. invoke text embedding
  143. """
  144. model_instance = ModelManager().get_model_instance(
  145. tenant_id=tenant.id,
  146. provider=payload.provider,
  147. model_type=payload.model_type,
  148. model=payload.model,
  149. )
  150. # invoke model
  151. response = model_instance.invoke_text_embedding(
  152. texts=payload.texts,
  153. user=user_id,
  154. )
  155. return response
  156. @classmethod
  157. def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
  158. """
  159. invoke rerank
  160. """
  161. model_instance = ModelManager().get_model_instance(
  162. tenant_id=tenant.id,
  163. provider=payload.provider,
  164. model_type=payload.model_type,
  165. model=payload.model,
  166. )
  167. # invoke model
  168. response = model_instance.invoke_rerank(
  169. query=payload.query,
  170. docs=payload.docs,
  171. score_threshold=payload.score_threshold,
  172. top_n=payload.top_n,
  173. user=user_id,
  174. )
  175. return response
  176. @classmethod
  177. def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
  178. """
  179. invoke tts
  180. """
  181. model_instance = ModelManager().get_model_instance(
  182. tenant_id=tenant.id,
  183. provider=payload.provider,
  184. model_type=payload.model_type,
  185. model=payload.model,
  186. )
  187. # invoke model
  188. response = model_instance.invoke_tts(
  189. content_text=payload.content_text,
  190. tenant_id=tenant.id,
  191. voice=payload.voice,
  192. user=user_id,
  193. )
  194. def handle() -> Generator[dict, None, None]:
  195. for chunk in response:
  196. yield {"result": hexlify(chunk).decode("utf-8")}
  197. return handle()
  198. @classmethod
  199. def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
  200. """
  201. invoke speech2text
  202. """
  203. model_instance = ModelManager().get_model_instance(
  204. tenant_id=tenant.id,
  205. provider=payload.provider,
  206. model_type=payload.model_type,
  207. model=payload.model,
  208. )
  209. # invoke model
  210. with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
  211. temp.write(unhexlify(payload.file))
  212. temp.flush()
  213. temp.seek(0)
  214. response = model_instance.invoke_speech2text(
  215. file=temp,
  216. user=user_id,
  217. )
  218. return {
  219. "result": response,
  220. }
  221. @classmethod
  222. def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
  223. """
  224. invoke moderation
  225. """
  226. model_instance = ModelManager().get_model_instance(
  227. tenant_id=tenant.id,
  228. provider=payload.provider,
  229. model_type=payload.model_type,
  230. model=payload.model,
  231. )
  232. # invoke model
  233. response = model_instance.invoke_moderation(
  234. text=payload.text,
  235. user=user_id,
  236. )
  237. return {
  238. "result": response,
  239. }
  240. @classmethod
  241. def get_system_model_max_tokens(cls, tenant_id: str) -> int:
  242. """
  243. get system model max tokens
  244. """
  245. return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
  246. @classmethod
  247. def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
  248. """
  249. get prompt tokens
  250. """
  251. return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
  252. @classmethod
  253. def invoke_system_model(
  254. cls,
  255. user_id: str,
  256. tenant: Tenant,
  257. prompt_messages: list[PromptMessage],
  258. ) -> LLMResult:
  259. """
  260. invoke system model
  261. """
  262. return ModelInvocationUtils.invoke(
  263. user_id=user_id,
  264. tenant_id=tenant.id,
  265. tool_type=ToolProviderType.PLUGIN,
  266. tool_name="plugin",
  267. prompt_messages=prompt_messages,
  268. )
  269. @classmethod
  270. def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
  271. """
  272. invoke summary
  273. """
  274. max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
  275. content = payload.text
  276. SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
  277. and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
  278. retain the original meaning and keep the key points.
  279. however, the text you got is too long, what you got is possible a part of the text.
  280. Please summarize the text you got.
  281. Here is the extra instruction you need to follow:
  282. <extra_instruction>
  283. {payload.instruction}
  284. </extra_instruction>
  285. """
  286. if (
  287. cls.get_prompt_tokens(
  288. tenant_id=tenant.id,
  289. prompt_messages=[UserPromptMessage(content=content)],
  290. )
  291. < max_tokens * 0.6
  292. ):
  293. return content
  294. def get_prompt_tokens(content: str) -> int:
  295. return cls.get_prompt_tokens(
  296. tenant_id=tenant.id,
  297. prompt_messages=[
  298. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  299. UserPromptMessage(content=content),
  300. ],
  301. )
  302. def summarize(content: str) -> str:
  303. summary = cls.invoke_system_model(
  304. user_id=user_id,
  305. tenant=tenant,
  306. prompt_messages=[
  307. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  308. UserPromptMessage(content=content),
  309. ],
  310. )
  311. assert isinstance(summary.message.content, str)
  312. return summary.message.content
  313. lines = content.split("\n")
  314. new_lines: list[str] = []
  315. # split long line into multiple lines
  316. for i in range(len(lines)):
  317. line = lines[i]
  318. if not line.strip():
  319. continue
  320. if len(line) < max_tokens * 0.5:
  321. new_lines.append(line)
  322. elif get_prompt_tokens(line) > max_tokens * 0.7:
  323. while get_prompt_tokens(line) > max_tokens * 0.7:
  324. new_lines.append(line[: int(max_tokens * 0.5)])
  325. line = line[int(max_tokens * 0.5) :]
  326. new_lines.append(line)
  327. else:
  328. new_lines.append(line)
  329. # merge lines into messages with max tokens
  330. messages: list[str] = []
  331. for line in new_lines:
  332. if len(messages) == 0:
  333. messages.append(line)
  334. else:
  335. if len(messages[-1]) + len(line) < max_tokens * 0.5:
  336. messages[-1] += line
  337. if get_prompt_tokens(messages[-1] + line) > max_tokens * 0.7:
  338. messages.append(line)
  339. else:
  340. messages[-1] += line
  341. summaries = []
  342. for i in range(len(messages)):
  343. message = messages[i]
  344. summary = summarize(message)
  345. summaries.append(summary)
  346. result = "\n".join(summaries)
  347. if (
  348. cls.get_prompt_tokens(
  349. tenant_id=tenant.id,
  350. prompt_messages=[UserPromptMessage(content=result)],
  351. )
  352. > max_tokens * 0.7
  353. ):
  354. return cls.invoke_summary(
  355. user_id=user_id,
  356. tenant=tenant,
  357. payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
  358. )
  359. return result