You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

file_manager.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import base64
  2. from collections.abc import Mapping
  3. from configs import dify_config
  4. from core.helper import ssrf_proxy
  5. from core.model_runtime.entities import (
  6. AudioPromptMessageContent,
  7. DocumentPromptMessageContent,
  8. ImagePromptMessageContent,
  9. TextPromptMessageContent,
  10. VideoPromptMessageContent,
  11. )
  12. from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
  13. from core.tools.signature import sign_tool_file
  14. from extensions.ext_storage import storage
  15. from . import helpers
  16. from .enums import FileAttribute
  17. from .models import File, FileTransferMethod, FileType
  18. def get_attr(*, file: File, attr: FileAttribute):
  19. match attr:
  20. case FileAttribute.TYPE:
  21. return file.type.value
  22. case FileAttribute.SIZE:
  23. return file.size
  24. case FileAttribute.NAME:
  25. return file.filename
  26. case FileAttribute.MIME_TYPE:
  27. return file.mime_type
  28. case FileAttribute.TRANSFER_METHOD:
  29. return file.transfer_method.value
  30. case FileAttribute.URL:
  31. return file.remote_url
  32. case FileAttribute.EXTENSION:
  33. return file.extension
  34. case FileAttribute.RELATED_ID:
  35. return file.related_id
  36. def to_prompt_message_content(
  37. f: File,
  38. /,
  39. *,
  40. image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
  41. ) -> PromptMessageContentUnionTypes:
  42. """
  43. Convert a file to prompt message content.
  44. This function converts files to their appropriate prompt message content types.
  45. For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
  46. corresponding message content with proper encoding/URL.
  47. For unsupported file types, instead of raising an error, it returns a
  48. TextPromptMessageContent with a descriptive message about the file.
  49. Args:
  50. f: The file to convert
  51. image_detail_config: Optional detail configuration for image files
  52. Returns:
  53. PromptMessageContentUnionTypes: The appropriate message content type
  54. Raises:
  55. ValueError: If file extension or mime_type is missing
  56. """
  57. if f.extension is None:
  58. raise ValueError("Missing file extension")
  59. if f.mime_type is None:
  60. raise ValueError("Missing file mime_type")
  61. prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
  62. FileType.IMAGE: ImagePromptMessageContent,
  63. FileType.AUDIO: AudioPromptMessageContent,
  64. FileType.VIDEO: VideoPromptMessageContent,
  65. FileType.DOCUMENT: DocumentPromptMessageContent,
  66. }
  67. # Check if file type is supported
  68. if f.type not in prompt_class_map:
  69. # For unsupported file types, return a text description
  70. return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
  71. # Process supported file types
  72. params = {
  73. "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
  74. "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
  75. "format": f.extension.removeprefix("."),
  76. "mime_type": f.mime_type,
  77. }
  78. if f.type == FileType.IMAGE:
  79. params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  80. return prompt_class_map[f.type].model_validate(params)
  81. def download(f: File, /):
  82. if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
  83. return _download_file_content(f._storage_key)
  84. elif f.transfer_method == FileTransferMethod.REMOTE_URL:
  85. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  86. response.raise_for_status()
  87. return response.content
  88. raise ValueError(f"unsupported transfer method: {f.transfer_method}")
  89. def _download_file_content(path: str, /):
  90. """
  91. Download and return the contents of a file as bytes.
  92. This function loads the file from storage and ensures it's in bytes format.
  93. Args:
  94. path (str): The path to the file in storage.
  95. Returns:
  96. bytes: The contents of the file as a bytes object.
  97. Raises:
  98. ValueError: If the loaded file is not a bytes object.
  99. """
  100. data = storage.load(path, stream=False)
  101. if not isinstance(data, bytes):
  102. raise ValueError(f"file {path} is not a bytes object")
  103. return data
  104. def _get_encoded_string(f: File, /):
  105. match f.transfer_method:
  106. case FileTransferMethod.REMOTE_URL:
  107. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  108. response.raise_for_status()
  109. data = response.content
  110. case FileTransferMethod.LOCAL_FILE:
  111. data = _download_file_content(f._storage_key)
  112. case FileTransferMethod.TOOL_FILE:
  113. data = _download_file_content(f._storage_key)
  114. encoded_string = base64.b64encode(data).decode("utf-8")
  115. return encoded_string
  116. def _to_url(f: File, /):
  117. if f.transfer_method == FileTransferMethod.REMOTE_URL:
  118. if f.remote_url is None:
  119. raise ValueError("Missing file remote_url")
  120. return f.remote_url
  121. elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
  122. if f.related_id is None:
  123. raise ValueError("Missing file related_id")
  124. return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
  125. elif f.transfer_method == FileTransferMethod.TOOL_FILE:
  126. # add sign url
  127. if f.related_id is None or f.extension is None:
  128. raise ValueError("Missing file related_id or extension")
  129. return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
  130. else:
  131. raise ValueError(f"Unsupported transfer method: {f.transfer_method}")