You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

prompts.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import datetime
  17. import json
  18. import logging
  19. import re
  20. from copy import deepcopy
  21. from typing import Tuple
  22. import jinja2
  23. import json_repair
  24. from api.utils import hash_str2int
  25. from rag.prompts.prompt_template import load_prompt
  26. from rag.settings import TAG_FLD
  27. from rag.utils import encoder, num_tokens_from_string
  28. STOP_TOKEN="<|STOP|>"
  29. COMPLETE_TASK="complete_task"
  30. def get_value(d, k1, k2):
  31. return d.get(k1, d.get(k2))
  32. def chunks_format(reference):
  33. return [
  34. {
  35. "id": get_value(chunk, "chunk_id", "id"),
  36. "content": get_value(chunk, "content", "content_with_weight"),
  37. "document_id": get_value(chunk, "doc_id", "document_id"),
  38. "document_name": get_value(chunk, "docnm_kwd", "document_name"),
  39. "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
  40. "image_id": get_value(chunk, "image_id", "img_id"),
  41. "positions": get_value(chunk, "positions", "position_int"),
  42. "url": chunk.get("url"),
  43. "similarity": chunk.get("similarity"),
  44. "vector_similarity": chunk.get("vector_similarity"),
  45. "term_similarity": chunk.get("term_similarity"),
  46. "doc_type": chunk.get("doc_type_kwd"),
  47. }
  48. for chunk in reference.get("chunks", [])
  49. ]
  50. def message_fit_in(msg, max_length=4000):
  51. def count():
  52. nonlocal msg
  53. tks_cnts = []
  54. for m in msg:
  55. tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
  56. total = 0
  57. for m in tks_cnts:
  58. total += m["count"]
  59. return total
  60. c = count()
  61. if c < max_length:
  62. return c, msg
  63. msg_ = [m for m in msg if m["role"] == "system"]
  64. if len(msg) > 1:
  65. msg_.append(msg[-1])
  66. msg = msg_
  67. c = count()
  68. if c < max_length:
  69. return c, msg
  70. ll = num_tokens_from_string(msg_[0]["content"])
  71. ll2 = num_tokens_from_string(msg_[-1]["content"])
  72. if ll / (ll + ll2) > 0.8:
  73. m = msg_[0]["content"]
  74. m = encoder.decode(encoder.encode(m)[: max_length - ll2])
  75. msg[0]["content"] = m
  76. return max_length, msg
  77. m = msg_[-1]["content"]
  78. m = encoder.decode(encoder.encode(m)[: max_length - ll2])
  79. msg[-1]["content"] = m
  80. return max_length, msg
  81. def kb_prompt(kbinfos, max_tokens, hash_id=False):
  82. from api.db.services.document_service import DocumentService
  83. knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]]
  84. kwlg_len = len(knowledges)
  85. used_token_count = 0
  86. chunks_num = 0
  87. for i, c in enumerate(knowledges):
  88. if not c:
  89. continue
  90. used_token_count += num_tokens_from_string(c)
  91. chunks_num += 1
  92. if max_tokens * 0.97 < used_token_count:
  93. knowledges = knowledges[:i]
  94. logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
  95. break
  96. docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]])
  97. docs = {d.id: d.meta_fields for d in docs}
  98. def draw_node(k, line):
  99. if line is not None and not isinstance(line, str):
  100. line = str(line)
  101. if not line:
  102. return ""
  103. return f"\n├── {k}: " + re.sub(r"\n+", " ", line, flags=re.DOTALL)
  104. knowledges = []
  105. for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
  106. cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
  107. cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
  108. cnt += draw_node("URL", ck['url']) if "url" in ck else ""
  109. for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
  110. cnt += draw_node(k, v)
  111. cnt += "\n└── Content:\n"
  112. cnt += get_value(ck, "content", "content_with_weight")
  113. knowledges.append(cnt)
  114. return knowledges
  115. CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
  116. CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
  117. CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
  118. CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
  119. CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
  120. FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
  121. KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
  122. QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
  123. VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
  124. VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
  125. ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
  126. ANALYZE_TASK_USER = load_prompt("analyze_task_user")
  127. NEXT_STEP = load_prompt("next_step")
  128. REFLECT = load_prompt("reflect")
  129. SUMMARY4MEMORY = load_prompt("summary4memory")
  130. RANK_MEMORY = load_prompt("rank_memory")
  131. META_FILTER = load_prompt("meta_filter")
  132. ASK_SUMMARY = load_prompt("ask_summary")
  133. PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
  134. def citation_prompt() -> str:
  135. template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
  136. return template.render()
  137. def citation_plus(sources: str) -> str:
  138. template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
  139. return template.render(example=citation_prompt(), sources=sources)
  140. def keyword_extraction(chat_mdl, content, topn=3):
  141. template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
  142. rendered_prompt = template.render(content=content, topn=topn)
  143. msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
  144. _, msg = message_fit_in(msg, chat_mdl.max_length)
  145. kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
  146. if isinstance(kwd, tuple):
  147. kwd = kwd[0]
  148. kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
  149. if kwd.find("**ERROR**") >= 0:
  150. return ""
  151. return kwd
  152. def question_proposal(chat_mdl, content, topn=3):
  153. template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
  154. rendered_prompt = template.render(content=content, topn=topn)
  155. msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
  156. _, msg = message_fit_in(msg, chat_mdl.max_length)
  157. kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
  158. if isinstance(kwd, tuple):
  159. kwd = kwd[0]
  160. kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
  161. if kwd.find("**ERROR**") >= 0:
  162. return ""
  163. return kwd
  164. def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
  165. from api.db import LLMType
  166. from api.db.services.llm_service import LLMBundle
  167. from api.db.services.tenant_llm_service import TenantLLMService
  168. if not chat_mdl:
  169. if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
  170. chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
  171. else:
  172. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
  173. conv = []
  174. for m in messages:
  175. if m["role"] not in ["user", "assistant"]:
  176. continue
  177. conv.append("{}: {}".format(m["role"].upper(), m["content"]))
  178. conversation = "\n".join(conv)
  179. today = datetime.date.today().isoformat()
  180. yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
  181. tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
  182. template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
  183. rendered_prompt = template.render(
  184. today=today,
  185. yesterday=yesterday,
  186. tomorrow=tomorrow,
  187. conversation=conversation,
  188. language=language,
  189. )
  190. ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
  191. ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  192. return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
  193. def cross_languages(tenant_id, llm_id, query, languages=[]):
  194. from api.db import LLMType
  195. from api.db.services.llm_service import LLMBundle
  196. from api.db.services.tenant_llm_service import TenantLLMService
  197. if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
  198. chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
  199. else:
  200. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
  201. rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
  202. rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
  203. ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
  204. ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  205. if ans.find("**ERROR**") >= 0:
  206. return query
  207. return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
  208. def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
  209. template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
  210. for ex in examples:
  211. ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
  212. rendered_prompt = template.render(
  213. topn=topn,
  214. all_tags=all_tags,
  215. examples=examples,
  216. content=content,
  217. )
  218. msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
  219. _, msg = message_fit_in(msg, chat_mdl.max_length)
  220. kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
  221. if isinstance(kwd, tuple):
  222. kwd = kwd[0]
  223. kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
  224. if kwd.find("**ERROR**") >= 0:
  225. raise Exception(kwd)
  226. try:
  227. obj = json_repair.loads(kwd)
  228. except json_repair.JSONDecodeError:
  229. try:
  230. result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
  231. result = "{" + result.split("{")[1].split("}")[0] + "}"
  232. obj = json_repair.loads(result)
  233. except Exception as e:
  234. logging.exception(f"JSON parsing error: {result} -> {e}")
  235. raise e
  236. res = {}
  237. for k, v in obj.items():
  238. try:
  239. if int(v) > 0:
  240. res[str(k)] = int(v)
  241. except Exception:
  242. pass
  243. return res
  244. def vision_llm_describe_prompt(page=None) -> str:
  245. template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
  246. return template.render(page=page)
  247. def vision_llm_figure_describe_prompt() -> str:
  248. template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
  249. return template.render()
  250. def tool_schema(tools_description: list[dict], complete_task=False):
  251. if not tools_description:
  252. return ""
  253. desc = {}
  254. if complete_task:
  255. desc[COMPLETE_TASK] = {
  256. "type": "function",
  257. "function": {
  258. "name": COMPLETE_TASK,
  259. "description": "When you have the final answer and are ready to complete the task, call this function with your answer",
  260. "parameters": {
  261. "type": "object",
  262. "properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
  263. "required": ["answer"]
  264. }
  265. }
  266. }
  267. for tool in tools_description:
  268. desc[tool["function"]["name"]] = tool
  269. return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
  270. def form_history(history, limit=-6):
  271. context = ""
  272. for h in history[limit:]:
  273. if h["role"] == "system":
  274. continue
  275. role = "USER"
  276. if h["role"].upper()!= role:
  277. role = "AGENT"
  278. context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
  279. return context
  280. def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict]):
  281. tools_desc = tool_schema(tools_description)
  282. context = ""
  283. template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER)
  284. context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
  285. kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": context}], {})
  286. if isinstance(kwd, tuple):
  287. kwd = kwd[0]
  288. kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
  289. if kwd.find("**ERROR**") >= 0:
  290. return ""
  291. return kwd
  292. def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
  293. if not tools_description:
  294. return ""
  295. desc = tool_schema(tools_description)
  296. template = PROMPT_JINJA_ENV.from_string(NEXT_STEP)
  297. user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
  298. hist = deepcopy(history)
  299. if hist[-1]["role"] == "user":
  300. hist[-1]["content"] += user_prompt
  301. else:
  302. hist.append({"role": "user", "content": user_prompt})
  303. json_str = chat_mdl.chat(template.render(task_analisys=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
  304. hist[1:], stop=["<|stop|>"])
  305. tk_cnt = num_tokens_from_string(json_str)
  306. json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
  307. return json_str, tk_cnt
  308. def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple]):
  309. tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
  310. goal = history[1]["content"]
  311. template = PROMPT_JINJA_ENV.from_string(REFLECT)
  312. user_prompt = template.render(goal=goal, tool_calls=tool_calls)
  313. hist = deepcopy(history)
  314. if hist[-1]["role"] == "user":
  315. hist[-1]["content"] += user_prompt
  316. else:
  317. hist.append({"role": "user", "content": user_prompt})
  318. _, msg = message_fit_in(hist, chat_mdl.max_length)
  319. ans = chat_mdl.chat(msg[0]["content"], msg[1:])
  320. ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  321. return """
  322. **Observation**
  323. {}
  324. **Reflection**
  325. {}
  326. """.format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
  327. def form_message(system_prompt, user_prompt):
  328. return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
  329. def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
  330. template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
  331. system_prompt = template.render(name=name,
  332. params=json.dumps(params, ensure_ascii=False, indent=2),
  333. result=result)
  334. user_prompt = "→ Summary: "
  335. _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
  336. ans = chat_mdl.chat(msg[0]["content"], msg[1:])
  337. return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  338. def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str]):
  339. template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
  340. system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
  341. user_prompt = " → rank: "
  342. _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
  343. ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
  344. return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  345. def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
  346. sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render(
  347. current_date=datetime.datetime.today().strftime('%Y-%m-%d'),
  348. metadata_keys=json.dumps(meta_data),
  349. user_question=query
  350. )
  351. user_prompt = "Generate filters:"
  352. ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
  353. ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
  354. try:
  355. ans = json_repair.loads(ans)
  356. assert isinstance(ans, list), ans
  357. return ans
  358. except Exception:
  359. logging.exception(f"Loading json failure: {ans}")
  360. return []