### What problem does this PR solve? Made task_executor async to speedup parsing ### Type of change - [x] Performance Improvementtags/v0.17.1
| import re | import re | ||||
| import traceback | import traceback | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| import trio | |||||
| from api.db.db_models import APIToken | from api.db.db_models import APIToken | ||||
| from api.db.services.conversation_service import ConversationService, structure_answer | from api.db.services.conversation_service import ConversationService, structure_answer | ||||
| rank_feature=label_question(question, [kb]) | rank_feature=label_question(question, [kb]) | ||||
| ) | ) | ||||
| mindmap = MindMapExtractor(chat_mdl) | mindmap = MindMapExtractor(chat_mdl) | ||||
| mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output | |||||
| mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) | |||||
| mind_map = mind_map.output | |||||
| if "error" in mind_map: | if "error" in mind_map: | ||||
| return server_error_response(Exception(mind_map["error"])) | return server_error_response(Exception(mind_map["error"])) | ||||
| return get_json_result(data=mind_map) | return get_json_result(data=mind_map) |
| from copy import deepcopy | from copy import deepcopy | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from io import BytesIO | from io import BytesIO | ||||
| import trio | |||||
| from peewee import fn | from peewee import fn | ||||
| if parser_ids[doc_id] != ParserType.PICTURE.value: | if parser_ids[doc_id] != ParserType.PICTURE.value: | ||||
| mindmap = MindMapExtractor(llm_bdl) | mindmap = MindMapExtractor(llm_bdl) | ||||
| try: | try: | ||||
| mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, | |||||
| ensure_ascii=False, indent=2) | |||||
| mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) | |||||
| mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) | |||||
| if len(mind_map) < 32: | if len(mind_map) < 32: | ||||
| raise Exception("Few content: " + mind_map) | raise Exception("Few content: " + mind_map) | ||||
| cks.append({ | cks.append({ |
| import json | import json | ||||
| import os | import os | ||||
| import re | import re | ||||
| import sys | |||||
| import threading | |||||
| from io import BytesIO | from io import BytesIO | ||||
| import pdfplumber | import pdfplumber | ||||
| PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") | PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") | ||||
| RAG_BASE = os.getenv("RAG_BASE") | RAG_BASE = os.getenv("RAG_BASE") | ||||
| LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" | |||||
| if LOCK_KEY_pdfplumber not in sys.modules: | |||||
| sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() | |||||
| def get_project_base_directory(*args): | def get_project_base_directory(*args): | ||||
| global PROJECT_BASE | global PROJECT_BASE | ||||
| """ | """ | ||||
| filename = filename.lower() | filename = filename.lower() | ||||
| if re.match(r".*\.pdf$", filename): | if re.match(r".*\.pdf$", filename): | ||||
| pdf = pdfplumber.open(BytesIO(blob)) | |||||
| buffered = BytesIO() | |||||
| resolution = 32 | |||||
| img = None | |||||
| for _ in range(10): | |||||
| # https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image | |||||
| pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png") | |||||
| img = buffered.getvalue() | |||||
| if len(img) >= 64000 and resolution >= 2: | |||||
| resolution = resolution / 2 | |||||
| buffered = BytesIO() | |||||
| else: | |||||
| break | |||||
| with sys.modules[LOCK_KEY_pdfplumber]: | |||||
| pdf = pdfplumber.open(BytesIO(blob)) | |||||
| buffered = BytesIO() | |||||
| resolution = 32 | |||||
| img = None | |||||
| for _ in range(10): | |||||
| # https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image | |||||
| pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png") | |||||
| img = buffered.getvalue() | |||||
| if len(img) >= 64000 and resolution >= 2: | |||||
| resolution = resolution / 2 | |||||
| buffered = BytesIO() | |||||
| else: | |||||
| break | |||||
| pdf.close() | pdf.close() | ||||
| return img | return img | ||||
| import logging | import logging | ||||
| from logging.handlers import RotatingFileHandler | from logging.handlers import RotatingFileHandler | ||||
| initialized_root_logger = False | |||||
| def get_project_base_directory(): | def get_project_base_directory(): | ||||
| PROJECT_BASE = os.path.abspath( | PROJECT_BASE = os.path.abspath( | ||||
| os.path.join( | os.path.join( | ||||
| return PROJECT_BASE | return PROJECT_BASE | ||||
| def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): | def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): | ||||
| logger = logging.getLogger() | |||||
| if logger.hasHandlers(): | |||||
| global initialized_root_logger | |||||
| if initialized_root_logger: | |||||
| return | return | ||||
| initialized_root_logger = True | |||||
| logger = logging.getLogger() | |||||
| logger.handlers.clear() | |||||
| log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log")) | log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log")) | ||||
| os.makedirs(os.path.dirname(log_path), exist_ok=True) | os.makedirs(os.path.dirname(log_path), exist_ok=True) |
| import os | import os | ||||
| import random | import random | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| import sys | |||||
| import threading | |||||
| import xgboost as xgb | import xgboost as xgb | ||||
| from io import BytesIO | from io import BytesIO | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||
| LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" | |||||
| if LOCK_KEY_pdfplumber not in sys.modules: | |||||
| sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() | |||||
| class RAGFlowPdfParser: | class RAGFlowPdfParser: | ||||
| def __init__(self): | def __init__(self): | ||||
| self.ocr = OCR() | self.ocr = OCR() | ||||
| @staticmethod | @staticmethod | ||||
| def total_page_number(fnm, binary=None): | def total_page_number(fnm, binary=None): | ||||
| try: | try: | ||||
| pdf = pdfplumber.open( | |||||
| fnm) if not binary else pdfplumber.open(BytesIO(binary)) | |||||
| with sys.modules[LOCK_KEY_pdfplumber]: | |||||
| pdf = pdfplumber.open( | |||||
| fnm) if not binary else pdfplumber.open(BytesIO(binary)) | |||||
| total_page = len(pdf.pages) | total_page = len(pdf.pages) | ||||
| pdf.close() | pdf.close() | ||||
| return total_page | return total_page | ||||
| self.page_from = page_from | self.page_from = page_from | ||||
| start = timer() | start = timer() | ||||
| try: | try: | ||||
| self.pdf = pdfplumber.open(fnm) if isinstance( | |||||
| fnm, str) else pdfplumber.open(BytesIO(fnm)) | |||||
| self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in | |||||
| enumerate(self.pdf.pages[page_from:page_to])] | |||||
| try: | |||||
| self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]] | |||||
| except Exception as e: | |||||
| logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}") | |||||
| self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead. | |||||
| self.total_page = len(self.pdf.pages) | |||||
| with sys.modules[LOCK_KEY_pdfplumber]: | |||||
| self.pdf = pdfplumber.open(fnm) if isinstance( | |||||
| fnm, str) else pdfplumber.open(BytesIO(fnm)) | |||||
| self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in | |||||
| enumerate(self.pdf.pages[page_from:page_to])] | |||||
| try: | |||||
| self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]] | |||||
| except Exception as e: | |||||
| logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}") | |||||
| self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead. | |||||
| self.total_page = len(self.pdf.pages) | |||||
| except Exception: | except Exception: | ||||
| logging.exception("RAGFlowPdfParser __images__") | logging.exception("RAGFlowPdfParser __images__") | ||||
| logging.info(f"__images__ dedupe_chars cost {timer() - start}s") | logging.info(f"__images__ dedupe_chars cost {timer() - start}s") |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import io | import io | ||||
| import sys | |||||
| import threading | |||||
| import pdfplumber | import pdfplumber | ||||
| from .ocr import OCR | from .ocr import OCR | ||||
| from .table_structure_recognizer import TableStructureRecognizer | from .table_structure_recognizer import TableStructureRecognizer | ||||
| LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" | |||||
| if LOCK_KEY_pdfplumber not in sys.modules: | |||||
| sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() | |||||
| def init_in_out(args): | def init_in_out(args): | ||||
| from PIL import Image | from PIL import Image | ||||
| import os | import os | ||||
| def pdf_pages(fnm, zoomin=3): | def pdf_pages(fnm, zoomin=3): | ||||
| nonlocal outputs, images | nonlocal outputs, images | ||||
| pdf = pdfplumber.open(fnm) | |||||
| images = [p.to_image(resolution=72 * zoomin).annotated for i, p in | |||||
| enumerate(pdf.pages)] | |||||
| with sys.modules[LOCK_KEY_pdfplumber]: | |||||
| pdf = pdfplumber.open(fnm) | |||||
| images = [p.to_image(resolution=72 * zoomin).annotated for i, p in | |||||
| enumerate(pdf.pages)] | |||||
| for i, page in enumerate(images): | for i, page in enumerate(images): | ||||
| outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg") | outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg") |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import logging | |||||
| import itertools | import itertools | ||||
| import re | import re | ||||
| import time | import time | ||||
| from typing import Any, Callable | from typing import Any, Callable | ||||
| import networkx as nx | import networkx as nx | ||||
| import trio | |||||
| from graphrag.general.extractor import Extractor | from graphrag.general.extractor import Extractor | ||||
| from rag.nlp import is_english | from rag.nlp import is_english | ||||
| import editdistance | import editdistance | ||||
| from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT | from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT | ||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| from graphrag.utils import perform_variable_replacements | |||||
| from graphrag.utils import perform_variable_replacements, chat_limiter | |||||
| DEFAULT_RECORD_DELIMITER = "##" | DEFAULT_RECORD_DELIMITER = "##" | ||||
| DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" | DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" | ||||
| self._resolution_result_delimiter_key = "resolution_result_delimiter" | self._resolution_result_delimiter_key = "resolution_result_delimiter" | ||||
| self._input_text_key = "input_text" | self._input_text_key = "input_text" | ||||
| def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: | |||||
| async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: | |||||
| """Call method definition.""" | """Call method definition.""" | ||||
| if prompt_variables is None: | if prompt_variables is None: | ||||
| prompt_variables = {} | prompt_variables = {} | ||||
| # Wire defaults into the prompt variables | # Wire defaults into the prompt variables | ||||
| prompt_variables = { | |||||
| self.prompt_variables = { | |||||
| **prompt_variables, | **prompt_variables, | ||||
| self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) | self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) | ||||
| or DEFAULT_RECORD_DELIMITER, | or DEFAULT_RECORD_DELIMITER, | ||||
| for k, v in node_clusters.items(): | for k, v in node_clusters.items(): | ||||
| candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] | candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] | ||||
| gen_conf = {"temperature": 0.5} | |||||
| resolution_result = set() | resolution_result = set() | ||||
| for candidate_resolution_i in candidate_resolution.items(): | |||||
| if candidate_resolution_i[1]: | |||||
| try: | |||||
| pair_txt = [ | |||||
| f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] | |||||
| for index, candidate in enumerate(candidate_resolution_i[1]): | |||||
| pair_txt.append( | |||||
| f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}') | |||||
| sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions' | |||||
| pair_txt.append( | |||||
| f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)') | |||||
| pair_prompt = '\n'.join(pair_txt) | |||||
| variables = { | |||||
| **prompt_variables, | |||||
| self._input_text_key: pair_prompt | |||||
| } | |||||
| text = perform_variable_replacements(self._resolution_prompt, variables=variables) | |||||
| response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||||
| result = self._process_results(len(candidate_resolution_i[1]), response, | |||||
| prompt_variables.get(self._record_delimiter_key, | |||||
| DEFAULT_RECORD_DELIMITER), | |||||
| prompt_variables.get(self._entity_index_dilimiter_key, | |||||
| DEFAULT_ENTITY_INDEX_DELIMITER), | |||||
| prompt_variables.get(self._resolution_result_delimiter_key, | |||||
| DEFAULT_RESOLUTION_RESULT_DELIMITER)) | |||||
| for result_i in result: | |||||
| resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) | |||||
| except Exception: | |||||
| logging.exception("error entity resolution") | |||||
| async with trio.open_nursery() as nursery: | |||||
| for candidate_resolution_i in candidate_resolution.items(): | |||||
| if not candidate_resolution_i[1]: | |||||
| continue | |||||
| nursery.start_soon(self._resolve_candidate(candidate_resolution_i, resolution_result)) | |||||
| connect_graph = nx.Graph() | connect_graph = nx.Graph() | ||||
| removed_entities = [] | removed_entities = [] | ||||
| removed_entities=removed_entities | removed_entities=removed_entities | ||||
| ) | ) | ||||
| async def _resolve_candidate(self, candidate_resolution_i, resolution_result): | |||||
| gen_conf = {"temperature": 0.5} | |||||
| pair_txt = [ | |||||
| f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] | |||||
| for index, candidate in enumerate(candidate_resolution_i[1]): | |||||
| pair_txt.append( | |||||
| f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}') | |||||
| sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions' | |||||
| pair_txt.append( | |||||
| f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)') | |||||
| pair_prompt = '\n'.join(pair_txt) | |||||
| variables = { | |||||
| **self.prompt_variables, | |||||
| self._input_text_key: pair_prompt | |||||
| } | |||||
| text = perform_variable_replacements(self._resolution_prompt, variables=variables) | |||||
| async with chat_limiter: | |||||
| response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) | |||||
| result = self._process_results(len(candidate_resolution_i[1]), response, | |||||
| self.prompt_variables.get(self._record_delimiter_key, | |||||
| DEFAULT_RECORD_DELIMITER), | |||||
| self.prompt_variables.get(self._entity_index_dilimiter_key, | |||||
| DEFAULT_ENTITY_INDEX_DELIMITER), | |||||
| self.prompt_variables.get(self._resolution_result_delimiter_key, | |||||
| DEFAULT_RESOLUTION_RESULT_DELIMITER)) | |||||
| for result_i in result: | |||||
| resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) | |||||
| def _process_results( | def _process_results( | ||||
| self, | self, | ||||
| records_length: int, | records_length: int, |
| # Copyright (c) 2024 Microsoft Corporation. | |||||
| # Licensed under the MIT License | |||||
| """ | |||||
| Reference: | |||||
| - [graphrag](https://github.com/microsoft/graphrag) | |||||
| """ | |||||
| import logging | |||||
| import argparse | |||||
| import json | |||||
| import re | |||||
| import traceback | |||||
| from dataclasses import dataclass | |||||
| from typing import Any | |||||
| import tiktoken | |||||
| from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | |||||
| from graphrag.general.extractor import Extractor | |||||
| from rag.llm.chat_model import Base as CompletionLLM | |||||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |||||
| DEFAULT_TUPLE_DELIMITER = "<|>" | |||||
| DEFAULT_RECORD_DELIMITER = "##" | |||||
| DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" | |||||
| CLAIM_MAX_GLEANINGS = 1 | |||||
| @dataclass | |||||
| class ClaimExtractorResult: | |||||
| """Claim extractor result class definition.""" | |||||
| output: list[dict] | |||||
| source_docs: dict[str, Any] | |||||
| class ClaimExtractor(Extractor): | |||||
| """Claim extractor class definition.""" | |||||
| _extraction_prompt: str | |||||
| _summary_prompt: str | |||||
| _output_formatter_prompt: str | |||||
| _input_text_key: str | |||||
| _input_entity_spec_key: str | |||||
| _input_claim_description_key: str | |||||
| _tuple_delimiter_key: str | |||||
| _record_delimiter_key: str | |||||
| _completion_delimiter_key: str | |||||
| _max_gleanings: int | |||||
| _on_error: ErrorHandlerFn | |||||
| def __init__( | |||||
| self, | |||||
| llm_invoker: CompletionLLM, | |||||
| extraction_prompt: str | None = None, | |||||
| input_text_key: str | None = None, | |||||
| input_entity_spec_key: str | None = None, | |||||
| input_claim_description_key: str | None = None, | |||||
| input_resolved_entities_key: str | None = None, | |||||
| tuple_delimiter_key: str | None = None, | |||||
| record_delimiter_key: str | None = None, | |||||
| completion_delimiter_key: str | None = None, | |||||
| encoding_model: str | None = None, | |||||
| max_gleanings: int | None = None, | |||||
| on_error: ErrorHandlerFn | None = None, | |||||
| ): | |||||
| """Init method definition.""" | |||||
| self._llm = llm_invoker | |||||
| self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT | |||||
| self._input_text_key = input_text_key or "input_text" | |||||
| self._input_entity_spec_key = input_entity_spec_key or "entity_specs" | |||||
| self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" | |||||
| self._record_delimiter_key = record_delimiter_key or "record_delimiter" | |||||
| self._completion_delimiter_key = ( | |||||
| completion_delimiter_key or "completion_delimiter" | |||||
| ) | |||||
| self._input_claim_description_key = ( | |||||
| input_claim_description_key or "claim_description" | |||||
| ) | |||||
| self._input_resolved_entities_key = ( | |||||
| input_resolved_entities_key or "resolved_entities" | |||||
| ) | |||||
| self._max_gleanings = ( | |||||
| max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS | |||||
| ) | |||||
| self._on_error = on_error or (lambda _e, _s, _d: None) | |||||
| # Construct the looping arguments | |||||
| encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") | |||||
| yes = encoding.encode("YES") | |||||
| no = encoding.encode("NO") | |||||
| self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} | |||||
| def __call__( | |||||
| self, inputs: dict[str, Any], prompt_variables: dict | None = None | |||||
| ) -> ClaimExtractorResult: | |||||
| """Call method definition.""" | |||||
| if prompt_variables is None: | |||||
| prompt_variables = {} | |||||
| texts = inputs[self._input_text_key] | |||||
| entity_spec = str(inputs[self._input_entity_spec_key]) | |||||
| claim_description = inputs[self._input_claim_description_key] | |||||
| resolved_entities = inputs.get(self._input_resolved_entities_key, {}) | |||||
| source_doc_map = {} | |||||
| prompt_args = { | |||||
| self._input_entity_spec_key: entity_spec, | |||||
| self._input_claim_description_key: claim_description, | |||||
| self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) | |||||
| or DEFAULT_TUPLE_DELIMITER, | |||||
| self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) | |||||
| or DEFAULT_RECORD_DELIMITER, | |||||
| self._completion_delimiter_key: prompt_variables.get( | |||||
| self._completion_delimiter_key | |||||
| ) | |||||
| or DEFAULT_COMPLETION_DELIMITER, | |||||
| } | |||||
| all_claims: list[dict] = [] | |||||
| for doc_index, text in enumerate(texts): | |||||
| document_id = f"d{doc_index}" | |||||
| try: | |||||
| claims = self._process_document(prompt_args, text, doc_index) | |||||
| all_claims += [ | |||||
| self._clean_claim(c, document_id, resolved_entities) for c in claims | |||||
| ] | |||||
| source_doc_map[document_id] = text | |||||
| except Exception as e: | |||||
| logging.exception("error extracting claim") | |||||
| self._on_error( | |||||
| e, | |||||
| traceback.format_exc(), | |||||
| {"doc_index": doc_index, "text": text}, | |||||
| ) | |||||
| continue | |||||
| return ClaimExtractorResult( | |||||
| output=all_claims, | |||||
| source_docs=source_doc_map, | |||||
| ) | |||||
| def _clean_claim( | |||||
| self, claim: dict, document_id: str, resolved_entities: dict | |||||
| ) -> dict: | |||||
| # clean the parsed claims to remove any claims with status = False | |||||
| obj = claim.get("object_id", claim.get("object")) | |||||
| subject = claim.get("subject_id", claim.get("subject")) | |||||
| # If subject or object in resolved entities, then replace with resolved entity | |||||
| obj = resolved_entities.get(obj, obj) | |||||
| subject = resolved_entities.get(subject, subject) | |||||
| claim["object_id"] = obj | |||||
| claim["subject_id"] = subject | |||||
| claim["doc_id"] = document_id | |||||
| return claim | |||||
| def _process_document( | |||||
| self, prompt_args: dict, doc, doc_index: int | |||||
| ) -> list[dict]: | |||||
| record_delimiter = prompt_args.get( | |||||
| self._record_delimiter_key, DEFAULT_RECORD_DELIMITER | |||||
| ) | |||||
| completion_delimiter = prompt_args.get( | |||||
| self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER | |||||
| ) | |||||
| variables = { | |||||
| self._input_text_key: doc, | |||||
| **prompt_args, | |||||
| } | |||||
| text = perform_variable_replacements(self._extraction_prompt, variables=variables) | |||||
| gen_conf = {"temperature": 0.5} | |||||
| results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||||
| claims = results.strip().removesuffix(completion_delimiter) | |||||
| history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}] | |||||
| # Repeat to ensure we maximize entity count | |||||
| for i in range(self._max_gleanings): | |||||
| text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) | |||||
| history.append({"role": "user", "content": text}) | |||||
| extension = self._chat("", history, gen_conf) | |||||
| claims += record_delimiter + extension.strip().removesuffix( | |||||
| completion_delimiter | |||||
| ) | |||||
| # If this isn't the last loop, check to see if we should continue | |||||
| if i >= self._max_gleanings - 1: | |||||
| break | |||||
| history.append({"role": "assistant", "content": extension}) | |||||
| history.append({"role": "user", "content": LOOP_PROMPT}) | |||||
| continuation = self._chat("", history, self._loop_args) | |||||
| if continuation != "YES": | |||||
| break | |||||
| result = self._parse_claim_tuples(claims, prompt_args) | |||||
| for r in result: | |||||
| r["doc_id"] = f"{doc_index}" | |||||
| return result | |||||
| def _parse_claim_tuples( | |||||
| self, claims: str, prompt_variables: dict | |||||
| ) -> list[dict[str, Any]]: | |||||
| """Parse claim tuples.""" | |||||
| record_delimiter = prompt_variables.get( | |||||
| self._record_delimiter_key, DEFAULT_RECORD_DELIMITER | |||||
| ) | |||||
| completion_delimiter = prompt_variables.get( | |||||
| self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER | |||||
| ) | |||||
| tuple_delimiter = prompt_variables.get( | |||||
| self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER | |||||
| ) | |||||
| def pull_field(index: int, fields: list[str]) -> str | None: | |||||
| return fields[index].strip() if len(fields) > index else None | |||||
| result: list[dict[str, Any]] = [] | |||||
| claims_values = ( | |||||
| claims.strip().removesuffix(completion_delimiter).split(record_delimiter) | |||||
| ) | |||||
| for claim in claims_values: | |||||
| claim = claim.strip().removeprefix("(").removesuffix(")") | |||||
| claim = re.sub(r".*Output:", "", claim) | |||||
| # Ignore the completion delimiter | |||||
| if claim == completion_delimiter: | |||||
| continue | |||||
| claim_fields = claim.split(tuple_delimiter) | |||||
| o = { | |||||
| "subject_id": pull_field(0, claim_fields), | |||||
| "object_id": pull_field(1, claim_fields), | |||||
| "type": pull_field(2, claim_fields), | |||||
| "status": pull_field(3, claim_fields), | |||||
| "start_date": pull_field(4, claim_fields), | |||||
| "end_date": pull_field(5, claim_fields), | |||||
| "description": pull_field(6, claim_fields), | |||||
| "source_text": pull_field(7, claim_fields), | |||||
| "doc_id": pull_field(8, claim_fields), | |||||
| } | |||||
| if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]): | |||||
| continue | |||||
| result.append(o) | |||||
| return result | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) | |||||
| parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) | |||||
| args = parser.parse_args() | |||||
| from api.db import LLMType | |||||
| from api.db.services.llm_service import LLMBundle | |||||
| from api import settings | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||||
| kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) | |||||
| ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) | |||||
| docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] | |||||
| info = { | |||||
| "input_text": docs, | |||||
| "entity_specs": "organization, person", | |||||
| "claim_description": "" | |||||
| } | |||||
| claim = ex(info) | |||||
| logging.info(json.dumps(claim.output, ensure_ascii=False, indent=2)) |
| # Copyright (c) 2024 Microsoft Corporation. | |||||
| # Licensed under the MIT License | |||||
| """ | |||||
| Reference: | |||||
| - [graphrag](https://github.com/microsoft/graphrag) | |||||
| """ | |||||
| CLAIM_EXTRACTION_PROMPT = """ | |||||
| ################ | |||||
| -Target activity- | |||||
| ################ | |||||
| You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document. | |||||
| ################ | |||||
| -Goal- | |||||
| ################ | |||||
| Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities. | |||||
| ################ | |||||
| -Steps- | |||||
| ################ | |||||
| - 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types. | |||||
| - 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim. | |||||
| For each claim, extract the following information: | |||||
| - Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1. | |||||
| - Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**. | |||||
| - Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type | |||||
| - Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified. | |||||
| - Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references. | |||||
| - Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. | |||||
| - Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. | |||||
| - 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>) | |||||
| - 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. | |||||
| - 5. If there's nothing satisfy the above requirements, just keep output empty. | |||||
| - 6. When finished, output {completion_delimiter} | |||||
| ################ | |||||
| -Examples- | |||||
| ################ | |||||
| Example 1: | |||||
| Entity specification: organization | |||||
| Claim description: red flags associated with an entity | |||||
| Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. | |||||
| Output: | |||||
| (COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) | |||||
| {completion_delimiter} | |||||
| ########################### | |||||
| Example 2: | |||||
| Entity specification: Company A, Person C | |||||
| Claim description: red flags associated with an entity | |||||
| Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. | |||||
| Output: | |||||
| (COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) | |||||
| {record_delimiter} | |||||
| (PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) | |||||
| {completion_delimiter} | |||||
| ################ | |||||
| -Real Data- | |||||
| ################ | |||||
| Use the following input for your answer. | |||||
| Entity specification: {entity_specs} | |||||
| Claim description: {claim_description} | |||||
| Text: {input_text} | |||||
| Output:""" | |||||
| CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: " | |||||
| LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n" |
| from graphrag.general.extractor import Extractor | from graphrag.general.extractor import Extractor | ||||
| from graphrag.general.leiden import add_community_info2graph | from graphrag.general.leiden import add_community_info2graph | ||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types | |||||
| from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter | |||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| import trio | |||||
| @dataclass | @dataclass | ||||
| self._extraction_prompt = COMMUNITY_REPORT_PROMPT | self._extraction_prompt = COMMUNITY_REPORT_PROMPT | ||||
| self._max_report_length = max_report_length or 1500 | self._max_report_length = max_report_length or 1500 | ||||
| def __call__(self, graph: nx.Graph, callback: Callable | None = None): | |||||
| async def __call__(self, graph: nx.Graph, callback: Callable | None = None): | |||||
| for node_degree in graph.degree: | for node_degree in graph.degree: | ||||
| graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) | graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) | ||||
| } | } | ||||
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | ||||
| gen_conf = {"temperature": 0.3} | gen_conf = {"temperature": 0.3} | ||||
| try: | |||||
| response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||||
| token_count += num_tokens_from_string(text + response) | |||||
| response = re.sub(r"^[^\{]*", "", response) | |||||
| response = re.sub(r"[^\}]*$", "", response) | |||||
| response = re.sub(r"\{\{", "{", response) | |||||
| response = re.sub(r"\}\}", "}", response) | |||||
| logging.debug(response) | |||||
| response = json.loads(response) | |||||
| if not dict_has_keys_with_types(response, [ | |||||
| ("title", str), | |||||
| ("summary", str), | |||||
| ("findings", list), | |||||
| ("rating", float), | |||||
| ("rating_explanation", str), | |||||
| ]): | |||||
| continue | |||||
| response["weight"] = weight | |||||
| response["entities"] = ents | |||||
| except Exception: | |||||
| logging.exception("CommunityReportsExtractor got exception") | |||||
| async with chat_limiter: | |||||
| response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) | |||||
| token_count += num_tokens_from_string(text + response) | |||||
| response = re.sub(r"^[^\{]*", "", response) | |||||
| response = re.sub(r"[^\}]*$", "", response) | |||||
| response = re.sub(r"\{\{", "{", response) | |||||
| response = re.sub(r"\}\}", "}", response) | |||||
| logging.debug(response) | |||||
| response = json.loads(response) | |||||
| if not dict_has_keys_with_types(response, [ | |||||
| ("title", str), | |||||
| ("summary", str), | |||||
| ("findings", list), | |||||
| ("rating", float), | |||||
| ("rating_explanation", str), | |||||
| ]): | |||||
| continue | continue | ||||
| response["weight"] = weight | |||||
| response["entities"] = ents | |||||
| add_community_info2graph(graph, ents, response["title"]) | add_community_info2graph(graph, ents, response["title"]) | ||||
| res_str.append(self._get_text_output(response)) | res_str.append(self._get_text_output(response)) |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import logging | import logging | ||||
| import os | |||||
| import re | import re | ||||
| from collections import defaultdict, Counter | from collections import defaultdict, Counter | ||||
| from concurrent.futures import ThreadPoolExecutor | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from typing import Callable | from typing import Callable | ||||
| import trio | |||||
| from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT | from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT | ||||
| from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ | from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ | ||||
| handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list | |||||
| handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter | |||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| from rag.utils import truncate | from rag.utils import truncate | ||||
| ) | ) | ||||
| return dict(maybe_nodes), dict(maybe_edges) | return dict(maybe_nodes), dict(maybe_edges) | ||||
| def __call__( | |||||
| async def __call__( | |||||
| self, chunks: list[tuple[str, str]], | self, chunks: list[tuple[str, str]], | ||||
| callback: Callable | None = None | callback: Callable | None = None | ||||
| ): | ): | ||||
| results = [] | |||||
| max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10)) | |||||
| with ThreadPoolExecutor(max_workers=max_workers) as exe: | |||||
| threads = [] | |||||
| self.callback = callback | |||||
| start_ts = trio.current_time() | |||||
| out_results = [] | |||||
| async with trio.open_nursery() as nursery: | |||||
| for i, (cid, ck) in enumerate(chunks): | for i, (cid, ck) in enumerate(chunks): | ||||
| ck = truncate(ck, int(self._llm.max_length*0.8)) | ck = truncate(ck, int(self._llm.max_length*0.8)) | ||||
| threads.append( | |||||
| exe.submit(self._process_single_content, (cid, ck))) | |||||
| for i, _ in enumerate(threads): | |||||
| n, r, tc = _.result() | |||||
| if not isinstance(n, Exception): | |||||
| results.append((n, r)) | |||||
| if callback: | |||||
| callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)") | |||||
| elif callback: | |||||
| callback(msg="Knowledge graph extraction error:{}".format(str(n))) | |||||
| nursery.start_soon(self._process_single_content, (cid, ck), i, len(chunks), out_results) | |||||
| maybe_nodes = defaultdict(list) | maybe_nodes = defaultdict(list) | ||||
| maybe_edges = defaultdict(list) | maybe_edges = defaultdict(list) | ||||
| for m_nodes, m_edges in results: | |||||
| sum_token_count = 0 | |||||
| for m_nodes, m_edges, token_count in out_results: | |||||
| for k, v in m_nodes.items(): | for k, v in m_nodes.items(): | ||||
| maybe_nodes[k].extend(v) | maybe_nodes[k].extend(v) | ||||
| for k, v in m_edges.items(): | for k, v in m_edges.items(): | ||||
| maybe_edges[tuple(sorted(k))].extend(v) | maybe_edges[tuple(sorted(k))].extend(v) | ||||
| logging.info("Inserting entities into storage...") | |||||
| sum_token_count += token_count | |||||
| now = trio.current_time() | |||||
| if callback: | |||||
| callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.") | |||||
| start_ts = now | |||||
| logging.info("Entities merging...") | |||||
| all_entities_data = [] | all_entities_data = [] | ||||
| with ThreadPoolExecutor(max_workers=max_workers) as exe: | |||||
| threads = [] | |||||
| async with trio.open_nursery() as nursery: | |||||
| for en_nm, ents in maybe_nodes.items(): | for en_nm, ents in maybe_nodes.items(): | ||||
| threads.append( | |||||
| exe.submit(self._merge_nodes, en_nm, ents)) | |||||
| for t in threads: | |||||
| n = t.result() | |||||
| if not isinstance(n, Exception): | |||||
| all_entities_data.append(n) | |||||
| elif callback: | |||||
| callback(msg="Knowledge graph nodes merging error: {}".format(str(n))) | |||||
| nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data) | |||||
| now = trio.current_time() | |||||
| if callback: | |||||
| callback(msg = f"Entities merging done, {now-start_ts:.2f}s.") | |||||
| logging.info("Inserting relationships into storage...") | |||||
| start_ts = now | |||||
| logging.info("Relationships merging...") | |||||
| all_relationships_data = [] | all_relationships_data = [] | ||||
| for (src, tgt), rels in maybe_edges.items(): | |||||
| all_relationships_data.append(self._merge_edges(src, tgt, rels)) | |||||
| async with trio.open_nursery() as nursery: | |||||
| for (src, tgt), rels in maybe_edges.items(): | |||||
| nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data) | |||||
| now = trio.current_time() | |||||
| if callback: | |||||
| callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.") | |||||
| if not len(all_entities_data) and not len(all_relationships_data): | if not len(all_entities_data) and not len(all_relationships_data): | ||||
| logging.warning( | logging.warning( | ||||
| return all_entities_data, all_relationships_data | return all_entities_data, all_relationships_data | ||||
| def _merge_nodes(self, entity_name: str, entities: list[dict]): | |||||
| async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data): | |||||
| if not entities: | if not entities: | ||||
| return | return | ||||
| already_entity_types = [] | already_entity_types = [] | ||||
| sorted(set([dp["description"] for dp in entities] + already_description)) | sorted(set([dp["description"] for dp in entities] + already_description)) | ||||
| ) | ) | ||||
| already_source_ids = flat_uniq_list(entities, "source_id") | already_source_ids = flat_uniq_list(entities, "source_id") | ||||
| try: | |||||
| description = self._handle_entity_relation_summary( | |||||
| entity_name, description | |||||
| ) | |||||
| node_data = dict( | |||||
| entity_type=entity_type, | |||||
| description=description, | |||||
| source_id=already_source_ids, | |||||
| ) | |||||
| node_data["entity_name"] = entity_name | |||||
| self._set_entity_(entity_name, node_data) | |||||
| return node_data | |||||
| except Exception as e: | |||||
| return e | |||||
| description = await self._handle_entity_relation_summary(entity_name, description) | |||||
| node_data = dict( | |||||
| entity_type=entity_type, | |||||
| description=description, | |||||
| source_id=already_source_ids, | |||||
| ) | |||||
| node_data["entity_name"] = entity_name | |||||
| self._set_entity_(entity_name, node_data) | |||||
| all_relationships_data.append(node_data) | |||||
| def _merge_edges( | |||||
| async def _merge_edges( | |||||
| self, | self, | ||||
| src_id: str, | src_id: str, | ||||
| tgt_id: str, | tgt_id: str, | ||||
| edges_data: list[dict] | |||||
| edges_data: list[dict], | |||||
| all_relationships_data | |||||
| ): | ): | ||||
| if not edges_data: | if not edges_data: | ||||
| return | return | ||||
| "description": description, | "description": description, | ||||
| "entity_type": 'UNKNOWN' | "entity_type": 'UNKNOWN' | ||||
| }) | }) | ||||
| description = self._handle_entity_relation_summary( | |||||
| description = await self._handle_entity_relation_summary( | |||||
| f"({src_id}, {tgt_id})", description | f"({src_id}, {tgt_id})", description | ||||
| ) | ) | ||||
| edge_data = dict( | edge_data = dict( | ||||
| source_id=source_id | source_id=source_id | ||||
| ) | ) | ||||
| self._set_relation_(src_id, tgt_id, edge_data) | self._set_relation_(src_id, tgt_id, edge_data) | ||||
| all_relationships_data.append(edge_data) | |||||
| return edge_data | |||||
| def _handle_entity_relation_summary( | |||||
| async def _handle_entity_relation_summary( | |||||
| self, | self, | ||||
| entity_or_relation_name: str, | entity_or_relation_name: str, | ||||
| description: str | description: str | ||||
| ) | ) | ||||
| use_prompt = prompt_template.format(**context_base) | use_prompt = prompt_template.format(**context_base) | ||||
| logging.info(f"Trigger summary: {entity_or_relation_name}") | logging.info(f"Trigger summary: {entity_or_relation_name}") | ||||
| summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}) | |||||
| async with chat_limiter: | |||||
| summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})) | |||||
| return summary | return summary |
| - [graphrag](https://github.com/microsoft/graphrag) | - [graphrag](https://github.com/microsoft/graphrag) | ||||
| """ | """ | ||||
| import logging | |||||
| import re | import re | ||||
| from typing import Any, Callable | from typing import Any, Callable | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| import tiktoken | import tiktoken | ||||
| import trio | |||||
| from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES | from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES | ||||
| from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | ||||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |||||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter | |||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| import networkx as nx | import networkx as nx | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES), | self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES), | ||||
| } | } | ||||
| def _process_single_content(self, | |||||
| chunk_key_dp: tuple[str, str] | |||||
| ): | |||||
| async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results): | |||||
| token_count = 0 | token_count = 0 | ||||
| chunk_key = chunk_key_dp[0] | chunk_key = chunk_key_dp[0] | ||||
| content = chunk_key_dp[1] | content = chunk_key_dp[1] | ||||
| variables = { | variables = { | ||||
| **self._prompt_variables, | **self._prompt_variables, | ||||
| self._input_text_key: content, | self._input_text_key: content, | ||||
| } | } | ||||
| try: | |||||
| gen_conf = {"temperature": 0.3} | |||||
| hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) | |||||
| response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf) | |||||
| token_count += num_tokens_from_string(hint_prompt + response) | |||||
| results = response or "" | |||||
| history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}] | |||||
| # Repeat to ensure we maximize entity count | |||||
| for i in range(self._max_gleanings): | |||||
| text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) | |||||
| history.append({"role": "user", "content": text}) | |||||
| response = self._chat("", history, gen_conf) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) | |||||
| results += response or "" | |||||
| # if this is the final glean, don't bother updating the continuation flag | |||||
| if i >= self._max_gleanings - 1: | |||||
| break | |||||
| history.append({"role": "assistant", "content": response}) | |||||
| history.append({"role": "user", "content": LOOP_PROMPT}) | |||||
| continuation = self._chat("", history, {"temperature": 0.8}) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) | |||||
| if continuation != "YES": | |||||
| break | |||||
| record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER) | |||||
| tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER) | |||||
| records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)] | |||||
| records = [r for r in records if r.strip()] | |||||
| maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter) | |||||
| return maybe_nodes, maybe_edges, token_count | |||||
| except Exception as e: | |||||
| logging.exception("error extracting graph") | |||||
| return e, None, None | |||||
| gen_conf = {"temperature": 0.3} | |||||
| hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) | |||||
| async with chat_limiter: | |||||
| response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)) | |||||
| token_count += num_tokens_from_string(hint_prompt + response) | |||||
| results = response or "" | |||||
| history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}] | |||||
| # Repeat to ensure we maximize entity count | |||||
| for i in range(self._max_gleanings): | |||||
| text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) | |||||
| history.append({"role": "user", "content": text}) | |||||
| async with chat_limiter: | |||||
| response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf)) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) | |||||
| results += response or "" | |||||
| # if this is the final glean, don't bother updating the continuation flag | |||||
| if i >= self._max_gleanings - 1: | |||||
| break | |||||
| history.append({"role": "assistant", "content": response}) | |||||
| history.append({"role": "user", "content": LOOP_PROMPT}) | |||||
| async with chat_limiter: | |||||
| continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8})) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) | |||||
| if continuation != "YES": | |||||
| break | |||||
| record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER) | |||||
| tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER) | |||||
| records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)] | |||||
| records = [r for r in records if r.strip()] | |||||
| maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter) | |||||
| out_results.append((maybe_nodes, maybe_edges, token_count)) | |||||
| if self.callback: | |||||
| self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.") |
| import logging | import logging | ||||
| from functools import reduce, partial | from functools import reduce, partial | ||||
| import networkx as nx | import networkx as nx | ||||
| import trio | |||||
| from api import settings | from api import settings | ||||
| from graphrag.general.community_reports_extractor import CommunityReportsExtractor | from graphrag.general.community_reports_extractor import CommunityReportsExtractor | ||||
| embed_bdl=None, | embed_bdl=None, | ||||
| callback=None | callback=None | ||||
| ): | ): | ||||
| docids = list(set([docid for docid,_ in chunks])) | |||||
| self.tenant_id = tenant_id | |||||
| self.kb_id = kb_id | |||||
| self.chunks = chunks | |||||
| self.llm_bdl = llm_bdl | self.llm_bdl = llm_bdl | ||||
| self.embed_bdl = embed_bdl | self.embed_bdl = embed_bdl | ||||
| ext = extractor(self.llm_bdl, language=language, | |||||
| self.ext = extractor(self.llm_bdl, language=language, | |||||
| entity_types=entity_types, | entity_types=entity_types, | ||||
| get_entity=partial(get_entity, tenant_id, kb_id), | get_entity=partial(get_entity, tenant_id, kb_id), | ||||
| set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), | set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), | ||||
| get_relation=partial(get_relation, tenant_id, kb_id), | get_relation=partial(get_relation, tenant_id, kb_id), | ||||
| set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) | set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) | ||||
| ) | ) | ||||
| ents, rels = ext(chunks, callback) | |||||
| self.graph = nx.Graph() | self.graph = nx.Graph() | ||||
| self.callback = callback | |||||
| async def __call__(self): | |||||
| docids = list(set([docid for docid, _ in self.chunks])) | |||||
| ents, rels = await self.ext(self.chunks, self.callback) | |||||
| for en in ents: | for en in ents: | ||||
| self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"]) | self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"]) | ||||
| #description=rel["description"] | #description=rel["description"] | ||||
| ) | ) | ||||
| with RedisDistributedLock(kb_id, 60*60): | |||||
| old_graph, old_doc_ids = get_graph(tenant_id, kb_id) | |||||
| with RedisDistributedLock(self.kb_id, 60*60): | |||||
| old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id) | |||||
| if old_graph is not None: | if old_graph is not None: | ||||
| logging.info("Merge with an exiting graph...................") | logging.info("Merge with an exiting graph...................") | ||||
| self.graph = reduce(graph_merge, [old_graph, self.graph]) | self.graph = reduce(graph_merge, [old_graph, self.graph]) | ||||
| update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2) | |||||
| update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2) | |||||
| if old_doc_ids: | if old_doc_ids: | ||||
| docids.extend(old_doc_ids) | docids.extend(old_doc_ids) | ||||
| docids = list(set(docids)) | docids = list(set(docids)) | ||||
| set_graph(tenant_id, kb_id, self.graph, docids) | |||||
| set_graph(self.tenant_id, self.kb_id, self.graph, docids) | |||||
| class WithResolution(Dealer): | class WithResolution(Dealer): | ||||
| embed_bdl=None, | embed_bdl=None, | ||||
| callback=None | callback=None | ||||
| ): | ): | ||||
| self.tenant_id = tenant_id | |||||
| self.kb_id = kb_id | |||||
| self.llm_bdl = llm_bdl | self.llm_bdl = llm_bdl | ||||
| self.embed_bdl = embed_bdl | self.embed_bdl = embed_bdl | ||||
| with RedisDistributedLock(kb_id, 60*60): | |||||
| self.graph, doc_ids = get_graph(tenant_id, kb_id) | |||||
| self.callback = callback | |||||
| async def __call__(self): | |||||
| with RedisDistributedLock(self.kb_id, 60*60): | |||||
| self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id)) | |||||
| if not self.graph: | if not self.graph: | ||||
| logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}") | |||||
| if callback: | |||||
| callback(-1, msg="Faild to fetch the graph.") | |||||
| logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") | |||||
| if self.callback: | |||||
| self.callback(-1, msg="Faild to fetch the graph.") | |||||
| return | return | ||||
| if callback: | |||||
| callback(msg="Fetch the existing graph.") | |||||
| if self.callback: | |||||
| self.callback(msg="Fetch the existing graph.") | |||||
| er = EntityResolution(self.llm_bdl, | er = EntityResolution(self.llm_bdl, | ||||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||||
| set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), | |||||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||||
| set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)) | |||||
| reso = er(self.graph) | |||||
| get_entity=partial(get_entity, self.tenant_id, self.kb_id), | |||||
| set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), | |||||
| get_relation=partial(get_relation, self.tenant_id, self.kb_id), | |||||
| set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) | |||||
| reso = await er(self.graph) | |||||
| self.graph = reso.graph | self.graph = reso.graph | ||||
| logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | ||||
| if callback: | |||||
| callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | |||||
| update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2) | |||||
| set_graph(tenant_id, kb_id, self.graph, doc_ids) | |||||
| if self.callback: | |||||
| self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | |||||
| await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)) | |||||
| await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) | |||||
| settings.docStoreConn.delete({ | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| "knowledge_graph_kwd": "relation", | "knowledge_graph_kwd": "relation", | ||||
| "kb_id": kb_id, | |||||
| "kb_id": self.kb_id, | |||||
| "from_entity_kwd": reso.removed_entities | "from_entity_kwd": reso.removed_entities | ||||
| }, search.index_name(tenant_id), kb_id) | |||||
| settings.docStoreConn.delete({ | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| "knowledge_graph_kwd": "relation", | "knowledge_graph_kwd": "relation", | ||||
| "kb_id": kb_id, | |||||
| "kb_id": self.kb_id, | |||||
| "to_entity_kwd": reso.removed_entities | "to_entity_kwd": reso.removed_entities | ||||
| }, search.index_name(tenant_id), kb_id) | |||||
| settings.docStoreConn.delete({ | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| "knowledge_graph_kwd": "entity", | "knowledge_graph_kwd": "entity", | ||||
| "kb_id": kb_id, | |||||
| "kb_id": self.kb_id, | |||||
| "entity_kwd": reso.removed_entities | "entity_kwd": reso.removed_entities | ||||
| }, search.index_name(tenant_id), kb_id) | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| class WithCommunity(Dealer): | class WithCommunity(Dealer): | ||||
| callback=None | callback=None | ||||
| ): | ): | ||||
| self.tenant_id = tenant_id | |||||
| self.kb_id = kb_id | |||||
| self.community_structure = None | self.community_structure = None | ||||
| self.community_reports = None | self.community_reports = None | ||||
| self.llm_bdl = llm_bdl | self.llm_bdl = llm_bdl | ||||
| self.embed_bdl = embed_bdl | self.embed_bdl = embed_bdl | ||||
| with RedisDistributedLock(kb_id, 60*60): | |||||
| self.graph, doc_ids = get_graph(tenant_id, kb_id) | |||||
| self.callback = callback | |||||
| async def __call__(self): | |||||
| with RedisDistributedLock(self.kb_id, 60*60): | |||||
| self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id) | |||||
| if not self.graph: | if not self.graph: | ||||
| logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}") | |||||
| if callback: | |||||
| callback(-1, msg="Faild to fetch the graph.") | |||||
| logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") | |||||
| if self.callback: | |||||
| self.callback(-1, msg="Faild to fetch the graph.") | |||||
| return | return | ||||
| if callback: | |||||
| callback(msg="Fetch the existing graph.") | |||||
| if self.callback: | |||||
| self.callback(msg="Fetch the existing graph.") | |||||
| cr = CommunityReportsExtractor(self.llm_bdl, | cr = CommunityReportsExtractor(self.llm_bdl, | ||||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||||
| set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), | |||||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||||
| set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)) | |||||
| cr = cr(self.graph, callback=callback) | |||||
| get_entity=partial(get_entity, self.tenant_id, self.kb_id), | |||||
| set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), | |||||
| get_relation=partial(get_relation, self.tenant_id, self.kb_id), | |||||
| set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) | |||||
| cr = await cr(self.graph, callback=self.callback) | |||||
| self.community_structure = cr.structured_output | self.community_structure = cr.structured_output | ||||
| self.community_reports = cr.output | self.community_reports = cr.output | ||||
| set_graph(tenant_id, kb_id, self.graph, doc_ids) | |||||
| await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) | |||||
| if callback: | |||||
| callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) | |||||
| if self.callback: | |||||
| self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) | |||||
| settings.docStoreConn.delete({ | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| "knowledge_graph_kwd": "community_report", | "knowledge_graph_kwd": "community_report", | ||||
| "kb_id": kb_id | |||||
| }, search.index_name(tenant_id), kb_id) | |||||
| "kb_id": self.kb_id | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| for stru, rep in zip(self.community_structure, self.community_reports): | for stru, rep in zip(self.community_structure, self.community_reports): | ||||
| obj = { | obj = { | ||||
| "weight_flt": stru["weight"], | "weight_flt": stru["weight"], | ||||
| "entities_kwd": stru["entities"], | "entities_kwd": stru["entities"], | ||||
| "important_kwd": stru["entities"], | "important_kwd": stru["entities"], | ||||
| "kb_id": kb_id, | |||||
| "kb_id": self.kb_id, | |||||
| "source_id": doc_ids, | "source_id": doc_ids, | ||||
| "available_int": 0 | "available_int": 0 | ||||
| } | } | ||||
| # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] | # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] | ||||
| #except Exception as e: | #except Exception as e: | ||||
| # logging.exception(f"Fail to embed entity relation: {e}") | # logging.exception(f"Fail to embed entity relation: {e}") | ||||
| settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id))) | |||||
| import logging | import logging | ||||
| import collections | import collections | ||||
| import os | |||||
| import re | import re | ||||
| import traceback | |||||
| from typing import Any | from typing import Any | ||||
| from concurrent.futures import ThreadPoolExecutor | |||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| import trio | |||||
| from graphrag.general.extractor import Extractor | from graphrag.general.extractor import Extractor | ||||
| from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT | from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT | ||||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |||||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter | |||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| import markdown_to_json | import markdown_to_json | ||||
| from functools import reduce | from functools import reduce | ||||
| ) | ) | ||||
| return arr | return arr | ||||
| def __call__( | |||||
| async def __call__( | |||||
| self, sections: list[str], prompt_variables: dict[str, Any] | None = None | self, sections: list[str], prompt_variables: dict[str, Any] | None = None | ||||
| ) -> MindMapResult: | ) -> MindMapResult: | ||||
| """Call method definition.""" | """Call method definition.""" | ||||
| if prompt_variables is None: | if prompt_variables is None: | ||||
| prompt_variables = {} | prompt_variables = {} | ||||
| try: | |||||
| res = [] | |||||
| max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12)) | |||||
| with ThreadPoolExecutor(max_workers=max_workers) as exe: | |||||
| threads = [] | |||||
| token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) | |||||
| texts = [] | |||||
| cnt = 0 | |||||
| for i in range(len(sections)): | |||||
| section_cnt = num_tokens_from_string(sections[i]) | |||||
| if cnt + section_cnt >= token_count and texts: | |||||
| threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) | |||||
| texts = [] | |||||
| cnt = 0 | |||||
| texts.append(sections[i]) | |||||
| cnt += section_cnt | |||||
| if texts: | |||||
| threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) | |||||
| for i, _ in enumerate(threads): | |||||
| res.append(_.result()) | |||||
| if not res: | |||||
| return MindMapResult(output={"id": "root", "children": []}) | |||||
| merge_json = reduce(self._merge, res) | |||||
| if len(merge_json) > 1: | |||||
| keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)] | |||||
| keyset = set(i for i in keys if i) | |||||
| merge_json = { | |||||
| "id": "root", | |||||
| "children": [ | |||||
| { | |||||
| "id": self._key(k), | |||||
| "children": self._be_children(v, keyset) | |||||
| } | |||||
| for k, v in merge_json.items() if isinstance(v, dict) and self._key(k) | |||||
| ] | |||||
| } | |||||
| else: | |||||
| k = self._key(list(merge_json.keys())[0]) | |||||
| merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})} | |||||
| except Exception as e: | |||||
| logging.exception("error mind graph") | |||||
| self._on_error( | |||||
| e, | |||||
| traceback.format_exc(), None | |||||
| ) | |||||
| merge_json = {"error": str(e)} | |||||
| res = [] | |||||
| token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) | |||||
| texts = [] | |||||
| cnt = 0 | |||||
| async with trio.open_nursery() as nursery: | |||||
| for i in range(len(sections)): | |||||
| section_cnt = num_tokens_from_string(sections[i]) | |||||
| if cnt + section_cnt >= token_count and texts: | |||||
| nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) | |||||
| texts = [] | |||||
| cnt = 0 | |||||
| texts.append(sections[i]) | |||||
| cnt += section_cnt | |||||
| if texts: | |||||
| nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) | |||||
| if not res: | |||||
| return MindMapResult(output={"id": "root", "children": []}) | |||||
| merge_json = reduce(self._merge, res) | |||||
| if len(merge_json) > 1: | |||||
| keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)] | |||||
| keyset = set(i for i in keys if i) | |||||
| merge_json = { | |||||
| "id": "root", | |||||
| "children": [ | |||||
| { | |||||
| "id": self._key(k), | |||||
| "children": self._be_children(v, keyset) | |||||
| } | |||||
| for k, v in merge_json.items() if isinstance(v, dict) and self._key(k) | |||||
| ] | |||||
| } | |||||
| else: | |||||
| k = self._key(list(merge_json.keys())[0]) | |||||
| merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})} | |||||
| return MindMapResult(output=merge_json) | return MindMapResult(output=merge_json) | ||||
| return self._list_to_kv(to_ret) | return self._list_to_kv(to_ret) | ||||
| def _process_document( | |||||
| self, text: str, prompt_variables: dict[str, str] | |||||
| async def _process_document( | |||||
| self, text: str, prompt_variables: dict[str, str], out_res | |||||
| ) -> str: | ) -> str: | ||||
| variables = { | variables = { | ||||
| **prompt_variables, | **prompt_variables, | ||||
| } | } | ||||
| text = perform_variable_replacements(self._mind_map_prompt, variables=variables) | text = perform_variable_replacements(self._mind_map_prompt, variables=variables) | ||||
| gen_conf = {"temperature": 0.5} | gen_conf = {"temperature": 0.5} | ||||
| response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||||
| async with chat_limiter: | |||||
| response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) | |||||
| response = re.sub(r"```[^\n]*", "", response) | response = re.sub(r"```[^\n]*", "", response) | ||||
| logging.debug(response) | logging.debug(response) | ||||
| logging.debug(self._todict(markdown_to_json.dictify(response))) | logging.debug(self._todict(markdown_to_json.dictify(response))) | ||||
| return self._todict(markdown_to_json.dictify(response)) | |||||
| out_res.append(self._todict(markdown_to_json.dictify(response))) |
| import json | import json | ||||
| import networkx as nx | import networkx as nx | ||||
| import trio | |||||
| from api import settings | from api import settings | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | ||||
| dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) | dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) | ||||
| trio.run(dealer()) | |||||
| print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) | print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) | ||||
| dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl) | dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl) | ||||
| trio.run(dealer()) | |||||
| dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl) | dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl) | ||||
| trio.run(dealer()) | |||||
| print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports) | print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports) | ||||
| print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2)) | print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2)) |
| Reference: | Reference: | ||||
| - [graphrag](https://github.com/microsoft/graphrag) | - [graphrag](https://github.com/microsoft/graphrag) | ||||
| """ | """ | ||||
| import logging | |||||
| import re | import re | ||||
| from typing import Any, Callable | from typing import Any, Callable | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS | from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS | ||||
| from graphrag.light.graph_prompt import PROMPTS | from graphrag.light.graph_prompt import PROMPTS | ||||
| from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers | |||||
| from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter | |||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| import networkx as nx | import networkx as nx | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| import trio | |||||
| @dataclass | @dataclass | ||||
| ) | ) | ||||
| self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count) | self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count) | ||||
| def _process_single_content(self, chunk_key_dp: tuple[str, str]): | |||||
| async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results): | |||||
| token_count = 0 | token_count = 0 | ||||
| chunk_key = chunk_key_dp[0] | chunk_key = chunk_key_dp[0] | ||||
| content = chunk_key_dp[1] | content = chunk_key_dp[1] | ||||
| **self._context_base, input_text="{input_text}" | **self._context_base, input_text="{input_text}" | ||||
| ).format(**self._context_base, input_text=content) | ).format(**self._context_base, input_text=content) | ||||
| try: | |||||
| gen_conf = {"temperature": 0.8} | |||||
| final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf) | |||||
| token_count += num_tokens_from_string(hint_prompt + final_result) | |||||
| history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt) | |||||
| for now_glean_index in range(self._max_gleanings): | |||||
| glean_result = self._chat(hint_prompt, history, gen_conf) | |||||
| history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}]) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) | |||||
| final_result += glean_result | |||||
| if now_glean_index == self._max_gleanings - 1: | |||||
| break | |||||
| gen_conf = {"temperature": 0.8} | |||||
| async with chat_limiter: | |||||
| final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)) | |||||
| token_count += num_tokens_from_string(hint_prompt + final_result) | |||||
| history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt) | |||||
| for now_glean_index in range(self._max_gleanings): | |||||
| async with chat_limiter: | |||||
| glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) | |||||
| history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}]) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) | |||||
| final_result += glean_result | |||||
| if now_glean_index == self._max_gleanings - 1: | |||||
| break | |||||
| if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) | |||||
| if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() | |||||
| if if_loop_result != "yes": | |||||
| break | |||||
| async with chat_limiter: | |||||
| if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf)) | |||||
| token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) | |||||
| if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() | |||||
| if if_loop_result != "yes": | |||||
| break | |||||
| records = split_string_by_multi_markers( | |||||
| final_result, | |||||
| [self._context_base["record_delimiter"], self._context_base["completion_delimiter"]], | |||||
| ) | |||||
| rcds = [] | |||||
| for record in records: | |||||
| record = re.search(r"\((.*)\)", record) | |||||
| if record is None: | |||||
| continue | |||||
| rcds.append(record.group(1)) | |||||
| records = rcds | |||||
| maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"]) | |||||
| return maybe_nodes, maybe_edges, token_count | |||||
| except Exception as e: | |||||
| logging.exception("error extracting graph") | |||||
| return e, None, None | |||||
| records = split_string_by_multi_markers( | |||||
| final_result, | |||||
| [self._context_base["record_delimiter"], self._context_base["completion_delimiter"]], | |||||
| ) | |||||
| rcds = [] | |||||
| for record in records: | |||||
| record = re.search(r"\((.*)\)", record) | |||||
| if record is None: | |||||
| continue | |||||
| rcds.append(record.group(1)) | |||||
| records = rcds | |||||
| maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"]) | |||||
| out_results.append((maybe_nodes, maybe_edges, token_count)) | |||||
| if self.callback: | |||||
| self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.") |
| from copy import deepcopy | from copy import deepcopy | ||||
| from hashlib import md5 | from hashlib import md5 | ||||
| from typing import Any, Callable | from typing import Any, Callable | ||||
| import os | |||||
| import trio | |||||
| import networkx as nx | import networkx as nx | ||||
| import numpy as np | import numpy as np | ||||
| ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] | ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] | ||||
| chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 100))) | |||||
| def perform_variable_replacements( | def perform_variable_replacements( | ||||
| input: str, history: list[dict] | None = None, variables: dict | None = None | input: str, history: list[dict] | None = None, variables: dict | None = None |
| "pyodbc>=5.2.0,<6.0.0", | "pyodbc>=5.2.0,<6.0.0", | ||||
| "pyicu>=2.13.1,<3.0.0", | "pyicu>=2.13.1,<3.0.0", | ||||
| "flasgger>=0.9.7.1,<0.10.0", | "flasgger>=0.9.7.1,<0.10.0", | ||||
| "xxhash>=3.5.0,<4.0.0" | |||||
| "xxhash>=3.5.0,<4.0.0", | |||||
| "trio>=0.29.0", | |||||
| ] | ] | ||||
| [project.optional-dependencies] | [project.optional-dependencies] | ||||
| "flagembedding==1.2.10", | "flagembedding==1.2.10", | ||||
| "torch>=2.5.0,<3.0.0", | "torch>=2.5.0,<3.0.0", | ||||
| "transformers>=4.35.0,<5.0.0" | "transformers>=4.35.0,<5.0.0" | ||||
| ] | |||||
| ] | |||||
| [[tool.uv.index]] | |||||
| url = "https://mirrors.aliyun.com/pypi/simple" |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import logging | import logging | ||||
| import os | |||||
| import re | import re | ||||
| from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait | |||||
| from threading import Lock | from threading import Lock | ||||
| import umap | import umap | ||||
| import numpy as np | import numpy as np | ||||
| from sklearn.mixture import GaussianMixture | from sklearn.mixture import GaussianMixture | ||||
| import trio | |||||
| from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache | |||||
| from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter | |||||
| from rag.utils import truncate | from rag.utils import truncate | ||||
| optimal_clusters = n_clusters[np.argmin(bics)] | optimal_clusters = n_clusters[np.argmin(bics)] | ||||
| return optimal_clusters | return optimal_clusters | ||||
| def __call__(self, chunks, random_state, callback=None): | |||||
| async def __call__(self, chunks, random_state, callback=None): | |||||
| layers = [(0, len(chunks))] | layers = [(0, len(chunks))] | ||||
| start, end = 0, len(chunks) | start, end = 0, len(chunks) | ||||
| if len(chunks) <= 1: | if len(chunks) <= 1: | ||||
| return [] | return [] | ||||
| chunks = [(s, a) for s, a in chunks if s and len(a) > 0] | chunks = [(s, a) for s, a in chunks if s and len(a) > 0] | ||||
| def summarize(ck_idx, lock): | |||||
| async def summarize(ck_idx, lock): | |||||
| nonlocal chunks | nonlocal chunks | ||||
| try: | try: | ||||
| texts = [chunks[i][0] for i in ck_idx] | texts = [chunks[i][0] for i in ck_idx] | ||||
| len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) | len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) | ||||
| cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) | cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) | ||||
| cnt = self._chat("You're a helpful assistant.", | |||||
| [{"role": "user", | |||||
| "content": self._prompt.format(cluster_content=cluster_content)}], | |||||
| {"temperature": 0.3, "max_tokens": self._max_token} | |||||
| ) | |||||
| async with chat_limiter: | |||||
| cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.", | |||||
| [{"role": "user", | |||||
| "content": self._prompt.format(cluster_content=cluster_content)}], | |||||
| {"temperature": 0.3, "max_tokens": self._max_token} | |||||
| )) | |||||
| cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", | cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", | ||||
| cnt) | cnt) | ||||
| logging.debug(f"SUM: {cnt}") | logging.debug(f"SUM: {cnt}") | ||||
| return e | return e | ||||
| labels = [] | labels = [] | ||||
| lock = Lock() | |||||
| while end - start > 1: | while end - start > 1: | ||||
| embeddings = [embd for _, embd in chunks[start: end]] | embeddings = [embd for _, embd in chunks[start: end]] | ||||
| if len(embeddings) == 2: | if len(embeddings) == 2: | ||||
| summarize([start, start + 1], Lock()) | |||||
| await summarize([start, start + 1], lock) | |||||
| if callback: | if callback: | ||||
| callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) | callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) | ||||
| labels.extend([0, 0]) | labels.extend([0, 0]) | ||||
| probs = gm.predict_proba(reduced_embeddings) | probs = gm.predict_proba(reduced_embeddings) | ||||
| lbls = [np.where(prob > self._threshold)[0] for prob in probs] | lbls = [np.where(prob > self._threshold)[0] for prob in probs] | ||||
| lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] | lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] | ||||
| lock = Lock() | |||||
| with ThreadPoolExecutor(max_workers=int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))) as executor: | |||||
| threads = [] | |||||
| async with trio.open_nursery() as nursery: | |||||
| for c in range(n_clusters): | for c in range(n_clusters): | ||||
| ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] | ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] | ||||
| if not ck_idx: | if not ck_idx: | ||||
| continue | continue | ||||
| threads.append(executor.submit(summarize, ck_idx, lock)) | |||||
| wait(threads, return_when=ALL_COMPLETED) | |||||
| for th in threads: | |||||
| if isinstance(th.result(), Exception): | |||||
| raise th.result() | |||||
| logging.debug(str([t.result() for t in threads])) | |||||
| async with chat_limiter: | |||||
| nursery.start_soon(lambda: summarize(ck_idx, lock)) | |||||
| assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) | assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) | ||||
| labels.extend(lbls) | labels.extend(lbls) |
| CONSUMER_NAME = "task_executor_" + CONSUMER_NO | CONSUMER_NAME = "task_executor_" + CONSUMER_NO | ||||
| initRootLogger(CONSUMER_NAME) | initRootLogger(CONSUMER_NAME) | ||||
| import asyncio | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| from datetime import datetime | from datetime import datetime | ||||
| import xxhash | import xxhash | ||||
| import copy | import copy | ||||
| import re | import re | ||||
| import time | |||||
| import threading | |||||
| from functools import partial | from functools import partial | ||||
| from io import BytesIO | from io import BytesIO | ||||
| from multiprocessing.context import TimeoutError | from multiprocessing.context import TimeoutError | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| import tracemalloc | import tracemalloc | ||||
| import resource | |||||
| import signal | import signal | ||||
| import trio | |||||
| import numpy as np | import numpy as np | ||||
| from peewee import DoesNotExist | from peewee import DoesNotExist | ||||
| from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | ||||
| from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD | from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| from rag.utils.redis_conn import REDIS_CONN, Payload | |||||
| from rag.utils.redis_conn import REDIS_CONN | |||||
| from rag.utils.storage_factory import STORAGE_IMPL | from rag.utils.storage_factory import STORAGE_IMPL | ||||
| from graphrag.utils import chat_limiter | |||||
| BATCH_SIZE = 64 | BATCH_SIZE = 64 | ||||
| ParserType.TAG.value: tag | ParserType.TAG.value: tag | ||||
| } | } | ||||
| UNACKED_ITERATOR = None | |||||
| CONSUMER_NAME = "task_consumer_" + CONSUMER_NO | CONSUMER_NAME = "task_consumer_" + CONSUMER_NO | ||||
| PAYLOAD: Payload | None = None | |||||
| BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds") | BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds") | ||||
| PENDING_TASKS = 0 | PENDING_TASKS = 0 | ||||
| LAG_TASKS = 0 | LAG_TASKS = 0 | ||||
| mt_lock = threading.Lock() | |||||
| DONE_TASKS = 0 | DONE_TASKS = 0 | ||||
| FAILED_TASKS = 0 | FAILED_TASKS = 0 | ||||
| CURRENT_TASK = None | |||||
| tracemalloc_started = False | |||||
| CURRENT_TASKS = {} | |||||
| MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) | |||||
| MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) | |||||
| task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) | |||||
| chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) | |||||
| # SIGUSR1 handler: start tracemalloc and take snapshot | # SIGUSR1 handler: start tracemalloc and take snapshot | ||||
| def start_tracemalloc_and_snapshot(signum, frame): | def start_tracemalloc_and_snapshot(signum, frame): | ||||
| global tracemalloc_started | |||||
| if not tracemalloc_started: | |||||
| logging.info("got SIGUSR1, start tracemalloc") | |||||
| if not tracemalloc.is_tracing(): | |||||
| logging.info("start tracemalloc") | |||||
| tracemalloc.start() | tracemalloc.start() | ||||
| tracemalloc_started = True | |||||
| else: | else: | ||||
| logging.info("got SIGUSR1, tracemalloc is already running") | |||||
| logging.info("tracemalloc is already running") | |||||
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||||
| snapshot_file = f"snapshot_{timestamp}.trace" | snapshot_file = f"snapshot_{timestamp}.trace" | ||||
| snapshot = tracemalloc.take_snapshot() | snapshot = tracemalloc.take_snapshot() | ||||
| snapshot.dump(snapshot_file) | snapshot.dump(snapshot_file) | ||||
| logging.info(f"taken snapshot {snapshot_file}") | |||||
| current, peak = tracemalloc.get_traced_memory() | |||||
| max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss | |||||
| logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") | |||||
| # SIGUSR2 handler: stop tracemalloc | # SIGUSR2 handler: stop tracemalloc | ||||
| def stop_tracemalloc(signum, frame): | def stop_tracemalloc(signum, frame): | ||||
| global tracemalloc_started | |||||
| if tracemalloc_started: | |||||
| logging.info("go SIGUSR2, stop tracemalloc") | |||||
| if tracemalloc.is_tracing(): | |||||
| logging.info("stop tracemalloc") | |||||
| tracemalloc.stop() | tracemalloc.stop() | ||||
| tracemalloc_started = False | |||||
| else: | else: | ||||
| logging.info("got SIGUSR2, tracemalloc not running") | |||||
| logging.info("tracemalloc not running") | |||||
| class TaskCanceledException(Exception): | class TaskCanceledException(Exception): | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): | def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): | ||||
| global PAYLOAD | |||||
| if prog is not None and prog < 0: | if prog is not None and prog < 0: | ||||
| msg = "[ERROR]" + msg | msg = "[ERROR]" + msg | ||||
| try: | |||||
| cancel = TaskService.do_cancel(task_id) | |||||
| except DoesNotExist: | |||||
| logging.warning(f"set_progress task {task_id} is unknown") | |||||
| if PAYLOAD: | |||||
| PAYLOAD.ack() | |||||
| PAYLOAD = None | |||||
| return | |||||
| cancel = TaskService.do_cancel(task_id) | |||||
| if cancel: | if cancel: | ||||
| msg += " [Canceled]" | msg += " [Canceled]" | ||||
| d["progress"] = prog | d["progress"] = prog | ||||
| logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") | logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") | ||||
| try: | |||||
| TaskService.update_progress(task_id, d) | |||||
| except DoesNotExist: | |||||
| logging.warning(f"set_progress task {task_id} is unknown") | |||||
| if PAYLOAD: | |||||
| PAYLOAD.ack() | |||||
| PAYLOAD = None | |||||
| return | |||||
| TaskService.update_progress(task_id, d) | |||||
| close_connection() | close_connection() | ||||
| if cancel and PAYLOAD: | |||||
| PAYLOAD.ack() | |||||
| PAYLOAD = None | |||||
| if cancel: | |||||
| raise TaskCanceledException(msg) | raise TaskCanceledException(msg) | ||||
| def collect(): | |||||
| global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS | |||||
| async def collect(): | |||||
| global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS | |||||
| global UNACKED_ITERATOR | |||||
| try: | try: | ||||
| PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker") | |||||
| if not PAYLOAD: | |||||
| PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) | |||||
| if not PAYLOAD: | |||||
| time.sleep(1) | |||||
| return None | |||||
| if not UNACKED_ITERATOR: | |||||
| UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) | |||||
| try: | |||||
| redis_msg = next(UNACKED_ITERATOR) | |||||
| except StopIteration: | |||||
| redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) | |||||
| if not redis_msg: | |||||
| await trio.sleep(1) | |||||
| return None, None | |||||
| except Exception: | except Exception: | ||||
| logging.exception("Get task event from queue exception") | |||||
| return None | |||||
| logging.exception("collect got exception") | |||||
| return None, None | |||||
| msg = PAYLOAD.get_message() | |||||
| msg = redis_msg.get_message() | |||||
| if not msg: | if not msg: | ||||
| return None | |||||
| logging.error(f"collect got empty message of {redis_msg.get_msg_id()}") | |||||
| redis_msg.ack() | |||||
| return None, None | |||||
| task = None | |||||
| canceled = False | canceled = False | ||||
| try: | |||||
| task = TaskService.get_task(msg["id"]) | |||||
| if task: | |||||
| _, doc = DocumentService.get_by_id(task["doc_id"]) | |||||
| canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0 | |||||
| except DoesNotExist: | |||||
| pass | |||||
| except Exception: | |||||
| logging.exception("collect get_task exception") | |||||
| task = TaskService.get_task(msg["id"]) | |||||
| if task: | |||||
| _, doc = DocumentService.get_by_id(task["doc_id"]) | |||||
| canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0 | |||||
| if not task or canceled: | if not task or canceled: | ||||
| state = "is unknown" if not task else "has been cancelled" | state = "is unknown" if not task else "has been cancelled" | ||||
| with mt_lock: | |||||
| DONE_TASKS += 1 | |||||
| logging.info(f"collect task {msg['id']} {state}") | |||||
| FAILED_TASKS += 1 | |||||
| logging.warning(f"collect task {msg['id']} {state}") | |||||
| redis_msg.ack() | |||||
| return None | return None | ||||
| task["task_type"] = msg.get("task_type", "") | task["task_type"] = msg.get("task_type", "") | ||||
| return task | |||||
| return redis_msg, task | |||||
| def get_storage_binary(bucket, name): | |||||
| return STORAGE_IMPL.get(bucket, name) | |||||
| async def get_storage_binary(bucket, name): | |||||
| return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name)) | |||||
| def build_chunks(task, progress_callback): | |||||
| async def build_chunks(task, progress_callback): | |||||
| if task["size"] > DOC_MAXIMUM_SIZE: | if task["size"] > DOC_MAXIMUM_SIZE: | ||||
| set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % | set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % | ||||
| (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) | (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) | ||||
| try: | try: | ||||
| st = timer() | st = timer() | ||||
| bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"]) | bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"]) | ||||
| binary = get_storage_binary(bucket, name) | |||||
| binary = await get_storage_binary(bucket, name) | |||||
| logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"])) | logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"])) | ||||
| except TimeoutError: | except TimeoutError: | ||||
| progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") | progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") | ||||
| raise | raise | ||||
| try: | try: | ||||
| cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], | |||||
| to_page=task["to_page"], lang=task["language"], callback=progress_callback, | |||||
| kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]) | |||||
| async with chunk_limiter: | |||||
| cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], | |||||
| to_page=task["to_page"], lang=task["language"], callback=progress_callback, | |||||
| kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) | |||||
| logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) | logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) | ||||
| except TaskCanceledException: | except TaskCanceledException: | ||||
| raise | raise | ||||
| d["image"].save(output_buffer, format='JPEG') | d["image"].save(output_buffer, format='JPEG') | ||||
| st = timer() | st = timer() | ||||
| STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()) | |||||
| await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())) | |||||
| el += timer() - st | el += timer() - st | ||||
| except Exception: | except Exception: | ||||
| logging.exception( | logging.exception( | ||||
| async def doc_keyword_extraction(chat_mdl, d, topn): | async def doc_keyword_extraction(chat_mdl, d, topn): | ||||
| cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) | cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) | ||||
| if not cached: | if not cached: | ||||
| cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn) | |||||
| async with chat_limiter: | |||||
| cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn)) | |||||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) | set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) | ||||
| if cached: | if cached: | ||||
| d["important_kwd"] = cached.split(",") | d["important_kwd"] = cached.split(",") | ||||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) | d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) | ||||
| return | return | ||||
| tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs] | |||||
| asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) | |||||
| async with trio.open_nursery() as nursery: | |||||
| for d in docs: | |||||
| nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"]) | |||||
| progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) | progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) | ||||
| if task["parser_config"].get("auto_questions", 0): | if task["parser_config"].get("auto_questions", 0): | ||||
| async def doc_question_proposal(chat_mdl, d, topn): | async def doc_question_proposal(chat_mdl, d, topn): | ||||
| cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) | cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) | ||||
| if not cached: | if not cached: | ||||
| cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn) | |||||
| async with chat_limiter: | |||||
| cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn)) | |||||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) | set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) | ||||
| if cached: | if cached: | ||||
| d["question_kwd"] = cached.split("\n") | d["question_kwd"] = cached.split("\n") | ||||
| d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) | d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) | ||||
| tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs] | |||||
| asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) | |||||
| async with trio.open_nursery() as nursery: | |||||
| for d in docs: | |||||
| nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) | |||||
| progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) | progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) | ||||
| if task["kb_parser_config"].get("tag_kb_ids", []): | if task["kb_parser_config"].get("tag_kb_ids", []): | ||||
| cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags}) | cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags}) | ||||
| if not cached: | if not cached: | ||||
| picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples | picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples | ||||
| cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags) | |||||
| async with chat_limiter: | |||||
| cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)) | |||||
| if cached: | if cached: | ||||
| cached = json.dumps(cached) | cached = json.dumps(cached) | ||||
| if cached: | if cached: | ||||
| set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) | set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) | ||||
| d[TAG_FLD] = json.loads(cached) | d[TAG_FLD] = json.loads(cached) | ||||
| tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag] | |||||
| asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) | |||||
| async with trio.open_nursery() as nursery: | |||||
| for d in docs_to_tag: | |||||
| nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags) | |||||
| progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) | progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) | ||||
| return docs | return docs | ||||
| return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) | return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) | ||||
| def embedding(docs, mdl, parser_config=None, callback=None): | |||||
| async def embedding(docs, mdl, parser_config=None, callback=None): | |||||
| if parser_config is None: | if parser_config is None: | ||||
| parser_config = {} | parser_config = {} | ||||
| batch_size = 16 | batch_size = 16 | ||||
| tk_count = 0 | tk_count = 0 | ||||
| if len(tts) == len(cnts): | if len(tts) == len(cnts): | ||||
| vts, c = mdl.encode(tts[0: 1]) | |||||
| vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) | |||||
| tts = np.concatenate([vts for _ in range(len(tts))], axis=0) | tts = np.concatenate([vts for _ in range(len(tts))], axis=0) | ||||
| tk_count += c | tk_count += c | ||||
| cnts_ = np.array([]) | cnts_ = np.array([]) | ||||
| for i in range(0, len(cnts), batch_size): | for i in range(0, len(cnts), batch_size): | ||||
| vts, c = mdl.encode(cnts[i: i + batch_size]) | |||||
| vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size])) | |||||
| if len(cnts_) == 0: | if len(cnts_) == 0: | ||||
| cnts_ = vts | cnts_ = vts | ||||
| else: | else: | ||||
| return tk_count, vector_size | return tk_count, vector_size | ||||
| def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): | |||||
| async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): | |||||
| chunks = [] | chunks = [] | ||||
| vctr_nm = "q_%d_vec"%vector_size | vctr_nm = "q_%d_vec"%vector_size | ||||
| for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | ||||
| row["parser_config"]["raptor"]["threshold"] | row["parser_config"]["raptor"]["threshold"] | ||||
| ) | ) | ||||
| original_length = len(chunks) | original_length = len(chunks) | ||||
| chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) | |||||
| chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) | |||||
| doc = { | doc = { | ||||
| "doc_id": row["doc_id"], | "doc_id": row["doc_id"], | ||||
| "kb_id": [str(row["kb_id"])], | "kb_id": [str(row["kb_id"])], | ||||
| return res, tk_count | return res, tk_count | ||||
| def run_graphrag(row, chat_model, language, embedding_model, callback=None): | |||||
| async def run_graphrag(row, chat_model, language, embedding_model, callback=None): | |||||
| chunks = [] | chunks = [] | ||||
| for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | ||||
| fields=["content_with_weight", "doc_id"]): | fields=["content_with_weight", "doc_id"]): | ||||
| chunks.append((d["doc_id"], d["content_with_weight"])) | chunks.append((d["doc_id"], d["content_with_weight"])) | ||||
| Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt, | |||||
| dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt, | |||||
| row["tenant_id"], | row["tenant_id"], | ||||
| str(row["kb_id"]), | str(row["kb_id"]), | ||||
| chat_model, | chat_model, | ||||
| entity_types=row["parser_config"]["graphrag"]["entity_types"], | entity_types=row["parser_config"]["graphrag"]["entity_types"], | ||||
| embed_bdl=embedding_model, | embed_bdl=embedding_model, | ||||
| callback=callback) | callback=callback) | ||||
| await dealer() | |||||
| def do_handle_task(task): | |||||
| async def do_handle_task(task): | |||||
| task_id = task["id"] | task_id = task["id"] | ||||
| task_from_page = task["from_page"] | task_from_page = task["from_page"] | ||||
| task_to_page = task["to_page"] | task_to_page = task["to_page"] | ||||
| task_doc_id = task["doc_id"] | task_doc_id = task["doc_id"] | ||||
| task_document_name = task["name"] | task_document_name = task["name"] | ||||
| task_parser_config = task["parser_config"] | task_parser_config = task["parser_config"] | ||||
| task_start_ts = timer() | |||||
| # prepare the progress callback function | # prepare the progress callback function | ||||
| progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) | progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) | ||||
| progress_callback(-1, msg=error_message) | progress_callback(-1, msg=error_message) | ||||
| raise Exception(error_message) | raise Exception(error_message) | ||||
| try: | |||||
| task_canceled = TaskService.do_cancel(task_id) | |||||
| except DoesNotExist: | |||||
| logging.warning(f"task {task_id} is unknown") | |||||
| return | |||||
| task_canceled = TaskService.do_cancel(task_id) | |||||
| if task_canceled: | if task_canceled: | ||||
| progress_callback(-1, msg="Task has been canceled.") | progress_callback(-1, msg="Task has been canceled.") | ||||
| return | return | ||||
| # Either using RAPTOR or Standard chunking methods | # Either using RAPTOR or Standard chunking methods | ||||
| if task.get("task_type", "") == "raptor": | if task.get("task_type", "") == "raptor": | ||||
| try: | |||||
| # bind LLM for raptor | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| # run RAPTOR | |||||
| chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) | |||||
| except TaskCanceledException: | |||||
| raise | |||||
| except Exception as e: | |||||
| error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}' | |||||
| progress_callback(-1, msg=error_message) | |||||
| logging.exception(error_message) | |||||
| raise | |||||
| # bind LLM for raptor | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| # run RAPTOR | |||||
| chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) | |||||
| # Either using graphrag or Standard chunking methods | # Either using graphrag or Standard chunking methods | ||||
| elif task.get("task_type", "") == "graphrag": | elif task.get("task_type", "") == "graphrag": | ||||
| start_ts = timer() | start_ts = timer() | ||||
| try: | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| except TaskCanceledException: | |||||
| raise | |||||
| except Exception as e: | |||||
| error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}' | |||||
| progress_callback(-1, msg=error_message) | |||||
| logging.exception(error_message) | |||||
| raise | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| return | return | ||||
| elif task.get("task_type", "") == "graph_resolution": | elif task.get("task_type", "") == "graph_resolution": | ||||
| start_ts = timer() | start_ts = timer() | ||||
| try: | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| WithResolution( | |||||
| task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model, | |||||
| progress_callback | |||||
| ) | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| except TaskCanceledException: | |||||
| raise | |||||
| except Exception as e: | |||||
| error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}' | |||||
| progress_callback(-1, msg=error_message) | |||||
| logging.exception(error_message) | |||||
| raise | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| with_res = WithResolution( | |||||
| task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model, | |||||
| progress_callback | |||||
| ) | |||||
| await with_res() | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| return | return | ||||
| elif task.get("task_type", "") == "graph_community": | elif task.get("task_type", "") == "graph_community": | ||||
| start_ts = timer() | start_ts = timer() | ||||
| try: | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| WithCommunity( | |||||
| task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, | |||||
| progress_callback | |||||
| ) | |||||
| progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| except TaskCanceledException: | |||||
| raise | |||||
| except Exception as e: | |||||
| error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}' | |||||
| progress_callback(-1, msg=error_message) | |||||
| logging.exception(error_message) | |||||
| raise | |||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||||
| with_comm = WithCommunity( | |||||
| task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, | |||||
| progress_callback | |||||
| ) | |||||
| await with_comm() | |||||
| progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| return | return | ||||
| else: | else: | ||||
| # Standard chunking methods | # Standard chunking methods | ||||
| start_ts = timer() | start_ts = timer() | ||||
| chunks = build_chunks(task, progress_callback) | |||||
| chunks = await build_chunks(task, progress_callback) | |||||
| logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts)) | logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts)) | ||||
| if chunks is None: | if chunks is None: | ||||
| return | return | ||||
| progress_callback(msg="Generate {} chunks".format(len(chunks))) | progress_callback(msg="Generate {} chunks".format(len(chunks))) | ||||
| start_ts = timer() | start_ts = timer() | ||||
| try: | try: | ||||
| token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback) | |||||
| token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback) | |||||
| except Exception as e: | except Exception as e: | ||||
| error_message = "Generate embedding error:{}".format(str(e)) | error_message = "Generate embedding error:{}".format(str(e)) | ||||
| progress_callback(-1, error_message) | progress_callback(-1, error_message) | ||||
| doc_store_result = "" | doc_store_result = "" | ||||
| es_bulk_size = 4 | es_bulk_size = 4 | ||||
| for b in range(0, len(chunks), es_bulk_size): | for b in range(0, len(chunks), es_bulk_size): | ||||
| doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), | |||||
| task_dataset_id) | |||||
| doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)) | |||||
| if b % 128 == 0: | if b % 128 == 0: | ||||
| progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") | progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") | ||||
| if doc_store_result: | if doc_store_result: | ||||
| TaskService.update_chunk_ids(task["id"], chunk_ids_str) | TaskService.update_chunk_ids(task["id"], chunk_ids_str) | ||||
| except DoesNotExist: | except DoesNotExist: | ||||
| logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") | logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") | ||||
| doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), | |||||
| task_dataset_id) | |||||
| doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) | |||||
| return | return | ||||
| logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, | logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, | ||||
| task_to_page, len(chunks), | task_to_page, len(chunks), | ||||
| DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) | DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) | ||||
| time_cost = timer() - start_ts | time_cost = timer() - start_ts | ||||
| progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost)) | |||||
| task_time_cost = timer() - task_start_ts | |||||
| progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) | |||||
| logging.info( | logging.info( | ||||
| "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, | "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, | ||||
| task_to_page, len(chunks), | task_to_page, len(chunks), | ||||
| token_count, time_cost)) | |||||
| token_count, task_time_cost)) | |||||
| def handle_task(): | |||||
| global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK | |||||
| task = collect() | |||||
| if task: | |||||
| async def handle_task(): | |||||
| global DONE_TASKS, FAILED_TASKS | |||||
| redis_msg, task = await collect() | |||||
| if not task: | |||||
| return | |||||
| try: | |||||
| logging.info(f"handle_task begin for task {json.dumps(task)}") | |||||
| CURRENT_TASKS[task["id"]] = copy.deepcopy(task) | |||||
| await do_handle_task(task) | |||||
| DONE_TASKS += 1 | |||||
| CURRENT_TASKS.pop(task["id"], None) | |||||
| logging.info(f"handle_task done for task {json.dumps(task)}") | |||||
| except Exception as e: | |||||
| FAILED_TASKS += 1 | |||||
| CURRENT_TASKS.pop(task["id"], None) | |||||
| try: | try: | ||||
| logging.info(f"handle_task begin for task {json.dumps(task)}") | |||||
| with mt_lock: | |||||
| CURRENT_TASK = copy.deepcopy(task) | |||||
| do_handle_task(task) | |||||
| with mt_lock: | |||||
| DONE_TASKS += 1 | |||||
| CURRENT_TASK = None | |||||
| logging.info(f"handle_task done for task {json.dumps(task)}") | |||||
| except TaskCanceledException: | |||||
| with mt_lock: | |||||
| DONE_TASKS += 1 | |||||
| CURRENT_TASK = None | |||||
| try: | |||||
| set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException") | |||||
| except Exception: | |||||
| pass | |||||
| logging.debug("handle_task got TaskCanceledException", exc_info=True) | |||||
| except Exception as e: | |||||
| with mt_lock: | |||||
| FAILED_TASKS += 1 | |||||
| CURRENT_TASK = None | |||||
| try: | |||||
| set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}") | |||||
| except Exception: | |||||
| pass | |||||
| logging.exception(f"handle_task got exception for task {json.dumps(task)}") | |||||
| if PAYLOAD: | |||||
| PAYLOAD.ack() | |||||
| PAYLOAD = None | |||||
| def report_status(): | |||||
| global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK | |||||
| set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}") | |||||
| except Exception: | |||||
| pass | |||||
| logging.exception(f"handle_task got exception for task {json.dumps(task)}") | |||||
| redis_msg.ack() | |||||
| async def report_status(): | |||||
| global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS | |||||
| REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) | REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) | ||||
| while True: | while True: | ||||
| try: | try: | ||||
| PENDING_TASKS = int(group_info.get("pending", 0)) | PENDING_TASKS = int(group_info.get("pending", 0)) | ||||
| LAG_TASKS = int(group_info.get("lag", 0)) | LAG_TASKS = int(group_info.get("lag", 0)) | ||||
| with mt_lock: | |||||
| heartbeat = json.dumps({ | |||||
| "name": CONSUMER_NAME, | |||||
| "now": now.astimezone().isoformat(timespec="milliseconds"), | |||||
| "boot_at": BOOT_AT, | |||||
| "pending": PENDING_TASKS, | |||||
| "lag": LAG_TASKS, | |||||
| "done": DONE_TASKS, | |||||
| "failed": FAILED_TASKS, | |||||
| "current": CURRENT_TASK, | |||||
| }) | |||||
| current = copy.deepcopy(CURRENT_TASKS) | |||||
| heartbeat = json.dumps({ | |||||
| "name": CONSUMER_NAME, | |||||
| "now": now.astimezone().isoformat(timespec="milliseconds"), | |||||
| "boot_at": BOOT_AT, | |||||
| "pending": PENDING_TASKS, | |||||
| "lag": LAG_TASKS, | |||||
| "done": DONE_TASKS, | |||||
| "failed": FAILED_TASKS, | |||||
| "current": current, | |||||
| }) | |||||
| REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) | REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) | ||||
| logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") | logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") | ||||
| REDIS_CONN.zpopmin(CONSUMER_NAME, expired) | REDIS_CONN.zpopmin(CONSUMER_NAME, expired) | ||||
| except Exception: | except Exception: | ||||
| logging.exception("report_status got exception") | logging.exception("report_status got exception") | ||||
| time.sleep(30) | |||||
| def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool): | |||||
| msg = "" | |||||
| if dump_full: | |||||
| stats2 = snapshot2.statistics('lineno') | |||||
| msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n" | |||||
| for stat in stats2[:10]: | |||||
| msg += f"{stat}\n" | |||||
| stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno') | |||||
| msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n" | |||||
| for stat in stats1_vs_2[:10]: | |||||
| msg += f"{stat}\n" | |||||
| msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n" | |||||
| for stat in stats1_vs_2[:3]: | |||||
| msg += '\n'.join(stat.traceback.format()) | |||||
| logging.info(msg) | |||||
| def main(): | |||||
| await trio.sleep(30) | |||||
| async def main(): | |||||
| logging.info(r""" | logging.info(r""" | ||||
| ______ __ ______ __ | ______ __ ______ __ | ||||
| /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____ | /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____ | ||||
| if TRACE_MALLOC_ENABLED: | if TRACE_MALLOC_ENABLED: | ||||
| start_tracemalloc_and_snapshot(None, None) | start_tracemalloc_and_snapshot(None, None) | ||||
| # Create an event to signal the background thread to exit | |||||
| stop_event = threading.Event() | |||||
| background_thread = threading.Thread(target=report_status) | |||||
| background_thread.daemon = True | |||||
| background_thread.start() | |||||
| # Handle SIGINT (Ctrl+C) | |||||
| def signal_handler(sig, frame): | |||||
| logging.info("Received Ctrl+C, shutting down gracefully...") | |||||
| stop_event.set() | |||||
| # Give the background thread time to clean up | |||||
| if background_thread.is_alive(): | |||||
| background_thread.join(timeout=5) | |||||
| logging.info("Exiting...") | |||||
| sys.exit(0) | |||||
| signal.signal(signal.SIGINT, signal_handler) | |||||
| try: | |||||
| while not stop_event.is_set(): | |||||
| handle_task() | |||||
| except KeyboardInterrupt: | |||||
| logging.info("Interrupted by keyboard, shutting down...") | |||||
| stop_event.set() | |||||
| if background_thread.is_alive(): | |||||
| background_thread.join(timeout=5) | |||||
| async with trio.open_nursery() as nursery: | |||||
| nursery.start_soon(report_status) | |||||
| while True: | |||||
| async with task_limiter: | |||||
| nursery.start_soon(handle_task) | |||||
| logging.error("BUG!!! You should not reach here!!!") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| main() | |||||
| trio.run(main) |
| from rag.utils import singleton | from rag.utils import singleton | ||||
| class Payload: | |||||
| class RedisMsg: | |||||
| def __init__(self, consumer, queue_name, group_name, msg_id, message): | def __init__(self, consumer, queue_name, group_name, msg_id, message): | ||||
| self.__consumer = consumer | self.__consumer = consumer | ||||
| self.__queue_name = queue_name | self.__queue_name = queue_name | ||||
| def get_message(self): | def get_message(self): | ||||
| return self.__message | return self.__message | ||||
| def get_msg_id(self): | |||||
| return self.__msg_id | |||||
| @singleton | @singleton | ||||
| class RedisDB: | class RedisDB: | ||||
| ) | ) | ||||
| return False | return False | ||||
| def queue_consumer( | |||||
| self, queue_name, group_name, consumer_name, msg_id=b">" | |||||
| ) -> Payload: | |||||
| def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg: | |||||
| """https://redis.io/docs/latest/commands/xreadgroup/""" | |||||
| try: | try: | ||||
| group_info = self.REDIS.xinfo_groups(queue_name) | group_info = self.REDIS.xinfo_groups(queue_name) | ||||
| if not any(e["name"] == group_name for e in group_info): | if not any(e["name"] == group_name for e in group_info): | ||||
| "groupname": group_name, | "groupname": group_name, | ||||
| "consumername": consumer_name, | "consumername": consumer_name, | ||||
| "count": 1, | "count": 1, | ||||
| "block": 10000, | |||||
| "block": 5, | |||||
| "streams": {queue_name: msg_id}, | "streams": {queue_name: msg_id}, | ||||
| } | } | ||||
| messages = self.REDIS.xreadgroup(**args) | messages = self.REDIS.xreadgroup(**args) | ||||
| if not messages: | if not messages: | ||||
| return None | return None | ||||
| stream, element_list = messages[0] | stream, element_list = messages[0] | ||||
| if not element_list: | |||||
| return None | |||||
| msg_id, payload = element_list[0] | msg_id, payload = element_list[0] | ||||
| res = Payload(self.REDIS, queue_name, group_name, msg_id, payload) | |||||
| res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload) | |||||
| return res | return res | ||||
| except Exception as e: | except Exception as e: | ||||
| if "key" in str(e): | if "key" in str(e): | ||||
| ) | ) | ||||
| return None | return None | ||||
| def get_unacked_for(self, consumer_name, queue_name, group_name): | |||||
| def get_unacked_iterator(self, queue_name, group_name, consumer_name): | |||||
| try: | try: | ||||
| group_info = self.REDIS.xinfo_groups(queue_name) | group_info = self.REDIS.xinfo_groups(queue_name) | ||||
| if not any(e["name"] == group_name for e in group_info): | if not any(e["name"] == group_name for e in group_info): | ||||
| return | return | ||||
| pendings = self.REDIS.xpending_range( | |||||
| queue_name, | |||||
| group_name, | |||||
| min=0, | |||||
| max=10000000000000, | |||||
| count=1, | |||||
| consumername=consumer_name, | |||||
| ) | |||||
| if not pendings: | |||||
| return | |||||
| msg_id = pendings[0]["message_id"] | |||||
| msg = self.REDIS.xrange(queue_name, min=msg_id, count=1) | |||||
| _, payload = msg[0] | |||||
| return Payload(self.REDIS, queue_name, group_name, msg_id, payload) | |||||
| current_min = 0 | |||||
| while True: | |||||
| payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min) | |||||
| if not payload: | |||||
| return | |||||
| current_min = payload.get_msg_id() | |||||
| logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}") | |||||
| yield payload | |||||
| except Exception as e: | except Exception as e: | ||||
| if "key" in str(e): | if "key" in str(e): | ||||
| return | return | ||||
| logging.exception( | logging.exception( | ||||
| "RedisDB.get_unacked_for " + consumer_name + " got exception: " + str(e) | |||||
| "RedisDB.get_unacked_iterator " + consumer_name + " got exception: " | |||||
| ) | ) | ||||
| self.__open__() | self.__open__() | ||||
| version = 1 | version = 1 | ||||
| revision = 1 | |||||
| requires-python = ">=3.10, <3.13" | requires-python = ">=3.10, <3.13" | ||||
| resolution-markers = [ | resolution-markers = [ | ||||
| "python_full_version >= '3.12' and sys_platform == 'darwin'", | "python_full_version >= '3.12' and sys_platform == 'darwin'", | ||||
| version = "0.8.2" | version = "0.8.2" | ||||
| source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | ||||
| sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" } | sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" } | ||||
| wheels = [ | |||||
| { url = "https://mirrors.aliyun.com/pypi/packages/44/02/53f0cf0bf0cd629ba6c2cc13f2f9db24323459e9c19463783d890a540a96/datrie-0.8.2-pp273-pypy_73-win32.whl", hash = "sha256:b07bd5fdfc3399a6dab86d6e35c72b1dbd598e80c97509c7c7518ab8774d3fda" }, | |||||
| ] | |||||
| [[package]] | [[package]] | ||||
| name = "decorator" | name = "decorator" | ||||
| version = "0.3.6" | version = "0.3.6" | ||||
| source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | ||||
| dependencies = [ | dependencies = [ | ||||
| { name = "huggingface-hub" }, | |||||
| { name = "loguru" }, | |||||
| { name = "mmh3" }, | |||||
| { name = "numpy" }, | |||||
| { name = "onnxruntime-gpu" }, | |||||
| { name = "pillow" }, | |||||
| { name = "pystemmer" }, | |||||
| { name = "requests" }, | |||||
| { name = "snowballstemmer" }, | |||||
| { name = "tokenizers" }, | |||||
| { name = "tqdm" }, | |||||
| { name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| ] | ] | ||||
| sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" } | sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" } | ||||
| wheels = [ | wheels = [ | ||||
| version = "1.19.2" | version = "1.19.2" | ||||
| source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | ||||
| dependencies = [ | dependencies = [ | ||||
| { name = "coloredlogs" }, | |||||
| { name = "flatbuffers" }, | |||||
| { name = "numpy" }, | |||||
| { name = "packaging" }, | |||||
| { name = "protobuf" }, | |||||
| { name = "sympy" }, | |||||
| { name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| { name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, | |||||
| ] | ] | ||||
| wheels = [ | wheels = [ | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" }, | { url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" }, | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" }, | { url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" }, | ||||
| ] | ] | ||||
| [[package]] | |||||
| name = "pybind11" | |||||
| version = "2.13.6" | |||||
| source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | |||||
| sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a" } | |||||
| wheels = [ | |||||
| { url = "https://mirrors.aliyun.com/pypi/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5" }, | |||||
| ] | |||||
| [[package]] | [[package]] | ||||
| name = "pyclipper" | name = "pyclipper" | ||||
| version = "1.3.0.post5" | version = "1.3.0.post5" | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" }, | { url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" }, | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" }, | { url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" }, | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" }, | { url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" }, | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/e7/c5/9140bb867141d948c8e242013ec8a8011172233c898dfdba0a2417c3169a/pycryptodomex-3.20.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:1be97461c439a6af4fe1cf8bf6ca5936d3db252737d2f379cc6b2e394e12a458" }, | |||||
| { url = "https://mirrors.aliyun.com/pypi/packages/5e/6a/04acb4978ce08ab16890c70611ebc6efd251681341617bbb9e53356dee70/pycryptodomex-3.20.0-pp27-pypy_73-win32.whl", hash = "sha256:19764605feea0df966445d46533729b645033f134baeb3ea26ad518c9fdf212c" }, | |||||
| { url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" }, | { url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" }, | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" }, | { url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" }, | ||||
| { url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" }, | { url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" }, | ||||
| { name = "tencentcloud-sdk-python" }, | { name = "tencentcloud-sdk-python" }, | ||||
| { name = "tika" }, | { name = "tika" }, | ||||
| { name = "tiktoken" }, | { name = "tiktoken" }, | ||||
| { name = "trio" }, | |||||
| { name = "umap-learn" }, | { name = "umap-learn" }, | ||||
| { name = "valkey" }, | { name = "valkey" }, | ||||
| { name = "vertexai" }, | { name = "vertexai" }, | ||||
| { name = "tiktoken", specifier = "==0.7.0" }, | { name = "tiktoken", specifier = "==0.7.0" }, | ||||
| { name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" }, | { name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" }, | ||||
| { name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" }, | { name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" }, | ||||
| { name = "trio", specifier = ">=0.29.0" }, | |||||
| { name = "umap-learn", specifier = "==0.5.6" }, | { name = "umap-learn", specifier = "==0.5.6" }, | ||||
| { name = "valkey", specifier = "==6.0.2" }, | { name = "valkey", specifier = "==6.0.2" }, | ||||
| { name = "vertexai", specifier = "==1.64.0" }, | { name = "vertexai", specifier = "==1.64.0" }, | ||||
| { name = "yfinance", specifier = "==0.1.96" }, | { name = "yfinance", specifier = "==0.1.96" }, | ||||
| { name = "zhipuai", specifier = "==2.0.1" }, | { name = "zhipuai", specifier = "==2.0.1" }, | ||||
| ] | ] | ||||
| provides-extras = ["full"] | |||||
| [[package]] | [[package]] | ||||
| name = "ranx" | name = "ranx" |