# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import json import logging import re from copy import deepcopy from typing import Tuple import jinja2 import json_repair from api.utils import hash_str2int from rag.prompts.prompt_template import load_prompt from rag.settings import TAG_FLD from rag.utils import encoder, num_tokens_from_string STOP_TOKEN="<|STOP|>" COMPLETE_TASK="complete_task" def get_value(d, k1, k2): return d.get(k1, d.get(k2)) def chunks_format(reference): return [ { "id": get_value(chunk, "chunk_id", "id"), "content": get_value(chunk, "content", "content_with_weight"), "document_id": get_value(chunk, "doc_id", "document_id"), "document_name": get_value(chunk, "docnm_kwd", "document_name"), "dataset_id": get_value(chunk, "kb_id", "dataset_id"), "image_id": get_value(chunk, "image_id", "img_id"), "positions": get_value(chunk, "positions", "position_int"), "url": chunk.get("url"), "similarity": chunk.get("similarity"), "vector_similarity": chunk.get("vector_similarity"), "term_similarity": chunk.get("term_similarity"), "doc_type": chunk.get("doc_type_kwd"), } for chunk in reference.get("chunks", []) ] def message_fit_in(msg, max_length=4000): def count(): nonlocal msg tks_cnts = [] for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])}) total = 0 for m in tks_cnts: total += m["count"] return total c = count() if c < max_length: return c, msg msg_ = [m for m in msg if m["role"] == "system"] if len(msg) > 1: msg_.append(msg[-1]) msg = msg_ c = count() if c < max_length: return c, msg ll = num_tokens_from_string(msg_[0]["content"]) ll2 = num_tokens_from_string(msg_[-1]["content"]) if ll / (ll + ll2) > 0.8: m = msg_[0]["content"] m = encoder.decode(encoder.encode(m)[: max_length - ll2]) msg[0]["content"] = m return max_length, msg m = msg_[-1]["content"] m = encoder.decode(encoder.encode(m)[: max_length - ll2]) msg[-1]["content"] = m return max_length, msg def kb_prompt(kbinfos, max_tokens, hash_id=False): from api.db.services.document_service import DocumentService knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]] kwlg_len = len(knowledges) used_token_count = 0 chunks_num = 0 for i, c in enumerate(knowledges): if not c: continue used_token_count += num_tokens_from_string(c) chunks_num += 1 if max_tokens * 0.97 < used_token_count: knowledges = knowledges[:i] logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}") break docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]]) docs = {d.id: d.meta_fields for d in docs} def draw_node(k, line): if not line: return "" return f"\n├── {k}: " + re.sub(r"\n+", " ", line, flags=re.DOTALL) knowledges = [] for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100)) cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name")) cnt += draw_node("URL", ck['url']) if "url" in ck else "" for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items(): cnt += draw_node(k, v) cnt += "\n└── Content:\n" cnt += get_value(ck, "content", "content_with_weight") knowledges.append(cnt) return knowledges CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt") CITATION_PLUS_TEMPLATE = load_prompt("citation_plus") CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt") CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt") CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt") FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt") KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt") QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt") VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt") VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt") ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system") ANALYZE_TASK_USER = load_prompt("analyze_task_user") NEXT_STEP = load_prompt("next_step") REFLECT = load_prompt("reflect") SUMMARY4MEMORY = load_prompt("summary4memory") RANK_MEMORY = load_prompt("rank_memory") PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True) def citation_prompt() -> str: template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE) return template.render() def citation_plus(sources: str) -> str: template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE) return template.render(example=citation_prompt(), sources=sources) def keyword_extraction(chat_mdl, content, topn=3): template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) if kwd.find("**ERROR**") >= 0: return "" return kwd def question_proposal(chat_mdl, content, topn=3): template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) if kwd.find("**ERROR**") >= 0: return "" return kwd def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): from api.db import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import TenantLLMService if not chat_mdl: if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) conv = [] for m in messages: if m["role"] not in ["user", "assistant"]: continue conv.append("{}: {}".format(m["role"].upper(), m["content"])) conversation = "\n".join(conv) today = datetime.date.today().isoformat() yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat() tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat() template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE) rendered_prompt = template.render( today=today, yesterday=yesterday, tomorrow=tomorrow, conversation=conversation, language=language, ) ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}]) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] def cross_languages(tenant_id, llm_id, query, languages=[]): from api.db import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import TenantLLMService if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) if ans.find("**ERROR**") >= 0: return query return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()]) def content_tagging(chat_mdl, content, all_tags, examples, topn=3): template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE) for ex in examples: ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False) rendered_prompt = template.render( topn=topn, all_tags=all_tags, examples=examples, content=content, ) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) if kwd.find("**ERROR**") >= 0: raise Exception(kwd) try: obj = json_repair.loads(kwd) except json_repair.JSONDecodeError: try: result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip() result = "{" + result.split("{")[1].split("}")[0] + "}" obj = json_repair.loads(result) except Exception as e: logging.exception(f"JSON parsing error: {result} -> {e}") raise e res = {} for k, v in obj.items(): try: if int(v) > 0: res[str(k)] = int(v) except Exception: pass return res def vision_llm_describe_prompt(page=None) -> str: template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT) return template.render(page=page) def vision_llm_figure_describe_prompt() -> str: template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT) return template.render() def tool_schema(tools_description: list[dict], complete_task=False): if not tools_description: return "" desc = {} if complete_task: desc[COMPLETE_TASK] = { "type": "function", "function": { "name": COMPLETE_TASK, "description": "When you have the final answer and are ready to complete the task, call this function with your answer", "parameters": { "type": "object", "properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}}, "required": ["answer"] } } } for tool in tools_description: desc[tool["function"]["name"]] = tool return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())]) def form_history(history, limit=-6): context = "" for h in history[limit:]: if h["role"] == "system": continue role = "USER" if h["role"].upper()!= role: role = "AGENT" context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}" return context def analyze_task(chat_mdl, task_name, tools_description: list[dict]): tools_desc = tool_schema(tools_description) context = "" template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER) kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": template.render(task=task_name, context=context, tools_desc=tools_desc)}], {}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) if kwd.find("**ERROR**") >= 0: return "" return kwd def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc): if not tools_description: return "" desc = tool_schema(tools_description) template = PROMPT_JINJA_ENV.from_string(NEXT_STEP) user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`." hist = deepcopy(history) if hist[-1]["role"] == "user": hist[-1]["content"] += user_prompt else: hist.append({"role": "user", "content": user_prompt}) json_str = chat_mdl.chat(template.render(task_analisys=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")), hist[1:], stop=["<|stop|>"]) tk_cnt = num_tokens_from_string(json_str) json_str = re.sub(r"^.*", "", json_str, flags=re.DOTALL) return json_str, tk_cnt def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple]): tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] goal = history[1]["content"] template = PROMPT_JINJA_ENV.from_string(REFLECT) user_prompt = template.render(goal=goal, tool_calls=tool_calls) hist = deepcopy(history) if hist[-1]["role"] == "user": hist[-1]["content"] += user_prompt else: hist.append({"role": "user", "content": user_prompt}) _, msg = message_fit_in(hist, chat_mdl.max_length) ans = chat_mdl.chat(msg[0]["content"], msg[1:]) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return """ **Observation** {} **Reflection** {} """.format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans) def form_message(system_prompt, user_prompt): return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}] def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str: template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY) system_prompt = template.render(name=name, params=json.dumps(params, ensure_ascii=False, indent=2), result=result) user_prompt = "→ Summary: " _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) ans = chat_mdl.chat(msg[0]["content"], msg[1:]) return re.sub(r"^.*", "", ans, flags=re.DOTALL) def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str]): template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY) system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) user_prompt = " → rank: " _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>") return re.sub(r"^.*", "", ans, flags=re.DOTALL)