您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

file_service.py 8.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import hashlib
  2. import os
  3. import uuid
  4. from typing import Any, Literal, Union
  5. from flask_login import current_user
  6. from sqlalchemy import Engine
  7. from sqlalchemy.orm import sessionmaker
  8. from werkzeug.exceptions import NotFound
  9. from configs import dify_config
  10. from constants import (
  11. AUDIO_EXTENSIONS,
  12. DOCUMENT_EXTENSIONS,
  13. IMAGE_EXTENSIONS,
  14. VIDEO_EXTENSIONS,
  15. )
  16. from core.file import helpers as file_helpers
  17. from core.rag.extractor.extract_processor import ExtractProcessor
  18. from extensions.ext_storage import storage
  19. from libs.datetime_utils import naive_utc_now
  20. from libs.helper import extract_tenant_id
  21. from models.account import Account
  22. from models.enums import CreatorUserRole
  23. from models.model import EndUser, UploadFile
  24. from .errors.file import FileTooLargeError, UnsupportedFileTypeError
  25. PREVIEW_WORDS_LIMIT = 3000
  26. class FileService:
  27. _session_maker: sessionmaker
  28. def __init__(self, session_factory: sessionmaker | Engine | None = None):
  29. if isinstance(session_factory, Engine):
  30. self._session_maker = sessionmaker(bind=session_factory)
  31. elif isinstance(session_factory, sessionmaker):
  32. self._session_maker = session_factory
  33. else:
  34. raise AssertionError("must be a sessionmaker or an Engine.")
  35. def upload_file(
  36. self,
  37. *,
  38. filename: str,
  39. content: bytes,
  40. mimetype: str,
  41. user: Union[Account, EndUser, Any],
  42. source: Literal["datasets"] | None = None,
  43. source_url: str = "",
  44. ) -> UploadFile:
  45. # get file extension
  46. extension = os.path.splitext(filename)[1].lstrip(".").lower()
  47. # check if filename contains invalid characters
  48. if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]):
  49. raise ValueError("Filename contains invalid characters")
  50. if len(filename) > 200:
  51. filename = filename.split(".")[0][:200] + "." + extension
  52. if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
  53. raise UnsupportedFileTypeError()
  54. # get file size
  55. file_size = len(content)
  56. # check if the file size is exceeded
  57. if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size):
  58. raise FileTooLargeError
  59. # generate file key
  60. file_uuid = str(uuid.uuid4())
  61. current_tenant_id = extract_tenant_id(user)
  62. file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
  63. # save file to storage
  64. storage.save(file_key, content)
  65. # save file to db
  66. upload_file = UploadFile(
  67. tenant_id=current_tenant_id or "",
  68. storage_type=dify_config.STORAGE_TYPE,
  69. key=file_key,
  70. name=filename,
  71. size=file_size,
  72. extension=extension,
  73. mime_type=mimetype,
  74. created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
  75. created_by=user.id,
  76. created_at=naive_utc_now(),
  77. used=False,
  78. hash=hashlib.sha3_256(content).hexdigest(),
  79. source_url=source_url,
  80. )
  81. # The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
  82. # We can directly generate the `source_url` here before committing.
  83. if not upload_file.source_url:
  84. upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
  85. with self._session_maker(expire_on_commit=False) as session:
  86. session.add(upload_file)
  87. session.commit()
  88. return upload_file
  89. @staticmethod
  90. def is_file_size_within_limit(*, extension: str, file_size: int) -> bool:
  91. if extension in IMAGE_EXTENSIONS:
  92. file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
  93. elif extension in VIDEO_EXTENSIONS:
  94. file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024
  95. elif extension in AUDIO_EXTENSIONS:
  96. file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024
  97. else:
  98. file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
  99. return file_size <= file_size_limit
  100. def upload_text(self, text: str, text_name: str) -> UploadFile:
  101. if len(text_name) > 200:
  102. text_name = text_name[:200]
  103. # user uuid as file name
  104. file_uuid = str(uuid.uuid4())
  105. file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
  106. # save file to storage
  107. storage.save(file_key, text.encode("utf-8"))
  108. # save file to db
  109. upload_file = UploadFile(
  110. tenant_id=current_user.current_tenant_id,
  111. storage_type=dify_config.STORAGE_TYPE,
  112. key=file_key,
  113. name=text_name,
  114. size=len(text),
  115. extension="txt",
  116. mime_type="text/plain",
  117. created_by=current_user.id,
  118. created_by_role=CreatorUserRole.ACCOUNT,
  119. created_at=naive_utc_now(),
  120. used=True,
  121. used_by=current_user.id,
  122. used_at=naive_utc_now(),
  123. )
  124. with self._session_maker(expire_on_commit=False) as session:
  125. session.add(upload_file)
  126. session.commit()
  127. return upload_file
  128. def get_file_preview(self, file_id: str):
  129. with self._session_maker(expire_on_commit=False) as session:
  130. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  131. if not upload_file:
  132. raise NotFound("File not found")
  133. # extract text from file
  134. extension = upload_file.extension
  135. if extension.lower() not in DOCUMENT_EXTENSIONS:
  136. raise UnsupportedFileTypeError()
  137. text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
  138. text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
  139. return text
  140. def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str):
  141. result = file_helpers.verify_image_signature(
  142. upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
  143. )
  144. if not result:
  145. raise NotFound("File not found or signature is invalid")
  146. with self._session_maker(expire_on_commit=False) as session:
  147. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  148. if not upload_file:
  149. raise NotFound("File not found or signature is invalid")
  150. # extract text from file
  151. extension = upload_file.extension
  152. if extension.lower() not in IMAGE_EXTENSIONS:
  153. raise UnsupportedFileTypeError()
  154. generator = storage.load(upload_file.key, stream=True)
  155. return generator, upload_file.mime_type
  156. def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str):
  157. result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
  158. if not result:
  159. raise NotFound("File not found or signature is invalid")
  160. with self._session_maker(expire_on_commit=False) as session:
  161. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  162. if not upload_file:
  163. raise NotFound("File not found or signature is invalid")
  164. generator = storage.load(upload_file.key, stream=True)
  165. return generator, upload_file
  166. def get_public_image_preview(self, file_id: str):
  167. with self._session_maker(expire_on_commit=False) as session:
  168. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  169. if not upload_file:
  170. raise NotFound("File not found or signature is invalid")
  171. # extract text from file
  172. extension = upload_file.extension
  173. if extension.lower() not in IMAGE_EXTENSIONS:
  174. raise UnsupportedFileTypeError()
  175. generator = storage.load(upload_file.key)
  176. return generator, upload_file.mime_type