Bladeren bron

Cache the result from llm for graphrag and raptor (#4051)

### What problem does this PR solve?

#4045

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.15.0
Kevin Hu 10 maanden geleden
bovenliggende
commit
cb6e9ce164
No account linked to committer's email address

+ 6
- 2
api/db/services/task_service.py Bestand weergeven





def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict): def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"])
idx = bisect.bisect_left(prev_tasks, task.get("from_page", 0), key=lambda x: x.get("from_page",0))
if idx >= len(prev_tasks): if idx >= len(prev_tasks):
return 0 return 0
prev_task = prev_tasks[idx] prev_task = prev_tasks[idx]
return 0 return 0
task["chunk_ids"] = prev_task["chunk_ids"] task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0 task["progress"] = 1.0
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks"
if "from_page" in task and "to_page" in task:
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
else:
task["progress_msg"] = ""
task["progress_msg"] += "reused previous task's chunks."
prev_task["chunk_ids"] = "" prev_task["chunk_ids"] = ""


return len(task["chunk_ids"].split()) return len(task["chunk_ids"].split())

+ 0
- 0
graphrag/__init__.py Bestand weergeven


+ 5
- 5
graphrag/claim_extractor.py Bestand weergeven

import tiktoken import tiktoken


from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.extractor import Extractor
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from graphrag.utils import ErrorHandlerFn, perform_variable_replacements


source_docs: dict[str, Any] source_docs: dict[str, Any]




class ClaimExtractor:
class ClaimExtractor(Extractor):
"""Claim extractor class definition.""" """Claim extractor class definition."""


_llm: CompletionLLM
_extraction_prompt: str _extraction_prompt: str
_summary_prompt: str _summary_prompt: str
_output_formatter_prompt: str _output_formatter_prompt: str
} }
text = perform_variable_replacements(self._extraction_prompt, variables=variables) text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5} gen_conf = {"temperature": 0.5}
results = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
claims = results.strip().removesuffix(completion_delimiter) claims = results.strip().removesuffix(completion_delimiter)
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}] history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]


for i in range(self._max_gleanings): for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text}) history.append({"role": "user", "content": text})
extension = self._llm.chat("", history, gen_conf)
extension = self._chat("", history, gen_conf)
claims += record_delimiter + extension.strip().removesuffix( claims += record_delimiter + extension.strip().removesuffix(
completion_delimiter completion_delimiter
) )


history.append({"role": "assistant", "content": extension}) history.append({"role": "assistant", "content": extension})
history.append({"role": "user", "content": LOOP_PROMPT}) history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._llm.chat("", history, self._loop_args)
continuation = self._chat("", history, self._loop_args)
if continuation != "YES": if continuation != "YES":
break break



+ 3
- 3
graphrag/community_reports_extractor.py Bestand weergeven

import pandas as pd import pandas as pd
from graphrag import leiden from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.extractor import Extractor
from graphrag.leiden import add_community_info2graph from graphrag.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
structured_output: list[dict] structured_output: list[dict]




class CommunityReportsExtractor:
class CommunityReportsExtractor(Extractor):
"""Community reports extractor class definition.""" """Community reports extractor class definition."""


_llm: CompletionLLM
_extraction_prompt: str _extraction_prompt: str
_output_formatter_prompt: str _output_formatter_prompt: str
_on_error: ErrorHandlerFn _on_error: ErrorHandlerFn
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3} gen_conf = {"temperature": 0.3}
try: try:
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(text + response) token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response) response = re.sub(r"[^\}]*$", "", response)

+ 3
- 3
graphrag/description_summary.py Bestand weergeven

import json import json
from dataclasses import dataclass from dataclasses import dataclass


from graphrag.extractor import Extractor
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM


description: str description: str




class SummarizeExtractor:
class SummarizeExtractor(Extractor):
"""Unipartite graph extractor class definition.""" """Unipartite graph extractor class definition."""


_llm: CompletionLLM
_entity_name_key: str _entity_name_key: str
_input_descriptions_key: str _input_descriptions_key: str
_summarization_prompt: str _summarization_prompt: str
self._input_descriptions_key: json.dumps(sorted(descriptions)), self._input_descriptions_key: json.dumps(sorted(descriptions)),
} }
text = perform_variable_replacements(self._summarization_prompt, variables=variables) text = perform_variable_replacements(self._summarization_prompt, variables=variables)
return self._llm.chat("", [{"role": "user", "content": text}])
return self._chat("", [{"role": "user", "content": text}])

+ 4
- 3
graphrag/entity_resolution.py Bestand weergeven

from typing import Any from typing import Any


import networkx as nx import networkx as nx

from graphrag.extractor import Extractor
from rag.nlp import is_english from rag.nlp import is_english
import editdistance import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
output: nx.Graph output: nx.Graph




class EntityResolution:
class EntityResolution(Extractor):
"""Entity resolution class definition.""" """Entity resolution class definition."""


_llm: CompletionLLM
_resolution_prompt: str _resolution_prompt: str
_output_formatter_prompt: str _output_formatter_prompt: str
_on_error: ErrorHandlerFn _on_error: ErrorHandlerFn
} }
text = perform_variable_replacements(self._resolution_prompt, variables=variables) text = perform_variable_replacements(self._resolution_prompt, variables=variables)


response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
result = self._process_results(len(candidate_resolution_i[1]), response, result = self._process_results(len(candidate_resolution_i[1]), response,
prompt_variables.get(self._record_delimiter_key, prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER), DEFAULT_RECORD_DELIMITER),

+ 34
- 0
graphrag/extractor.py Bestand weergeven

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from graphrag.utils import get_llm_cache, set_llm_cache
from rag.llm.chat_model import Base as CompletionLLM


class Extractor:
_llm: CompletionLLM

def __init__(self, llm_invoker: CompletionLLM):
self._llm = llm_invoker

def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm.llm_name, system, history, gen_conf)
if response:
return response
response = self._llm.chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
return response

+ 6
- 9
graphrag/graph_extractor.py Bestand weergeven

from typing import Any, Callable, Mapping from typing import Any, Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
import tiktoken import tiktoken

from graphrag.extractor import Extractor
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
source_docs: dict[Any, Any] source_docs: dict[Any, Any]




class GraphExtractor:
class GraphExtractor(Extractor):
"""Unipartite graph extractor class definition.""" """Unipartite graph extractor class definition."""


_llm: CompletionLLM
_join_descriptions: bool _join_descriptions: bool
_tuple_delimiter_key: str _tuple_delimiter_key: str
_record_delimiter_key: str _record_delimiter_key: str
token_count = 0 token_count = 0
text = perform_variable_replacements(self._extraction_prompt, variables=variables) text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.3} gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
token_count = num_tokens_from_string(text + response) token_count = num_tokens_from_string(text + response)


results = response or "" results = response or ""
for i in range(self._max_gleanings): for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text}) history.append({"role": "user", "content": text})
response = self._llm.chat("", history, gen_conf)
if response.find("**ERROR**") >=0:
raise Exception(response)
response = self._chat("", history, gen_conf)
results += response or "" results += response or ""


# if this is the final glean, don't bother updating the continuation flag # if this is the final glean, don't bother updating the continuation flag
break break
history.append({"role": "assistant", "content": response}) history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT}) history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._llm.chat("", history, self._loop_args)
continuation = self._chat("", history, self._loop_args)
if continuation != "YES": if continuation != "YES":
break break



+ 3
- 3
graphrag/mind_map_extractor.py Bestand weergeven

from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass


from graphrag.extractor import Extractor
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
output: dict output: dict




class MindMapExtractor:
_llm: CompletionLLM
class MindMapExtractor(Extractor):
_input_text_key: str _input_text_key: str
_mind_map_prompt: str _mind_map_prompt: str
_on_error: ErrorHandlerFn _on_error: ErrorHandlerFn
} }
text = perform_variable_replacements(self._mind_map_prompt, variables=variables) text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
gen_conf = {"temperature": 0.5} gen_conf = {"temperature": 0.5}
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
response = re.sub(r"```[^\n]*", "", response) response = re.sub(r"```[^\n]*", "", response)
logging.debug(response) logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response))) logging.debug(self._todict(markdown_to_json.dictify(response)))

+ 52
- 0
graphrag/utils.py Bestand weergeven

""" """


import html import html
import json
import re import re
from typing import Any, Callable from typing import Any, Callable


import numpy as np
import xxhash

from rag.utils.redis_conn import REDIS_CONN

ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]




return False return False
return True return True



def get_llm_cache(llmnm, txt, history, genconf):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
hasher.update(str(history).encode("utf-8"))
hasher.update(str(genconf).encode("utf-8"))

k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return bin.decode("utf-8")


def set_llm_cache(llmnm, txt, v: str, history, genconf):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
hasher.update(str(history).encode("utf-8"))
hasher.update(str(genconf).encode("utf-8"))

k = hasher.hexdigest()
REDIS_CONN.set(k, v.encode("utf-8"), 24*3600)


def get_embed_cache(llmnm, txt):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))

k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return np.array(json.loads(bin.decode("utf-8")))


def set_embed_cache(llmnm, txt, arr):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))

k = hasher.hexdigest()
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)

+ 24
- 4
rag/raptor.py Bestand weergeven

import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture


from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
from rag.utils import truncate from rag.utils import truncate




self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token


def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
if response:
return response
response = self._llm_model.chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
return response

def _embedding_encode(self, txt):
response = get_embed_cache(self._embd_model.llm_name, txt)
if response:
return response
embds, _ = self._embd_model.encode([txt])
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
set_embed_cache(self._embd_model.llm_name, txt, embds)
return embds

def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int): def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
max_clusters = min(self._max_cluster, len(embeddings)) max_clusters = min(self._max_cluster, len(embeddings))
n_clusters = np.arange(1, max_clusters) n_clusters = np.arange(1, max_clusters)
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
cnt = self._llm_model.chat("You're a helpful assistant.",
cnt = self._chat("You're a helpful assistant.",
[{"role": "user", [{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}], "content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token} {"temperature": 0.3, "max_tokens": self._max_token}
logging.debug(f"SUM: {cnt}") logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt]) embds, _ = self._embd_model.encode([cnt])
with lock: with lock:
if not len(embds[0]):
return
chunks.append((cnt, embds[0]))
chunks.append((cnt, self._embedding_encode(cnt)))
except Exception as e: except Exception as e:
logging.exception("summarize got exception") logging.exception("summarize got exception")
return e return e

+ 21
- 6
rag/svr/task_executor.py Bestand weergeven



import sys import sys
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger
from graphrag.utils import get_llm_cache, set_llm_cache

CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME) initRootLogger(CONSUMER_NAME)
if not d.get("image"): if not d.get("image"):
_ = d.pop("image", None) _ = d.pop("image", None)
d["img_id"] = "" d["img_id"] = ""
d["page_num_int"] = []
d["position_int"] = []
d["top_int"] = []
docs.append(d) docs.append(d)
continue continue


progress_callback(msg="Start to generate keywords for every chunk ...") progress_callback(msg="Start to generate keywords for every chunk ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs: for d in docs:
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
task["parser_config"]["auto_keywords"]).split(",")
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords",
{"topn": task["parser_config"]["auto_keywords"]})
if not cached:
cached = keyword_extraction(chat_mdl, d["content_with_weight"],
task["parser_config"]["auto_keywords"])
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords",
{"topn": task["parser_config"]["auto_keywords"]})

d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st)) progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))


progress_callback(msg="Start to generate questions for every chunk ...") progress_callback(msg="Start to generate questions for every chunk ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs: for d in docs:
d["question_kwd"] = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"]).split("\n")
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question",
{"topn": task["parser_config"]["auto_questions"]})
if not cached:
cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"])
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question",
{"topn": task["parser_config"]["auto_questions"]})

d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st)) progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))



Laden…
Annuleren
Opslaan