Browse Source

Rebuild graph when it's out of time. (#4607)

### What problem does this PR solve?

#4543

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
tags/v0.16.0
Kevin Hu 9 months ago
parent
commit
86892959a0
No account linked to committer's email address
5 changed files with 55 additions and 10 deletions
  1. 12
    3
      api/db/services/dialog_service.py
  2. 3
    3
      graphrag/search.py
  3. 35
    1
      graphrag/utils.py
  4. 1
    1
      rag/nlp/search.py
  5. 4
    2
      rag/svr/task_executor.py

+ 12
- 3
api/db/services/dialog_service.py View File

import binascii import binascii
import os import os
import json import json
import json_repair
import re import re
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000


prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": answer, "reference": refs, "prompt": prompt}
return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)}


if stream: if stream:
last_ans = "" last_ans = ""
if kwd.find("**ERROR**") >= 0: if kwd.find("**ERROR**") >= 0:
raise Exception(kwd) raise Exception(kwd)


kwd = re.sub(r".*?\{", "{", kwd)
return json.loads(kwd)
try:
return json_repair.loads(kwd)
except json_repair.JSONDecodeError:
try:
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
return json_repair.loads(result)
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e

+ 3
- 3
graphrag/search.py View File

break break


if ents: if ents:
ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv())
ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv())
else: else:
ents = "" ents = ""
if relas: if relas:
relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv())
relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv())
else: else:
relas = "" relas = ""




if not txts: if not txts:
return "" return ""
return "\n-Community Report-\n" + "\n".join(txts)
return "\n---- Community Report ----\n" + "\n".join(txts)




if __name__ == "__main__": if __name__ == "__main__":

+ 35
- 1
graphrag/utils.py View File



from api import settings from api import settings
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN


ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
res.field[id]["source_id"] res.field[id]["source_id"]
except Exception: except Exception:
continue continue
return None, None
return rebuild_graph(tenant_id, kb_id)




def set_graph(tenant_id, kb_id, graph, docids): def set_graph(tenant_id, kb_id, graph, docids):
res.append(a) res.append(a)
return list(set(res)) return list(set(res))



def rebuild_graph(tenant_id, kb_id):
graph = nx.Graph()
src_ids = []
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
bs = 256
for i in range(0, 10000000, bs):
es_res = settings.docStoreConn.search(flds, [],
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
[],
OrderByExpr(),
i, bs, search.index_name(tenant_id), [kb_id]
)
tot = settings.docStoreConn.getTotal(es_res)
if tot == 0:
return None, None

es_res = settings.docStoreConn.getFields(es_res, flds)
for id, d in es_res.items():
src_ids.extend(d.get("source_id", []))
if d["knowledge_graph_kwd"] == "entity":
graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"])
else:
graph.add_edge(
d["from_entity_kwd"],
d["to_entity_kwd"],
weight=int(d["weight_int"])
)

if len(es_res.keys()) < 128:
return graph, list(set(src_ids))

return graph, list(set(src_ids))

+ 1
- 1
rag/nlp/search.py View File

cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs], tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags] key=lambda x: x[1] * -1)[:topn_tags]
return {a: c for a, c in tag_fea if c > 0}
return {a: max(1, c) for a, c in tag_fea}

+ 4
- 2
rag/svr/task_executor.py View File

random.choices(examples, k=2) if len(examples)>2 else examples, random.choices(examples, k=2) if len(examples)>2 else examples,
topn=topn_tags) topn=topn_tags)
if cached: if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)


progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st)) progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))



Loading…
Cancel
Save