@@ -1,5 +1,3 @@ | |||
import base64 | |||
import io | |||
import json | |||
import random | |||
import uuid | |||
@@ -8,7 +6,7 @@ import httpx | |||
from websocket import WebSocket | |||
from yarl import URL | |||
from core.file.file_manager import _get_encoded_string | |||
from core.file.file_manager import download | |||
from core.file.models import File | |||
@@ -29,8 +27,7 @@ class ComfyUiClient: | |||
return response.content | |||
def upload_image(self, image_file: File) -> dict: | |||
image_content = base64.b64decode(_get_encoded_string(image_file)) | |||
file = io.BytesIO(image_content) | |||
file = download(image_file) | |||
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} | |||
res = httpx.post(str(self.base_url / "upload/image"), files=files) | |||
return res.json() | |||
@@ -47,12 +44,7 @@ class ComfyUiClient: | |||
ws.connect(ws_address) | |||
return ws, client_id | |||
def set_prompt( | |||
self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = "" | |||
) -> dict: | |||
""" | |||
find the first KSampler, then can find the prompt node through it. | |||
""" | |||
def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict: | |||
prompt = origin_prompt.copy() | |||
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} | |||
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] | |||
@@ -64,9 +56,20 @@ class ComfyUiClient: | |||
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] | |||
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt | |||
if image_name != "": | |||
image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0] | |||
prompt.get(image_loader)["inputs"]["image"] = image_name | |||
return prompt | |||
def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict: | |||
prompt = origin_prompt.copy() | |||
for index, image_node_id in enumerate(image_ids): | |||
prompt[image_node_id]["inputs"]["image"] = image_names[index] | |||
return prompt | |||
def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict: | |||
prompt = origin_prompt.copy() | |||
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} | |||
load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"] | |||
for load_image, image_name in zip(load_image_nodes, image_names): | |||
prompt.get(load_image)["inputs"]["image"] = image_name | |||
return prompt | |||
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): |
@@ -1,7 +1,9 @@ | |||
import json | |||
from typing import Any | |||
from core.file import FileType | |||
from core.tools.entities.tool_entities import ToolInvokeMessage | |||
from core.tools.errors import ToolParameterValidationError | |||
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient | |||
from core.tools.tool.builtin_tool import BuiltinTool | |||
@@ -10,19 +12,46 @@ class ComfyUIWorkflowTool(BuiltinTool): | |||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: | |||
comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) | |||
positive_prompt = tool_parameters.get("positive_prompt") | |||
negative_prompt = tool_parameters.get("negative_prompt") | |||
positive_prompt = tool_parameters.get("positive_prompt", "") | |||
negative_prompt = tool_parameters.get("negative_prompt", "") | |||
images = tool_parameters.get("images") or [] | |||
workflow = tool_parameters.get("workflow_json") | |||
image_name = "" | |||
if image := tool_parameters.get("image"): | |||
image_names = [] | |||
for image in images: | |||
if image.type != FileType.IMAGE: | |||
continue | |||
image_name = comfyui.upload_image(image).get("name") | |||
image_names.append(image_name) | |||
set_prompt_with_ksampler = True | |||
if "{{positive_prompt}}" in workflow: | |||
set_prompt_with_ksampler = False | |||
workflow = workflow.replace("{{positive_prompt}}", positive_prompt) | |||
workflow = workflow.replace("{{negative_prompt}}", negative_prompt) | |||
try: | |||
origin_prompt = json.loads(workflow) | |||
prompt = json.loads(workflow) | |||
except: | |||
return self.create_text_message("the Workflow JSON is not correct") | |||
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name) | |||
if set_prompt_with_ksampler: | |||
try: | |||
prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt) | |||
except: | |||
raise ToolParameterValidationError( | |||
"Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json" | |||
) | |||
if image_names: | |||
if image_ids := tool_parameters.get("image_ids"): | |||
image_ids = image_ids.split(",") | |||
try: | |||
prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids) | |||
except: | |||
raise ToolParameterValidationError("the Image Node ID List not match your upload image files.") | |||
else: | |||
prompt = comfyui.set_prompt_images_by_default(prompt, image_names) | |||
images = comfyui.generate_image_by_prompt(prompt) | |||
result = [] | |||
for img in images: |
@@ -24,12 +24,12 @@ parameters: | |||
zh_Hans: 负面提示词 | |||
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. | |||
form: llm | |||
- name: image | |||
type: file | |||
- name: images | |||
type: files | |||
label: | |||
en_US: Input Image | |||
en_US: Input Images | |||
zh_Hans: 输入的图片 | |||
llm_description: The input image, used to transfer to the comfyui workflow to generate another image. | |||
llm_description: The input images, used to transfer to the comfyui workflow to generate another image. | |||
form: llm | |||
- name: workflow_json | |||
type: string | |||
@@ -40,3 +40,15 @@ parameters: | |||
en_US: exported from ComfyUI workflow | |||
zh_Hans: 从ComfyUI的工作流中导出 | |||
form: form | |||
- name: image_ids | |||
type: string | |||
label: | |||
en_US: Image Node ID List | |||
zh_Hans: 图片节点ID列表 | |||
placeholder: | |||
en_US: Use commas to separate multiple node ID | |||
zh_Hans: 多个节点ID时使用半角逗号分隔 | |||
human_description: | |||
en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list. | |||
zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI | |||
form: form |