You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

file_service.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import hashlib
  2. import os
  3. import uuid
  4. from typing import Literal, Union
  5. from sqlalchemy import Engine
  6. from sqlalchemy.orm import sessionmaker
  7. from werkzeug.exceptions import NotFound
  8. from configs import dify_config
  9. from constants import (
  10. AUDIO_EXTENSIONS,
  11. DOCUMENT_EXTENSIONS,
  12. IMAGE_EXTENSIONS,
  13. VIDEO_EXTENSIONS,
  14. )
  15. from core.file import helpers as file_helpers
  16. from core.rag.extractor.extract_processor import ExtractProcessor
  17. from extensions.ext_storage import storage
  18. from libs.datetime_utils import naive_utc_now
  19. from libs.helper import extract_tenant_id
  20. from models.account import Account
  21. from models.enums import CreatorUserRole
  22. from models.model import EndUser, UploadFile
  23. from .errors.file import FileTooLargeError, UnsupportedFileTypeError
  24. PREVIEW_WORDS_LIMIT = 3000
  25. class FileService:
  26. _session_maker: sessionmaker
  27. def __init__(self, session_factory: sessionmaker | Engine | None = None):
  28. if isinstance(session_factory, Engine):
  29. self._session_maker = sessionmaker(bind=session_factory)
  30. elif isinstance(session_factory, sessionmaker):
  31. self._session_maker = session_factory
  32. else:
  33. raise AssertionError("must be a sessionmaker or an Engine.")
  34. def upload_file(
  35. self,
  36. *,
  37. filename: str,
  38. content: bytes,
  39. mimetype: str,
  40. user: Union[Account, EndUser],
  41. source: Literal["datasets"] | None = None,
  42. source_url: str = "",
  43. ) -> UploadFile:
  44. # get file extension
  45. extension = os.path.splitext(filename)[1].lstrip(".").lower()
  46. # check if filename contains invalid characters
  47. if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]):
  48. raise ValueError("Filename contains invalid characters")
  49. if len(filename) > 200:
  50. filename = filename.split(".")[0][:200] + "." + extension
  51. if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
  52. raise UnsupportedFileTypeError()
  53. # get file size
  54. file_size = len(content)
  55. # check if the file size is exceeded
  56. if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size):
  57. raise FileTooLargeError
  58. # generate file key
  59. file_uuid = str(uuid.uuid4())
  60. current_tenant_id = extract_tenant_id(user)
  61. file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
  62. # save file to storage
  63. storage.save(file_key, content)
  64. # save file to db
  65. upload_file = UploadFile(
  66. tenant_id=current_tenant_id or "",
  67. storage_type=dify_config.STORAGE_TYPE,
  68. key=file_key,
  69. name=filename,
  70. size=file_size,
  71. extension=extension,
  72. mime_type=mimetype,
  73. created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
  74. created_by=user.id,
  75. created_at=naive_utc_now(),
  76. used=False,
  77. hash=hashlib.sha3_256(content).hexdigest(),
  78. source_url=source_url,
  79. )
  80. # The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
  81. # We can directly generate the `source_url` here before committing.
  82. if not upload_file.source_url:
  83. upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
  84. with self._session_maker(expire_on_commit=False) as session:
  85. session.add(upload_file)
  86. session.commit()
  87. return upload_file
  88. @staticmethod
  89. def is_file_size_within_limit(*, extension: str, file_size: int) -> bool:
  90. if extension in IMAGE_EXTENSIONS:
  91. file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
  92. elif extension in VIDEO_EXTENSIONS:
  93. file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024
  94. elif extension in AUDIO_EXTENSIONS:
  95. file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024
  96. else:
  97. file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
  98. return file_size <= file_size_limit
  99. def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
  100. if len(text_name) > 200:
  101. text_name = text_name[:200]
  102. # user uuid as file name
  103. file_uuid = str(uuid.uuid4())
  104. file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt"
  105. # save file to storage
  106. storage.save(file_key, text.encode("utf-8"))
  107. # save file to db
  108. upload_file = UploadFile(
  109. tenant_id=tenant_id,
  110. storage_type=dify_config.STORAGE_TYPE,
  111. key=file_key,
  112. name=text_name,
  113. size=len(text),
  114. extension="txt",
  115. mime_type="text/plain",
  116. created_by=user_id,
  117. created_by_role=CreatorUserRole.ACCOUNT,
  118. created_at=naive_utc_now(),
  119. used=True,
  120. used_by=user_id,
  121. used_at=naive_utc_now(),
  122. )
  123. with self._session_maker(expire_on_commit=False) as session:
  124. session.add(upload_file)
  125. session.commit()
  126. return upload_file
  127. def get_file_preview(self, file_id: str):
  128. with self._session_maker(expire_on_commit=False) as session:
  129. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  130. if not upload_file:
  131. raise NotFound("File not found")
  132. # extract text from file
  133. extension = upload_file.extension
  134. if extension.lower() not in DOCUMENT_EXTENSIONS:
  135. raise UnsupportedFileTypeError()
  136. text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
  137. text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
  138. return text
  139. def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str):
  140. result = file_helpers.verify_image_signature(
  141. upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
  142. )
  143. if not result:
  144. raise NotFound("File not found or signature is invalid")
  145. with self._session_maker(expire_on_commit=False) as session:
  146. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  147. if not upload_file:
  148. raise NotFound("File not found or signature is invalid")
  149. # extract text from file
  150. extension = upload_file.extension
  151. if extension.lower() not in IMAGE_EXTENSIONS:
  152. raise UnsupportedFileTypeError()
  153. generator = storage.load(upload_file.key, stream=True)
  154. return generator, upload_file.mime_type
  155. def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str):
  156. result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
  157. if not result:
  158. raise NotFound("File not found or signature is invalid")
  159. with self._session_maker(expire_on_commit=False) as session:
  160. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  161. if not upload_file:
  162. raise NotFound("File not found or signature is invalid")
  163. generator = storage.load(upload_file.key, stream=True)
  164. return generator, upload_file
  165. def get_public_image_preview(self, file_id: str):
  166. with self._session_maker(expire_on_commit=False) as session:
  167. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  168. if not upload_file:
  169. raise NotFound("File not found or signature is invalid")
  170. # extract text from file
  171. extension = upload_file.extension
  172. if extension.lower() not in IMAGE_EXTENSIONS:
  173. raise UnsupportedFileTypeError()
  174. generator = storage.load(upload_file.key)
  175. return generator, upload_file.mime_type
  176. def get_file_content(self, file_id: str) -> str:
  177. with self._session_maker(expire_on_commit=False) as session:
  178. upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
  179. if not upload_file:
  180. raise NotFound("File not found")
  181. content = storage.load(upload_file.key)
  182. return content.decode("utf-8")
  183. def delete_file(self, file_id: str):
  184. with self._session_maker(expire_on_commit=False) as session:
  185. upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
  186. if not upload_file:
  187. return
  188. storage.delete(upload_file.key)
  189. session.delete(upload_file)
  190. session.commit()