|
|
|
@@ -14,11 +14,11 @@ |
|
|
|
# limitations under the License. |
|
|
|
# |
|
|
|
import binascii |
|
|
|
from datetime import datetime |
|
|
|
import logging |
|
|
|
import re |
|
|
|
import time |
|
|
|
from copy import deepcopy |
|
|
|
from datetime import datetime |
|
|
|
from functools import partial |
|
|
|
from timeit import default_timer as timer |
|
|
|
|
|
|
|
@@ -36,8 +36,7 @@ from api.utils import current_timestamp, datetime_format |
|
|
|
from rag.app.resume import forbidden_select_fields4resume |
|
|
|
from rag.app.tag import label_question |
|
|
|
from rag.nlp.search import index_name |
|
|
|
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \ |
|
|
|
cross_languages |
|
|
|
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in |
|
|
|
from rag.utils import num_tokens_from_string, rmSpace |
|
|
|
from rag.utils.tavily_conn import Tavily |
|
|
|
|
|
|
|
@@ -303,6 +302,39 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) |
|
|
|
|
|
|
|
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: dict): |
|
|
|
max_index = len(kbinfos["chunks"]) |
|
|
|
|
|
|
|
def safe_add(i): |
|
|
|
if 0 <= i < max_index: |
|
|
|
idx.add(i) |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0): |
|
|
|
nonlocal answer |
|
|
|
for match in re.finditer(pattern, answer, flags=flags): |
|
|
|
try: |
|
|
|
i = int(match.group(group_index)) |
|
|
|
if safe_add(i): |
|
|
|
answer = answer.replace(match.group(0), repl(i)) |
|
|
|
except Exception: |
|
|
|
continue |
|
|
|
|
|
|
|
find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12) |
|
|
|
find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12 |
|
|
|
find_and_replace(r"\$\$(\d+)\$\$") # $$12$$ |
|
|
|
find_and_replace(r"\$\[(\d+)\]\$") # $[12]$ |
|
|
|
find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$ |
|
|
|
find_and_replace(r"\$(\d+)\$") # $12$ |
|
|
|
find_and_replace(r"#(\d+)\$\$") # #12$$ |
|
|
|
find_and_replace(r"##(\d+)\$") # ##12$ |
|
|
|
find_and_replace(r"##(\d+)#{2,}") # ##12### |
|
|
|
find_and_replace(r"【(\d+)】") # 【12】 |
|
|
|
find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12 |
|
|
|
|
|
|
|
return answer, idx |
|
|
|
|
|
|
|
def decorate_answer(answer): |
|
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer |
|
|
|
|
|
|
|
@@ -331,15 +363,7 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
if i < len(kbinfos["chunks"]): |
|
|
|
idx.add(i) |
|
|
|
|
|
|
|
# handle (ID: 1), ID: 2 etc. |
|
|
|
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer): |
|
|
|
full_match = match.group(0) |
|
|
|
id = match.group(1) or match.group(2) |
|
|
|
if id: |
|
|
|
i = int(id) |
|
|
|
if i < len(kbinfos["chunks"]): |
|
|
|
idx.add(i) |
|
|
|
answer = answer.replace(full_match, f"##{i}$$") |
|
|
|
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx) |
|
|
|
|
|
|
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) |
|
|
|
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] |
|
|
|
@@ -502,7 +526,7 @@ Please write the SQL, only SQL, without any other explanations or text. |
|
|
|
|
|
|
|
# compose Markdown table |
|
|
|
columns = ( |
|
|
|
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") |
|
|
|
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") |
|
|
|
) |
|
|
|
|
|
|
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") |
|
|
|
@@ -598,4 +622,5 @@ def ask(question, kb_ids, tenant_id): |
|
|
|
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): |
|
|
|
answer = ans |
|
|
|
yield {"answer": answer, "reference": {}} |
|
|
|
yield decorate_answer(answer) |
|
|
|
yield decorate_answer(answer) |
|
|
|
|