|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421 |
- import logging
- from typing import Any, NoReturn
-
- from flask import Response
- from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
- from sqlalchemy.orm import Session
- from werkzeug.exceptions import Forbidden
-
- from controllers.console import api
- from controllers.console.app.error import (
- DraftWorkflowNotExist,
- )
- from controllers.console.app.wraps import get_app_model
- from controllers.console.wraps import account_initialization_required, setup_required
- from controllers.web.error import InvalidArgumentError, NotFoundError
- from core.variables.segment_group import SegmentGroup
- from core.variables.segments import ArrayFileSegment, FileSegment, Segment
- from core.variables.types import SegmentType
- from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
- from factories.file_factory import build_from_mapping, build_from_mappings
- from factories.variable_factory import build_segment_with_type
- from libs.login import current_user, login_required
- from models import App, AppMode, db
- from models.workflow import WorkflowDraftVariable
- from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
- from services.workflow_service import WorkflowService
-
- logger = logging.getLogger(__name__)
-
-
- def _convert_values_to_json_serializable_object(value: Segment) -> Any:
- if isinstance(value, FileSegment):
- return value.value.model_dump()
- elif isinstance(value, ArrayFileSegment):
- return [i.model_dump() for i in value.value]
- elif isinstance(value, SegmentGroup):
- return [_convert_values_to_json_serializable_object(i) for i in value.value]
- else:
- return value.value
-
-
- def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
- value = variable.get_value()
- # create a copy of the value to avoid affecting the model cache.
- value = value.model_copy(deep=True)
- # Refresh the url signature before returning it to client.
- if isinstance(value, FileSegment):
- file = value.value
- file.remote_url = file.generate_url()
- elif isinstance(value, ArrayFileSegment):
- files = value.value
- for file in files:
- file.remote_url = file.generate_url()
- return _convert_values_to_json_serializable_object(value)
-
-
- def _create_pagination_parser():
- parser = reqparse.RequestParser()
- parser.add_argument(
- "page",
- type=inputs.int_range(1, 100_000),
- required=False,
- default=1,
- location="args",
- help="the page of data requested",
- )
- parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
- return parser
-
-
- _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
- "id": fields.String,
- "type": fields.String(attribute=lambda model: model.get_variable_type()),
- "name": fields.String,
- "description": fields.String,
- "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
- "value_type": fields.String,
- "edited": fields.Boolean(attribute=lambda model: model.edited),
- "visible": fields.Boolean,
- }
-
- _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
- _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
- value=fields.Raw(attribute=_serialize_var_value),
- )
-
- _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
- "id": fields.String,
- "type": fields.String(attribute=lambda _: "env"),
- "name": fields.String,
- "description": fields.String,
- "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
- "value_type": fields.String,
- "edited": fields.Boolean(attribute=lambda model: model.edited),
- "visible": fields.Boolean,
- }
-
- _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
- "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
- }
-
-
- def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
- return var_list.variables
-
-
- _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
- "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
- "total": fields.Raw(),
- }
-
- _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
- "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
- }
-
-
- def _api_prerequisite(f):
- """Common prerequisites for all draft workflow variable APIs.
-
- It ensures the following conditions are satisfied:
-
- - Dify has been property setup.
- - The request user has logged in and initialized.
- - The requested app is a workflow or a chat flow.
- - The request user has the edit permission for the app.
- """
-
- @setup_required
- @login_required
- @account_initialization_required
- @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- def wrapper(*args, **kwargs):
- if not current_user.is_editor:
- raise Forbidden()
- return f(*args, **kwargs)
-
- return wrapper
-
-
- class WorkflowVariableCollectionApi(Resource):
- @_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
- def get(self, app_model: App):
- """
- Get draft workflow
- """
- parser = _create_pagination_parser()
- args = parser.parse_args()
-
- # fetch draft workflow by app_model
- workflow_service = WorkflowService()
- workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
- if not workflow_exist:
- raise DraftWorkflowNotExist()
-
- # fetch draft workflow by app_model
- with Session(bind=db.engine, expire_on_commit=False) as session:
- draft_var_srv = WorkflowDraftVariableService(
- session=session,
- )
- workflow_vars = draft_var_srv.list_variables_without_values(
- app_id=app_model.id,
- page=args.page,
- limit=args.limit,
- )
-
- return workflow_vars
-
- @_api_prerequisite
- def delete(self, app_model: App):
- draft_var_srv = WorkflowDraftVariableService(
- session=db.session(),
- )
- draft_var_srv.delete_workflow_variables(app_model.id)
- db.session.commit()
- return Response("", 204)
-
-
- def validate_node_id(node_id: str) -> NoReturn | None:
- if node_id in [
- CONVERSATION_VARIABLE_NODE_ID,
- SYSTEM_VARIABLE_NODE_ID,
- ]:
- # NOTE(QuantumGhost): While we store the system and conversation variables as node variables
- # with specific `node_id` in database, we still want to make the API separated. By disallowing
- # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
- # we mitigate the risk that user of the API depending on the implementation detail of the API.
- #
- # ref: [Hyrum's Law](https://www.hyrumslaw.com/)
-
- raise InvalidArgumentError(
- f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
- )
- return None
-
-
- class NodeVariableCollectionApi(Resource):
- @_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
- def get(self, app_model: App, node_id: str):
- validate_node_id(node_id)
- with Session(bind=db.engine, expire_on_commit=False) as session:
- draft_var_srv = WorkflowDraftVariableService(
- session=session,
- )
- node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
-
- return node_vars
-
- @_api_prerequisite
- def delete(self, app_model: App, node_id: str):
- validate_node_id(node_id)
- srv = WorkflowDraftVariableService(db.session())
- srv.delete_node_variables(app_model.id, node_id)
- db.session.commit()
- return Response("", 204)
-
-
- class VariableApi(Resource):
- _PATCH_NAME_FIELD = "name"
- _PATCH_VALUE_FIELD = "value"
-
- @_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
- def get(self, app_model: App, variable_id: str):
- draft_var_srv = WorkflowDraftVariableService(
- session=db.session(),
- )
- variable = draft_var_srv.get_variable(variable_id=variable_id)
- if variable is None:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
- if variable.app_id != app_model.id:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
- return variable
-
- @_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
- def patch(self, app_model: App, variable_id: str):
- # Request payload for file types:
- #
- # Local File:
- #
- # {
- # "type": "image",
- # "transfer_method": "local_file",
- # "url": "",
- # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
- # }
- #
- # Remote File:
- #
- #
- # {
- # "type": "image",
- # "transfer_method": "remote_url",
- # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
- # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
- # }
-
- parser = reqparse.RequestParser()
- parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
- # Parse 'value' field as-is to maintain its original data structure
- parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
-
- draft_var_srv = WorkflowDraftVariableService(
- session=db.session(),
- )
- args = parser.parse_args(strict=True)
-
- variable = draft_var_srv.get_variable(variable_id=variable_id)
- if variable is None:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
- if variable.app_id != app_model.id:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
-
- new_name = args.get(self._PATCH_NAME_FIELD, None)
- raw_value = args.get(self._PATCH_VALUE_FIELD, None)
- if new_name is None and raw_value is None:
- return variable
-
- new_value = None
- if raw_value is not None:
- if variable.value_type == SegmentType.FILE:
- if not isinstance(raw_value, dict):
- raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
- raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
- elif variable.value_type == SegmentType.ARRAY_FILE:
- if not isinstance(raw_value, list):
- raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
- if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
- raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
- raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
- new_value = build_segment_with_type(variable.value_type, raw_value)
- draft_var_srv.update_variable(variable, name=new_name, value=new_value)
- db.session.commit()
- return variable
-
- @_api_prerequisite
- def delete(self, app_model: App, variable_id: str):
- draft_var_srv = WorkflowDraftVariableService(
- session=db.session(),
- )
- variable = draft_var_srv.get_variable(variable_id=variable_id)
- if variable is None:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
- if variable.app_id != app_model.id:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
- draft_var_srv.delete_variable(variable)
- db.session.commit()
- return Response("", 204)
-
-
- class VariableResetApi(Resource):
- @_api_prerequisite
- def put(self, app_model: App, variable_id: str):
- draft_var_srv = WorkflowDraftVariableService(
- session=db.session(),
- )
-
- workflow_srv = WorkflowService()
- draft_workflow = workflow_srv.get_draft_workflow(app_model)
- if draft_workflow is None:
- raise NotFoundError(
- f"Draft workflow not found, app_id={app_model.id}",
- )
- variable = draft_var_srv.get_variable(variable_id=variable_id)
- if variable is None:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
- if variable.app_id != app_model.id:
- raise NotFoundError(description=f"variable not found, id={variable_id}")
-
- resetted = draft_var_srv.reset_variable(draft_workflow, variable)
- db.session.commit()
- if resetted is None:
- return Response("", 204)
- else:
- return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
-
-
- def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
- with Session(bind=db.engine, expire_on_commit=False) as session:
- draft_var_srv = WorkflowDraftVariableService(
- session=session,
- )
- if node_id == CONVERSATION_VARIABLE_NODE_ID:
- draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
- elif node_id == SYSTEM_VARIABLE_NODE_ID:
- draft_vars = draft_var_srv.list_system_variables(app_model.id)
- else:
- draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
- return draft_vars
-
-
- class ConversationVariableCollectionApi(Resource):
- @_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
- def get(self, app_model: App):
- # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
- # so their IDs can be returned to the caller.
- workflow_srv = WorkflowService()
- draft_workflow = workflow_srv.get_draft_workflow(app_model)
- if draft_workflow is None:
- raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
- draft_var_srv = WorkflowDraftVariableService(db.session())
- draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
- db.session.commit()
- return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
-
-
- class SystemVariableCollectionApi(Resource):
- @_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
- def get(self, app_model: App):
- return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
-
-
- class EnvironmentVariableCollectionApi(Resource):
- @_api_prerequisite
- def get(self, app_model: App):
- """
- Get draft workflow
- """
- # fetch draft workflow by app_model
- workflow_service = WorkflowService()
- workflow = workflow_service.get_draft_workflow(app_model=app_model)
- if workflow is None:
- raise DraftWorkflowNotExist()
-
- env_vars = workflow.environment_variables
- env_vars_list = []
- for v in env_vars:
- env_vars_list.append(
- {
- "id": v.id,
- "type": "env",
- "name": v.name,
- "description": v.description,
- "selector": v.selector,
- "value_type": v.value_type.value,
- "value": v.value,
- # Do not track edited for env vars.
- "edited": False,
- "visible": True,
- "editable": True,
- }
- )
-
- return {"items": env_vars_list}
-
-
- api.add_resource(
- WorkflowVariableCollectionApi,
- "/apps/<uuid:app_id>/workflows/draft/variables",
- )
- api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
- api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
- api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
-
- api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
- api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
- api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|