You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

file_service.py 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import hashlib
  2. import os
  3. import uuid
  4. from typing import Literal, Union
  5. from werkzeug.exceptions import NotFound
  6. from configs import dify_config
  7. from constants import (
  8. AUDIO_EXTENSIONS,
  9. DOCUMENT_EXTENSIONS,
  10. IMAGE_EXTENSIONS,
  11. VIDEO_EXTENSIONS,
  12. )
  13. from core.file import helpers as file_helpers
  14. from core.rag.extractor.extract_processor import ExtractProcessor
  15. from extensions.ext_database import db
  16. from extensions.ext_storage import storage
  17. from libs.datetime_utils import naive_utc_now
  18. from libs.helper import extract_tenant_id
  19. from libs.login import current_user
  20. from models.account import Account
  21. from models.enums import CreatorUserRole
  22. from models.model import EndUser, UploadFile
  23. from .errors.file import FileTooLargeError, UnsupportedFileTypeError
  24. PREVIEW_WORDS_LIMIT = 3000
  25. class FileService:
  26. @staticmethod
  27. def upload_file(
  28. *,
  29. filename: str,
  30. content: bytes,
  31. mimetype: str,
  32. user: Union[Account, EndUser],
  33. source: Literal["datasets"] | None = None,
  34. source_url: str = "",
  35. ) -> UploadFile:
  36. # get file extension
  37. extension = os.path.splitext(filename)[1].lstrip(".").lower()
  38. # check if filename contains invalid characters
  39. if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]):
  40. raise ValueError("Filename contains invalid characters")
  41. if len(filename) > 200:
  42. filename = filename.split(".")[0][:200] + "." + extension
  43. if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
  44. raise UnsupportedFileTypeError()
  45. # get file size
  46. file_size = len(content)
  47. # check if the file size is exceeded
  48. if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size):
  49. raise FileTooLargeError
  50. # generate file key
  51. file_uuid = str(uuid.uuid4())
  52. current_tenant_id = extract_tenant_id(user)
  53. file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
  54. # save file to storage
  55. storage.save(file_key, content)
  56. # save file to db
  57. upload_file = UploadFile(
  58. tenant_id=current_tenant_id or "",
  59. storage_type=dify_config.STORAGE_TYPE,
  60. key=file_key,
  61. name=filename,
  62. size=file_size,
  63. extension=extension,
  64. mime_type=mimetype,
  65. created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
  66. created_by=user.id,
  67. created_at=naive_utc_now(),
  68. used=False,
  69. hash=hashlib.sha3_256(content).hexdigest(),
  70. source_url=source_url,
  71. )
  72. db.session.add(upload_file)
  73. db.session.commit()
  74. if not upload_file.source_url:
  75. upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
  76. db.session.add(upload_file)
  77. db.session.commit()
  78. return upload_file
  79. @staticmethod
  80. def is_file_size_within_limit(*, extension: str, file_size: int) -> bool:
  81. if extension in IMAGE_EXTENSIONS:
  82. file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
  83. elif extension in VIDEO_EXTENSIONS:
  84. file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024
  85. elif extension in AUDIO_EXTENSIONS:
  86. file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024
  87. else:
  88. file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
  89. return file_size <= file_size_limit
  90. @staticmethod
  91. def upload_text(text: str, text_name: str) -> UploadFile:
  92. assert isinstance(current_user, Account)
  93. assert current_user.current_tenant_id is not None
  94. if len(text_name) > 200:
  95. text_name = text_name[:200]
  96. # user uuid as file name
  97. file_uuid = str(uuid.uuid4())
  98. file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
  99. # save file to storage
  100. storage.save(file_key, text.encode("utf-8"))
  101. # save file to db
  102. upload_file = UploadFile(
  103. tenant_id=current_user.current_tenant_id,
  104. storage_type=dify_config.STORAGE_TYPE,
  105. key=file_key,
  106. name=text_name,
  107. size=len(text),
  108. extension="txt",
  109. mime_type="text/plain",
  110. created_by=current_user.id,
  111. created_by_role=CreatorUserRole.ACCOUNT,
  112. created_at=naive_utc_now(),
  113. used=True,
  114. used_by=current_user.id,
  115. used_at=naive_utc_now(),
  116. )
  117. db.session.add(upload_file)
  118. db.session.commit()
  119. return upload_file
  120. @staticmethod
  121. def get_file_preview(file_id: str):
  122. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  123. if not upload_file:
  124. raise NotFound("File not found")
  125. # extract text from file
  126. extension = upload_file.extension
  127. if extension.lower() not in DOCUMENT_EXTENSIONS:
  128. raise UnsupportedFileTypeError()
  129. text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
  130. text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
  131. return text
  132. @staticmethod
  133. def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str):
  134. result = file_helpers.verify_image_signature(
  135. upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
  136. )
  137. if not result:
  138. raise NotFound("File not found or signature is invalid")
  139. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  140. if not upload_file:
  141. raise NotFound("File not found or signature is invalid")
  142. # extract text from file
  143. extension = upload_file.extension
  144. if extension.lower() not in IMAGE_EXTENSIONS:
  145. raise UnsupportedFileTypeError()
  146. generator = storage.load(upload_file.key, stream=True)
  147. return generator, upload_file.mime_type
  148. @staticmethod
  149. def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str):
  150. result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
  151. if not result:
  152. raise NotFound("File not found or signature is invalid")
  153. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  154. if not upload_file:
  155. raise NotFound("File not found or signature is invalid")
  156. generator = storage.load(upload_file.key, stream=True)
  157. return generator, upload_file
  158. @staticmethod
  159. def get_public_image_preview(file_id: str):
  160. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  161. if not upload_file:
  162. raise NotFound("File not found or signature is invalid")
  163. # extract text from file
  164. extension = upload_file.extension
  165. if extension.lower() not in IMAGE_EXTENSIONS:
  166. raise UnsupportedFileTypeError()
  167. generator = storage.load(upload_file.key)
  168. return generator, upload_file.mime_type