您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

workflow_draft_variable.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import logging
  2. from typing import NoReturn
  3. from flask import Response
  4. from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
  5. from sqlalchemy.orm import Session
  6. from werkzeug.exceptions import Forbidden
  7. from controllers.console import api
  8. from controllers.console.app.error import (
  9. DraftWorkflowNotExist,
  10. )
  11. from controllers.console.app.wraps import get_app_model
  12. from controllers.console.wraps import account_initialization_required, setup_required
  13. from controllers.web.error import InvalidArgumentError, NotFoundError
  14. from core.file import helpers as file_helpers
  15. from core.variables.segment_group import SegmentGroup
  16. from core.variables.segments import ArrayFileSegment, FileSegment, Segment
  17. from core.variables.types import SegmentType
  18. from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
  19. from extensions.ext_database import db
  20. from factories.file_factory import build_from_mapping, build_from_mappings
  21. from factories.variable_factory import build_segment_with_type
  22. from libs.login import current_user, login_required
  23. from models import App, AppMode
  24. from models.account import Account
  25. from models.workflow import WorkflowDraftVariable
  26. from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
  27. from services.workflow_service import WorkflowService
  28. logger = logging.getLogger(__name__)
  29. def _convert_values_to_json_serializable_object(value: Segment):
  30. if isinstance(value, FileSegment):
  31. return value.value.model_dump()
  32. elif isinstance(value, ArrayFileSegment):
  33. return [i.model_dump() for i in value.value]
  34. elif isinstance(value, SegmentGroup):
  35. return [_convert_values_to_json_serializable_object(i) for i in value.value]
  36. else:
  37. return value.value
  38. def _serialize_var_value(variable: WorkflowDraftVariable):
  39. value = variable.get_value()
  40. # create a copy of the value to avoid affecting the model cache.
  41. value = value.model_copy(deep=True)
  42. # Refresh the url signature before returning it to client.
  43. if isinstance(value, FileSegment):
  44. file = value.value
  45. file.remote_url = file.generate_url()
  46. elif isinstance(value, ArrayFileSegment):
  47. files = value.value
  48. for file in files:
  49. file.remote_url = file.generate_url()
  50. return _convert_values_to_json_serializable_object(value)
  51. def _create_pagination_parser():
  52. parser = reqparse.RequestParser()
  53. parser.add_argument(
  54. "page",
  55. type=inputs.int_range(1, 100_000),
  56. required=False,
  57. default=1,
  58. location="args",
  59. help="the page of data requested",
  60. )
  61. parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
  62. return parser
  63. def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
  64. value_type = workflow_draft_var.value_type
  65. return value_type.exposed_type().value
  66. def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
  67. """Serialize full_content information for large variables."""
  68. if not variable.is_truncated():
  69. return None
  70. variable_file = variable.variable_file
  71. assert variable_file is not None
  72. return {
  73. "size_bytes": variable_file.size,
  74. "value_type": variable_file.value_type.exposed_type().value,
  75. "length": variable_file.length,
  76. "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
  77. }
  78. _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
  79. "id": fields.String,
  80. "type": fields.String(attribute=lambda model: model.get_variable_type()),
  81. "name": fields.String,
  82. "description": fields.String,
  83. "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
  84. "value_type": fields.String(attribute=_serialize_variable_type),
  85. "edited": fields.Boolean(attribute=lambda model: model.edited),
  86. "visible": fields.Boolean,
  87. "is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
  88. }
  89. _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
  90. _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
  91. value=fields.Raw(attribute=_serialize_var_value),
  92. full_content=fields.Raw(attribute=_serialize_full_content),
  93. )
  94. _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
  95. "id": fields.String,
  96. "type": fields.String(attribute=lambda _: "env"),
  97. "name": fields.String,
  98. "description": fields.String,
  99. "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
  100. "value_type": fields.String(attribute=_serialize_variable_type),
  101. "edited": fields.Boolean(attribute=lambda model: model.edited),
  102. "visible": fields.Boolean,
  103. }
  104. _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
  105. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
  106. }
  107. def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
  108. return var_list.variables
  109. _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
  110. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
  111. "total": fields.Raw(),
  112. }
  113. _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
  114. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
  115. }
  116. def _api_prerequisite(f):
  117. """Common prerequisites for all draft workflow variable APIs.
  118. It ensures the following conditions are satisfied:
  119. - Dify has been property setup.
  120. - The request user has logged in and initialized.
  121. - The requested app is a workflow or a chat flow.
  122. - The request user has the edit permission for the app.
  123. """
  124. @setup_required
  125. @login_required
  126. @account_initialization_required
  127. @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
  128. def wrapper(*args, **kwargs):
  129. assert isinstance(current_user, Account)
  130. if not current_user.is_editor:
  131. raise Forbidden()
  132. return f(*args, **kwargs)
  133. return wrapper
  134. class WorkflowVariableCollectionApi(Resource):
  135. @_api_prerequisite
  136. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
  137. def get(self, app_model: App):
  138. """
  139. Get draft workflow
  140. """
  141. parser = _create_pagination_parser()
  142. args = parser.parse_args()
  143. # fetch draft workflow by app_model
  144. workflow_service = WorkflowService()
  145. workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
  146. if not workflow_exist:
  147. raise DraftWorkflowNotExist()
  148. # fetch draft workflow by app_model
  149. with Session(bind=db.engine, expire_on_commit=False) as session:
  150. draft_var_srv = WorkflowDraftVariableService(
  151. session=session,
  152. )
  153. workflow_vars = draft_var_srv.list_variables_without_values(
  154. app_id=app_model.id,
  155. page=args.page,
  156. limit=args.limit,
  157. )
  158. return workflow_vars
  159. @_api_prerequisite
  160. def delete(self, app_model: App):
  161. draft_var_srv = WorkflowDraftVariableService(
  162. session=db.session(),
  163. )
  164. draft_var_srv.delete_workflow_variables(app_model.id)
  165. db.session.commit()
  166. return Response("", 204)
  167. def validate_node_id(node_id: str) -> NoReturn | None:
  168. if node_id in [
  169. CONVERSATION_VARIABLE_NODE_ID,
  170. SYSTEM_VARIABLE_NODE_ID,
  171. ]:
  172. # NOTE(QuantumGhost): While we store the system and conversation variables as node variables
  173. # with specific `node_id` in database, we still want to make the API separated. By disallowing
  174. # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
  175. # we mitigate the risk that user of the API depending on the implementation detail of the API.
  176. #
  177. # ref: [Hyrum's Law](https://www.hyrumslaw.com/)
  178. raise InvalidArgumentError(
  179. f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
  180. )
  181. return None
  182. class NodeVariableCollectionApi(Resource):
  183. @_api_prerequisite
  184. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  185. def get(self, app_model: App, node_id: str):
  186. validate_node_id(node_id)
  187. with Session(bind=db.engine, expire_on_commit=False) as session:
  188. draft_var_srv = WorkflowDraftVariableService(
  189. session=session,
  190. )
  191. node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
  192. return node_vars
  193. @_api_prerequisite
  194. def delete(self, app_model: App, node_id: str):
  195. validate_node_id(node_id)
  196. srv = WorkflowDraftVariableService(db.session())
  197. srv.delete_node_variables(app_model.id, node_id)
  198. db.session.commit()
  199. return Response("", 204)
  200. class VariableApi(Resource):
  201. _PATCH_NAME_FIELD = "name"
  202. _PATCH_VALUE_FIELD = "value"
  203. @_api_prerequisite
  204. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
  205. def get(self, app_model: App, variable_id: str):
  206. draft_var_srv = WorkflowDraftVariableService(
  207. session=db.session(),
  208. )
  209. variable = draft_var_srv.get_variable(variable_id=variable_id)
  210. if variable is None:
  211. raise NotFoundError(description=f"variable not found, id={variable_id}")
  212. if variable.app_id != app_model.id:
  213. raise NotFoundError(description=f"variable not found, id={variable_id}")
  214. return variable
  215. @_api_prerequisite
  216. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
  217. def patch(self, app_model: App, variable_id: str):
  218. # Request payload for file types:
  219. #
  220. # Local File:
  221. #
  222. # {
  223. # "type": "image",
  224. # "transfer_method": "local_file",
  225. # "url": "",
  226. # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
  227. # }
  228. #
  229. # Remote File:
  230. #
  231. #
  232. # {
  233. # "type": "image",
  234. # "transfer_method": "remote_url",
  235. # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
  236. # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
  237. # }
  238. parser = reqparse.RequestParser()
  239. parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
  240. # Parse 'value' field as-is to maintain its original data structure
  241. parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
  242. draft_var_srv = WorkflowDraftVariableService(
  243. session=db.session(),
  244. )
  245. args = parser.parse_args(strict=True)
  246. variable = draft_var_srv.get_variable(variable_id=variable_id)
  247. if variable is None:
  248. raise NotFoundError(description=f"variable not found, id={variable_id}")
  249. if variable.app_id != app_model.id:
  250. raise NotFoundError(description=f"variable not found, id={variable_id}")
  251. new_name = args.get(self._PATCH_NAME_FIELD, None)
  252. raw_value = args.get(self._PATCH_VALUE_FIELD, None)
  253. if new_name is None and raw_value is None:
  254. return variable
  255. new_value = None
  256. if raw_value is not None:
  257. if variable.value_type == SegmentType.FILE:
  258. if not isinstance(raw_value, dict):
  259. raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
  260. raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
  261. elif variable.value_type == SegmentType.ARRAY_FILE:
  262. if not isinstance(raw_value, list):
  263. raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
  264. if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
  265. raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
  266. raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
  267. new_value = build_segment_with_type(variable.value_type, raw_value)
  268. draft_var_srv.update_variable(variable, name=new_name, value=new_value)
  269. db.session.commit()
  270. return variable
  271. @_api_prerequisite
  272. def delete(self, app_model: App, variable_id: str):
  273. draft_var_srv = WorkflowDraftVariableService(
  274. session=db.session(),
  275. )
  276. variable = draft_var_srv.get_variable(variable_id=variable_id)
  277. if variable is None:
  278. raise NotFoundError(description=f"variable not found, id={variable_id}")
  279. if variable.app_id != app_model.id:
  280. raise NotFoundError(description=f"variable not found, id={variable_id}")
  281. draft_var_srv.delete_variable(variable)
  282. db.session.commit()
  283. return Response("", 204)
  284. class VariableResetApi(Resource):
  285. @_api_prerequisite
  286. def put(self, app_model: App, variable_id: str):
  287. draft_var_srv = WorkflowDraftVariableService(
  288. session=db.session(),
  289. )
  290. workflow_srv = WorkflowService()
  291. draft_workflow = workflow_srv.get_draft_workflow(app_model)
  292. if draft_workflow is None:
  293. raise NotFoundError(
  294. f"Draft workflow not found, app_id={app_model.id}",
  295. )
  296. variable = draft_var_srv.get_variable(variable_id=variable_id)
  297. if variable is None:
  298. raise NotFoundError(description=f"variable not found, id={variable_id}")
  299. if variable.app_id != app_model.id:
  300. raise NotFoundError(description=f"variable not found, id={variable_id}")
  301. resetted = draft_var_srv.reset_variable(draft_workflow, variable)
  302. db.session.commit()
  303. if resetted is None:
  304. return Response("", 204)
  305. else:
  306. return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
  307. def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
  308. with Session(bind=db.engine, expire_on_commit=False) as session:
  309. draft_var_srv = WorkflowDraftVariableService(
  310. session=session,
  311. )
  312. if node_id == CONVERSATION_VARIABLE_NODE_ID:
  313. draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
  314. elif node_id == SYSTEM_VARIABLE_NODE_ID:
  315. draft_vars = draft_var_srv.list_system_variables(app_model.id)
  316. else:
  317. draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
  318. return draft_vars
  319. class ConversationVariableCollectionApi(Resource):
  320. @_api_prerequisite
  321. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  322. def get(self, app_model: App):
  323. # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
  324. # so their IDs can be returned to the caller.
  325. workflow_srv = WorkflowService()
  326. draft_workflow = workflow_srv.get_draft_workflow(app_model)
  327. if draft_workflow is None:
  328. raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
  329. draft_var_srv = WorkflowDraftVariableService(db.session())
  330. draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
  331. db.session.commit()
  332. return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
  333. class SystemVariableCollectionApi(Resource):
  334. @_api_prerequisite
  335. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  336. def get(self, app_model: App):
  337. return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
  338. class EnvironmentVariableCollectionApi(Resource):
  339. @_api_prerequisite
  340. def get(self, app_model: App):
  341. """
  342. Get draft workflow
  343. """
  344. # fetch draft workflow by app_model
  345. workflow_service = WorkflowService()
  346. workflow = workflow_service.get_draft_workflow(app_model=app_model)
  347. if workflow is None:
  348. raise DraftWorkflowNotExist()
  349. env_vars = workflow.environment_variables
  350. env_vars_list = []
  351. for v in env_vars:
  352. env_vars_list.append(
  353. {
  354. "id": v.id,
  355. "type": "env",
  356. "name": v.name,
  357. "description": v.description,
  358. "selector": v.selector,
  359. "value_type": v.value_type.exposed_type().value,
  360. "value": v.value,
  361. # Do not track edited for env vars.
  362. "edited": False,
  363. "visible": True,
  364. "editable": True,
  365. }
  366. )
  367. return {"items": env_vars_list}
  368. api.add_resource(
  369. WorkflowVariableCollectionApi,
  370. "/apps/<uuid:app_id>/workflows/draft/variables",
  371. )
  372. api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
  373. api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
  374. api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
  375. api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
  376. api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
  377. api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")