瀏覽代碼

chage tas execution logic (#103)

tags/v0.1.0
KevinHuSh 1 年之前
父節點
當前提交
b89ac3c4be
沒有連結到貢獻者的電子郵件帳戶。
共有 8 個檔案被更改,包括 25 行新增16 行删除
  1. 4
    1
      deepdoc/vision/layout_recognizer.py
  2. 9
    5
      docker/entrypoint.sh
  3. 0
    1
      rag/nlp/__init__.py
  4. 2
    2
      rag/nlp/query.py
  5. 3
    3
      rag/nlp/search.py
  6. 1
    0
      rag/nlp/term_weight.py
  7. 2
    0
      rag/settings.py
  8. 4
    4
      rag/svr/task_executor.py

+ 4
- 1
deepdoc/vision/layout_recognizer.py 查看文件

from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from api.db import ParserType
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import Recognizer from deepdoc.vision import Recognizer
] ]
def __init__(self, domain): def __init__(self, domain):
super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
self.garbage_layouts = ["footer", "header", "reference"]
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b): def __is_garbage(b):
i += 1 i += 1
continue continue
lts_[ii]["visited"] = True lts_[ii]["visited"] = True
if lts_[ii]["type"] in ["footer", "header", "reference"]:
if lts_[ii]["type"] in self.garbage_layouts:
if lts_[ii]["type"] not in garbages: if lts_[ii]["type"] not in garbages:
garbages[lts_[ii]["type"]] = [] garbages[lts_[ii]["type"]] = []
garbages[lts_[ii]["type"]].append(bxs[i]["text"]) garbages[lts_[ii]["type"]].append(bxs[i]["text"])

+ 9
- 5
docker/entrypoint.sh 查看文件



PY=/root/miniconda3/envs/py11/bin/python PY=/root/miniconda3/envs/py11/bin/python




function task_exe(){ function task_exe(){
sleep 60;
while [ 1 -eq 1 ];do mpirun -n 4 --allow-run-as-root $PY rag/svr/task_executor.py ; done
while [ 1 -eq 1 ];do
$PY rag/svr/task_executor.py $1 $2;
done
} }


function watch_broker(){ function watch_broker(){
} }


task_bro & task_bro &
task_exe &

WS=8
for ((i=0;i<WS;i++))
do
task_exe $i $WS &
done


$PY api/ragflow_server.py $PY api/ragflow_server.py



+ 0
- 1
rag/nlp/__init__.py 查看文件

d["page_num_int"].append(pn + 1) d["page_num_int"].append(pn + 1)
d["top_int"].append(top) d["top_int"].append(top)
d["position_int"].append((pn + 1, left, right, top, bottom)) d["position_int"].append((pn + 1, left, right, top, bottom))
d["top_int"] = d["top_int"][:1]
def remove_contents_table(sections, eng=False): def remove_contents_table(sections, eng=False):

+ 2
- 2
rag/nlp/query.py 查看文件

s = 1e-9 s = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
if k in dtwt: if k in dtwt:
s += v * dtwt[k]
s += v# * dtwt[k]
q = 1e-9 q = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
q += v * v q += v * v
d = 1e-9 d = 1e-9
for k, v in dtwt.items(): for k, v in dtwt.items():
d += v * v d += v * v
return s / math.sqrt(q) / math.sqrt(d)
return s / q#math.sqrt(q) / math.sqrt(d)

+ 3
- 3
rag/nlp/search.py 查看文件

return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]


def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7):
embd_mdl, tkweight=0.7, vtweight=0.3):
assert len(chunks) == len(chunk_v) assert len(chunks) == len(chunk_v)
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)): for i in range(1, len(pieces)):
chunks_tks, chunks_tks,
tkweight, vtweight) tkweight, vtweight)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
if mx < 0.55:
if mx < 0.35:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
if i not in cites: if i not in cites:
continue continue
for c in cites[i]: assert int(c) < len(chunk_v) for c in cites[i]: assert int(c) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i])
for c in cites[i]: res += f" ##{c}$$"


return res return res



+ 1
- 0
rag/nlp/term_weight.py 查看文件

def ner(t): def ner(t):
if not self.ne or t not in self.ne: if not self.ne or t not in self.ne:
return 1 return 1
if re.match(r"[0-9,.]+$", t): return 2
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1} "firstnm": 1}
return m[self.ne[t]] return m[self.ne[t]]

+ 2
- 0
rag/settings.py 查看文件

minio_logger = getLogger("minio") minio_logger = getLogger("minio")
cron_logger = getLogger("cron_logger") cron_logger = getLogger("cron_logger")
chunk_logger = getLogger("chunk_logger") chunk_logger = getLogger("chunk_logger")
database_logger = getLogger("database")

+ 4
- 4
rag/svr/task_executor.py 查看文件

import sys import sys
import traceback import traceback
from functools import partial from functools import partial
from timeit import default_timer as timer

from rag.settings import database_logger
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE


import numpy as np import numpy as np
from elasticsearch_dsl import Q from elasticsearch_dsl import Q


from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH from rag.utils import ELASTICSEARCH
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import rmSpace, findMaxTm from rag.utils import rmSpace, findMaxTm
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import database_logger
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory


BATCH_SIZE = 64 BATCH_SIZE = 64
from mpi4py import MPI from mpi4py import MPI


comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())
main(int(sys.argv[2]), int(sys.argv[1]))

Loading…
取消
儲存