|
|
|
@@ -0,0 +1,475 @@ |
|
|
|
import json |
|
|
|
import os |
|
|
|
import random |
|
|
|
import uuid |
|
|
|
from copy import deepcopy |
|
|
|
from enum import Enum |
|
|
|
from typing import Any, Union |
|
|
|
|
|
|
|
import websocket |
|
|
|
from httpx import get, post |
|
|
|
from yarl import URL |
|
|
|
|
|
|
|
from core.tools.entities.common_entities import I18nObject |
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption |
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError |
|
|
|
from core.tools.tool.builtin_tool import BuiltinTool |
|
|
|
|
|
|
|
SD_TXT2IMG_OPTIONS = {} |
|
|
|
LORA_NODE = { |
|
|
|
"inputs": {"lora_name": "", "strength_model": 1, "strength_clip": 1, "model": ["11", 0], "clip": ["11", 1]}, |
|
|
|
"class_type": "LoraLoader", |
|
|
|
"_meta": {"title": "Load LoRA"}, |
|
|
|
} |
|
|
|
FluxGuidanceNode = { |
|
|
|
"inputs": {"guidance": 3.5, "conditioning": ["6", 0]}, |
|
|
|
"class_type": "FluxGuidance", |
|
|
|
"_meta": {"title": "FluxGuidance"}, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class ModelType(Enum): |
|
|
|
SD15 = 1 |
|
|
|
SDXL = 2 |
|
|
|
SD3 = 3 |
|
|
|
FLUX = 4 |
|
|
|
|
|
|
|
|
|
|
|
class ComfyuiStableDiffusionTool(BuiltinTool): |
|
|
|
def _invoke( |
|
|
|
self, user_id: str, tool_parameters: dict[str, Any] |
|
|
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: |
|
|
|
""" |
|
|
|
invoke tools |
|
|
|
""" |
|
|
|
# base url |
|
|
|
base_url = self.runtime.credentials.get("base_url", "") |
|
|
|
if not base_url: |
|
|
|
return self.create_text_message("Please input base_url") |
|
|
|
|
|
|
|
if tool_parameters.get("model"): |
|
|
|
self.runtime.credentials["model"] = tool_parameters["model"] |
|
|
|
|
|
|
|
model = self.runtime.credentials.get("model", None) |
|
|
|
if not model: |
|
|
|
return self.create_text_message("Please input model") |
|
|
|
|
|
|
|
# prompt |
|
|
|
prompt = tool_parameters.get("prompt", "") |
|
|
|
if not prompt: |
|
|
|
return self.create_text_message("Please input prompt") |
|
|
|
|
|
|
|
# get negative prompt |
|
|
|
negative_prompt = tool_parameters.get("negative_prompt", "") |
|
|
|
|
|
|
|
# get size |
|
|
|
width = tool_parameters.get("width", 1024) |
|
|
|
height = tool_parameters.get("height", 1024) |
|
|
|
|
|
|
|
# get steps |
|
|
|
steps = tool_parameters.get("steps", 1) |
|
|
|
|
|
|
|
# get sampler_name |
|
|
|
sampler_name = tool_parameters.get("sampler_name", "euler") |
|
|
|
|
|
|
|
# scheduler |
|
|
|
scheduler = tool_parameters.get("scheduler", "normal") |
|
|
|
|
|
|
|
# get cfg |
|
|
|
cfg = tool_parameters.get("cfg", 7.0) |
|
|
|
|
|
|
|
# get model type |
|
|
|
model_type = tool_parameters.get("model_type", ModelType.SD15.name) |
|
|
|
|
|
|
|
# get lora |
|
|
|
# supports up to 3 loras |
|
|
|
lora_list = [] |
|
|
|
lora_strength_list = [] |
|
|
|
if tool_parameters.get("lora_1"): |
|
|
|
lora_list.append(tool_parameters["lora_1"]) |
|
|
|
lora_strength_list.append(tool_parameters.get("lora_strength_1", 1)) |
|
|
|
if tool_parameters.get("lora_2"): |
|
|
|
lora_list.append(tool_parameters["lora_2"]) |
|
|
|
lora_strength_list.append(tool_parameters.get("lora_strength_2", 1)) |
|
|
|
if tool_parameters.get("lora_3"): |
|
|
|
lora_list.append(tool_parameters["lora_3"]) |
|
|
|
lora_strength_list.append(tool_parameters.get("lora_strength_3", 1)) |
|
|
|
|
|
|
|
return self.text2img( |
|
|
|
base_url=base_url, |
|
|
|
model=model, |
|
|
|
model_type=model_type, |
|
|
|
prompt=prompt, |
|
|
|
negative_prompt=negative_prompt, |
|
|
|
width=width, |
|
|
|
height=height, |
|
|
|
steps=steps, |
|
|
|
sampler_name=sampler_name, |
|
|
|
scheduler=scheduler, |
|
|
|
cfg=cfg, |
|
|
|
lora_list=lora_list, |
|
|
|
lora_strength_list=lora_strength_list, |
|
|
|
) |
|
|
|
|
|
|
|
def get_checkpoints(self) -> list[str]: |
|
|
|
""" |
|
|
|
get checkpoints |
|
|
|
""" |
|
|
|
try: |
|
|
|
base_url = self.runtime.credentials.get("base_url", None) |
|
|
|
if not base_url: |
|
|
|
return [] |
|
|
|
api_url = str(URL(base_url) / "models" / "checkpoints") |
|
|
|
response = get(url=api_url, timeout=(2, 10)) |
|
|
|
if response.status_code != 200: |
|
|
|
return [] |
|
|
|
else: |
|
|
|
return response.json() |
|
|
|
except Exception as e: |
|
|
|
return [] |
|
|
|
|
|
|
|
def get_loras(self) -> list[str]: |
|
|
|
""" |
|
|
|
get loras |
|
|
|
""" |
|
|
|
try: |
|
|
|
base_url = self.runtime.credentials.get("base_url", None) |
|
|
|
if not base_url: |
|
|
|
return [] |
|
|
|
api_url = str(URL(base_url) / "models" / "loras") |
|
|
|
response = get(url=api_url, timeout=(2, 10)) |
|
|
|
if response.status_code != 200: |
|
|
|
return [] |
|
|
|
else: |
|
|
|
return response.json() |
|
|
|
except Exception as e: |
|
|
|
return [] |
|
|
|
|
|
|
|
def get_sample_methods(self) -> tuple[list[str], list[str]]: |
|
|
|
""" |
|
|
|
get sample method |
|
|
|
""" |
|
|
|
try: |
|
|
|
base_url = self.runtime.credentials.get("base_url", None) |
|
|
|
if not base_url: |
|
|
|
return [], [] |
|
|
|
api_url = str(URL(base_url) / "object_info" / "KSampler") |
|
|
|
response = get(url=api_url, timeout=(2, 10)) |
|
|
|
if response.status_code != 200: |
|
|
|
return [], [] |
|
|
|
else: |
|
|
|
data = response.json()["KSampler"]["input"]["required"] |
|
|
|
return data["sampler_name"][0], data["scheduler"][0] |
|
|
|
except Exception as e: |
|
|
|
return [], [] |
|
|
|
|
|
|
|
def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: |
|
|
|
""" |
|
|
|
validate models |
|
|
|
""" |
|
|
|
try: |
|
|
|
base_url = self.runtime.credentials.get("base_url", None) |
|
|
|
if not base_url: |
|
|
|
raise ToolProviderCredentialValidationError("Please input base_url") |
|
|
|
model = self.runtime.credentials.get("model", None) |
|
|
|
if not model: |
|
|
|
raise ToolProviderCredentialValidationError("Please input model") |
|
|
|
|
|
|
|
api_url = str(URL(base_url) / "models" / "checkpoints") |
|
|
|
response = get(url=api_url, timeout=(2, 10)) |
|
|
|
if response.status_code != 200: |
|
|
|
raise ToolProviderCredentialValidationError("Failed to get models") |
|
|
|
else: |
|
|
|
models = response.json() |
|
|
|
if len([d for d in models if d == model]) > 0: |
|
|
|
return self.create_text_message(json.dumps(models)) |
|
|
|
else: |
|
|
|
raise ToolProviderCredentialValidationError(f"model {model} does not exist") |
|
|
|
except Exception as e: |
|
|
|
raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") |
|
|
|
|
|
|
|
def get_history(self, base_url, prompt_id): |
|
|
|
""" |
|
|
|
get history |
|
|
|
""" |
|
|
|
url = str(URL(base_url) / "history") |
|
|
|
respond = get(url, params={"prompt_id": prompt_id}, timeout=(2, 10)) |
|
|
|
return respond.json() |
|
|
|
|
|
|
|
def download_image(self, base_url, filename, subfolder, folder_type): |
|
|
|
""" |
|
|
|
download image |
|
|
|
""" |
|
|
|
url = str(URL(base_url) / "view") |
|
|
|
response = get(url, params={"filename": filename, "subfolder": subfolder, "type": folder_type}, timeout=(2, 10)) |
|
|
|
return response.content |
|
|
|
|
|
|
|
def queue_prompt_image(self, base_url, client_id, prompt): |
|
|
|
""" |
|
|
|
send prompt task and rotate |
|
|
|
""" |
|
|
|
# initiate task execution |
|
|
|
url = str(URL(base_url) / "prompt") |
|
|
|
respond = post(url, data=json.dumps({"client_id": client_id, "prompt": prompt}), timeout=(2, 10)) |
|
|
|
prompt_id = respond.json()["prompt_id"] |
|
|
|
|
|
|
|
ws = websocket.WebSocket() |
|
|
|
if "https" in base_url: |
|
|
|
ws_url = base_url.replace("https", "ws") |
|
|
|
else: |
|
|
|
ws_url = base_url.replace("http", "ws") |
|
|
|
ws.connect(str(URL(f"{ws_url}") / "ws") + f"?clientId={client_id}", timeout=120) |
|
|
|
|
|
|
|
# websocket rotate execution status |
|
|
|
output_images = {} |
|
|
|
while True: |
|
|
|
out = ws.recv() |
|
|
|
if isinstance(out, str): |
|
|
|
message = json.loads(out) |
|
|
|
if message["type"] == "executing": |
|
|
|
data = message["data"] |
|
|
|
if data["node"] is None and data["prompt_id"] == prompt_id: |
|
|
|
break # Execution is done |
|
|
|
elif message["type"] == "status": |
|
|
|
data = message["data"] |
|
|
|
if data["status"]["exec_info"]["queue_remaining"] == 0 and data.get("sid"): |
|
|
|
break # Execution is done |
|
|
|
else: |
|
|
|
continue # previews are binary data |
|
|
|
|
|
|
|
# download image when execution finished |
|
|
|
history = self.get_history(base_url, prompt_id)[prompt_id] |
|
|
|
for o in history["outputs"]: |
|
|
|
for node_id in history["outputs"]: |
|
|
|
node_output = history["outputs"][node_id] |
|
|
|
if "images" in node_output: |
|
|
|
images_output = [] |
|
|
|
for image in node_output["images"]: |
|
|
|
image_data = self.download_image(base_url, image["filename"], image["subfolder"], image["type"]) |
|
|
|
images_output.append(image_data) |
|
|
|
output_images[node_id] = images_output |
|
|
|
|
|
|
|
ws.close() |
|
|
|
|
|
|
|
return output_images |
|
|
|
|
|
|
|
def text2img( |
|
|
|
self, |
|
|
|
base_url: str, |
|
|
|
model: str, |
|
|
|
model_type: str, |
|
|
|
prompt: str, |
|
|
|
negative_prompt: str, |
|
|
|
width: int, |
|
|
|
height: int, |
|
|
|
steps: int, |
|
|
|
sampler_name: str, |
|
|
|
scheduler: str, |
|
|
|
cfg: float, |
|
|
|
lora_list: list, |
|
|
|
lora_strength_list: list, |
|
|
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: |
|
|
|
""" |
|
|
|
generate image |
|
|
|
""" |
|
|
|
if not SD_TXT2IMG_OPTIONS: |
|
|
|
current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
with open(os.path.join(current_dir, "txt2img.json")) as file: |
|
|
|
SD_TXT2IMG_OPTIONS.update(json.load(file)) |
|
|
|
|
|
|
|
draw_options = deepcopy(SD_TXT2IMG_OPTIONS) |
|
|
|
draw_options["3"]["inputs"]["steps"] = steps |
|
|
|
draw_options["3"]["inputs"]["sampler_name"] = sampler_name |
|
|
|
draw_options["3"]["inputs"]["scheduler"] = scheduler |
|
|
|
draw_options["3"]["inputs"]["cfg"] = cfg |
|
|
|
# generate different image when using same prompt next time |
|
|
|
draw_options["3"]["inputs"]["seed"] = random.randint(0, 100000000) |
|
|
|
draw_options["4"]["inputs"]["ckpt_name"] = model |
|
|
|
draw_options["5"]["inputs"]["width"] = width |
|
|
|
draw_options["5"]["inputs"]["height"] = height |
|
|
|
draw_options["6"]["inputs"]["text"] = prompt |
|
|
|
draw_options["7"]["inputs"]["text"] = negative_prompt |
|
|
|
# if the model is SD3 or FLUX series, the Latent class should be corresponding to SD3 Latent |
|
|
|
if model_type in (ModelType.SD3.name, ModelType.FLUX.name): |
|
|
|
draw_options["5"]["class_type"] = "EmptySD3LatentImage" |
|
|
|
|
|
|
|
if lora_list: |
|
|
|
# last Lora node link to KSampler node |
|
|
|
draw_options["3"]["inputs"]["model"][0] = "10" |
|
|
|
# last Lora node link to positive and negative Clip node |
|
|
|
draw_options["6"]["inputs"]["clip"][0] = "10" |
|
|
|
draw_options["7"]["inputs"]["clip"][0] = "10" |
|
|
|
# every Lora node link to next Lora node, and Checkpoints node link to first Lora node |
|
|
|
for i, (lora, strength) in enumerate(zip(lora_list, lora_strength_list), 10): |
|
|
|
if i - 10 == len(lora_list) - 1: |
|
|
|
next_node_id = "4" |
|
|
|
else: |
|
|
|
next_node_id = str(i + 1) |
|
|
|
lora_node = deepcopy(LORA_NODE) |
|
|
|
lora_node["inputs"]["lora_name"] = lora |
|
|
|
lora_node["inputs"]["strength_model"] = strength |
|
|
|
lora_node["inputs"]["strength_clip"] = strength |
|
|
|
lora_node["inputs"]["model"][0] = next_node_id |
|
|
|
lora_node["inputs"]["clip"][0] = next_node_id |
|
|
|
draw_options[str(i)] = lora_node |
|
|
|
|
|
|
|
# FLUX need to add FluxGuidance Node |
|
|
|
if model_type == ModelType.FLUX.name: |
|
|
|
last_node_id = str(10 + len(lora_list)) |
|
|
|
draw_options[last_node_id] = deepcopy(FluxGuidanceNode) |
|
|
|
draw_options[last_node_id]["inputs"]["conditioning"][0] = "6" |
|
|
|
draw_options["3"]["inputs"]["positive"][0] = last_node_id |
|
|
|
|
|
|
|
try: |
|
|
|
client_id = str(uuid.uuid4()) |
|
|
|
result = self.queue_prompt_image(base_url, client_id, prompt=draw_options) |
|
|
|
|
|
|
|
# get first image |
|
|
|
image = b"" |
|
|
|
for node in result: |
|
|
|
for img in result[node]: |
|
|
|
if img: |
|
|
|
image = img |
|
|
|
break |
|
|
|
|
|
|
|
return self.create_blob_message( |
|
|
|
blob=image, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value |
|
|
|
) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
return self.create_text_message(f"Failed to generate image: {str(e)}") |
|
|
|
|
|
|
|
def get_runtime_parameters(self) -> list[ToolParameter]: |
|
|
|
parameters = [ |
|
|
|
ToolParameter( |
|
|
|
name="prompt", |
|
|
|
label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), |
|
|
|
human_description=I18nObject( |
|
|
|
en_US="Image prompt, you can check the official documentation of Stable Diffusion", |
|
|
|
zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", |
|
|
|
), |
|
|
|
type=ToolParameter.ToolParameterType.STRING, |
|
|
|
form=ToolParameter.ToolParameterForm.LLM, |
|
|
|
llm_description="Image prompt of Stable Diffusion, you should describe the image " |
|
|
|
"you want to generate as a list of words as possible as detailed, " |
|
|
|
"the prompt must be written in English.", |
|
|
|
required=True, |
|
|
|
), |
|
|
|
] |
|
|
|
if self.runtime.credentials: |
|
|
|
try: |
|
|
|
models = self.get_checkpoints() |
|
|
|
if len(models) != 0: |
|
|
|
parameters.append( |
|
|
|
ToolParameter( |
|
|
|
name="model", |
|
|
|
label=I18nObject(en_US="Model", zh_Hans="Model"), |
|
|
|
human_description=I18nObject( |
|
|
|
en_US="Model of Stable Diffusion or FLUX, " |
|
|
|
"you can check the official documentation of Stable Diffusion or FLUX", |
|
|
|
zh_Hans="Stable Diffusion 或者 FLUX 的模型,您可以查看 Stable Diffusion 的官方文档", |
|
|
|
), |
|
|
|
type=ToolParameter.ToolParameterType.SELECT, |
|
|
|
form=ToolParameter.ToolParameterForm.FORM, |
|
|
|
llm_description="Model of Stable Diffusion or FLUX, " |
|
|
|
"you can check the official documentation of Stable Diffusion or FLUX", |
|
|
|
required=True, |
|
|
|
default=models[0], |
|
|
|
options=[ |
|
|
|
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models |
|
|
|
], |
|
|
|
) |
|
|
|
) |
|
|
|
loras = self.get_loras() |
|
|
|
if len(loras) != 0: |
|
|
|
for n in range(1, 4): |
|
|
|
parameters.append( |
|
|
|
ToolParameter( |
|
|
|
name=f"lora_{n}", |
|
|
|
label=I18nObject(en_US=f"Lora {n}", zh_Hans=f"Lora {n}"), |
|
|
|
human_description=I18nObject( |
|
|
|
en_US="Lora of Stable Diffusion, " |
|
|
|
"you can check the official documentation of Stable Diffusion", |
|
|
|
zh_Hans="Stable Diffusion 的 Lora 模型,您可以查看 Stable Diffusion 的官方文档", |
|
|
|
), |
|
|
|
type=ToolParameter.ToolParameterType.SELECT, |
|
|
|
form=ToolParameter.ToolParameterForm.FORM, |
|
|
|
llm_description="Lora of Stable Diffusion, " |
|
|
|
"you can check the official documentation of " |
|
|
|
"Stable Diffusion", |
|
|
|
required=False, |
|
|
|
options=[ |
|
|
|
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in loras |
|
|
|
], |
|
|
|
) |
|
|
|
) |
|
|
|
sample_methods, schedulers = self.get_sample_methods() |
|
|
|
if len(sample_methods) != 0: |
|
|
|
parameters.append( |
|
|
|
ToolParameter( |
|
|
|
name="sampler_name", |
|
|
|
label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), |
|
|
|
human_description=I18nObject( |
|
|
|
en_US="Sampling method of Stable Diffusion, " |
|
|
|
"you can check the official documentation of Stable Diffusion", |
|
|
|
zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", |
|
|
|
), |
|
|
|
type=ToolParameter.ToolParameterType.SELECT, |
|
|
|
form=ToolParameter.ToolParameterForm.FORM, |
|
|
|
llm_description="Sampling method of Stable Diffusion, " |
|
|
|
"you can check the official documentation of Stable Diffusion", |
|
|
|
required=True, |
|
|
|
default=sample_methods[0], |
|
|
|
options=[ |
|
|
|
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) |
|
|
|
for i in sample_methods |
|
|
|
], |
|
|
|
) |
|
|
|
) |
|
|
|
if len(schedulers) != 0: |
|
|
|
parameters.append( |
|
|
|
ToolParameter( |
|
|
|
name="scheduler", |
|
|
|
label=I18nObject(en_US="Scheduler", zh_Hans="Scheduler"), |
|
|
|
human_description=I18nObject( |
|
|
|
en_US="Scheduler of Stable Diffusion, " |
|
|
|
"you can check the official documentation of Stable Diffusion", |
|
|
|
zh_Hans="Stable Diffusion 的Scheduler,您可以查看 Stable Diffusion 的官方文档", |
|
|
|
), |
|
|
|
type=ToolParameter.ToolParameterType.SELECT, |
|
|
|
form=ToolParameter.ToolParameterForm.FORM, |
|
|
|
llm_description="Scheduler of Stable Diffusion, " |
|
|
|
"you can check the official documentation of Stable Diffusion", |
|
|
|
required=True, |
|
|
|
default=schedulers[0], |
|
|
|
options=[ |
|
|
|
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in schedulers |
|
|
|
], |
|
|
|
) |
|
|
|
) |
|
|
|
parameters.append( |
|
|
|
ToolParameter( |
|
|
|
name="model_type", |
|
|
|
label=I18nObject(en_US="Model Type", zh_Hans="Model Type"), |
|
|
|
human_description=I18nObject( |
|
|
|
en_US="Model Type of Stable Diffusion or Flux, " |
|
|
|
"you can check the official documentation of Stable Diffusion or Flux", |
|
|
|
zh_Hans="Stable Diffusion 或 FLUX 的模型类型," |
|
|
|
"您可以查看 Stable Diffusion 或 Flux 的官方文档", |
|
|
|
), |
|
|
|
type=ToolParameter.ToolParameterType.SELECT, |
|
|
|
form=ToolParameter.ToolParameterForm.FORM, |
|
|
|
llm_description="Model Type of Stable Diffusion or Flux, " |
|
|
|
"you can check the official documentation of Stable Diffusion or Flux", |
|
|
|
required=True, |
|
|
|
default=ModelType.SD15.name, |
|
|
|
options=[ |
|
|
|
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) |
|
|
|
for i in ModelType.__members__ |
|
|
|
], |
|
|
|
) |
|
|
|
) |
|
|
|
except: |
|
|
|
pass |
|
|
|
|
|
|
|
return parameters |