您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

prompts.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import datetime
  17. import json
  18. import logging
  19. import re
  20. from collections import defaultdict
  21. import jinja2
  22. import json_repair
  23. from api import settings
  24. from rag.prompt_template import load_prompt
  25. from rag.settings import TAG_FLD
  26. from rag.utils import encoder, num_tokens_from_string
  27. def chunks_format(reference):
  28. def get_value(d, k1, k2):
  29. return d.get(k1, d.get(k2))
  30. return [
  31. {
  32. "id": get_value(chunk, "chunk_id", "id"),
  33. "content": get_value(chunk, "content", "content_with_weight"),
  34. "document_id": get_value(chunk, "doc_id", "document_id"),
  35. "document_name": get_value(chunk, "docnm_kwd", "document_name"),
  36. "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
  37. "image_id": get_value(chunk, "image_id", "img_id"),
  38. "positions": get_value(chunk, "positions", "position_int"),
  39. "url": chunk.get("url"),
  40. "similarity": chunk.get("similarity"),
  41. "vector_similarity": chunk.get("vector_similarity"),
  42. "term_similarity": chunk.get("term_similarity"),
  43. "doc_type": chunk.get("doc_type_kwd"),
  44. }
  45. for chunk in reference.get("chunks", [])
  46. ]
  47. def llm_id2llm_type(llm_id):
  48. from api.db.services.llm_service import TenantLLMService
  49. llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
  50. llm_factories = settings.FACTORY_LLM_INFOS
  51. for llm_factory in llm_factories:
  52. for llm in llm_factory["llm"]:
  53. if llm_id == llm["llm_name"]:
  54. return llm["model_type"].strip(",")[-1]
  55. def message_fit_in(msg, max_length=4000):
  56. def count():
  57. nonlocal msg
  58. tks_cnts = []
  59. for m in msg:
  60. tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
  61. total = 0
  62. for m in tks_cnts:
  63. total += m["count"]
  64. return total
  65. c = count()
  66. if c < max_length:
  67. return c, msg
  68. msg_ = [m for m in msg if m["role"] == "system"]
  69. if len(msg) > 1:
  70. msg_.append(msg[-1])
  71. msg = msg_
  72. c = count()
  73. if c < max_length:
  74. return c, msg
  75. ll = num_tokens_from_string(msg_[0]["content"])
  76. ll2 = num_tokens_from_string(msg_[-1]["content"])
  77. if ll / (ll + ll2) > 0.8:
  78. m = msg_[0]["content"]
  79. m = encoder.decode(encoder.encode(m)[: max_length - ll2])
  80. msg[0]["content"] = m
  81. return max_length, msg
  82. m = msg_[-1]["content"]
  83. m = encoder.decode(encoder.encode(m)[: max_length - ll2])
  84. msg[-1]["content"] = m
  85. return max_length, msg
  86. def kb_prompt(kbinfos, max_tokens):
  87. from api.db.services.document_service import DocumentService
  88. knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
  89. kwlg_len = len(knowledges)
  90. used_token_count = 0
  91. chunks_num = 0
  92. for i, c in enumerate(knowledges):
  93. used_token_count += num_tokens_from_string(c)
  94. chunks_num += 1
  95. if max_tokens * 0.97 < used_token_count:
  96. knowledges = knowledges[:i]
  97. logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
  98. break
  99. docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
  100. docs = {d.id: d.meta_fields for d in docs}
  101. doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
  102. for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
  103. cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
  104. cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL | re.IGNORECASE)
  105. doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
  106. doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
  107. knowledges = []
  108. for nm, cks_meta in doc2chunks.items():
  109. txt = f"\nDocument: {nm} \n"
  110. for k, v in cks_meta["meta"].items():
  111. txt += f"{k}: {v}\n"
  112. txt += "Relevant fragments as following:\n"
  113. for i, chunk in enumerate(cks_meta["chunks"], 1):
  114. txt += f"{chunk}\n"
  115. knowledges.append(txt)
  116. return knowledges
  117. CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
  118. CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
  119. CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
  120. CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
  121. FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
  122. KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
  123. QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
  124. VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
  125. VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
  126. PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
  127. def citation_prompt() -> str:
  128. template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
  129. return template.render()
  130. def keyword_extraction(chat_mdl, content, topn=3):
  131. template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
  132. rendered_prompt = template.render(content=content, topn=topn)
  133. msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
  134. _, msg = message_fit_in(msg, chat_mdl.max_length)
  135. kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
  136. if isinstance(kwd, tuple):
  137. kwd = kwd[0]
  138. kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
  139. if kwd.find("**ERROR**") >= 0:
  140. return ""
  141. return kwd
  142. def question_proposal(chat_mdl, content, topn=3):
  143. template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
  144. rendered_prompt = template.render(content=content, topn=topn)
  145. msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
  146. _, msg = message_fit_in(msg, chat_mdl.max_length)
  147. kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
  148. if isinstance(kwd, tuple):
  149. kwd = kwd[0]
  150. kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
  151. if kwd.find("**ERROR**") >= 0:
  152. return ""
  153. return kwd
  154. def full_question(tenant_id, llm_id, messages, language=None):
  155. from api.db import LLMType
  156. from api.db.services.llm_service import LLMBundle
  157. if llm_id2llm_type(llm_id) == "image2text":
  158. chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
  159. else:
  160. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
  161. conv = []
  162. for m in messages:
  163. if m["role"] not in ["user", "assistant"]:
  164. continue
  165. conv.append("{}: {}".format(m["role"].upper(), m["content"]))
  166. conversation = "\n".join(conv)
  167. today = datetime.date.today().isoformat()
  168. yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
  169. tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
  170. template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
  171. rendered_prompt = template.render(
  172. today=today,
  173. yesterday=yesterday,
  174. tomorrow=tomorrow,
  175. conversation=conversation,
  176. language=language,
  177. )
  178. ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
  179. ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  180. return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
  181. def cross_languages(tenant_id, llm_id, query, languages=[]):
  182. from api.db import LLMType
  183. from api.db.services.llm_service import LLMBundle
  184. if llm_id and llm_id2llm_type(llm_id) == "image2text":
  185. chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
  186. else:
  187. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
  188. rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
  189. rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
  190. ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
  191. ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  192. if ans.find("**ERROR**") >= 0:
  193. return query
  194. return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
  195. def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
  196. template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
  197. for ex in examples:
  198. ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
  199. rendered_prompt = template.render(
  200. topn=topn,
  201. all_tags=all_tags,
  202. examples=examples,
  203. content=content,
  204. )
  205. msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
  206. _, msg = message_fit_in(msg, chat_mdl.max_length)
  207. kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
  208. if isinstance(kwd, tuple):
  209. kwd = kwd[0]
  210. kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
  211. if kwd.find("**ERROR**") >= 0:
  212. raise Exception(kwd)
  213. try:
  214. obj = json_repair.loads(kwd)
  215. except json_repair.JSONDecodeError:
  216. try:
  217. result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
  218. result = "{" + result.split("{")[1].split("}")[0] + "}"
  219. obj = json_repair.loads(result)
  220. except Exception as e:
  221. logging.exception(f"JSON parsing error: {result} -> {e}")
  222. raise e
  223. res = {}
  224. for k, v in obj.items():
  225. try:
  226. if int(v) > 0:
  227. res[str(k)] = int(v)
  228. except Exception:
  229. pass
  230. return res
  231. def vision_llm_describe_prompt(page=None) -> str:
  232. template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
  233. return template.render(page=page)
  234. def vision_llm_figure_describe_prompt() -> str:
  235. template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
  236. return template.render()
  237. if __name__ == "__main__":
  238. print(CITATION_PROMPT_TEMPLATE)
  239. print(CONTENT_TAGGING_PROMPT_TEMPLATE)
  240. print(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE)
  241. print(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE)
  242. print(FULL_QUESTION_PROMPT_TEMPLATE)
  243. print(KEYWORD_PROMPT_TEMPLATE)
  244. print(QUESTION_PROMPT_TEMPLATE)
  245. print(VISION_LLM_DESCRIBE_PROMPT)
  246. print(VISION_LLM_FIGURE_DESCRIBE_PROMPT)