| import json | |||||
| import operator | |||||
| from typing import Any, Optional, Union | |||||
| import boto3 | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||||
| from core.tools.tool.builtin_tool import BuiltinTool | |||||
| class BedrockRetrieveTool(BuiltinTool): | |||||
| bedrock_client: Any = None | |||||
| knowledge_base_id: str = None | |||||
| topk: int = None | |||||
| def _bedrock_retrieve( | |||||
| self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None | |||||
| ): | |||||
| try: | |||||
| retrieval_query = {"text": query_input} | |||||
| retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} | |||||
| # 如果有元数据过滤条件,则添加到检索配置中 | |||||
| if metadata_filter: | |||||
| retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter | |||||
| response = self.bedrock_client.retrieve( | |||||
| knowledgeBaseId=knowledge_base_id, | |||||
| retrievalQuery=retrieval_query, | |||||
| retrievalConfiguration=retrieval_configuration, | |||||
| ) | |||||
| results = [] | |||||
| for result in response.get("retrievalResults", []): | |||||
| results.append( | |||||
| { | |||||
| "content": result.get("content", {}).get("text", ""), | |||||
| "score": result.get("score", 0.0), | |||||
| "metadata": result.get("metadata", {}), | |||||
| } | |||||
| ) | |||||
| return results | |||||
| except Exception as e: | |||||
| raise Exception(f"Error retrieving from knowledge base: {str(e)}") | |||||
| def _invoke( | |||||
| self, | |||||
| user_id: str, | |||||
| tool_parameters: dict[str, Any], | |||||
| ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||||
| """ | |||||
| invoke tools | |||||
| """ | |||||
| line = 0 | |||||
| try: | |||||
| if not self.bedrock_client: | |||||
| aws_region = tool_parameters.get("aws_region") | |||||
| if aws_region: | |||||
| self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region) | |||||
| else: | |||||
| self.bedrock_client = boto3.client("bedrock-agent-runtime") | |||||
| line = 1 | |||||
| if not self.knowledge_base_id: | |||||
| self.knowledge_base_id = tool_parameters.get("knowledge_base_id") | |||||
| if not self.knowledge_base_id: | |||||
| return self.create_text_message("Please provide knowledge_base_id") | |||||
| line = 2 | |||||
| if not self.topk: | |||||
| self.topk = tool_parameters.get("topk", 5) | |||||
| line = 3 | |||||
| query = tool_parameters.get("query", "") | |||||
| if not query: | |||||
| return self.create_text_message("Please input query") | |||||
| # 获取元数据过滤条件(如果存在) | |||||
| metadata_filter_str = tool_parameters.get("metadata_filter") | |||||
| metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None | |||||
| line = 4 | |||||
| retrieved_docs = self._bedrock_retrieve( | |||||
| query_input=query, | |||||
| knowledge_base_id=self.knowledge_base_id, | |||||
| num_results=self.topk, | |||||
| metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 | |||||
| ) | |||||
| line = 5 | |||||
| # Sort results by score in descending order | |||||
| sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True) | |||||
| line = 6 | |||||
| return [self.create_json_message(res) for res in sorted_docs] | |||||
| except Exception as e: | |||||
| return self.create_text_message(f"Exception {str(e)}, line : {line}") | |||||
| def validate_parameters(self, parameters: dict[str, Any]) -> None: | |||||
| """ | |||||
| Validate the parameters | |||||
| """ | |||||
| if not parameters.get("knowledge_base_id"): | |||||
| raise ValueError("knowledge_base_id is required") | |||||
| if not parameters.get("query"): | |||||
| raise ValueError("query is required") | |||||
| # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) | |||||
| metadata_filter_str = parameters.get("metadata_filter") | |||||
| if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): | |||||
| raise ValueError("metadata_filter must be a valid JSON object") |
| identity: | |||||
| name: bedrock_retrieve | |||||
| author: AWS | |||||
| label: | |||||
| en_US: Bedrock Retrieve | |||||
| zh_Hans: Bedrock检索 | |||||
| pt_BR: Bedrock Retrieve | |||||
| icon: icon.svg | |||||
| description: | |||||
| human: | |||||
| en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool | |||||
| zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明 | |||||
| pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. | |||||
| llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool | |||||
| parameters: | |||||
| - name: knowledge_base_id | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: Bedrock Knowledge Base ID | |||||
| zh_Hans: Bedrock知识库ID | |||||
| pt_BR: Bedrock Knowledge Base ID | |||||
| human_description: | |||||
| en_US: ID of the Bedrock Knowledge Base to retrieve from | |||||
| zh_Hans: 用于检索的Bedrock知识库ID | |||||
| pt_BR: ID of the Bedrock Knowledge Base to retrieve from | |||||
| llm_description: ID of the Bedrock Knowledge Base to retrieve from | |||||
| form: form | |||||
| - name: query | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: Query string | |||||
| zh_Hans: 查询语句 | |||||
| pt_BR: Query string | |||||
| human_description: | |||||
| en_US: The search query to retrieve relevant information | |||||
| zh_Hans: 用于检索相关信息的查询语句 | |||||
| pt_BR: The search query to retrieve relevant information | |||||
| llm_description: The search query to retrieve relevant information | |||||
| form: llm | |||||
| - name: topk | |||||
| type: number | |||||
| required: false | |||||
| form: form | |||||
| label: | |||||
| en_US: Limit for results count | |||||
| zh_Hans: 返回结果数量限制 | |||||
| pt_BR: Limit for results count | |||||
| human_description: | |||||
| en_US: Maximum number of results to return | |||||
| zh_Hans: 最大返回结果数量 | |||||
| pt_BR: Maximum number of results to return | |||||
| min: 1 | |||||
| max: 10 | |||||
| default: 5 | |||||
| - name: aws_region | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: AWS Region | |||||
| zh_Hans: AWS 区域 | |||||
| pt_BR: AWS Region | |||||
| human_description: | |||||
| en_US: AWS region where the Bedrock Knowledge Base is located | |||||
| zh_Hans: Bedrock知识库所在的AWS区域 | |||||
| pt_BR: AWS region where the Bedrock Knowledge Base is located | |||||
| llm_description: AWS region where the Bedrock Knowledge Base is located | |||||
| form: form | |||||
| - name: metadata_filter | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: Metadata Filter | |||||
| zh_Hans: 元数据过滤器 | |||||
| pt_BR: Metadata Filter | |||||
| human_description: | |||||
| en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' | |||||
| zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})' | |||||
| pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' | |||||
| form: form |
| import base64 | |||||
| import json | |||||
| import logging | |||||
| import re | |||||
| from datetime import datetime | |||||
| from typing import Any, Union | |||||
| from urllib.parse import urlparse | |||||
| import boto3 | |||||
| from core.tools.entities.common_entities import I18nObject | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||||
| from core.tools.tool.builtin_tool import BuiltinTool | |||||
| logging.basicConfig(level=logging.INFO) | |||||
| logger = logging.getLogger(__name__) | |||||
| class NovaCanvasTool(BuiltinTool): | |||||
| def _invoke( | |||||
| self, user_id: str, tool_parameters: dict[str, Any] | |||||
| ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||||
| """ | |||||
| Invoke AWS Bedrock Nova Canvas model for image generation | |||||
| """ | |||||
| # Get common parameters | |||||
| prompt = tool_parameters.get("prompt", "") | |||||
| image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip() | |||||
| if not prompt: | |||||
| return self.create_text_message("Please provide a text prompt for image generation.") | |||||
| if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3": | |||||
| return self.create_text_message("Please provide an valid S3 URI for image output.") | |||||
| task_type = tool_parameters.get("task_type", "TEXT_IMAGE") | |||||
| aws_region = tool_parameters.get("aws_region", "us-east-1") | |||||
| # Get common image generation config parameters | |||||
| width = tool_parameters.get("width", 1024) | |||||
| height = tool_parameters.get("height", 1024) | |||||
| cfg_scale = tool_parameters.get("cfg_scale", 8.0) | |||||
| negative_prompt = tool_parameters.get("negative_prompt", "") | |||||
| seed = tool_parameters.get("seed", 0) | |||||
| quality = tool_parameters.get("quality", "standard") | |||||
| # Handle S3 image if provided | |||||
| image_input_s3uri = tool_parameters.get("image_input_s3uri", "") | |||||
| if task_type != "TEXT_IMAGE": | |||||
| if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3": | |||||
| return self.create_text_message("Please provide a valid S3 URI for image to image generation.") | |||||
| # Parse S3 URI | |||||
| parsed_uri = urlparse(image_input_s3uri) | |||||
| bucket = parsed_uri.netloc | |||||
| key = parsed_uri.path.lstrip("/") | |||||
| # Initialize S3 client and download image | |||||
| s3_client = boto3.client("s3") | |||||
| response = s3_client.get_object(Bucket=bucket, Key=key) | |||||
| image_data = response["Body"].read() | |||||
| # Base64 encode the image | |||||
| input_image = base64.b64encode(image_data).decode("utf-8") | |||||
| try: | |||||
| # Initialize Bedrock client | |||||
| bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region) | |||||
| # Base image generation config | |||||
| image_generation_config = { | |||||
| "width": width, | |||||
| "height": height, | |||||
| "cfgScale": cfg_scale, | |||||
| "seed": seed, | |||||
| "numberOfImages": 1, | |||||
| "quality": quality, | |||||
| } | |||||
| # Prepare request body based on task type | |||||
| body = {"imageGenerationConfig": image_generation_config} | |||||
| if task_type == "TEXT_IMAGE": | |||||
| body["taskType"] = "TEXT_IMAGE" | |||||
| body["textToImageParams"] = {"text": prompt} | |||||
| if negative_prompt: | |||||
| body["textToImageParams"]["negativeText"] = negative_prompt | |||||
| elif task_type == "COLOR_GUIDED_GENERATION": | |||||
| colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680") | |||||
| if not self._validate_color_string(colors): | |||||
| return self.create_text_message("Please provide valid colors in hexadecimal format.") | |||||
| body["taskType"] = "COLOR_GUIDED_GENERATION" | |||||
| body["colorGuidedGenerationParams"] = { | |||||
| "colors": colors.split("-"), | |||||
| "referenceImage": input_image, | |||||
| "text": prompt, | |||||
| } | |||||
| if negative_prompt: | |||||
| body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt | |||||
| elif task_type == "IMAGE_VARIATION": | |||||
| similarity_strength = tool_parameters.get("similarity_strength", 0.5) | |||||
| body["taskType"] = "IMAGE_VARIATION" | |||||
| body["imageVariationParams"] = { | |||||
| "images": [input_image], | |||||
| "similarityStrength": similarity_strength, | |||||
| "text": prompt, | |||||
| } | |||||
| if negative_prompt: | |||||
| body["imageVariationParams"]["negativeText"] = negative_prompt | |||||
| elif task_type == "INPAINTING": | |||||
| mask_prompt = tool_parameters.get("mask_prompt") | |||||
| if not mask_prompt: | |||||
| return self.create_text_message("Please provide a mask prompt for image inpainting.") | |||||
| body["taskType"] = "INPAINTING" | |||||
| body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt} | |||||
| if negative_prompt: | |||||
| body["inPaintingParams"]["negativeText"] = negative_prompt | |||||
| elif task_type == "OUTPAINTING": | |||||
| mask_prompt = tool_parameters.get("mask_prompt") | |||||
| if not mask_prompt: | |||||
| return self.create_text_message("Please provide a mask prompt for image outpainting.") | |||||
| outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT") | |||||
| body["taskType"] = "OUTPAINTING" | |||||
| body["outPaintingParams"] = { | |||||
| "image": input_image, | |||||
| "maskPrompt": mask_prompt, | |||||
| "outPaintingMode": outpainting_mode, | |||||
| "text": prompt, | |||||
| } | |||||
| if negative_prompt: | |||||
| body["outPaintingParams"]["negativeText"] = negative_prompt | |||||
| elif task_type == "BACKGROUND_REMOVAL": | |||||
| body["taskType"] = "BACKGROUND_REMOVAL" | |||||
| body["backgroundRemovalParams"] = {"image": input_image} | |||||
| else: | |||||
| return self.create_text_message(f"Unsupported task type: {task_type}") | |||||
| # Call Nova Canvas model | |||||
| response = bedrock.invoke_model( | |||||
| body=json.dumps(body), | |||||
| modelId="amazon.nova-canvas-v1:0", | |||||
| accept="application/json", | |||||
| contentType="application/json", | |||||
| ) | |||||
| # Process response | |||||
| response_body = json.loads(response.get("body").read()) | |||||
| if response_body.get("error"): | |||||
| raise Exception(f"Error in model response: {response_body.get('error')}") | |||||
| base64_image = response_body.get("images")[0] | |||||
| # Upload to S3 if image_output_s3uri is provided | |||||
| try: | |||||
| # Parse S3 URI for output | |||||
| parsed_uri = urlparse(image_output_s3uri) | |||||
| output_bucket = parsed_uri.netloc | |||||
| output_base_path = parsed_uri.path.lstrip("/") | |||||
| # Generate filename with timestamp | |||||
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |||||
| output_key = f"{output_base_path}/canvas-output-{timestamp}.png" | |||||
| # Initialize S3 client if not already done | |||||
| s3_client = boto3.client("s3", region_name=aws_region) | |||||
| # Decode base64 image and upload to S3 | |||||
| image_data = base64.b64decode(base64_image) | |||||
| s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png") | |||||
| logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}") | |||||
| except Exception as e: | |||||
| logger.exception("Failed to upload image to S3") | |||||
| # Return image | |||||
| return [ | |||||
| self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"), | |||||
| self.create_blob_message( | |||||
| blob=base64.b64decode(base64_image), | |||||
| meta={"mime_type": "image/png"}, | |||||
| save_as=self.VariableKey.IMAGE.value, | |||||
| ), | |||||
| ] | |||||
| except Exception as e: | |||||
| return self.create_text_message(f"Failed to generate image: {str(e)}") | |||||
| def _validate_color_string(self, color_string) -> bool: | |||||
| color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$" | |||||
| if re.match(color_pattern, color_string): | |||||
| return True | |||||
| return False | |||||
| def get_runtime_parameters(self) -> list[ToolParameter]: | |||||
| parameters = [ | |||||
| ToolParameter( | |||||
| name="prompt", | |||||
| label=I18nObject(en_US="Prompt", zh_Hans="提示词"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=True, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject( | |||||
| en_US="Text description of the image you want to generate or modify", | |||||
| zh_Hans="您想要生成或修改的图像的文本描述", | |||||
| ), | |||||
| llm_description="Describe the image you want to generate or how you want to modify the input image", | |||||
| ), | |||||
| ToolParameter( | |||||
| name="image_input_s3uri", | |||||
| label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="image_output_s3uri", | |||||
| label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=True, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI" | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="width", | |||||
| label=I18nObject(en_US="Width", zh_Hans="宽度"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=1024, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="height", | |||||
| label=I18nObject(en_US="Height", zh_Hans="高度"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=1024, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="cfg_scale", | |||||
| label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=8.0, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词" | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="negative_prompt", | |||||
| label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| default="", | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject( | |||||
| en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容" | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="seed", | |||||
| label=I18nObject(en_US="Seed", zh_Hans="种子值"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=0, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="aws_region", | |||||
| label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| default="us-east-1", | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="task_type", | |||||
| label=I18nObject(en_US="Task Type", zh_Hans="任务类型"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| default="TEXT_IMAGE", | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="quality", | |||||
| label=I18nObject(en_US="Quality", zh_Hans="质量"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| default="standard", | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)" | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="colors", | |||||
| label=I18nObject(en_US="Colors", zh_Hans="颜色"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680", | |||||
| zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680", | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="similarity_strength", | |||||
| label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=0.5, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="How similar the generated image should be to the input image (0.0 to 1.0)", | |||||
| zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)", | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="mask_prompt", | |||||
| label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject( | |||||
| en_US="Text description to generate mask for inpainting/outpainting", | |||||
| zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述", | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="outpainting_mode", | |||||
| label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| default="DEFAULT", | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="Mode for outpainting (DEFAULT or other supported modes)", | |||||
| zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)", | |||||
| ), | |||||
| ), | |||||
| ] | |||||
| return parameters |
| identity: | |||||
| name: nova_canvas | |||||
| author: AWS | |||||
| label: | |||||
| en_US: AWS Bedrock Nova Canvas | |||||
| zh_Hans: AWS Bedrock Nova Canvas | |||||
| icon: icon.svg | |||||
| description: | |||||
| human: | |||||
| en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html | |||||
| zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。 | |||||
| llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. | |||||
| parameters: | |||||
| - name: task_type | |||||
| type: string | |||||
| required: false | |||||
| default: TEXT_IMAGE | |||||
| label: | |||||
| en_US: Task Type | |||||
| zh_Hans: 任务类型 | |||||
| human_description: | |||||
| en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL) | |||||
| zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除) | |||||
| form: llm | |||||
| - name: prompt | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: Prompt | |||||
| zh_Hans: 提示词 | |||||
| human_description: | |||||
| en_US: Text description of the image you want to generate or modify | |||||
| zh_Hans: 您想要生成或修改的图像的文本描述 | |||||
| llm_description: Describe the image you want to generate or how you want to modify the input image | |||||
| form: llm | |||||
| - name: image_input_s3uri | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: Input image s3 uri | |||||
| zh_Hans: 输入图片的s3 uri | |||||
| human_description: | |||||
| en_US: The input image to modify (required for all modes except TEXT_IMAGE) | |||||
| zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要) | |||||
| llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE. | |||||
| form: llm | |||||
| - name: image_output_s3uri | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: Output S3 URI | |||||
| zh_Hans: 输出S3 URI | |||||
| human_description: | |||||
| en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png | |||||
| zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传 | |||||
| llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename. | |||||
| form: form | |||||
| - name: negative_prompt | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: Negative Prompt | |||||
| zh_Hans: 负面提示词 | |||||
| human_description: | |||||
| en_US: Things you don't want in the generated image | |||||
| zh_Hans: 您不想在生成的图像中出现的内容 | |||||
| form: llm | |||||
| - name: width | |||||
| type: number | |||||
| required: false | |||||
| label: | |||||
| en_US: Width | |||||
| zh_Hans: 宽度 | |||||
| human_description: | |||||
| en_US: Width of the generated image | |||||
| zh_Hans: 生成图像的宽度 | |||||
| form: form | |||||
| default: 1024 | |||||
| - name: height | |||||
| type: number | |||||
| required: false | |||||
| label: | |||||
| en_US: Height | |||||
| zh_Hans: 高度 | |||||
| human_description: | |||||
| en_US: Height of the generated image | |||||
| zh_Hans: 生成图像的高度 | |||||
| form: form | |||||
| default: 1024 | |||||
| - name: cfg_scale | |||||
| type: number | |||||
| required: false | |||||
| label: | |||||
| en_US: CFG Scale | |||||
| zh_Hans: CFG比例 | |||||
| human_description: | |||||
| en_US: How strongly the image should conform to the prompt | |||||
| zh_Hans: 图像应该多大程度上符合提示词 | |||||
| form: form | |||||
| default: 8.0 | |||||
| - name: seed | |||||
| type: number | |||||
| required: false | |||||
| label: | |||||
| en_US: Seed | |||||
| zh_Hans: 种子值 | |||||
| human_description: | |||||
| en_US: Random seed for image generation | |||||
| zh_Hans: 图像生成的随机种子 | |||||
| form: form | |||||
| default: 0 | |||||
| - name: aws_region | |||||
| type: string | |||||
| required: false | |||||
| default: us-east-1 | |||||
| label: | |||||
| en_US: AWS Region | |||||
| zh_Hans: AWS 区域 | |||||
| human_description: | |||||
| en_US: AWS region for Bedrock service | |||||
| zh_Hans: Bedrock 服务的 AWS 区域 | |||||
| form: form | |||||
| - name: quality | |||||
| type: string | |||||
| required: false | |||||
| default: standard | |||||
| label: | |||||
| en_US: Quality | |||||
| zh_Hans: 质量 | |||||
| human_description: | |||||
| en_US: Quality of the generated image (standard or premium) | |||||
| zh_Hans: 生成图像的质量(标准或高级) | |||||
| form: form | |||||
| - name: colors | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: Colors | |||||
| zh_Hans: 颜色 | |||||
| human_description: | |||||
| en_US: List of colors for color-guided generation | |||||
| zh_Hans: 颜色引导生成的颜色列表 | |||||
| form: form | |||||
| - name: similarity_strength | |||||
| type: number | |||||
| required: false | |||||
| default: 0.5 | |||||
| label: | |||||
| en_US: Similarity Strength | |||||
| zh_Hans: 相似度强度 | |||||
| human_description: | |||||
| en_US: How similar the generated image should be to the input image (0.0 to 1.0) | |||||
| zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0) | |||||
| form: form | |||||
| - name: mask_prompt | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: Mask Prompt | |||||
| zh_Hans: 蒙版提示词 | |||||
| human_description: | |||||
| en_US: Text description to generate mask for inpainting/outpainting | |||||
| zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述 | |||||
| form: llm | |||||
| - name: outpainting_mode | |||||
| type: string | |||||
| required: false | |||||
| default: DEFAULT | |||||
| label: | |||||
| en_US: Outpainting Mode | |||||
| zh_Hans: 外补绘制模式 | |||||
| human_description: | |||||
| en_US: Mode for outpainting (DEFAULT or other supported modes) | |||||
| zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式) | |||||
| form: form |
| import base64 | |||||
| import logging | |||||
| import time | |||||
| from io import BytesIO | |||||
| from typing import Any, Optional, Union | |||||
| from urllib.parse import urlparse | |||||
| import boto3 | |||||
| from botocore.exceptions import ClientError | |||||
| from PIL import Image | |||||
| from core.tools.entities.common_entities import I18nObject | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||||
| from core.tools.tool.builtin_tool import BuiltinTool | |||||
| logging.basicConfig(level=logging.INFO) | |||||
| logger = logging.getLogger(__name__) | |||||
| NOVA_REEL_DEFAULT_REGION = "us-east-1" | |||||
| NOVA_REEL_DEFAULT_DIMENSION = "1280x720" | |||||
| NOVA_REEL_DEFAULT_FPS = 24 | |||||
| NOVA_REEL_DEFAULT_DURATION = 6 | |||||
| NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0" | |||||
| NOVA_REEL_STATUS_CHECK_INTERVAL = 5 | |||||
| # Image requirements | |||||
| NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280 | |||||
| NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720 | |||||
| NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB" | |||||
| class NovaReelTool(BuiltinTool): | |||||
| def _invoke( | |||||
| self, user_id: str, tool_parameters: dict[str, Any] | |||||
| ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||||
| """ | |||||
| Invoke AWS Bedrock Nova Reel model for video generation. | |||||
| Args: | |||||
| user_id: The ID of the user making the request | |||||
| tool_parameters: Dictionary containing the tool parameters | |||||
| Returns: | |||||
| ToolInvokeMessage containing either the video content or status information | |||||
| """ | |||||
| try: | |||||
| # Validate and extract parameters | |||||
| params = self._validate_and_extract_parameters(tool_parameters) | |||||
| if isinstance(params, ToolInvokeMessage): | |||||
| return params | |||||
| # Initialize AWS clients | |||||
| bedrock, s3_client = self._initialize_aws_clients(params["aws_region"]) | |||||
| # Prepare model input | |||||
| model_input = self._prepare_model_input(params, s3_client) | |||||
| if isinstance(model_input, ToolInvokeMessage): | |||||
| return model_input | |||||
| # Start video generation | |||||
| invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"]) | |||||
| invocation_arn = invocation["invocationArn"] | |||||
| # Handle async/sync mode | |||||
| return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"]) | |||||
| except ClientError as e: | |||||
| error_code = e.response.get("Error", {}).get("Code", "Unknown") | |||||
| error_message = e.response.get("Error", {}).get("Message", str(e)) | |||||
| logger.exception(f"AWS API error: {error_code} - {error_message}") | |||||
| return self.create_text_message(f"AWS service error: {error_code} - {error_message}") | |||||
| except Exception as e: | |||||
| logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True) | |||||
| return self.create_text_message(f"Failed to generate video: {str(e)}") | |||||
| def _validate_and_extract_parameters( | |||||
| self, tool_parameters: dict[str, Any] | |||||
| ) -> Union[dict[str, Any], ToolInvokeMessage]: | |||||
| """Validate and extract parameters from the input dictionary.""" | |||||
| prompt = tool_parameters.get("prompt", "") | |||||
| video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip() | |||||
| # Validate required parameters | |||||
| if not prompt: | |||||
| return self.create_text_message("Please provide a text prompt for video generation.") | |||||
| if not video_output_s3uri: | |||||
| return self.create_text_message("Please provide an S3 URI for video output.") | |||||
| # Validate S3 URI format | |||||
| if not video_output_s3uri.startswith("s3://"): | |||||
| return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") | |||||
| # Ensure S3 URI ends with '/' | |||||
| video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/" | |||||
| return { | |||||
| "prompt": prompt, | |||||
| "video_output_s3uri": video_output_s3uri, | |||||
| "image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(), | |||||
| "aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION), | |||||
| "dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION), | |||||
| "seed": int(tool_parameters.get("seed", 0)), | |||||
| "fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)), | |||||
| "duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)), | |||||
| "async_mode": bool(tool_parameters.get("async", True)), | |||||
| } | |||||
| def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]: | |||||
| """Initialize AWS Bedrock and S3 clients.""" | |||||
| bedrock = boto3.client(service_name="bedrock-runtime", region_name=region) | |||||
| s3_client = boto3.client("s3", region_name=region) | |||||
| return bedrock, s3_client | |||||
| def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]: | |||||
| """Prepare the input for the Nova Reel model.""" | |||||
| model_input = { | |||||
| "taskType": "TEXT_VIDEO", | |||||
| "textToVideoParams": {"text": params["prompt"]}, | |||||
| "videoGenerationConfig": { | |||||
| "durationSeconds": params["duration"], | |||||
| "fps": params["fps"], | |||||
| "dimension": params["dimension"], | |||||
| "seed": params["seed"], | |||||
| }, | |||||
| } | |||||
| # Add image if provided | |||||
| if params["image_input_s3uri"]: | |||||
| try: | |||||
| image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"]) | |||||
| if not image_data: | |||||
| return self.create_text_message("Failed to retrieve image from S3") | |||||
| # Process and validate image | |||||
| processed_image = self._process_and_validate_image(image_data) | |||||
| if isinstance(processed_image, ToolInvokeMessage): | |||||
| return processed_image | |||||
| # Convert processed image to base64 | |||||
| img_buffer = BytesIO() | |||||
| processed_image.save(img_buffer, format="PNG") | |||||
| img_buffer.seek(0) | |||||
| input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") | |||||
| model_input["textToVideoParams"]["images"] = [ | |||||
| {"format": "png", "source": {"bytes": input_image_base64}} | |||||
| ] | |||||
| except Exception as e: | |||||
| logger.error(f"Error processing input image: {str(e)}", exc_info=True) | |||||
| return self.create_text_message(f"Failed to process input image: {str(e)}") | |||||
| return model_input | |||||
| def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]: | |||||
| """ | |||||
| Process and validate the input image according to Nova Reel requirements. | |||||
| Requirements: | |||||
| - Must be 1280x720 pixels | |||||
| - Must be RGB format (8 bits per channel) | |||||
| - If PNG, alpha channel must not have transparent/translucent pixels | |||||
| """ | |||||
| try: | |||||
| # Open image | |||||
| img = Image.open(BytesIO(image_data)) | |||||
| # Convert RGBA to RGB if needed, ensuring no transparency | |||||
| if img.mode == "RGBA": | |||||
| # Check for transparency | |||||
| if img.getchannel("A").getextrema()[0] < 255: | |||||
| return self.create_text_message( | |||||
| "PNG image contains transparent or translucent pixels, which is not supported. " | |||||
| "Please provide an image without transparency." | |||||
| ) | |||||
| # Convert to RGB | |||||
| img = img.convert("RGB") | |||||
| elif img.mode != "RGB": | |||||
| # Convert any other mode to RGB | |||||
| img = img.convert("RGB") | |||||
| # Validate/adjust dimensions | |||||
| if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT): | |||||
| logger.warning( | |||||
| f"Image dimensions {img.size} do not match required dimensions " | |||||
| f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..." | |||||
| ) | |||||
| img = img.resize( | |||||
| (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS | |||||
| ) | |||||
| # Validate bit depth | |||||
| if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE: | |||||
| return self.create_text_message( | |||||
| f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel" | |||||
| ) | |||||
| return img | |||||
| except Exception as e: | |||||
| logger.error(f"Error processing image: {str(e)}", exc_info=True) | |||||
| return self.create_text_message( | |||||
| "Failed to process image. Please ensure the image is a valid JPEG or PNG file." | |||||
| ) | |||||
| def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]: | |||||
| """Download and return image data from S3.""" | |||||
| parsed_uri = urlparse(s3_uri) | |||||
| bucket = parsed_uri.netloc | |||||
| key = parsed_uri.path.lstrip("/") | |||||
| response = s3_client.get_object(Bucket=bucket, Key=key) | |||||
| return response["Body"].read() | |||||
| def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]: | |||||
| """Start the async video generation process.""" | |||||
| return bedrock.start_async_invoke( | |||||
| modelId=NOVA_REEL_MODEL_ID, | |||||
| modelInput=model_input, | |||||
| outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}}, | |||||
| ) | |||||
| def _handle_generation_mode( | |||||
| self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool | |||||
| ) -> ToolInvokeMessage: | |||||
| """Handle async or sync video generation mode.""" | |||||
| invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn) | |||||
| video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] | |||||
| video_uri = f"{video_path}/output.mp4" | |||||
| if async_mode: | |||||
| return self.create_text_message( | |||||
| f"Video generation started.\nInvocation ARN: {invocation_arn}\n" | |||||
| f"Video will be available at: {video_uri}" | |||||
| ) | |||||
| return self._wait_for_completion(bedrock, s3_client, invocation_arn) | |||||
| def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage: | |||||
| """Wait for video generation completion and handle the result.""" | |||||
| while True: | |||||
| status_response = bedrock.get_async_invoke(invocationArn=invocation_arn) | |||||
| status = status_response["status"] | |||||
| video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] | |||||
| if status == "Completed": | |||||
| return self._handle_completed_video(s3_client, video_path) | |||||
| elif status == "Failed": | |||||
| failure_message = status_response.get("failureMessage", "Unknown error") | |||||
| return self.create_text_message(f"Video generation failed.\nError: {failure_message}") | |||||
| elif status == "InProgress": | |||||
| time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL) | |||||
| else: | |||||
| return self.create_text_message(f"Unexpected status: {status}") | |||||
| def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage: | |||||
| """Handle completed video generation and return the result.""" | |||||
| parsed_uri = urlparse(video_path) | |||||
| bucket = parsed_uri.netloc | |||||
| key = parsed_uri.path.lstrip("/") + "/output.mp4" | |||||
| try: | |||||
| response = s3_client.get_object(Bucket=bucket, Key=key) | |||||
| video_content = response["Body"].read() | |||||
| return [ | |||||
| self.create_text_message(f"Video is available at: {video_path}/output.mp4"), | |||||
| self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"), | |||||
| ] | |||||
| except Exception as e: | |||||
| logger.error(f"Error downloading video: {str(e)}", exc_info=True) | |||||
| return self.create_text_message( | |||||
| f"Video generation completed but failed to download video: {str(e)}\n" | |||||
| f"Video is available at: s3://{bucket}/{key}" | |||||
| ) | |||||
| def get_runtime_parameters(self) -> list[ToolParameter]: | |||||
| """Define the tool's runtime parameters.""" | |||||
| parameters = [ | |||||
| ToolParameter( | |||||
| name="prompt", | |||||
| label=I18nObject(en_US="Prompt", zh_Hans="提示词"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=True, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject( | |||||
| en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述" | |||||
| ), | |||||
| llm_description="Describe the video you want to generate", | |||||
| ), | |||||
| ToolParameter( | |||||
| name="video_output_s3uri", | |||||
| label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=True, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI" | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="dimension", | |||||
| label=I18nObject(en_US="Dimension", zh_Hans="尺寸"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| default=NOVA_REEL_DEFAULT_DIMENSION, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="duration", | |||||
| label=I18nObject(en_US="Duration", zh_Hans="时长"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=NOVA_REEL_DEFAULT_DURATION, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="seed", | |||||
| label=I18nObject(en_US="Seed", zh_Hans="种子值"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=0, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="fps", | |||||
| label=I18nObject(en_US="FPS", zh_Hans="帧率"), | |||||
| type=ToolParameter.ToolParameterType.NUMBER, | |||||
| required=False, | |||||
| default=NOVA_REEL_DEFAULT_FPS, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject( | |||||
| en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数" | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="async", | |||||
| label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"), | |||||
| type=ToolParameter.ToolParameterType.BOOLEAN, | |||||
| required=False, | |||||
| default=True, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject( | |||||
| en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)", | |||||
| zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)", | |||||
| ), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="aws_region", | |||||
| label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| default=NOVA_REEL_DEFAULT_REGION, | |||||
| form=ToolParameter.ToolParameterForm.FORM, | |||||
| human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), | |||||
| ), | |||||
| ToolParameter( | |||||
| name="image_input_s3uri", | |||||
| label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| required=False, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| human_description=I18nObject( | |||||
| en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame", | |||||
| zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI", | |||||
| ), | |||||
| ), | |||||
| ] | |||||
| return parameters |
| identity: | |||||
| name: nova_reel | |||||
| author: AWS | |||||
| label: | |||||
| en_US: AWS Bedrock Nova Reel | |||||
| zh_Hans: AWS Bedrock Nova Reel | |||||
| icon: icon.svg | |||||
| description: | |||||
| human: | |||||
| en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html | |||||
| zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html | |||||
| llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution. | |||||
| parameters: | |||||
| - name: prompt | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: Prompt | |||||
| zh_Hans: 提示词 | |||||
| human_description: | |||||
| en_US: Text description of the video you want to generate | |||||
| zh_Hans: 您想要生成的视频的文本描述 | |||||
| llm_description: Describe the video you want to generate | |||||
| form: llm | |||||
| - name: video_output_s3uri | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: Output S3 URI | |||||
| zh_Hans: 输出S3 URI | |||||
| human_description: | |||||
| en_US: S3 URI where the generated video will be stored | |||||
| zh_Hans: 生成的视频将存储的S3 URI | |||||
| form: form | |||||
| - name: dimension | |||||
| type: string | |||||
| required: false | |||||
| default: 1280x720 | |||||
| label: | |||||
| en_US: Dimension | |||||
| zh_Hans: 尺寸 | |||||
| human_description: | |||||
| en_US: Video dimensions (width x height) | |||||
| zh_Hans: 视频尺寸(宽 x 高) | |||||
| form: form | |||||
| - name: duration | |||||
| type: number | |||||
| required: false | |||||
| default: 6 | |||||
| label: | |||||
| en_US: Duration | |||||
| zh_Hans: 时长 | |||||
| human_description: | |||||
| en_US: Video duration in seconds | |||||
| zh_Hans: 视频时长(秒) | |||||
| form: form | |||||
| - name: seed | |||||
| type: number | |||||
| required: false | |||||
| default: 0 | |||||
| label: | |||||
| en_US: Seed | |||||
| zh_Hans: 种子值 | |||||
| human_description: | |||||
| en_US: Random seed for video generation | |||||
| zh_Hans: 视频生成的随机种子 | |||||
| form: form | |||||
| - name: fps | |||||
| type: number | |||||
| required: false | |||||
| default: 24 | |||||
| label: | |||||
| en_US: FPS | |||||
| zh_Hans: 帧率 | |||||
| human_description: | |||||
| en_US: Frames per second for the generated video | |||||
| zh_Hans: 生成视频的每秒帧数 | |||||
| form: form | |||||
| - name: async | |||||
| type: boolean | |||||
| required: false | |||||
| default: true | |||||
| label: | |||||
| en_US: Async Mode | |||||
| zh_Hans: 异步模式 | |||||
| human_description: | |||||
| en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion) | |||||
| zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成) | |||||
| form: llm | |||||
| - name: aws_region | |||||
| type: string | |||||
| required: false | |||||
| default: us-east-1 | |||||
| label: | |||||
| en_US: AWS Region | |||||
| zh_Hans: AWS 区域 | |||||
| human_description: | |||||
| en_US: AWS region for Bedrock service | |||||
| zh_Hans: Bedrock 服务的 AWS 区域 | |||||
| form: form | |||||
| - name: image_input_s3uri | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: Input Image S3 URI | |||||
| zh_Hans: 输入图像S3 URI | |||||
| human_description: | |||||
| en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame | |||||
| zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI | |||||
| form: llm | |||||
| development: | |||||
| dependencies: | |||||
| - boto3 | |||||
| - pillow |
| from typing import Any, Union | |||||
| from urllib.parse import urlparse | |||||
| import boto3 | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||||
| from core.tools.tool.builtin_tool import BuiltinTool | |||||
| class S3Operator(BuiltinTool): | |||||
| s3_client: Any = None | |||||
| def _invoke( | |||||
| self, | |||||
| user_id: str, | |||||
| tool_parameters: dict[str, Any], | |||||
| ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||||
| """ | |||||
| invoke tools | |||||
| """ | |||||
| try: | |||||
| # Initialize S3 client if not already done | |||||
| if not self.s3_client: | |||||
| aws_region = tool_parameters.get("aws_region") | |||||
| if aws_region: | |||||
| self.s3_client = boto3.client("s3", region_name=aws_region) | |||||
| else: | |||||
| self.s3_client = boto3.client("s3") | |||||
| # Parse S3 URI | |||||
| s3_uri = tool_parameters.get("s3_uri") | |||||
| if not s3_uri: | |||||
| return self.create_text_message("s3_uri parameter is required") | |||||
| parsed_uri = urlparse(s3_uri) | |||||
| if parsed_uri.scheme != "s3": | |||||
| return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") | |||||
| bucket = parsed_uri.netloc | |||||
| # Remove leading slash from key | |||||
| key = parsed_uri.path.lstrip("/") | |||||
| operation_type = tool_parameters.get("operation_type", "read") | |||||
| generate_presign_url = tool_parameters.get("generate_presign_url", False) | |||||
| presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour | |||||
| if operation_type == "write": | |||||
| text_content = tool_parameters.get("text_content") | |||||
| if not text_content: | |||||
| return self.create_text_message("text_content parameter is required for write operation") | |||||
| # Write content to S3 | |||||
| self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8")) | |||||
| result = f"s3://{bucket}/{key}" | |||||
| # Generate presigned URL for the written object if requested | |||||
| if generate_presign_url: | |||||
| result = self.s3_client.generate_presigned_url( | |||||
| "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry | |||||
| ) | |||||
| else: # read operation | |||||
| # Get object from S3 | |||||
| response = self.s3_client.get_object(Bucket=bucket, Key=key) | |||||
| result = response["Body"].read().decode("utf-8") | |||||
| # Generate presigned URL if requested | |||||
| if generate_presign_url: | |||||
| result = self.s3_client.generate_presigned_url( | |||||
| "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry | |||||
| ) | |||||
| return self.create_text_message(text=result) | |||||
| except self.s3_client.exceptions.NoSuchBucket: | |||||
| return self.create_text_message(f"Bucket '{bucket}' does not exist") | |||||
| except self.s3_client.exceptions.NoSuchKey: | |||||
| return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'") | |||||
| except Exception as e: | |||||
| return self.create_text_message(f"Exception: {str(e)}") |
| identity: | |||||
| name: s3_operator | |||||
| author: AWS | |||||
| label: | |||||
| en_US: AWS S3 Operator | |||||
| zh_Hans: AWS S3 读写器 | |||||
| pt_BR: AWS S3 Operator | |||||
| icon: icon.svg | |||||
| description: | |||||
| human: | |||||
| en_US: AWS S3 Writer and Reader | |||||
| zh_Hans: 读写S3 bucket中的文件 | |||||
| pt_BR: AWS S3 Writer and Reader | |||||
| llm: AWS S3 Writer and Reader | |||||
| parameters: | |||||
| - name: text_content | |||||
| type: string | |||||
| required: false | |||||
| label: | |||||
| en_US: The text to write | |||||
| zh_Hans: 待写入的文本 | |||||
| pt_BR: The text to write | |||||
| human_description: | |||||
| en_US: The text to write | |||||
| zh_Hans: 待写入的文本 | |||||
| pt_BR: The text to write | |||||
| llm_description: The text to write | |||||
| form: llm | |||||
| - name: s3_uri | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: s3 uri | |||||
| zh_Hans: s3 uri | |||||
| pt_BR: s3 uri | |||||
| human_description: | |||||
| en_US: s3 uri | |||||
| zh_Hans: s3 uri | |||||
| pt_BR: s3 uri | |||||
| llm_description: s3 uri | |||||
| form: llm | |||||
| - name: aws_region | |||||
| type: string | |||||
| required: true | |||||
| label: | |||||
| en_US: region of bucket | |||||
| zh_Hans: bucket 所在的region | |||||
| pt_BR: region of bucket | |||||
| human_description: | |||||
| en_US: region of bucket | |||||
| zh_Hans: bucket 所在的region | |||||
| pt_BR: region of bucket | |||||
| llm_description: region of bucket | |||||
| form: form | |||||
| - name: operation_type | |||||
| type: select | |||||
| required: true | |||||
| label: | |||||
| en_US: operation type | |||||
| zh_Hans: 操作类型 | |||||
| pt_BR: operation type | |||||
| human_description: | |||||
| en_US: operation type | |||||
| zh_Hans: 操作类型 | |||||
| pt_BR: operation type | |||||
| default: read | |||||
| options: | |||||
| - value: read | |||||
| label: | |||||
| en_US: read | |||||
| zh_Hans: 读 | |||||
| - value: write | |||||
| label: | |||||
| en_US: write | |||||
| zh_Hans: 写 | |||||
| form: form | |||||
| - name: generate_presign_url | |||||
| type: boolean | |||||
| required: false | |||||
| label: | |||||
| en_US: Generate presigned URL | |||||
| zh_Hans: 生成预签名URL | |||||
| human_description: | |||||
| en_US: Whether to generate a presigned URL for the S3 object | |||||
| zh_Hans: 是否生成S3对象的预签名URL | |||||
| default: false | |||||
| form: form | |||||
| - name: presign_expiry | |||||
| type: number | |||||
| required: false | |||||
| label: | |||||
| en_US: Presigned URL expiration time | |||||
| zh_Hans: 预签名URL有效期 | |||||
| human_description: | |||||
| en_US: Expiration time in seconds for the presigned URL | |||||
| zh_Hans: 预签名URL的有效期(秒) | |||||
| default: 3600 | |||||
| form: form |