You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

file_manager.py 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import base64
  2. from collections.abc import Mapping
  3. from configs import dify_config
  4. from core.helper import ssrf_proxy
  5. from core.model_runtime.entities import (
  6. AudioPromptMessageContent,
  7. DocumentPromptMessageContent,
  8. ImagePromptMessageContent,
  9. VideoPromptMessageContent,
  10. )
  11. from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
  12. from core.tools.signature import sign_tool_file
  13. from extensions.ext_storage import storage
  14. from . import helpers
  15. from .enums import FileAttribute
  16. from .models import File, FileTransferMethod, FileType
  17. def get_attr(*, file: File, attr: FileAttribute):
  18. match attr:
  19. case FileAttribute.TYPE:
  20. return file.type.value
  21. case FileAttribute.SIZE:
  22. return file.size
  23. case FileAttribute.NAME:
  24. return file.filename
  25. case FileAttribute.MIME_TYPE:
  26. return file.mime_type
  27. case FileAttribute.TRANSFER_METHOD:
  28. return file.transfer_method.value
  29. case FileAttribute.URL:
  30. return file.remote_url
  31. case FileAttribute.EXTENSION:
  32. return file.extension
  33. case FileAttribute.RELATED_ID:
  34. return file.related_id
  35. def to_prompt_message_content(
  36. f: File,
  37. /,
  38. *,
  39. image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
  40. ) -> PromptMessageContentUnionTypes:
  41. if f.extension is None:
  42. raise ValueError("Missing file extension")
  43. if f.mime_type is None:
  44. raise ValueError("Missing file mime_type")
  45. params = {
  46. "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
  47. "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
  48. "format": f.extension.removeprefix("."),
  49. "mime_type": f.mime_type,
  50. }
  51. if f.type == FileType.IMAGE:
  52. params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  53. prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
  54. FileType.IMAGE: ImagePromptMessageContent,
  55. FileType.AUDIO: AudioPromptMessageContent,
  56. FileType.VIDEO: VideoPromptMessageContent,
  57. FileType.DOCUMENT: DocumentPromptMessageContent,
  58. }
  59. try:
  60. return prompt_class_map[f.type].model_validate(params)
  61. except KeyError:
  62. raise ValueError(f"file type {f.type} is not supported")
  63. def download(f: File, /):
  64. if f.transfer_method in (
  65. FileTransferMethod.TOOL_FILE,
  66. FileTransferMethod.LOCAL_FILE,
  67. FileTransferMethod.DATASOURCE_FILE,
  68. ):
  69. return _download_file_content(f._storage_key)
  70. elif f.transfer_method == FileTransferMethod.REMOTE_URL:
  71. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  72. response.raise_for_status()
  73. return response.content
  74. raise ValueError(f"unsupported transfer method: {f.transfer_method}")
  75. def _download_file_content(path: str, /):
  76. """
  77. Download and return the contents of a file as bytes.
  78. This function loads the file from storage and ensures it's in bytes format.
  79. Args:
  80. path (str): The path to the file in storage.
  81. Returns:
  82. bytes: The contents of the file as a bytes object.
  83. Raises:
  84. ValueError: If the loaded file is not a bytes object.
  85. """
  86. data = storage.load(path, stream=False)
  87. if not isinstance(data, bytes):
  88. raise ValueError(f"file {path} is not a bytes object")
  89. return data
  90. def _get_encoded_string(f: File, /):
  91. match f.transfer_method:
  92. case FileTransferMethod.REMOTE_URL:
  93. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  94. response.raise_for_status()
  95. data = response.content
  96. case FileTransferMethod.LOCAL_FILE:
  97. data = _download_file_content(f._storage_key)
  98. case FileTransferMethod.TOOL_FILE:
  99. data = _download_file_content(f._storage_key)
  100. encoded_string = base64.b64encode(data).decode("utf-8")
  101. return encoded_string
  102. def _to_url(f: File, /):
  103. if f.transfer_method == FileTransferMethod.REMOTE_URL:
  104. if f.remote_url is None:
  105. raise ValueError("Missing file remote_url")
  106. return f.remote_url
  107. elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
  108. if f.related_id is None:
  109. raise ValueError("Missing file related_id")
  110. return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
  111. elif f.transfer_method == FileTransferMethod.TOOL_FILE:
  112. # add sign url
  113. if f.related_id is None or f.extension is None:
  114. raise ValueError("Missing file related_id or extension")
  115. return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
  116. else:
  117. raise ValueError(f"Unsupported transfer method: {f.transfer_method}")