|
|
|
@@ -40,6 +40,7 @@ from rag.nlp.search import index_name |
|
|
|
from rag.settings import TAG_FLD |
|
|
|
from rag.utils import rmSpace, num_tokens_from_string, encoder |
|
|
|
from api.utils.file_utils import get_project_base_directory |
|
|
|
from rag.utils.tavily_conn import Tavily |
|
|
|
|
|
|
|
|
|
|
|
class DialogService(CommonService): |
|
|
|
@@ -125,6 +126,7 @@ def kb_prompt(kbinfos, max_tokens): |
|
|
|
chunks_num += 1 |
|
|
|
if max_tokens * 0.97 < used_token_count: |
|
|
|
knowledges = knowledges[:i] |
|
|
|
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}") |
|
|
|
break |
|
|
|
|
|
|
|
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]]) |
|
|
|
@@ -132,7 +134,7 @@ def kb_prompt(kbinfos, max_tokens): |
|
|
|
|
|
|
|
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []}) |
|
|
|
for ck in kbinfos["chunks"][:chunks_num]: |
|
|
|
doc2chunks[ck["docnm_kwd"]]["chunks"].append(ck["content_with_weight"]) |
|
|
|
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + ck["content_with_weight"]) |
|
|
|
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {}) |
|
|
|
|
|
|
|
knowledges = [] |
|
|
|
@@ -295,7 +297,7 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
|
|
|
|
knowledges = [] |
|
|
|
if prompt_config.get("reasoning", False): |
|
|
|
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3): |
|
|
|
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, prompt_config, MAX_SEARCH_LIMIT=3): |
|
|
|
if isinstance(think, str): |
|
|
|
thought = think |
|
|
|
knowledges = [t for t in think.split("\n") if t] |
|
|
|
@@ -309,6 +311,11 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, |
|
|
|
rank_feature=label_question(" ".join(questions), kbs) |
|
|
|
) |
|
|
|
if prompt_config.get("tavily_api_key"): |
|
|
|
tav = Tavily(prompt_config["tavily_api_key"]) |
|
|
|
tav_res = tav.retrieve_chunks(" ".join(questions)) |
|
|
|
kbinfos["chunks"].extend(tav_res["chunks"]) |
|
|
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) |
|
|
|
if prompt_config.get("use_kg"): |
|
|
|
ck = settings.kg_retrievaler.retrieval(" ".join(questions), |
|
|
|
tenant_ids, |
|
|
|
@@ -852,7 +859,7 @@ Output: |
|
|
|
|
|
|
|
|
|
|
|
def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle, |
|
|
|
tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3, |
|
|
|
tenant_ids: list[str], kb_ids: list[str], prompt_config, MAX_SEARCH_LIMIT: int = 3, |
|
|
|
top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3): |
|
|
|
BEGIN_SEARCH_QUERY = "<|begin_search_query|>" |
|
|
|
END_SEARCH_QUERY = "<|end_search_query|>" |
|
|
|
@@ -1023,10 +1030,28 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL |
|
|
|
truncated_prev_reasoning += '...\n\n' |
|
|
|
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n') |
|
|
|
|
|
|
|
# Retrieval procedure: |
|
|
|
# 1. KB search |
|
|
|
# 2. Web search (optional) |
|
|
|
# 3. KG search (optional) |
|
|
|
kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n, |
|
|
|
similarity_threshold, |
|
|
|
vector_similarity_weight |
|
|
|
) |
|
|
|
if prompt_config.get("tavily_api_key", "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"): |
|
|
|
tav = Tavily(prompt_config["tavily_api_key"]) |
|
|
|
tav_res = tav.retrieve_chunks(" ".join(search_query)) |
|
|
|
kbinfos["chunks"].extend(tav_res["chunks"]) |
|
|
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) |
|
|
|
if prompt_config.get("use_kg"): |
|
|
|
ck = settings.kg_retrievaler.retrieval(search_query, |
|
|
|
tenant_ids, |
|
|
|
kb_ids, |
|
|
|
embd_mdl, |
|
|
|
chat_mdl) |
|
|
|
if ck["content_with_weight"]: |
|
|
|
kbinfos["chunks"].insert(0, ck) |
|
|
|
|
|
|
|
# Merge chunk info for citations |
|
|
|
if not chunk_info["chunks"]: |
|
|
|
for k in chunk_info.keys(): |
|
|
|
@@ -1048,7 +1073,7 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL |
|
|
|
relevant_extraction_prompt.format( |
|
|
|
prev_reasoning=truncated_prev_reasoning, |
|
|
|
search_query=search_query, |
|
|
|
document="\n".join(kb_prompt(kbinfos, 512)) |
|
|
|
document="\n".join(kb_prompt(kbinfos, 4096)) |
|
|
|
), |
|
|
|
[{"role": "user", |
|
|
|
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], |