Ver código fonte

fix: create_blob_message of tool will always create image type file (#10701)

tags/0.11.2
非法操作 11 meses atrás
pai
commit
4b2abf8ac2
Nenhuma conta vinculada ao e-mail do autor do commit

+ 0
- 9
api/core/workflow/nodes/tool/tool_node.py Ver arquivo

from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from os import path
from typing import Any from typing import Any


from sqlalchemy import select from sqlalchemy import select
for response in tool_response: for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = str(response.message) if response.message else None url = str(response.message) if response.message else None
ext = path.splitext(url)[1] if url else ".bin"
tool_file_id = str(url).split("/")[-1].split(".")[0] tool_file_id = str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)


) )
result.append(file) result.append(file)
elif response.type == ToolInvokeMessage.MessageType.BLOB: elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = str(response.message).split("/")[-1].split(".")[0] tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session: with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id) stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
raise ValueError(f"tool file {tool_file_id} not exists") raise ValueError(f"tool file {tool_file_id} not exists")
mapping = { mapping = {
"tool_file_id": tool_file_id, "tool_file_id": tool_file_id,
"type": FileType.IMAGE,
"transfer_method": FileTransferMethod.TOOL_FILE, "transfer_method": FileTransferMethod.TOOL_FILE,
} }
file = file_factory.build_from_mapping( file = file_factory.build_from_mapping(
tool_file = session.scalar(stmt) tool_file = session.scalar(stmt)
if tool_file is None: if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist") raise ToolFileError(f"Tool file {tool_file_id} does not exist")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
extension = ".bin"
mapping = { mapping = {
"tool_file_id": tool_file_id, "tool_file_id": tool_file_id,
"type": FileType.IMAGE,
"transfer_method": transfer_method, "transfer_method": transfer_method,
"url": url, "url": url,
} }

+ 16
- 1
api/factories/file_factory.py Ver arquivo

return mime_type, filename, file_size return mime_type, filename, file_size




def _get_file_type_by_mimetype(mime_type: str) -> FileType:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type


def _build_from_tool_file( def _build_from_tool_file(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")


extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))


return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
tenant_id=tenant_id, tenant_id=tenant_id,
filename=tool_file.name, filename=tool_file.name,
type=FileType.value_of(mapping.get("type")),
type=file_type,
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=tool_file.original_url, remote_url=tool_file.original_url,
related_id=tool_file.id, related_id=tool_file.id,

Carregando…
Cancelar
Salvar