浏览代码

Refa: more fallbacks for bad citation format (#7710)

### What problem does this PR solve?

More fallbacks for bad citation format

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
tags/v0.19.0
Yongteng Lei 5 个月前
父节点
当前提交
e8e2a95165
没有帐户链接到提交者的电子邮件
共有 1 个文件被更改,包括 39 次插入14 次删除
  1. 39
    14
      api/db/services/dialog_service.py

+ 39
- 14
api/db/services/dialog_service.py 查看文件

# limitations under the License. # limitations under the License.
# #
import binascii import binascii
from datetime import datetime
import logging import logging
import re import re
import time import time
from copy import deepcopy from copy import deepcopy
from datetime import datetime
from functools import partial from functools import partial
from timeit import default_timer as timer from timeit import default_timer as timer


from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \
cross_languages
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
from rag.utils import num_tokens_from_string, rmSpace from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily


if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)


def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: dict):
max_index = len(kbinfos["chunks"])

def safe_add(i):
if 0 <= i < max_index:
idx.add(i)
return True
return False

def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0):
nonlocal answer
for match in re.finditer(pattern, answer, flags=flags):
try:
i = int(match.group(group_index))
if safe_add(i):
answer = answer.replace(match.group(0), repl(i))
except Exception:
continue

find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12)
find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12
find_and_replace(r"\$\$(\d+)\$\$") # $$12$$
find_and_replace(r"\$\[(\d+)\]\$") # $[12]$
find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$
find_and_replace(r"\$(\d+)\$") # $12$
find_and_replace(r"#(\d+)\$\$") # #12$$
find_and_replace(r"##(\d+)\$") # ##12$
find_and_replace(r"##(\d+)#{2,}") # ##12###
find_and_replace(r"【(\d+)】") # 【12】
find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12

return answer, idx

def decorate_answer(answer): def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer


if i < len(kbinfos["chunks"]): if i < len(kbinfos["chunks"]):
idx.add(i) idx.add(i)


# handle (ID: 1), ID: 2 etc.
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer):
full_match = match.group(0)
id = match.group(1) or match.group(2)
if id:
i = int(id)
if i < len(kbinfos["chunks"]):
idx.add(i)
answer = answer.replace(full_match, f"##{i}$$")
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)


idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]


# compose Markdown table # compose Markdown table
columns = ( columns = (
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
) )


line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
yield decorate_answer(answer)


正在加载...
取消
保存