|
|
|
@@ -18,7 +18,7 @@ import os |
|
|
|
import json |
|
|
|
import re |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
from timeit import default_timer as timer |
|
|
|
from api.db import LLMType, ParserType |
|
|
|
from api.db.db_models import Dialog, Conversation |
|
|
|
from api.db.services.common_service import CommonService |
|
|
|
@@ -88,6 +88,7 @@ def llm_id2llm_type(llm_id): |
|
|
|
|
|
|
|
def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." |
|
|
|
st = timer() |
|
|
|
llm = LLMService.query(llm_name=dialog.llm_id) |
|
|
|
if not llm: |
|
|
|
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id) |
|
|
|
@@ -158,25 +159,16 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
doc_ids=attachments, |
|
|
|
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) |
|
|
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] |
|
|
|
#self-rag |
|
|
|
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges): |
|
|
|
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1]) |
|
|
|
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, |
|
|
|
dialog.similarity_threshold, |
|
|
|
dialog.vector_similarity_weight, |
|
|
|
doc_ids=attachments, |
|
|
|
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) |
|
|
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] |
|
|
|
|
|
|
|
chat_logger.info( |
|
|
|
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) |
|
|
|
retrieval_tm = timer() |
|
|
|
|
|
|
|
if not knowledges and prompt_config.get("empty_response"): |
|
|
|
empty_res = prompt_config["empty_response"] |
|
|
|
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)} |
|
|
|
return {"answer": prompt_config["empty_response"], "reference": kbinfos} |
|
|
|
|
|
|
|
kwargs["knowledge"] = "\n".join(knowledges) |
|
|
|
kwargs["knowledge"] = "\n------\n".join(knowledges) |
|
|
|
gen_conf = dialog.llm_setting |
|
|
|
|
|
|
|
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] |
|
|
|
@@ -192,7 +184,7 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
max_tokens - used_token_count) |
|
|
|
|
|
|
|
def decorate_answer(answer): |
|
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt |
|
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm |
|
|
|
refs = [] |
|
|
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): |
|
|
|
answer, idx = retr.insert_citations(answer, |
|
|
|
@@ -216,7 +208,9 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
|
|
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" |
|
|
|
return {"answer": answer, "reference": refs, "prompt": prompt} |
|
|
|
done_tm = timer() |
|
|
|
prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000) |
|
|
|
return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", "<br/>", prompt)} |
|
|
|
|
|
|
|
if stream: |
|
|
|
last_ans = "" |
|
|
|
@@ -415,4 +409,75 @@ def tts(tts_mdl, text): |
|
|
|
bin = b"" |
|
|
|
for chunk in tts_mdl.tts(text): |
|
|
|
bin += chunk |
|
|
|
return binascii.hexlify(bin).decode("utf-8") |
|
|
|
return binascii.hexlify(bin).decode("utf-8") |
|
|
|
|
|
|
|
|
|
|
|
def ask(question, kb_ids, tenant_id): |
|
|
|
kbs = KnowledgebaseService.get_by_ids(kb_ids) |
|
|
|
embd_nms = list(set([kb.embd_id for kb in kbs])) |
|
|
|
|
|
|
|
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) |
|
|
|
retr = retrievaler if not is_kg else kg_retrievaler |
|
|
|
|
|
|
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) |
|
|
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) |
|
|
|
max_tokens = chat_mdl.max_length |
|
|
|
|
|
|
|
kbinfos = retr.retrieval(question, embd_mdl, tenant_id, kb_ids, 1, 12, 0.1, 0.3, aggs=False) |
|
|
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] |
|
|
|
|
|
|
|
used_token_count = 0 |
|
|
|
for i, c in enumerate(knowledges): |
|
|
|
used_token_count += num_tokens_from_string(c) |
|
|
|
if max_tokens * 0.97 < used_token_count: |
|
|
|
knowledges = knowledges[:i] |
|
|
|
break |
|
|
|
|
|
|
|
prompt = """ |
|
|
|
Role: You're a smart assistant. Your name is Miss R. |
|
|
|
Task: Summarize the information from knowledge bases and answer user's question. |
|
|
|
Requirements and restriction: |
|
|
|
- DO NOT make things up, especially for numbers. |
|
|
|
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. |
|
|
|
- Answer with markdown format text. |
|
|
|
- Answer in language of user's question. |
|
|
|
- DO NOT make things up, especially for numbers. |
|
|
|
|
|
|
|
### Information from knowledge bases |
|
|
|
%s |
|
|
|
|
|
|
|
The above is information from knowledge bases. |
|
|
|
|
|
|
|
"""%"\n".join(knowledges) |
|
|
|
msg = [{"role": "user", "content": question}] |
|
|
|
|
|
|
|
def decorate_answer(answer): |
|
|
|
nonlocal knowledges, kbinfos, prompt |
|
|
|
answer, idx = retr.insert_citations(answer, |
|
|
|
[ck["content_ltks"] |
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
[ck["vector"] |
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
embd_mdl, |
|
|
|
tkweight=0.7, |
|
|
|
vtweight=0.3) |
|
|
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) |
|
|
|
recall_docs = [ |
|
|
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] |
|
|
|
if not recall_docs: recall_docs = kbinfos["doc_aggs"] |
|
|
|
kbinfos["doc_aggs"] = recall_docs |
|
|
|
refs = deepcopy(kbinfos) |
|
|
|
for c in refs["chunks"]: |
|
|
|
if c.get("vector"): |
|
|
|
del c["vector"] |
|
|
|
|
|
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" |
|
|
|
return {"answer": answer, "reference": refs} |
|
|
|
|
|
|
|
answer = "" |
|
|
|
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): |
|
|
|
answer = ans |
|
|
|
yield {"answer": answer, "reference": {}} |
|
|
|
yield decorate_answer(answer) |
|
|
|
|