瀏覽代碼

set PARALLEL_DEVICES default value= 0 (#7935)

### What problem does this PR solve?


it would be fail if PARALLEL_DEVICES = None in OCR class , because it
pass 0 to TextDetector and TextRecognizer init method.

and It would be simpler to set 0 as the default value for
PARALLEL_DEVICES.

### Type of change

- [x] Refactoring
tags/v0.19.1
giiiiiithub 5 月之前
父節點
當前提交
6ba5a4348a
No account linked to committer's email address
共有 3 個文件被更改,包括 8 次插入9 次删除
  1. 1
    1
      deepdoc/parser/pdf_parser.py
  2. 6
    7
      deepdoc/vision/ocr.py
  3. 1
    1
      rag/settings.py

+ 1
- 1
deepdoc/parser/pdf_parser.py 查看文件



self.ocr = OCR() self.ocr = OCR()
self.parallel_limiter = None self.parallel_limiter = None
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1:
if PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]


if hasattr(self, "model_speciess"): if hasattr(self, "model_speciess"):

+ 6
- 7
deepdoc/vision/ocr.py 查看文件

"rag/res/deepdoc") "rag/res/deepdoc")
# Append muti-gpus task to the list # Append muti-gpus task to the list
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 0:
if PARALLEL_DEVICES > 0:
self.text_detector = [] self.text_detector = []
self.text_recognizer = [] self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES): for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id)) self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else: else:
self.text_detector = [TextDetector(model_dir, 0)]
self.text_recognizer = [TextRecognizer(model_dir, 0)]
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]


except Exception: except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
if PARALLEL_DEVICES is not None:
assert PARALLEL_DEVICES > 0, "Number of devices must be >= 1"
if PARALLEL_DEVICES > 0:
self.text_detector = [] self.text_detector = []
self.text_recognizer = [] self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES): for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id)) self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else: else:
self.text_detector = [TextDetector(model_dir, 0)]
self.text_recognizer = [TextRecognizer(model_dir, 0)]
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]


self.drop_score = 0.5 self.drop_score = 0.5
self.crop_image_res_index = 0 self.crop_image_res_index = 0

+ 1
- 1
rag/settings.py 查看文件

PAGERANK_FLD = "pagerank_fea" PAGERANK_FLD = "pagerank_fea"
TAG_FLD = "tag_feas" TAG_FLD = "tag_feas"


PARALLEL_DEVICES = None
PARALLEL_DEVICES = 0
try: try:
import torch.cuda import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count() PARALLEL_DEVICES = torch.cuda.device_count()

Loading…
取消
儲存