You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow_draft_variable.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import logging
  2. from typing import NoReturn
  3. from flask import Response
  4. from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
  5. from sqlalchemy.orm import Session
  6. from werkzeug.exceptions import Forbidden
  7. from controllers.console import api, console_ns
  8. from controllers.console.app.error import (
  9. DraftWorkflowNotExist,
  10. )
  11. from controllers.console.app.wraps import get_app_model
  12. from controllers.console.wraps import account_initialization_required, setup_required
  13. from controllers.web.error import InvalidArgumentError, NotFoundError
  14. from core.variables.segment_group import SegmentGroup
  15. from core.variables.segments import ArrayFileSegment, FileSegment, Segment
  16. from core.variables.types import SegmentType
  17. from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
  18. from factories.file_factory import build_from_mapping, build_from_mappings
  19. from factories.variable_factory import build_segment_with_type
  20. from libs.login import current_user, login_required
  21. from models import App, AppMode, db
  22. from models.account import Account
  23. from models.workflow import WorkflowDraftVariable
  24. from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
  25. from services.workflow_service import WorkflowService
  26. logger = logging.getLogger(__name__)
  27. def _convert_values_to_json_serializable_object(value: Segment):
  28. if isinstance(value, FileSegment):
  29. return value.value.model_dump()
  30. elif isinstance(value, ArrayFileSegment):
  31. return [i.model_dump() for i in value.value]
  32. elif isinstance(value, SegmentGroup):
  33. return [_convert_values_to_json_serializable_object(i) for i in value.value]
  34. else:
  35. return value.value
  36. def _serialize_var_value(variable: WorkflowDraftVariable):
  37. value = variable.get_value()
  38. # create a copy of the value to avoid affecting the model cache.
  39. value = value.model_copy(deep=True)
  40. # Refresh the url signature before returning it to client.
  41. if isinstance(value, FileSegment):
  42. file = value.value
  43. file.remote_url = file.generate_url()
  44. elif isinstance(value, ArrayFileSegment):
  45. files = value.value
  46. for file in files:
  47. file.remote_url = file.generate_url()
  48. return _convert_values_to_json_serializable_object(value)
  49. def _create_pagination_parser():
  50. parser = reqparse.RequestParser()
  51. parser.add_argument(
  52. "page",
  53. type=inputs.int_range(1, 100_000),
  54. required=False,
  55. default=1,
  56. location="args",
  57. help="the page of data requested",
  58. )
  59. parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
  60. return parser
  61. def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
  62. value_type = workflow_draft_var.value_type
  63. return value_type.exposed_type().value
  64. _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
  65. "id": fields.String,
  66. "type": fields.String(attribute=lambda model: model.get_variable_type()),
  67. "name": fields.String,
  68. "description": fields.String,
  69. "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
  70. "value_type": fields.String(attribute=_serialize_variable_type),
  71. "edited": fields.Boolean(attribute=lambda model: model.edited),
  72. "visible": fields.Boolean,
  73. }
  74. _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
  75. _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
  76. value=fields.Raw(attribute=_serialize_var_value),
  77. )
  78. _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
  79. "id": fields.String,
  80. "type": fields.String(attribute=lambda _: "env"),
  81. "name": fields.String,
  82. "description": fields.String,
  83. "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
  84. "value_type": fields.String(attribute=_serialize_variable_type),
  85. "edited": fields.Boolean(attribute=lambda model: model.edited),
  86. "visible": fields.Boolean,
  87. }
  88. _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
  89. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
  90. }
  91. def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
  92. return var_list.variables
  93. _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
  94. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
  95. "total": fields.Raw(),
  96. }
  97. _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
  98. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
  99. }
  100. def _api_prerequisite(f):
  101. """Common prerequisites for all draft workflow variable APIs.
  102. It ensures the following conditions are satisfied:
  103. - Dify has been property setup.
  104. - The request user has logged in and initialized.
  105. - The requested app is a workflow or a chat flow.
  106. - The request user has the edit permission for the app.
  107. """
  108. @setup_required
  109. @login_required
  110. @account_initialization_required
  111. @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
  112. def wrapper(*args, **kwargs):
  113. assert isinstance(current_user, Account)
  114. if not current_user.has_edit_permission:
  115. raise Forbidden()
  116. return f(*args, **kwargs)
  117. return wrapper
  118. @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
  119. class WorkflowVariableCollectionApi(Resource):
  120. @api.doc("get_workflow_variables")
  121. @api.doc(description="Get draft workflow variables")
  122. @api.doc(params={"app_id": "Application ID"})
  123. @api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
  124. @api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
  125. @_api_prerequisite
  126. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
  127. def get(self, app_model: App):
  128. """
  129. Get draft workflow
  130. """
  131. parser = _create_pagination_parser()
  132. args = parser.parse_args()
  133. # fetch draft workflow by app_model
  134. workflow_service = WorkflowService()
  135. workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
  136. if not workflow_exist:
  137. raise DraftWorkflowNotExist()
  138. # fetch draft workflow by app_model
  139. with Session(bind=db.engine, expire_on_commit=False) as session:
  140. draft_var_srv = WorkflowDraftVariableService(
  141. session=session,
  142. )
  143. workflow_vars = draft_var_srv.list_variables_without_values(
  144. app_id=app_model.id,
  145. page=args.page,
  146. limit=args.limit,
  147. )
  148. return workflow_vars
  149. @api.doc("delete_workflow_variables")
  150. @api.doc(description="Delete all draft workflow variables")
  151. @api.response(204, "Workflow variables deleted successfully")
  152. @_api_prerequisite
  153. def delete(self, app_model: App):
  154. draft_var_srv = WorkflowDraftVariableService(
  155. session=db.session(),
  156. )
  157. draft_var_srv.delete_workflow_variables(app_model.id)
  158. db.session.commit()
  159. return Response("", 204)
  160. def validate_node_id(node_id: str) -> NoReturn | None:
  161. if node_id in [
  162. CONVERSATION_VARIABLE_NODE_ID,
  163. SYSTEM_VARIABLE_NODE_ID,
  164. ]:
  165. # NOTE(QuantumGhost): While we store the system and conversation variables as node variables
  166. # with specific `node_id` in database, we still want to make the API separated. By disallowing
  167. # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
  168. # we mitigate the risk that user of the API depending on the implementation detail of the API.
  169. #
  170. # ref: [Hyrum's Law](https://www.hyrumslaw.com/)
  171. raise InvalidArgumentError(
  172. f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
  173. )
  174. return None
  175. @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
  176. class NodeVariableCollectionApi(Resource):
  177. @api.doc("get_node_variables")
  178. @api.doc(description="Get variables for a specific node")
  179. @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
  180. @api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  181. @_api_prerequisite
  182. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  183. def get(self, app_model: App, node_id: str):
  184. validate_node_id(node_id)
  185. with Session(bind=db.engine, expire_on_commit=False) as session:
  186. draft_var_srv = WorkflowDraftVariableService(
  187. session=session,
  188. )
  189. node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
  190. return node_vars
  191. @api.doc("delete_node_variables")
  192. @api.doc(description="Delete all variables for a specific node")
  193. @api.response(204, "Node variables deleted successfully")
  194. @_api_prerequisite
  195. def delete(self, app_model: App, node_id: str):
  196. validate_node_id(node_id)
  197. srv = WorkflowDraftVariableService(db.session())
  198. srv.delete_node_variables(app_model.id, node_id)
  199. db.session.commit()
  200. return Response("", 204)
  201. @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
  202. class VariableApi(Resource):
  203. _PATCH_NAME_FIELD = "name"
  204. _PATCH_VALUE_FIELD = "value"
  205. @api.doc("get_variable")
  206. @api.doc(description="Get a specific workflow variable")
  207. @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
  208. @api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
  209. @api.response(404, "Variable not found")
  210. @_api_prerequisite
  211. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
  212. def get(self, app_model: App, variable_id: str):
  213. draft_var_srv = WorkflowDraftVariableService(
  214. session=db.session(),
  215. )
  216. variable = draft_var_srv.get_variable(variable_id=variable_id)
  217. if variable is None:
  218. raise NotFoundError(description=f"variable not found, id={variable_id}")
  219. if variable.app_id != app_model.id:
  220. raise NotFoundError(description=f"variable not found, id={variable_id}")
  221. return variable
  222. @api.doc("update_variable")
  223. @api.doc(description="Update a workflow variable")
  224. @api.expect(
  225. api.model(
  226. "UpdateVariableRequest",
  227. {
  228. "name": fields.String(description="Variable name"),
  229. "value": fields.Raw(description="Variable value"),
  230. },
  231. )
  232. )
  233. @api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
  234. @api.response(404, "Variable not found")
  235. @_api_prerequisite
  236. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
  237. def patch(self, app_model: App, variable_id: str):
  238. # Request payload for file types:
  239. #
  240. # Local File:
  241. #
  242. # {
  243. # "type": "image",
  244. # "transfer_method": "local_file",
  245. # "url": "",
  246. # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
  247. # }
  248. #
  249. # Remote File:
  250. #
  251. #
  252. # {
  253. # "type": "image",
  254. # "transfer_method": "remote_url",
  255. # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
  256. # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
  257. # }
  258. parser = reqparse.RequestParser()
  259. parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
  260. # Parse 'value' field as-is to maintain its original data structure
  261. parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
  262. draft_var_srv = WorkflowDraftVariableService(
  263. session=db.session(),
  264. )
  265. args = parser.parse_args(strict=True)
  266. variable = draft_var_srv.get_variable(variable_id=variable_id)
  267. if variable is None:
  268. raise NotFoundError(description=f"variable not found, id={variable_id}")
  269. if variable.app_id != app_model.id:
  270. raise NotFoundError(description=f"variable not found, id={variable_id}")
  271. new_name = args.get(self._PATCH_NAME_FIELD, None)
  272. raw_value = args.get(self._PATCH_VALUE_FIELD, None)
  273. if new_name is None and raw_value is None:
  274. return variable
  275. new_value = None
  276. if raw_value is not None:
  277. if variable.value_type == SegmentType.FILE:
  278. if not isinstance(raw_value, dict):
  279. raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
  280. raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
  281. elif variable.value_type == SegmentType.ARRAY_FILE:
  282. if not isinstance(raw_value, list):
  283. raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
  284. if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
  285. raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
  286. raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
  287. new_value = build_segment_with_type(variable.value_type, raw_value)
  288. draft_var_srv.update_variable(variable, name=new_name, value=new_value)
  289. db.session.commit()
  290. return variable
  291. @api.doc("delete_variable")
  292. @api.doc(description="Delete a workflow variable")
  293. @api.response(204, "Variable deleted successfully")
  294. @api.response(404, "Variable not found")
  295. @_api_prerequisite
  296. def delete(self, app_model: App, variable_id: str):
  297. draft_var_srv = WorkflowDraftVariableService(
  298. session=db.session(),
  299. )
  300. variable = draft_var_srv.get_variable(variable_id=variable_id)
  301. if variable is None:
  302. raise NotFoundError(description=f"variable not found, id={variable_id}")
  303. if variable.app_id != app_model.id:
  304. raise NotFoundError(description=f"variable not found, id={variable_id}")
  305. draft_var_srv.delete_variable(variable)
  306. db.session.commit()
  307. return Response("", 204)
  308. @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
  309. class VariableResetApi(Resource):
  310. @api.doc("reset_variable")
  311. @api.doc(description="Reset a workflow variable to its default value")
  312. @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
  313. @api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
  314. @api.response(204, "Variable reset (no content)")
  315. @api.response(404, "Variable not found")
  316. @_api_prerequisite
  317. def put(self, app_model: App, variable_id: str):
  318. draft_var_srv = WorkflowDraftVariableService(
  319. session=db.session(),
  320. )
  321. workflow_srv = WorkflowService()
  322. draft_workflow = workflow_srv.get_draft_workflow(app_model)
  323. if draft_workflow is None:
  324. raise NotFoundError(
  325. f"Draft workflow not found, app_id={app_model.id}",
  326. )
  327. variable = draft_var_srv.get_variable(variable_id=variable_id)
  328. if variable is None:
  329. raise NotFoundError(description=f"variable not found, id={variable_id}")
  330. if variable.app_id != app_model.id:
  331. raise NotFoundError(description=f"variable not found, id={variable_id}")
  332. resetted = draft_var_srv.reset_variable(draft_workflow, variable)
  333. db.session.commit()
  334. if resetted is None:
  335. return Response("", 204)
  336. else:
  337. return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
  338. def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
  339. with Session(bind=db.engine, expire_on_commit=False) as session:
  340. draft_var_srv = WorkflowDraftVariableService(
  341. session=session,
  342. )
  343. if node_id == CONVERSATION_VARIABLE_NODE_ID:
  344. draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
  345. elif node_id == SYSTEM_VARIABLE_NODE_ID:
  346. draft_vars = draft_var_srv.list_system_variables(app_model.id)
  347. else:
  348. draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
  349. return draft_vars
  350. @console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables")
  351. class ConversationVariableCollectionApi(Resource):
  352. @api.doc("get_conversation_variables")
  353. @api.doc(description="Get conversation variables for workflow")
  354. @api.doc(params={"app_id": "Application ID"})
  355. @api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  356. @api.response(404, "Draft workflow not found")
  357. @_api_prerequisite
  358. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  359. def get(self, app_model: App):
  360. # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
  361. # so their IDs can be returned to the caller.
  362. workflow_srv = WorkflowService()
  363. draft_workflow = workflow_srv.get_draft_workflow(app_model)
  364. if draft_workflow is None:
  365. raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
  366. draft_var_srv = WorkflowDraftVariableService(db.session())
  367. draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
  368. db.session.commit()
  369. return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
  370. @console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
  371. class SystemVariableCollectionApi(Resource):
  372. @api.doc("get_system_variables")
  373. @api.doc(description="Get system variables for workflow")
  374. @api.doc(params={"app_id": "Application ID"})
  375. @api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  376. @_api_prerequisite
  377. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  378. def get(self, app_model: App):
  379. return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
  380. @console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
  381. class EnvironmentVariableCollectionApi(Resource):
  382. @api.doc("get_environment_variables")
  383. @api.doc(description="Get environment variables for workflow")
  384. @api.doc(params={"app_id": "Application ID"})
  385. @api.response(200, "Environment variables retrieved successfully")
  386. @api.response(404, "Draft workflow not found")
  387. @_api_prerequisite
  388. def get(self, app_model: App):
  389. """
  390. Get draft workflow
  391. """
  392. # fetch draft workflow by app_model
  393. workflow_service = WorkflowService()
  394. workflow = workflow_service.get_draft_workflow(app_model=app_model)
  395. if workflow is None:
  396. raise DraftWorkflowNotExist()
  397. env_vars = workflow.environment_variables
  398. env_vars_list = []
  399. for v in env_vars:
  400. env_vars_list.append(
  401. {
  402. "id": v.id,
  403. "type": "env",
  404. "name": v.name,
  405. "description": v.description,
  406. "selector": v.selector,
  407. "value_type": v.value_type.exposed_type().value,
  408. "value": v.value,
  409. # Do not track edited for env vars.
  410. "edited": False,
  411. "visible": True,
  412. "editable": True,
  413. }
  414. )
  415. return {"items": env_vars_list}