Pārlūkot izejas kodu

refine loginfo about graprag progress (#1823)

### What problem does this PR solve?



### Type of change

- [x] Refactoring
tags/v0.10.0
Kevin Hu pirms 1 gada
vecāks
revīzija
43199c45c3
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam

+ 2
- 1
api/db/services/document_service.py Parādīt failu

if 0 <= t.progress < 1: if 0 <= t.progress < 1:
finished = False finished = False
prg += t.progress if t.progress >= 0 else 0 prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg)
if t.progress_msg not in msg:
msg.append(t.progress_msg)
if t.progress == -1: if t.progress == -1:
bad += 1 bad += 1
prg /= len(tsks) prg /= len(tsks)

+ 11
- 5
graphrag/community_reports_extractor.py Parādīt failu

import re import re
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List

from typing import Any, List, Callable
import networkx as nx import networkx as nx
import pandas as pd import pandas as pd

from graphrag import leiden from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.leiden import add_community_info2graph from graphrag.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer


log = logging.getLogger(__name__) log = logging.getLogger(__name__)


self._on_error = on_error or (lambda _e, _s, _d: None) self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_report_length = max_report_length or 1500 self._max_report_length = max_report_length or 1500


def __call__(self, graph: nx.Graph):
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
communities: dict[str, dict[str, List]] = leiden.run(graph, {}) communities: dict[str, dict[str, List]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = [] res_str = []
res_dict = [] res_dict = []
over, token_count = 0, 0
st = timer()
for level, comm in communities.items(): for level, comm in communities.items():
for cm_id, ents in comm.items(): for cm_id, ents in comm.items():
weight = ents["weight"] weight = ents["weight"]
"relation_df": rela_df.to_csv(index_label="id") "relation_df": rela_df.to_csv(index_label="id")
} }
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.5}
gen_conf = {"temperature": 0.3}
try: try:
response = self._llm.chat(text, [], gen_conf) response = self._llm.chat(text, [], gen_conf)
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response) response = re.sub(r"[^\}]*$", "", response)
print(response) print(response)
add_community_info2graph(graph, ents, response["title"]) add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response)) res_str.append(self._get_text_output(response))
res_dict.append(response) res_dict.append(response)
over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")


return CommunityReportsResult( return CommunityReportsResult(
structured_output=res_dict, structured_output=res_dict,

+ 15
- 5
graphrag/graph_extractor.py Parādīt failu

import re import re
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Mapping
from typing import Any, Mapping, Callable
import tiktoken import tiktoken
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx import networkx as nx
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from timeit import default_timer as timer


DEFAULT_TUPLE_DELIMITER = "<|>" DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##" DEFAULT_RECORD_DELIMITER = "##"
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}


def __call__( def __call__(
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
self, texts: list[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None
) -> GraphExtractionResult: ) -> GraphExtractionResult:
"""Call method definition.""" """Call method definition."""
if prompt_variables is None: if prompt_variables is None:
), ),
} }


st = timer()
total = len(texts)
total_token_count = 0
for doc_index, text in enumerate(texts): for doc_index, text in enumerate(texts):
try: try:
# Invoke the entity extraction # Invoke the entity extraction
result = self._process_document(text, prompt_variables)
result, token_count = self._process_document(text, prompt_variables)
source_doc_map[doc_index] = text source_doc_map[doc_index] = text
all_records[doc_index] = result all_records[doc_index] = result
total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e: except Exception as e:
logging.exception("error extracting graph") logging.exception("error extracting graph")
self._on_error( self._on_error(
**prompt_variables, **prompt_variables,
self._input_text_key: text, self._input_text_key: text,
} }
token_count = 0
text = perform_variable_replacements(self._extraction_prompt, variables=variables) text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [], gen_conf) response = self._llm.chat(text, [], gen_conf)
token_count = num_tokens_from_string(text + response)


results = response or "" results = response or ""
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
if continuation != "YES": if continuation != "YES":
break break


return results
return results, token_count


def _process_results( def _process_results(
self, self,

+ 3
- 3
graphrag/index.py Parādīt failu

for i in range(len(chunks)): for i in range(len(chunks)):
tkn_cnt = num_tokens_from_string(chunks[i]) tkn_cnt = num_tokens_from_string(chunks[i])
if cnt+tkn_cnt >= left_token_count and texts: if cnt+tkn_cnt >= left_token_count and texts:
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback))
texts = [] texts = []
cnt = 0 cnt = 0
texts.append(chunks[i]) texts.append(chunks[i])
graphs = [] graphs = []
for i, _ in enumerate(threads): for i, _ in enumerate(threads):
graphs.append(_.result().output) graphs.append(_.result().output)
callback(0.5 + 0.1*i/len(threads))
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")


graph = reduce(graph_merge, graphs) graph = reduce(graph_merge, graphs)
er = EntityResolution(llm_bdl) er = EntityResolution(llm_bdl)


callback(0.6, "Extracting community reports.") callback(0.6, "Extracting community reports.")
cr = CommunityReportsExtractor(llm_bdl) cr = CommunityReportsExtractor(llm_bdl)
cr = cr(graph)
cr = cr(graph, callback=callback)
for community, desc in zip(cr.structured_output, cr.output): for community, desc in zip(cr.structured_output, cr.output):
chunk = { chunk = {
"title_tks": rag_tokenizer.tokenize(community["title"]), "title_tks": rag_tokenizer.tokenize(community["title"]),

+ 1
- 1
rag/nlp/search.py Parādīt failu

es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s: if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%") bqry, _ = self.qryr.question(qst, min_match="10%")
bqry = self._add_filters(bqry)
bqry = self._add_filters(bqry, req)
s["query"] = bqry.to_dict() s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17 s["knn"]["similarity"] = 0.17

Notiek ielāde…
Atcelt
Saglabāt