### What problem does this PR solve? Add OCR's muti-gpus and parallel processing support ### Type of change - [x] New Feature (non-breaking change which adds functionality) @yuzhichang I've tried to resolve the comments in #5697. OCR jobs can now be done on both CPU and GPU. ( By the way, I've encountered a “Generate embedding error” issue #5954 that might be due to my outdated GPUs? idk. ) Please review it and give me suggestions. GPU:   CPU: tags/v0.18.0
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| import sys | import sys | ||||
| import threading | import threading | ||||
| import trio | |||||
| import xgboost as xgb | import xgboost as xgb | ||||
| from io import BytesIO | from io import BytesIO | ||||
| sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() | sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() | ||||
| class RAGFlowPdfParser: | class RAGFlowPdfParser: | ||||
| def __init__(self): | |||||
| def __init__(self, parallel_devices: int | None = None): | |||||
| """ | """ | ||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | If you have trouble downloading HuggingFace models, -_^ this might help!! | ||||
| ^_- | ^_- | ||||
| """ | """ | ||||
| self.ocr = OCR() | |||||
| self.ocr = OCR(parallel_devices = parallel_devices) | |||||
| self.parallel_devices = parallel_devices | |||||
| self.parallel_limiter = None | |||||
| if parallel_devices is not None and parallel_devices > 1: | |||||
| self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(parallel_devices)] | |||||
| if hasattr(self, "model_speciess"): | if hasattr(self, "model_speciess"): | ||||
| self.layouter = LayoutRecognizer("layout." + self.model_speciess) | self.layouter = LayoutRecognizer("layout." + self.model_speciess) | ||||
| else: | else: | ||||
| self.updown_cnt_mdl = xgb.Booster() | self.updown_cnt_mdl = xgb.Booster() | ||||
| if not settings.LIGHTEN: | if not settings.LIGHTEN: | ||||
| try: | try: | ||||
| import torch | |||||
| import torch.cuda | |||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| self.updown_cnt_mdl.set_param({"device": "cuda"}) | self.updown_cnt_mdl.set_param({"device": "cuda"}) | ||||
| except Exception: | except Exception: | ||||
| b["H_right"] = spans[ii]["x1"] | b["H_right"] = spans[ii]["x1"] | ||||
| b["SP"] = ii | b["SP"] = ii | ||||
| def __ocr(self, pagenum, img, chars, ZM=3): | |||||
| def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None): | |||||
| start = timer() | start = timer() | ||||
| bxs = self.ocr.detect(np.array(img)) | |||||
| bxs = self.ocr.detect(np.array(img), device_id) | |||||
| logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)") | logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)") | ||||
| start = timer() | start = timer() | ||||
| b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32)) | b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32)) | ||||
| boxes_to_reg.append(b) | boxes_to_reg.append(b) | ||||
| del b["txt"] | del b["txt"] | ||||
| texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg]) | |||||
| texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg], device_id) | |||||
| for i in range(len(boxes_to_reg)): | for i in range(len(boxes_to_reg)): | ||||
| boxes_to_reg[i]["text"] = texts[i] | boxes_to_reg[i]["text"] = texts[i] | ||||
| del boxes_to_reg[i]["box_image"] | del boxes_to_reg[i]["box_image"] | ||||
| else: | else: | ||||
| self.is_english = False | self.is_english = False | ||||
| start = timer() | |||||
| for i, img in enumerate(self.page_images): | |||||
| chars = self.page_chars[i] if not self.is_english else [] | |||||
| self.mean_height.append( | |||||
| np.median(sorted([c["height"] for c in chars])) if chars else 0 | |||||
| ) | |||||
| self.mean_width.append( | |||||
| np.median(sorted([c["width"] for c in chars])) if chars else 8 | |||||
| ) | |||||
| self.page_cum_height.append(img.size[1] / zoomin) | |||||
| async def __img_ocr(i, id, img, chars, limiter): | |||||
| j = 0 | j = 0 | ||||
| while j + 1 < len(chars): | while j + 1 < len(chars): | ||||
| if chars[j]["text"] and chars[j + 1]["text"] \ | if chars[j]["text"] and chars[j + 1]["text"] \ | ||||
| and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \ | and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \ | ||||
| and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"], | and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"], | ||||
| chars[j]["width"]) / 2: | |||||
| chars[j]["width"]) / 2: | |||||
| chars[j]["text"] += " " | chars[j]["text"] += " " | ||||
| j += 1 | j += 1 | ||||
| self.__ocr(i + 1, img, chars, zoomin) | |||||
| if limiter: | |||||
| async with limiter: | |||||
| await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) | |||||
| else: | |||||
| self.__ocr(i + 1, img, chars, zoomin, id) | |||||
| if callback and i % 6 == 5: | if callback and i % 6 == 5: | ||||
| callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | ||||
| async def __img_ocr_launcher(): | |||||
| def __ocr_preprocess(): | |||||
| chars = self.page_chars[i] if not self.is_english else [] | |||||
| self.mean_height.append( | |||||
| np.median(sorted([c["height"] for c in chars])) if chars else 0 | |||||
| ) | |||||
| self.mean_width.append( | |||||
| np.median(sorted([c["width"] for c in chars])) if chars else 8 | |||||
| ) | |||||
| self.page_cum_height.append(img.size[1] / zoomin) | |||||
| return chars | |||||
| if self.parallel_limiter: | |||||
| async with trio.open_nursery() as nursery: | |||||
| for i, img in enumerate(self.page_images): | |||||
| chars = __ocr_preprocess() | |||||
| nursery.start_soon(__img_ocr, i, i % self.parallel_devices, img, chars, | |||||
| self.parallel_limiter[i % self.parallel_devices]) | |||||
| await trio.sleep(0.1) | |||||
| else: | |||||
| for i, img in enumerate(self.page_images): | |||||
| chars = __ocr_preprocess() | |||||
| await __img_ocr(i, 0, img, chars, None) | |||||
| start = timer() | |||||
| trio.run(__img_ocr_launcher) | |||||
| logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") | logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") | ||||
| if not self.is_english and not any( | if not self.is_english and not any( |
| return ops | return ops | ||||
| def load_model(model_dir, nm): | |||||
| def load_model(model_dir, nm, device_id: int | None = None): | |||||
| model_file_path = os.path.join(model_dir, nm + ".onnx") | model_file_path = os.path.join(model_dir, nm + ".onnx") | ||||
| model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path | |||||
| global loaded_models | global loaded_models | ||||
| loaded_model = loaded_models.get(model_file_path) | |||||
| loaded_model = loaded_models.get(model_cached_tag) | |||||
| if loaded_model: | if loaded_model: | ||||
| logging.info(f"load_model {model_file_path} reuses cached model") | logging.info(f"load_model {model_file_path} reuses cached model") | ||||
| return loaded_model | return loaded_model | ||||
| def cuda_is_available(): | def cuda_is_available(): | ||||
| try: | try: | ||||
| import torch | import torch | ||||
| if torch.cuda.is_available(): | |||||
| if torch.cuda.is_available() and torch.cuda.device_count() > device_id: | |||||
| return True | return True | ||||
| except Exception: | except Exception: | ||||
| return False | return False | ||||
| run_options = ort.RunOptions() | run_options = ort.RunOptions() | ||||
| if cuda_is_available(): | if cuda_is_available(): | ||||
| cuda_provider_options = { | cuda_provider_options = { | ||||
| "device_id": 0, # Use specific GPU | |||||
| "device_id": device_id, # Use specific GPU | |||||
| "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory | "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory | ||||
| "arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy | "arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy | ||||
| } | } | ||||
| providers=['CUDAExecutionProvider'], | providers=['CUDAExecutionProvider'], | ||||
| provider_options=[cuda_provider_options] | provider_options=[cuda_provider_options] | ||||
| ) | ) | ||||
| run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:0") | |||||
| run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id)) | |||||
| logging.info(f"load_model {model_file_path} uses GPU") | logging.info(f"load_model {model_file_path} uses GPU") | ||||
| else: | else: | ||||
| sess = ort.InferenceSession( | sess = ort.InferenceSession( | ||||
| run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu") | run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu") | ||||
| logging.info(f"load_model {model_file_path} uses CPU") | logging.info(f"load_model {model_file_path} uses CPU") | ||||
| loaded_model = (sess, run_options) | loaded_model = (sess, run_options) | ||||
| loaded_models[model_file_path] = loaded_model | |||||
| loaded_models[model_cached_tag] = loaded_model | |||||
| return loaded_model | return loaded_model | ||||
| class TextRecognizer: | class TextRecognizer: | ||||
| def __init__(self, model_dir): | |||||
| def __init__(self, model_dir, device_id: int | None = None): | |||||
| self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")] | self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")] | ||||
| self.rec_batch_num = 16 | self.rec_batch_num = 16 | ||||
| postprocess_params = { | postprocess_params = { | ||||
| "use_space_char": True | "use_space_char": True | ||||
| } | } | ||||
| self.postprocess_op = build_post_process(postprocess_params) | self.postprocess_op = build_post_process(postprocess_params) | ||||
| self.predictor, self.run_options = load_model(model_dir, 'rec') | |||||
| self.predictor, self.run_options = load_model(model_dir, 'rec', device_id) | |||||
| self.input_tensor = self.predictor.get_inputs()[0] | self.input_tensor = self.predictor.get_inputs()[0] | ||||
| def resize_norm_img(self, img, max_wh_ratio): | def resize_norm_img(self, img, max_wh_ratio): | ||||
| class TextDetector: | class TextDetector: | ||||
| def __init__(self, model_dir): | |||||
| def __init__(self, model_dir, device_id: int | None = None): | |||||
| pre_process_list = [{ | pre_process_list = [{ | ||||
| 'DetResizeForTest': { | 'DetResizeForTest': { | ||||
| 'limit_side_len': 960, | 'limit_side_len': 960, | ||||
| "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} | "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} | ||||
| self.postprocess_op = build_post_process(postprocess_params) | self.postprocess_op = build_post_process(postprocess_params) | ||||
| self.predictor, self.run_options = load_model(model_dir, 'det') | |||||
| self.predictor, self.run_options = load_model(model_dir, 'det', device_id) | |||||
| self.input_tensor = self.predictor.get_inputs()[0] | self.input_tensor = self.predictor.get_inputs()[0] | ||||
| img_h, img_w = self.input_tensor.shape[2:] | img_h, img_w = self.input_tensor.shape[2:] | ||||
| class OCR: | class OCR: | ||||
| def __init__(self, model_dir=None): | |||||
| def __init__(self, model_dir=None, parallel_devices: int | None = None): | |||||
| """ | """ | ||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | If you have trouble downloading HuggingFace models, -_^ this might help!! | ||||
| model_dir = os.path.join( | model_dir = os.path.join( | ||||
| get_project_base_directory(), | get_project_base_directory(), | ||||
| "rag/res/deepdoc") | "rag/res/deepdoc") | ||||
| self.text_detector = TextDetector(model_dir) | |||||
| self.text_recognizer = TextRecognizer(model_dir) | |||||
| # Append muti-gpus task to the list | |||||
| if parallel_devices is not None and parallel_devices > 0: | |||||
| self.text_detector = [] | |||||
| self.text_recognizer = [] | |||||
| for device_id in range(parallel_devices): | |||||
| self.text_detector.append(TextDetector(model_dir, device_id)) | |||||
| self.text_recognizer.append(TextRecognizer(model_dir, device_id)) | |||||
| else: | |||||
| self.text_detector = [TextDetector(model_dir, 0)] | |||||
| self.text_recognizer = [TextRecognizer(model_dir, 0)] | |||||
| except Exception: | except Exception: | ||||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", | model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", | ||||
| local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), | local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), | ||||
| local_dir_use_symlinks=False) | local_dir_use_symlinks=False) | ||||
| self.text_detector = TextDetector(model_dir) | |||||
| self.text_recognizer = TextRecognizer(model_dir) | |||||
| if parallel_devices is not None: | |||||
| assert parallel_devices > 0 , "Number of devices must be >= 1" | |||||
| self.text_detector = [] | |||||
| self.text_recognizer = [] | |||||
| for device_id in range(parallel_devices): | |||||
| self.text_detector.append(TextDetector(model_dir, device_id)) | |||||
| self.text_recognizer.append(TextRecognizer(model_dir, device_id)) | |||||
| else: | |||||
| self.text_detector = [TextDetector(model_dir, 0)] | |||||
| self.text_recognizer = [TextRecognizer(model_dir, 0)] | |||||
| self.drop_score = 0.5 | self.drop_score = 0.5 | ||||
| self.crop_image_res_index = 0 | self.crop_image_res_index = 0 | ||||
| break | break | ||||
| return _boxes | return _boxes | ||||
| def detect(self, img): | |||||
| def detect(self, img, device_id: int | None = None): | |||||
| if device_id is None: | |||||
| device_id = 0 | |||||
| time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} | time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} | ||||
| if img is None: | if img is None: | ||||
| return None, None, time_dict | return None, None, time_dict | ||||
| start = time.time() | start = time.time() | ||||
| dt_boxes, elapse = self.text_detector(img) | |||||
| dt_boxes, elapse = self.text_detector[device_id](img) | |||||
| time_dict['det'] = elapse | time_dict['det'] = elapse | ||||
| if dt_boxes is None: | if dt_boxes is None: | ||||
| return zip(self.sorted_boxes(dt_boxes), [ | return zip(self.sorted_boxes(dt_boxes), [ | ||||
| ("", 0) for _ in range(len(dt_boxes))]) | ("", 0) for _ in range(len(dt_boxes))]) | ||||
| def recognize(self, ori_im, box): | |||||
| def recognize(self, ori_im, box, device_id: int | None = None): | |||||
| if device_id is None: | |||||
| device_id = 0 | |||||
| img_crop = self.get_rotate_crop_image(ori_im, box) | img_crop = self.get_rotate_crop_image(ori_im, box) | ||||
| rec_res, elapse = self.text_recognizer([img_crop]) | |||||
| rec_res, elapse = self.text_recognizer[device_id]([img_crop]) | |||||
| text, score = rec_res[0] | text, score = rec_res[0] | ||||
| if score < self.drop_score: | if score < self.drop_score: | ||||
| return "" | return "" | ||||
| return text | return text | ||||
| def recognize_batch(self, img_list): | |||||
| rec_res, elapse = self.text_recognizer(img_list) | |||||
| def recognize_batch(self, img_list, device_id: int | None = None): | |||||
| if device_id is None: | |||||
| device_id = 0 | |||||
| rec_res, elapse = self.text_recognizer[device_id](img_list) | |||||
| texts = [] | texts = [] | ||||
| for i in range(len(rec_res)): | for i in range(len(rec_res)): | ||||
| text, score = rec_res[i] | text, score = rec_res[i] | ||||
| texts.append(text) | texts.append(text) | ||||
| return texts | return texts | ||||
| def __call__(self, img, cls=True): | |||||
| def __call__(self, img, device_id = 0, cls=True): | |||||
| time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} | time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} | ||||
| if device_id is None: | |||||
| device_id = 0 | |||||
| if img is None: | if img is None: | ||||
| return None, None, time_dict | return None, None, time_dict | ||||
| start = time.time() | start = time.time() | ||||
| ori_im = img.copy() | ori_im = img.copy() | ||||
| dt_boxes, elapse = self.text_detector(img) | |||||
| dt_boxes, elapse = self.text_detector[device_id](img) | |||||
| time_dict['det'] = elapse | time_dict['det'] = elapse | ||||
| if dt_boxes is None: | if dt_boxes is None: | ||||
| img_crop = self.get_rotate_crop_image(ori_im, tmp_box) | img_crop = self.get_rotate_crop_image(ori_im, tmp_box) | ||||
| img_crop_list.append(img_crop) | img_crop_list.append(img_crop) | ||||
| rec_res, elapse = self.text_recognizer(img_crop_list) | |||||
| rec_res, elapse = self.text_recognizer[device_id](img_crop_list) | |||||
| time_dict['rec'] = elapse | time_dict['rec'] = elapse | ||||
| from deepdoc.vision import OCR, init_in_out | from deepdoc.vision import OCR, init_in_out | ||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| import trio | |||||
| # os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous | |||||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu | |||||
| # os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu | |||||
| def main(args): | def main(args): | ||||
| ocr = OCR() | |||||
| import torch.cuda | |||||
| cuda_devices = torch.cuda.device_count() | |||||
| limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None | |||||
| ocr = OCR(parallel_devices = cuda_devices) | |||||
| images, outputs = init_in_out(args) | images, outputs = init_in_out(args) | ||||
| for i, img in enumerate(images): | |||||
| bxs = ocr(np.array(img)) | |||||
| def __ocr(i, id, img): | |||||
| print("Task {} start".format(i)) | |||||
| bxs = ocr(np.array(img), id) | |||||
| bxs = [(line[0], line[1][0]) for line in bxs] | bxs = [(line[0], line[1][0]) for line in bxs] | ||||
| bxs = [{ | bxs = [{ | ||||
| "text": t, | "text": t, | ||||
| with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f: | with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f: | ||||
| f.write("\n".join([o["text"] for o in bxs])) | f.write("\n".join([o["text"] for o in bxs])) | ||||
| print("Task {} done".format(i)) | |||||
| async def __ocr_thread(i, id, img, limiter = None): | |||||
| if limiter: | |||||
| async with limiter: | |||||
| print("Task {} use device {}".format(i, id)) | |||||
| await trio.to_thread.run_sync(lambda: __ocr(i, id, img)) | |||||
| else: | |||||
| __ocr(i, id, img) | |||||
| async def __ocr_launcher(): | |||||
| if cuda_devices > 1: | |||||
| async with trio.open_nursery() as nursery: | |||||
| for i, img in enumerate(images): | |||||
| nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices]) | |||||
| await trio.sleep(0.1) | |||||
| else: | |||||
| for i, img in enumerate(images): | |||||
| await __ocr_thread(i, 0, img) | |||||
| trio.run(__ocr_launcher) | |||||
| print("OCR tasks are all done") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() |
| class Pdf(PdfParser): | class Pdf(PdfParser): | ||||
| def __init__(self, parallel_devices = None): | |||||
| super().__init__(parallel_devices) | |||||
| def __call__(self, filename, binary=None, from_page=0, | def __call__(self, filename, binary=None, from_page=0, | ||||
| to_page=100000, zoomin=3, callback=None): | to_page=100000, zoomin=3, callback=None): | ||||
| start = timer() | start = timer() | ||||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | def chunk(filename, binary=None, from_page=0, to_page=100000, | ||||
| lang="Chinese", callback=None, **kwargs): | |||||
| lang="Chinese", parallel_devices=None, callback=None, **kwargs): | |||||
| """ | """ | ||||
| Supported file formats are docx, pdf, excel, txt. | Supported file formats are docx, pdf, excel, txt. | ||||
| This method apply the naive ways to chunk files. | This method apply the naive ways to chunk files. | ||||
| return res | return res | ||||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | elif re.search(r"\.pdf$", filename, re.IGNORECASE): | ||||
| pdf_parser = Pdf() | |||||
| pdf_parser = Pdf(parallel_devices) | |||||
| if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": | if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": | ||||
| pdf_parser = PlainParser() | pdf_parser = PlainParser() | ||||
| sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, | sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, |
| task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) | task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) | ||||
| chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) | chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) | ||||
| PARALLEL_DEVICES = None | |||||
| try: | |||||
| import torch.cuda | |||||
| PARALLEL_DEVICES = torch.cuda.device_count() | |||||
| logging.info(f"found {PARALLEL_DEVICES} gpus") | |||||
| except Exception: | |||||
| logging.info("can't import package 'torch'") | |||||
| # SIGUSR1 handler: start tracemalloc and take snapshot | # SIGUSR1 handler: start tracemalloc and take snapshot | ||||
| def start_tracemalloc_and_snapshot(signum, frame): | def start_tracemalloc_and_snapshot(signum, frame): | ||||
| if not tracemalloc.is_tracing(): | if not tracemalloc.is_tracing(): | ||||
| try: | try: | ||||
| async with chunk_limiter: | async with chunk_limiter: | ||||
| cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], | cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], | ||||
| to_page=task["to_page"], lang=task["language"], callback=progress_callback, | |||||
| to_page=task["to_page"], lang=task["language"], parallel_devices = PARALLEL_DEVICES, callback=progress_callback, | |||||
| kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) | kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) | ||||
| logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) | logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) | ||||
| except TaskCanceledException: | except TaskCanceledException: |