소스 검색

feat(file_factory): Standardize custom file type into known types (#11028)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/0.12.0
-LAN- 11 달 전
부모
커밋
8565c18e84
No account linked to committer's email address
1개의 변경된 파일74개의 추가작업 그리고 26개의 파일을 삭제
  1. 74
    26
      api/factories/file_factory.py

+ 74
- 26
api/factories/file_factory.py 파일 보기

import mimetypes import mimetypes
from collections.abc import Callable, Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, cast


import httpx import httpx
from sqlalchemy import select from sqlalchemy import select


from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
transfer_method=transfer_method, transfer_method=transfer_method,
) )


if not _is_file_valid_with_config(file=file, config=config):
if not _is_file_valid_with_config(
input_file_type=mapping.get("type", FileType.CUSTOM),
file_extension=file.extension,
file_transfer_method=file.transfer_method,
config=config,
):
raise ValueError(f"File validation failed for file: {file.filename}") raise ValueError(f"File validation failed for file: {file.filename}")


return file return file
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
) -> File: ) -> File:
file_type = FileType.value_of(mapping.get("type"))
stmt = select(UploadFile).where( stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"), UploadFile.id == mapping.get("upload_file_id"),
UploadFile.tenant_id == tenant_id, UploadFile.tenant_id == tenant_id,
) )


row = db.session.scalar(stmt) row = db.session.scalar(stmt)

if row is None: if row is None:
raise ValueError("Invalid upload file") raise ValueError("Invalid upload file")


file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type)

return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=row.name, filename=row.name,
mime_type, filename, file_size = _get_remote_file_info(url) mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"


file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type)

return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=filename, filename=filename,
tenant_id=tenant_id, tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")),
type=file_type,
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=url, remote_url=url,
mime_type=mime_type, mime_type=mime_type,
mime_type = mimetypes.guess_type(filename)[0] or "" mime_type = mimetypes.guess_type(filename)[0] or ""


resp = ssrf_proxy.head(url, follow_redirects=True) resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK: if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"): if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"')) filename = str(content_disposition.split("filename=")[-1].strip('"'))
return mime_type, filename, file_size return mime_type, filename, file_size




def _get_file_type_by_mimetype(mime_type: str) -> FileType:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type


def _build_from_tool_file( def _build_from_tool_file(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")


extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype)


return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
) )




def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
def _is_file_valid_with_config(
*,
input_file_type: str,
file_extension: str,
file_transfer_method: FileTransferMethod,
config: FileUploadConfig,
) -> bool:
if (
config.allowed_file_types
and input_file_type not in config.allowed_file_types
and input_file_type != FileType.CUSTOM
):
return False return False


if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions:
if config.allowed_file_extensions and file_extension not in config.allowed_file_extensions:
return False return False


if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods:
if config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
return False return False


if file.type == FileType.IMAGE and config.image_config:
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
if input_file_type == FileType.IMAGE and config.image_config:
if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods:
return False return False


return True return True


def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType:
"""
If custom type, try to guess the file type by extension and mime_type.
"""
if file_type != FileType.CUSTOM:
return FileType(file_type)
guessed_type = None
if extension:
guessed_type = _get_file_type_by_extension(extension)
if guessed_type is None and mime_type:
guessed_type = _get_file_type_by_mimetype(mime_type)
return guessed_type or FileType.CUSTOM


def _get_file_type_by_extension(extension: str) -> FileType | None:
extension = extension.lstrip(".")
if extension in IMAGE_EXTENSIONS:
return FileType.IMAGE
elif extension in VIDEO_EXTENSIONS:
return FileType.VIDEO
elif extension in AUDIO_EXTENSIONS:
return FileType.AUDIO
elif extension in DOCUMENT_EXTENSIONS:
return FileType.DOCUMENT


def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type

Loading…
취소
저장