You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

file_service.py 7.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import hashlib
  2. import os
  3. import uuid
  4. from typing import Any, Literal, Union
  5. from flask_login import current_user
  6. from werkzeug.exceptions import NotFound
  7. from configs import dify_config
  8. from constants import (
  9. AUDIO_EXTENSIONS,
  10. DOCUMENT_EXTENSIONS,
  11. IMAGE_EXTENSIONS,
  12. VIDEO_EXTENSIONS,
  13. )
  14. from core.file import helpers as file_helpers
  15. from core.rag.extractor.extract_processor import ExtractProcessor
  16. from extensions.ext_database import db
  17. from extensions.ext_storage import storage
  18. from libs.datetime_utils import naive_utc_now
  19. from libs.helper import extract_tenant_id
  20. from models.account import Account
  21. from models.enums import CreatorUserRole
  22. from models.model import EndUser, UploadFile
  23. from .errors.file import FileTooLargeError, UnsupportedFileTypeError
  24. PREVIEW_WORDS_LIMIT = 3000
  25. class FileService:
  26. @staticmethod
  27. def upload_file(
  28. *,
  29. filename: str,
  30. content: bytes,
  31. mimetype: str,
  32. user: Union[Account, EndUser, Any],
  33. source: Literal["datasets"] | None = None,
  34. source_url: str = "",
  35. ) -> UploadFile:
  36. # get file extension
  37. extension = os.path.splitext(filename)[1].lstrip(".").lower()
  38. # check if filename contains invalid characters
  39. if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]):
  40. raise ValueError("Filename contains invalid characters")
  41. if len(filename) > 200:
  42. filename = filename.split(".")[0][:200] + "." + extension
  43. if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
  44. raise UnsupportedFileTypeError()
  45. # get file size
  46. file_size = len(content)
  47. # check if the file size is exceeded
  48. if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size):
  49. raise FileTooLargeError
  50. # generate file key
  51. file_uuid = str(uuid.uuid4())
  52. current_tenant_id = extract_tenant_id(user)
  53. file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
  54. # save file to storage
  55. storage.save(file_key, content)
  56. # save file to db
  57. upload_file = UploadFile(
  58. tenant_id=current_tenant_id or "",
  59. storage_type=dify_config.STORAGE_TYPE,
  60. key=file_key,
  61. name=filename,
  62. size=file_size,
  63. extension=extension,
  64. mime_type=mimetype,
  65. created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
  66. created_by=user.id,
  67. created_at=naive_utc_now(),
  68. used=False,
  69. hash=hashlib.sha3_256(content).hexdigest(),
  70. source_url=source_url,
  71. )
  72. db.session.add(upload_file)
  73. db.session.commit()
  74. if not upload_file.source_url:
  75. upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
  76. db.session.add(upload_file)
  77. db.session.commit()
  78. return upload_file
  79. @staticmethod
  80. def is_file_size_within_limit(*, extension: str, file_size: int) -> bool:
  81. if extension in IMAGE_EXTENSIONS:
  82. file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
  83. elif extension in VIDEO_EXTENSIONS:
  84. file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024
  85. elif extension in AUDIO_EXTENSIONS:
  86. file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024
  87. else:
  88. file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
  89. return file_size <= file_size_limit
  90. @staticmethod
  91. def upload_text(text: str, text_name: str) -> UploadFile:
  92. if len(text_name) > 200:
  93. text_name = text_name[:200]
  94. # user uuid as file name
  95. file_uuid = str(uuid.uuid4())
  96. file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
  97. # save file to storage
  98. storage.save(file_key, text.encode("utf-8"))
  99. # save file to db
  100. upload_file = UploadFile(
  101. tenant_id=current_user.current_tenant_id,
  102. storage_type=dify_config.STORAGE_TYPE,
  103. key=file_key,
  104. name=text_name,
  105. size=len(text),
  106. extension="txt",
  107. mime_type="text/plain",
  108. created_by=current_user.id,
  109. created_by_role=CreatorUserRole.ACCOUNT,
  110. created_at=naive_utc_now(),
  111. used=True,
  112. used_by=current_user.id,
  113. used_at=naive_utc_now(),
  114. )
  115. db.session.add(upload_file)
  116. db.session.commit()
  117. return upload_file
  118. @staticmethod
  119. def get_file_preview(file_id: str):
  120. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  121. if not upload_file:
  122. raise NotFound("File not found")
  123. # extract text from file
  124. extension = upload_file.extension
  125. if extension.lower() not in DOCUMENT_EXTENSIONS:
  126. raise UnsupportedFileTypeError()
  127. text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
  128. text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
  129. return text
  130. @staticmethod
  131. def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str):
  132. result = file_helpers.verify_image_signature(
  133. upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
  134. )
  135. if not result:
  136. raise NotFound("File not found or signature is invalid")
  137. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  138. if not upload_file:
  139. raise NotFound("File not found or signature is invalid")
  140. # extract text from file
  141. extension = upload_file.extension
  142. if extension.lower() not in IMAGE_EXTENSIONS:
  143. raise UnsupportedFileTypeError()
  144. generator = storage.load(upload_file.key, stream=True)
  145. return generator, upload_file.mime_type
  146. @staticmethod
  147. def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str):
  148. result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
  149. if not result:
  150. raise NotFound("File not found or signature is invalid")
  151. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  152. if not upload_file:
  153. raise NotFound("File not found or signature is invalid")
  154. generator = storage.load(upload_file.key, stream=True)
  155. return generator, upload_file
  156. @staticmethod
  157. def get_public_image_preview(file_id: str):
  158. upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  159. if not upload_file:
  160. raise NotFound("File not found or signature is invalid")
  161. # extract text from file
  162. extension = upload_file.extension
  163. if extension.lower() not in IMAGE_EXTENSIONS:
  164. raise UnsupportedFileTypeError()
  165. generator = storage.load(upload_file.key)
  166. return generator, upload_file.mime_type