Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

file_factory.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. import mimetypes
  2. import os
  3. import urllib.parse
  4. import uuid
  5. from collections.abc import Callable, Mapping, Sequence
  6. from typing import Any
  7. import httpx
  8. from sqlalchemy import select
  9. from sqlalchemy.orm import Session
  10. from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
  11. from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
  12. from core.helper import ssrf_proxy
  13. from extensions.ext_database import db
  14. from models import MessageFile, ToolFile, UploadFile
  15. def build_from_message_files(
  16. *,
  17. message_files: Sequence["MessageFile"],
  18. tenant_id: str,
  19. config: FileUploadConfig,
  20. ) -> Sequence[File]:
  21. results = [
  22. build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
  23. for file in message_files
  24. if file.belongs_to != FileBelongsTo.ASSISTANT
  25. ]
  26. return results
  27. def build_from_message_file(
  28. *,
  29. message_file: "MessageFile",
  30. tenant_id: str,
  31. config: FileUploadConfig,
  32. ):
  33. mapping = {
  34. "transfer_method": message_file.transfer_method,
  35. "url": message_file.url,
  36. "id": message_file.id,
  37. "type": message_file.type,
  38. }
  39. # Set the correct ID field based on transfer method
  40. if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
  41. mapping["tool_file_id"] = message_file.upload_file_id
  42. else:
  43. mapping["upload_file_id"] = message_file.upload_file_id
  44. return build_from_mapping(
  45. mapping=mapping,
  46. tenant_id=tenant_id,
  47. config=config,
  48. )
  49. def build_from_mapping(
  50. *,
  51. mapping: Mapping[str, Any],
  52. tenant_id: str,
  53. config: FileUploadConfig | None = None,
  54. strict_type_validation: bool = False,
  55. ) -> File:
  56. transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
  57. build_functions: dict[FileTransferMethod, Callable] = {
  58. FileTransferMethod.LOCAL_FILE: _build_from_local_file,
  59. FileTransferMethod.REMOTE_URL: _build_from_remote_url,
  60. FileTransferMethod.TOOL_FILE: _build_from_tool_file,
  61. FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
  62. }
  63. build_func = build_functions.get(transfer_method)
  64. if not build_func:
  65. raise ValueError(f"Invalid file transfer method: {transfer_method}")
  66. file: File = build_func(
  67. mapping=mapping,
  68. tenant_id=tenant_id,
  69. transfer_method=transfer_method,
  70. strict_type_validation=strict_type_validation,
  71. )
  72. if config and not _is_file_valid_with_config(
  73. input_file_type=mapping.get("type", FileType.CUSTOM),
  74. file_extension=file.extension or "",
  75. file_transfer_method=file.transfer_method,
  76. config=config,
  77. ):
  78. raise ValueError(f"File validation failed for file: {file.filename}")
  79. return file
  80. def build_from_mappings(
  81. *,
  82. mappings: Sequence[Mapping[str, Any]],
  83. config: FileUploadConfig | None = None,
  84. tenant_id: str,
  85. strict_type_validation: bool = False,
  86. ) -> Sequence[File]:
  87. # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
  88. # Implement batch processing to reduce database load when handling multiple files.
  89. files = [
  90. build_from_mapping(
  91. mapping=mapping,
  92. tenant_id=tenant_id,
  93. config=config,
  94. strict_type_validation=strict_type_validation,
  95. )
  96. for mapping in mappings
  97. ]
  98. if (
  99. config
  100. # If image config is set.
  101. and config.image_config
  102. # And the number of image files exceeds the maximum limit
  103. and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
  104. ):
  105. raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
  106. if config and config.number_limits and len(files) > config.number_limits:
  107. raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
  108. return files
  109. def _build_from_local_file(
  110. *,
  111. mapping: Mapping[str, Any],
  112. tenant_id: str,
  113. transfer_method: FileTransferMethod,
  114. strict_type_validation: bool = False,
  115. ) -> File:
  116. upload_file_id = mapping.get("upload_file_id")
  117. if not upload_file_id:
  118. raise ValueError("Invalid upload file id")
  119. # check if upload_file_id is a valid uuid
  120. try:
  121. uuid.UUID(upload_file_id)
  122. except ValueError:
  123. raise ValueError("Invalid upload file id format")
  124. stmt = select(UploadFile).where(
  125. UploadFile.id == upload_file_id,
  126. UploadFile.tenant_id == tenant_id,
  127. )
  128. row = db.session.scalar(stmt)
  129. if row is None:
  130. raise ValueError("Invalid upload file")
  131. detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
  132. specified_type = mapping.get("type", "custom")
  133. if strict_type_validation and detected_file_type.value != specified_type:
  134. raise ValueError("Detected file type does not match the specified type. Please verify the file.")
  135. file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
  136. return File(
  137. id=mapping.get("id"),
  138. filename=row.name,
  139. extension="." + row.extension,
  140. mime_type=row.mime_type,
  141. tenant_id=tenant_id,
  142. type=file_type,
  143. transfer_method=transfer_method,
  144. remote_url=row.source_url,
  145. related_id=mapping.get("upload_file_id"),
  146. size=row.size,
  147. storage_key=row.key,
  148. )
  149. def _build_from_remote_url(
  150. *,
  151. mapping: Mapping[str, Any],
  152. tenant_id: str,
  153. transfer_method: FileTransferMethod,
  154. strict_type_validation: bool = False,
  155. ) -> File:
  156. upload_file_id = mapping.get("upload_file_id")
  157. if upload_file_id:
  158. try:
  159. uuid.UUID(upload_file_id)
  160. except ValueError:
  161. raise ValueError("Invalid upload file id format")
  162. stmt = select(UploadFile).where(
  163. UploadFile.id == upload_file_id,
  164. UploadFile.tenant_id == tenant_id,
  165. )
  166. upload_file = db.session.scalar(stmt)
  167. if upload_file is None:
  168. raise ValueError("Invalid upload file")
  169. detected_file_type = _standardize_file_type(
  170. extension="." + upload_file.extension, mime_type=upload_file.mime_type
  171. )
  172. specified_type = mapping.get("type")
  173. if strict_type_validation and specified_type and detected_file_type.value != specified_type:
  174. raise ValueError("Detected file type does not match the specified type. Please verify the file.")
  175. file_type = (
  176. FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
  177. )
  178. return File(
  179. id=mapping.get("id"),
  180. filename=upload_file.name,
  181. extension="." + upload_file.extension,
  182. mime_type=upload_file.mime_type,
  183. tenant_id=tenant_id,
  184. type=file_type,
  185. transfer_method=transfer_method,
  186. remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
  187. related_id=mapping.get("upload_file_id"),
  188. size=upload_file.size,
  189. storage_key=upload_file.key,
  190. )
  191. url = mapping.get("url") or mapping.get("remote_url")
  192. if not url:
  193. raise ValueError("Invalid file url")
  194. mime_type, filename, file_size = _get_remote_file_info(url)
  195. extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
  196. file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
  197. if file_type.value != mapping.get("type", "custom"):
  198. raise ValueError("Detected file type does not match the specified type. Please verify the file.")
  199. return File(
  200. id=mapping.get("id"),
  201. filename=filename,
  202. tenant_id=tenant_id,
  203. type=file_type,
  204. transfer_method=transfer_method,
  205. remote_url=url,
  206. mime_type=mime_type,
  207. extension=extension,
  208. size=file_size,
  209. storage_key="",
  210. )
  211. def _get_remote_file_info(url: str):
  212. file_size = -1
  213. parsed_url = urllib.parse.urlparse(url)
  214. url_path = parsed_url.path
  215. filename = os.path.basename(url_path)
  216. # Initialize mime_type from filename as fallback
  217. mime_type, _ = mimetypes.guess_type(filename)
  218. if mime_type is None:
  219. mime_type = ""
  220. resp = ssrf_proxy.head(url, follow_redirects=True)
  221. if resp.status_code == httpx.codes.OK:
  222. if content_disposition := resp.headers.get("Content-Disposition"):
  223. filename = str(content_disposition.split("filename=")[-1].strip('"'))
  224. # Re-guess mime_type from updated filename
  225. mime_type, _ = mimetypes.guess_type(filename)
  226. if mime_type is None:
  227. mime_type = ""
  228. file_size = int(resp.headers.get("Content-Length", file_size))
  229. # Fallback to Content-Type header if mime_type is still empty
  230. if not mime_type:
  231. mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
  232. return mime_type, filename, file_size
  233. def _build_from_tool_file(
  234. *,
  235. mapping: Mapping[str, Any],
  236. tenant_id: str,
  237. transfer_method: FileTransferMethod,
  238. strict_type_validation: bool = False,
  239. ) -> File:
  240. tool_file = db.session.scalar(
  241. select(ToolFile).where(
  242. ToolFile.id == mapping.get("tool_file_id"),
  243. ToolFile.tenant_id == tenant_id,
  244. )
  245. )
  246. if tool_file is None:
  247. raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
  248. extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
  249. detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
  250. specified_type = mapping.get("type")
  251. if strict_type_validation and specified_type and detected_file_type.value != specified_type:
  252. raise ValueError("Detected file type does not match the specified type. Please verify the file.")
  253. file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
  254. return File(
  255. id=mapping.get("id"),
  256. tenant_id=tenant_id,
  257. filename=tool_file.name,
  258. type=file_type,
  259. transfer_method=transfer_method,
  260. remote_url=tool_file.original_url,
  261. related_id=tool_file.id,
  262. extension=extension,
  263. mime_type=tool_file.mimetype,
  264. size=tool_file.size,
  265. storage_key=tool_file.file_key,
  266. )
  267. def _build_from_datasource_file(
  268. *,
  269. mapping: Mapping[str, Any],
  270. tenant_id: str,
  271. transfer_method: FileTransferMethod,
  272. strict_type_validation: bool = False,
  273. ) -> File:
  274. datasource_file = (
  275. db.session.query(UploadFile)
  276. .filter(
  277. UploadFile.id == mapping.get("datasource_file_id"),
  278. UploadFile.tenant_id == tenant_id,
  279. )
  280. .first()
  281. )
  282. if datasource_file is None:
  283. raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
  284. extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
  285. detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
  286. specified_type = mapping.get("type")
  287. if strict_type_validation and specified_type and detected_file_type.value != specified_type:
  288. raise ValueError("Detected file type does not match the specified type. Please verify the file.")
  289. file_type = (
  290. FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
  291. )
  292. return File(
  293. id=mapping.get("datasource_file_id"),
  294. tenant_id=tenant_id,
  295. filename=datasource_file.name,
  296. type=file_type,
  297. transfer_method=FileTransferMethod.TOOL_FILE,
  298. remote_url=datasource_file.source_url,
  299. related_id=datasource_file.id,
  300. extension=extension,
  301. mime_type=datasource_file.mime_type,
  302. size=datasource_file.size,
  303. storage_key=datasource_file.key,
  304. url=datasource_file.source_url,
  305. )
  306. def _is_file_valid_with_config(
  307. *,
  308. input_file_type: str,
  309. file_extension: str,
  310. file_transfer_method: FileTransferMethod,
  311. config: FileUploadConfig,
  312. ) -> bool:
  313. # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
  314. # These are internally generated and should bypass user upload restrictions
  315. if file_transfer_method == FileTransferMethod.TOOL_FILE:
  316. return True
  317. if (
  318. config.allowed_file_types
  319. and input_file_type not in config.allowed_file_types
  320. and input_file_type != FileType.CUSTOM
  321. ):
  322. return False
  323. if (
  324. input_file_type == FileType.CUSTOM
  325. and config.allowed_file_extensions is not None
  326. and file_extension not in config.allowed_file_extensions
  327. ):
  328. return False
  329. if input_file_type == FileType.IMAGE:
  330. if (
  331. config.image_config
  332. and config.image_config.transfer_methods
  333. and file_transfer_method not in config.image_config.transfer_methods
  334. ):
  335. return False
  336. elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
  337. return False
  338. return True
  339. def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
  340. """
  341. Infer the possible actual type of the file based on the extension and mime_type
  342. """
  343. guessed_type = None
  344. if extension:
  345. guessed_type = _get_file_type_by_extension(extension)
  346. if guessed_type is None and mime_type:
  347. guessed_type = _get_file_type_by_mimetype(mime_type)
  348. return guessed_type or FileType.CUSTOM
  349. def _get_file_type_by_extension(extension: str) -> FileType | None:
  350. extension = extension.lstrip(".")
  351. if extension in IMAGE_EXTENSIONS:
  352. return FileType.IMAGE
  353. elif extension in VIDEO_EXTENSIONS:
  354. return FileType.VIDEO
  355. elif extension in AUDIO_EXTENSIONS:
  356. return FileType.AUDIO
  357. elif extension in DOCUMENT_EXTENSIONS:
  358. return FileType.DOCUMENT
  359. return None
  360. def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
  361. if "image" in mime_type:
  362. file_type = FileType.IMAGE
  363. elif "video" in mime_type:
  364. file_type = FileType.VIDEO
  365. elif "audio" in mime_type:
  366. file_type = FileType.AUDIO
  367. elif "text" in mime_type or "pdf" in mime_type:
  368. file_type = FileType.DOCUMENT
  369. else:
  370. file_type = FileType.CUSTOM
  371. return file_type
  372. def get_file_type_by_mime_type(mime_type: str) -> FileType:
  373. return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM
  374. class StorageKeyLoader:
  375. """FileKeyLoader load the storage key from database for a list of files.
  376. This loader is batched, the database query count is constant regardless of the input size.
  377. """
  378. def __init__(self, session: Session, tenant_id: str) -> None:
  379. self._session = session
  380. self._tenant_id = tenant_id
  381. def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
  382. stmt = select(UploadFile).where(
  383. UploadFile.id.in_(upload_file_ids),
  384. UploadFile.tenant_id == self._tenant_id,
  385. )
  386. return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
  387. def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
  388. stmt = select(ToolFile).where(
  389. ToolFile.id.in_(tool_file_ids),
  390. ToolFile.tenant_id == self._tenant_id,
  391. )
  392. return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
  393. def load_storage_keys(self, files: Sequence[File]):
  394. """Loads storage keys for a sequence of files by retrieving the corresponding
  395. `UploadFile` or `ToolFile` records from the database based on their transfer method.
  396. This method doesn't modify the input sequence structure but updates the `_storage_key`
  397. property of each file object by extracting the relevant key from its database record.
  398. Performance note: This is a batched operation where database query count remains constant
  399. regardless of input size. However, for optimal performance, input sequences should contain
  400. fewer than 1000 files. For larger collections, split into smaller batches and process each
  401. batch separately.
  402. """
  403. upload_file_ids: list[uuid.UUID] = []
  404. tool_file_ids: list[uuid.UUID] = []
  405. for file in files:
  406. related_model_id = file.related_id
  407. if file.related_id is None:
  408. raise ValueError("file id should not be None.")
  409. if file.tenant_id != self._tenant_id:
  410. err_msg = (
  411. f"invalid file, expected tenant_id={self._tenant_id}, "
  412. f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}"
  413. )
  414. raise ValueError(err_msg)
  415. model_id = uuid.UUID(related_model_id)
  416. if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
  417. upload_file_ids.append(model_id)
  418. elif file.transfer_method == FileTransferMethod.TOOL_FILE:
  419. tool_file_ids.append(model_id)
  420. tool_files = self._load_tool_files(tool_file_ids)
  421. upload_files = self._load_upload_files(upload_file_ids)
  422. for file in files:
  423. model_id = uuid.UUID(file.related_id)
  424. if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
  425. upload_file_row = upload_files.get(model_id)
  426. if upload_file_row is None:
  427. raise ValueError(f"Upload file not found for id: {model_id}")
  428. file._storage_key = upload_file_row.key
  429. elif file.transfer_method == FileTransferMethod.TOOL_FILE:
  430. tool_file_row = tool_files.get(model_id)
  431. if tool_file_row is None:
  432. raise ValueError(f"Tool file not found for id: {model_id}")
  433. file._storage_key = tool_file_row.file_key