| @@ -1,17 +1,21 @@ | |||
| from typing import Any | |||
| import websocket | |||
| from yarl import URL | |||
| from core.tools.errors import ToolProviderCredentialValidationError | |||
| from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool | |||
| from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController | |||
| class ComfyUIProvider(BuiltinToolProviderController): | |||
| def _validate_credentials(self, credentials: dict[str, Any]) -> None: | |||
| ws = websocket.WebSocket() | |||
| base_url = URL(credentials.get("base_url")) | |||
| ws_address = f"ws://{base_url.authority}/ws?clientId=test123" | |||
| try: | |||
| ComfyuiStableDiffusionTool().fork_tool_runtime( | |||
| runtime={ | |||
| "credentials": credentials, | |||
| } | |||
| ).validate_models() | |||
| ws.connect(ws_address) | |||
| except Exception as e: | |||
| raise ToolProviderCredentialValidationError(str(e)) | |||
| raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") | |||
| finally: | |||
| ws.close() | |||
| @@ -4,11 +4,9 @@ identity: | |||
| label: | |||
| en_US: ComfyUI | |||
| zh_Hans: ComfyUI | |||
| pt_BR: ComfyUI | |||
| description: | |||
| en_US: ComfyUI is a tool for generating images which can be deployed locally. | |||
| zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。 | |||
| pt_BR: ComfyUI is a tool for generating images which can be deployed locally. | |||
| icon: icon.png | |||
| tags: | |||
| - image | |||
| @@ -17,26 +15,9 @@ credentials_for_provider: | |||
| type: text-input | |||
| required: true | |||
| label: | |||
| en_US: Base URL | |||
| zh_Hans: ComfyUI服务器的Base URL | |||
| pt_BR: Base URL | |||
| en_US: The URL of ComfyUI Server | |||
| zh_Hans: ComfyUI服务器的URL | |||
| placeholder: | |||
| en_US: Please input your ComfyUI server's Base URL | |||
| zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL | |||
| pt_BR: Please input your ComfyUI server's Base URL | |||
| model: | |||
| type: text-input | |||
| required: true | |||
| label: | |||
| en_US: Model with suffix | |||
| zh_Hans: 模型, 需要带后缀 | |||
| pt_BR: Model with suffix | |||
| placeholder: | |||
| en_US: Please input your model | |||
| zh_Hans: 请输入你的模型名称 | |||
| pt_BR: Please input your model | |||
| help: | |||
| en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors | |||
| zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors | |||
| pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors | |||
| url: https://github.com/comfyanonymous/ComfyUI#installing | |||
| url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui | |||
| @@ -0,0 +1,105 @@ | |||
| import json | |||
| import random | |||
| import uuid | |||
| import httpx | |||
| from websocket import WebSocket | |||
| from yarl import URL | |||
| class ComfyUiClient: | |||
| def __init__(self, base_url: str): | |||
| self.base_url = URL(base_url) | |||
| def get_history(self, prompt_id: str): | |||
| res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) | |||
| history = res.json()[prompt_id] | |||
| return history | |||
| def get_image(self, filename: str, subfolder: str, folder_type: str): | |||
| response = httpx.get( | |||
| str(self.base_url / "view"), | |||
| params={"filename": filename, "subfolder": subfolder, "type": folder_type}, | |||
| ) | |||
| return response.content | |||
| def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False): | |||
| # plan to support img2img in dify 0.10.0 | |||
| with open(input_path, "rb") as file: | |||
| files = {"image": (name, file, "image/png")} | |||
| data = {"type": image_type, "overwrite": str(overwrite).lower()} | |||
| res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files) | |||
| return res | |||
| def queue_prompt(self, client_id: str, prompt: dict): | |||
| res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) | |||
| prompt_id = res.json()["prompt_id"] | |||
| return prompt_id | |||
| def open_websocket_connection(self): | |||
| client_id = str(uuid.uuid4()) | |||
| ws = WebSocket() | |||
| ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" | |||
| ws.connect(ws_address) | |||
| return ws, client_id | |||
| def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""): | |||
| """ | |||
| find the first KSampler, then can find the prompt node through it. | |||
| """ | |||
| prompt = origin_prompt.copy() | |||
| id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} | |||
| k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] | |||
| prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) | |||
| positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0] | |||
| prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt | |||
| if negative_prompt != "": | |||
| negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] | |||
| prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt | |||
| return prompt | |||
| def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): | |||
| node_ids = list(prompt.keys()) | |||
| finished_nodes = [] | |||
| while True: | |||
| out = ws.recv() | |||
| if isinstance(out, str): | |||
| message = json.loads(out) | |||
| if message["type"] == "progress": | |||
| data = message["data"] | |||
| current_step = data["value"] | |||
| print("In K-Sampler -> Step: ", current_step, " of: ", data["max"]) | |||
| if message["type"] == "execution_cached": | |||
| data = message["data"] | |||
| for itm in data["nodes"]: | |||
| if itm not in finished_nodes: | |||
| finished_nodes.append(itm) | |||
| print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") | |||
| if message["type"] == "executing": | |||
| data = message["data"] | |||
| if data["node"] not in finished_nodes: | |||
| finished_nodes.append(data["node"]) | |||
| print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") | |||
| if data["node"] is None and data["prompt_id"] == prompt_id: | |||
| break # Execution is done | |||
| else: | |||
| continue | |||
| def generate_image_by_prompt(self, prompt: dict): | |||
| try: | |||
| ws, client_id = self.open_websocket_connection() | |||
| prompt_id = self.queue_prompt(client_id, prompt) | |||
| self.track_progress(prompt, ws, prompt_id) | |||
| history = self.get_history(prompt_id) | |||
| images = [] | |||
| for output in history["outputs"].values(): | |||
| for img in output.get("images", []): | |||
| image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) | |||
| images.append(image_data) | |||
| return images | |||
| finally: | |||
| ws.close() | |||
| @@ -1,10 +1,10 @@ | |||
| identity: | |||
| name: txt2img workflow | |||
| name: txt2img | |||
| author: Qun | |||
| label: | |||
| en_US: Txt2Img Workflow | |||
| zh_Hans: Txt2Img Workflow | |||
| pt_BR: Txt2Img Workflow | |||
| en_US: Txt2Img | |||
| zh_Hans: Txt2Img | |||
| pt_BR: Txt2Img | |||
| description: | |||
| human: | |||
| en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. | |||
| @@ -0,0 +1,32 @@ | |||
| import json | |||
| from typing import Any | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| from .comfyui_client import ComfyUiClient | |||
| class ComfyUIWorkflowTool(BuiltinTool): | |||
| def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: | |||
| comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) | |||
| positive_prompt = tool_parameters.get("positive_prompt") | |||
| negative_prompt = tool_parameters.get("negative_prompt") | |||
| workflow = tool_parameters.get("workflow_json") | |||
| try: | |||
| origin_prompt = json.loads(workflow) | |||
| except: | |||
| return self.create_text_message("the Workflow JSON is not correct") | |||
| prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt) | |||
| images = comfyui.generate_image_by_prompt(prompt) | |||
| result = [] | |||
| for img in images: | |||
| result.append( | |||
| self.create_blob_message( | |||
| blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value | |||
| ) | |||
| ) | |||
| return result | |||
| @@ -0,0 +1,35 @@ | |||
| identity: | |||
| name: workflow | |||
| author: hjlarry | |||
| label: | |||
| en_US: workflow | |||
| zh_Hans: 工作流 | |||
| description: | |||
| human: | |||
| en_US: Run ComfyUI workflow. | |||
| zh_Hans: 运行ComfyUI工作流。 | |||
| llm: Run ComfyUI workflow. | |||
| parameters: | |||
| - name: positive_prompt | |||
| type: string | |||
| label: | |||
| en_US: Prompt | |||
| zh_Hans: 提示词 | |||
| llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. | |||
| form: llm | |||
| - name: negative_prompt | |||
| type: string | |||
| label: | |||
| en_US: Negative Prompt | |||
| zh_Hans: 负面提示词 | |||
| llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. | |||
| form: llm | |||
| - name: workflow_json | |||
| type: string | |||
| required: true | |||
| label: | |||
| en_US: Workflow JSON | |||
| human_description: | |||
| en_US: exported from ComfyUI workflow | |||
| zh_Hans: 从ComfyUI的工作流中导出 | |||
| form: form | |||