Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

file_service.py 8.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import hashlib
  2. import os
  3. import uuid
  4. from typing import Any, Literal, Union
  5. from sqlalchemy import Engine
  6. from sqlalchemy.orm import sessionmaker
  7. from werkzeug.exceptions import NotFound
  8. from configs import dify_config
  9. from constants import (
  10. AUDIO_EXTENSIONS,
  11. DOCUMENT_EXTENSIONS,
  12. IMAGE_EXTENSIONS,
  13. VIDEO_EXTENSIONS,
  14. )
  15. from core.file import helpers as file_helpers
  16. from core.rag.extractor.extract_processor import ExtractProcessor
  17. from extensions.ext_storage import storage
  18. from libs.datetime_utils import naive_utc_now
  19. from libs.helper import extract_tenant_id
  20. from libs.login import current_user
  21. from models.account import Account
  22. from models.enums import CreatorUserRole
  23. from models.model import EndUser, UploadFile
  24. from .errors.file import FileTooLargeError, UnsupportedFileTypeError
  25. PREVIEW_WORDS_LIMIT = 3000
  26. class FileService:
  27. _session_maker: sessionmaker
  28. def __init__(self, session_factory: sessionmaker | Engine | None = None):
  29. if isinstance(session_factory, Engine):
  30. self._session_maker = sessionmaker(bind=session_factory)
  31. elif isinstance(session_factory, sessionmaker):
  32. self._session_maker = session_factory
  33. else:
  34. raise AssertionError("must be a sessionmaker or an Engine.")
  35. def upload_file(
  36. self,
  37. *,
  38. filename: str,
  39. content: bytes,
  40. mimetype: str,
  41. user: Union[Account, EndUser, Any],
  42. source: Literal["datasets"] | None = None,
  43. source_url: str = "",
  44. ) -> UploadFile:
  45. # get file extension
  46. extension = os.path.splitext(filename)[1].lstrip(".").lower()
  47. # check if filename contains invalid characters
  48. if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]):
  49. raise ValueError("Filename contains invalid characters")
  50. if len(filename) > 200:
  51. filename = filename.split(".")[0][:200] + "." + extension
  52. if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
  53. raise UnsupportedFileTypeError()
  54. # get file size
  55. file_size = len(content)
  56. # check if the file size is exceeded
  57. if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size):
  58. raise FileTooLargeError
  59. # generate file key
  60. file_uuid = str(uuid.uuid4())
  61. current_tenant_id = extract_tenant_id(user)
  62. file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
  63. # save file to storage
  64. storage.save(file_key, content)
  65. # save file to db
  66. upload_file = UploadFile(
  67. tenant_id=current_tenant_id or "",
  68. storage_type=dify_config.STORAGE_TYPE,
  69. key=file_key,
  70. name=filename,
  71. size=file_size,
  72. extension=extension,
  73. mime_type=mimetype,
  74. created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
  75. created_by=user.id,
  76. created_at=naive_utc_now(),
  77. used=False,
  78. hash=hashlib.sha3_256(content).hexdigest(),
  79. source_url=source_url,
  80. )
  81. # The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
  82. # We can directly generate the `source_url` here before committing.
  83. if not upload_file.source_url:
  84. upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
  85. with self._session_maker(expire_on_commit=False) as session:
  86. session.add(upload_file)
  87. session.commit()
  88. return upload_file
  89. @staticmethod
  90. def is_file_size_within_limit(*, extension: str, file_size: int) -> bool:
  91. if extension in IMAGE_EXTENSIONS:
  92. file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
  93. elif extension in VIDEO_EXTENSIONS:
  94. file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024
  95. elif extension in AUDIO_EXTENSIONS:
  96. file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024
  97. else:
  98. file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
  99. return file_size <= file_size_limit
  100. @staticmethod
  101. def upload_text(text: str, text_name: str) -> UploadFile:
  102. assert isinstance(current_user, Account)
  103. assert current_user.current_tenant_id is not None
  104. if len(text_name) > 200:
  105. text_name = text_name[:200]
  106. # user uuid as file name
  107. file_uuid = str(uuid.uuid4())
  108. file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
  109. # save file to storage
  110. storage.save(file_key, text.encode("utf-8"))
  111. # save file to db
  112. upload_file = UploadFile(
  113. tenant_id=current_user.current_tenant_id,
  114. storage_type=dify_config.STORAGE_TYPE,
  115. key=file_key,
  116. name=text_name,
  117. size=len(text),
  118. extension="txt",
  119. mime_type="text/plain",
  120. created_by=current_user.id,
  121. created_by_role=CreatorUserRole.ACCOUNT,
  122. created_at=naive_utc_now(),
  123. used=True,
  124. used_by=current_user.id,
  125. used_at=naive_utc_now(),
  126. )
  127. with self._session_maker(expire_on_commit=False) as session:
  128. session.add(upload_file)
  129. session.commit()
  130. return upload_file
  131. def get_file_preview(self, file_id: str):
  132. with self._session_maker(expire_on_commit=False) as session:
  133. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  134. if not upload_file:
  135. raise NotFound("File not found")
  136. # extract text from file
  137. extension = upload_file.extension
  138. if extension.lower() not in DOCUMENT_EXTENSIONS:
  139. raise UnsupportedFileTypeError()
  140. text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
  141. text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
  142. return text
  143. def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str):
  144. result = file_helpers.verify_image_signature(
  145. upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
  146. )
  147. if not result:
  148. raise NotFound("File not found or signature is invalid")
  149. with self._session_maker(expire_on_commit=False) as session:
  150. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  151. if not upload_file:
  152. raise NotFound("File not found or signature is invalid")
  153. # extract text from file
  154. extension = upload_file.extension
  155. if extension.lower() not in IMAGE_EXTENSIONS:
  156. raise UnsupportedFileTypeError()
  157. generator = storage.load(upload_file.key, stream=True)
  158. return generator, upload_file.mime_type
  159. def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str):
  160. result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
  161. if not result:
  162. raise NotFound("File not found or signature is invalid")
  163. with self._session_maker(expire_on_commit=False) as session:
  164. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  165. if not upload_file:
  166. raise NotFound("File not found or signature is invalid")
  167. generator = storage.load(upload_file.key, stream=True)
  168. return generator, upload_file
  169. def get_public_image_preview(self, file_id: str):
  170. with self._session_maker(expire_on_commit=False) as session:
  171. upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
  172. if not upload_file:
  173. raise NotFound("File not found or signature is invalid")
  174. # extract text from file
  175. extension = upload_file.extension
  176. if extension.lower() not in IMAGE_EXTENSIONS:
  177. raise UnsupportedFileTypeError()
  178. generator = storage.load(upload_file.key)
  179. return generator, upload_file.mime_type