| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import datetime
- import json
- import logging
- import re
- from collections import defaultdict
-
- import jinja2
- import json_repair
-
- from api import settings
- from rag.prompt_template import load_prompt
- from rag.settings import TAG_FLD
- from rag.utils import encoder, num_tokens_from_string
-
-
- def chunks_format(reference):
- def get_value(d, k1, k2):
- return d.get(k1, d.get(k2))
-
- return [
- {
- "id": get_value(chunk, "chunk_id", "id"),
- "content": get_value(chunk, "content", "content_with_weight"),
- "document_id": get_value(chunk, "doc_id", "document_id"),
- "document_name": get_value(chunk, "docnm_kwd", "document_name"),
- "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
- "image_id": get_value(chunk, "image_id", "img_id"),
- "positions": get_value(chunk, "positions", "position_int"),
- "url": chunk.get("url"),
- "similarity": chunk.get("similarity"),
- "vector_similarity": chunk.get("vector_similarity"),
- "term_similarity": chunk.get("term_similarity"),
- "doc_type": chunk.get("doc_type_kwd"),
- }
- for chunk in reference.get("chunks", [])
- ]
-
-
- def llm_id2llm_type(llm_id):
- from api.db.services.llm_service import TenantLLMService
-
- llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
-
- llm_factories = settings.FACTORY_LLM_INFOS
- for llm_factory in llm_factories:
- for llm in llm_factory["llm"]:
- if llm_id == llm["llm_name"]:
- return llm["model_type"].strip(",")[-1]
-
-
- def message_fit_in(msg, max_length=4000):
- def count():
- nonlocal msg
- tks_cnts = []
- for m in msg:
- tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
- total = 0
- for m in tks_cnts:
- total += m["count"]
- return total
-
- c = count()
- if c < max_length:
- return c, msg
-
- msg_ = [m for m in msg if m["role"] == "system"]
- if len(msg) > 1:
- msg_.append(msg[-1])
- msg = msg_
- c = count()
- if c < max_length:
- return c, msg
-
- ll = num_tokens_from_string(msg_[0]["content"])
- ll2 = num_tokens_from_string(msg_[-1]["content"])
- if ll / (ll + ll2) > 0.8:
- m = msg_[0]["content"]
- m = encoder.decode(encoder.encode(m)[: max_length - ll2])
- msg[0]["content"] = m
- return max_length, msg
-
- m = msg_[-1]["content"]
- m = encoder.decode(encoder.encode(m)[: max_length - ll2])
- msg[-1]["content"] = m
- return max_length, msg
-
-
- def kb_prompt(kbinfos, max_tokens):
- from api.db.services.document_service import DocumentService
-
- knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
- kwlg_len = len(knowledges)
- used_token_count = 0
- chunks_num = 0
- for i, c in enumerate(knowledges):
- used_token_count += num_tokens_from_string(c)
- chunks_num += 1
- if max_tokens * 0.97 < used_token_count:
- knowledges = knowledges[:i]
- logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
- break
-
- docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
- docs = {d.id: d.meta_fields for d in docs}
-
- doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
- for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
- cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
- cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL | re.IGNORECASE)
- doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
- doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
-
- knowledges = []
- for nm, cks_meta in doc2chunks.items():
- txt = f"\nDocument: {nm} \n"
- for k, v in cks_meta["meta"].items():
- txt += f"{k}: {v}\n"
- txt += "Relevant fragments as following:\n"
- for i, chunk in enumerate(cks_meta["chunks"], 1):
- txt += f"{chunk}\n"
- knowledges.append(txt)
- return knowledges
-
-
- CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
- CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
- CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
- CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
- FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
- KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
- QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
- VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
- VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
-
- PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
-
-
- def citation_prompt() -> str:
- template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
- return template.render()
-
-
- def keyword_extraction(chat_mdl, content, topn=3):
- template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
- rendered_prompt = template.render(content=content, topn=topn)
-
- msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
- _, msg = message_fit_in(msg, chat_mdl.max_length)
- kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
- if isinstance(kwd, tuple):
- kwd = kwd[0]
- kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
- if kwd.find("**ERROR**") >= 0:
- return ""
- return kwd
-
-
- def question_proposal(chat_mdl, content, topn=3):
- template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
- rendered_prompt = template.render(content=content, topn=topn)
-
- msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
- _, msg = message_fit_in(msg, chat_mdl.max_length)
- kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
- if isinstance(kwd, tuple):
- kwd = kwd[0]
- kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
- if kwd.find("**ERROR**") >= 0:
- return ""
- return kwd
-
-
- def full_question(tenant_id, llm_id, messages, language=None):
- from api.db import LLMType
- from api.db.services.llm_service import LLMBundle
-
- if llm_id2llm_type(llm_id) == "image2text":
- chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
- else:
- chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
- conv = []
- for m in messages:
- if m["role"] not in ["user", "assistant"]:
- continue
- conv.append("{}: {}".format(m["role"].upper(), m["content"]))
- conversation = "\n".join(conv)
- today = datetime.date.today().isoformat()
- yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
- tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
-
- template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
- rendered_prompt = template.render(
- today=today,
- yesterday=yesterday,
- tomorrow=tomorrow,
- conversation=conversation,
- language=language,
- )
-
- ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
- ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
- return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
-
-
- def cross_languages(tenant_id, llm_id, query, languages=[]):
- from api.db import LLMType
- from api.db.services.llm_service import LLMBundle
-
- if llm_id and llm_id2llm_type(llm_id) == "image2text":
- chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
- else:
- chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
-
- rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
- rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
-
- ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
- ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
- if ans.find("**ERROR**") >= 0:
- return query
- return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
-
-
- def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
- template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
-
- for ex in examples:
- ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
-
- rendered_prompt = template.render(
- topn=topn,
- all_tags=all_tags,
- examples=examples,
- content=content,
- )
-
- msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
- _, msg = message_fit_in(msg, chat_mdl.max_length)
- kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
- if isinstance(kwd, tuple):
- kwd = kwd[0]
- kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
- if kwd.find("**ERROR**") >= 0:
- raise Exception(kwd)
-
- try:
- obj = json_repair.loads(kwd)
- except json_repair.JSONDecodeError:
- try:
- result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
- result = "{" + result.split("{")[1].split("}")[0] + "}"
- obj = json_repair.loads(result)
- except Exception as e:
- logging.exception(f"JSON parsing error: {result} -> {e}")
- raise e
- res = {}
- for k, v in obj.items():
- try:
- if int(v) > 0:
- res[str(k)] = int(v)
- except Exception:
- pass
- return res
-
-
- def vision_llm_describe_prompt(page=None) -> str:
- template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
-
- return template.render(page=page)
-
-
- def vision_llm_figure_describe_prompt() -> str:
- template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
- return template.render()
-
-
- if __name__ == "__main__":
- print(CITATION_PROMPT_TEMPLATE)
- print(CONTENT_TAGGING_PROMPT_TEMPLATE)
- print(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE)
- print(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE)
- print(FULL_QUESTION_PROMPT_TEMPLATE)
- print(KEYWORD_PROMPT_TEMPLATE)
- print(QUESTION_PROMPT_TEMPLATE)
- print(VISION_LLM_DESCRIBE_PROMPT)
- print(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
|