### What problem does this PR solve? #4543 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoringtags/v0.16.0
| @@ -17,6 +17,7 @@ import logging | |||
| import binascii | |||
| import os | |||
| import json | |||
| import json_repair | |||
| import re | |||
| from collections import defaultdict | |||
| from copy import deepcopy | |||
| @@ -353,7 +354,7 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 | |||
| prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" | |||
| return {"answer": answer, "reference": refs, "prompt": prompt} | |||
| return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)} | |||
| if stream: | |||
| last_ans = "" | |||
| @@ -795,5 +796,13 @@ Output: | |||
| if kwd.find("**ERROR**") >= 0: | |||
| raise Exception(kwd) | |||
| kwd = re.sub(r".*?\{", "{", kwd) | |||
| return json.loads(kwd) | |||
| try: | |||
| return json_repair.loads(kwd) | |||
| except json_repair.JSONDecodeError: | |||
| try: | |||
| result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip() | |||
| result = '{' + result.split('{')[1].split('}')[0] + '}' | |||
| return json_repair.loads(result) | |||
| except Exception as e: | |||
| logging.exception(f"JSON parsing error: {result} -> {e}") | |||
| raise e | |||
| @@ -251,11 +251,11 @@ class KGSearch(Dealer): | |||
| break | |||
| if ents: | |||
| ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv()) | |||
| ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv()) | |||
| else: | |||
| ents = "" | |||
| if relas: | |||
| relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv()) | |||
| relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv()) | |||
| else: | |||
| relas = "" | |||
| @@ -296,7 +296,7 @@ class KGSearch(Dealer): | |||
| if not txts: | |||
| return "" | |||
| return "\n-Community Report-\n" + "\n".join(txts) | |||
| return "\n---- Community Report ----\n" + "\n".join(txts) | |||
| if __name__ == "__main__": | |||
| @@ -23,6 +23,7 @@ from networkx.readwrite import json_graph | |||
| from api import settings | |||
| from rag.nlp import search, rag_tokenizer | |||
| from rag.utils.doc_store_conn import OrderByExpr | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] | |||
| @@ -363,7 +364,7 @@ def get_graph(tenant_id, kb_id): | |||
| res.field[id]["source_id"] | |||
| except Exception: | |||
| continue | |||
| return None, None | |||
| return rebuild_graph(tenant_id, kb_id) | |||
| def set_graph(tenant_id, kb_id, graph, docids): | |||
| @@ -517,3 +518,36 @@ def flat_uniq_list(arr, key): | |||
| res.append(a) | |||
| return list(set(res)) | |||
| def rebuild_graph(tenant_id, kb_id): | |||
| graph = nx.Graph() | |||
| src_ids = [] | |||
| flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] | |||
| bs = 256 | |||
| for i in range(0, 10000000, bs): | |||
| es_res = settings.docStoreConn.search(flds, [], | |||
| {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, | |||
| [], | |||
| OrderByExpr(), | |||
| i, bs, search.index_name(tenant_id), [kb_id] | |||
| ) | |||
| tot = settings.docStoreConn.getTotal(es_res) | |||
| if tot == 0: | |||
| return None, None | |||
| es_res = settings.docStoreConn.getFields(es_res, flds) | |||
| for id, d in es_res.items(): | |||
| src_ids.extend(d.get("source_id", [])) | |||
| if d["knowledge_graph_kwd"] == "entity": | |||
| graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"]) | |||
| else: | |||
| graph.add_edge( | |||
| d["from_entity_kwd"], | |||
| d["to_entity_kwd"], | |||
| weight=int(d["weight_int"]) | |||
| ) | |||
| if len(es_res.keys()) < 128: | |||
| return graph, list(set(src_ids)) | |||
| return graph, list(set(src_ids)) | |||
| @@ -483,4 +483,4 @@ class Dealer: | |||
| cnt = np.sum([c for _, c in aggs]) | |||
| tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs], | |||
| key=lambda x: x[1] * -1)[:topn_tags] | |||
| return {a: c for a, c in tag_fea if c > 0} | |||
| return {a: max(1, c) for a, c in tag_fea} | |||
| @@ -327,8 +327,10 @@ def build_chunks(task, progress_callback): | |||
| random.choices(examples, k=2) if len(examples)>2 else examples, | |||
| topn=topn_tags) | |||
| if cached: | |||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) | |||
| d[TAG_FLD] = json.loads(cached) | |||
| cached = json.dumps(cached) | |||
| if cached: | |||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) | |||
| d[TAG_FLD] = json.loads(cached) | |||
| progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st)) | |||