|
|
|
@@ -48,7 +48,6 @@ class ComfyUiClient: |
|
|
|
prompt = origin_prompt.copy() |
|
|
|
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} |
|
|
|
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] |
|
|
|
prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) |
|
|
|
positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0] |
|
|
|
prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt |
|
|
|
|
|
|
|
@@ -72,6 +71,18 @@ class ComfyUiClient: |
|
|
|
prompt.get(load_image)["inputs"]["image"] = image_name |
|
|
|
return prompt |
|
|
|
|
|
|
|
def set_prompt_seed_by_id(self, origin_prompt: dict, seed_id: str) -> dict: |
|
|
|
prompt = origin_prompt.copy() |
|
|
|
if seed_id not in prompt: |
|
|
|
raise Exception("Not a valid seed node") |
|
|
|
if "seed" in prompt[seed_id]["inputs"]: |
|
|
|
prompt[seed_id]["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) |
|
|
|
elif "noise_seed" in prompt[seed_id]["inputs"]: |
|
|
|
prompt[seed_id]["inputs"]["noise_seed"] = random.randint(10**14, 10**15 - 1) |
|
|
|
else: |
|
|
|
raise Exception("Not a valid seed node") |
|
|
|
return prompt |
|
|
|
|
|
|
|
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): |
|
|
|
node_ids = list(prompt.keys()) |
|
|
|
finished_nodes = [] |