### What problem does this PR solve? #4543 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoringtags/v0.16.0
| import binascii | import binascii | ||||
| import os | import os | ||||
| import json | import json | ||||
| import json_repair | |||||
| import re | import re | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 | generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 | ||||
| prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" | prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" | ||||
| return {"answer": answer, "reference": refs, "prompt": prompt} | |||||
| return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)} | |||||
| if stream: | if stream: | ||||
| last_ans = "" | last_ans = "" | ||||
| if kwd.find("**ERROR**") >= 0: | if kwd.find("**ERROR**") >= 0: | ||||
| raise Exception(kwd) | raise Exception(kwd) | ||||
| kwd = re.sub(r".*?\{", "{", kwd) | |||||
| return json.loads(kwd) | |||||
| try: | |||||
| return json_repair.loads(kwd) | |||||
| except json_repair.JSONDecodeError: | |||||
| try: | |||||
| result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip() | |||||
| result = '{' + result.split('{')[1].split('}')[0] + '}' | |||||
| return json_repair.loads(result) | |||||
| except Exception as e: | |||||
| logging.exception(f"JSON parsing error: {result} -> {e}") | |||||
| raise e |
| break | break | ||||
| if ents: | if ents: | ||||
| ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv()) | |||||
| ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv()) | |||||
| else: | else: | ||||
| ents = "" | ents = "" | ||||
| if relas: | if relas: | ||||
| relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv()) | |||||
| relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv()) | |||||
| else: | else: | ||||
| relas = "" | relas = "" | ||||
| if not txts: | if not txts: | ||||
| return "" | return "" | ||||
| return "\n-Community Report-\n" + "\n".join(txts) | |||||
| return "\n---- Community Report ----\n" + "\n".join(txts) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": |
| from api import settings | from api import settings | ||||
| from rag.nlp import search, rag_tokenizer | from rag.nlp import search, rag_tokenizer | ||||
| from rag.utils.doc_store_conn import OrderByExpr | |||||
| from rag.utils.redis_conn import REDIS_CONN | from rag.utils.redis_conn import REDIS_CONN | ||||
| ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] | ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] | ||||
| res.field[id]["source_id"] | res.field[id]["source_id"] | ||||
| except Exception: | except Exception: | ||||
| continue | continue | ||||
| return None, None | |||||
| return rebuild_graph(tenant_id, kb_id) | |||||
| def set_graph(tenant_id, kb_id, graph, docids): | def set_graph(tenant_id, kb_id, graph, docids): | ||||
| res.append(a) | res.append(a) | ||||
| return list(set(res)) | return list(set(res)) | ||||
| def rebuild_graph(tenant_id, kb_id): | |||||
| graph = nx.Graph() | |||||
| src_ids = [] | |||||
| flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] | |||||
| bs = 256 | |||||
| for i in range(0, 10000000, bs): | |||||
| es_res = settings.docStoreConn.search(flds, [], | |||||
| {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, | |||||
| [], | |||||
| OrderByExpr(), | |||||
| i, bs, search.index_name(tenant_id), [kb_id] | |||||
| ) | |||||
| tot = settings.docStoreConn.getTotal(es_res) | |||||
| if tot == 0: | |||||
| return None, None | |||||
| es_res = settings.docStoreConn.getFields(es_res, flds) | |||||
| for id, d in es_res.items(): | |||||
| src_ids.extend(d.get("source_id", [])) | |||||
| if d["knowledge_graph_kwd"] == "entity": | |||||
| graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"]) | |||||
| else: | |||||
| graph.add_edge( | |||||
| d["from_entity_kwd"], | |||||
| d["to_entity_kwd"], | |||||
| weight=int(d["weight_int"]) | |||||
| ) | |||||
| if len(es_res.keys()) < 128: | |||||
| return graph, list(set(src_ids)) | |||||
| return graph, list(set(src_ids)) |
| cnt = np.sum([c for _, c in aggs]) | cnt = np.sum([c for _, c in aggs]) | ||||
| tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs], | tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs], | ||||
| key=lambda x: x[1] * -1)[:topn_tags] | key=lambda x: x[1] * -1)[:topn_tags] | ||||
| return {a: c for a, c in tag_fea if c > 0} | |||||
| return {a: max(1, c) for a, c in tag_fea} |
| random.choices(examples, k=2) if len(examples)>2 else examples, | random.choices(examples, k=2) if len(examples)>2 else examples, | ||||
| topn=topn_tags) | topn=topn_tags) | ||||
| if cached: | if cached: | ||||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) | |||||
| d[TAG_FLD] = json.loads(cached) | |||||
| cached = json.dumps(cached) | |||||
| if cached: | |||||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) | |||||
| d[TAG_FLD] = json.loads(cached) | |||||
| progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st)) | progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st)) | ||||