You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import datetime
  17. import json
  18. import logging
  19. import os
  20. import re
  21. from collections import defaultdict
  22. import json_repair
  23. from api.db import LLMType
  24. from api.db.services.document_service import DocumentService
  25. from api.db.services.llm_service import TenantLLMService, LLMBundle
  26. from api.utils.file_utils import get_project_base_directory
  27. from rag.settings import TAG_FLD
  28. from rag.utils import num_tokens_from_string, encoder
  29. def chunks_format(reference):
  30. def get_value(d, k1, k2):
  31. return d.get(k1, d.get(k2))
  32. return [{
  33. "id": get_value(chunk, "chunk_id", "id"),
  34. "content": get_value(chunk, "content", "content_with_weight"),
  35. "document_id": get_value(chunk, "doc_id", "document_id"),
  36. "document_name": get_value(chunk, "docnm_kwd", "document_name"),
  37. "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
  38. "image_id": get_value(chunk, "image_id", "img_id"),
  39. "positions": get_value(chunk, "positions", "position_int"),
  40. "url": chunk.get("url")
  41. } for chunk in reference.get("chunks", [])]
  42. def llm_id2llm_type(llm_id):
  43. llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
  44. fnm = os.path.join(get_project_base_directory(), "conf")
  45. llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
  46. for llm_factory in llm_factories["factory_llm_infos"]:
  47. for llm in llm_factory["llm"]:
  48. if llm_id == llm["llm_name"]:
  49. return llm["model_type"].strip(",")[-1]
  50. def message_fit_in(msg, max_length=4000):
  51. def count():
  52. nonlocal msg
  53. tks_cnts = []
  54. for m in msg:
  55. tks_cnts.append(
  56. {"role": m["role"], "count": num_tokens_from_string(m["content"])})
  57. total = 0
  58. for m in tks_cnts:
  59. total += m["count"]
  60. return total
  61. c = count()
  62. if c < max_length:
  63. return c, msg
  64. msg_ = [m for m in msg[:-1] if m["role"] == "system"]
  65. if len(msg) > 1:
  66. msg_.append(msg[-1])
  67. msg = msg_
  68. c = count()
  69. if c < max_length:
  70. return c, msg
  71. ll = num_tokens_from_string(msg_[0]["content"])
  72. ll2 = num_tokens_from_string(msg_[-1]["content"])
  73. if ll / (ll + ll2) > 0.8:
  74. m = msg_[0]["content"]
  75. m = encoder.decode(encoder.encode(m)[:max_length - ll2])
  76. msg[0]["content"] = m
  77. return max_length, msg
  78. m = msg_[1]["content"]
  79. m = encoder.decode(encoder.encode(m)[:max_length - ll2])
  80. msg[1]["content"] = m
  81. return max_length, msg
  82. def kb_prompt(kbinfos, max_tokens):
  83. knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
  84. used_token_count = 0
  85. chunks_num = 0
  86. for i, c in enumerate(knowledges):
  87. used_token_count += num_tokens_from_string(c)
  88. chunks_num += 1
  89. if max_tokens * 0.97 < used_token_count:
  90. knowledges = knowledges[:i]
  91. logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
  92. break
  93. docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
  94. docs = {d.id: d.meta_fields for d in docs}
  95. doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
  96. for ck in kbinfos["chunks"][:chunks_num]:
  97. doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + ck["content_with_weight"])
  98. doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
  99. knowledges = []
  100. for nm, cks_meta in doc2chunks.items():
  101. txt = f"Document: {nm} \n"
  102. for k, v in cks_meta["meta"].items():
  103. txt += f"{k}: {v}\n"
  104. txt += "Relevant fragments as following:\n"
  105. for i, chunk in enumerate(cks_meta["chunks"], 1):
  106. txt += f"{i}. {chunk}\n"
  107. knowledges.append(txt)
  108. return knowledges
  109. def keyword_extraction(chat_mdl, content, topn=3):
  110. prompt = f"""
  111. Role: You're a text analyzer.
  112. Task: extract the most important keywords/phrases of a given piece of text content.
  113. Requirements:
  114. - Summarize the text content, and give top {topn} important keywords/phrases.
  115. - The keywords MUST be in language of the given piece of text content.
  116. - The keywords are delimited by ENGLISH COMMA.
  117. - Keywords ONLY in output.
  118. ### Text Content
  119. {content}
  120. """
  121. msg = [
  122. {"role": "system", "content": prompt},
  123. {"role": "user", "content": "Output: "}
  124. ]
  125. _, msg = message_fit_in(msg, chat_mdl.max_length)
  126. kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
  127. if isinstance(kwd, tuple):
  128. kwd = kwd[0]
  129. kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
  130. if kwd.find("**ERROR**") >= 0:
  131. return ""
  132. return kwd
  133. def question_proposal(chat_mdl, content, topn=3):
  134. prompt = f"""
  135. Role: You're a text analyzer.
  136. Task: propose {topn} questions about a given piece of text content.
  137. Requirements:
  138. - Understand and summarize the text content, and propose top {topn} important questions.
  139. - The questions SHOULD NOT have overlapping meanings.
  140. - The questions SHOULD cover the main content of the text as much as possible.
  141. - The questions MUST be in language of the given piece of text content.
  142. - One question per line.
  143. - Question ONLY in output.
  144. ### Text Content
  145. {content}
  146. """
  147. msg = [
  148. {"role": "system", "content": prompt},
  149. {"role": "user", "content": "Output: "}
  150. ]
  151. _, msg = message_fit_in(msg, chat_mdl.max_length)
  152. kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
  153. if isinstance(kwd, tuple):
  154. kwd = kwd[0]
  155. kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
  156. if kwd.find("**ERROR**") >= 0:
  157. return ""
  158. return kwd
  159. def full_question(tenant_id, llm_id, messages):
  160. if llm_id2llm_type(llm_id) == "image2text":
  161. chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
  162. else:
  163. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
  164. conv = []
  165. for m in messages:
  166. if m["role"] not in ["user", "assistant"]:
  167. continue
  168. conv.append("{}: {}".format(m["role"].upper(), m["content"]))
  169. conv = "\n".join(conv)
  170. today = datetime.date.today().isoformat()
  171. yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
  172. tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
  173. prompt = f"""
  174. Role: A helpful assistant
  175. Task and steps:
  176. 1. Generate a full user question that would follow the conversation.
  177. 2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
  178. Requirements & Restrictions:
  179. - Text generated MUST be in the same language of the original user's question.
  180. - If the user's latest question is completely, don't do anything, just return the original question.
  181. - DON'T generate anything except a refined question.
  182. ######################
  183. -Examples-
  184. ######################
  185. # Example 1
  186. ## Conversation
  187. USER: What is the name of Donald Trump's father?
  188. ASSISTANT: Fred Trump.
  189. USER: And his mother?
  190. ###############
  191. Output: What's the name of Donald Trump's mother?
  192. ------------
  193. # Example 2
  194. ## Conversation
  195. USER: What is the name of Donald Trump's father?
  196. ASSISTANT: Fred Trump.
  197. USER: And his mother?
  198. ASSISTANT: Mary Trump.
  199. User: What's her full name?
  200. ###############
  201. Output: What's the full name of Donald Trump's mother Mary Trump?
  202. ------------
  203. # Example 3
  204. ## Conversation
  205. USER: What's the weather today in London?
  206. ASSISTANT: Cloudy.
  207. USER: What's about tomorrow in Rochester?
  208. ###############
  209. Output: What's the weather in Rochester on {tomorrow}?
  210. ######################
  211. # Real Data
  212. ## Conversation
  213. {conv}
  214. ###############
  215. """
  216. ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
  217. ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
  218. return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
  219. def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
  220. prompt = f"""
  221. Role: You're a text analyzer.
  222. Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
  223. Steps::
  224. - Comprehend the tag/label set.
  225. - Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
  226. - Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
  227. Requirements
  228. - The tags MUST be from the tag set.
  229. - The output MUST be in JSON format only, the key is tag and the value is its relevance score.
  230. - The relevance score must be range from 1 to 10.
  231. - Keywords ONLY in output.
  232. # TAG SET
  233. {", ".join(all_tags)}
  234. """
  235. for i, ex in enumerate(examples):
  236. prompt += """
  237. # Examples {}
  238. ### Text Content
  239. {}
  240. Output:
  241. {}
  242. """.format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
  243. prompt += f"""
  244. # Real Data
  245. ### Text Content
  246. {content}
  247. """
  248. msg = [
  249. {"role": "system", "content": prompt},
  250. {"role": "user", "content": "Output: "}
  251. ]
  252. _, msg = message_fit_in(msg, chat_mdl.max_length)
  253. kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
  254. if isinstance(kwd, tuple):
  255. kwd = kwd[0]
  256. kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
  257. if kwd.find("**ERROR**") >= 0:
  258. raise Exception(kwd)
  259. try:
  260. return json_repair.loads(kwd)
  261. except json_repair.JSONDecodeError:
  262. try:
  263. result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
  264. result = '{' + result.split('{')[1].split('}')[0] + '}'
  265. return json_repair.loads(result)
  266. except Exception as e:
  267. logging.exception(f"JSON parsing error: {result} -> {e}")
  268. raise e