|
|
|
@@ -0,0 +1,509 @@ |
|
|
|
from base64 import b64encode |
|
|
|
from hashlib import sha1 |
|
|
|
from hmac import new as hmac_new |
|
|
|
from json import loads as json_loads |
|
|
|
from threading import Lock |
|
|
|
from time import sleep, time |
|
|
|
from typing import Any |
|
|
|
|
|
|
|
from httpx import get, post |
|
|
|
from requests import get as requests_get |
|
|
|
from yarl import URL |
|
|
|
|
|
|
|
from core.tools.entities.common_entities import I18nObject |
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption |
|
|
|
from core.tools.tool.builtin_tool import BuiltinTool |
|
|
|
|
|
|
|
|
|
|
|
class AIPPTGenerateTool(BuiltinTool): |
|
|
|
""" |
|
|
|
A tool for generating a ppt |
|
|
|
""" |
|
|
|
|
|
|
|
_api_base_url = URL('https://co.aippt.cn/api') |
|
|
|
_api_token_cache = {} |
|
|
|
_api_token_cache_lock = Lock() |
|
|
|
|
|
|
|
_task = {} |
|
|
|
_task_type_map = { |
|
|
|
'auto': 1, |
|
|
|
'markdown': 7, |
|
|
|
} |
|
|
|
|
|
|
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: |
|
|
|
""" |
|
|
|
Invokes the AIPPT generate tool with the given user ID and tool parameters. |
|
|
|
|
|
|
|
Args: |
|
|
|
user_id (str): The ID of the user invoking the tool. |
|
|
|
tool_parameters (dict[str, Any]): The parameters for the tool |
|
|
|
|
|
|
|
Returns: |
|
|
|
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. |
|
|
|
""" |
|
|
|
title = tool_parameters.get('title', '') |
|
|
|
if not title: |
|
|
|
return self.create_text_message('Please provide a title for the ppt') |
|
|
|
|
|
|
|
model = tool_parameters.get('model', 'aippt') |
|
|
|
if not model: |
|
|
|
return self.create_text_message('Please provide a model for the ppt') |
|
|
|
|
|
|
|
outline = tool_parameters.get('outline', '') |
|
|
|
|
|
|
|
# create task |
|
|
|
task_id = self._create_task( |
|
|
|
type=self._task_type_map['auto' if not outline else 'markdown'], |
|
|
|
title=title, |
|
|
|
content=outline, |
|
|
|
user_id=user_id |
|
|
|
) |
|
|
|
|
|
|
|
# get suit |
|
|
|
color = tool_parameters.get('color') |
|
|
|
style = tool_parameters.get('style') |
|
|
|
|
|
|
|
if color == '__default__': |
|
|
|
color_id = '' |
|
|
|
else: |
|
|
|
color_id = int(color.split('-')[1]) |
|
|
|
|
|
|
|
if style == '__default__': |
|
|
|
style_id = '' |
|
|
|
else: |
|
|
|
style_id = int(style.split('-')[1]) |
|
|
|
|
|
|
|
suit_id = self._get_suit(style_id=style_id, colour_id=color_id) |
|
|
|
|
|
|
|
# generate outline |
|
|
|
if not outline: |
|
|
|
self._generate_outline( |
|
|
|
task_id=task_id, |
|
|
|
model=model, |
|
|
|
user_id=user_id |
|
|
|
) |
|
|
|
|
|
|
|
# generate content |
|
|
|
self._generate_content( |
|
|
|
task_id=task_id, |
|
|
|
model=model, |
|
|
|
user_id=user_id |
|
|
|
) |
|
|
|
|
|
|
|
# generate ppt |
|
|
|
_, ppt_url = self._generate_ppt( |
|
|
|
task_id=task_id, |
|
|
|
suit_id=suit_id, |
|
|
|
user_id=user_id |
|
|
|
) |
|
|
|
|
|
|
|
return self.create_text_message('''the ppt has been created successfully,''' |
|
|
|
f'''the ppt url is {ppt_url}''' |
|
|
|
'''please give the ppt url to user and direct user to download it.''') |
|
|
|
|
|
|
|
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str: |
|
|
|
""" |
|
|
|
Create a task |
|
|
|
|
|
|
|
:param type: the task type |
|
|
|
:param title: the task title |
|
|
|
:param content: the task content |
|
|
|
|
|
|
|
:return: the task ID |
|
|
|
""" |
|
|
|
headers = { |
|
|
|
'x-channel': '', |
|
|
|
'x-api-key': self.runtime.credentials['aippt_access_key'], |
|
|
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), |
|
|
|
} |
|
|
|
response = post( |
|
|
|
str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'), |
|
|
|
headers=headers, |
|
|
|
files={ |
|
|
|
'type': ('', str(type)), |
|
|
|
'title': ('', title), |
|
|
|
'content': ('', content) |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
response = response.json() |
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to create task: {response.get("msg")}') |
|
|
|
|
|
|
|
return response.get('data', {}).get('id') |
|
|
|
|
|
|
|
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: |
|
|
|
api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \ |
|
|
|
self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline' |
|
|
|
api_url %= {'task_id': task_id} |
|
|
|
|
|
|
|
headers = { |
|
|
|
'x-channel': '', |
|
|
|
'x-api-key': self.runtime.credentials['aippt_access_key'], |
|
|
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), |
|
|
|
} |
|
|
|
|
|
|
|
response = requests_get( |
|
|
|
url=api_url, |
|
|
|
headers=headers, |
|
|
|
stream=True, |
|
|
|
timeout=(10, 60) |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
outline = '' |
|
|
|
for chunk in response.iter_lines(delimiter=b'\n\n'): |
|
|
|
if not chunk: |
|
|
|
continue |
|
|
|
|
|
|
|
event = '' |
|
|
|
lines = chunk.decode('utf-8').split('\n') |
|
|
|
for line in lines: |
|
|
|
if line.startswith('event:'): |
|
|
|
event = line[6:] |
|
|
|
elif line.startswith('data:'): |
|
|
|
data = line[5:] |
|
|
|
if event == 'message': |
|
|
|
try: |
|
|
|
data = json_loads(data) |
|
|
|
outline += data.get('content', '') |
|
|
|
except Exception as e: |
|
|
|
pass |
|
|
|
elif event == 'close': |
|
|
|
break |
|
|
|
elif event == 'error' or event == 'filter': |
|
|
|
raise Exception(f'Failed to generate outline: {data}') |
|
|
|
|
|
|
|
return outline |
|
|
|
|
|
|
|
def _generate_content(self, task_id: str, model: str, user_id: str) -> str: |
|
|
|
api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \ |
|
|
|
self._api_base_url / 'ai' / 'chat' / 'wx' / 'content' |
|
|
|
api_url %= {'task_id': task_id} |
|
|
|
|
|
|
|
headers = { |
|
|
|
'x-channel': '', |
|
|
|
'x-api-key': self.runtime.credentials['aippt_access_key'], |
|
|
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), |
|
|
|
} |
|
|
|
|
|
|
|
response = requests_get( |
|
|
|
url=api_url, |
|
|
|
headers=headers, |
|
|
|
stream=True, |
|
|
|
timeout=(10, 60) |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
if model == 'aippt': |
|
|
|
content = '' |
|
|
|
for chunk in response.iter_lines(delimiter=b'\n\n'): |
|
|
|
if not chunk: |
|
|
|
continue |
|
|
|
|
|
|
|
event = '' |
|
|
|
lines = chunk.decode('utf-8').split('\n') |
|
|
|
for line in lines: |
|
|
|
if line.startswith('event:'): |
|
|
|
event = line[6:] |
|
|
|
elif line.startswith('data:'): |
|
|
|
data = line[5:] |
|
|
|
if event == 'message': |
|
|
|
try: |
|
|
|
data = json_loads(data) |
|
|
|
content += data.get('content', '') |
|
|
|
except Exception as e: |
|
|
|
pass |
|
|
|
elif event == 'close': |
|
|
|
break |
|
|
|
elif event == 'error' or event == 'filter': |
|
|
|
raise Exception(f'Failed to generate content: {data}') |
|
|
|
|
|
|
|
return content |
|
|
|
elif model == 'wenxin': |
|
|
|
response = response.json() |
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to generate content: {response.get("msg")}') |
|
|
|
|
|
|
|
return response.get('data', '') |
|
|
|
|
|
|
|
return '' |
|
|
|
|
|
|
|
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: |
|
|
|
""" |
|
|
|
Generate a ppt |
|
|
|
|
|
|
|
:param task_id: the task ID |
|
|
|
:param suit_id: the suit ID |
|
|
|
:return: the cover url of the ppt and the ppt url |
|
|
|
""" |
|
|
|
headers = { |
|
|
|
'x-channel': '', |
|
|
|
'x-api-key': self.runtime.credentials['aippt_access_key'], |
|
|
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), |
|
|
|
} |
|
|
|
|
|
|
|
response = post( |
|
|
|
str(self._api_base_url / 'design' / 'v2' / 'save'), |
|
|
|
headers=headers, |
|
|
|
data={ |
|
|
|
'task_id': task_id, |
|
|
|
'template_id': suit_id |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
response = response.json() |
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to generate ppt: {response.get("msg")}') |
|
|
|
|
|
|
|
id = response.get('data', {}).get('id') |
|
|
|
cover_url = response.get('data', {}).get('cover_url') |
|
|
|
|
|
|
|
response = post( |
|
|
|
str(self._api_base_url / 'download' / 'export' / 'file'), |
|
|
|
headers=headers, |
|
|
|
data={ |
|
|
|
'id': id, |
|
|
|
'format': 'ppt', |
|
|
|
'files_to_zip': False, |
|
|
|
'edit': True |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
response = response.json() |
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to generate ppt: {response.get("msg")}') |
|
|
|
|
|
|
|
export_code = response.get('data') |
|
|
|
if not export_code: |
|
|
|
raise Exception('Failed to generate ppt, the export code is empty') |
|
|
|
|
|
|
|
current_iteration = 0 |
|
|
|
while current_iteration < 50: |
|
|
|
# get ppt url |
|
|
|
response = post( |
|
|
|
str(self._api_base_url / 'download' / 'export' / 'file' / 'result'), |
|
|
|
headers=headers, |
|
|
|
data={ |
|
|
|
'task_key': export_code |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
response = response.json() |
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to generate ppt: {response.get("msg")}') |
|
|
|
|
|
|
|
if response.get('msg') == '导出中': |
|
|
|
current_iteration += 1 |
|
|
|
sleep(2) |
|
|
|
continue |
|
|
|
|
|
|
|
ppt_url = response.get('data', []) |
|
|
|
if len(ppt_url) == 0: |
|
|
|
raise Exception('Failed to generate ppt, the ppt url is empty') |
|
|
|
|
|
|
|
return cover_url, ppt_url[0] |
|
|
|
|
|
|
|
raise Exception('Failed to generate ppt, the export is timeout') |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: |
|
|
|
""" |
|
|
|
Get API token |
|
|
|
|
|
|
|
:param credentials: the credentials |
|
|
|
:return: the API token |
|
|
|
""" |
|
|
|
access_key = credentials['aippt_access_key'] |
|
|
|
secret_key = credentials['aippt_secret_key'] |
|
|
|
|
|
|
|
cache_key = f'{access_key}#@#{user_id}' |
|
|
|
|
|
|
|
with cls._api_token_cache_lock: |
|
|
|
# clear expired tokens |
|
|
|
now = time() |
|
|
|
for key in list(cls._api_token_cache.keys()): |
|
|
|
if cls._api_token_cache[key]['expire'] < now: |
|
|
|
del cls._api_token_cache[key] |
|
|
|
|
|
|
|
if cache_key in cls._api_token_cache: |
|
|
|
return cls._api_token_cache[cache_key]['token'] |
|
|
|
|
|
|
|
# get token |
|
|
|
headers = { |
|
|
|
'x-api-key': access_key, |
|
|
|
'x-timestamp': str(int(now)), |
|
|
|
'x-signature': cls._calculate_sign(access_key, secret_key, int(now)) |
|
|
|
} |
|
|
|
|
|
|
|
param = { |
|
|
|
'uid': user_id, |
|
|
|
'channel': '' |
|
|
|
} |
|
|
|
|
|
|
|
response = get( |
|
|
|
str(cls._api_base_url / 'grant' / 'token'), |
|
|
|
params=param, |
|
|
|
headers=headers |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
response = response.json() |
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.get("msg")}') |
|
|
|
|
|
|
|
token = response.get('data', {}).get('token') |
|
|
|
expire = response.get('data', {}).get('time_expire') |
|
|
|
|
|
|
|
with cls._api_token_cache_lock: |
|
|
|
cls._api_token_cache[cache_key] = { |
|
|
|
'token': token, |
|
|
|
'expire': now + expire |
|
|
|
} |
|
|
|
|
|
|
|
return token |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: |
|
|
|
return b64encode( |
|
|
|
hmac_new( |
|
|
|
key=secret_key.encode('utf-8'), |
|
|
|
msg=f'GET@/api/grant/token/@{timestamp}'.encode(), |
|
|
|
digestmod=sha1 |
|
|
|
).digest() |
|
|
|
).decode('utf-8') |
|
|
|
|
|
|
|
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: |
|
|
|
""" |
|
|
|
Get styles |
|
|
|
|
|
|
|
:param credentials: the credentials |
|
|
|
:return: Tuple[list[dict[id, color]], list[dict[id, style]] |
|
|
|
""" |
|
|
|
headers = { |
|
|
|
'x-channel': '', |
|
|
|
'x-api-key': self.runtime.credentials['aippt_access_key'], |
|
|
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id) |
|
|
|
} |
|
|
|
response = get( |
|
|
|
str(self._api_base_url / 'template_component' / 'suit' / 'select'), |
|
|
|
headers=headers |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
response = response.json() |
|
|
|
|
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.get("msg")}') |
|
|
|
|
|
|
|
colors = [{ |
|
|
|
'id': f'id-{item.get("id")}', |
|
|
|
'name': item.get('name'), |
|
|
|
'en_name': item.get('en_name', item.get('name')), |
|
|
|
} for item in response.get('data', {}).get('colour') or []] |
|
|
|
styles = [{ |
|
|
|
'id': f'id-{item.get("id")}', |
|
|
|
'name': item.get('title'), |
|
|
|
} for item in response.get('data', {}).get('suit_style') or []] |
|
|
|
|
|
|
|
return colors, styles |
|
|
|
|
|
|
|
def _get_suit(self, style_id: int, colour_id: int) -> int: |
|
|
|
""" |
|
|
|
Get suit |
|
|
|
""" |
|
|
|
headers = { |
|
|
|
'x-channel': '', |
|
|
|
'x-api-key': self.runtime.credentials['aippt_access_key'], |
|
|
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__') |
|
|
|
} |
|
|
|
response = get( |
|
|
|
str(self._api_base_url / 'template_component' / 'suit' / 'search'), |
|
|
|
headers=headers, |
|
|
|
params={ |
|
|
|
'style_id': style_id, |
|
|
|
'colour_id': colour_id, |
|
|
|
'page': 1, |
|
|
|
'page_size': 1 |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.text}') |
|
|
|
|
|
|
|
response = response.json() |
|
|
|
|
|
|
|
if response.get('code') != 0: |
|
|
|
raise Exception(f'Failed to connect to aippt: {response.get("msg")}') |
|
|
|
|
|
|
|
if len(response.get('data', {}).get('list') or []) > 0: |
|
|
|
return response.get('data', {}).get('list')[0].get('id') |
|
|
|
|
|
|
|
raise Exception('Failed to get suit, the suit does not exist, please check the style and color') |
|
|
|
|
|
|
|
def get_runtime_parameters(self) -> list[ToolParameter]: |
|
|
|
""" |
|
|
|
Get runtime parameters |
|
|
|
|
|
|
|
Override this method to add runtime parameters to the tool. |
|
|
|
""" |
|
|
|
try: |
|
|
|
colors, styles = self.get_styles(user_id='__dify_system__') |
|
|
|
except Exception as e: |
|
|
|
colors, styles = [ |
|
|
|
{'id': -1, 'name': '__default__'} |
|
|
|
], [ |
|
|
|
{'id': -1, 'name': '__default__'} |
|
|
|
] |
|
|
|
|
|
|
|
return [ |
|
|
|
ToolParameter( |
|
|
|
name='color', |
|
|
|
label=I18nObject(zh_Hans='颜色', en_US='Color'), |
|
|
|
human_description=I18nObject(zh_Hans='颜色', en_US='Color'), |
|
|
|
type=ToolParameter.ToolParameterType.SELECT, |
|
|
|
form=ToolParameter.ToolParameterForm.FORM, |
|
|
|
required=False, |
|
|
|
default=colors[0]['id'], |
|
|
|
options=[ |
|
|
|
ToolParameterOption( |
|
|
|
value=color['id'], |
|
|
|
label=I18nObject(zh_Hans=color['name'], en_US=color['en_name']) |
|
|
|
) for color in colors |
|
|
|
] |
|
|
|
), |
|
|
|
ToolParameter( |
|
|
|
name='style', |
|
|
|
label=I18nObject(zh_Hans='风格', en_US='Style'), |
|
|
|
human_description=I18nObject(zh_Hans='风格', en_US='Style'), |
|
|
|
type=ToolParameter.ToolParameterType.SELECT, |
|
|
|
form=ToolParameter.ToolParameterForm.FORM, |
|
|
|
required=False, |
|
|
|
default=styles[0]['id'], |
|
|
|
options=[ |
|
|
|
ToolParameterOption( |
|
|
|
value=style['id'], |
|
|
|
label=I18nObject(zh_Hans=style['name'], en_US=style['name']) |
|
|
|
) for style in styles |
|
|
|
] |
|
|
|
), |
|
|
|
] |