Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

workflow_service.py 40KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008
  1. import json
  2. import time
  3. import uuid
  4. from collections.abc import Callable, Generator, Mapping, Sequence
  5. from typing import Any, Optional, cast
  6. from uuid import uuid4
  7. from sqlalchemy import exists, select
  8. from sqlalchemy.orm import Session, sessionmaker
  9. from core.app.app_config.entities import VariableEntityType
  10. from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
  11. from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
  12. from core.file import File
  13. from core.repositories import DifyCoreRepositoryFactory
  14. from core.variables import Variable
  15. from core.variables.variables import VariableUnion
  16. from core.workflow.entities.node_entities import NodeRunResult
  17. from core.workflow.entities.variable_pool import VariablePool
  18. from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
  19. from core.workflow.errors import WorkflowNodeRunFailedError
  20. from core.workflow.graph_engine.entities.event import InNodeEvent
  21. from core.workflow.nodes import NodeType
  22. from core.workflow.nodes.base.node import BaseNode
  23. from core.workflow.nodes.enums import ErrorStrategy
  24. from core.workflow.nodes.event import RunCompletedEvent
  25. from core.workflow.nodes.event.types import NodeEvent
  26. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  27. from core.workflow.nodes.start.entities import StartNodeData
  28. from core.workflow.system_variable import SystemVariable
  29. from core.workflow.workflow_entry import WorkflowEntry
  30. from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
  31. from extensions.ext_database import db
  32. from factories.file_factory import build_from_mapping, build_from_mappings
  33. from libs.datetime_utils import naive_utc_now
  34. from models.account import Account
  35. from models.model import App, AppMode
  36. from models.tools import WorkflowToolProvider
  37. from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
  38. from repositories.factory import DifyAPIRepositoryFactory
  39. from services.enterprise.plugin_manager_service import PluginCredentialType
  40. from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
  41. from services.workflow.workflow_converter import WorkflowConverter
  42. from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
  43. from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
  44. class WorkflowService:
  45. """
  46. Workflow Service
  47. """
  48. def __init__(self, session_maker: sessionmaker | None = None):
  49. """Initialize WorkflowService with repository dependencies."""
  50. if session_maker is None:
  51. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  52. self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
  53. session_maker
  54. )
  55. def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
  56. """
  57. Get the most recent execution for a specific node.
  58. Args:
  59. app_model: The application model
  60. workflow: The workflow model
  61. node_id: The node identifier
  62. Returns:
  63. The most recent WorkflowNodeExecutionModel for the node, or None if not found
  64. """
  65. return self._node_execution_service_repo.get_node_last_execution(
  66. tenant_id=app_model.tenant_id,
  67. app_id=app_model.id,
  68. workflow_id=workflow.id,
  69. node_id=node_id,
  70. )
  71. def is_workflow_exist(self, app_model: App) -> bool:
  72. stmt = select(
  73. exists().where(
  74. Workflow.tenant_id == app_model.tenant_id,
  75. Workflow.app_id == app_model.id,
  76. Workflow.version == Workflow.VERSION_DRAFT,
  77. )
  78. )
  79. return db.session.execute(stmt).scalar_one()
  80. def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]:
  81. """
  82. Get draft workflow
  83. """
  84. if workflow_id:
  85. return self.get_published_workflow_by_id(app_model, workflow_id)
  86. # fetch draft workflow by app_model
  87. workflow = (
  88. db.session.query(Workflow)
  89. .where(
  90. Workflow.tenant_id == app_model.tenant_id,
  91. Workflow.app_id == app_model.id,
  92. Workflow.version == Workflow.VERSION_DRAFT,
  93. )
  94. .first()
  95. )
  96. # return draft workflow
  97. return workflow
  98. def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
  99. """
  100. fetch published workflow by workflow_id
  101. """
  102. workflow = (
  103. db.session.query(Workflow)
  104. .where(
  105. Workflow.tenant_id == app_model.tenant_id,
  106. Workflow.app_id == app_model.id,
  107. Workflow.id == workflow_id,
  108. )
  109. .first()
  110. )
  111. if not workflow:
  112. return None
  113. if workflow.version == Workflow.VERSION_DRAFT:
  114. raise IsDraftWorkflowError(
  115. f"Cannot use draft workflow version. Workflow ID: {workflow_id}. "
  116. f"Please use a published workflow version or leave workflow_id empty."
  117. )
  118. return workflow
  119. def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
  120. """
  121. Get published workflow
  122. """
  123. if not app_model.workflow_id:
  124. return None
  125. # fetch published workflow by workflow_id
  126. workflow = (
  127. db.session.query(Workflow)
  128. .where(
  129. Workflow.tenant_id == app_model.tenant_id,
  130. Workflow.app_id == app_model.id,
  131. Workflow.id == app_model.workflow_id,
  132. )
  133. .first()
  134. )
  135. return workflow
  136. def get_all_published_workflow(
  137. self,
  138. *,
  139. session: Session,
  140. app_model: App,
  141. page: int,
  142. limit: int,
  143. user_id: str | None,
  144. named_only: bool = False,
  145. ) -> tuple[Sequence[Workflow], bool]:
  146. """
  147. Get published workflow with pagination
  148. """
  149. if not app_model.workflow_id:
  150. return [], False
  151. stmt = (
  152. select(Workflow)
  153. .where(Workflow.app_id == app_model.id)
  154. .order_by(Workflow.version.desc())
  155. .limit(limit + 1)
  156. .offset((page - 1) * limit)
  157. )
  158. if user_id:
  159. stmt = stmt.where(Workflow.created_by == user_id)
  160. if named_only:
  161. stmt = stmt.where(Workflow.marked_name != "")
  162. workflows = session.scalars(stmt).all()
  163. has_more = len(workflows) > limit
  164. if has_more:
  165. workflows = workflows[:-1]
  166. return workflows, has_more
  167. def sync_draft_workflow(
  168. self,
  169. *,
  170. app_model: App,
  171. graph: dict,
  172. features: dict,
  173. unique_hash: Optional[str],
  174. account: Account,
  175. environment_variables: Sequence[Variable],
  176. conversation_variables: Sequence[Variable],
  177. ) -> Workflow:
  178. """
  179. Sync draft workflow
  180. :raises WorkflowHashNotEqualError
  181. """
  182. # fetch draft workflow by app_model
  183. workflow = self.get_draft_workflow(app_model=app_model)
  184. if workflow and workflow.unique_hash != unique_hash:
  185. raise WorkflowHashNotEqualError()
  186. # validate features structure
  187. self.validate_features_structure(app_model=app_model, features=features)
  188. # create draft workflow if not found
  189. if not workflow:
  190. workflow = Workflow(
  191. tenant_id=app_model.tenant_id,
  192. app_id=app_model.id,
  193. type=WorkflowType.from_app_mode(app_model.mode).value,
  194. version=Workflow.VERSION_DRAFT,
  195. graph=json.dumps(graph),
  196. features=json.dumps(features),
  197. created_by=account.id,
  198. environment_variables=environment_variables,
  199. conversation_variables=conversation_variables,
  200. )
  201. db.session.add(workflow)
  202. # update draft workflow if found
  203. else:
  204. workflow.graph = json.dumps(graph)
  205. workflow.features = json.dumps(features)
  206. workflow.updated_by = account.id
  207. workflow.updated_at = naive_utc_now()
  208. workflow.environment_variables = environment_variables
  209. workflow.conversation_variables = conversation_variables
  210. # commit db session changes
  211. db.session.commit()
  212. # trigger app workflow events
  213. app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow)
  214. # return draft workflow
  215. return workflow
  216. def publish_workflow(
  217. self,
  218. *,
  219. session: Session,
  220. app_model: App,
  221. account: Account,
  222. marked_name: str = "",
  223. marked_comment: str = "",
  224. ) -> Workflow:
  225. draft_workflow_stmt = select(Workflow).where(
  226. Workflow.tenant_id == app_model.tenant_id,
  227. Workflow.app_id == app_model.id,
  228. Workflow.version == Workflow.VERSION_DRAFT,
  229. )
  230. draft_workflow = session.scalar(draft_workflow_stmt)
  231. if not draft_workflow:
  232. raise ValueError("No valid workflow found.")
  233. # Validate credentials before publishing, for credential policy check
  234. from services.feature_service import FeatureService
  235. if FeatureService.get_system_features().plugin_manager.enabled:
  236. self._validate_workflow_credentials(draft_workflow)
  237. # create new workflow
  238. workflow = Workflow.new(
  239. tenant_id=app_model.tenant_id,
  240. app_id=app_model.id,
  241. type=draft_workflow.type,
  242. version=Workflow.version_from_datetime(naive_utc_now()),
  243. graph=draft_workflow.graph,
  244. features=draft_workflow.features,
  245. created_by=account.id,
  246. environment_variables=draft_workflow.environment_variables,
  247. conversation_variables=draft_workflow.conversation_variables,
  248. marked_name=marked_name,
  249. marked_comment=marked_comment,
  250. )
  251. # commit db session changes
  252. session.add(workflow)
  253. # trigger app workflow events
  254. app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
  255. # return new workflow
  256. return workflow
  257. def _validate_workflow_credentials(self, workflow: Workflow) -> None:
  258. """
  259. Validate all credentials in workflow nodes before publishing.
  260. :param workflow: The workflow to validate
  261. :raises ValueError: If any credentials violate policy compliance
  262. """
  263. graph_dict = workflow.graph_dict
  264. nodes = graph_dict.get("nodes", [])
  265. for node in nodes:
  266. node_data = node.get("data", {})
  267. node_type = node_data.get("type")
  268. node_id = node.get("id", "unknown")
  269. try:
  270. # Extract and validate credentials based on node type
  271. if node_type == "tool":
  272. credential_id = node_data.get("credential_id")
  273. provider = node_data.get("provider_id")
  274. if provider:
  275. if credential_id:
  276. # Check specific credential
  277. from core.helper.credential_utils import check_credential_policy_compliance
  278. check_credential_policy_compliance(
  279. credential_id=credential_id,
  280. provider=provider,
  281. credential_type=PluginCredentialType.TOOL,
  282. )
  283. else:
  284. # Check default workspace credential for this provider
  285. self._check_default_tool_credential(workflow.tenant_id, provider)
  286. elif node_type == "agent":
  287. agent_params = node_data.get("agent_parameters", {})
  288. model_config = agent_params.get("model", {}).get("value", {})
  289. if model_config.get("provider") and model_config.get("model"):
  290. self._validate_llm_model_config(
  291. workflow.tenant_id, model_config["provider"], model_config["model"]
  292. )
  293. # Validate load balancing credentials for agent model if load balancing is enabled
  294. agent_model_node_data = {"model": model_config}
  295. self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)
  296. # Validate agent tools
  297. tools = agent_params.get("tools", {}).get("value", [])
  298. for tool in tools:
  299. # Agent tools store provider in provider_name field
  300. provider = tool.get("provider_name")
  301. credential_id = tool.get("credential_id")
  302. if provider:
  303. if credential_id:
  304. from core.helper.credential_utils import check_credential_policy_compliance
  305. check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
  306. else:
  307. self._check_default_tool_credential(workflow.tenant_id, provider)
  308. elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
  309. model_config = node_data.get("model", {})
  310. provider = model_config.get("provider")
  311. model_name = model_config.get("name")
  312. if provider and model_name:
  313. # Validate that the provider+model combination can fetch valid credentials
  314. self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
  315. # Validate load balancing credentials if load balancing is enabled
  316. self._validate_load_balancing_credentials(workflow, node_data, node_id)
  317. else:
  318. raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")
  319. except Exception as e:
  320. if isinstance(e, ValueError):
  321. raise e
  322. else:
  323. raise ValueError(f"Node {node_id} ({node_type}): {str(e)}")
  324. def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
  325. """
  326. Validate that an LLM model configuration can fetch valid credentials.
  327. This method attempts to get the model instance and validates that:
  328. 1. The provider exists and is configured
  329. 2. The model exists in the provider
  330. 3. Credentials can be fetched for the model
  331. 4. The credentials pass policy compliance checks
  332. :param tenant_id: The tenant ID
  333. :param provider: The provider name
  334. :param model_name: The model name
  335. :raises ValueError: If the model configuration is invalid or credentials fail policy checks
  336. """
  337. try:
  338. from core.model_manager import ModelManager
  339. from core.model_runtime.entities.model_entities import ModelType
  340. # Get model instance to validate provider+model combination
  341. model_manager = ModelManager()
  342. model_manager.get_model_instance(
  343. tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
  344. )
  345. # The ModelInstance constructor will automatically check credential policy compliance
  346. # via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
  347. # If it fails, an exception will be raised
  348. except Exception as e:
  349. raise ValueError(
  350. f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
  351. )
  352. def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None:
  353. """
  354. Check credential policy compliance for the default workspace credential of a tool provider.
  355. This method finds the default credential for the given provider and validates it.
  356. Uses the same fallback logic as runtime to handle deauthorized credentials.
  357. :param tenant_id: The tenant ID
  358. :param provider: The tool provider name
  359. :raises ValueError: If no default credential exists or if it fails policy compliance
  360. """
  361. try:
  362. from models.tools import BuiltinToolProvider
  363. # Use the same fallback logic as runtime: get the first available credential
  364. # ordered by is_default DESC, created_at ASC (same as tool_manager.py)
  365. default_provider = (
  366. db.session.query(BuiltinToolProvider)
  367. .where(
  368. BuiltinToolProvider.tenant_id == tenant_id,
  369. BuiltinToolProvider.provider == provider,
  370. )
  371. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  372. .first()
  373. )
  374. if not default_provider:
  375. raise ValueError("No default credential found")
  376. # Check credential policy compliance using the default credential ID
  377. from core.helper.credential_utils import check_credential_policy_compliance
  378. check_credential_policy_compliance(
  379. credential_id=default_provider.id,
  380. provider=provider,
  381. credential_type=PluginCredentialType.TOOL,
  382. check_existence=False,
  383. )
  384. except Exception as e:
  385. raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
  386. def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
  387. """
  388. Validate load balancing credentials for a workflow node.
  389. :param workflow: The workflow being validated
  390. :param node_data: The node data containing model configuration
  391. :param node_id: The node ID for error reporting
  392. :raises ValueError: If load balancing credentials violate policy compliance
  393. """
  394. # Extract model configuration
  395. model_config = node_data.get("model", {})
  396. provider = model_config.get("provider")
  397. model_name = model_config.get("name")
  398. if not provider or not model_name:
  399. return # No model config to validate
  400. # Check if this model has load balancing enabled
  401. if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
  402. # Get all load balancing configurations for this model
  403. load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
  404. # Validate each load balancing configuration
  405. try:
  406. for config in load_balancing_configs:
  407. if config.get("credential_id"):
  408. from core.helper.credential_utils import check_credential_policy_compliance
  409. check_credential_policy_compliance(
  410. config["credential_id"], provider, PluginCredentialType.MODEL
  411. )
  412. except Exception as e:
  413. raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
  414. def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
  415. """
  416. Check if load balancing is enabled for a specific model.
  417. :param tenant_id: The tenant ID
  418. :param provider: The provider name
  419. :param model_name: The model name
  420. :return: True if load balancing is enabled, False otherwise
  421. """
  422. try:
  423. from core.model_runtime.entities.model_entities import ModelType
  424. from core.provider_manager import ProviderManager
  425. # Get provider configurations
  426. provider_manager = ProviderManager()
  427. provider_configurations = provider_manager.get_configurations(tenant_id)
  428. provider_configuration = provider_configurations.get(provider)
  429. if not provider_configuration:
  430. return False
  431. # Get provider model setting
  432. provider_model_setting = provider_configuration.get_provider_model_setting(
  433. model_type=ModelType.LLM,
  434. model=model_name,
  435. )
  436. return provider_model_setting is not None and provider_model_setting.load_balancing_enabled
  437. except Exception:
  438. # If we can't determine the status, assume load balancing is not enabled
  439. return False
  440. def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
  441. """
  442. Get all load balancing configurations for a model.
  443. :param tenant_id: The tenant ID
  444. :param provider: The provider name
  445. :param model_name: The model name
  446. :return: List of load balancing configuration dictionaries
  447. """
  448. try:
  449. from services.model_load_balancing_service import ModelLoadBalancingService
  450. model_load_balancing_service = ModelLoadBalancingService()
  451. _, configs = model_load_balancing_service.get_load_balancing_configs(
  452. tenant_id=tenant_id,
  453. provider=provider,
  454. model=model_name,
  455. model_type="llm", # Load balancing is primarily used for LLM models
  456. config_from="predefined-model", # Check both predefined and custom models
  457. )
  458. _, custom_configs = model_load_balancing_service.get_load_balancing_configs(
  459. tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
  460. )
  461. all_configs = configs + custom_configs
  462. return [config for config in all_configs if config.get("credential_id")]
  463. except Exception:
  464. # If we can't get the configurations, return empty list
  465. # This will prevent validation errors from breaking the workflow
  466. return []
  467. def get_default_block_configs(self) -> list[dict]:
  468. """
  469. Get default block configs
  470. """
  471. # return default block config
  472. default_block_configs = []
  473. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  474. node_class = node_class_mapping[LATEST_VERSION]
  475. default_config = node_class.get_default_config()
  476. if default_config:
  477. default_block_configs.append(default_config)
  478. return default_block_configs
  479. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  480. """
  481. Get default config of node.
  482. :param node_type: node type
  483. :param filters: filter by node config parameters.
  484. :return:
  485. """
  486. node_type_enum = NodeType(node_type)
  487. # return default block config
  488. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  489. return None
  490. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  491. default_config = node_class.get_default_config(filters=filters)
  492. if not default_config:
  493. return None
  494. return default_config
  495. def run_draft_workflow_node(
  496. self,
  497. app_model: App,
  498. draft_workflow: Workflow,
  499. node_id: str,
  500. user_inputs: Mapping[str, Any],
  501. account: Account,
  502. query: str = "",
  503. files: Sequence[File] | None = None,
  504. ) -> WorkflowNodeExecutionModel:
  505. """
  506. Run draft workflow node
  507. """
  508. files = files or []
  509. with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
  510. draft_var_srv = WorkflowDraftVariableService(session)
  511. draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
  512. node_config = draft_workflow.get_node_config_by_id(node_id)
  513. node_type = Workflow.get_node_type_from_node_config(node_config)
  514. node_data = node_config.get("data", {})
  515. if node_type == NodeType.START:
  516. with Session(bind=db.engine) as session, session.begin():
  517. draft_var_srv = WorkflowDraftVariableService(session)
  518. conversation_id = draft_var_srv.get_or_create_conversation(
  519. account_id=account.id,
  520. app=app_model,
  521. workflow=draft_workflow,
  522. )
  523. start_data = StartNodeData.model_validate(node_data)
  524. user_inputs = _rebuild_file_for_user_inputs_in_start_node(
  525. tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
  526. )
  527. # init variable pool
  528. variable_pool = _setup_variable_pool(
  529. query=query,
  530. files=files or [],
  531. user_id=account.id,
  532. user_inputs=user_inputs,
  533. workflow=draft_workflow,
  534. # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables.
  535. conversation_variables=[],
  536. node_type=node_type,
  537. conversation_id=conversation_id,
  538. )
  539. else:
  540. variable_pool = VariablePool(
  541. system_variables=SystemVariable.empty(),
  542. user_inputs=user_inputs,
  543. environment_variables=draft_workflow.environment_variables,
  544. conversation_variables=[],
  545. )
  546. variable_loader = DraftVarLoader(
  547. engine=db.engine,
  548. app_id=app_model.id,
  549. tenant_id=app_model.tenant_id,
  550. )
  551. enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  552. if enclosing_node_type_and_id:
  553. _, enclosing_node_id = enclosing_node_type_and_id
  554. else:
  555. enclosing_node_id = None
  556. run = WorkflowEntry.single_step_run(
  557. workflow=draft_workflow,
  558. node_id=node_id,
  559. user_inputs=user_inputs,
  560. user_id=account.id,
  561. variable_pool=variable_pool,
  562. variable_loader=variable_loader,
  563. )
  564. # run draft workflow node
  565. start_at = time.perf_counter()
  566. node_execution = self._handle_node_run_result(
  567. invoke_node_fn=lambda: run,
  568. start_at=start_at,
  569. node_id=node_id,
  570. )
  571. # Set workflow_id on the NodeExecution
  572. node_execution.workflow_id = draft_workflow.id
  573. # Create repository and save the node execution
  574. repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
  575. session_factory=db.engine,
  576. user=account,
  577. app_id=app_model.id,
  578. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  579. )
  580. repository.save(node_execution)
  581. workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
  582. if workflow_node_execution is None:
  583. raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
  584. with Session(bind=db.engine) as session, session.begin():
  585. draft_var_saver = DraftVariableSaver(
  586. session=session,
  587. app_id=app_model.id,
  588. node_id=workflow_node_execution.node_id,
  589. node_type=NodeType(workflow_node_execution.node_type),
  590. enclosing_node_id=enclosing_node_id,
  591. node_execution_id=node_execution.id,
  592. )
  593. draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
  594. session.commit()
  595. return workflow_node_execution
  596. def run_free_workflow_node(
  597. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  598. ) -> WorkflowNodeExecution:
  599. """
  600. Run free workflow node
  601. """
  602. # run free workflow node
  603. start_at = time.perf_counter()
  604. node_execution = self._handle_node_run_result(
  605. invoke_node_fn=lambda: WorkflowEntry.run_free_node(
  606. node_id=node_id,
  607. node_data=node_data,
  608. tenant_id=tenant_id,
  609. user_id=user_id,
  610. user_inputs=user_inputs,
  611. ),
  612. start_at=start_at,
  613. node_id=node_id,
  614. )
  615. return node_execution
  616. def _handle_node_run_result(
  617. self,
  618. invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  619. start_at: float,
  620. node_id: str,
  621. ) -> WorkflowNodeExecution:
  622. try:
  623. node, node_events = invoke_node_fn()
  624. node_run_result: NodeRunResult | None = None
  625. for event in node_events:
  626. if isinstance(event, RunCompletedEvent):
  627. node_run_result = event.run_result
  628. # sign output files
  629. # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  630. break
  631. if not node_run_result:
  632. raise ValueError("Node run failed with no run result")
  633. # single step debug mode error handling return
  634. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
  635. node_error_args: dict[str, Any] = {
  636. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  637. "error": node_run_result.error,
  638. "inputs": node_run_result.inputs,
  639. "metadata": {"error_strategy": node.error_strategy},
  640. }
  641. if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  642. node_run_result = NodeRunResult(
  643. **node_error_args,
  644. outputs={
  645. **node.default_value_dict,
  646. "error_message": node_run_result.error,
  647. "error_type": node_run_result.error_type,
  648. },
  649. )
  650. else:
  651. node_run_result = NodeRunResult(
  652. **node_error_args,
  653. outputs={
  654. "error_message": node_run_result.error,
  655. "error_type": node_run_result.error_type,
  656. },
  657. )
  658. run_succeeded = node_run_result.status in (
  659. WorkflowNodeExecutionStatus.SUCCEEDED,
  660. WorkflowNodeExecutionStatus.EXCEPTION,
  661. )
  662. error = node_run_result.error if not run_succeeded else None
  663. except WorkflowNodeRunFailedError as e:
  664. node = e._node
  665. run_succeeded = False
  666. node_run_result = None
  667. error = e._error
  668. # Create a NodeExecution domain model
  669. node_execution = WorkflowNodeExecution(
  670. id=str(uuid4()),
  671. workflow_id="", # This is a single-step execution, so no workflow ID
  672. index=1,
  673. node_id=node_id,
  674. node_type=node.type_,
  675. title=node.title,
  676. elapsed_time=time.perf_counter() - start_at,
  677. created_at=naive_utc_now(),
  678. finished_at=naive_utc_now(),
  679. )
  680. if run_succeeded and node_run_result:
  681. # Set inputs, process_data, and outputs as dictionaries (not JSON strings)
  682. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  683. process_data = (
  684. WorkflowEntry.handle_special_values(node_run_result.process_data)
  685. if node_run_result.process_data
  686. else None
  687. )
  688. outputs = node_run_result.outputs
  689. node_execution.inputs = inputs
  690. node_execution.process_data = process_data
  691. node_execution.outputs = outputs
  692. node_execution.metadata = node_run_result.metadata
  693. # Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
  694. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  695. node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  696. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  697. node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  698. node_execution.error = node_run_result.error
  699. else:
  700. # Set failed status and error
  701. node_execution.status = WorkflowNodeExecutionStatus.FAILED
  702. node_execution.error = error
  703. return node_execution
  704. def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
  705. """
  706. Basic mode of chatbot app(expert mode) to workflow
  707. Completion App to Workflow App
  708. :param app_model: App instance
  709. :param account: Account instance
  710. :param args: dict
  711. :return:
  712. """
  713. # chatbot convert to workflow mode
  714. workflow_converter = WorkflowConverter()
  715. if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
  716. raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
  717. # convert to workflow
  718. new_app: App = workflow_converter.convert_to_workflow(
  719. app_model=app_model,
  720. account=account,
  721. name=args.get("name", "Default Name"),
  722. icon_type=args.get("icon_type", "emoji"),
  723. icon=args.get("icon", "🤖"),
  724. icon_background=args.get("icon_background", "#FFEAD5"),
  725. )
  726. return new_app
  727. def validate_features_structure(self, app_model: App, features: dict):
  728. if app_model.mode == AppMode.ADVANCED_CHAT.value:
  729. return AdvancedChatAppConfigManager.config_validate(
  730. tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
  731. )
  732. elif app_model.mode == AppMode.WORKFLOW.value:
  733. return WorkflowAppConfigManager.config_validate(
  734. tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
  735. )
  736. else:
  737. raise ValueError(f"Invalid app mode: {app_model.mode}")
  738. def update_workflow(
  739. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  740. ) -> Optional[Workflow]:
  741. """
  742. Update workflow attributes
  743. :param session: SQLAlchemy database session
  744. :param workflow_id: Workflow ID
  745. :param tenant_id: Tenant ID
  746. :param account_id: Account ID (for permission check)
  747. :param data: Dictionary containing fields to update
  748. :return: Updated workflow or None if not found
  749. """
  750. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  751. workflow = session.scalar(stmt)
  752. if not workflow:
  753. return None
  754. allowed_fields = ["marked_name", "marked_comment"]
  755. for field, value in data.items():
  756. if field in allowed_fields:
  757. setattr(workflow, field, value)
  758. workflow.updated_by = account_id
  759. workflow.updated_at = naive_utc_now()
  760. return workflow
  761. def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
  762. """
  763. Delete a workflow
  764. :param session: SQLAlchemy database session
  765. :param workflow_id: Workflow ID
  766. :param tenant_id: Tenant ID
  767. :return: True if successful
  768. :raises: ValueError if workflow not found
  769. :raises: WorkflowInUseError if workflow is in use
  770. :raises: DraftWorkflowDeletionError if workflow is a draft version
  771. """
  772. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  773. workflow = session.scalar(stmt)
  774. if not workflow:
  775. raise ValueError(f"Workflow with ID {workflow_id} not found")
  776. # Check if workflow is a draft version
  777. if workflow.version == Workflow.VERSION_DRAFT:
  778. raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
  779. # Check if this workflow is currently referenced by an app
  780. app_stmt = select(App).where(App.workflow_id == workflow_id)
  781. app = session.scalar(app_stmt)
  782. if app:
  783. # Cannot delete a workflow that's currently in use by an app
  784. raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
  785. # Don't use workflow.tool_published as it's not accurate for specific workflow versions
  786. # Check if there's a tool provider using this specific workflow version
  787. tool_provider = (
  788. session.query(WorkflowToolProvider)
  789. .where(
  790. WorkflowToolProvider.tenant_id == workflow.tenant_id,
  791. WorkflowToolProvider.app_id == workflow.app_id,
  792. WorkflowToolProvider.version == workflow.version,
  793. )
  794. .first()
  795. )
  796. if tool_provider:
  797. # Cannot delete a workflow that's published as a tool
  798. raise WorkflowInUseError("Cannot delete workflow that is published as a tool")
  799. session.delete(workflow)
  800. return True
  801. def _setup_variable_pool(
  802. query: str,
  803. files: Sequence[File],
  804. user_id: str,
  805. user_inputs: Mapping[str, Any],
  806. workflow: Workflow,
  807. node_type: NodeType,
  808. conversation_id: str,
  809. conversation_variables: list[Variable],
  810. ):
  811. # Only inject system variables for START node type.
  812. if node_type == NodeType.START:
  813. system_variable = SystemVariable(
  814. user_id=user_id,
  815. app_id=workflow.app_id,
  816. workflow_id=workflow.id,
  817. files=files or [],
  818. workflow_execution_id=str(uuid.uuid4()),
  819. )
  820. # Only add chatflow-specific variables for non-workflow types
  821. if workflow.type != WorkflowType.WORKFLOW.value:
  822. system_variable.query = query
  823. system_variable.conversation_id = conversation_id
  824. system_variable.dialogue_count = 0
  825. else:
  826. system_variable = SystemVariable.empty()
  827. # init variable pool
  828. variable_pool = VariablePool(
  829. system_variables=system_variable,
  830. user_inputs=user_inputs,
  831. environment_variables=workflow.environment_variables,
  832. # Based on the definition of `VariableUnion`,
  833. # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
  834. conversation_variables=cast(list[VariableUnion], conversation_variables), #
  835. )
  836. return variable_pool
  837. def _rebuild_file_for_user_inputs_in_start_node(
  838. tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any]
  839. ) -> Mapping[str, Any]:
  840. inputs_copy = dict(user_inputs)
  841. for variable in start_node_data.variables:
  842. if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST):
  843. continue
  844. if variable.variable not in user_inputs:
  845. continue
  846. value = user_inputs[variable.variable]
  847. file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
  848. inputs_copy[variable.variable] = file
  849. return inputs_copy
  850. def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]:
  851. if variable_entity_type == VariableEntityType.FILE:
  852. if not isinstance(value, dict):
  853. raise ValueError(f"expected dict for file object, got {type(value)}")
  854. return build_from_mapping(mapping=value, tenant_id=tenant_id)
  855. elif variable_entity_type == VariableEntityType.FILE_LIST:
  856. if not isinstance(value, list):
  857. raise ValueError(f"expected list for file list object, got {type(value)}")
  858. if len(value) == 0:
  859. return []
  860. if not isinstance(value[0], dict):
  861. raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
  862. return build_from_mappings(mappings=value, tenant_id=tenant_id)
  863. else:
  864. raise Exception("unreachable")