Browse Source

Feat: add OCR's muti-gpus and parallel processing support (#5972)

### What problem does this PR solve?

Add OCR's muti-gpus and parallel processing support

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

@yuzhichang I've tried to resolve the comments in #5697. OCR jobs can
now be done on both CPU and GPU. ( By the way, I've encountered a
“Generate embedding error” issue #5954 that might be due to my outdated
GPUs? idk. ) Please review it and give me suggestions.

GPU:

![gpu_ocr](https://github.com/user-attachments/assets/0ee2ecfb-a665-4e50-8bc7-15941b9cd80e)

![smi](https://github.com/user-attachments/assets/a2312f8c-cf24-443d-bf89-bec50503546d)

CPU:

![cpu_ocr](https://github.com/user-attachments/assets/1ba6bb0b-94df-41ea-be79-790096da4bf1)
tags/v0.18.0
Debug Doctor 7 months ago
parent
commit
3e19044dee
No account linked to committer's email address
5 changed files with 157 additions and 48 deletions
  1. 51
    18
      deepdoc/parser/pdf_parser.py
  2. 55
    24
      deepdoc/vision/ocr.py
  3. 37
    3
      deepdoc/vision/t_ocr.py
  4. 5
    2
      rag/app/naive.py
  5. 9
    1
      rag/svr/task_executor.py

+ 51
- 18
deepdoc/parser/pdf_parser.py View File

from timeit import default_timer as timer from timeit import default_timer as timer
import sys import sys
import threading import threading
import trio


import xgboost as xgb import xgboost as xgb
from io import BytesIO from io import BytesIO
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()


class RAGFlowPdfParser: class RAGFlowPdfParser:
def __init__(self):
def __init__(self, parallel_devices: int | None = None):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!


^_- ^_-


""" """
self.ocr = OCR()
self.ocr = OCR(parallel_devices = parallel_devices)
self.parallel_devices = parallel_devices
self.parallel_limiter = None
if parallel_devices is not None and parallel_devices > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(parallel_devices)]
if hasattr(self, "model_speciess"): if hasattr(self, "model_speciess"):
self.layouter = LayoutRecognizer("layout." + self.model_speciess) self.layouter = LayoutRecognizer("layout." + self.model_speciess)
else: else:
self.updown_cnt_mdl = xgb.Booster() self.updown_cnt_mdl = xgb.Booster()
if not settings.LIGHTEN: if not settings.LIGHTEN:
try: try:
import torch
import torch.cuda
if torch.cuda.is_available(): if torch.cuda.is_available():
self.updown_cnt_mdl.set_param({"device": "cuda"}) self.updown_cnt_mdl.set_param({"device": "cuda"})
except Exception: except Exception:
b["H_right"] = spans[ii]["x1"] b["H_right"] = spans[ii]["x1"]
b["SP"] = ii b["SP"] = ii


def __ocr(self, pagenum, img, chars, ZM=3):
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
start = timer() start = timer()
bxs = self.ocr.detect(np.array(img))
bxs = self.ocr.detect(np.array(img), device_id)
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)") logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")


start = timer() start = timer()
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32)) b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
boxes_to_reg.append(b) boxes_to_reg.append(b)
del b["txt"] del b["txt"]
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg])
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg], device_id)
for i in range(len(boxes_to_reg)): for i in range(len(boxes_to_reg)):
boxes_to_reg[i]["text"] = texts[i] boxes_to_reg[i]["text"] = texts[i]
del boxes_to_reg[i]["box_image"] del boxes_to_reg[i]["box_image"]
else: else:
self.is_english = False self.is_english = False


start = timer()
for i, img in enumerate(self.page_images):
chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append(
np.median(sorted([c["height"] for c in chars])) if chars else 0
)
self.mean_width.append(
np.median(sorted([c["width"] for c in chars])) if chars else 8
)
self.page_cum_height.append(img.size[1] / zoomin)
async def __img_ocr(i, id, img, chars, limiter):
j = 0 j = 0
while j + 1 < len(chars): while j + 1 < len(chars):
if chars[j]["text"] and chars[j + 1]["text"] \ if chars[j]["text"] and chars[j + 1]["text"] \
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \ and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"], and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
chars[j]["width"]) / 2:
chars[j]["width"]) / 2:
chars[j]["text"] += " " chars[j]["text"] += " "
j += 1 j += 1


self.__ocr(i + 1, img, chars, zoomin)
if limiter:
async with limiter:
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
else:
self.__ocr(i + 1, img, chars, zoomin, id)
if callback and i % 6 == 5: if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")

async def __img_ocr_launcher():
def __ocr_preprocess():
chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append(
np.median(sorted([c["height"] for c in chars])) if chars else 0
)
self.mean_width.append(
np.median(sorted([c["width"] for c in chars])) if chars else 8
)
self.page_cum_height.append(img.size[1] / zoomin)
return chars
if self.parallel_limiter:
async with trio.open_nursery() as nursery:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()

nursery.start_soon(__img_ocr, i, i % self.parallel_devices, img, chars,
self.parallel_limiter[i % self.parallel_devices])
await trio.sleep(0.1)
else:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
await __img_ocr(i, 0, img, chars, None)

start = timer()
trio.run(__img_ocr_launcher)
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")


if not self.is_english and not any( if not self.is_english and not any(

+ 55
- 24
deepdoc/vision/ocr.py View File

return ops return ops




def load_model(model_dir, nm):
def load_model(model_dir, nm, device_id: int | None = None):
model_file_path = os.path.join(model_dir, nm + ".onnx") model_file_path = os.path.join(model_dir, nm + ".onnx")
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path

global loaded_models global loaded_models
loaded_model = loaded_models.get(model_file_path)
loaded_model = loaded_models.get(model_cached_tag)
if loaded_model: if loaded_model:
logging.info(f"load_model {model_file_path} reuses cached model") logging.info(f"load_model {model_file_path} reuses cached model")
return loaded_model return loaded_model
def cuda_is_available(): def cuda_is_available():
try: try:
import torch import torch
if torch.cuda.is_available():
if torch.cuda.is_available() and torch.cuda.device_count() > device_id:
return True return True
except Exception: except Exception:
return False return False
run_options = ort.RunOptions() run_options = ort.RunOptions()
if cuda_is_available(): if cuda_is_available():
cuda_provider_options = { cuda_provider_options = {
"device_id": 0, # Use specific GPU
"device_id": device_id, # Use specific GPU
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
"arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy "arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
} }
providers=['CUDAExecutionProvider'], providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options] provider_options=[cuda_provider_options]
) )
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:0")
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id))
logging.info(f"load_model {model_file_path} uses GPU") logging.info(f"load_model {model_file_path} uses GPU")
else: else:
sess = ort.InferenceSession( sess = ort.InferenceSession(
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu") run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
logging.info(f"load_model {model_file_path} uses CPU") logging.info(f"load_model {model_file_path} uses CPU")
loaded_model = (sess, run_options) loaded_model = (sess, run_options)
loaded_models[model_file_path] = loaded_model
loaded_models[model_cached_tag] = loaded_model
return loaded_model return loaded_model




class TextRecognizer: class TextRecognizer:
def __init__(self, model_dir):
def __init__(self, model_dir, device_id: int | None = None):
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")] self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
self.rec_batch_num = 16 self.rec_batch_num = 16
postprocess_params = { postprocess_params = {
"use_space_char": True "use_space_char": True
} }
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'rec')
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
self.input_tensor = self.predictor.get_inputs()[0] self.input_tensor = self.predictor.get_inputs()[0]


def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):




class TextDetector: class TextDetector:
def __init__(self, model_dir):
def __init__(self, model_dir, device_id: int | None = None):
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': { 'DetResizeForTest': {
'limit_side_len': 960, 'limit_side_len': 960,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}


self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'det')
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
self.input_tensor = self.predictor.get_inputs()[0] self.input_tensor = self.predictor.get_inputs()[0]


img_h, img_w = self.input_tensor.shape[2:] img_h, img_w = self.input_tensor.shape[2:]




class OCR: class OCR:
def __init__(self, model_dir=None):
def __init__(self, model_dir=None, parallel_devices: int | None = None):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!


model_dir = os.path.join( model_dir = os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc") "rag/res/deepdoc")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
# Append muti-gpus task to the list
if parallel_devices is not None and parallel_devices > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(parallel_devices):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir, 0)]
self.text_recognizer = [TextRecognizer(model_dir, 0)]

except Exception: except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
if parallel_devices is not None:
assert parallel_devices > 0 , "Number of devices must be >= 1"
self.text_detector = []
self.text_recognizer = []
for device_id in range(parallel_devices):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir, 0)]
self.text_recognizer = [TextRecognizer(model_dir, 0)]


self.drop_score = 0.5 self.drop_score = 0.5
self.crop_image_res_index = 0 self.crop_image_res_index = 0
break break
return _boxes return _boxes


def detect(self, img):
def detect(self, img, device_id: int | None = None):
if device_id is None:
device_id = 0

time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}


if img is None: if img is None:
return None, None, time_dict return None, None, time_dict


start = time.time() start = time.time()
dt_boxes, elapse = self.text_detector(img)
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse time_dict['det'] = elapse


if dt_boxes is None: if dt_boxes is None:
return zip(self.sorted_boxes(dt_boxes), [ return zip(self.sorted_boxes(dt_boxes), [
("", 0) for _ in range(len(dt_boxes))]) ("", 0) for _ in range(len(dt_boxes))])


def recognize(self, ori_im, box):
def recognize(self, ori_im, box, device_id: int | None = None):
if device_id is None:
device_id = 0

img_crop = self.get_rotate_crop_image(ori_im, box) img_crop = self.get_rotate_crop_image(ori_im, box)


rec_res, elapse = self.text_recognizer([img_crop])
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
text, score = rec_res[0] text, score = rec_res[0]
if score < self.drop_score: if score < self.drop_score:
return "" return ""
return text return text


def recognize_batch(self, img_list):
rec_res, elapse = self.text_recognizer(img_list)
def recognize_batch(self, img_list, device_id: int | None = None):
if device_id is None:
device_id = 0
rec_res, elapse = self.text_recognizer[device_id](img_list)
texts = [] texts = []
for i in range(len(rec_res)): for i in range(len(rec_res)):
text, score = rec_res[i] text, score = rec_res[i]
texts.append(text) texts.append(text)
return texts return texts


def __call__(self, img, cls=True):
def __call__(self, img, device_id = 0, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if device_id is None:
device_id = 0


if img is None: if img is None:
return None, None, time_dict return None, None, time_dict


start = time.time() start = time.time()
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse time_dict['det'] = elapse


if dt_boxes is None: if dt_boxes is None:
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)


rec_res, elapse = self.text_recognizer(img_crop_list)
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)


time_dict['rec'] = elapse time_dict['rec'] = elapse



+ 37
- 3
deepdoc/vision/t_ocr.py View File

from deepdoc.vision import OCR, init_in_out from deepdoc.vision import OCR, init_in_out
import argparse import argparse
import numpy as np import numpy as np
import trio


# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu


def main(args): def main(args):
ocr = OCR()
import torch.cuda

cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR(parallel_devices = cuda_devices)
images, outputs = init_in_out(args) images, outputs = init_in_out(args)


for i, img in enumerate(images):
bxs = ocr(np.array(img))

def __ocr(i, id, img):
print("Task {} start".format(i))
bxs = ocr(np.array(img), id)
bxs = [(line[0], line[1][0]) for line in bxs] bxs = [(line[0], line[1][0]) for line in bxs]
bxs = [{ bxs = [{
"text": t, "text": t,
with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f: with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f:
f.write("\n".join([o["text"] for o in bxs])) f.write("\n".join([o["text"] for o in bxs]))


print("Task {} done".format(i))

async def __ocr_thread(i, id, img, limiter = None):
if limiter:
async with limiter:
print("Task {} use device {}".format(i, id))
await trio.to_thread.run_sync(lambda: __ocr(i, id, img))
else:
__ocr(i, id, img)

async def __ocr_launcher():
if cuda_devices > 1:
async with trio.open_nursery() as nursery:
for i, img in enumerate(images):
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices])
await trio.sleep(0.1)
else:
for i, img in enumerate(images):
await __ocr_thread(i, 0, img)

trio.run(__ocr_launcher)

print("OCR tasks are all done")



if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

+ 5
- 2
rag/app/naive.py View File





class Pdf(PdfParser): class Pdf(PdfParser):
def __init__(self, parallel_devices = None):
super().__init__(parallel_devices)

def __call__(self, filename, binary=None, from_page=0, def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None): to_page=100000, zoomin=3, callback=None):
start = timer() start = timer()




def chunk(filename, binary=None, from_page=0, to_page=100000, def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
lang="Chinese", parallel_devices=None, callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, excel, txt. Supported file formats are docx, pdf, excel, txt.
This method apply the naive ways to chunk files. This method apply the naive ways to chunk files.
return res return res


elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
pdf_parser = Pdf(parallel_devices)
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser() pdf_parser = PlainParser()
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,

+ 9
- 1
rag/svr/task_executor.py View File

task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)


PARALLEL_DEVICES = None
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")

# SIGUSR1 handler: start tracemalloc and take snapshot # SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame): def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing(): if not tracemalloc.is_tracing():
try: try:
async with chunk_limiter: async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
to_page=task["to_page"], lang=task["language"], parallel_devices = PARALLEL_DEVICES, callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException: except TaskCanceledException:

Loading…
Cancel
Save