您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rag_pipeline.py 50KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213
  1. import json
  2. import logging
  3. import re
  4. import threading
  5. import time
  6. from collections.abc import Callable, Generator, Mapping, Sequence
  7. from datetime import UTC, datetime
  8. from typing import Any, Optional, cast
  9. from uuid import uuid4
  10. from flask_login import current_user
  11. from sqlalchemy import func, or_, select
  12. from sqlalchemy.orm import Session
  13. import contexts
  14. from configs import dify_config
  15. from core.app.entities.app_invoke_entities import InvokeFrom
  16. from core.datasource.entities.datasource_entities import (
  17. DatasourceMessage,
  18. DatasourceProviderType,
  19. GetOnlineDocumentPageContentRequest,
  20. OnlineDocumentPagesMessage,
  21. OnlineDriveBrowseFilesRequest,
  22. OnlineDriveBrowseFilesResponse,
  23. WebsiteCrawlMessage,
  24. )
  25. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  26. from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
  27. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  28. from core.rag.entities.event import (
  29. BaseDatasourceEvent,
  30. DatasourceCompletedEvent,
  31. DatasourceErrorEvent,
  32. DatasourceProcessingEvent,
  33. )
  34. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  35. from core.variables.variables import Variable
  36. from core.workflow.entities.node_entities import NodeRunResult
  37. from core.workflow.entities.variable_pool import VariablePool
  38. from core.workflow.entities.workflow_node_execution import (
  39. WorkflowNodeExecution,
  40. WorkflowNodeExecutionStatus,
  41. )
  42. from core.workflow.enums import SystemVariableKey
  43. from core.workflow.errors import WorkflowNodeRunFailedError
  44. from core.workflow.graph_engine.entities.event import InNodeEvent
  45. from core.workflow.nodes.base.node import BaseNode
  46. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  47. from core.workflow.nodes.event.event import RunCompletedEvent
  48. from core.workflow.nodes.event.types import NodeEvent
  49. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  50. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  51. from core.workflow.system_variable import SystemVariable
  52. from core.workflow.workflow_entry import WorkflowEntry
  53. from extensions.ext_database import db
  54. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  55. from models.account import Account
  56. from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
  57. from models.enums import WorkflowRunTriggeredFrom
  58. from models.model import EndUser
  59. from models.workflow import (
  60. Workflow,
  61. WorkflowNodeExecutionModel,
  62. WorkflowNodeExecutionTriggeredFrom,
  63. WorkflowRun,
  64. WorkflowType,
  65. )
  66. from services.dataset_service import DatasetService
  67. from services.datasource_provider_service import DatasourceProviderService
  68. from services.entities.knowledge_entities.rag_pipeline_entities import (
  69. KnowledgeConfiguration,
  70. PipelineTemplateInfoEntity,
  71. )
  72. from services.errors.app import WorkflowHashNotEqualError
  73. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  74. from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
  75. logger = logging.getLogger(__name__)
  76. class RagPipelineService:
  77. @classmethod
  78. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  79. if type == "built-in":
  80. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  81. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  82. result = retrieval_instance.get_pipeline_templates(language)
  83. if not result.get("pipeline_templates") and language != "en-US":
  84. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  85. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  86. return result
  87. else:
  88. mode = "customized"
  89. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  90. result = retrieval_instance.get_pipeline_templates(language)
  91. return result
  92. @classmethod
  93. def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
  94. """
  95. Get pipeline template detail.
  96. :param template_id: template id
  97. :return:
  98. """
  99. if type == "built-in":
  100. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  101. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  102. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  103. else:
  104. mode = "customized"
  105. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  106. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  107. return result
  108. @classmethod
  109. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  110. """
  111. Update pipeline template.
  112. :param template_id: template id
  113. :param template_info: template info
  114. """
  115. customized_template: PipelineCustomizedTemplate | None = (
  116. db.session.query(PipelineCustomizedTemplate)
  117. .filter(
  118. PipelineCustomizedTemplate.id == template_id,
  119. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  120. )
  121. .first()
  122. )
  123. if not customized_template:
  124. raise ValueError("Customized pipeline template not found.")
  125. # check template name is exist
  126. template_name = template_info.name
  127. if template_name:
  128. template = (
  129. db.session.query(PipelineCustomizedTemplate)
  130. .filter(
  131. PipelineCustomizedTemplate.name == template_name,
  132. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  133. PipelineCustomizedTemplate.id != template_id,
  134. )
  135. .first()
  136. )
  137. if template:
  138. raise ValueError("Template name is already exists")
  139. customized_template.name = template_info.name
  140. customized_template.description = template_info.description
  141. customized_template.icon = template_info.icon_info.model_dump()
  142. customized_template.updated_by = current_user.id
  143. db.session.commit()
  144. return customized_template
  145. @classmethod
  146. def delete_customized_pipeline_template(cls, template_id: str):
  147. """
  148. Delete customized pipeline template.
  149. """
  150. customized_template: PipelineCustomizedTemplate | None = (
  151. db.session.query(PipelineCustomizedTemplate)
  152. .filter(
  153. PipelineCustomizedTemplate.id == template_id,
  154. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  155. )
  156. .first()
  157. )
  158. if not customized_template:
  159. raise ValueError("Customized pipeline template not found.")
  160. db.session.delete(customized_template)
  161. db.session.commit()
  162. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  163. """
  164. Get draft workflow
  165. """
  166. # fetch draft workflow by rag pipeline
  167. workflow = (
  168. db.session.query(Workflow)
  169. .filter(
  170. Workflow.tenant_id == pipeline.tenant_id,
  171. Workflow.app_id == pipeline.id,
  172. Workflow.version == "draft",
  173. )
  174. .first()
  175. )
  176. # return draft workflow
  177. return workflow
  178. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  179. """
  180. Get published workflow
  181. """
  182. if not pipeline.workflow_id:
  183. return None
  184. # fetch published workflow by workflow_id
  185. workflow = (
  186. db.session.query(Workflow)
  187. .filter(
  188. Workflow.tenant_id == pipeline.tenant_id,
  189. Workflow.app_id == pipeline.id,
  190. Workflow.id == pipeline.workflow_id,
  191. )
  192. .first()
  193. )
  194. return workflow
  195. def get_all_published_workflow(
  196. self,
  197. *,
  198. session: Session,
  199. pipeline: Pipeline,
  200. page: int,
  201. limit: int,
  202. user_id: str | None,
  203. named_only: bool = False,
  204. ) -> tuple[Sequence[Workflow], bool]:
  205. """
  206. Get published workflow with pagination
  207. """
  208. if not pipeline.workflow_id:
  209. return [], False
  210. stmt = (
  211. select(Workflow)
  212. .where(Workflow.app_id == pipeline.id)
  213. .order_by(Workflow.version.desc())
  214. .limit(limit + 1)
  215. .offset((page - 1) * limit)
  216. )
  217. if user_id:
  218. stmt = stmt.where(Workflow.created_by == user_id)
  219. if named_only:
  220. stmt = stmt.where(Workflow.marked_name != "")
  221. workflows = session.scalars(stmt).all()
  222. has_more = len(workflows) > limit
  223. if has_more:
  224. workflows = workflows[:-1]
  225. return workflows, has_more
  226. def sync_draft_workflow(
  227. self,
  228. *,
  229. pipeline: Pipeline,
  230. graph: dict,
  231. unique_hash: Optional[str],
  232. account: Account,
  233. environment_variables: Sequence[Variable],
  234. conversation_variables: Sequence[Variable],
  235. rag_pipeline_variables: list,
  236. ) -> Workflow:
  237. """
  238. Sync draft workflow
  239. :raises WorkflowHashNotEqualError
  240. """
  241. # fetch draft workflow by app_model
  242. workflow = self.get_draft_workflow(pipeline=pipeline)
  243. if workflow and workflow.unique_hash != unique_hash:
  244. raise WorkflowHashNotEqualError()
  245. # create draft workflow if not found
  246. if not workflow:
  247. workflow = Workflow(
  248. tenant_id=pipeline.tenant_id,
  249. app_id=pipeline.id,
  250. features="{}",
  251. type=WorkflowType.RAG_PIPELINE.value,
  252. version="draft",
  253. graph=json.dumps(graph),
  254. created_by=account.id,
  255. environment_variables=environment_variables,
  256. conversation_variables=conversation_variables,
  257. rag_pipeline_variables=rag_pipeline_variables,
  258. )
  259. db.session.add(workflow)
  260. db.session.flush()
  261. pipeline.workflow_id = workflow.id
  262. # update draft workflow if found
  263. else:
  264. workflow.graph = json.dumps(graph)
  265. workflow.updated_by = account.id
  266. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  267. workflow.environment_variables = environment_variables
  268. workflow.conversation_variables = conversation_variables
  269. workflow.rag_pipeline_variables = rag_pipeline_variables
  270. # commit db session changes
  271. db.session.commit()
  272. # trigger workflow events TODO
  273. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  274. # return draft workflow
  275. return workflow
  276. def publish_workflow(
  277. self,
  278. *,
  279. session: Session,
  280. pipeline: Pipeline,
  281. account: Account,
  282. ) -> Workflow:
  283. draft_workflow_stmt = select(Workflow).where(
  284. Workflow.tenant_id == pipeline.tenant_id,
  285. Workflow.app_id == pipeline.id,
  286. Workflow.version == "draft",
  287. )
  288. draft_workflow = session.scalar(draft_workflow_stmt)
  289. if not draft_workflow:
  290. raise ValueError("No valid workflow found.")
  291. # create new workflow
  292. workflow = Workflow.new(
  293. tenant_id=pipeline.tenant_id,
  294. app_id=pipeline.id,
  295. type=draft_workflow.type,
  296. version=str(datetime.now(UTC).replace(tzinfo=None)),
  297. graph=draft_workflow.graph,
  298. features=draft_workflow.features,
  299. created_by=account.id,
  300. environment_variables=draft_workflow.environment_variables,
  301. conversation_variables=draft_workflow.conversation_variables,
  302. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  303. marked_name="",
  304. marked_comment="",
  305. )
  306. # commit db session changes
  307. session.add(workflow)
  308. graph = workflow.graph_dict
  309. nodes = graph.get("nodes", [])
  310. for node in nodes:
  311. if node.get("data", {}).get("type") == "knowledge-index":
  312. knowledge_configuration = node.get("data", {})
  313. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  314. # update dataset
  315. dataset = pipeline.dataset
  316. if not dataset:
  317. raise ValueError("Dataset not found")
  318. DatasetService.update_rag_pipeline_dataset_settings(
  319. session=session,
  320. dataset=dataset,
  321. knowledge_configuration=knowledge_configuration,
  322. has_published=pipeline.is_published,
  323. )
  324. # return new workflow
  325. return workflow
  326. def get_default_block_configs(self) -> list[dict]:
  327. """
  328. Get default block configs
  329. """
  330. # return default block config
  331. default_block_configs = []
  332. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  333. node_class = node_class_mapping[LATEST_VERSION]
  334. default_config = node_class.get_default_config()
  335. if default_config:
  336. default_block_configs.append(default_config)
  337. return default_block_configs
  338. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  339. """
  340. Get default config of node.
  341. :param node_type: node type
  342. :param filters: filter by node config parameters.
  343. :return:
  344. """
  345. node_type_enum = NodeType(node_type)
  346. # return default block config
  347. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  348. return None
  349. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  350. default_config = node_class.get_default_config(filters=filters)
  351. if not default_config:
  352. return None
  353. return default_config
  354. def run_draft_workflow_node(
  355. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  356. ) -> WorkflowNodeExecutionModel:
  357. """
  358. Run draft workflow node
  359. """
  360. # fetch draft workflow by app_model
  361. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  362. if not draft_workflow:
  363. raise ValueError("Workflow not initialized")
  364. # run draft workflow node
  365. start_at = time.perf_counter()
  366. node_config = draft_workflow.get_node_config_by_id(node_id)
  367. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  368. if eclosing_node_type_and_id:
  369. _, enclosing_node_id = eclosing_node_type_and_id
  370. else:
  371. enclosing_node_id = None
  372. workflow_node_execution = self._handle_node_run_result(
  373. getter=lambda: WorkflowEntry.single_step_run(
  374. workflow=draft_workflow,
  375. node_id=node_id,
  376. user_inputs=user_inputs,
  377. user_id=account.id,
  378. variable_pool=VariablePool(
  379. system_variables=SystemVariable.empty(),
  380. user_inputs=user_inputs,
  381. environment_variables=[],
  382. conversation_variables=[],
  383. rag_pipeline_variables=[],
  384. ),
  385. variable_loader=DraftVarLoader(
  386. engine=db.engine,
  387. app_id=pipeline.id,
  388. tenant_id=pipeline.tenant_id,
  389. ),
  390. ),
  391. start_at=start_at,
  392. tenant_id=pipeline.tenant_id,
  393. node_id=node_id,
  394. )
  395. workflow_node_execution.workflow_id = draft_workflow.id
  396. # Create repository and save the node execution
  397. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  398. session_factory=db.engine,
  399. user=account,
  400. app_id=pipeline.id,
  401. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  402. )
  403. repository.save(workflow_node_execution)
  404. # Convert node_execution to WorkflowNodeExecution after save
  405. workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution)
  406. with Session(bind=db.engine) as session, session.begin():
  407. draft_var_saver = DraftVariableSaver(
  408. session=session,
  409. app_id=pipeline.id,
  410. node_id=workflow_node_execution_db_model.node_id,
  411. node_type=NodeType(workflow_node_execution_db_model.node_type),
  412. enclosing_node_id=enclosing_node_id,
  413. node_execution_id=workflow_node_execution.id,
  414. )
  415. draft_var_saver.save(
  416. process_data=workflow_node_execution.process_data,
  417. outputs=workflow_node_execution.outputs,
  418. )
  419. session.commit()
  420. return workflow_node_execution_db_model
  421. def run_datasource_workflow_node(
  422. self,
  423. pipeline: Pipeline,
  424. node_id: str,
  425. user_inputs: dict,
  426. account: Account,
  427. datasource_type: str,
  428. is_published: bool,
  429. credential_id: Optional[str] = None,
  430. ) -> Generator[BaseDatasourceEvent, None, None]:
  431. """
  432. Run published workflow datasource
  433. """
  434. try:
  435. if is_published:
  436. # fetch published workflow by app_model
  437. workflow = self.get_published_workflow(pipeline=pipeline)
  438. else:
  439. workflow = self.get_draft_workflow(pipeline=pipeline)
  440. if not workflow:
  441. raise ValueError("Workflow not initialized")
  442. # run draft workflow node
  443. datasource_node_data = None
  444. datasource_nodes = workflow.graph_dict.get("nodes", [])
  445. for datasource_node in datasource_nodes:
  446. if datasource_node.get("id") == node_id:
  447. datasource_node_data = datasource_node.get("data", {})
  448. break
  449. if not datasource_node_data:
  450. raise ValueError("Datasource node data not found")
  451. variables_map = {}
  452. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  453. for key, value in datasource_parameters.items():
  454. if value.get("value") and isinstance(value.get("value"), str):
  455. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  456. match = re.match(pattern, value["value"])
  457. if match:
  458. full_path = match.group(1)
  459. last_part = full_path.split(".")[-1]
  460. if last_part in user_inputs:
  461. variables_map[key] = user_inputs[last_part]
  462. else:
  463. variables_map[key] = value["value"]
  464. else:
  465. variables_map[key] = value["value"]
  466. else:
  467. variables_map[key] = value["value"]
  468. from core.datasource.datasource_manager import DatasourceManager
  469. datasource_runtime = DatasourceManager.get_datasource_runtime(
  470. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  471. datasource_name=datasource_node_data.get("datasource_name"),
  472. tenant_id=pipeline.tenant_id,
  473. datasource_type=DatasourceProviderType(datasource_type),
  474. )
  475. datasource_provider_service = DatasourceProviderService()
  476. credentials = datasource_provider_service.get_datasource_credentials(
  477. tenant_id=pipeline.tenant_id,
  478. provider=datasource_node_data.get("provider_name"),
  479. plugin_id=datasource_node_data.get("plugin_id"),
  480. credential_id=credential_id,
  481. )
  482. if credentials:
  483. datasource_runtime.runtime.credentials = credentials
  484. match datasource_type:
  485. case DatasourceProviderType.ONLINE_DOCUMENT:
  486. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  487. online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
  488. datasource_runtime.get_online_document_pages(
  489. user_id=account.id,
  490. datasource_parameters=user_inputs,
  491. provider_type=datasource_runtime.datasource_provider_type(),
  492. )
  493. )
  494. start_time = time.time()
  495. start_event = DatasourceProcessingEvent(
  496. total=0,
  497. completed=0,
  498. )
  499. yield start_event.model_dump()
  500. try:
  501. for message in online_document_result:
  502. end_time = time.time()
  503. online_document_event = DatasourceCompletedEvent(
  504. data=message.result, time_consuming=round(end_time - start_time, 2)
  505. )
  506. yield online_document_event.model_dump()
  507. except Exception as e:
  508. logger.exception("Error during online document.")
  509. yield DatasourceErrorEvent(error=str(e)).model_dump()
  510. case DatasourceProviderType.ONLINE_DRIVE:
  511. datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
  512. online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = (
  513. datasource_runtime.online_drive_browse_files(
  514. user_id=account.id,
  515. request=OnlineDriveBrowseFilesRequest(
  516. bucket=user_inputs.get("bucket"),
  517. prefix=user_inputs.get("prefix"),
  518. max_keys=user_inputs.get("max_keys", 20),
  519. start_after=user_inputs.get("start_after"),
  520. ),
  521. provider_type=datasource_runtime.datasource_provider_type(),
  522. )
  523. )
  524. start_time = time.time()
  525. start_event = DatasourceProcessingEvent(
  526. total=0,
  527. completed=0,
  528. )
  529. yield start_event.model_dump()
  530. for message in online_drive_result:
  531. end_time = time.time()
  532. online_drive_event = DatasourceCompletedEvent(
  533. data=message.result,
  534. time_consuming=round(end_time - start_time, 2),
  535. total=None,
  536. completed=None,
  537. )
  538. yield online_drive_event.model_dump()
  539. case DatasourceProviderType.WEBSITE_CRAWL:
  540. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  541. website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
  542. datasource_runtime.get_website_crawl(
  543. user_id=account.id,
  544. datasource_parameters=variables_map,
  545. provider_type=datasource_runtime.datasource_provider_type(),
  546. )
  547. )
  548. start_time = time.time()
  549. try:
  550. for message in website_crawl_result:
  551. end_time = time.time()
  552. if message.result.status == "completed":
  553. crawl_event = DatasourceCompletedEvent(
  554. data=message.result.web_info_list,
  555. total=message.result.total,
  556. completed=message.result.completed,
  557. time_consuming=round(end_time - start_time, 2),
  558. )
  559. else:
  560. crawl_event = DatasourceProcessingEvent(
  561. total=message.result.total,
  562. completed=message.result.completed,
  563. )
  564. yield crawl_event.model_dump()
  565. except Exception as e:
  566. logger.exception("Error during website crawl.")
  567. yield DatasourceErrorEvent(error=str(e)).model_dump()
  568. case _:
  569. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  570. except Exception as e:
  571. logger.exception("Error in run_datasource_workflow_node.")
  572. yield DatasourceErrorEvent(error=str(e)).model_dump()
  573. def run_datasource_node_preview(
  574. self,
  575. pipeline: Pipeline,
  576. node_id: str,
  577. user_inputs: dict,
  578. account: Account,
  579. datasource_type: str,
  580. is_published: bool,
  581. credential_id: Optional[str] = None,
  582. ) -> Mapping[str, Any]:
  583. """
  584. Run published workflow datasource
  585. """
  586. try:
  587. if is_published:
  588. # fetch published workflow by app_model
  589. workflow = self.get_published_workflow(pipeline=pipeline)
  590. else:
  591. workflow = self.get_draft_workflow(pipeline=pipeline)
  592. if not workflow:
  593. raise ValueError("Workflow not initialized")
  594. # run draft workflow node
  595. datasource_node_data = None
  596. datasource_nodes = workflow.graph_dict.get("nodes", [])
  597. for datasource_node in datasource_nodes:
  598. if datasource_node.get("id") == node_id:
  599. datasource_node_data = datasource_node.get("data", {})
  600. break
  601. if not datasource_node_data:
  602. raise ValueError("Datasource node data not found")
  603. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  604. for key, value in datasource_parameters.items():
  605. if not user_inputs.get(key):
  606. user_inputs[key] = value["value"]
  607. from core.datasource.datasource_manager import DatasourceManager
  608. datasource_runtime = DatasourceManager.get_datasource_runtime(
  609. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  610. datasource_name=datasource_node_data.get("datasource_name"),
  611. tenant_id=pipeline.tenant_id,
  612. datasource_type=DatasourceProviderType(datasource_type),
  613. )
  614. datasource_provider_service = DatasourceProviderService()
  615. credentials = datasource_provider_service.get_datasource_credentials(
  616. tenant_id=pipeline.tenant_id,
  617. provider=datasource_node_data.get("provider_name"),
  618. plugin_id=datasource_node_data.get("plugin_id"),
  619. credential_id=credential_id,
  620. )
  621. if credentials:
  622. datasource_runtime.runtime.credentials = credentials
  623. match datasource_type:
  624. case DatasourceProviderType.ONLINE_DOCUMENT:
  625. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  626. online_document_result: Generator[DatasourceMessage, None, None] = (
  627. datasource_runtime.get_online_document_page_content(
  628. user_id=account.id,
  629. datasource_parameters=GetOnlineDocumentPageContentRequest(
  630. workspace_id=user_inputs.get("workspace_id"),
  631. page_id=user_inputs.get("page_id"),
  632. type=user_inputs.get("type"),
  633. ),
  634. provider_type=datasource_type,
  635. )
  636. )
  637. try:
  638. variables: dict[str, Any] = {}
  639. for message in online_document_result:
  640. if message.type == DatasourceMessage.MessageType.VARIABLE:
  641. assert isinstance(message.message, DatasourceMessage.VariableMessage)
  642. variable_name = message.message.variable_name
  643. variable_value = message.message.variable_value
  644. if message.message.stream:
  645. if not isinstance(variable_value, str):
  646. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  647. if variable_name not in variables:
  648. variables[variable_name] = ""
  649. variables[variable_name] += variable_value
  650. else:
  651. variables[variable_name] = variable_value
  652. return variables
  653. except Exception as e:
  654. logger.exception("Error during get online document content.")
  655. raise RuntimeError(str(e))
  656. # TODO Online Drive
  657. case _:
  658. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  659. except Exception as e:
  660. logger.exception("Error in run_datasource_node_preview.")
  661. raise RuntimeError(str(e))
  662. def run_free_workflow_node(
  663. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  664. ) -> WorkflowNodeExecution:
  665. """
  666. Run draft workflow node
  667. """
  668. # run draft workflow node
  669. start_at = time.perf_counter()
  670. workflow_node_execution = self._handle_node_run_result(
  671. getter=lambda: WorkflowEntry.run_free_node(
  672. node_id=node_id,
  673. node_data=node_data,
  674. tenant_id=tenant_id,
  675. user_id=user_id,
  676. user_inputs=user_inputs,
  677. ),
  678. start_at=start_at,
  679. tenant_id=tenant_id,
  680. node_id=node_id,
  681. )
  682. return workflow_node_execution
  683. def _handle_node_run_result(
  684. self,
  685. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  686. start_at: float,
  687. tenant_id: str,
  688. node_id: str,
  689. ) -> WorkflowNodeExecution:
  690. """
  691. Handle node run result
  692. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  693. :param start_at: float
  694. :param tenant_id: str
  695. :param node_id: str
  696. """
  697. try:
  698. node_instance, generator = getter()
  699. node_run_result: NodeRunResult | None = None
  700. for event in generator:
  701. if isinstance(event, RunCompletedEvent):
  702. node_run_result = event.run_result
  703. # sign output files
  704. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  705. break
  706. if not node_run_result:
  707. raise ValueError("Node run failed with no run result")
  708. # single step debug mode error handling return
  709. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.continue_on_error:
  710. node_error_args: dict[str, Any] = {
  711. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  712. "error": node_run_result.error,
  713. "inputs": node_run_result.inputs,
  714. "metadata": {"error_strategy": node_instance.error_strategy},
  715. }
  716. if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  717. node_run_result = NodeRunResult(
  718. **node_error_args,
  719. outputs={
  720. **node_instance.default_value_dict,
  721. "error_message": node_run_result.error,
  722. "error_type": node_run_result.error_type,
  723. },
  724. )
  725. else:
  726. node_run_result = NodeRunResult(
  727. **node_error_args,
  728. outputs={
  729. "error_message": node_run_result.error,
  730. "error_type": node_run_result.error_type,
  731. },
  732. )
  733. run_succeeded = node_run_result.status in (
  734. WorkflowNodeExecutionStatus.SUCCEEDED,
  735. WorkflowNodeExecutionStatus.EXCEPTION,
  736. )
  737. error = node_run_result.error if not run_succeeded else None
  738. except WorkflowNodeRunFailedError as e:
  739. node_instance = e._node
  740. run_succeeded = False
  741. node_run_result = None
  742. error = e._error
  743. workflow_node_execution = WorkflowNodeExecution(
  744. id=str(uuid4()),
  745. workflow_id=node_instance.workflow_id,
  746. index=1,
  747. node_id=node_id,
  748. node_type=node_instance.type_,
  749. title=node_instance.title,
  750. elapsed_time=time.perf_counter() - start_at,
  751. finished_at=datetime.now(UTC).replace(tzinfo=None),
  752. created_at=datetime.now(UTC).replace(tzinfo=None),
  753. )
  754. if run_succeeded and node_run_result:
  755. # create workflow node execution
  756. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  757. process_data = (
  758. WorkflowEntry.handle_special_values(node_run_result.process_data)
  759. if node_run_result.process_data
  760. else None
  761. )
  762. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  763. workflow_node_execution.inputs = inputs
  764. workflow_node_execution.process_data = process_data
  765. workflow_node_execution.outputs = outputs
  766. workflow_node_execution.metadata = node_run_result.metadata
  767. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  768. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  769. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  770. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  771. workflow_node_execution.error = node_run_result.error
  772. else:
  773. # create workflow node execution
  774. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
  775. workflow_node_execution.error = error
  776. # update document status
  777. variable_pool = node_instance.graph_runtime_state.variable_pool
  778. invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
  779. if invoke_from:
  780. if invoke_from.value == InvokeFrom.PUBLISHED.value:
  781. document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
  782. if document_id:
  783. document = db.session.query(Document).filter(Document.id == document_id.value).first()
  784. if document:
  785. document.indexing_status = "error"
  786. document.error = error
  787. db.session.add(document)
  788. db.session.commit()
  789. return workflow_node_execution
  790. def update_workflow(
  791. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  792. ) -> Optional[Workflow]:
  793. """
  794. Update workflow attributes
  795. :param session: SQLAlchemy database session
  796. :param workflow_id: Workflow ID
  797. :param tenant_id: Tenant ID
  798. :param account_id: Account ID (for permission check)
  799. :param data: Dictionary containing fields to update
  800. :return: Updated workflow or None if not found
  801. """
  802. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  803. workflow = session.scalar(stmt)
  804. if not workflow:
  805. return None
  806. allowed_fields = ["marked_name", "marked_comment"]
  807. for field, value in data.items():
  808. if field in allowed_fields:
  809. setattr(workflow, field, value)
  810. workflow.updated_by = account_id
  811. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  812. return workflow
  813. def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  814. """
  815. Get first step parameters of rag pipeline
  816. """
  817. workflow = (
  818. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  819. )
  820. if not workflow:
  821. raise ValueError("Workflow not initialized")
  822. datasource_node_data = None
  823. datasource_nodes = workflow.graph_dict.get("nodes", [])
  824. for datasource_node in datasource_nodes:
  825. if datasource_node.get("id") == node_id:
  826. datasource_node_data = datasource_node.get("data", {})
  827. break
  828. if not datasource_node_data:
  829. raise ValueError("Datasource node data not found")
  830. variables = workflow.rag_pipeline_variables
  831. if variables:
  832. variables_map = {item["variable"]: item for item in variables}
  833. else:
  834. return []
  835. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  836. user_input_variables = []
  837. for key, value in datasource_parameters.items():
  838. if value.get("value") and isinstance(value.get("value"), str):
  839. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  840. match = re.match(pattern, value["value"])
  841. if match:
  842. full_path = match.group(1)
  843. last_part = full_path.split(".")[-1]
  844. user_input_variables.append(variables_map.get(last_part, {}))
  845. return user_input_variables
  846. def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  847. """
  848. Get second step parameters of rag pipeline
  849. """
  850. workflow = (
  851. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  852. )
  853. if not workflow:
  854. raise ValueError("Workflow not initialized")
  855. # get second step node
  856. rag_pipeline_variables = workflow.rag_pipeline_variables
  857. if not rag_pipeline_variables:
  858. return []
  859. variables_map = {item["variable"]: item for item in rag_pipeline_variables}
  860. # get datasource node data
  861. datasource_node_data = None
  862. datasource_nodes = workflow.graph_dict.get("nodes", [])
  863. for datasource_node in datasource_nodes:
  864. if datasource_node.get("id") == node_id:
  865. datasource_node_data = datasource_node.get("data", {})
  866. break
  867. if datasource_node_data:
  868. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  869. for key, value in datasource_parameters.items():
  870. if value.get("value") and isinstance(value.get("value"), str):
  871. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  872. match = re.match(pattern, value["value"])
  873. if match:
  874. full_path = match.group(1)
  875. last_part = full_path.split(".")[-1]
  876. variables_map.pop(last_part)
  877. all_second_step_variables = list(variables_map.values())
  878. datasource_provider_variables = [
  879. item
  880. for item in all_second_step_variables
  881. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  882. ]
  883. return datasource_provider_variables
  884. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  885. """
  886. Get debug workflow run list
  887. Only return triggered_from == debugging
  888. :param app_model: app model
  889. :param args: request args
  890. """
  891. limit = int(args.get("limit", 20))
  892. base_query = db.session.query(WorkflowRun).filter(
  893. WorkflowRun.tenant_id == pipeline.tenant_id,
  894. WorkflowRun.app_id == pipeline.id,
  895. or_(
  896. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  897. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  898. ),
  899. )
  900. if args.get("last_id"):
  901. last_workflow_run = base_query.filter(
  902. WorkflowRun.id == args.get("last_id"),
  903. ).first()
  904. if not last_workflow_run:
  905. raise ValueError("Last workflow run not exists")
  906. workflow_runs = (
  907. base_query.filter(
  908. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  909. )
  910. .order_by(WorkflowRun.created_at.desc())
  911. .limit(limit)
  912. .all()
  913. )
  914. else:
  915. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  916. has_more = False
  917. if len(workflow_runs) == limit:
  918. current_page_first_workflow_run = workflow_runs[-1]
  919. rest_count = base_query.filter(
  920. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  921. WorkflowRun.id != current_page_first_workflow_run.id,
  922. ).count()
  923. if rest_count > 0:
  924. has_more = True
  925. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  926. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  927. """
  928. Get workflow run detail
  929. :param app_model: app model
  930. :param run_id: workflow run id
  931. """
  932. workflow_run = (
  933. db.session.query(WorkflowRun)
  934. .filter(
  935. WorkflowRun.tenant_id == pipeline.tenant_id,
  936. WorkflowRun.app_id == pipeline.id,
  937. WorkflowRun.id == run_id,
  938. )
  939. .first()
  940. )
  941. return workflow_run
  942. def get_rag_pipeline_workflow_run_node_executions(
  943. self,
  944. pipeline: Pipeline,
  945. run_id: str,
  946. user: Account | EndUser,
  947. ) -> list[WorkflowNodeExecutionModel]:
  948. """
  949. Get workflow run node execution list
  950. """
  951. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  952. contexts.plugin_tool_providers.set({})
  953. contexts.plugin_tool_providers_lock.set(threading.Lock())
  954. if not workflow_run:
  955. return []
  956. # Use the repository to get the node execution
  957. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  958. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  959. )
  960. # Use the repository to get the node executions with ordering
  961. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  962. node_executions = repository.get_db_models_by_workflow_run(
  963. workflow_run_id=run_id,
  964. order_config=order_config,
  965. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  966. )
  967. return list(node_executions)
  968. @classmethod
  969. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  970. """
  971. Publish customized pipeline template
  972. """
  973. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  974. if not pipeline:
  975. raise ValueError("Pipeline not found")
  976. if not pipeline.workflow_id:
  977. raise ValueError("Pipeline workflow not found")
  978. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  979. if not workflow:
  980. raise ValueError("Workflow not found")
  981. dataset = pipeline.dataset
  982. if not dataset:
  983. raise ValueError("Dataset not found")
  984. # check template name is exist
  985. template_name = args.get("name")
  986. if template_name:
  987. template = (
  988. db.session.query(PipelineCustomizedTemplate)
  989. .filter(
  990. PipelineCustomizedTemplate.name == template_name,
  991. PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
  992. )
  993. .first()
  994. )
  995. if template:
  996. raise ValueError("Template name is already exists")
  997. max_position = (
  998. db.session.query(func.max(PipelineCustomizedTemplate.position))
  999. .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
  1000. .scalar()
  1001. )
  1002. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  1003. dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
  1004. pipeline_customized_template = PipelineCustomizedTemplate(
  1005. name=args.get("name"),
  1006. description=args.get("description"),
  1007. icon=args.get("icon_info"),
  1008. tenant_id=pipeline.tenant_id,
  1009. yaml_content=dsl,
  1010. position=max_position + 1 if max_position else 1,
  1011. chunk_structure=dataset.chunk_structure,
  1012. language="en-US",
  1013. created_by=current_user.id,
  1014. )
  1015. db.session.add(pipeline_customized_template)
  1016. db.session.commit()
  1017. def is_workflow_exist(self, pipeline: Pipeline) -> bool:
  1018. return (
  1019. db.session.query(Workflow)
  1020. .filter(
  1021. Workflow.tenant_id == pipeline.tenant_id,
  1022. Workflow.app_id == pipeline.id,
  1023. Workflow.version == Workflow.VERSION_DRAFT,
  1024. )
  1025. .count()
  1026. ) > 0
  1027. def get_node_last_run(
  1028. self, pipeline: Pipeline, workflow: Workflow, node_id: str
  1029. ) -> WorkflowNodeExecutionModel | None:
  1030. # TODO(QuantumGhost): This query is not fully covered by index.
  1031. criteria = (
  1032. WorkflowNodeExecutionModel.tenant_id == pipeline.tenant_id,
  1033. WorkflowNodeExecutionModel.app_id == pipeline.id,
  1034. WorkflowNodeExecutionModel.workflow_id == workflow.id,
  1035. WorkflowNodeExecutionModel.node_id == node_id,
  1036. )
  1037. node_exec = (
  1038. db.session.query(WorkflowNodeExecutionModel)
  1039. .filter(*criteria)
  1040. .order_by(WorkflowNodeExecutionModel.created_at.desc())
  1041. .first()
  1042. )
  1043. return node_exec
  1044. def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser):
  1045. # fetch draft workflow by app_model
  1046. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  1047. if not draft_workflow:
  1048. raise ValueError("Workflow not initialized")
  1049. workflow_node_execution = WorkflowNodeExecution(
  1050. id=str(uuid4()),
  1051. workflow_id=draft_workflow.id,
  1052. index=1,
  1053. node_id=args.get("start_node_id", ""),
  1054. node_type=NodeType.DATASOURCE,
  1055. title=args.get("start_node_title", "Datasource"),
  1056. elapsed_time=0,
  1057. finished_at=datetime.now(UTC).replace(tzinfo=None),
  1058. created_at=datetime.now(UTC).replace(tzinfo=None),
  1059. status=WorkflowNodeExecutionStatus.SUCCEEDED,
  1060. inputs=None,
  1061. metadata=None,
  1062. )
  1063. outputs = {
  1064. **args.get("datasource_info", {}),
  1065. "datasource_type": args.get("datasource_type", ""),
  1066. }
  1067. workflow_node_execution.outputs = outputs
  1068. node_config = draft_workflow.get_node_config_by_id(args.get("start_node_id", ""))
  1069. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  1070. if eclosing_node_type_and_id:
  1071. _, enclosing_node_id = eclosing_node_type_and_id
  1072. else:
  1073. enclosing_node_id = None
  1074. # Create repository and save the node execution
  1075. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  1076. session_factory=db.engine,
  1077. user=current_user,
  1078. app_id=pipeline.id,
  1079. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  1080. )
  1081. repository.save(workflow_node_execution)
  1082. # Convert node_execution to WorkflowNodeExecution after save
  1083. workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution)
  1084. with Session(bind=db.engine) as session, session.begin():
  1085. draft_var_saver = DraftVariableSaver(
  1086. session=session,
  1087. app_id=pipeline.id,
  1088. node_id=workflow_node_execution_db_model.node_id,
  1089. node_type=NodeType(workflow_node_execution_db_model.node_type),
  1090. enclosing_node_id=enclosing_node_id,
  1091. node_execution_id=workflow_node_execution.id,
  1092. )
  1093. draft_var_saver.save(
  1094. process_data=workflow_node_execution.process_data,
  1095. outputs=workflow_node_execution.outputs,
  1096. )
  1097. session.commit()
  1098. return workflow_node_execution_db_model