|
|
|
@@ -5,7 +5,10 @@ import requests |
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMessage |
|
|
|
from core.tools.tool.builtin_tool import BuiltinTool |
|
|
|
|
|
|
|
FLUX_URL = "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" |
|
|
|
FLUX_URL = { |
|
|
|
"schnell": "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image", |
|
|
|
"dev": "https://api.siliconflow.cn/v1/image/generations", |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class FluxTool(BuiltinTool): |
|
|
|
@@ -24,8 +27,12 @@ class FluxTool(BuiltinTool): |
|
|
|
"seed": tool_parameters.get("seed"), |
|
|
|
"num_inference_steps": tool_parameters.get("num_inference_steps", 20), |
|
|
|
} |
|
|
|
model = tool_parameters.get("model", "schnell") |
|
|
|
url = FLUX_URL.get(model) |
|
|
|
if model == "dev": |
|
|
|
payload["model"] = "black-forest-labs/FLUX.1-dev" |
|
|
|
|
|
|
|
response = requests.post(FLUX_URL, json=payload, headers=headers) |
|
|
|
response = requests.post(url, json=payload, headers=headers) |
|
|
|
if response.status_code != 200: |
|
|
|
return self.create_text_message(f"Got Error Response:{response.text}") |
|
|
|
|