| @@ -18,10 +18,12 @@ import datetime | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from elasticsearch_dsl import Q | |||
| from rag.app.qa import rmPrefix, beAdoc | |||
| from rag.nlp import search, huqie, retrievaler | |||
| from rag.utils import ELASTICSEARCH, rmSpace | |||
| from api.db import LLMType | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| @@ -89,10 +91,8 @@ def get(): | |||
| res["chunk_id"] = id | |||
| k = [] | |||
| for n in res.keys(): | |||
| if re.search(r"(_vec$|_sm_)", n): | |||
| if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): | |||
| k.append(n) | |||
| if re.search(r"(_tks|_ltks)", n): | |||
| res[n] = rmSpace(res[n]) | |||
| for n in k: | |||
| del res[n] | |||
| @@ -106,12 +106,12 @@ def get(): | |||
| @manager.route('/set', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id", "chunk_id", "content_ltks", | |||
| @validate_request("doc_id", "chunk_id", "content_with_weight", | |||
| "important_kwd") | |||
| def set(): | |||
| req = request.json | |||
| d = {"id": req["chunk_id"]} | |||
| d["content_ltks"] = huqie.qie(req["content_ltks"]) | |||
| d["content_ltks"] = huqie.qie(req["content_with_weight"]) | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| d["important_kwd"] = req["important_kwd"] | |||
| d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) | |||
| @@ -127,8 +127,15 @@ def set(): | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if doc.parser_id == ParserType.QA: | |||
| arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t)>1] | |||
| if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.") | |||
| q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] | |||
| d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a])) | |||
| v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] | |||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||
| return get_json_result(data=True) | |||
| @@ -18,7 +18,7 @@ from flask import request | |||
| from flask_login import login_required, current_user | |||
| from api.db.services.dialog_service import DialogService | |||
| from api.db import StatusEnum | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.user_service import TenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| @@ -27,10 +27,10 @@ from api.db.services.task_service import TaskService | |||
| from rag.nlp import search | |||
| from rag.utils import ELASTICSEARCH | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.db import FileType | |||
| from api.db import FileType, TaskStatus | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| @@ -210,13 +210,12 @@ def rm(): | |||
| @manager.route('/run', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_ids", "run") | |||
| def rm(): | |||
| def run(): | |||
| req = request.json | |||
| try: | |||
| for id in req["doc_ids"]: | |||
| DocumentService.update_by_id(id, {"run": str(req["run"])}) | |||
| if req["run"] == "2": | |||
| TaskService.filter_delete([Task.doc_id == id]) | |||
| DocumentService.update_by_id(id, {"run": str(req["run"]), "progress": 0}) | |||
| if str(req["run"]) == TaskStatus.CANCEL.value: | |||
| tenant_id = DocumentService.get_tenant_id(id) | |||
| if not tenant_id: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| @@ -284,12 +283,13 @@ def change_parser(): | |||
| if doc.parser_id.lower() == req["parser_id"].lower(): | |||
| return get_json_result(data=True) | |||
| e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": "", "run": 1}) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) | |||
| e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if doc.token_num>0: | |||
| e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -21,7 +21,7 @@ from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid, get_format_time | |||
| from api.db import StatusEnum, UserTenantRole | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.db_models import Knowledgebase | |||
| from api.settings import stat_logger, RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| @@ -22,7 +22,7 @@ from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid, get_format_time | |||
| from api.db import StatusEnum, UserTenantRole | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.db_models import Knowledgebase, TenantLLM | |||
| from api.settings import stat_logger, RetCode | |||
| from api.utils.api_utils import get_json_result | |||
| @@ -61,12 +61,19 @@ class ChatStyle(StrEnum): | |||
| CUSTOM = 'Custom' | |||
| class TaskStatus(StrEnum): | |||
| RUNNING = "1" | |||
| CANCEL = "2" | |||
| DONE = "3" | |||
| FAIL = "4" | |||
| class ParserType(StrEnum): | |||
| GENERAL = "general" | |||
| PRESENTATION = "presentation" | |||
| LAWS = "laws" | |||
| MANUAL = "manual" | |||
| PAPER = "paper" | |||
| RESUME = "" | |||
| BOOK = "" | |||
| QA = "" | |||
| RESUME = "resume" | |||
| BOOK = "book" | |||
| QA = "qa" | |||
| @@ -33,8 +33,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): | |||
| DB.create_tables([model]) | |||
| for data in data_source: | |||
| current_time = current_timestamp() | |||
| for i,data in enumerate(data_source): | |||
| current_time = current_timestamp() + i | |||
| current_date = timestamp_to_date(current_time) | |||
| if 'create_time' not in data: | |||
| data['create_time'] = current_time | |||
| @@ -15,11 +15,11 @@ | |||
| # | |||
| from peewee import Expression | |||
| from api.db import TenantPermission, FileType | |||
| from api.db import TenantPermission, FileType, TaskStatus | |||
| from api.db.db_models import DB, Knowledgebase, Tenant | |||
| from api.db.db_models import Document | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.kb_service import KnowledgebaseService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db import StatusEnum | |||
| @@ -71,6 +71,7 @@ class DocumentService(CommonService): | |||
| ~(cls.model.type == FileType.VIRTUAL.value), | |||
| cls.model.progress == 0, | |||
| cls.model.update_time >= tm, | |||
| cls.model.run == TaskStatus.RUNNING.value, | |||
| (Expression(cls.model.create_time, "%%", comm) == mod))\ | |||
| .order_by(cls.model.update_time.asc())\ | |||
| .paginate(1, items_per_page) | |||
| @@ -13,13 +13,52 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from api.db.db_models import Knowledgebase, Document | |||
| from api.db import StatusEnum, TenantPermission | |||
| from api.db.db_models import Knowledgebase, DB, Tenant | |||
| from api.db.services.common_service import CommonService | |||
| class KnowledgebaseService(CommonService): | |||
| model = Knowledgebase | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | |||
| page_number, items_per_page, orderby, desc): | |||
| kbs = cls.model.select().where( | |||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == | |||
| TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) | |||
| & (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| if desc: | |||
| kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) | |||
| else: | |||
| kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) | |||
| class DocumentService(CommonService): | |||
| model = Document | |||
| kbs = kbs.paginate(page_number, items_per_page) | |||
| return list(kbs.dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_detail(cls, kb_id): | |||
| fields = [ | |||
| cls.model.id, | |||
| Tenant.embd_id, | |||
| cls.model.avatar, | |||
| cls.model.name, | |||
| cls.model.description, | |||
| cls.model.permission, | |||
| cls.model.doc_num, | |||
| cls.model.token_num, | |||
| cls.model.chunk_num, | |||
| cls.model.parser_id] | |||
| kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( | |||
| (cls.model.id == kb_id), | |||
| (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| if not kbs: | |||
| return | |||
| d = kbs[0].to_dict() | |||
| d["embd_id"] = kbs[0].tenant.embd_id | |||
| return d | |||
| @@ -1,53 +1,55 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from peewee import Expression | |||
| from api.db.db_models import DB | |||
| from api.db import StatusEnum, FileType | |||
| from api.db.db_models import Task, Document, Knowledgebase, Tenant | |||
| from api.db.services.common_service import CommonService | |||
| class TaskService(CommonService): | |||
| model = Task | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): | |||
| fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||
| .where( | |||
| Document.status == StatusEnum.VALID.value, | |||
| ~(Document.type == FileType.VIRTUAL.value), | |||
| cls.model.progress == 0, | |||
| cls.model.update_time >= tm, | |||
| (Expression(cls.model.create_time, "%%", comm) == mod))\ | |||
| .order_by(cls.model.update_time.asc())\ | |||
| .paginate(1, items_per_page) | |||
| return list(docs.dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def do_cancel(cls, id): | |||
| try: | |||
| cls.model.get_by_id(id) | |||
| return False | |||
| except Exception as e: | |||
| pass | |||
| return True | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from peewee import Expression | |||
| from api.db.db_models import DB | |||
| from api.db import StatusEnum, FileType, TaskStatus | |||
| from api.db.db_models import Task, Document, Knowledgebase, Tenant | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.document_service import DocumentService | |||
| class TaskService(CommonService): | |||
| model = Task | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): | |||
| fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||
| .where( | |||
| Document.status == StatusEnum.VALID.value, | |||
| ~(Document.type == FileType.VIRTUAL.value), | |||
| cls.model.progress == 0, | |||
| cls.model.update_time >= tm, | |||
| (Expression(cls.model.create_time, "%%", comm) == mod))\ | |||
| .order_by(cls.model.update_time.asc())\ | |||
| .paginate(1, items_per_page) | |||
| return list(docs.dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def do_cancel(cls, id): | |||
| try: | |||
| task = cls.model.get_by_id(id) | |||
| _, doc = DocumentService.get_by_id(task.doc_id) | |||
| return doc.run == TaskStatus.CANCEL.value | |||
| except Exception as e: | |||
| pass | |||
| return True | |||
| @@ -143,7 +143,7 @@ def filename_type(filename): | |||
| if re.match(r".*\.pdf$", filename): | |||
| return FileType.PDF.value | |||
| if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||
| if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||
| return FileType.DOC.value | |||
| if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | |||
| @@ -4,14 +4,8 @@ from nltk import word_tokenize | |||
| from rag.nlp import stemmer, huqie | |||
| def callback__(progress, msg, func): | |||
| if not func :return | |||
| func(progress, msg) | |||
| BULLET_PATTERN = [[ | |||
| r"第[零一二三四五六七八九十百]+编", | |||
| r"第[零一二三四五六七八九十百]+(编|部分)", | |||
| r"第[零一二三四五六七八九十百]+章", | |||
| r"第[零一二三四五六七八九十百]+节", | |||
| r"第[零一二三四五六七八九十百]+条", | |||
| @@ -22,6 +16,8 @@ BULLET_PATTERN = [[ | |||
| r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", | |||
| r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", | |||
| ], [ | |||
| r"第[零一二三四五六七八九十百]+章", | |||
| r"第[零一二三四五六七八九十百]+节", | |||
| r"[零一二三四五六七八九十百]+[ 、]", | |||
| r"[\((][零一二三四五六七八九十百]+[\))]", | |||
| r"[\((][0-9]{,2}[\))]", | |||
| @@ -54,7 +50,7 @@ def bullets_category(sections): | |||
| def is_english(texts): | |||
| eng = 0 | |||
| for t in texts: | |||
| if re.match(r"[a-zA-Z]", t.strip()): | |||
| if re.match(r"[a-zA-Z]{2,}", t.strip()): | |||
| eng += 1 | |||
| if eng / len(texts) > 0.8: | |||
| return True | |||
| @@ -70,3 +66,26 @@ def tokenize(d, t, eng): | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| def remove_contents_table(sections, eng=False): | |||
| i = 0 | |||
| while i < len(sections): | |||
| def get(i): | |||
| nonlocal sections | |||
| return (sections[i] if type(sections[i]) == type("") else sections[i][0]).strip() | |||
| if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)): | |||
| i += 1 | |||
| continue | |||
| sections.pop(i) | |||
| if i >= len(sections): break | |||
| prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) | |||
| while not prefix: | |||
| sections.pop(i) | |||
| if i >= len(sections): break | |||
| prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) | |||
| sections.pop(i) | |||
| if i >= len(sections) or not prefix: break | |||
| for j in range(i, min(i+128, len(sections))): | |||
| if not re.match(prefix, get(j)): | |||
| continue | |||
| for _ in range(i, j):sections.pop(i) | |||
| break | |||
| @@ -0,0 +1,156 @@ | |||
| import copy | |||
| import random | |||
| import re | |||
| from io import BytesIO | |||
| from docx import Document | |||
| import numpy as np | |||
| from rag.app import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table | |||
| from rag.nlp import huqie | |||
| from rag.parser.docx_parser import HuDocxParser | |||
| from rag.parser.pdf_parser import HuParser | |||
| class Pdf(HuParser): | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| self.__images__( | |||
| filename if not binary else binary, | |||
| zoomin, | |||
| from_page, | |||
| to_page) | |||
| callback(0.1, "OCR finished") | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| callback(0.47, "Layout analysis finished") | |||
| print("paddle layouts:", timer() - start) | |||
| self._table_transformer_job(zoomin) | |||
| callback(0.68, "Table analysis finished") | |||
| self._text_merge() | |||
| column_width = np.median([b["x1"] - b["x0"] for b in self.boxes]) | |||
| self._concat_downward(concat_between_pages=False) | |||
| self._filter_forpages() | |||
| self._merge_with_same_bullet() | |||
| callback(0.75, "Text merging finished.") | |||
| tbls = self._extract_table_figure(True, zoomin, False) | |||
| callback(0.8, "Text extraction finished") | |||
| return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno","")) for b in self.boxes] | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): | |||
| doc = { | |||
| "docnm_kwd": filename, | |||
| "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) | |||
| } | |||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | |||
| pdf_parser = None | |||
| sections,tbls = [], [] | |||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| doc_parser = HuDocxParser() | |||
| # TODO: table of contents need to be removed | |||
| sections, tbls = doc_parser(binary if binary else filename) | |||
| remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() | |||
| sections,tbls = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| elif re.search(r"\.txt$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary:txt = binary.decode("utf-8") | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l:break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = [(l,"") for l in sections if l] | |||
| remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| bull = bullets_category([b["text"] for b in random.choices([t for t,_ in sections], k=100)]) | |||
| projs = [len(BULLET_PATTERN[bull]) + 1] * len(sections) | |||
| levels = [[]] * len(BULLET_PATTERN[bull]) + 2 | |||
| for i, (txt, layout) in enumerate(sections): | |||
| for j, p in enumerate(BULLET_PATTERN[bull]): | |||
| if re.match(p, txt.strip()): | |||
| projs[i] = j | |||
| levels[j].append(i) | |||
| break | |||
| else: | |||
| if re.search(r"(title|head)", layout): | |||
| projs[i] = BULLET_PATTERN[bull] | |||
| levels[BULLET_PATTERN[bull]].append(i) | |||
| else: | |||
| levels[BULLET_PATTERN[bull] + 1].append(i) | |||
| sections = [t for t,_ in sections] | |||
| def binary_search(arr, target): | |||
| if target > arr[-1]: return len(arr) - 1 | |||
| if target > arr[0]: return -1 | |||
| s, e = 0, len(arr) | |||
| while e - s > 1: | |||
| i = (e + s) // 2 | |||
| if target > arr[i]: | |||
| s = i | |||
| continue | |||
| elif target < arr[i]: | |||
| e = i | |||
| continue | |||
| else: | |||
| assert False | |||
| return s | |||
| cks = [] | |||
| readed = [False] * len(sections) | |||
| levels = levels[::-1] | |||
| for i, arr in enumerate(levels): | |||
| for j in arr: | |||
| if readed[j]: continue | |||
| readed[j] = True | |||
| cks.append([j]) | |||
| if i + 1 == len(levels) - 1: continue | |||
| for ii in range(i + 1, len(levels)): | |||
| jj = binary_search(levels[ii], j) | |||
| if jj < 0: break | |||
| if jj > cks[-1][-1]: cks[-1].pop(-1) | |||
| cks[-1].append(levels[ii][jj]) | |||
| # is it English | |||
| eng = is_english(random.choices(sections, k=218)) | |||
| res = [] | |||
| # add tables | |||
| for img, rows in tbls: | |||
| bs = 10 | |||
| de = ";" if eng else ";" | |||
| for i in range(0, len(rows), bs): | |||
| d = copy.deepcopy(doc) | |||
| r = de.join(rows[i:i + bs]) | |||
| r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r) | |||
| tokenize(d, r, eng) | |||
| d["image"] = img | |||
| res.append(d) | |||
| # wrap up to es documents | |||
| for ck in cks: | |||
| print("\n-".join(ck[::-1])) | |||
| ck = "\n".join(ck[::-1]) | |||
| d = copy.deepcopy(doc) | |||
| if pdf_parser: | |||
| d["image"] = pdf_parser.crop(ck) | |||
| ck = pdf_parser.remove_tag(ck) | |||
| tokenize(d, ck, eng) | |||
| res.append(d) | |||
| return res | |||
| if __name__ == "__main__": | |||
| import sys | |||
| chunk(sys.argv[1]) | |||
| @@ -3,7 +3,7 @@ import re | |||
| from io import BytesIO | |||
| from docx import Document | |||
| import numpy as np | |||
| from rag.app import callback__, bullets_category, BULLET_PATTERN, is_english, tokenize | |||
| from rag.app import bullets_category, BULLET_PATTERN, is_english, tokenize | |||
| from rag.nlp import huqie | |||
| from rag.parser.docx_parser import HuDocxParser | |||
| from rag.parser.pdf_parser import HuParser | |||
| @@ -32,12 +32,12 @@ class Pdf(HuParser): | |||
| zoomin, | |||
| from_page, | |||
| to_page) | |||
| callback__(0.1, "OCR finished", callback) | |||
| callback(0.1, "OCR finished") | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| callback__(0.77, "Layout analysis finished", callback) | |||
| callback(0.77, "Layout analysis finished") | |||
| print("paddle layouts:", timer()-start) | |||
| bxs = self.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) | |||
| # is it English | |||
| @@ -75,7 +75,7 @@ class Pdf(HuParser): | |||
| b["x1"] = max(b["x1"], b_["x1"]) | |||
| bxs.pop(i + 1) | |||
| callback__(0.8, "Text extraction finished", callback) | |||
| callback(0.8, "Text extraction finished") | |||
| return [b["text"] + self._line_tag(b, zoomin) for b in bxs] | |||
| @@ -89,17 +89,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): | |||
| pdf_parser = None | |||
| sections = [] | |||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | |||
| callback__(0.1, "Start to parse.", callback) | |||
| callback(0.1, "Start to parse.") | |||
| for txt in Docx()(filename, binary): | |||
| sections.append(txt) | |||
| callback__(0.8, "Finish parsing.", callback) | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() | |||
| for txt in pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback): | |||
| sections.append(txt) | |||
| elif re.search(r"\.txt$", filename, re.IGNORECASE): | |||
| callback__(0.1, "Start to parse.", callback) | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary:txt = binary.decode("utf-8") | |||
| else: | |||
| @@ -110,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = [l for l in sections if l] | |||
| callback__(0.8, "Finish parsing.", callback) | |||
| callback(0.8, "Finish parsing.") | |||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| # is it English | |||
| @@ -118,7 +118,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): | |||
| # Remove 'Contents' part | |||
| i = 0 | |||
| while i < len(sections): | |||
| if not re.match(r"(Contents|目录|目次)$", re.sub(r"( | |\u3000)+", "", sections[i].split("@@")[0])): | |||
| if not re.match(r"(contents|目录|目次|table of contents)$", re.sub(r"( | |\u3000)+", "", sections[i].split("@@")[0], re.IGNORECASE)): | |||
| i += 1 | |||
| continue | |||
| sections.pop(i) | |||
| @@ -133,7 +133,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): | |||
| for j in range(i, min(i+128, len(sections))): | |||
| if not re.match(prefix, sections[j]): | |||
| continue | |||
| for k in range(i, j):sections.pop(i) | |||
| for _ in range(i, j):sections.pop(i) | |||
| break | |||
| bull = bullets_category(sections) | |||
| @@ -1,6 +1,6 @@ | |||
| import copy | |||
| import re | |||
| from rag.app import callback__, tokenize | |||
| from rag.app import tokenize | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| from rag.utils import num_tokens_from_string | |||
| @@ -14,19 +14,19 @@ class Pdf(HuParser): | |||
| zoomin, | |||
| from_page, | |||
| to_page) | |||
| callback__(0.2, "OCR finished.", callback) | |||
| callback(0.2, "OCR finished.") | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| callback__(0.5, "Layout analysis finished.", callback) | |||
| callback(0.5, "Layout analysis finished.") | |||
| print("paddle layouts:", timer() - start) | |||
| self._table_transformer_job(zoomin) | |||
| callback__(0.7, "Table analysis finished.", callback) | |||
| callback(0.7, "Table analysis finished.") | |||
| self._text_merge() | |||
| self._concat_downward(concat_between_pages=False) | |||
| self._filter_forpages() | |||
| callback__(0.77, "Text merging finished", callback) | |||
| callback(0.77, "Text merging finished") | |||
| tbls = self._extract_table_figure(True, zoomin, False) | |||
| # clean mess | |||
| @@ -34,20 +34,8 @@ class Pdf(HuParser): | |||
| b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip()) | |||
| # merge chunks with the same bullets | |||
| i = 0 | |||
| while i + 1 < len(self.boxes): | |||
| b = self.boxes[i] | |||
| b_ = self.boxes[i + 1] | |||
| if b["text"].strip()[0] != b_["text"].strip()[0] \ | |||
| or b["page_number"]!=b_["page_number"] \ | |||
| or b["top"] > b_["bottom"]: | |||
| i += 1 | |||
| continue | |||
| b_["text"] = b["text"] + "\n" + b_["text"] | |||
| b_["x0"] = min(b["x0"], b_["x0"]) | |||
| b_["x1"] = max(b["x1"], b_["x1"]) | |||
| b_["top"] = b["top"] | |||
| self.boxes.pop(i) | |||
| self._merge_with_same_bullet() | |||
| # merge title with decent chunk | |||
| i = 0 | |||
| while i + 1 < len(self.boxes): | |||
| @@ -62,7 +50,7 @@ class Pdf(HuParser): | |||
| b_["top"] = b["top"] | |||
| self.boxes.pop(i) | |||
| callback__(0.8, "Parsing finished", callback) | |||
| callback(0.8, "Parsing finished") | |||
| for b in self.boxes: print(b["text"], b.get("layoutno")) | |||
| print(tbls) | |||
| @@ -1,11 +1,9 @@ | |||
| import copy | |||
| import re | |||
| from collections import Counter | |||
| from rag.app import callback__, bullets_category, BULLET_PATTERN, is_english, tokenize | |||
| from rag.nlp import huqie, stemmer | |||
| from rag.parser.docx_parser import HuDocxParser | |||
| from rag.app import tokenize | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| from nltk.tokenize import word_tokenize | |||
| import numpy as np | |||
| from rag.utils import num_tokens_from_string | |||
| @@ -18,20 +16,20 @@ class Pdf(HuParser): | |||
| zoomin, | |||
| from_page, | |||
| to_page) | |||
| callback__(0.2, "OCR finished.", callback) | |||
| callback(0.2, "OCR finished.") | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| callback__(0.47, "Layout analysis finished", callback) | |||
| callback(0.47, "Layout analysis finished") | |||
| print("paddle layouts:", timer() - start) | |||
| self._table_transformer_job(zoomin) | |||
| callback__(0.68, "Table analysis finished", callback) | |||
| callback(0.68, "Table analysis finished") | |||
| self._text_merge() | |||
| column_width = np.median([b["x1"] - b["x0"] for b in self.boxes]) | |||
| self._concat_downward(concat_between_pages=False) | |||
| self._filter_forpages() | |||
| callback__(0.75, "Text merging finished.", callback) | |||
| callback(0.75, "Text merging finished.") | |||
| tbls = self._extract_table_figure(True, zoomin, False) | |||
| # clean mess | |||
| @@ -101,7 +99,7 @@ class Pdf(HuParser): | |||
| break | |||
| if not abstr: i = 0 | |||
| callback__(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page)), callback) | |||
| callback(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page))) | |||
| for b in self.boxes: print(b["text"], b.get("layoutno")) | |||
| print(tbls) | |||
| @@ -3,7 +3,7 @@ import re | |||
| from io import BytesIO | |||
| from pptx import Presentation | |||
| from rag.app import callback__, tokenize, is_english | |||
| from rag.app import tokenize, is_english | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| @@ -43,7 +43,7 @@ class Ppt(object): | |||
| if txt: texts.append(txt) | |||
| txts.append("\n".join(texts)) | |||
| callback__(0.5, "Text extraction finished.", callback) | |||
| callback(0.5, "Text extraction finished.") | |||
| import aspose.slides as slides | |||
| import aspose.pydrawing as drawing | |||
| imgs = [] | |||
| @@ -53,7 +53,7 @@ class Ppt(object): | |||
| slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) | |||
| imgs.append(buffered.getvalue()) | |||
| assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) | |||
| callback__(0.9, "Image extraction finished", callback) | |||
| callback(0.9, "Image extraction finished") | |||
| self.is_english = is_english(txts) | |||
| return [(txts[i], imgs[i]) for i in range(len(txts))] | |||
| @@ -70,7 +70,7 @@ class Pdf(HuParser): | |||
| def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): | |||
| self.__images__(filename if not binary else binary, zoomin, from_page, to_page) | |||
| callback__(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)), callback) | |||
| callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page))) | |||
| assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) | |||
| res = [] | |||
| #################### More precisely ################### | |||
| @@ -89,7 +89,7 @@ class Pdf(HuParser): | |||
| for i in range(len(self.boxes)): | |||
| lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])]) | |||
| res.append((lines, self.page_images[i])) | |||
| callback__(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page)), callback) | |||
| callback(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page))) | |||
| return res | |||
| @@ -0,0 +1,104 @@ | |||
| import random | |||
| import re | |||
| from io import BytesIO | |||
| from nltk import word_tokenize | |||
| from openpyxl import load_workbook | |||
| from rag.app import is_english | |||
| from rag.nlp import huqie, stemmer | |||
| class Excel(object): | |||
| def __call__(self, fnm, binary=None, callback=None): | |||
| if not binary: | |||
| wb = load_workbook(fnm) | |||
| else: | |||
| wb = load_workbook(BytesIO(binary)) | |||
| total = 0 | |||
| for sheetname in wb.sheetnames: | |||
| total += len(list(wb[sheetname].rows)) | |||
| res, fails = [], [] | |||
| for sheetname in wb.sheetnames: | |||
| ws = wb[sheetname] | |||
| rows = list(ws.rows) | |||
| for i, r in enumerate(rows): | |||
| q, a = "", "" | |||
| for cell in r: | |||
| if not cell.value: continue | |||
| if not q: q = str(cell.value) | |||
| elif not a: a = str(cell.value) | |||
| else: break | |||
| if q and a: res.append((q, a)) | |||
| else: fails.append(str(i+1)) | |||
| if len(res) % 999 == 0: | |||
| callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else ""))) | |||
| callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1]) | |||
| return res | |||
| def rmPrefix(txt): | |||
| return re.sub(r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) | |||
| def beAdoc(d, q, a, eng): | |||
| qprefix = "Question: " if eng else "问题:" | |||
| aprefix = "Answer: " if eng else "回答:" | |||
| d["content_with_weight"] = "\t".join([qprefix+rmPrefix(q), aprefix+rmPrefix(a)]) | |||
| if eng: | |||
| d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(q)]) | |||
| else: | |||
| d["content_ltks"] = huqie.qie(q) | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| return d | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): | |||
| res = [] | |||
| if re.search(r"\.xlsx?$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| excel_parser = Excel() | |||
| for q,a in excel_parser(filename, binary, callback): | |||
| res.append(beAdoc({}, q, a, excel_parser.is_english)) | |||
| return res | |||
| elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l: break | |||
| txt += l | |||
| lines = txt.split("\n") | |||
| eng = is_english([rmPrefix(l) for l in lines[:100]]) | |||
| fails = [] | |||
| for i, line in enumerate(lines): | |||
| arr = [l for l in line.split("\t") if len(l) > 1] | |||
| if len(arr) != 2: | |||
| fails.append(str(i)) | |||
| continue | |||
| res.append(beAdoc({}, arr[0], arr[1], eng)) | |||
| if len(res) % 999 == 0: | |||
| callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| return res | |||
| raise NotImplementedError("file type not supported yet(pptx, pdf supported)") | |||
| if __name__== "__main__": | |||
| import sys | |||
| def kk(rat, ss): | |||
| pass | |||
| print(chunk(sys.argv[1], callback=kk)) | |||
| @@ -763,7 +763,7 @@ class HuParser: | |||
| return | |||
| i = 0 | |||
| while i < len(self.boxes): | |||
| if not re.match(r"(contents|目录|目次|table of contents)$", re.sub(r"( | |\u3000)+", "", self.boxes[i]["text"].lower())): | |||
| if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", re.sub(r"( | |\u3000)+", "", self.boxes[i]["text"].lower())): | |||
| i += 1 | |||
| continue | |||
| eng = re.match(r"[0-9a-zA-Z :'.-]{5,}", self.boxes[i]["text"].strip()) | |||
| @@ -782,6 +782,22 @@ class HuParser: | |||
| for k in range(i, j): self.boxes.pop(i) | |||
| break | |||
| def _merge_with_same_bullet(self): | |||
| i = 0 | |||
| while i + 1 < len(self.boxes): | |||
| b = self.boxes[i] | |||
| b_ = self.boxes[i + 1] | |||
| if b["text"].strip()[0] != b_["text"].strip()[0] \ | |||
| or b["text"].strip()[0].lower() in set("qwertyuopasdfghjklzxcvbnm") \ | |||
| or b["top"] > b_["bottom"]: | |||
| i += 1 | |||
| continue | |||
| b_["text"] = b["text"] + "\n" + b_["text"] | |||
| b_["x0"] = min(b["x0"], b_["x0"]) | |||
| b_["x1"] = max(b["x1"], b_["x1"]) | |||
| b_["top"] = b["top"] | |||
| self.boxes.pop(i) | |||
| def _blockType(self, b): | |||
| patt = [ | |||
| ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), | |||
| @@ -1,130 +1,138 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import os | |||
| import time | |||
| import random | |||
| from timeit import default_timer as timer | |||
| from api.db.db_models import Task | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.db.services.task_service import TaskService | |||
| from rag.parser.pdf_parser import HuParser | |||
| from rag.settings import cron_logger | |||
| from rag.utils import MINIO | |||
| from rag.utils import findMaxTm | |||
| import pandas as pd | |||
| from api.db import FileType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import database_logger | |||
| from api.utils import get_format_time, get_uuid | |||
| from api.utils.file_utils import get_project_base_directory | |||
| def collect(tm): | |||
| docs = DocumentService.get_newly_uploaded(tm) | |||
| if len(docs) == 0: | |||
| return pd.DataFrame() | |||
| docs = pd.DataFrame(docs) | |||
| mtm = docs["update_time"].max() | |||
| cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) | |||
| return docs | |||
| def set_dispatching(docid): | |||
| try: | |||
| DocumentService.update_by_id( | |||
| docid, {"progress": random.randint(0, 3) / 100., | |||
| "progress_msg": "Task dispatched...", | |||
| "process_begin_at": get_format_time() | |||
| }) | |||
| except Exception as e: | |||
| cron_logger.error("set_dispatching:({}), {}".format(docid, str(e))) | |||
| def dispatch(): | |||
| tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") | |||
| tm = findMaxTm(tm_fnm) | |||
| rows = collect(tm) | |||
| if len(rows) == 0: | |||
| return | |||
| tmf = open(tm_fnm, "a+") | |||
| for _, r in rows.iterrows(): | |||
| try: | |||
| tsks = TaskService.query(doc_id=r["id"]) | |||
| if tsks: | |||
| for t in tsks: | |||
| TaskService.delete_by_id(t.id) | |||
| except Exception as e: | |||
| cron_logger.error("delete task exception:" + str(e)) | |||
| def new_task(): | |||
| nonlocal r | |||
| return { | |||
| "id": get_uuid(), | |||
| "doc_id": r["id"] | |||
| } | |||
| tsks = [] | |||
| if r["type"] == FileType.PDF.value: | |||
| pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||
| for p in range(0, pages, 10): | |||
| task = new_task() | |||
| task["from_page"] = p | |||
| task["to_page"] = min(p + 10, pages) | |||
| tsks.append(task) | |||
| else: | |||
| tsks.append(new_task()) | |||
| print(tsks) | |||
| bulk_insert_into_db(Task, tsks, True) | |||
| set_dispatching(r["id"]) | |||
| tmf.write(str(r["update_time"]) + "\n") | |||
| tmf.close() | |||
| def update_progress(): | |||
| docs = DocumentService.get_unfinished_docs() | |||
| for d in docs: | |||
| try: | |||
| tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) | |||
| if not tsks:continue | |||
| msg = [] | |||
| prg = 0 | |||
| finished = True | |||
| bad = 0 | |||
| for t in tsks: | |||
| if 0 <= t.progress < 1: finished = False | |||
| prg += t.progress if t.progress >= 0 else 0 | |||
| msg.append(t.progress_msg) | |||
| if t.progress == -1: bad += 1 | |||
| prg /= len(tsks) | |||
| if finished and bad: prg = -1 | |||
| msg = "\n".join(msg) | |||
| DocumentService.update_by_id(d["id"], {"progress": prg, "progress_msg": msg, "process_duation": timer()-d["process_begin_at"].timestamp()}) | |||
| except Exception as e: | |||
| cron_logger.error("fetch task exception:" + str(e)) | |||
| if __name__ == "__main__": | |||
| peewee_logger = logging.getLogger('peewee') | |||
| peewee_logger.propagate = False | |||
| peewee_logger.addHandler(database_logger.handlers[0]) | |||
| peewee_logger.setLevel(database_logger.level) | |||
| while True: | |||
| dispatch() | |||
| time.sleep(3) | |||
| update_progress() | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import os | |||
| import time | |||
| import random | |||
| from datetime import datetime | |||
| from api.db.db_models import Task | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.db.services.task_service import TaskService | |||
| from rag.parser.pdf_parser import HuParser | |||
| from rag.settings import cron_logger | |||
| from rag.utils import MINIO | |||
| from rag.utils import findMaxTm | |||
| import pandas as pd | |||
| from api.db import FileType, TaskStatus | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import database_logger | |||
| from api.utils import get_format_time, get_uuid | |||
| from api.utils.file_utils import get_project_base_directory | |||
| def collect(tm): | |||
| docs = DocumentService.get_newly_uploaded(tm) | |||
| if len(docs) == 0: | |||
| return pd.DataFrame() | |||
| docs = pd.DataFrame(docs) | |||
| mtm = docs["update_time"].max() | |||
| cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) | |||
| return docs | |||
| def set_dispatching(docid): | |||
| try: | |||
| DocumentService.update_by_id( | |||
| docid, {"progress": random.randint(0, 3) / 100., | |||
| "progress_msg": "Task dispatched...", | |||
| "process_begin_at": get_format_time() | |||
| }) | |||
| except Exception as e: | |||
| cron_logger.error("set_dispatching:({}), {}".format(docid, str(e))) | |||
| def dispatch(): | |||
| tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") | |||
| tm = findMaxTm(tm_fnm) | |||
| rows = collect(tm) | |||
| if len(rows) == 0: | |||
| return | |||
| tmf = open(tm_fnm, "a+") | |||
| for _, r in rows.iterrows(): | |||
| try: | |||
| tsks = TaskService.query(doc_id=r["id"]) | |||
| if tsks: | |||
| for t in tsks: | |||
| TaskService.delete_by_id(t.id) | |||
| except Exception as e: | |||
| cron_logger.error("delete task exception:" + str(e)) | |||
| def new_task(): | |||
| nonlocal r | |||
| return { | |||
| "id": get_uuid(), | |||
| "doc_id": r["id"] | |||
| } | |||
| tsks = [] | |||
| if r["type"] == FileType.PDF.value: | |||
| pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||
| for p in range(0, pages, 10): | |||
| task = new_task() | |||
| task["from_page"] = p | |||
| task["to_page"] = min(p + 10, pages) | |||
| tsks.append(task) | |||
| else: | |||
| tsks.append(new_task()) | |||
| print(tsks) | |||
| bulk_insert_into_db(Task, tsks, True) | |||
| set_dispatching(r["id"]) | |||
| tmf.write(str(r["update_time"]) + "\n") | |||
| tmf.close() | |||
| def update_progress(): | |||
| docs = DocumentService.get_unfinished_docs() | |||
| for d in docs: | |||
| try: | |||
| tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) | |||
| if not tsks:continue | |||
| msg = [] | |||
| prg = 0 | |||
| finished = True | |||
| bad = 0 | |||
| status = TaskStatus.RUNNING.value | |||
| for t in tsks: | |||
| if 0 <= t.progress < 1: finished = False | |||
| prg += t.progress if t.progress >= 0 else 0 | |||
| msg.append(t.progress_msg) | |||
| if t.progress == -1: bad += 1 | |||
| prg /= len(tsks) | |||
| if finished and bad: | |||
| prg = -1 | |||
| status = TaskStatus.FAIL.value | |||
| elif finished: status = TaskStatus.DONE.value | |||
| msg = "\n".join(msg) | |||
| info = {"process_duation": datetime.timestamp(datetime.now())-d["process_begin_at"].timestamp(), "run": status} | |||
| if prg !=0 : info["progress"] = prg | |||
| if msg: info["progress_msg"] = msg | |||
| DocumentService.update_by_id(d["id"], info) | |||
| except Exception as e: | |||
| cron_logger.error("fetch task exception:" + str(e)) | |||
| if __name__ == "__main__": | |||
| peewee_logger = logging.getLogger('peewee') | |||
| peewee_logger.propagate = False | |||
| peewee_logger.addHandler(database_logger.handlers[0]) | |||
| peewee_logger.setLevel(database_logger.level) | |||
| while True: | |||
| dispatch() | |||
| time.sleep(3) | |||
| update_progress() | |||
| @@ -24,8 +24,9 @@ import sys | |||
| from functools import partial | |||
| from timeit import default_timer as timer | |||
| from elasticsearch_dsl import Q | |||
| from api.db.services.task_service import TaskService | |||
| from rag.llm import EmbeddingModel, CvModel | |||
| from rag.settings import cron_logger, DOC_MAXIMUM_SIZE | |||
| from rag.utils import ELASTICSEARCH | |||
| from rag.utils import MINIO | |||
| @@ -35,7 +36,7 @@ from rag.nlp import search | |||
| from io import BytesIO | |||
| import pandas as pd | |||
| from rag.app import laws, paper, presentation, manual | |||
| from rag.app import laws, paper, presentation, manual, qa | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.document_service import DocumentService | |||
| @@ -51,13 +52,14 @@ FACTORY = { | |||
| ParserType.PRESENTATION.value: presentation, | |||
| ParserType.MANUAL.value: manual, | |||
| ParserType.LAWS.value: laws, | |||
| ParserType.QA.value: qa, | |||
| } | |||
| def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."): | |||
| cancel = TaskService.do_cancel(task_id) | |||
| if cancel: | |||
| msg = "Canceled." | |||
| msg += " [Canceled]" | |||
| prog = -1 | |||
| if to_page > 0: msg = f"Page({from_page}~{to_page}): " + msg | |||
| @@ -166,13 +168,16 @@ def init_kb(row): | |||
| def embedding(docs, mdl): | |||
| tts, cnts = [d["docnm_kwd"] for d in docs], [d["content_with_weight"] for d in docs] | |||
| tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs] | |||
| tk_count = 0 | |||
| tts, c = mdl.encode(tts) | |||
| tk_count += c | |||
| if len(tts) == len(cnts): | |||
| tts, c = mdl.encode(tts) | |||
| tk_count += c | |||
| cnts, c = mdl.encode(cnts) | |||
| tk_count += c | |||
| vects = 0.1 * tts + 0.9 * cnts | |||
| vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts | |||
| assert len(vects) == len(docs) | |||
| for i, d in enumerate(docs): | |||
| v = vects[i].tolist() | |||
| @@ -215,12 +220,14 @@ def main(comm, mod): | |||
| callback(msg="Finished embedding! Start to build index!") | |||
| init_kb(r) | |||
| chunk_count = len(set([c["_id"] for c in cks])) | |||
| callback(1., "Done!") | |||
| es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"])) | |||
| if es_r: | |||
| callback(-1, "Index failure!") | |||
| cron_logger.error(str(es_r)) | |||
| else: | |||
| if TaskService.do_cancel(r["id"]): | |||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) | |||
| callback(1., "Done!") | |||
| DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) | |||
| cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) | |||