You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow_draft_variable.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import logging
  2. from typing import Any, NoReturn
  3. from flask import Response
  4. from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
  5. from sqlalchemy.orm import Session
  6. from werkzeug.exceptions import Forbidden
  7. from controllers.console import api
  8. from controllers.console.app.error import (
  9. DraftWorkflowNotExist,
  10. )
  11. from controllers.console.app.wraps import get_app_model
  12. from controllers.console.wraps import account_initialization_required, setup_required
  13. from controllers.web.error import InvalidArgumentError, NotFoundError
  14. from core.variables.segment_group import SegmentGroup
  15. from core.variables.segments import ArrayFileSegment, FileSegment, Segment
  16. from core.variables.types import SegmentType
  17. from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
  18. from factories.file_factory import build_from_mapping, build_from_mappings
  19. from factories.variable_factory import build_segment_with_type
  20. from libs.login import current_user, login_required
  21. from models import App, AppMode, db
  22. from models.workflow import WorkflowDraftVariable
  23. from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
  24. from services.workflow_service import WorkflowService
  25. logger = logging.getLogger(__name__)
  26. def _convert_values_to_json_serializable_object(value: Segment) -> Any:
  27. if isinstance(value, FileSegment):
  28. return value.value.model_dump()
  29. elif isinstance(value, ArrayFileSegment):
  30. return [i.model_dump() for i in value.value]
  31. elif isinstance(value, SegmentGroup):
  32. return [_convert_values_to_json_serializable_object(i) for i in value.value]
  33. else:
  34. return value.value
  35. def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
  36. value = variable.get_value()
  37. # create a copy of the value to avoid affecting the model cache.
  38. value = value.model_copy(deep=True)
  39. # Refresh the url signature before returning it to client.
  40. if isinstance(value, FileSegment):
  41. file = value.value
  42. file.remote_url = file.generate_url()
  43. elif isinstance(value, ArrayFileSegment):
  44. files = value.value
  45. for file in files:
  46. file.remote_url = file.generate_url()
  47. return _convert_values_to_json_serializable_object(value)
  48. def _create_pagination_parser():
  49. parser = reqparse.RequestParser()
  50. parser.add_argument(
  51. "page",
  52. type=inputs.int_range(1, 100_000),
  53. required=False,
  54. default=1,
  55. location="args",
  56. help="the page of data requested",
  57. )
  58. parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
  59. return parser
  60. _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
  61. "id": fields.String,
  62. "type": fields.String(attribute=lambda model: model.get_variable_type()),
  63. "name": fields.String,
  64. "description": fields.String,
  65. "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
  66. "value_type": fields.String,
  67. "edited": fields.Boolean(attribute=lambda model: model.edited),
  68. "visible": fields.Boolean,
  69. }
  70. _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
  71. _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
  72. value=fields.Raw(attribute=_serialize_var_value),
  73. )
  74. _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
  75. "id": fields.String,
  76. "type": fields.String(attribute=lambda _: "env"),
  77. "name": fields.String,
  78. "description": fields.String,
  79. "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
  80. "value_type": fields.String,
  81. "edited": fields.Boolean(attribute=lambda model: model.edited),
  82. "visible": fields.Boolean,
  83. }
  84. _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
  85. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
  86. }
  87. def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
  88. return var_list.variables
  89. _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
  90. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
  91. "total": fields.Raw(),
  92. }
  93. _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
  94. "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
  95. }
  96. def _api_prerequisite(f):
  97. """Common prerequisites for all draft workflow variable APIs.
  98. It ensures the following conditions are satisfied:
  99. - Dify has been property setup.
  100. - The request user has logged in and initialized.
  101. - The requested app is a workflow or a chat flow.
  102. - The request user has the edit permission for the app.
  103. """
  104. @setup_required
  105. @login_required
  106. @account_initialization_required
  107. @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
  108. def wrapper(*args, **kwargs):
  109. if not current_user.is_editor:
  110. raise Forbidden()
  111. return f(*args, **kwargs)
  112. return wrapper
  113. class WorkflowVariableCollectionApi(Resource):
  114. @_api_prerequisite
  115. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
  116. def get(self, app_model: App):
  117. """
  118. Get draft workflow
  119. """
  120. parser = _create_pagination_parser()
  121. args = parser.parse_args()
  122. # fetch draft workflow by app_model
  123. workflow_service = WorkflowService()
  124. workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
  125. if not workflow_exist:
  126. raise DraftWorkflowNotExist()
  127. # fetch draft workflow by app_model
  128. with Session(bind=db.engine, expire_on_commit=False) as session:
  129. draft_var_srv = WorkflowDraftVariableService(
  130. session=session,
  131. )
  132. workflow_vars = draft_var_srv.list_variables_without_values(
  133. app_id=app_model.id,
  134. page=args.page,
  135. limit=args.limit,
  136. )
  137. return workflow_vars
  138. @_api_prerequisite
  139. def delete(self, app_model: App):
  140. draft_var_srv = WorkflowDraftVariableService(
  141. session=db.session(),
  142. )
  143. draft_var_srv.delete_workflow_variables(app_model.id)
  144. db.session.commit()
  145. return Response("", 204)
  146. def validate_node_id(node_id: str) -> NoReturn | None:
  147. if node_id in [
  148. CONVERSATION_VARIABLE_NODE_ID,
  149. SYSTEM_VARIABLE_NODE_ID,
  150. ]:
  151. # NOTE(QuantumGhost): While we store the system and conversation variables as node variables
  152. # with specific `node_id` in database, we still want to make the API separated. By disallowing
  153. # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
  154. # we mitigate the risk that user of the API depending on the implementation detail of the API.
  155. #
  156. # ref: [Hyrum's Law](https://www.hyrumslaw.com/)
  157. raise InvalidArgumentError(
  158. f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
  159. )
  160. return None
  161. class NodeVariableCollectionApi(Resource):
  162. @_api_prerequisite
  163. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  164. def get(self, app_model: App, node_id: str):
  165. validate_node_id(node_id)
  166. with Session(bind=db.engine, expire_on_commit=False) as session:
  167. draft_var_srv = WorkflowDraftVariableService(
  168. session=session,
  169. )
  170. node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
  171. return node_vars
  172. @_api_prerequisite
  173. def delete(self, app_model: App, node_id: str):
  174. validate_node_id(node_id)
  175. srv = WorkflowDraftVariableService(db.session())
  176. srv.delete_node_variables(app_model.id, node_id)
  177. db.session.commit()
  178. return Response("", 204)
  179. class VariableApi(Resource):
  180. _PATCH_NAME_FIELD = "name"
  181. _PATCH_VALUE_FIELD = "value"
  182. @_api_prerequisite
  183. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
  184. def get(self, app_model: App, variable_id: str):
  185. draft_var_srv = WorkflowDraftVariableService(
  186. session=db.session(),
  187. )
  188. variable = draft_var_srv.get_variable(variable_id=variable_id)
  189. if variable is None:
  190. raise NotFoundError(description=f"variable not found, id={variable_id}")
  191. if variable.app_id != app_model.id:
  192. raise NotFoundError(description=f"variable not found, id={variable_id}")
  193. return variable
  194. @_api_prerequisite
  195. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
  196. def patch(self, app_model: App, variable_id: str):
  197. # Request payload for file types:
  198. #
  199. # Local File:
  200. #
  201. # {
  202. # "type": "image",
  203. # "transfer_method": "local_file",
  204. # "url": "",
  205. # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
  206. # }
  207. #
  208. # Remote File:
  209. #
  210. #
  211. # {
  212. # "type": "image",
  213. # "transfer_method": "remote_url",
  214. # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
  215. # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
  216. # }
  217. parser = reqparse.RequestParser()
  218. parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
  219. # Parse 'value' field as-is to maintain its original data structure
  220. parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
  221. draft_var_srv = WorkflowDraftVariableService(
  222. session=db.session(),
  223. )
  224. args = parser.parse_args(strict=True)
  225. variable = draft_var_srv.get_variable(variable_id=variable_id)
  226. if variable is None:
  227. raise NotFoundError(description=f"variable not found, id={variable_id}")
  228. if variable.app_id != app_model.id:
  229. raise NotFoundError(description=f"variable not found, id={variable_id}")
  230. new_name = args.get(self._PATCH_NAME_FIELD, None)
  231. raw_value = args.get(self._PATCH_VALUE_FIELD, None)
  232. if new_name is None and raw_value is None:
  233. return variable
  234. new_value = None
  235. if raw_value is not None:
  236. if variable.value_type == SegmentType.FILE:
  237. if not isinstance(raw_value, dict):
  238. raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
  239. raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
  240. elif variable.value_type == SegmentType.ARRAY_FILE:
  241. if not isinstance(raw_value, list):
  242. raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
  243. if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
  244. raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
  245. raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
  246. new_value = build_segment_with_type(variable.value_type, raw_value)
  247. draft_var_srv.update_variable(variable, name=new_name, value=new_value)
  248. db.session.commit()
  249. return variable
  250. @_api_prerequisite
  251. def delete(self, app_model: App, variable_id: str):
  252. draft_var_srv = WorkflowDraftVariableService(
  253. session=db.session(),
  254. )
  255. variable = draft_var_srv.get_variable(variable_id=variable_id)
  256. if variable is None:
  257. raise NotFoundError(description=f"variable not found, id={variable_id}")
  258. if variable.app_id != app_model.id:
  259. raise NotFoundError(description=f"variable not found, id={variable_id}")
  260. draft_var_srv.delete_variable(variable)
  261. db.session.commit()
  262. return Response("", 204)
  263. class VariableResetApi(Resource):
  264. @_api_prerequisite
  265. def put(self, app_model: App, variable_id: str):
  266. draft_var_srv = WorkflowDraftVariableService(
  267. session=db.session(),
  268. )
  269. workflow_srv = WorkflowService()
  270. draft_workflow = workflow_srv.get_draft_workflow(app_model)
  271. if draft_workflow is None:
  272. raise NotFoundError(
  273. f"Draft workflow not found, app_id={app_model.id}",
  274. )
  275. variable = draft_var_srv.get_variable(variable_id=variable_id)
  276. if variable is None:
  277. raise NotFoundError(description=f"variable not found, id={variable_id}")
  278. if variable.app_id != app_model.id:
  279. raise NotFoundError(description=f"variable not found, id={variable_id}")
  280. resetted = draft_var_srv.reset_variable(draft_workflow, variable)
  281. db.session.commit()
  282. if resetted is None:
  283. return Response("", 204)
  284. else:
  285. return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
  286. def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
  287. with Session(bind=db.engine, expire_on_commit=False) as session:
  288. draft_var_srv = WorkflowDraftVariableService(
  289. session=session,
  290. )
  291. if node_id == CONVERSATION_VARIABLE_NODE_ID:
  292. draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
  293. elif node_id == SYSTEM_VARIABLE_NODE_ID:
  294. draft_vars = draft_var_srv.list_system_variables(app_model.id)
  295. else:
  296. draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
  297. return draft_vars
  298. class ConversationVariableCollectionApi(Resource):
  299. @_api_prerequisite
  300. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  301. def get(self, app_model: App):
  302. # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
  303. # so their IDs can be returned to the caller.
  304. workflow_srv = WorkflowService()
  305. draft_workflow = workflow_srv.get_draft_workflow(app_model)
  306. if draft_workflow is None:
  307. raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
  308. draft_var_srv = WorkflowDraftVariableService(db.session())
  309. draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
  310. db.session.commit()
  311. return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
  312. class SystemVariableCollectionApi(Resource):
  313. @_api_prerequisite
  314. @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
  315. def get(self, app_model: App):
  316. return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
  317. class EnvironmentVariableCollectionApi(Resource):
  318. @_api_prerequisite
  319. def get(self, app_model: App):
  320. """
  321. Get draft workflow
  322. """
  323. # fetch draft workflow by app_model
  324. workflow_service = WorkflowService()
  325. workflow = workflow_service.get_draft_workflow(app_model=app_model)
  326. if workflow is None:
  327. raise DraftWorkflowNotExist()
  328. env_vars = workflow.environment_variables
  329. env_vars_list = []
  330. for v in env_vars:
  331. env_vars_list.append(
  332. {
  333. "id": v.id,
  334. "type": "env",
  335. "name": v.name,
  336. "description": v.description,
  337. "selector": v.selector,
  338. "value_type": v.value_type.value,
  339. "value": v.value,
  340. # Do not track edited for env vars.
  341. "edited": False,
  342. "visible": True,
  343. "editable": True,
  344. }
  345. )
  346. return {"items": env_vars_list}
  347. api.add_resource(
  348. WorkflowVariableCollectionApi,
  349. "/apps/<uuid:app_id>/workflows/draft/variables",
  350. )
  351. api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
  352. api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
  353. api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
  354. api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
  355. api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
  356. api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")