### What problem does this PR solve? #4045 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.15.0
| @@ -271,7 +271,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): | |||
| def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict): | |||
| idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"]) | |||
| idx = bisect.bisect_left(prev_tasks, task.get("from_page", 0), key=lambda x: x.get("from_page",0)) | |||
| if idx >= len(prev_tasks): | |||
| return 0 | |||
| prev_task = prev_tasks[idx] | |||
| @@ -279,7 +279,11 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: | |||
| return 0 | |||
| task["chunk_ids"] = prev_task["chunk_ids"] | |||
| task["progress"] = 1.0 | |||
| task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks" | |||
| if "from_page" in task and "to_page" in task: | |||
| task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): " | |||
| else: | |||
| task["progress_msg"] = "" | |||
| task["progress_msg"] += "reused previous task's chunks." | |||
| prev_task["chunk_ids"] = "" | |||
| return len(task["chunk_ids"].split()) | |||
| @@ -16,6 +16,7 @@ from typing import Any | |||
| import tiktoken | |||
| from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | |||
| from graphrag.extractor import Extractor | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |||
| @@ -33,10 +34,9 @@ class ClaimExtractorResult: | |||
| source_docs: dict[str, Any] | |||
| class ClaimExtractor: | |||
| class ClaimExtractor(Extractor): | |||
| """Claim extractor class definition.""" | |||
| _llm: CompletionLLM | |||
| _extraction_prompt: str | |||
| _summary_prompt: str | |||
| _output_formatter_prompt: str | |||
| @@ -169,7 +169,7 @@ class ClaimExtractor: | |||
| } | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=variables) | |||
| gen_conf = {"temperature": 0.5} | |||
| results = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| claims = results.strip().removesuffix(completion_delimiter) | |||
| history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}] | |||
| @@ -177,7 +177,7 @@ class ClaimExtractor: | |||
| for i in range(self._max_gleanings): | |||
| text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) | |||
| history.append({"role": "user", "content": text}) | |||
| extension = self._llm.chat("", history, gen_conf) | |||
| extension = self._chat("", history, gen_conf) | |||
| claims += record_delimiter + extension.strip().removesuffix( | |||
| completion_delimiter | |||
| ) | |||
| @@ -188,7 +188,7 @@ class ClaimExtractor: | |||
| history.append({"role": "assistant", "content": extension}) | |||
| history.append({"role": "user", "content": LOOP_PROMPT}) | |||
| continuation = self._llm.chat("", history, self._loop_args) | |||
| continuation = self._chat("", history, self._loop_args) | |||
| if continuation != "YES": | |||
| break | |||
| @@ -15,6 +15,7 @@ import networkx as nx | |||
| import pandas as pd | |||
| from graphrag import leiden | |||
| from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT | |||
| from graphrag.extractor import Extractor | |||
| from graphrag.leiden import add_community_info2graph | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types | |||
| @@ -30,10 +31,9 @@ class CommunityReportsResult: | |||
| structured_output: list[dict] | |||
| class CommunityReportsExtractor: | |||
| class CommunityReportsExtractor(Extractor): | |||
| """Community reports extractor class definition.""" | |||
| _llm: CompletionLLM | |||
| _extraction_prompt: str | |||
| _output_formatter_prompt: str | |||
| _on_error: ErrorHandlerFn | |||
| @@ -74,7 +74,7 @@ class CommunityReportsExtractor: | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | |||
| gen_conf = {"temperature": 0.3} | |||
| try: | |||
| response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| token_count += num_tokens_from_string(text + response) | |||
| response = re.sub(r"^[^\{]*", "", response) | |||
| response = re.sub(r"[^\}]*$", "", response) | |||
| @@ -8,6 +8,7 @@ Reference: | |||
| import json | |||
| from dataclasses import dataclass | |||
| from graphrag.extractor import Extractor | |||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| @@ -42,10 +43,9 @@ class SummarizationResult: | |||
| description: str | |||
| class SummarizeExtractor: | |||
| class SummarizeExtractor(Extractor): | |||
| """Unipartite graph extractor class definition.""" | |||
| _llm: CompletionLLM | |||
| _entity_name_key: str | |||
| _input_descriptions_key: str | |||
| _summarization_prompt: str | |||
| @@ -143,4 +143,4 @@ class SummarizeExtractor: | |||
| self._input_descriptions_key: json.dumps(sorted(descriptions)), | |||
| } | |||
| text = perform_variable_replacements(self._summarization_prompt, variables=variables) | |||
| return self._llm.chat("", [{"role": "user", "content": text}]) | |||
| return self._chat("", [{"role": "user", "content": text}]) | |||
| @@ -21,6 +21,8 @@ from dataclasses import dataclass | |||
| from typing import Any | |||
| import networkx as nx | |||
| from graphrag.extractor import Extractor | |||
| from rag.nlp import is_english | |||
| import editdistance | |||
| from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT | |||
| @@ -39,10 +41,9 @@ class EntityResolutionResult: | |||
| output: nx.Graph | |||
| class EntityResolution: | |||
| class EntityResolution(Extractor): | |||
| """Entity resolution class definition.""" | |||
| _llm: CompletionLLM | |||
| _resolution_prompt: str | |||
| _output_formatter_prompt: str | |||
| _on_error: ErrorHandlerFn | |||
| @@ -117,7 +118,7 @@ class EntityResolution: | |||
| } | |||
| text = perform_variable_replacements(self._resolution_prompt, variables=variables) | |||
| response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| result = self._process_results(len(candidate_resolution_i[1]), response, | |||
| prompt_variables.get(self._record_delimiter_key, | |||
| DEFAULT_RECORD_DELIMITER), | |||
| @@ -0,0 +1,34 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from graphrag.utils import get_llm_cache, set_llm_cache | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| class Extractor: | |||
| _llm: CompletionLLM | |||
| def __init__(self, llm_invoker: CompletionLLM): | |||
| self._llm = llm_invoker | |||
| def _chat(self, system, history, gen_conf): | |||
| response = get_llm_cache(self._llm.llm_name, system, history, gen_conf) | |||
| if response: | |||
| return response | |||
| response = self._llm.chat(system, history, gen_conf) | |||
| if response.find("**ERROR**") >= 0: | |||
| raise Exception(response) | |||
| set_llm_cache(self._llm.llm_name, system, response, history, gen_conf) | |||
| return response | |||
| @@ -12,6 +12,8 @@ import traceback | |||
| from typing import Any, Callable, Mapping | |||
| from dataclasses import dataclass | |||
| import tiktoken | |||
| from graphrag.extractor import Extractor | |||
| from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | |||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| @@ -34,10 +36,9 @@ class GraphExtractionResult: | |||
| source_docs: dict[Any, Any] | |||
| class GraphExtractor: | |||
| class GraphExtractor(Extractor): | |||
| """Unipartite graph extractor class definition.""" | |||
| _llm: CompletionLLM | |||
| _join_descriptions: bool | |||
| _tuple_delimiter_key: str | |||
| _record_delimiter_key: str | |||
| @@ -165,9 +166,7 @@ class GraphExtractor: | |||
| token_count = 0 | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=variables) | |||
| gen_conf = {"temperature": 0.3} | |||
| response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| if response.find("**ERROR**") >= 0: | |||
| raise Exception(response) | |||
| response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| token_count = num_tokens_from_string(text + response) | |||
| results = response or "" | |||
| @@ -177,9 +176,7 @@ class GraphExtractor: | |||
| for i in range(self._max_gleanings): | |||
| text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) | |||
| history.append({"role": "user", "content": text}) | |||
| response = self._llm.chat("", history, gen_conf) | |||
| if response.find("**ERROR**") >=0: | |||
| raise Exception(response) | |||
| response = self._chat("", history, gen_conf) | |||
| results += response or "" | |||
| # if this is the final glean, don't bother updating the continuation flag | |||
| @@ -187,7 +184,7 @@ class GraphExtractor: | |||
| break | |||
| history.append({"role": "assistant", "content": response}) | |||
| history.append({"role": "user", "content": LOOP_PROMPT}) | |||
| continuation = self._llm.chat("", history, self._loop_args) | |||
| continuation = self._chat("", history, self._loop_args) | |||
| if continuation != "YES": | |||
| break | |||
| @@ -23,6 +23,7 @@ from typing import Any | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from dataclasses import dataclass | |||
| from graphrag.extractor import Extractor | |||
| from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT | |||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| @@ -37,8 +38,7 @@ class MindMapResult: | |||
| output: dict | |||
| class MindMapExtractor: | |||
| _llm: CompletionLLM | |||
| class MindMapExtractor(Extractor): | |||
| _input_text_key: str | |||
| _mind_map_prompt: str | |||
| _on_error: ErrorHandlerFn | |||
| @@ -190,7 +190,7 @@ class MindMapExtractor: | |||
| } | |||
| text = perform_variable_replacements(self._mind_map_prompt, variables=variables) | |||
| gen_conf = {"temperature": 0.5} | |||
| response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| response = re.sub(r"```[^\n]*", "", response) | |||
| logging.debug(response) | |||
| logging.debug(self._todict(markdown_to_json.dictify(response))) | |||
| @@ -6,9 +6,15 @@ Reference: | |||
| """ | |||
| import html | |||
| import json | |||
| import re | |||
| from typing import Any, Callable | |||
| import numpy as np | |||
| import xxhash | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] | |||
| @@ -60,3 +66,49 @@ def dict_has_keys_with_types( | |||
| return False | |||
| return True | |||
| def get_llm_cache(llmnm, txt, history, genconf): | |||
| hasher = xxhash.xxh64() | |||
| hasher.update(str(llmnm).encode("utf-8")) | |||
| hasher.update(str(txt).encode("utf-8")) | |||
| hasher.update(str(history).encode("utf-8")) | |||
| hasher.update(str(genconf).encode("utf-8")) | |||
| k = hasher.hexdigest() | |||
| bin = REDIS_CONN.get(k) | |||
| if not bin: | |||
| return | |||
| return bin.decode("utf-8") | |||
| def set_llm_cache(llmnm, txt, v: str, history, genconf): | |||
| hasher = xxhash.xxh64() | |||
| hasher.update(str(llmnm).encode("utf-8")) | |||
| hasher.update(str(txt).encode("utf-8")) | |||
| hasher.update(str(history).encode("utf-8")) | |||
| hasher.update(str(genconf).encode("utf-8")) | |||
| k = hasher.hexdigest() | |||
| REDIS_CONN.set(k, v.encode("utf-8"), 24*3600) | |||
| def get_embed_cache(llmnm, txt): | |||
| hasher = xxhash.xxh64() | |||
| hasher.update(str(llmnm).encode("utf-8")) | |||
| hasher.update(str(txt).encode("utf-8")) | |||
| k = hasher.hexdigest() | |||
| bin = REDIS_CONN.get(k) | |||
| if not bin: | |||
| return | |||
| return np.array(json.loads(bin.decode("utf-8"))) | |||
| def set_embed_cache(llmnm, txt, arr): | |||
| hasher = xxhash.xxh64() | |||
| hasher.update(str(llmnm).encode("utf-8")) | |||
| hasher.update(str(txt).encode("utf-8")) | |||
| k = hasher.hexdigest() | |||
| arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr) | |||
| REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600) | |||
| @@ -21,6 +21,7 @@ import umap | |||
| import numpy as np | |||
| from sklearn.mixture import GaussianMixture | |||
| from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache | |||
| from rag.utils import truncate | |||
| @@ -33,6 +34,27 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| self._prompt = prompt | |||
| self._max_token = max_token | |||
| def _chat(self, system, history, gen_conf): | |||
| response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) | |||
| if response: | |||
| return response | |||
| response = self._llm_model.chat(system, history, gen_conf) | |||
| if response.find("**ERROR**") >= 0: | |||
| raise Exception(response) | |||
| set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) | |||
| return response | |||
| def _embedding_encode(self, txt): | |||
| response = get_embed_cache(self._embd_model.llm_name, txt) | |||
| if response: | |||
| return response | |||
| embds, _ = self._embd_model.encode([txt]) | |||
| if len(embds) < 1 or len(embds[0]) < 1: | |||
| raise Exception("Embedding error: ") | |||
| embds = embds[0] | |||
| set_embed_cache(self._embd_model.llm_name, txt, embds) | |||
| return embds | |||
| def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int): | |||
| max_clusters = min(self._max_cluster, len(embeddings)) | |||
| n_clusters = np.arange(1, max_clusters) | |||
| @@ -57,7 +79,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| texts = [chunks[i][0] for i in ck_idx] | |||
| len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) | |||
| cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) | |||
| cnt = self._llm_model.chat("You're a helpful assistant.", | |||
| cnt = self._chat("You're a helpful assistant.", | |||
| [{"role": "user", | |||
| "content": self._prompt.format(cluster_content=cluster_content)}], | |||
| {"temperature": 0.3, "max_tokens": self._max_token} | |||
| @@ -67,9 +89,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| logging.debug(f"SUM: {cnt}") | |||
| embds, _ = self._embd_model.encode([cnt]) | |||
| with lock: | |||
| if not len(embds[0]): | |||
| return | |||
| chunks.append((cnt, embds[0])) | |||
| chunks.append((cnt, self._embedding_encode(cnt))) | |||
| except Exception as e: | |||
| logging.exception("summarize got exception") | |||
| return e | |||
| @@ -19,6 +19,8 @@ | |||
| import sys | |||
| from api.utils.log_utils import initRootLogger | |||
| from graphrag.utils import get_llm_cache, set_llm_cache | |||
| CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] | |||
| CONSUMER_NAME = "task_executor_" + CONSUMER_NO | |||
| initRootLogger(CONSUMER_NAME) | |||
| @@ -232,9 +234,6 @@ def build_chunks(task, progress_callback): | |||
| if not d.get("image"): | |||
| _ = d.pop("image", None) | |||
| d["img_id"] = "" | |||
| d["page_num_int"] = [] | |||
| d["position_int"] = [] | |||
| d["top_int"] = [] | |||
| docs.append(d) | |||
| continue | |||
| @@ -262,8 +261,16 @@ def build_chunks(task, progress_callback): | |||
| progress_callback(msg="Start to generate keywords for every chunk ...") | |||
| chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) | |||
| for d in docs: | |||
| d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"], | |||
| task["parser_config"]["auto_keywords"]).split(",") | |||
| cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", | |||
| {"topn": task["parser_config"]["auto_keywords"]}) | |||
| if not cached: | |||
| cached = keyword_extraction(chat_mdl, d["content_with_weight"], | |||
| task["parser_config"]["auto_keywords"]) | |||
| if cached: | |||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", | |||
| {"topn": task["parser_config"]["auto_keywords"]}) | |||
| d["important_kwd"] = cached.split(",") | |||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) | |||
| progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st)) | |||
| @@ -272,7 +279,15 @@ def build_chunks(task, progress_callback): | |||
| progress_callback(msg="Start to generate questions for every chunk ...") | |||
| chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) | |||
| for d in docs: | |||
| d["question_kwd"] = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"]).split("\n") | |||
| cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", | |||
| {"topn": task["parser_config"]["auto_questions"]}) | |||
| if not cached: | |||
| cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"]) | |||
| if cached: | |||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", | |||
| {"topn": task["parser_config"]["auto_questions"]}) | |||
| d["question_kwd"] = cached.split("\n") | |||
| d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) | |||
| progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st)) | |||