您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rag_pipeline.py 29KB

6 个月前
5 个月前
6 个月前
5 个月前
5 个月前
6 个月前
6 个月前
6 个月前
5 个月前
6 个月前
5 个月前
5 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
6 个月前
5 个月前
6 个月前
5 个月前
5 个月前
5 个月前
6 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
6 个月前
5 个月前
6 个月前
6 个月前
6 个月前
6 个月前
5 个月前
6 个月前
6 个月前
6 个月前
5 个月前
6 个月前
6 个月前
5 个月前
6 个月前
6 个月前
5 个月前
6 个月前
6 个月前
6 个月前
6 个月前
6 个月前
6 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
6 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
6 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
6 个月前
5 个月前
6 个月前
5 个月前
6 个月前
6 个月前
6 个月前
5 个月前
5 个月前
5 个月前
6 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
5 个月前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. import json
  2. import threading
  3. import time
  4. from collections.abc import Callable, Generator, Sequence
  5. from datetime import UTC, datetime
  6. from typing import Any, Optional, cast
  7. from uuid import uuid4
  8. from flask_login import current_user
  9. from sqlalchemy import select
  10. from sqlalchemy.orm import Session
  11. import contexts
  12. from configs import dify_config
  13. from core.datasource.entities.datasource_entities import DatasourceProviderType, GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, GetWebsiteCrawlRequest, GetWebsiteCrawlResponse
  14. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  15. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  16. from core.model_runtime.utils.encoders import jsonable_encoder
  17. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  18. from core.variables.variables import Variable
  19. from core.workflow.entities.node_entities import NodeRunResult
  20. from core.workflow.errors import WorkflowNodeRunFailedError
  21. from core.workflow.graph_engine.entities.event import InNodeEvent
  22. from core.workflow.nodes.base.node import BaseNode
  23. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  24. from core.workflow.nodes.event.event import RunCompletedEvent
  25. from core.workflow.nodes.event.types import NodeEvent
  26. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  27. from core.workflow.repository.workflow_node_execution_repository import OrderConfig
  28. from core.workflow.workflow_entry import WorkflowEntry
  29. from extensions.ext_database import db
  30. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  31. from models.account import Account
  32. from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
  33. from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
  34. from models.model import EndUser
  35. from models.workflow import (
  36. Workflow,
  37. WorkflowNodeExecution,
  38. WorkflowNodeExecutionStatus,
  39. WorkflowNodeExecutionTriggeredFrom,
  40. WorkflowRun,
  41. WorkflowType,
  42. )
  43. from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
  44. from services.errors.app import WorkflowHashNotEqualError
  45. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  46. class RagPipelineService:
  47. @staticmethod
  48. def get_pipeline_templates(
  49. type: str = "built-in", language: str = "en-US"
  50. ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
  51. if type == "built-in":
  52. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  53. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  54. result = retrieval_instance.get_pipeline_templates(language)
  55. if not result.get("pipeline_templates") and language != "en-US":
  56. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  57. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  58. return result.get("pipeline_templates")
  59. else:
  60. mode = "customized"
  61. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  62. result = retrieval_instance.get_pipeline_templates(language)
  63. return result.get("pipeline_templates")
  64. @classmethod
  65. def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
  66. """
  67. Get pipeline template detail.
  68. :param template_id: template id
  69. :return:
  70. """
  71. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  72. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  73. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  74. return result
  75. @classmethod
  76. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  77. """
  78. Update pipeline template.
  79. :param template_id: template id
  80. :param template_info: template info
  81. """
  82. customized_template: PipelineCustomizedTemplate | None = (
  83. db.query(PipelineCustomizedTemplate)
  84. .filter(
  85. PipelineCustomizedTemplate.id == template_id,
  86. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  87. )
  88. .first()
  89. )
  90. if not customized_template:
  91. raise ValueError("Customized pipeline template not found.")
  92. customized_template.name = template_info.name
  93. customized_template.description = template_info.description
  94. customized_template.icon = template_info.icon_info.model_dump()
  95. db.commit()
  96. return customized_template
  97. @classmethod
  98. def delete_customized_pipeline_template(cls, template_id: str):
  99. """
  100. Delete customized pipeline template.
  101. """
  102. customized_template: PipelineCustomizedTemplate | None = (
  103. db.query(PipelineCustomizedTemplate)
  104. .filter(
  105. PipelineCustomizedTemplate.id == template_id,
  106. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  107. )
  108. .first()
  109. )
  110. if not customized_template:
  111. raise ValueError("Customized pipeline template not found.")
  112. db.delete(customized_template)
  113. db.commit()
  114. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  115. """
  116. Get draft workflow
  117. """
  118. # fetch draft workflow by rag pipeline
  119. workflow = (
  120. db.session.query(Workflow)
  121. .filter(
  122. Workflow.tenant_id == pipeline.tenant_id,
  123. Workflow.app_id == pipeline.id,
  124. Workflow.version == "draft",
  125. )
  126. .first()
  127. )
  128. # return draft workflow
  129. return workflow
  130. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  131. """
  132. Get published workflow
  133. """
  134. if not pipeline.workflow_id:
  135. return None
  136. # fetch published workflow by workflow_id
  137. workflow = (
  138. db.session.query(Workflow)
  139. .filter(
  140. Workflow.tenant_id == pipeline.tenant_id,
  141. Workflow.app_id == pipeline.id,
  142. Workflow.id == pipeline.workflow_id,
  143. )
  144. .first()
  145. )
  146. return workflow
  147. def get_all_published_workflow(
  148. self,
  149. *,
  150. session: Session,
  151. pipeline: Pipeline,
  152. page: int,
  153. limit: int,
  154. user_id: str | None,
  155. named_only: bool = False,
  156. ) -> tuple[Sequence[Workflow], bool]:
  157. """
  158. Get published workflow with pagination
  159. """
  160. if not pipeline.workflow_id:
  161. return [], False
  162. stmt = (
  163. select(Workflow)
  164. .where(Workflow.app_id == pipeline.id)
  165. .order_by(Workflow.version.desc())
  166. .limit(limit + 1)
  167. .offset((page - 1) * limit)
  168. )
  169. if user_id:
  170. stmt = stmt.where(Workflow.created_by == user_id)
  171. if named_only:
  172. stmt = stmt.where(Workflow.marked_name != "")
  173. workflows = session.scalars(stmt).all()
  174. has_more = len(workflows) > limit
  175. if has_more:
  176. workflows = workflows[:-1]
  177. return workflows, has_more
  178. def sync_draft_workflow(
  179. self,
  180. *,
  181. pipeline: Pipeline,
  182. graph: dict,
  183. unique_hash: Optional[str],
  184. account: Account,
  185. environment_variables: Sequence[Variable],
  186. conversation_variables: Sequence[Variable],
  187. rag_pipeline_variables: list,
  188. ) -> Workflow:
  189. """
  190. Sync draft workflow
  191. :raises WorkflowHashNotEqualError
  192. """
  193. # fetch draft workflow by app_model
  194. workflow = self.get_draft_workflow(pipeline=pipeline)
  195. if workflow and workflow.unique_hash != unique_hash:
  196. raise WorkflowHashNotEqualError()
  197. # create draft workflow if not found
  198. if not workflow:
  199. workflow = Workflow(
  200. tenant_id=pipeline.tenant_id,
  201. app_id=pipeline.id,
  202. features="{}",
  203. type=WorkflowType.RAG_PIPELINE.value,
  204. version="draft",
  205. graph=json.dumps(graph),
  206. created_by=account.id,
  207. environment_variables=environment_variables,
  208. conversation_variables=conversation_variables,
  209. rag_pipeline_variables=rag_pipeline_variables,
  210. )
  211. db.session.add(workflow)
  212. db.session.flush()
  213. pipeline.workflow_id = workflow.id
  214. # update draft workflow if found
  215. else:
  216. workflow.graph = json.dumps(graph)
  217. workflow.updated_by = account.id
  218. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  219. workflow.environment_variables = environment_variables
  220. workflow.conversation_variables = conversation_variables
  221. workflow.rag_pipeline_variables = rag_pipeline_variables
  222. # commit db session changes
  223. db.session.commit()
  224. # trigger workflow events TODO
  225. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  226. # return draft workflow
  227. return workflow
  228. def publish_workflow(
  229. self,
  230. *,
  231. session: Session,
  232. pipeline: Pipeline,
  233. account: Account,
  234. marked_name: str = "",
  235. marked_comment: str = "",
  236. ) -> Workflow:
  237. draft_workflow_stmt = select(Workflow).where(
  238. Workflow.tenant_id == pipeline.tenant_id,
  239. Workflow.app_id == pipeline.id,
  240. Workflow.version == "draft",
  241. )
  242. draft_workflow = session.scalar(draft_workflow_stmt)
  243. if not draft_workflow:
  244. raise ValueError("No valid workflow found.")
  245. # create new workflow
  246. workflow = Workflow.new(
  247. tenant_id=pipeline.tenant_id,
  248. app_id=pipeline.id,
  249. type=draft_workflow.type,
  250. version=str(datetime.now(UTC).replace(tzinfo=None)),
  251. graph=draft_workflow.graph,
  252. features=draft_workflow.features,
  253. created_by=account.id,
  254. environment_variables=draft_workflow.environment_variables,
  255. conversation_variables=draft_workflow.conversation_variables,
  256. marked_name=marked_name,
  257. marked_comment=marked_comment,
  258. )
  259. # commit db session changes
  260. session.add(workflow)
  261. # trigger app workflow events TODO
  262. # app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
  263. # return new workflow
  264. return workflow
  265. def get_default_block_configs(self) -> list[dict]:
  266. """
  267. Get default block configs
  268. """
  269. # return default block config
  270. default_block_configs = []
  271. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  272. node_class = node_class_mapping[LATEST_VERSION]
  273. default_config = node_class.get_default_config()
  274. if default_config:
  275. default_block_configs.append(default_config)
  276. return default_block_configs
  277. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  278. """
  279. Get default config of node.
  280. :param node_type: node type
  281. :param filters: filter by node config parameters.
  282. :return:
  283. """
  284. node_type_enum = NodeType(node_type)
  285. # return default block config
  286. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  287. return None
  288. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  289. default_config = node_class.get_default_config(filters=filters)
  290. if not default_config:
  291. return None
  292. return default_config
  293. def run_draft_workflow_node(
  294. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  295. ) -> WorkflowNodeExecution:
  296. """
  297. Run draft workflow node
  298. """
  299. # fetch draft workflow by app_model
  300. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  301. if not draft_workflow:
  302. raise ValueError("Workflow not initialized")
  303. # run draft workflow node
  304. start_at = time.perf_counter()
  305. workflow_node_execution = self._handle_node_run_result(
  306. getter=lambda: WorkflowEntry.single_step_run(
  307. workflow=draft_workflow,
  308. node_id=node_id,
  309. user_inputs=user_inputs,
  310. user_id=account.id,
  311. ),
  312. start_at=start_at,
  313. tenant_id=pipeline.tenant_id,
  314. node_id=node_id,
  315. )
  316. workflow_node_execution.app_id = pipeline.id
  317. workflow_node_execution.created_by = account.id
  318. workflow_node_execution.workflow_id = draft_workflow.id
  319. db.session.add(workflow_node_execution)
  320. db.session.commit()
  321. return workflow_node_execution
  322. def run_published_workflow_node(
  323. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  324. ) -> WorkflowNodeExecution:
  325. """
  326. Run published workflow node
  327. """
  328. # fetch published workflow by app_model
  329. published_workflow = self.get_published_workflow(pipeline=pipeline)
  330. if not published_workflow:
  331. raise ValueError("Workflow not initialized")
  332. # run draft workflow node
  333. start_at = time.perf_counter()
  334. workflow_node_execution = self._handle_node_run_result(
  335. getter=lambda: WorkflowEntry.single_step_run(
  336. workflow=published_workflow,
  337. node_id=node_id,
  338. user_inputs=user_inputs,
  339. user_id=account.id,
  340. ),
  341. start_at=start_at,
  342. tenant_id=pipeline.tenant_id,
  343. node_id=node_id,
  344. )
  345. workflow_node_execution.app_id = pipeline.id
  346. workflow_node_execution.created_by = account.id
  347. workflow_node_execution.workflow_id = published_workflow.id
  348. db.session.add(workflow_node_execution)
  349. db.session.commit()
  350. return workflow_node_execution
  351. def run_datasource_workflow_node(
  352. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str
  353. ) -> dict:
  354. """
  355. Run published workflow datasource
  356. """
  357. # fetch published workflow by app_model
  358. published_workflow = self.get_published_workflow(pipeline=pipeline)
  359. if not published_workflow:
  360. raise ValueError("Workflow not initialized")
  361. # run draft workflow node
  362. start_at = time.perf_counter()
  363. datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {})
  364. if not datasource_node_data:
  365. raise ValueError("Datasource node data not found")
  366. from core.datasource.datasource_manager import DatasourceManager
  367. datasource_runtime = DatasourceManager.get_datasource_runtime(
  368. provider_id=datasource_node_data.get("provider_id"),
  369. datasource_name=datasource_node_data.get("datasource_name"),
  370. tenant_id=pipeline.tenant_id,
  371. datasource_type=DatasourceProviderType(datasource_type),
  372. )
  373. if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
  374. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  375. online_document_result: GetOnlineDocumentPagesResponse = (
  376. datasource_runtime._get_online_document_pages(
  377. user_id=account.id,
  378. datasource_parameters=user_inputs,
  379. provider_type=datasource_runtime.datasource_provider_type(),
  380. )
  381. )
  382. return {
  383. "result": [page.model_dump() for page in online_document_result.result],
  384. "provider_type": datasource_node_data.get("provider_type"),
  385. }
  386. elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
  387. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  388. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  389. user_id=account.id,
  390. datasource_parameters=user_inputs,
  391. provider_type=datasource_runtime.datasource_provider_type(),
  392. )
  393. return {
  394. "result": [result.model_dump() for result in website_crawl_result.result],
  395. "provider_type": datasource_node_data.get("provider_type"),
  396. }
  397. else:
  398. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  399. def run_free_workflow_node(
  400. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  401. ) -> WorkflowNodeExecution:
  402. """
  403. Run draft workflow node
  404. """
  405. # run draft workflow node
  406. start_at = time.perf_counter()
  407. workflow_node_execution = self._handle_node_run_result(
  408. getter=lambda: WorkflowEntry.run_free_node(
  409. node_id=node_id,
  410. node_data=node_data,
  411. tenant_id=tenant_id,
  412. user_id=user_id,
  413. user_inputs=user_inputs,
  414. ),
  415. start_at=start_at,
  416. tenant_id=tenant_id,
  417. node_id=node_id,
  418. )
  419. return workflow_node_execution
  420. def _handle_node_run_result(
  421. self,
  422. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  423. start_at: float,
  424. tenant_id: str,
  425. node_id: str,
  426. ) -> WorkflowNodeExecution:
  427. """
  428. Handle node run result
  429. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  430. :param start_at: float
  431. :param tenant_id: str
  432. :param node_id: str
  433. """
  434. try:
  435. node_instance, generator = getter()
  436. node_run_result: NodeRunResult | None = None
  437. for event in generator:
  438. if isinstance(event, RunCompletedEvent):
  439. node_run_result = event.run_result
  440. # sign output files
  441. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  442. break
  443. if not node_run_result:
  444. raise ValueError("Node run failed with no run result")
  445. # single step debug mode error handling return
  446. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  447. node_error_args: dict[str, Any] = {
  448. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  449. "error": node_run_result.error,
  450. "inputs": node_run_result.inputs,
  451. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  452. }
  453. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  454. node_run_result = NodeRunResult(
  455. **node_error_args,
  456. outputs={
  457. **node_instance.node_data.default_value_dict,
  458. "error_message": node_run_result.error,
  459. "error_type": node_run_result.error_type,
  460. },
  461. )
  462. else:
  463. node_run_result = NodeRunResult(
  464. **node_error_args,
  465. outputs={
  466. "error_message": node_run_result.error,
  467. "error_type": node_run_result.error_type,
  468. },
  469. )
  470. run_succeeded = node_run_result.status in (
  471. WorkflowNodeExecutionStatus.SUCCEEDED,
  472. WorkflowNodeExecutionStatus.EXCEPTION,
  473. )
  474. error = node_run_result.error if not run_succeeded else None
  475. except WorkflowNodeRunFailedError as e:
  476. node_instance = e.node_instance
  477. run_succeeded = False
  478. node_run_result = None
  479. error = e.error
  480. workflow_node_execution = WorkflowNodeExecution()
  481. workflow_node_execution.id = str(uuid4())
  482. workflow_node_execution.tenant_id = tenant_id
  483. workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
  484. workflow_node_execution.index = 1
  485. workflow_node_execution.node_id = node_id
  486. workflow_node_execution.node_type = node_instance.node_type
  487. workflow_node_execution.title = node_instance.node_data.title
  488. workflow_node_execution.elapsed_time = time.perf_counter() - start_at
  489. workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
  490. workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
  491. workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
  492. if run_succeeded and node_run_result:
  493. # create workflow node execution
  494. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  495. process_data = (
  496. WorkflowEntry.handle_special_values(node_run_result.process_data)
  497. if node_run_result.process_data
  498. else None
  499. )
  500. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  501. workflow_node_execution.inputs = json.dumps(inputs)
  502. workflow_node_execution.process_data = json.dumps(process_data)
  503. workflow_node_execution.outputs = json.dumps(outputs)
  504. workflow_node_execution.execution_metadata = (
  505. json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
  506. )
  507. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  508. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
  509. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  510. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
  511. workflow_node_execution.error = node_run_result.error
  512. else:
  513. # create workflow node execution
  514. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
  515. workflow_node_execution.error = error
  516. return workflow_node_execution
  517. def update_workflow(
  518. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  519. ) -> Optional[Workflow]:
  520. """
  521. Update workflow attributes
  522. :param session: SQLAlchemy database session
  523. :param workflow_id: Workflow ID
  524. :param tenant_id: Tenant ID
  525. :param account_id: Account ID (for permission check)
  526. :param data: Dictionary containing fields to update
  527. :return: Updated workflow or None if not found
  528. """
  529. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  530. workflow = session.scalar(stmt)
  531. if not workflow:
  532. return None
  533. allowed_fields = ["marked_name", "marked_comment"]
  534. for field, value in data.items():
  535. if field in allowed_fields:
  536. setattr(workflow, field, value)
  537. workflow.updated_by = account_id
  538. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  539. return workflow
  540. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  541. """
  542. Get second step parameters of rag pipeline
  543. """
  544. workflow = self.get_published_workflow(pipeline=pipeline)
  545. if not workflow:
  546. raise ValueError("Workflow not initialized")
  547. # get second step node
  548. rag_pipeline_variables = workflow.rag_pipeline_variables
  549. if not rag_pipeline_variables:
  550. return []
  551. # get datasource provider
  552. datasource_provider_variables = [
  553. item
  554. for item in rag_pipeline_variables
  555. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  556. ]
  557. return datasource_provider_variables
  558. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  559. """
  560. Get second step parameters of rag pipeline
  561. """
  562. workflow = self.get_draft_workflow(pipeline=pipeline)
  563. if not workflow:
  564. raise ValueError("Workflow not initialized")
  565. # get second step node
  566. rag_pipeline_variables = workflow.rag_pipeline_variables
  567. if not rag_pipeline_variables:
  568. return []
  569. # get datasource provider
  570. datasource_provider_variables = [
  571. item
  572. for item in rag_pipeline_variables
  573. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  574. ]
  575. return datasource_provider_variables
  576. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  577. """
  578. Get debug workflow run list
  579. Only return triggered_from == debugging
  580. :param app_model: app model
  581. :param args: request args
  582. """
  583. limit = int(args.get("limit", 20))
  584. base_query = db.session.query(WorkflowRun).filter(
  585. WorkflowRun.tenant_id == pipeline.tenant_id,
  586. WorkflowRun.app_id == pipeline.id,
  587. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
  588. )
  589. if args.get("last_id"):
  590. last_workflow_run = base_query.filter(
  591. WorkflowRun.id == args.get("last_id"),
  592. ).first()
  593. if not last_workflow_run:
  594. raise ValueError("Last workflow run not exists")
  595. workflow_runs = (
  596. base_query.filter(
  597. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  598. )
  599. .order_by(WorkflowRun.created_at.desc())
  600. .limit(limit)
  601. .all()
  602. )
  603. else:
  604. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  605. has_more = False
  606. if len(workflow_runs) == limit:
  607. current_page_first_workflow_run = workflow_runs[-1]
  608. rest_count = base_query.filter(
  609. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  610. WorkflowRun.id != current_page_first_workflow_run.id,
  611. ).count()
  612. if rest_count > 0:
  613. has_more = True
  614. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  615. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  616. """
  617. Get workflow run detail
  618. :param app_model: app model
  619. :param run_id: workflow run id
  620. """
  621. workflow_run = (
  622. db.session.query(WorkflowRun)
  623. .filter(
  624. WorkflowRun.tenant_id == pipeline.tenant_id,
  625. WorkflowRun.app_id == pipeline.id,
  626. WorkflowRun.id == run_id,
  627. )
  628. .first()
  629. )
  630. return workflow_run
  631. def get_rag_pipeline_workflow_run_node_executions(
  632. self,
  633. pipeline: Pipeline,
  634. run_id: str,
  635. user: Account | EndUser,
  636. ) -> list[WorkflowNodeExecution]:
  637. """
  638. Get workflow run node execution list
  639. """
  640. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  641. contexts.plugin_tool_providers.set({})
  642. contexts.plugin_tool_providers_lock.set(threading.Lock())
  643. if not workflow_run:
  644. return []
  645. # Use the repository to get the node execution
  646. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  647. session_factory=db.engine,
  648. app_id=pipeline.id,
  649. user=user,
  650. triggered_from=None
  651. )
  652. # Use the repository to get the node executions with ordering
  653. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  654. node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
  655. # Convert domain models to database models
  656. workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
  657. return workflow_node_executions