選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import os
  19. import re
  20. from typing import Any, Generator
  21. import json_repair
  22. from functools import partial
  23. from api.db import LLMType
  24. from api.db.services.llm_service import LLMBundle
  25. from api.db.services.tenant_llm_service import TenantLLMService
  26. from agent.component.base import ComponentBase, ComponentParamBase
  27. from api.utils.api_utils import timeout
  28. from rag.prompts import message_fit_in, citation_prompt
  29. from rag.prompts.prompts import tool_call_summary
  30. class LLMParam(ComponentParamBase):
  31. """
  32. Define the LLM component parameters.
  33. """
  34. def __init__(self):
  35. super().__init__()
  36. self.llm_id = ""
  37. self.sys_prompt = ""
  38. self.prompts = [{"role": "user", "content": "{sys.query}"}]
  39. self.max_tokens = 0
  40. self.temperature = 0
  41. self.top_p = 0
  42. self.presence_penalty = 0
  43. self.frequency_penalty = 0
  44. self.output_structure = None
  45. self.cite = True
  46. self.visual_files_var = None
  47. def check(self):
  48. self.check_decimal_float(float(self.temperature), "[Agent] Temperature")
  49. self.check_decimal_float(float(self.presence_penalty), "[Agent] Presence penalty")
  50. self.check_decimal_float(float(self.frequency_penalty), "[Agent] Frequency penalty")
  51. self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
  52. self.check_decimal_float(float(self.top_p), "[Agent] Top P")
  53. self.check_empty(self.llm_id, "[Agent] LLM")
  54. self.check_empty(self.sys_prompt, "[Agent] System prompt")
  55. self.check_empty(self.prompts, "[Agent] User prompt")
  56. def gen_conf(self):
  57. conf = {}
  58. def get_attr(nm):
  59. try:
  60. return getattr(self, nm)
  61. except Exception:
  62. pass
  63. if int(self.max_tokens) > 0 and get_attr("maxTokensEnabled"):
  64. conf["max_tokens"] = int(self.max_tokens)
  65. if float(self.temperature) > 0 and get_attr("temperatureEnabled"):
  66. conf["temperature"] = float(self.temperature)
  67. if float(self.top_p) > 0 and get_attr("topPEnabled"):
  68. conf["top_p"] = float(self.top_p)
  69. if float(self.presence_penalty) > 0 and get_attr("presencePenaltyEnabled"):
  70. conf["presence_penalty"] = float(self.presence_penalty)
  71. if float(self.frequency_penalty) > 0 and get_attr("frequencyPenaltyEnabled"):
  72. conf["frequency_penalty"] = float(self.frequency_penalty)
  73. return conf
  74. class LLM(ComponentBase):
  75. component_name = "LLM"
  76. def __init__(self, canvas, id, param: ComponentParamBase):
  77. super().__init__(canvas, id, param)
  78. self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
  79. self._param.llm_id, max_retries=self._param.max_retries,
  80. retry_interval=self._param.delay_after_error
  81. )
  82. self.imgs = []
  83. def get_input_form(self) -> dict[str, dict]:
  84. res = {}
  85. for k, v in self.get_input_elements().items():
  86. res[k] = {
  87. "type": "line",
  88. "name": v["name"]
  89. }
  90. return res
  91. def get_input_elements(self) -> dict[str, Any]:
  92. res = self.get_input_elements_from_text(self._param.sys_prompt)
  93. for prompt in self._param.prompts:
  94. d = self.get_input_elements_from_text(prompt["content"])
  95. res.update(d)
  96. return res
  97. def set_debug_inputs(self, inputs: dict[str, dict]):
  98. self._param.debug_inputs = inputs
  99. def add2system_prompt(self, txt):
  100. self._param.sys_prompt += txt
  101. def _prepare_prompt_variables(self):
  102. if self._param.visual_files_var:
  103. self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
  104. if not self.imgs:
  105. self.imgs = []
  106. self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"]
  107. if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
  108. self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
  109. self._param.llm_id, max_retries=self._param.max_retries,
  110. retry_interval=self._param.delay_after_error
  111. )
  112. args = {}
  113. vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
  114. sys_prompt = self._param.sys_prompt
  115. for k, o in vars.items():
  116. args[k] = o["value"]
  117. if not isinstance(args[k], str):
  118. try:
  119. args[k] = json.dumps(args[k], ensure_ascii=False)
  120. except Exception:
  121. args[k] = str(args[k])
  122. self.set_input_value(k, args[k])
  123. msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
  124. for p in self._param.prompts:
  125. if msg and msg[-1]["role"] == p["role"]:
  126. continue
  127. msg.append(p)
  128. sys_prompt = self.string_format(sys_prompt, args)
  129. for m in msg:
  130. m["content"] = self.string_format(m["content"], args)
  131. if self._param.cite and self._canvas.get_reference()["chunks"]:
  132. sys_prompt += citation_prompt()
  133. return sys_prompt, msg
  134. def _generate(self, msg:list[dict], **kwargs) -> str:
  135. if not self.imgs:
  136. return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
  137. return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
  138. def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
  139. ans = ""
  140. last_idx = 0
  141. endswith_think = False
  142. def delta(txt):
  143. nonlocal ans, last_idx, endswith_think
  144. delta_ans = txt[last_idx:]
  145. ans = txt
  146. if delta_ans.find("<think>") == 0:
  147. last_idx += len("<think>")
  148. return "<think>"
  149. elif delta_ans.find("<think>") > 0:
  150. delta_ans = txt[last_idx:last_idx+delta_ans.find("<think>")]
  151. last_idx += delta_ans.find("<think>")
  152. return delta_ans
  153. elif delta_ans.endswith("</think>"):
  154. endswith_think = True
  155. elif endswith_think:
  156. endswith_think = False
  157. return "</think>"
  158. last_idx = len(ans)
  159. if ans.endswith("</think>"):
  160. last_idx -= len("</think>")
  161. return re.sub(r"(<think>|</think>)", "", delta_ans)
  162. if not self.imgs:
  163. for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs):
  164. yield delta(txt)
  165. else:
  166. for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
  167. yield delta(txt)
  168. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
  169. def _invoke(self, **kwargs):
  170. def clean_formated_answer(ans: str) -> str:
  171. ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  172. ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
  173. return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
  174. prompt, msg = self._prepare_prompt_variables()
  175. error = ""
  176. if self._param.output_structure:
  177. prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
  178. prompt += "\nRedundant information is FORBIDDEN."
  179. for _ in range(self._param.max_retries+1):
  180. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  181. error = ""
  182. ans = self._generate(msg)
  183. msg.pop(0)
  184. if ans.find("**ERROR**") >= 0:
  185. logging.error(f"LLM response error: {ans}")
  186. error = ans
  187. continue
  188. try:
  189. self.set_output("structured_content", json_repair.loads(clean_formated_answer(ans)))
  190. return
  191. except Exception:
  192. msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
  193. error = "The answer can't not be parsed as JSON"
  194. if error:
  195. self.set_output("_ERROR", error)
  196. return
  197. downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
  198. ex = self.exception_handler()
  199. if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
  200. self.set_output("content", partial(self._stream_output, prompt, msg))
  201. return
  202. for _ in range(self._param.max_retries+1):
  203. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  204. error = ""
  205. ans = self._generate(msg)
  206. msg.pop(0)
  207. if ans.find("**ERROR**") >= 0:
  208. logging.error(f"LLM response error: {ans}")
  209. error = ans
  210. continue
  211. self.set_output("content", ans)
  212. break
  213. if error:
  214. if self.get_exception_default_value():
  215. self.set_output("content", self.get_exception_default_value())
  216. else:
  217. self.set_output("_ERROR", error)
  218. def _stream_output(self, prompt, msg):
  219. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  220. answer = ""
  221. for ans in self._generate_streamly(msg):
  222. if ans.find("**ERROR**") >= 0:
  223. if self.get_exception_default_value():
  224. self.set_output("content", self.get_exception_default_value())
  225. yield self.get_exception_default_value()
  226. else:
  227. self.set_output("_ERROR", ans)
  228. return
  229. yield ans
  230. answer += ans
  231. self.set_output("content", answer)
  232. def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str):
  233. summ = tool_call_summary(self.chat_mdl, func_name, params, results)
  234. logging.info(f"[MEMORY]: {summ}")
  235. self._canvas.add_memory(user, assist, summ)
  236. def thoughts(self) -> str:
  237. _, msg = self._prepare_prompt_variables()
  238. return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."