Переглянути джерело

Made task_executor async to speedup parsing (#5530)

### What problem does this PR solve?

Made task_executor async to speedup parsing

### Type of change

- [x] Performance Improvement
tags/v0.17.1
Zhichang Yu 8 місяці тому
джерело
коміт
c813c1ff4c
Аккаунт користувача з таким Email не знайдено

+ 3
- 1
api/apps/conversation_app.py Переглянути файл

@@ -17,6 +17,7 @@ import json
import re
import traceback
from copy import deepcopy
import trio
from api.db.db_models import APIToken

from api.db.services.conversation_service import ConversationService, structure_answer
@@ -386,7 +387,8 @@ def mindmap():
rank_feature=label_question(question, [kb])
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)

+ 3
- 2
api/db/services/document_service.py Переглянути файл

@@ -22,6 +22,7 @@ from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from io import BytesIO
import trio

from peewee import fn

@@ -597,8 +598,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value:
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
ensure_ascii=False, indent=2)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({

+ 20
- 13
api/utils/file_utils.py Переглянути файл

@@ -17,6 +17,8 @@ import base64
import json
import os
import re
import sys
import threading
from io import BytesIO

import pdfplumber
@@ -30,6 +32,10 @@ from api.constants import IMG_BASE64_PREFIX
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE")

LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()


def get_project_base_directory(*args):
global PROJECT_BASE
@@ -175,19 +181,20 @@ def thumbnail_img(filename, blob):
"""
filename = filename.lower()
if re.match(r".*\.pdf$", filename):
pdf = pdfplumber.open(BytesIO(blob))
buffered = BytesIO()
resolution = 32
img = None
for _ in range(10):
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
img = buffered.getvalue()
if len(img) >= 64000 and resolution >= 2:
resolution = resolution / 2
buffered = BytesIO()
else:
break
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(BytesIO(blob))
buffered = BytesIO()
resolution = 32
img = None
for _ in range(10):
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
img = buffered.getvalue()
if len(img) >= 64000 and resolution >= 2:
resolution = resolution / 2
buffered = BytesIO()
else:
break
pdf.close()
return img


+ 7
- 2
api/utils/log_utils.py Переглянути файл

@@ -18,6 +18,8 @@ import os.path
import logging
from logging.handlers import RotatingFileHandler

initialized_root_logger = False

def get_project_base_directory():
PROJECT_BASE = os.path.abspath(
os.path.join(
@@ -29,10 +31,13 @@ def get_project_base_directory():
return PROJECT_BASE

def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
logger = logging.getLogger()
if logger.hasHandlers():
global initialized_root_logger
if initialized_root_logger:
return
initialized_root_logger = True

logger = logging.getLogger()
logger.handlers.clear()
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))

os.makedirs(os.path.dirname(log_path), exist_ok=True)

+ 21
- 13
deepdoc/parser/pdf_parser.py Переглянути файл

@@ -18,6 +18,8 @@ import logging
import os
import random
from timeit import default_timer as timer
import sys
import threading

import xgboost as xgb
from io import BytesIO
@@ -34,6 +36,10 @@ from rag.nlp import rag_tokenizer
from copy import deepcopy
from huggingface_hub import snapshot_download

LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()

class RAGFlowPdfParser:
def __init__(self):
self.ocr = OCR()
@@ -948,8 +954,9 @@ class RAGFlowPdfParser:
@staticmethod
def total_page_number(fnm, binary=None):
try:
pdf = pdfplumber.open(
fnm) if not binary else pdfplumber.open(BytesIO(binary))
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(
fnm) if not binary else pdfplumber.open(BytesIO(binary))
total_page = len(pdf.pages)
pdf.close()
return total_page
@@ -968,17 +975,18 @@ class RAGFlowPdfParser:
self.page_from = page_from
start = timer()
try:
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])]
try:
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
except Exception as e:
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
self.total_page = len(self.pdf.pages)
with sys.modules[LOCK_KEY_pdfplumber]:
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])]
try:
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
except Exception as e:
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
self.total_page = len(self.pdf.pages)
except Exception:
logging.exception("RAGFlowPdfParser __images__")
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")

+ 11
- 4
deepdoc/vision/__init__.py Переглянути файл

@@ -14,7 +14,8 @@
# limitations under the License.
#
import io

import sys
import threading
import pdfplumber

from .ocr import OCR
@@ -23,6 +24,11 @@ from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer


LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()


def init_in_out(args):
from PIL import Image
import os
@@ -36,9 +42,10 @@ def init_in_out(args):

def pdf_pages(fnm, zoomin=3):
nonlocal outputs, images
pdf = pdfplumber.open(fnm)
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(pdf.pages)]
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(fnm)
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(pdf.pages)]

for i, page in enumerate(images):
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")

+ 37
- 36
graphrag/entity_resolution.py Переглянути файл

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import itertools
import re
import time
@@ -21,13 +20,14 @@ from dataclasses import dataclass
from typing import Any, Callable

import networkx as nx
import trio

from graphrag.general.extractor import Extractor
from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements
from graphrag.utils import perform_variable_replacements, chat_limiter

DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@@ -67,13 +67,13 @@ class EntityResolution(Extractor):
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"

def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}

# Wire defaults into the prompt variables
prompt_variables = {
self.prompt_variables = {
**prompt_variables,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
@@ -94,39 +94,12 @@ class EntityResolution(Extractor):
for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]

gen_conf = {"temperature": 0.5}
resolution_result = set()
for candidate_resolution_i in candidate_resolution.items():
if candidate_resolution_i[1]:
try:
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)

variables = {
**prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)

response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
result = self._process_results(len(candidate_resolution_i[1]), response,
prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
except Exception:
logging.exception("error entity resolution")
async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]:
continue
nursery.start_soon(self._resolve_candidate(candidate_resolution_i, resolution_result))

connect_graph = nx.Graph()
removed_entities = []
@@ -172,6 +145,34 @@ class EntityResolution(Extractor):
removed_entities=removed_entities
)

async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
gen_conf = {"temperature": 0.5}
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**self.prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
self.prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])

def _process_results(
self,
records_length: int,

+ 0
- 268
graphrag/general/claim_extractor.py Переглянути файл

@@ -1,268 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

import logging
import argparse
import json
import re
import traceback
from dataclasses import dataclass
from typing import Any

import tiktoken

from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.general.extractor import Extractor
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements

DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
CLAIM_MAX_GLEANINGS = 1


@dataclass
class ClaimExtractorResult:
"""Claim extractor result class definition."""

output: list[dict]
source_docs: dict[str, Any]


class ClaimExtractor(Extractor):
"""Claim extractor class definition."""

_extraction_prompt: str
_summary_prompt: str
_output_formatter_prompt: str
_input_text_key: str
_input_entity_spec_key: str
_input_claim_description_key: str
_tuple_delimiter_key: str
_record_delimiter_key: str
_completion_delimiter_key: str
_max_gleanings: int
_on_error: ErrorHandlerFn

def __init__(
self,
llm_invoker: CompletionLLM,
extraction_prompt: str | None = None,
input_text_key: str | None = None,
input_entity_spec_key: str | None = None,
input_claim_description_key: str | None = None,
input_resolved_entities_key: str | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
completion_delimiter_key: str | None = None,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
self._input_text_key = input_text_key or "input_text"
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._input_claim_description_key = (
input_claim_description_key or "claim_description"
)
self._input_resolved_entities_key = (
input_resolved_entities_key or "resolved_entities"
)
self._max_gleanings = (
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)

# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}

def __call__(
self, inputs: dict[str, Any], prompt_variables: dict | None = None
) -> ClaimExtractorResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
texts = inputs[self._input_text_key]
entity_spec = str(inputs[self._input_entity_spec_key])
claim_description = inputs[self._input_claim_description_key]
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
source_doc_map = {}

prompt_args = {
self._input_entity_spec_key: entity_spec,
self._input_claim_description_key: claim_description,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
}

all_claims: list[dict] = []
for doc_index, text in enumerate(texts):
document_id = f"d{doc_index}"
try:
claims = self._process_document(prompt_args, text, doc_index)
all_claims += [
self._clean_claim(c, document_id, resolved_entities) for c in claims
]
source_doc_map[document_id] = text
except Exception as e:
logging.exception("error extracting claim")
self._on_error(
e,
traceback.format_exc(),
{"doc_index": doc_index, "text": text},
)
continue

return ClaimExtractorResult(
output=all_claims,
source_docs=source_doc_map,
)

def _clean_claim(
self, claim: dict, document_id: str, resolved_entities: dict
) -> dict:
# clean the parsed claims to remove any claims with status = False
obj = claim.get("object_id", claim.get("object"))
subject = claim.get("subject_id", claim.get("subject"))

# If subject or object in resolved entities, then replace with resolved entity
obj = resolved_entities.get(obj, obj)
subject = resolved_entities.get(subject, subject)
claim["object_id"] = obj
claim["subject_id"] = subject
claim["doc_id"] = document_id
return claim

def _process_document(
self, prompt_args: dict, doc, doc_index: int
) -> list[dict]:
record_delimiter = prompt_args.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_args.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
variables = {
self._input_text_key: doc,
**prompt_args,
}
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
claims = results.strip().removesuffix(completion_delimiter)
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]

# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
extension = self._chat("", history, gen_conf)
claims += record_delimiter + extension.strip().removesuffix(
completion_delimiter
)

# If this isn't the last loop, check to see if we should continue
if i >= self._max_gleanings - 1:
break

history.append({"role": "assistant", "content": extension})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, self._loop_args)
if continuation != "YES":
break

result = self._parse_claim_tuples(claims, prompt_args)
for r in result:
r["doc_id"] = f"{doc_index}"
return result

def _parse_claim_tuples(
self, claims: str, prompt_variables: dict
) -> list[dict[str, Any]]:
"""Parse claim tuples."""
record_delimiter = prompt_variables.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_variables.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
tuple_delimiter = prompt_variables.get(
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
)

def pull_field(index: int, fields: list[str]) -> str | None:
return fields[index].strip() if len(fields) > index else None

result: list[dict[str, Any]] = []
claims_values = (
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
)
for claim in claims_values:
claim = claim.strip().removeprefix("(").removesuffix(")")
claim = re.sub(r".*Output:", "", claim)

# Ignore the completion delimiter
if claim == completion_delimiter:
continue

claim_fields = claim.split(tuple_delimiter)
o = {
"subject_id": pull_field(0, claim_fields),
"object_id": pull_field(1, claim_fields),
"type": pull_field(2, claim_fields),
"status": pull_field(3, claim_fields),
"start_date": pull_field(4, claim_fields),
"end_date": pull_field(5, claim_fields),
"description": pull_field(6, claim_fields),
"source_text": pull_field(7, claim_fields),
"doc_id": pull_field(8, claim_fields),
}
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
continue
result.append(o)
return result


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()

from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService

kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)

ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = {
"input_text": docs,
"entity_specs": "organization, person",
"claim_description": ""
}
claim = ex(info)
logging.info(json.dumps(claim.output, ensure_ascii=False, indent=2))

+ 0
- 71
graphrag/general/claim_prompt.py Переглянути файл

@@ -1,71 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

CLAIM_EXTRACTION_PROMPT = """
################
-Target activity-
################
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.

################
-Goal-
################
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.

################
-Steps-
################
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
For each claim, extract the following information:
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.

- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
- 5. If there's nothing satisfy the above requirements, just keep output empty.
- 6. When finished, output {completion_delimiter}

################
-Examples-
################
Example 1:
Entity specification: organization
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{completion_delimiter}

###########################
Example 2:
Entity specification: Company A, Person C
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{record_delimiter}
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
{completion_delimiter}

################
-Real Data-
################
Use the following input for your answer.
Entity specification: {entity_specs}
Claim description: {claim_description}
Text: {input_text}
Output:"""


CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"

+ 21
- 23
graphrag/general/community_reports_extractor.py Переглянути файл

@@ -17,9 +17,10 @@ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor
from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
import trio


@dataclass
@@ -52,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500

def __call__(self, graph: nx.Graph, callback: Callable | None = None):
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])

@@ -86,28 +87,25 @@ class CommunityReportsExtractor(Extractor):
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3}
try:
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response)
logging.debug(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]):
continue
response["weight"] = weight
response["entities"] = ents
except Exception:
logging.exception("CommunityReportsExtractor got exception")
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response)
logging.debug(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]):
continue
response["weight"] = weight
response["entities"] = ents

add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))

+ 47
- 56
graphrag/general/extractor.py Переглянути файл

@@ -14,16 +14,15 @@
# limitations under the License.
#
import logging
import os
import re
from collections import defaultdict, Counter
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Callable
import trio

from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
from rag.utils import truncate

@@ -91,54 +90,50 @@ class Extractor:
)
return dict(maybe_nodes), dict(maybe_edges)

def __call__(
async def __call__(
self, chunks: list[tuple[str, str]],
callback: Callable | None = None
):

results = []
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
self.callback = callback
start_ts = trio.current_time()
out_results = []
async with trio.open_nursery() as nursery:
for i, (cid, ck) in enumerate(chunks):
ck = truncate(ck, int(self._llm.max_length*0.8))
threads.append(
exe.submit(self._process_single_content, (cid, ck)))

for i, _ in enumerate(threads):
n, r, tc = _.result()
if not isinstance(n, Exception):
results.append((n, r))
if callback:
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
elif callback:
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
nursery.start_soon(self._process_single_content, (cid, ck), i, len(chunks), out_results)

maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
sum_token_count = 0
for m_nodes, m_edges, token_count in out_results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
logging.info("Inserting entities into storage...")
sum_token_count += token_count
now = trio.current_time()
if callback:
callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.")
start_ts = now
logging.info("Entities merging...")
all_entities_data = []
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
threads.append(
exe.submit(self._merge_nodes, en_nm, ents))
for t in threads:
n = t.result()
if not isinstance(n, Exception):
all_entities_data.append(n)
elif callback:
callback(msg="Knowledge graph nodes merging error: {}".format(str(n)))
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
now = trio.current_time()
if callback:
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")

logging.info("Inserting relationships into storage...")
start_ts = now
logging.info("Relationships merging...")
all_relationships_data = []
for (src, tgt), rels in maybe_edges.items():
all_relationships_data.append(self._merge_edges(src, tgt, rels))
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
now = trio.current_time()
if callback:
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")

if not len(all_entities_data) and not len(all_relationships_data):
logging.warning(
@@ -152,7 +147,7 @@ class Extractor:

return all_entities_data, all_relationships_data

def _merge_nodes(self, entity_name: str, entities: list[dict]):
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
if not entities:
return
already_entity_types = []
@@ -176,26 +171,22 @@ class Extractor:
sorted(set([dp["description"] for dp in entities] + already_description))
)
already_source_ids = flat_uniq_list(entities, "source_id")
try:
description = self._handle_entity_relation_summary(
entity_name, description
)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=already_source_ids,
)
node_data["entity_name"] = entity_name
self._set_entity_(entity_name, node_data)
return node_data
except Exception as e:
return e
description = await self._handle_entity_relation_summary(entity_name, description)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=already_source_ids,
)
node_data["entity_name"] = entity_name
self._set_entity_(entity_name, node_data)
all_relationships_data.append(node_data)

def _merge_edges(
async def _merge_edges(
self,
src_id: str,
tgt_id: str,
edges_data: list[dict]
edges_data: list[dict],
all_relationships_data
):
if not edges_data:
return
@@ -226,7 +217,7 @@ class Extractor:
"description": description,
"entity_type": 'UNKNOWN'
})
description = self._handle_entity_relation_summary(
description = await self._handle_entity_relation_summary(
f"({src_id}, {tgt_id})", description
)
edge_data = dict(
@@ -238,10 +229,9 @@ class Extractor:
source_id=source_id
)
self._set_relation_(src_id, tgt_id, edge_data)
all_relationships_data.append(edge_data)

return edge_data

def _handle_entity_relation_summary(
async def _handle_entity_relation_summary(
self,
entity_or_relation_name: str,
description: str
@@ -256,5 +246,6 @@ class Extractor:
)
use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}")
summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})
async with chat_limiter:
summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}))
return summary

+ 39
- 45
graphrag/general/graph_extractor.py Переглянути файл

@@ -5,15 +5,15 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

import logging
import re
from typing import Any, Callable
from dataclasses import dataclass
import tiktoken
import trio

from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
@@ -102,53 +102,47 @@ class GraphExtractor(Extractor):
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
}

def _process_single_content(self,
chunk_key_dp: tuple[str, str]
):
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
token_count = 0

chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
variables = {
**self._prompt_variables,
self._input_text_key: content,
}
try:
gen_conf = {"temperature": 0.3}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(hint_prompt + response)

results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]

# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._chat("", history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""

# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, {"temperature": 0.8})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break

record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None



gen_conf = {"temperature": 0.3}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(hint_prompt + response)

results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]

# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""

# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8}))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
out_results.append((maybe_nodes, maybe_edges, token_count))
if self.callback:
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")

+ 67
- 54
graphrag/general/index.py Переглянути файл

@@ -17,6 +17,7 @@ import json
import logging
from functools import reduce, partial
import networkx as nx
import trio

from api import settings
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
@@ -41,18 +42,24 @@ class Dealer:
embed_bdl=None,
callback=None
):
docids = list(set([docid for docid,_ in chunks]))
self.tenant_id = tenant_id
self.kb_id = kb_id
self.chunks = chunks
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
ext = extractor(self.llm_bdl, language=language,
self.ext = extractor(self.llm_bdl, language=language,
entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
)
ents, rels = ext(chunks, callback)
self.graph = nx.Graph()
self.callback = callback

async def __call__(self):
docids = list(set([docid for docid, _ in self.chunks]))
ents, rels = await self.ext(self.chunks, self.callback)
for en in ents:
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])

@@ -64,16 +71,16 @@ class Dealer:
#description=rel["description"]
)

with RedisDistributedLock(kb_id, 60*60):
old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
with RedisDistributedLock(self.kb_id, 60*60):
old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph])
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)
if old_doc_ids:
docids.extend(old_doc_ids)
docids = list(set(docids))
set_graph(tenant_id, kb_id, self.graph, docids)
set_graph(self.tenant_id, self.kb_id, self.graph, docids)


class WithResolution(Dealer):
@@ -84,47 +91,50 @@ class WithResolution(Dealer):
embed_bdl=None,
callback=None
):
self.tenant_id = tenant_id
self.kb_id = kb_id
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl

with RedisDistributedLock(kb_id, 60*60):
self.graph, doc_ids = get_graph(tenant_id, kb_id)
self.callback = callback
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return

if callback:
callback(msg="Fetch the existing graph.")
if self.callback:
self.callback(msg="Fetch the existing graph.")
er = EntityResolution(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
reso = er(self.graph)
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
reso = await er(self.graph)
self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if callback:
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
set_graph(tenant_id, kb_id, self.graph, doc_ids)
if self.callback:
self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))

settings.docStoreConn.delete({
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"kb_id": self.kb_id,
"from_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.docStoreConn.delete({
}, search.index_name(self.tenant_id), self.kb_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"kb_id": self.kb_id,
"to_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.docStoreConn.delete({
}, search.index_name(self.tenant_id), self.kb_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "entity",
"kb_id": kb_id,
"kb_id": self.kb_id,
"entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
}, search.index_name(self.tenant_id), self.kb_id))


class WithCommunity(Dealer):
@@ -136,38 +146,41 @@ class WithCommunity(Dealer):
callback=None
):

self.tenant_id = tenant_id
self.kb_id = kb_id
self.community_structure = None
self.community_reports = None
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl

with RedisDistributedLock(kb_id, 60*60):
self.graph, doc_ids = get_graph(tenant_id, kb_id)
self.callback = callback
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return
if callback:
callback(msg="Fetch the existing graph.")
if self.callback:
self.callback(msg="Fetch the existing graph.")

cr = CommunityReportsExtractor(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
cr = cr(self.graph, callback=callback)
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
cr = await cr(self.graph, callback=self.callback)
self.community_structure = cr.structured_output
self.community_reports = cr.output
set_graph(tenant_id, kb_id, self.graph, doc_ids)
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))

if callback:
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
if self.callback:
self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))

settings.docStoreConn.delete({
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "community_report",
"kb_id": kb_id
}, search.index_name(tenant_id), kb_id)
"kb_id": self.kb_id
}, search.index_name(self.tenant_id), self.kb_id))

for stru, rep in zip(self.community_structure, self.community_reports):
obj = {
@@ -183,7 +196,7 @@ class WithCommunity(Dealer):
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": kb_id,
"kb_id": self.kb_id,
"source_id": doc_ids,
"available_int": 0
}
@@ -193,5 +206,5 @@ class WithCommunity(Dealer):
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
#except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}")
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id)))


+ 42
- 59
graphrag/general/mind_map_extractor.py Переглянути файл

@@ -16,16 +16,14 @@

import logging
import collections
import os
import re
import traceback
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
import trio

from graphrag.general.extractor import Extractor
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
import markdown_to_json
from functools import reduce
@@ -80,63 +78,47 @@ class MindMapExtractor(Extractor):
)
return arr

def __call__(
async def __call__(
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
) -> MindMapResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}

try:
res = []
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = []
cnt = 0
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))

for i, _ in enumerate(threads):
res.append(_.result())

if not res:
return MindMapResult(output={"id": "root", "children": []})

merge_json = reduce(self._merge, res)
if len(merge_json) > 1:
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
keyset = set(i for i in keys if i)
merge_json = {
"id": "root",
"children": [
{
"id": self._key(k),
"children": self._be_children(v, keyset)
}
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
]
}
else:
k = self._key(list(merge_json.keys())[0])
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}

except Exception as e:
logging.exception("error mind graph")
self._on_error(
e,
traceback.format_exc(), None
)
merge_json = {"error": str(e)}
res = []
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = []
cnt = 0
async with trio.open_nursery() as nursery:
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
if not res:
return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res)
if len(merge_json) > 1:
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
keyset = set(i for i in keys if i)
merge_json = {
"id": "root",
"children": [
{
"id": self._key(k),
"children": self._be_children(v, keyset)
}
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
]
}
else:
k = self._key(list(merge_json.keys())[0])
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}

return MindMapResult(output=merge_json)

@@ -181,8 +163,8 @@ class MindMapExtractor(Extractor):

return self._list_to_kv(to_ret)

def _process_document(
self, text: str, prompt_variables: dict[str, str]
async def _process_document(
self, text: str, prompt_variables: dict[str, str], out_res
) -> str:
variables = {
**prompt_variables,
@@ -190,8 +172,9 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))
return self._todict(markdown_to_json.dictify(response))
out_res.append(self._todict(markdown_to_json.dictify(response)))

+ 4
- 0
graphrag/general/smoke.py Переглянути файл

@@ -18,6 +18,7 @@ import argparse
import json

import networkx as nx
import trio

from api import settings
from api.db import LLMType
@@ -54,10 +55,13 @@ if __name__ == "__main__":
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)

dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
trio.run(dealer())
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))

dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())

print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))

+ 37
- 36
graphrag/light/graph_extractor.py Переглянути файл

@@ -4,16 +4,16 @@
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import re
from typing import Any, Callable
from dataclasses import dataclass
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.light.graph_prompt import PROMPTS
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
import trio


@dataclass
@@ -82,7 +82,7 @@ class GraphExtractor(Extractor):
)
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)

def _process_single_content(self, chunk_key_dp: tuple[str, str]):
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
@@ -90,38 +90,39 @@ class GraphExtractor(Extractor):
**self._context_base, input_text="{input_text}"
).format(**self._context_base, input_text=content)

try:
gen_conf = {"temperature": 0.8}
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
glean_result = self._chat(hint_prompt, history, gen_conf)
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
if now_glean_index == self._max_gleanings - 1:
break
gen_conf = {"temperature": 0.8}
async with chat_limiter:
final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
if now_glean_index == self._max_gleanings - 1:
break

if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
async with chat_limiter:
if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break

records = split_string_by_multi_markers(
final_result,
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
)
rcds = []
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
rcds.append(record.group(1))
records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None
records = split_string_by_multi_markers(
final_result,
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
)
rcds = []
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
rcds.append(record.group(1))
records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
out_results.append((maybe_nodes, maybe_edges, token_count))
if self.callback:
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")

+ 3
- 0
graphrag/utils.py Переглянути файл

@@ -15,6 +15,8 @@ from collections import defaultdict
from copy import deepcopy
from hashlib import md5
from typing import Any, Callable
import os
import trio

import networkx as nx
import numpy as np
@@ -28,6 +30,7 @@ from rag.utils.redis_conn import REDIS_CONN

ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]

chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 100)))

def perform_variable_replacements(
input: str, history: list[dict] | None = None, variables: dict | None = None

+ 6
- 2
pyproject.toml Переглянути файл

@@ -122,7 +122,8 @@ dependencies = [
"pyodbc>=5.2.0,<6.0.0",
"pyicu>=2.13.1,<3.0.0",
"flasgger>=0.9.7.1,<0.10.0",
"xxhash>=3.5.0,<4.0.0"
"xxhash>=3.5.0,<4.0.0",
"trio>=0.29.0",
]

[project.optional-dependencies]
@@ -133,4 +134,7 @@ full = [
"flagembedding==1.2.10",
"torch>=2.5.0,<3.0.0",
"transformers>=4.35.0,<5.0.0"
]
]

[[tool.uv.index]]
url = "https://mirrors.aliyun.com/pypi/simple"

+ 16
- 20
rag/raptor.py Переглянути файл

@@ -14,15 +14,14 @@
# limitations under the License.
#
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio

from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
from rag.utils import truncate


@@ -68,24 +67,25 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters

def __call__(self, chunks, random_state, callback=None):
async def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
if len(chunks) <= 1:
return []
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]

def summarize(ck_idx, lock):
async def summarize(ck_idx, lock):
nonlocal chunks
try:
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
cnt = self._chat("You're a helpful assistant.",
[{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token}
)
async with chat_limiter:
cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
[{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token}
))
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
cnt)
logging.debug(f"SUM: {cnt}")
@@ -97,10 +97,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
return e

labels = []
lock = Lock()
while end - start > 1:
embeddings = [embd for _, embd in chunks[start: end]]
if len(embeddings) == 2:
summarize([start, start + 1], Lock())
await summarize([start, start + 1], lock)
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
labels.extend([0, 0])
@@ -122,19 +123,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
lock = Lock()
with ThreadPoolExecutor(max_workers=int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))) as executor:
threads = []

async with trio.open_nursery() as nursery:
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
if not ck_idx:
continue
threads.append(executor.submit(summarize, ck_idx, lock))
wait(threads, return_when=ALL_COMPLETED)
for th in threads:
if isinstance(th.result(), Exception):
raise th.result()
logging.debug(str([t.result() for t in threads]))
async with chat_limiter:
nursery.start_soon(lambda: summarize(ck_idx, lock))

assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)

+ 158
- 254
rag/svr/task_executor.py Переглянути файл

@@ -30,7 +30,6 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME)

import asyncio
import logging
import os
from datetime import datetime
@@ -38,14 +37,14 @@ import json
import xxhash
import copy
import re
import time
import threading
from functools import partial
from io import BytesIO
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import tracemalloc
import resource
import signal
import trio

import numpy as np
from peewee import DoesNotExist
@@ -64,8 +63,9 @@ from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter

BATCH_SIZE = 64

@@ -88,28 +88,28 @@ FACTORY = {
ParserType.TAG.value: tag
}

UNACKED_ITERATOR = None
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
PAYLOAD: Payload | None = None
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
PENDING_TASKS = 0
LAG_TASKS = 0

mt_lock = threading.Lock()
DONE_TASKS = 0
FAILED_TASKS = 0
CURRENT_TASK = None

tracemalloc_started = False
CURRENT_TASKS = {}

MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)

# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
global tracemalloc_started
if not tracemalloc_started:
logging.info("got SIGUSR1, start tracemalloc")
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
tracemalloc_started = True
else:
logging.info("got SIGUSR1, tracemalloc is already running")
logging.info("tracemalloc is already running")

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
@@ -117,17 +117,17 @@ def start_tracemalloc_and_snapshot(signum, frame):

snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
logging.info(f"taken snapshot {snapshot_file}")
current, peak = tracemalloc.get_traced_memory()
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")

# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
global tracemalloc_started
if tracemalloc_started:
logging.info("go SIGUSR2, stop tracemalloc")
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
tracemalloc_started = False
else:
logging.info("got SIGUSR2, tracemalloc not running")
logging.info("tracemalloc not running")

class TaskCanceledException(Exception):
def __init__(self, msg):
@@ -135,17 +135,9 @@ class TaskCanceledException(Exception):


def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
try:
cancel = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
cancel = TaskService.do_cancel(task_id)

if cancel:
msg += " [Canceled]"
@@ -162,66 +154,55 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
d["progress"] = prog

logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
try:
TaskService.update_progress(task_id, d)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
TaskService.update_progress(task_id, d)

close_connection()
if cancel and PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
if cancel:
raise TaskCanceledException(msg)

def collect():
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global UNACKED_ITERATOR
try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not PAYLOAD:
time.sleep(1)
return None
if not UNACKED_ITERATOR:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
try:
redis_msg = next(UNACKED_ITERATOR)
except StopIteration:
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not redis_msg:
await trio.sleep(1)
return None, None
except Exception:
logging.exception("Get task event from queue exception")
return None
logging.exception("collect got exception")
return None, None

msg = PAYLOAD.get_message()
msg = redis_msg.get_message()
if not msg:
return None
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
redis_msg.ack()
return None, None

task = None
canceled = False
try:
task = TaskService.get_task(msg["id"])
if task:
_, doc = DocumentService.get_by_id(task["doc_id"])
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except DoesNotExist:
pass
except Exception:
logging.exception("collect get_task exception")
task = TaskService.get_task(msg["id"])
if task:
_, doc = DocumentService.get_by_id(task["doc_id"])
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
if not task or canceled:
state = "is unknown" if not task else "has been cancelled"
with mt_lock:
DONE_TASKS += 1
logging.info(f"collect task {msg['id']} {state}")
FAILED_TASKS += 1
logging.warning(f"collect task {msg['id']} {state}")
redis_msg.ack()
return None

task["task_type"] = msg.get("task_type", "")
return task
return redis_msg, task


def get_storage_binary(bucket, name):
return STORAGE_IMPL.get(bucket, name)
async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))


def build_chunks(task, progress_callback):
async def build_chunks(task, progress_callback):
if task["size"] > DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
@@ -231,7 +212,7 @@ def build_chunks(task, progress_callback):
try:
st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
binary = get_storage_binary(bucket, name)
binary = await get_storage_binary(bucket, name)
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
@@ -247,9 +228,10 @@ def build_chunks(task, progress_callback):
raise

try:
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException:
raise
@@ -286,7 +268,7 @@ def build_chunks(task, progress_callback):
d["image"].save(output_buffer, format='JPEG')

st = timer()
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
el += timer() - st
except Exception:
logging.exception(
@@ -306,14 +288,16 @@ def build_chunks(task, progress_callback):
async def doc_keyword_extraction(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached:
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))

if task["parser_config"].get("auto_questions", 0):
@@ -324,13 +308,15 @@ def build_chunks(task, progress_callback):
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))

if task["kb_parser_config"].get("tag_kb_ids", []):
@@ -361,14 +347,16 @@ def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
if cached:
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))

return docs
@@ -379,7 +367,7 @@ def init_kb(row, vector_size: int):
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)


def embedding(docs, mdl, parser_config=None, callback=None):
async def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
batch_size = 16
@@ -396,13 +384,13 @@ def embedding(docs, mdl, parser_config=None, callback=None):

tk_count = 0
if len(tts) == len(cnts):
vts, c = mdl.encode(tts[0: 1])
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
tk_count += c

cnts_ = np.array([])
for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i + batch_size])
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
if len(cnts_) == 0:
cnts_ = vts
else:
@@ -424,7 +412,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size


def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
@@ -440,7 +428,7 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
row["parser_config"]["raptor"]["threshold"]
)
original_length = len(chunks)
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])],
@@ -465,13 +453,13 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count


def run_graphrag(row, chat_model, language, embedding_model, callback=None):
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))

Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
@@ -480,9 +468,10 @@ def run_graphrag(row, chat_model, language, embedding_model, callback=None):
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
await dealer()


def do_handle_task(task):
async def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
@@ -494,6 +483,7 @@ def do_handle_task(task):
task_doc_id = task["doc_id"]
task_document_name = task["name"]
task_parser_config = task["parser_config"]
task_start_ts = timer()

# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@@ -505,11 +495,7 @@ def do_handle_task(task):
progress_callback(-1, msg=error_message)
raise Exception(error_message)

try:
task_canceled = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"task {task_id} is unknown")
return
task_canceled = TaskService.do_cancel(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
@@ -529,71 +515,41 @@ def do_handle_task(task):

# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
try:
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_resolution":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_community":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
return
else:
# Standard chunking methods
start_ts = timer()
chunks = build_chunks(task, progress_callback)
chunks = await build_chunks(task, progress_callback)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if chunks is None:
return
@@ -605,7 +561,7 @@ def do_handle_task(task):
progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer()
try:
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
except Exception as e:
error_message = "Generate embedding error:{}".format(str(e))
progress_callback(-1, error_message)
@@ -621,8 +577,7 @@ def do_handle_task(task):
doc_store_result = ""
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
task_dataset_id)
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@@ -635,8 +590,7 @@ def do_handle_task(task):
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
task_dataset_id)
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
@@ -645,51 +599,39 @@ def do_handle_task(task):
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)

time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
token_count, time_cost))
token_count, task_time_cost))


def handle_task():
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
task = collect()
if task:
async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
return
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
await do_handle_task(task)
DONE_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
logging.info(f"handle_task done for task {json.dumps(task)}")
except Exception as e:
FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
with mt_lock:
CURRENT_TASK = copy.deepcopy(task)
do_handle_task(task)
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
logging.info(f"handle_task done for task {json.dumps(task)}")
except TaskCanceledException:
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
except Exception:
pass
logging.debug("handle_task got TaskCanceledException", exc_info=True)
except Exception as e:
with mt_lock:
FAILED_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None


def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
redis_msg.ack()


async def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
while True:
try:
@@ -699,17 +641,17 @@ def report_status():
PENDING_TASKS = int(group_info.get("pending", 0))
LAG_TASKS = int(group_info.get("lag", 0))

with mt_lock:
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": CURRENT_TASK,
})
current = copy.deepcopy(CURRENT_TASKS)
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": current,
})
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")

@@ -718,27 +660,10 @@ def report_status():
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception:
logging.exception("report_status got exception")
time.sleep(30)


def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
msg = ""
if dump_full:
stats2 = snapshot2.statistics('lineno')
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
for stat in stats2[:10]:
msg += f"{stat}\n"
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
for stat in stats1_vs_2[:10]:
msg += f"{stat}\n"
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
for stat in stats1_vs_2[:3]:
msg += '\n'.join(stat.traceback.format())
logging.info(msg)


def main():
await trio.sleep(30)


async def main():
logging.info(r"""
______ __ ______ __
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
@@ -755,33 +680,12 @@ def main():
if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None)

# Create an event to signal the background thread to exit
stop_event = threading.Event()
background_thread = threading.Thread(target=report_status)
background_thread.daemon = True
background_thread.start()
# Handle SIGINT (Ctrl+C)
def signal_handler(sig, frame):
logging.info("Received Ctrl+C, shutting down gracefully...")
stop_event.set()
# Give the background thread time to clean up
if background_thread.is_alive():
background_thread.join(timeout=5)
logging.info("Exiting...")
sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

try:
while not stop_event.is_set():
handle_task()
except KeyboardInterrupt:
logging.info("Interrupted by keyboard, shutting down...")
stop_event.set()
if background_thread.is_alive():
background_thread.join(timeout=5)
async with trio.open_nursery() as nursery:
nursery.start_soon(report_status)
while True:
async with task_limiter:
nursery.start_soon(handle_task)
logging.error("BUG!!! You should not reach here!!!")

if __name__ == "__main__":
main()
trio.run(main)

+ 20
- 22
rag/utils/redis_conn.py Переглянути файл

@@ -24,7 +24,7 @@ from rag import settings
from rag.utils import singleton


class Payload:
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer
self.__queue_name = queue_name
@@ -43,6 +43,9 @@ class Payload:
def get_message(self):
return self.__message

def get_msg_id(self):
return self.__msg_id


@singleton
class RedisDB:
@@ -206,9 +209,8 @@ class RedisDB:
)
return False

def queue_consumer(
self, queue_name, group_name, consumer_name, msg_id=b">"
) -> Payload:
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
"""https://redis.io/docs/latest/commands/xreadgroup/"""
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(e["name"] == group_name for e in group_info):
@@ -217,15 +219,17 @@ class RedisDB:
"groupname": group_name,
"consumername": consumer_name,
"count": 1,
"block": 10000,
"block": 5,
"streams": {queue_name: msg_id},
}
messages = self.REDIS.xreadgroup(**args)
if not messages:
return None
stream, element_list = messages[0]
if not element_list:
return None
msg_id, payload = element_list[0]
res = Payload(self.REDIS, queue_name, group_name, msg_id, payload)
res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
return res
except Exception as e:
if "key" in str(e):
@@ -239,30 +243,24 @@ class RedisDB:
)
return None

def get_unacked_for(self, consumer_name, queue_name, group_name):
def get_unacked_iterator(self, queue_name, group_name, consumer_name):
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(e["name"] == group_name for e in group_info):
return
pendings = self.REDIS.xpending_range(
queue_name,
group_name,
min=0,
max=10000000000000,
count=1,
consumername=consumer_name,
)
if not pendings:
return
msg_id = pendings[0]["message_id"]
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
_, payload = msg[0]
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
current_min = 0
while True:
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
if not payload:
return
current_min = payload.get_msg_id()
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
yield payload
except Exception as e:
if "key" in str(e):
return
logging.exception(
"RedisDB.get_unacked_for " + consumer_name + " got exception: " + str(e)
"RedisDB.get_unacked_iterator " + consumer_name + " got exception: "
)
self.__open__()


+ 21
- 31
uv.lock Переглянути файл

@@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.10, <3.13"
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'darwin'",
@@ -1083,9 +1084,6 @@ name = "datrie"
version = "0.8.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/44/02/53f0cf0bf0cd629ba6c2cc13f2f9db24323459e9c19463783d890a540a96/datrie-0.8.2-pp273-pypy_73-win32.whl", hash = "sha256:b07bd5fdfc3399a6dab86d6e35c72b1dbd598e80c97509c7c7518ab8774d3fda" },
]

[[package]]
name = "decorator"
@@ -1362,17 +1360,17 @@ name = "fastembed-gpu"
version = "0.3.6"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "loguru" },
{ name = "mmh3" },
{ name = "numpy" },
{ name = "onnxruntime-gpu" },
{ name = "pillow" },
{ name = "pystemmer" },
{ name = "requests" },
{ name = "snowballstemmer" },
{ name = "tokenizers" },
{ name = "tqdm" },
{ name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
wheels = [
@@ -3485,12 +3483,12 @@ name = "onnxruntime-gpu"
version = "1.19.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "coloredlogs" },
{ name = "flatbuffers" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "protobuf" },
{ name = "sympy" },
{ name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
@@ -4164,15 +4162,6 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" },
]

[[package]]
name = "pybind11"
version = "2.13.6"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5" },
]

[[package]]
name = "pyclipper"
version = "1.3.0.post5"
@@ -4230,8 +4219,6 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" },
{ url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" },
{ url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" },
{ url = "https://mirrors.aliyun.com/pypi/packages/e7/c5/9140bb867141d948c8e242013ec8a8011172233c898dfdba0a2417c3169a/pycryptodomex-3.20.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:1be97461c439a6af4fe1cf8bf6ca5936d3db252737d2f379cc6b2e394e12a458" },
{ url = "https://mirrors.aliyun.com/pypi/packages/5e/6a/04acb4978ce08ab16890c70611ebc6efd251681341617bbb9e53356dee70/pycryptodomex-3.20.0-pp27-pypy_73-win32.whl", hash = "sha256:19764605feea0df966445d46533729b645033f134baeb3ea26ad518c9fdf212c" },
{ url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" },
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" },
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" },
@@ -4820,6 +4807,7 @@ dependencies = [
{ name = "tencentcloud-sdk-python" },
{ name = "tika" },
{ name = "tiktoken" },
{ name = "trio" },
{ name = "umap-learn" },
{ name = "valkey" },
{ name = "vertexai" },
@@ -4954,6 +4942,7 @@ requires-dist = [
{ name = "tiktoken", specifier = "==0.7.0" },
{ name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" },
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
{ name = "trio", specifier = ">=0.29.0" },
{ name = "umap-learn", specifier = "==0.5.6" },
{ name = "valkey", specifier = "==6.0.2" },
{ name = "vertexai", specifier = "==1.64.0" },
@@ -4969,6 +4958,7 @@ requires-dist = [
{ name = "yfinance", specifier = "==0.1.96" },
{ name = "zhipuai", specifier = "==2.0.1" },
]
provides-extras = ["full"]

[[package]]
name = "ranx"

Завантаження…
Відмінити
Зберегти