### What problem does this PR solve? Introduced OCR.recognize_batch ### Type of change - [x] Performance Improvementtags/v0.17.0
| @@ -17,6 +17,7 @@ | |||
| import logging | |||
| import os | |||
| import random | |||
| from timeit import default_timer as timer | |||
| import xgboost as xgb | |||
| from io import BytesIO | |||
| @@ -277,7 +278,11 @@ class RAGFlowPdfParser: | |||
| b["SP"] = ii | |||
| def __ocr(self, pagenum, img, chars, ZM=3): | |||
| start = timer() | |||
| bxs = self.ocr.detect(np.array(img)) | |||
| logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)") | |||
| start = timer() | |||
| if not bxs: | |||
| self.boxes.append([]) | |||
| return | |||
| @@ -308,14 +313,22 @@ class RAGFlowPdfParser: | |||
| else: | |||
| bxs[ii]["text"] += c["text"] | |||
| logging.info(f"__ocr sorting {len(chars)} chars cost {timer() - start}s") | |||
| start = timer() | |||
| boxes_to_reg = [] | |||
| img_np = np.array(img) | |||
| for b in bxs: | |||
| if not b["text"]: | |||
| left, right, top, bott = b["x0"] * ZM, b["x1"] * \ | |||
| ZM, b["top"] * ZM, b["bottom"] * ZM | |||
| b["text"] = self.ocr.recognize(np.array(img), | |||
| np.array([[left, top], [right, top], [right, bott], [left, bott]], | |||
| dtype=np.float32)) | |||
| b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32)) | |||
| boxes_to_reg.append(b) | |||
| del b["txt"] | |||
| texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg]) | |||
| for i in range(len(boxes_to_reg)): | |||
| boxes_to_reg[i]["text"] = texts[i] | |||
| del boxes_to_reg[i]["box_image"] | |||
| logging.info(f"__ocr recognize {len(bxs)} boxes cost {timer() - start}s") | |||
| bxs = [b for b in bxs if b["text"]] | |||
| if self.mean_height[-1] == 0: | |||
| self.mean_height[-1] = np.median([b["bottom"] - b["top"] | |||
| @@ -951,6 +964,7 @@ class RAGFlowPdfParser: | |||
| self.page_cum_height = [0] | |||
| self.page_layout = [] | |||
| self.page_from = page_from | |||
| start = timer() | |||
| try: | |||
| self.pdf = pdfplumber.open(fnm) if isinstance( | |||
| fnm, str) else pdfplumber.open(BytesIO(fnm)) | |||
| @@ -965,6 +979,7 @@ class RAGFlowPdfParser: | |||
| self.total_page = len(self.pdf.pages) | |||
| except Exception: | |||
| logging.exception("RAGFlowPdfParser __images__") | |||
| logging.info(f"__images__ dedupe_chars cost {timer() - start}s") | |||
| self.outlines = [] | |||
| try: | |||
| @@ -994,7 +1009,7 @@ class RAGFlowPdfParser: | |||
| else: | |||
| self.is_english = False | |||
| # st = timer() | |||
| start = timer() | |||
| for i, img in enumerate(self.page_images): | |||
| chars = self.page_chars[i] if not self.is_english else [] | |||
| self.mean_height.append( | |||
| @@ -1016,7 +1031,7 @@ class RAGFlowPdfParser: | |||
| self.__ocr(i + 1, img, chars, zoomin) | |||
| if callback and i % 6 == 5: | |||
| callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||
| # print("OCR:", timer()-st) | |||
| logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") | |||
| if not self.is_english and not any( | |||
| [c for c in self.page_chars]) and self.boxes: | |||
| @@ -620,6 +620,16 @@ class OCR(object): | |||
| return "" | |||
| return text | |||
| def recognize_batch(self, img_list): | |||
| rec_res, elapse = self.text_recognizer(img_list) | |||
| texts = [] | |||
| for i in range(len(rec_res)): | |||
| text, score = rec_res[i] | |||
| if score < self.drop_score: | |||
| text = "" | |||
| texts.append(text) | |||
| return texts | |||
| def __call__(self, img, cls=True): | |||
| time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} | |||
| @@ -18,7 +18,7 @@ | |||
| # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code | |||
| import random | |||
| import sys | |||
| from api.utils.log_utils import initRootLogger | |||
| from api.utils.log_utils import initRootLogger, get_project_base_directory | |||
| from graphrag.general.index import WithCommunity, WithResolution, Dealer | |||
| from graphrag.light.graph_extractor import GraphExtractor as LightKGExt | |||
| from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt | |||
| @@ -42,6 +42,7 @@ from io import BytesIO | |||
| from multiprocessing.context import TimeoutError | |||
| from timeit import default_timer as timer | |||
| import tracemalloc | |||
| import signal | |||
| import numpy as np | |||
| from peewee import DoesNotExist | |||
| @@ -96,6 +97,35 @@ DONE_TASKS = 0 | |||
| FAILED_TASKS = 0 | |||
| CURRENT_TASK = None | |||
| tracemalloc_started = False | |||
| # SIGUSR1 handler: start tracemalloc and take snapshot | |||
| def start_tracemalloc_and_snapshot(signum, frame): | |||
| global tracemalloc_started | |||
| if not tracemalloc_started: | |||
| logging.info("got SIGUSR1, start tracemalloc") | |||
| tracemalloc.start() | |||
| tracemalloc_started = True | |||
| else: | |||
| logging.info("got SIGUSR1, tracemalloc is already running") | |||
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |||
| snapshot_file = f"snapshot_{timestamp}.trace" | |||
| snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace")) | |||
| snapshot = tracemalloc.take_snapshot() | |||
| snapshot.dump(snapshot_file) | |||
| logging.info(f"taken snapshot {snapshot_file}") | |||
| # SIGUSR2 handler: stop tracemalloc | |||
| def stop_tracemalloc(signum, frame): | |||
| global tracemalloc_started | |||
| if tracemalloc_started: | |||
| logging.info("go SIGUSR2, stop tracemalloc") | |||
| tracemalloc.stop() | |||
| tracemalloc_started = False | |||
| else: | |||
| logging.info("got SIGUSR2, tracemalloc not running") | |||
| class TaskCanceledException(Exception): | |||
| def __init__(self, msg): | |||
| @@ -712,26 +742,18 @@ def main(): | |||
| logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}') | |||
| settings.init_settings() | |||
| print_rag_settings() | |||
| signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot) | |||
| signal.signal(signal.SIGUSR2, stop_tracemalloc) | |||
| TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0")) | |||
| if TRACE_MALLOC_ENABLED: | |||
| start_tracemalloc_and_snapshot(None, None) | |||
| background_thread = threading.Thread(target=report_status) | |||
| background_thread.daemon = True | |||
| background_thread.start() | |||
| TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0")) | |||
| TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0")) | |||
| if TRACE_MALLOC_DELTA > 0: | |||
| if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA: | |||
| TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA | |||
| tracemalloc.start() | |||
| snapshot1 = tracemalloc.take_snapshot() | |||
| while True: | |||
| handle_task() | |||
| num_tasks = DONE_TASKS + FAILED_TASKS | |||
| if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0: | |||
| snapshot2 = tracemalloc.take_snapshot() | |||
| analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0) | |||
| snapshot1 = snapshot2 | |||
| snapshot2 = None | |||
| if __name__ == "__main__": | |||
| main() | |||