浏览代码

Feat: repair corrupted PDF files on upload automatically (#7693)

### What problem does this PR solve?

Try the best to repair corrupted PDF files on upload automatically.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.19.0
Yongteng Lei 5 个月前
父节点
当前提交
0ebf05440e
没有帐户链接到提交者的电子邮件
共有 5 个文件被更改,包括 252 次插入324 次删除
  1. 2
    2
      .github/workflows/tests.yml
  2. 2
    1
      Dockerfile
  3. 121
    204
      api/apps/document_app.py
  4. 56
    87
      api/db/services/file_service.py
  5. 71
    30
      api/utils/file_utils.py

+ 2
- 2
.github/workflows/tests.yml 查看文件



# https://github.com/astral-sh/ruff-action # https://github.com/astral-sh/ruff-action
- name: Static check with Ruff - name: Static check with Ruff
uses: astral-sh/ruff-action@v2
uses: astral-sh/ruff-action@v3
with: with:
version: ">=0.8.2"
version: ">=0.11.x"
args: "check" args: "check"


- name: Build ragflow:nightly-slim - name: Build ragflow:nightly-slim

+ 2
- 1
Dockerfile 查看文件

apt install -y libatk-bridge2.0-0 && \ apt install -y libatk-bridge2.0-0 && \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \ apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript


RUN if [ "$NEED_MIRROR" == "1" ]; then \ RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple && \ pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple && \

+ 121
- 204
api/apps/document_app.py 查看文件



import flask import flask
from flask import request from flask import request
from flask_login import login_required, current_user
from flask_login import current_user, login_required


from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search

from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, TaskStatus, ParserType, FileSource
from api import settings
from api.constants import IMG_BASE64_PREFIX
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
from api.db.db_models import File, Task from api.db.db_models import File, Task
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.task_service import queue_tasks
from api.db.services.user_service import UserTenantService
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import UserTenantService
from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
server_error_response,
get_data_error_result, get_data_error_result,
get_json_result,
server_error_response,
validate_request, validate_request,
) )
from api.utils import get_uuid
from api import settings
from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail
from api.utils.web_utils import html2pdf, is_valid_url from api.utils.web_utils import html2pdf, is_valid_url
from api.constants import IMG_BASE64_PREFIX
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search
from rag.utils.storage_factory import STORAGE_IMPL




@manager.route('/upload', methods=['POST']) # noqa: F821
@manager.route("/upload", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("kb_id") @validate_request("kb_id")
def upload(): def upload():
kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if "file" not in request.files:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)


file_objs = request.files.getlist('file')
file_objs = request.files.getlist("file")
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
if file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)


e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")

err, files = FileService.upload_document(kb, file_objs, current_user.id) err, files = FileService.upload_document(kb, file_objs, current_user.id)
files = [f[0] for f in files] # remove the blob

if not files:
return get_json_result(data=files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR)
files = [f[0] for f in files] # remove the blob

if err: if err:
return get_json_result(
data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=files) return get_json_result(data=files)




@manager.route('/web_crawl', methods=['POST']) # noqa: F821
@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("kb_id", "name", "url") @validate_request("kb_id", "name", "url")
def web_crawl(): def web_crawl():
kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
name = request.form.get("name") name = request.form.get("name")
url = request.form.get("url") url = request.form.get("url")
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result(
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])


try: try:
filename = duplicate_name(
DocumentService.query,
name=name + ".pdf",
kb_id=kb.id)
filename = duplicate_name(DocumentService.query, name=name + ".pdf", kb_id=kb.id)
filetype = filename_type(filename) filetype = filename_type(filename)
if filetype == FileType.OTHER.value: if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!") raise RuntimeError("This type of file has not been supported yet!")
"name": filename, "name": filename,
"location": location, "location": location,
"size": len(blob), "size": len(blob),
"thumbnail": thumbnail(filename, blob)
"thumbnail": thumbnail(filename, blob),
} }
if doc["type"] == FileType.VISUAL: if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value doc["parser_id"] = ParserType.PICTURE.value
return get_json_result(data=True) return get_json_result(data=True)




@manager.route('/create', methods=['POST']) # noqa: F821
@manager.route("/create", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("name", "kb_id") @validate_request("name", "kb_id")
def create(): def create():
req = request.json req = request.json
kb_id = req["kb_id"] kb_id = req["kb_id"]
if not kb_id: if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)


try: try:
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
return get_data_error_result(message="Can't find this knowledgebase!")


if DocumentService.query(name=req["name"], kb_id=kb_id): if DocumentService.query(name=req["name"], kb_id=kb_id):
return get_data_error_result(
message="Duplicated document name in the same knowledgebase.")

doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": FileType.VIRTUAL,
"name": req["name"],
"location": "",
"size": 0
})
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")

doc = DocumentService.insert(
{
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": FileType.VIRTUAL,
"name": req["name"],
"location": "",
"size": 0,
}
)
return get_json_result(data=doc.to_json()) return get_json_result(data=doc.to_json())
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)




@manager.route('/list', methods=['POST']) # noqa: F821
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required @login_required
def list_docs(): def list_docs():
kb_id = request.args.get("kb_id") kb_id = request.args.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
break break
else: else:
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")


page_number = int(request.args.get("page", 0)) page_number = int(request.args.get("page", 0))
if run_status: if run_status:
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS} invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
if invalid_status: if invalid_status:
return get_data_error_result(
message=f"Invalid filter run status conditions: {', '.join(invalid_status)}"
)
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")


types = req.get("types", []) types = req.get("types", [])
if types: if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES} invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types: if invalid_types:
return get_data_error_result(
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
)
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")


try: try:
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types)
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types)


for doc_item in docs: for doc_item in docs:
if doc_item['thumbnail'] and not doc_item['thumbnail'].startswith(IMG_BASE64_PREFIX):
doc_item['thumbnail'] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"


return get_json_result(data={"total": tol, "docs": docs}) return get_json_result(data={"total": tol, "docs": docs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)




@manager.route('/infos', methods=['POST']) # noqa: F821
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required @login_required
def docinfos(): def docinfos():
req = request.json req = request.json
doc_ids = req["doc_ids"] doc_ids = req["doc_ids"]
for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
docs = DocumentService.get_by_ids(doc_ids) docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts())) return get_json_result(data=list(docs.dicts()))




@manager.route('/thumbnails', methods=['GET']) # noqa: F821
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
# @login_required # @login_required
def thumbnails(): def thumbnails():
doc_ids = request.args.get("doc_ids").split(",") doc_ids = request.args.get("doc_ids").split(",")
if not doc_ids: if not doc_ids:
return get_json_result(
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)


try: try:
docs = DocumentService.get_thumbnails(doc_ids) docs = DocumentService.get_thumbnails(doc_ids)


for doc_item in docs: for doc_item in docs:
if doc_item['thumbnail'] and not doc_item['thumbnail'].startswith(IMG_BASE64_PREFIX):
doc_item['thumbnail'] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}"
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
doc_item["thumbnail"] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}"


return get_json_result(data={d["id"]: d["thumbnail"] for d in docs}) return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)




@manager.route('/change_status', methods=['POST']) # noqa: F821
@manager.route("/change_status", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "status") @validate_request("doc_id", "status")
def change_status(): def change_status():
req = request.json req = request.json
if str(req["status"]) not in ["0", "1"]: if str(req["status"]) not in ["0", "1"]:
return get_json_result(
data=False,
message='"Status" must be either 0 or 1!',
code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=settings.RetCode.ARGUMENT_ERROR)


if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)


try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e: if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
return get_data_error_result(message="Can't find this knowledgebase!")


if not DocumentService.update_by_id(
req["doc_id"], {"status": str(req["status"])}):
return get_data_error_result(
message="Database error (Document update)!")
if not DocumentService.update_by_id(req["doc_id"], {"status": str(req["status"])}):
return get_data_error_result(message="Database error (Document update)!")


status = int(req["status"]) status = int(req["status"])
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
search.index_name(kb.tenant_id), doc.kb_id)
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)




@manager.route('/rm', methods=['POST']) # noqa: F821
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
def rm(): def rm():


for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id): if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)


root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"] pf_id = root_folder["id"]


TaskService.filter_delete([Task.doc_id == doc_id]) TaskService.filter_delete([Task.doc_id == doc_id])
if not DocumentService.remove_document(doc, tenant_id): if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
return get_data_error_result(message="Database error (Document removal)!")


f2d = File2DocumentService.get_by_document_id(doc_id) f2d = File2DocumentService.get_by_document_id(doc_id)
deleted_file_count = 0 deleted_file_count = 0
return get_json_result(data=True) return get_json_result(data=True)




@manager.route('/run', methods=['POST']) # noqa: F821
@manager.route("/run", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
def run():
def run():
req = request.json req = request.json
for doc_id in req["doc_ids"]: for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
kb_table_num_map = {} kb_table_num_map = {}
for id in req["doc_ids"]: for id in req["doc_ids"]:
if kb_id not in kb_table_num_map: if kb_id not in kb_table_num_map:
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[]) count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
kb_table_num_map[kb_id] = count kb_table_num_map[kb_id] = count
if kb_table_num_map[kb_id] <=0:
if kb_table_num_map[kb_id] <= 0:
KnowledgebaseService.delete_field_map(kb_id) KnowledgebaseService.delete_field_map(kb_id)
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0) queue_tasks(doc, bucket, name, 0)
return server_error_response(e) return server_error_response(e)




@manager.route('/rename', methods=['POST']) # noqa: F821
@manager.route("/rename", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
def rename(): def rename():
req = request.json req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
doc.name.lower()).suffix:
return get_json_result(
data=False,
message="The extension of file can't be changed",
code=settings.RetCode.ARGUMENT_ERROR)
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]: if d.name == req["name"]:
return get_data_error_result(
message="Duplicated document name in the same knowledgebase.")
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")


if not DocumentService.update_by_id(
req["doc_id"], {"name": req["name"]}):
return get_data_error_result(
message="Database error (Document rename)!")
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
return get_data_error_result(message="Database error (Document rename)!")


informs = File2DocumentService.get_by_document_id(req["doc_id"]) informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs: if informs:
return server_error_response(e) return server_error_response(e)




@manager.route('/get/<doc_id>', methods=['GET']) # noqa: F821
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
# @login_required # @login_required
def get(doc_id): def get(doc_id):
try: try:
ext = re.search(r"\.([^.]+)$", doc.name) ext = re.search(r"\.([^.]+)$", doc.name)
if ext: if ext:
if doc.type == FileType.VISUAL.value: if doc.type == FileType.VISUAL.value:
response.headers.set('Content-Type', 'image/%s' % ext.group(1))
response.headers.set("Content-Type", "image/%s" % ext.group(1))
else: else:
response.headers.set(
'Content-Type',
'application/%s' %
ext.group(1))
response.headers.set("Content-Type", "application/%s" % ext.group(1))
return response return response
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)




@manager.route('/change_parser', methods=['POST']) # noqa: F821
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "parser_id") @validate_request("doc_id", "parser_id")
def change_parser(): def change_parser():
req = request.json req = request.json


if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
else: else:
return get_json_result(data=True) return get_json_result(data=True)


if ((doc.type == FileType.VISUAL and req["parser_id"] != "picture")
or (re.search(
r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation")):
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
return get_data_error_result(message="Not supported yet!") return get_data_error_result(message="Not supported yet!")


e = DocumentService.update_by_id(doc.id,
{"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
"run": TaskStatus.UNSTART.value})
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if "parser_config" in req: if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"]) DocumentService.update_parser_config(doc.id, req["parser_config"])
if doc.token_num > 0: if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
doc.process_duation * -1)
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duation * -1)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
return server_error_response(e) return server_error_response(e)




@manager.route('/image/<image_id>', methods=['GET']) # noqa: F821
@manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
# @login_required # @login_required
def get_image(image_id): def get_image(image_id):
try: try:
return get_data_error_result(message="Image not found.") return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
response = flask.make_response(STORAGE_IMPL.get(bkt, nm)) response = flask.make_response(STORAGE_IMPL.get(bkt, nm))
response.headers.set('Content-Type', 'image/JPEG')
response.headers.set("Content-Type", "image/JPEG")
return response return response
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)




@manager.route('/upload_and_parse', methods=['POST']) # noqa: F821
@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("conversation_id") @validate_request("conversation_id")
def upload_and_parse(): def upload_and_parse():
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
if "file" not in request.files:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)


file_objs = request.files.getlist('file')
file_objs = request.files.getlist("file")
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
if file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)


doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)


return get_json_result(data=doc_ids) return get_json_result(data=doc_ids)




@manager.route('/parse', methods=['POST']) # noqa: F821
@manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required @login_required
def parse(): def parse():
url = request.json.get("url") if request.json else "" url = request.json.get("url") if request.json else ""
if url: if url:
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result(
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
download_path = os.path.join(get_project_base_directory(), "logs/downloads") download_path = os.path.join(get_project_base_directory(), "logs/downloads")
os.makedirs(download_path, exist_ok=True) os.makedirs(download_path, exist_ok=True)
from seleniumwire.webdriver import Chrome, ChromeOptions from seleniumwire.webdriver import Chrome, ChromeOptions

options = ChromeOptions() options = ChromeOptions()
options.add_argument('--headless')
options.add_argument('--disable-gpu')
options.add_argument('--no-sandbox')
options.add_argument('--disable-dev-shm-usage')
options.add_experimental_option('prefs', {
'download.default_directory': download_path,
'download.prompt_for_download': False,
'download.directory_upgrade': True,
'safebrowsing.enabled': True
})
options.add_argument("--headless")
options.add_argument("--disable-gpu")
options.add_argument("--no-sandbox")
options.add_argument("--disable-dev-shm-usage")
options.add_experimental_option("prefs", {"download.default_directory": download_path, "download.prompt_for_download": False, "download.directory_upgrade": True, "safebrowsing.enabled": True})
driver = Chrome(options=options) driver = Chrome(options=options)
driver.get(url) driver.get(url)
res_headers = [r.response.headers for r in driver.requests if r and r.response] res_headers = [r.response.headers for r in driver.requests if r and r.response]


r = re.search(r"filename=\"([^\"]+)\"", str(res_headers)) r = re.search(r"filename=\"([^\"]+)\"", str(res_headers))
if not r or not r.group(1): if not r or not r.group(1):
return get_json_result(
data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
f = File(r.group(1), os.path.join(download_path, r.group(1))) f = File(r.group(1), os.path.join(download_path, r.group(1)))
txt = FileService.parse_docs([f], current_user.id) txt = FileService.parse_docs([f], current_user.id)
return get_json_result(data=txt) return get_json_result(data=txt)


if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
if "file" not in request.files:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)


file_objs = request.files.getlist('file')
file_objs = request.files.getlist("file")
txt = FileService.parse_docs(file_objs, current_user.id) txt = FileService.parse_docs(file_objs, current_user.id)


return get_json_result(data=txt) return get_json_result(data=txt)




@manager.route('/set_meta', methods=['POST']) # noqa: F821
@manager.route("/set_meta", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "meta") @validate_request("doc_id", "meta")
def set_meta(): def set_meta():
req = request.json req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
meta = json.loads(req["meta"]) meta = json.loads(req["meta"])
except Exception as e: except Exception as e:
return get_json_result(
data=False, message=f'Json syntax error: {e}', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR)
if not isinstance(meta, dict): if not isinstance(meta, dict):
return get_json_result(
data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR)


try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")


if not DocumentService.update_by_id(
req["doc_id"], {"meta_fields": meta}):
return get_data_error_result(
message="Database error (meta updates)!")
if not DocumentService.update_by_id(req["doc_id"], {"meta_fields": meta}):
return get_data_error_result(message="Database error (meta updates)!")


return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:

+ 56
- 87
api/db/services/file_service.py 查看文件

# limitations under the License. # limitations under the License.
# #
import logging import logging
import re
import os import os
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor


from flask_login import current_user from flask_login import current_user
from peewee import fn from peewee import fn


from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource, ParserType
from api.db.db_models import DB, File2Document, Knowledgebase
from api.db.db_models import File, Document
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.file_utils import filename_type, thumbnail_img
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL






@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page,
orderby, desc, keywords):
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords):
# Get files by parent folder ID with pagination and filtering # Get files by parent folder ID with pagination and filtering
# Args: # Args:
# tenant_id: ID of the tenant # tenant_id: ID of the tenant
# Returns: # Returns:
# Tuple of (file_list, total_count) # Tuple of (file_list, total_count)
if keywords: if keywords:
files = cls.model.select().where(
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == pf_id),
(fn.LOWER(cls.model.name).contains(keywords.lower())),
~(cls.model.id == pf_id)
)
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).contains(keywords.lower())), ~(cls.model.id == pf_id))
else: else:
files = cls.model.select().where((cls.model.tenant_id == tenant_id),
(cls.model.parent_id == pf_id),
~(cls.model.id == pf_id)
)
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), ~(cls.model.id == pf_id))
count = files.count() count = files.count()
if desc: if desc:
files = files.order_by(cls.model.getter_by(orderby).desc()) files = files.order_by(cls.model.getter_by(orderby).desc())
for file in res_files: for file in res_files:
if file["type"] == FileType.FOLDER.value: if file["type"] == FileType.FOLDER.value:
file["size"] = cls.get_folder_size(file["id"]) file["size"] = cls.get_folder_size(file["id"])
file['kbs_info'] = []
children = list(cls.model.select().where(
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == file["id"]),
~(cls.model.id == file["id"]),
).dicts())
file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children)
file["kbs_info"] = []
children = list(
cls.model.select()
.where(
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == file["id"]),
~(cls.model.id == file["id"]),
)
.dicts()
)
file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children)
continue continue
kbs_info = cls.get_kb_id_by_file_id(file['id'])
file['kbs_info'] = kbs_info
kbs_info = cls.get_kb_id_by_file_id(file["id"])
file["kbs_info"] = kbs_info


return res_files, count return res_files, count


# file_id: File ID # file_id: File ID
# Returns: # Returns:
# List of dictionaries containing knowledge base IDs and names # List of dictionaries containing knowledge base IDs and names
kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
.join(File2Document, on=(File2Document.file_id == file_id))
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id))
kbs = (
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
.join(File2Document, on=(File2Document.file_id == file_id))
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id)
)
if not kbs: if not kbs:
return [] return []
kbs_info_list = [] kbs_info_list = []
for kb in list(kbs.dicts()): for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
return kbs_info_list return kbs_info_list


@classmethod @classmethod
if count > len(name) - 2: if count > len(name) - 2:
return file return file
else: else:
file = cls.insert({
"id": get_uuid(),
"parent_id": parent_id,
"tenant_id": current_user.id,
"created_by": current_user.id,
"name": name[count],
"location": "",
"size": 0,
"type": FileType.FOLDER.value
})
file = cls.insert(
{"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value}
)
return cls.create_folder(file, file.id, name, count + 1) return cls.create_folder(file, file.id, name, count + 1)


@classmethod @classmethod
# tenant_id: Tenant ID # tenant_id: Tenant ID
# Returns: # Returns:
# Root folder dictionary # Root folder dictionary
for file in cls.model.select().where((cls.model.tenant_id == tenant_id),
(cls.model.parent_id == cls.model.id)
):
for file in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
return file.to_dict() return file.to_dict()


file_id = get_uuid() file_id = get_uuid()
# tenant_id: Tenant ID # tenant_id: Tenant ID
# Returns: # Returns:
# Knowledge base folder dictionary # Knowledge base folder dictionary
for root in cls.model.select().where(
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
for folder in cls.model.select().where(
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id),
(cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)):
for root in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
for folder in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id), (cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)):
return folder.to_dict() return folder.to_dict()
assert False, "Can't find the KB folder. Database init error." assert False, "Can't find the KB folder. Database init error."


"type": ty, "type": ty,
"size": size, "size": size,
"location": location, "location": location,
"source_type": FileSource.KNOWLEDGEBASE
"source_type": FileSource.KNOWLEDGEBASE,
} }
cls.save(**file) cls.save(**file)
return file return file
# Args: # Args:
# root_id: Root folder ID # root_id: Root folder ID
# tenant_id: Tenant ID # tenant_id: Tenant ID
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\
& (cls.model.parent_id == root_id)):
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME) & (cls.model.parent_id == root_id)):
return return
folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id) folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)


for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id==tenant_id):
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id == tenant_id):
kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"]) kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"])
for doc in DocumentService.query(kb_id=kb.id): for doc in DocumentService.query(kb_id=kb.id):
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id) FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id)
@DB.connection_context() @DB.connection_context()
def delete_folder_by_pf_id(cls, user_id, folder_id): def delete_folder_by_pf_id(cls, user_id, folder_id):
try: try:
files = cls.model.select().where((cls.model.tenant_id == user_id)
& (cls.model.parent_id == folder_id))
files = cls.model.select().where((cls.model.tenant_id == user_id) & (cls.model.parent_id == folder_id))
for file in files: for file in files:
cls.delete_folder_by_pf_id(user_id, file.id) cls.delete_folder_by_pf_id(user_id, file.id)
return cls.model.delete().where((cls.model.tenant_id == user_id)
& (cls.model.id == folder_id)).execute(),
return (cls.model.delete().where((cls.model.tenant_id == user_id) & (cls.model.id == folder_id)).execute(),)
except Exception: except Exception:
logging.exception("delete_folder_by_pf_id") logging.exception("delete_folder_by_pf_id")
raise RuntimeError("Database error (File retrieval)!") raise RuntimeError("Database error (File retrieval)!")


def dfs(parent_id): def dfs(parent_id):
nonlocal size nonlocal size
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(
cls.model.parent_id == parent_id, cls.model.id != parent_id):
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id, cls.model.id != parent_id):
size += f.size size += f.size
if f.type == FileType.FOLDER.value: if f.type == FileType.FOLDER.value:
dfs(f.id) dfs(f.id)
"type": doc["type"], "type": doc["type"],
"size": doc["size"], "size": doc["size"],
"location": doc["location"], "location": doc["location"],
"source_type": FileSource.KNOWLEDGEBASE
"source_type": FileSource.KNOWLEDGEBASE,
} }
cls.save(**file) cls.save(**file)
File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]}) File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]})
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def move_file(cls, file_ids, folder_id): def move_file(cls, file_ids, folder_id):
try: try:
cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id })
cls.filter_update((cls.model.id << file_ids,), {"parent_id": folder_id})
except Exception: except Exception:
logging.exception("move_file") logging.exception("move_file")
raise RuntimeError("Database error (File move)!") raise RuntimeError("Database error (File move)!")
err, files = [], [] err, files = [], []
for file in file_objs: for file in file_objs:
try: try:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!") raise RuntimeError("Exceed the maximum file number of a free user!")
if len(file.filename.encode("utf-8")) >= 128: if len(file.filename.encode("utf-8")) >= 128:
raise RuntimeError("Exceed the maximum length of file name!") raise RuntimeError("Exceed the maximum length of file name!")


filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb.id)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
filetype = filename_type(filename) filetype = filename_type(filename)
if filetype == FileType.OTHER.value: if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!") raise RuntimeError("This type of file has not been supported yet!")
location = filename location = filename
while STORAGE_IMPL.obj_exist(kb.id, location): while STORAGE_IMPL.obj_exist(kb.id, location):
location += "_" location += "_"

blob = file.read() blob = file.read()
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
STORAGE_IMPL.put(kb.id, location, blob) STORAGE_IMPL.put(kb.id, location, blob)


doc_id = get_uuid() doc_id = get_uuid()


img = thumbnail_img(filename, blob) img = thumbnail_img(filename, blob)
thumbnail_location = ''
thumbnail_location = ""
if img is not None: if img is not None:
thumbnail_location = f'thumbnail_{doc_id}.png'
thumbnail_location = f"thumbnail_{doc_id}.png"
STORAGE_IMPL.put(kb.id, thumbnail_location, img) STORAGE_IMPL.put(kb.id, thumbnail_location, img)


doc = { doc = {
"name": filename, "name": filename,
"location": location, "location": location,
"size": len(blob), "size": len(blob),
"thumbnail": thumbnail_location
"thumbnail": thumbnail_location,
} }
DocumentService.insert(doc) DocumentService.insert(doc)




@staticmethod @staticmethod
def parse_docs(file_objs, user_id): def parse_docs(file_objs, user_id):
from rag.app import presentation, picture, naive, audio, email
from rag.app import audio, email, naive, picture, presentation


def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass


FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12) exe = ThreadPoolExecutor(max_workers=12)
threads = [] threads = []
for file in file_objs: for file in file_objs:
kwargs = {
"lang": "English",
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": user_id
}
kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": user_id}
filetype = filename_type(file.filename) filetype = filename_type(file.filename)
blob = file.read() blob = file.read()
threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs)) threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs))
return ParserType.PRESENTATION.value return ParserType.PRESENTATION.value
if re.search(r"\.(eml)$", filename): if re.search(r"\.(eml)$", filename):
return ParserType.EMAIL.value return ParserType.EMAIL.value
return default
return default


+ 71
- 30
api/utils/file_utils.py 查看文件

import json import json
import os import os
import re import re
import shutil
import subprocess
import sys import sys
import tempfile
import threading import threading
from io import BytesIO from io import BytesIO


import pdfplumber import pdfplumber
from PIL import Image
from cachetools import LRUCache, cached from cachetools import LRUCache, cached
from PIL import Image
from ruamel.yaml import YAML from ruamel.yaml import YAML


from api.db import FileType
from api.constants import IMG_BASE64_PREFIX from api.constants import IMG_BASE64_PREFIX
from api.db import FileType


PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE") RAG_BASE = os.getenv("RAG_BASE")




def get_home_cache_dir(): def get_home_cache_dir():
dir = os.path.join(os.path.expanduser('~'), ".ragflow")
dir = os.path.join(os.path.expanduser("~"), ".ragflow")
try: try:
os.mkdir(dir) os.mkdir(dir)
except OSError: except OSError:
with open(json_conf_path) as f: with open(json_conf_path) as f:
return json.load(f) return json.load(f)
except BaseException: except BaseException:
raise EnvironmentError(
"loading json file config from '{}' failed!".format(json_conf_path)
)
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))




def dump_json_conf(config_data, conf_path): def dump_json_conf(config_data, conf_path):
with open(json_conf_path, "w") as f: with open(json_conf_path, "w") as f:
json.dump(config_data, f, indent=4) json.dump(config_data, f, indent=4)
except BaseException: except BaseException:
raise EnvironmentError(
"loading json file config from '{}' failed!".format(json_conf_path)
)
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))




def load_json_conf_real_time(conf_path): def load_json_conf_real_time(conf_path):
with open(json_conf_path) as f: with open(json_conf_path) as f:
return json.load(f) return json.load(f)
except BaseException: except BaseException:
raise EnvironmentError(
"loading json file config from '{}' failed!".format(json_conf_path)
)
raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))




def load_yaml_conf(conf_path): def load_yaml_conf(conf_path):
conf_path = os.path.join(get_project_base_directory(), conf_path) conf_path = os.path.join(get_project_base_directory(), conf_path)
try: try:
with open(conf_path) as f: with open(conf_path) as f:
yaml = YAML(typ='safe', pure=True)
yaml = YAML(typ="safe", pure=True)
return yaml.load(f) return yaml.load(f)
except Exception as e: except Exception as e:
raise EnvironmentError(
"loading yaml file config from {} failed:".format(conf_path), e
)
raise EnvironmentError("loading yaml file config from {} failed:".format(conf_path), e)




def rewrite_yaml_conf(conf_path, config): def rewrite_yaml_conf(conf_path, config):
yaml = YAML(typ="safe") yaml = YAML(typ="safe")
yaml.dump(config, f) yaml.dump(config, f)
except Exception as e: except Exception as e:
raise EnvironmentError(
"rewrite yaml file config {} failed:".format(conf_path), e
)
raise EnvironmentError("rewrite yaml file config {} failed:".format(conf_path), e)




def rewrite_json_file(filepath, json_data): def rewrite_json_file(filepath, json_data):
with open(filepath, "w", encoding='utf-8') as f:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4, separators=(",", ": ")) json.dump(json_data, f, indent=4, separators=(",", ": "))
f.close() f.close()


if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value


if re.match(
r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
if re.match(r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
return FileType.DOC.value return FileType.DOC.value


if re.match(
r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
return FileType.AURAL.value return FileType.AURAL.value


if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):


return FileType.OTHER.value return FileType.OTHER.value



def thumbnail_img(filename, blob): def thumbnail_img(filename, blob):
""" """
MySQL LongText max length is 65535 MySQL LongText max length is 65535
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
with sys.modules[LOCK_KEY_pdfplumber]: with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(BytesIO(blob)) pdf = pdfplumber.open(BytesIO(blob))

buffered = BytesIO() buffered = BytesIO()
resolution = 32 resolution = 32
img = None img = None
return buffered.getvalue() return buffered.getvalue()


elif re.match(r".*\.(ppt|pptx)$", filename): elif re.match(r".*\.(ppt|pptx)$", filename):
import aspose.slides as slides
import aspose.pydrawing as drawing import aspose.pydrawing as drawing
import aspose.slides as slides

try: try:
with slides.Presentation(BytesIO(blob)) as presentation: with slides.Presentation(BytesIO(blob)) as presentation:
buffered = BytesIO() buffered = BytesIO()
img = None img = None
for _ in range(10): for _ in range(10):
# https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float # https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float
presentation.slides[0].get_thumbnail(scale, scale).save(
buffered, drawing.imaging.ImageFormat.png)
presentation.slides[0].get_thumbnail(scale, scale).save(buffered, drawing.imaging.ImageFormat.png)
img = buffered.getvalue() img = buffered.getvalue()
if len(img) >= 64000: if len(img) >= 64000:
scale = scale / 2.0 scale = scale / 2.0
def thumbnail(filename, blob): def thumbnail(filename, blob):
img = thumbnail_img(filename, blob) img = thumbnail_img(filename, blob)
if img is not None: if img is not None:
return IMG_BASE64_PREFIX + \
base64.b64encode(img).decode("utf-8")
return IMG_BASE64_PREFIX + base64.b64encode(img).decode("utf-8")
else: else:
return ''
return ""




def traversal_files(base): def traversal_files(base):
for f in fs: for f in fs:
fullname = os.path.join(root, f) fullname = os.path.join(root, f)
yield fullname yield fullname


def repair_pdf_with_ghostscript(input_bytes):
if shutil.which("gs") is None:
return input_bytes

with tempfile.NamedTemporaryFile(suffix=".pdf") as temp_in, tempfile.NamedTemporaryFile(suffix=".pdf") as temp_out:
temp_in.write(input_bytes)
temp_in.flush()

cmd = [
"gs",
"-o",
temp_out.name,
"-sDEVICE=pdfwrite",
"-dPDFSETTINGS=/prepress",
temp_in.name,
]
try:
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
return input_bytes
except Exception:
return input_bytes

temp_out.seek(0)
repaired_bytes = temp_out.read()

return repaired_bytes


def read_potential_broken_pdf(blob):
def try_open(blob):
try:
with pdfplumber.open(BytesIO(blob)) as pdf:
if pdf.pages:
return True
except Exception:
return False
return False

if try_open(blob):
return blob

repaired = repair_pdf_with_ghostscript(blob)
if try_open(repaired):
return repaired

return blob

正在加载...
取消
保存