|
|
|
@@ -67,6 +67,34 @@ class Generate(ComponentBase): |
|
|
|
cpnts = [para["component_id"] for para in self._param.parameters] |
|
|
|
return cpnts |
|
|
|
|
|
|
|
def set_cite(self, retrieval_res, answer): |
|
|
|
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], |
|
|
|
[ck["vector"] for _, ck in retrieval_res.iterrows()], |
|
|
|
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, |
|
|
|
self._canvas.get_embedding_model()), tkweight=0.7, |
|
|
|
vtweight=0.3) |
|
|
|
doc_ids = set([]) |
|
|
|
recall_docs = [] |
|
|
|
for i in idx: |
|
|
|
did = retrieval_res.loc[int(i), "doc_id"] |
|
|
|
if did in doc_ids: continue |
|
|
|
doc_ids.add(did) |
|
|
|
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) |
|
|
|
|
|
|
|
del retrieval_res["vector"] |
|
|
|
del retrieval_res["content_ltks"] |
|
|
|
|
|
|
|
reference = { |
|
|
|
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()], |
|
|
|
"doc_aggs": recall_docs |
|
|
|
} |
|
|
|
|
|
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" |
|
|
|
res = {"content": answer, "reference": reference} |
|
|
|
|
|
|
|
return res |
|
|
|
|
|
|
|
def _run(self, history, **kwargs): |
|
|
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) |
|
|
|
prompt = self._param.prompt |
|
|
|
@@ -87,9 +115,8 @@ class Generate(ComponentBase): |
|
|
|
prompt = re.sub(r"\{%s\}" % n, str(v), prompt) |
|
|
|
|
|
|
|
downstreams = self._canvas.get_component(self._id)["downstream"] |
|
|
|
if kwargs.get("stream") \ |
|
|
|
and len(downstreams) == 1 \ |
|
|
|
and self._canvas.get_component(downstreams[0])["obj"].component_name.lower() == "answer": |
|
|
|
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[ |
|
|
|
"obj"].component_name.lower() == "answer": |
|
|
|
return partial(self.stream_output, chat_mdl, prompt, retrieval_res) |
|
|
|
|
|
|
|
if "empty_response" in retrieval_res.columns: |
|
|
|
@@ -97,27 +124,8 @@ class Generate(ComponentBase): |
|
|
|
|
|
|
|
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), |
|
|
|
self._param.gen_conf()) |
|
|
|
|
|
|
|
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: |
|
|
|
ans, idx = retrievaler.insert_citations(ans, |
|
|
|
[ck["content_ltks"] |
|
|
|
for _, ck in retrieval_res.iterrows()], |
|
|
|
[ck["vector"] |
|
|
|
for _, ck in retrieval_res.iterrows()], |
|
|
|
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, |
|
|
|
self._canvas.get_embedding_model()), |
|
|
|
tkweight=0.7, |
|
|
|
vtweight=0.3) |
|
|
|
del retrieval_res["vector"] |
|
|
|
retrieval_res = retrieval_res.to_dict("records") |
|
|
|
df = [] |
|
|
|
for i in idx: |
|
|
|
df.append(retrieval_res[int(i)]) |
|
|
|
r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans) |
|
|
|
assert r, f"{i} => {ans}" |
|
|
|
df[-1]["content"] = r.group(1) |
|
|
|
ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans) |
|
|
|
if ans: df.append({"content": ans}) |
|
|
|
df = self.set_cite(retrieval_res, ans) |
|
|
|
return pd.DataFrame(df) |
|
|
|
|
|
|
|
return Generate.be_output(ans) |
|
|
|
@@ -138,34 +146,7 @@ class Generate(ComponentBase): |
|
|
|
yield res |
|
|
|
|
|
|
|
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: |
|
|
|
answer, idx = retrievaler.insert_citations(answer, |
|
|
|
[ck["content_ltks"] |
|
|
|
for _, ck in retrieval_res.iterrows()], |
|
|
|
[ck["vector"] |
|
|
|
for _, ck in retrieval_res.iterrows()], |
|
|
|
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, |
|
|
|
self._canvas.get_embedding_model()), |
|
|
|
tkweight=0.7, |
|
|
|
vtweight=0.3) |
|
|
|
doc_ids = set([]) |
|
|
|
recall_docs = [] |
|
|
|
for i in idx: |
|
|
|
did = retrieval_res.loc[int(i), "doc_id"] |
|
|
|
if did in doc_ids: continue |
|
|
|
doc_ids.add(did) |
|
|
|
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) |
|
|
|
|
|
|
|
del retrieval_res["vector"] |
|
|
|
del retrieval_res["content_ltks"] |
|
|
|
|
|
|
|
reference = { |
|
|
|
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()], |
|
|
|
"doc_aggs": recall_docs |
|
|
|
} |
|
|
|
|
|
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" |
|
|
|
res = {"content": answer, "reference": reference} |
|
|
|
res = self.set_cite(retrieval_res, answer) |
|
|
|
yield res |
|
|
|
|
|
|
|
self.set_output(res) |