| from collections import Counter | from collections import Counter | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| import numpy as np | import numpy as np | ||||
| from api.db import ParserType | |||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| from deepdoc.vision import Recognizer | from deepdoc.vision import Recognizer | ||||
| ] | ] | ||||
| def __init__(self, domain): | def __init__(self, domain): | ||||
| super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | ||||
| self.garbage_layouts = ["footer", "header", "reference"] | |||||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): | def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): | ||||
| def __is_garbage(b): | def __is_garbage(b): | ||||
| i += 1 | i += 1 | ||||
| continue | continue | ||||
| lts_[ii]["visited"] = True | lts_[ii]["visited"] = True | ||||
| if lts_[ii]["type"] in ["footer", "header", "reference"]: | |||||
| if lts_[ii]["type"] in self.garbage_layouts: | |||||
| if lts_[ii]["type"] not in garbages: | if lts_[ii]["type"] not in garbages: | ||||
| garbages[lts_[ii]["type"]] = [] | garbages[lts_[ii]["type"]] = [] | ||||
| garbages[lts_[ii]["type"]].append(bxs[i]["text"]) | garbages[lts_[ii]["type"]].append(bxs[i]["text"]) |
| PY=/root/miniconda3/envs/py11/bin/python | PY=/root/miniconda3/envs/py11/bin/python | ||||
| function task_exe(){ | function task_exe(){ | ||||
| sleep 60; | |||||
| while [ 1 -eq 1 ];do mpirun -n 4 --allow-run-as-root $PY rag/svr/task_executor.py ; done | |||||
| while [ 1 -eq 1 ];do | |||||
| $PY rag/svr/task_executor.py $1 $2; | |||||
| done | |||||
| } | } | ||||
| function watch_broker(){ | function watch_broker(){ | ||||
| } | } | ||||
| task_bro & | task_bro & | ||||
| task_exe & | |||||
| WS=8 | |||||
| for ((i=0;i<WS;i++)) | |||||
| do | |||||
| task_exe $i $WS & | |||||
| done | |||||
| $PY api/ragflow_server.py | $PY api/ragflow_server.py | ||||
| d["page_num_int"].append(pn + 1) | d["page_num_int"].append(pn + 1) | ||||
| d["top_int"].append(top) | d["top_int"].append(top) | ||||
| d["position_int"].append((pn + 1, left, right, top, bottom)) | d["position_int"].append((pn + 1, left, right, top, bottom)) | ||||
| d["top_int"] = d["top_int"][:1] | |||||
| def remove_contents_table(sections, eng=False): | def remove_contents_table(sections, eng=False): |
| s = 1e-9 | s = 1e-9 | ||||
| for k, v in qtwt.items(): | for k, v in qtwt.items(): | ||||
| if k in dtwt: | if k in dtwt: | ||||
| s += v * dtwt[k] | |||||
| s += v# * dtwt[k] | |||||
| q = 1e-9 | q = 1e-9 | ||||
| for k, v in qtwt.items(): | for k, v in qtwt.items(): | ||||
| q += v * v | q += v * v | ||||
| d = 1e-9 | d = 1e-9 | ||||
| for k, v in dtwt.items(): | for k, v in dtwt.items(): | ||||
| d += v * v | d += v * v | ||||
| return s / math.sqrt(q) / math.sqrt(d) | |||||
| return s / q#math.sqrt(q) / math.sqrt(d) |
| return [float(t) for t in txt.split("\t")] | return [float(t) for t in txt.split("\t")] | ||||
| def insert_citations(self, answer, chunks, chunk_v, | def insert_citations(self, answer, chunks, chunk_v, | ||||
| embd_mdl, tkweight=0.3, vtweight=0.7): | |||||
| embd_mdl, tkweight=0.7, vtweight=0.3): | |||||
| assert len(chunks) == len(chunk_v) | assert len(chunks) == len(chunk_v) | ||||
| pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | ||||
| for i in range(1, len(pieces)): | for i in range(1, len(pieces)): | ||||
| chunks_tks, | chunks_tks, | ||||
| tkweight, vtweight) | tkweight, vtweight) | ||||
| mx = np.max(sim) * 0.99 | mx = np.max(sim) * 0.99 | ||||
| if mx < 0.55: | |||||
| if mx < 0.35: | |||||
| continue | continue | ||||
| cites[idx[i]] = list( | cites[idx[i]] = list( | ||||
| set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] | set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] | ||||
| if i not in cites: | if i not in cites: | ||||
| continue | continue | ||||
| for c in cites[i]: assert int(c) < len(chunk_v) | for c in cites[i]: assert int(c) < len(chunk_v) | ||||
| res += "##%s$$" % "$".join(cites[i]) | |||||
| for c in cites[i]: res += f" ##{c}$$" | |||||
| return res | return res | ||||
| def ner(t): | def ner(t): | ||||
| if not self.ne or t not in self.ne: | if not self.ne or t not in self.ne: | ||||
| return 1 | return 1 | ||||
| if re.match(r"[0-9,.]+$", t): return 2 | |||||
| m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, | m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, | ||||
| "firstnm": 1} | "firstnm": 1} | ||||
| return m[self.ne[t]] | return m[self.ne[t]] |
| minio_logger = getLogger("minio") | minio_logger = getLogger("minio") | ||||
| cron_logger = getLogger("cron_logger") | cron_logger = getLogger("cron_logger") | ||||
| chunk_logger = getLogger("chunk_logger") | chunk_logger = getLogger("chunk_logger") | ||||
| database_logger = getLogger("database") | |||||
| import sys | import sys | ||||
| import traceback | import traceback | ||||
| from functools import partial | from functools import partial | ||||
| from timeit import default_timer as timer | |||||
| from rag.settings import database_logger | |||||
| from rag.settings import cron_logger, DOC_MAXIMUM_SIZE | |||||
| import numpy as np | import numpy as np | ||||
| from elasticsearch_dsl import Q | from elasticsearch_dsl import Q | ||||
| from api.db.services.task_service import TaskService | from api.db.services.task_service import TaskService | ||||
| from rag.settings import cron_logger, DOC_MAXIMUM_SIZE | |||||
| from rag.utils import ELASTICSEARCH | from rag.utils import ELASTICSEARCH | ||||
| from rag.utils import MINIO | from rag.utils import MINIO | ||||
| from rag.utils import rmSpace, findMaxTm | from rag.utils import rmSpace, findMaxTm | ||||
| from api.db import LLMType, ParserType | from api.db import LLMType, ParserType | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.settings import database_logger | |||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| BATCH_SIZE = 64 | BATCH_SIZE = 64 | ||||
| from mpi4py import MPI | from mpi4py import MPI | ||||
| comm = MPI.COMM_WORLD | comm = MPI.COMM_WORLD | ||||
| main(comm.Get_size(), comm.Get_rank()) | |||||
| main(int(sys.argv[2]), int(sys.argv[1])) |