You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 30KB

6 月之前
5 月之前
6 月之前
5 月之前
5 月之前
6 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
5 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
6 月之前
5 月之前
5 月之前
5 月之前
6 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. import json
  2. import threading
  3. import time
  4. from collections.abc import Callable, Generator, Sequence
  5. from datetime import UTC, datetime
  6. from typing import Any, Optional, cast
  7. from uuid import uuid4
  8. from flask_login import current_user
  9. from sqlalchemy import select
  10. from sqlalchemy.orm import Session
  11. import contexts
  12. from configs import dify_config
  13. from core.datasource.entities.datasource_entities import (
  14. DatasourceProviderType,
  15. GetOnlineDocumentPagesResponse,
  16. GetWebsiteCrawlResponse,
  17. )
  18. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  19. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  20. from core.model_runtime.utils.encoders import jsonable_encoder
  21. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  22. from core.variables.variables import Variable
  23. from core.workflow.entities.node_entities import NodeRunResult
  24. from core.workflow.errors import WorkflowNodeRunFailedError
  25. from core.workflow.graph_engine.entities.event import InNodeEvent
  26. from core.workflow.nodes.base.node import BaseNode
  27. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  28. from core.workflow.nodes.event.event import RunCompletedEvent
  29. from core.workflow.nodes.event.types import NodeEvent
  30. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  31. from core.workflow.repository.workflow_node_execution_repository import OrderConfig
  32. from core.workflow.workflow_entry import WorkflowEntry
  33. from extensions.ext_database import db
  34. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  35. from models.account import Account
  36. from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
  37. from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
  38. from models.model import EndUser
  39. from models.workflow import (
  40. Workflow,
  41. WorkflowNodeExecution,
  42. WorkflowNodeExecutionStatus,
  43. WorkflowNodeExecutionTriggeredFrom,
  44. WorkflowRun,
  45. WorkflowType,
  46. )
  47. from services.dataset_service import DatasetService
  48. from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity
  49. from services.errors.app import WorkflowHashNotEqualError
  50. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  51. class RagPipelineService:
  52. @staticmethod
  53. def get_pipeline_templates(
  54. type: str = "built-in", language: str = "en-US"
  55. ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
  56. if type == "built-in":
  57. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  58. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  59. result = retrieval_instance.get_pipeline_templates(language)
  60. if not result.get("pipeline_templates") and language != "en-US":
  61. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  62. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  63. return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
  64. else:
  65. mode = "customized"
  66. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  67. result = retrieval_instance.get_pipeline_templates(language)
  68. return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
  69. @classmethod
  70. def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
  71. """
  72. Get pipeline template detail.
  73. :param template_id: template id
  74. :return:
  75. """
  76. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  77. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  78. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  79. return result
  80. @classmethod
  81. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  82. """
  83. Update pipeline template.
  84. :param template_id: template id
  85. :param template_info: template info
  86. """
  87. customized_template: PipelineCustomizedTemplate | None = (
  88. db.query(PipelineCustomizedTemplate)
  89. .filter(
  90. PipelineCustomizedTemplate.id == template_id,
  91. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  92. )
  93. .first()
  94. )
  95. if not customized_template:
  96. raise ValueError("Customized pipeline template not found.")
  97. customized_template.name = template_info.name
  98. customized_template.description = template_info.description
  99. customized_template.icon = template_info.icon_info.model_dump()
  100. db.commit()
  101. return customized_template
  102. @classmethod
  103. def delete_customized_pipeline_template(cls, template_id: str):
  104. """
  105. Delete customized pipeline template.
  106. """
  107. customized_template: PipelineCustomizedTemplate | None = (
  108. db.query(PipelineCustomizedTemplate)
  109. .filter(
  110. PipelineCustomizedTemplate.id == template_id,
  111. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  112. )
  113. .first()
  114. )
  115. if not customized_template:
  116. raise ValueError("Customized pipeline template not found.")
  117. db.delete(customized_template)
  118. db.commit()
  119. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  120. """
  121. Get draft workflow
  122. """
  123. # fetch draft workflow by rag pipeline
  124. workflow = (
  125. db.session.query(Workflow)
  126. .filter(
  127. Workflow.tenant_id == pipeline.tenant_id,
  128. Workflow.app_id == pipeline.id,
  129. Workflow.version == "draft",
  130. )
  131. .first()
  132. )
  133. # return draft workflow
  134. return workflow
  135. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  136. """
  137. Get published workflow
  138. """
  139. if not pipeline.workflow_id:
  140. return None
  141. # fetch published workflow by workflow_id
  142. workflow = (
  143. db.session.query(Workflow)
  144. .filter(
  145. Workflow.tenant_id == pipeline.tenant_id,
  146. Workflow.app_id == pipeline.id,
  147. Workflow.id == pipeline.workflow_id,
  148. )
  149. .first()
  150. )
  151. return workflow
  152. def get_all_published_workflow(
  153. self,
  154. *,
  155. session: Session,
  156. pipeline: Pipeline,
  157. page: int,
  158. limit: int,
  159. user_id: str | None,
  160. named_only: bool = False,
  161. ) -> tuple[Sequence[Workflow], bool]:
  162. """
  163. Get published workflow with pagination
  164. """
  165. if not pipeline.workflow_id:
  166. return [], False
  167. stmt = (
  168. select(Workflow)
  169. .where(Workflow.app_id == pipeline.id)
  170. .order_by(Workflow.version.desc())
  171. .limit(limit + 1)
  172. .offset((page - 1) * limit)
  173. )
  174. if user_id:
  175. stmt = stmt.where(Workflow.created_by == user_id)
  176. if named_only:
  177. stmt = stmt.where(Workflow.marked_name != "")
  178. workflows = session.scalars(stmt).all()
  179. has_more = len(workflows) > limit
  180. if has_more:
  181. workflows = workflows[:-1]
  182. return workflows, has_more
  183. def sync_draft_workflow(
  184. self,
  185. *,
  186. pipeline: Pipeline,
  187. graph: dict,
  188. unique_hash: Optional[str],
  189. account: Account,
  190. environment_variables: Sequence[Variable],
  191. conversation_variables: Sequence[Variable],
  192. rag_pipeline_variables: list,
  193. ) -> Workflow:
  194. """
  195. Sync draft workflow
  196. :raises WorkflowHashNotEqualError
  197. """
  198. # fetch draft workflow by app_model
  199. workflow = self.get_draft_workflow(pipeline=pipeline)
  200. if workflow and workflow.unique_hash != unique_hash:
  201. raise WorkflowHashNotEqualError()
  202. # create draft workflow if not found
  203. if not workflow:
  204. workflow = Workflow(
  205. tenant_id=pipeline.tenant_id,
  206. app_id=pipeline.id,
  207. features="{}",
  208. type=WorkflowType.RAG_PIPELINE.value,
  209. version="draft",
  210. graph=json.dumps(graph),
  211. created_by=account.id,
  212. environment_variables=environment_variables,
  213. conversation_variables=conversation_variables,
  214. rag_pipeline_variables=rag_pipeline_variables,
  215. )
  216. db.session.add(workflow)
  217. db.session.flush()
  218. pipeline.workflow_id = workflow.id
  219. # update draft workflow if found
  220. else:
  221. workflow.graph = json.dumps(graph)
  222. workflow.updated_by = account.id
  223. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  224. workflow.environment_variables = environment_variables
  225. workflow.conversation_variables = conversation_variables
  226. workflow.rag_pipeline_variables = rag_pipeline_variables
  227. # commit db session changes
  228. db.session.commit()
  229. # trigger workflow events TODO
  230. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  231. # return draft workflow
  232. return workflow
  233. def publish_workflow(
  234. self,
  235. *,
  236. session: Session,
  237. pipeline: Pipeline,
  238. account: Account,
  239. ) -> Workflow:
  240. draft_workflow_stmt = select(Workflow).where(
  241. Workflow.tenant_id == pipeline.tenant_id,
  242. Workflow.app_id == pipeline.id,
  243. Workflow.version == "draft",
  244. )
  245. draft_workflow = session.scalar(draft_workflow_stmt)
  246. if not draft_workflow:
  247. raise ValueError("No valid workflow found.")
  248. # create new workflow
  249. workflow = Workflow.new(
  250. tenant_id=pipeline.tenant_id,
  251. app_id=pipeline.id,
  252. type=draft_workflow.type,
  253. version=str(datetime.now(UTC).replace(tzinfo=None)),
  254. graph=draft_workflow.graph,
  255. features=draft_workflow.features,
  256. created_by=account.id,
  257. environment_variables=draft_workflow.environment_variables,
  258. conversation_variables=draft_workflow.conversation_variables,
  259. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  260. marked_name="",
  261. marked_comment="",
  262. )
  263. # commit db session changes
  264. session.add(workflow)
  265. graph = workflow.graph_dict
  266. nodes = graph.get("nodes", [])
  267. for node in nodes:
  268. if node.get("data", {}).get("type") == "knowledge_index":
  269. knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
  270. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  271. # update dataset
  272. dataset = pipeline.dataset
  273. if not dataset:
  274. raise ValueError("Dataset not found")
  275. DatasetService.update_rag_pipeline_dataset_settings(
  276. session=session,
  277. dataset=dataset,
  278. knowledge_configuration=knowledge_configuration,
  279. has_published=pipeline.is_published
  280. )
  281. # return new workflow
  282. return workflow
  283. def get_default_block_configs(self) -> list[dict]:
  284. """
  285. Get default block configs
  286. """
  287. # return default block config
  288. default_block_configs = []
  289. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  290. node_class = node_class_mapping[LATEST_VERSION]
  291. default_config = node_class.get_default_config()
  292. if default_config:
  293. default_block_configs.append(default_config)
  294. return default_block_configs
  295. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  296. """
  297. Get default config of node.
  298. :param node_type: node type
  299. :param filters: filter by node config parameters.
  300. :return:
  301. """
  302. node_type_enum = NodeType(node_type)
  303. # return default block config
  304. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  305. return None
  306. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  307. default_config = node_class.get_default_config(filters=filters)
  308. if not default_config:
  309. return None
  310. return default_config
  311. def run_draft_workflow_node(
  312. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  313. ) -> WorkflowNodeExecution:
  314. """
  315. Run draft workflow node
  316. """
  317. # fetch draft workflow by app_model
  318. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  319. if not draft_workflow:
  320. raise ValueError("Workflow not initialized")
  321. # run draft workflow node
  322. start_at = time.perf_counter()
  323. workflow_node_execution = self._handle_node_run_result(
  324. getter=lambda: WorkflowEntry.single_step_run(
  325. workflow=draft_workflow,
  326. node_id=node_id,
  327. user_inputs=user_inputs,
  328. user_id=account.id,
  329. ),
  330. start_at=start_at,
  331. tenant_id=pipeline.tenant_id,
  332. node_id=node_id,
  333. )
  334. workflow_node_execution.app_id = pipeline.id
  335. workflow_node_execution.created_by = account.id
  336. workflow_node_execution.workflow_id = draft_workflow.id
  337. db.session.add(workflow_node_execution)
  338. db.session.commit()
  339. return workflow_node_execution
  340. def run_published_workflow_node(
  341. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  342. ) -> WorkflowNodeExecution:
  343. """
  344. Run published workflow node
  345. """
  346. # fetch published workflow by app_model
  347. published_workflow = self.get_published_workflow(pipeline=pipeline)
  348. if not published_workflow:
  349. raise ValueError("Workflow not initialized")
  350. # run draft workflow node
  351. start_at = time.perf_counter()
  352. workflow_node_execution = self._handle_node_run_result(
  353. getter=lambda: WorkflowEntry.single_step_run(
  354. workflow=published_workflow,
  355. node_id=node_id,
  356. user_inputs=user_inputs,
  357. user_id=account.id,
  358. ),
  359. start_at=start_at,
  360. tenant_id=pipeline.tenant_id,
  361. node_id=node_id,
  362. )
  363. workflow_node_execution.app_id = pipeline.id
  364. workflow_node_execution.created_by = account.id
  365. workflow_node_execution.workflow_id = published_workflow.id
  366. db.session.add(workflow_node_execution)
  367. db.session.commit()
  368. return workflow_node_execution
  369. def run_datasource_workflow_node(
  370. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str
  371. ) -> dict:
  372. """
  373. Run published workflow datasource
  374. """
  375. # fetch published workflow by app_model
  376. published_workflow = self.get_published_workflow(pipeline=pipeline)
  377. if not published_workflow:
  378. raise ValueError("Workflow not initialized")
  379. # run draft workflow node
  380. start_at = time.perf_counter()
  381. datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {})
  382. if not datasource_node_data:
  383. raise ValueError("Datasource node data not found")
  384. from core.datasource.datasource_manager import DatasourceManager
  385. datasource_runtime = DatasourceManager.get_datasource_runtime(
  386. provider_id=datasource_node_data.get("provider_id"),
  387. datasource_name=datasource_node_data.get("datasource_name"),
  388. tenant_id=pipeline.tenant_id,
  389. datasource_type=DatasourceProviderType(datasource_type),
  390. )
  391. if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
  392. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  393. online_document_result: GetOnlineDocumentPagesResponse = (
  394. datasource_runtime._get_online_document_pages(
  395. user_id=account.id,
  396. datasource_parameters=user_inputs,
  397. provider_type=datasource_runtime.datasource_provider_type(),
  398. )
  399. )
  400. return {
  401. "result": [page.model_dump() for page in online_document_result.result],
  402. "provider_type": datasource_node_data.get("provider_type"),
  403. }
  404. elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
  405. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  406. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  407. user_id=account.id,
  408. datasource_parameters=user_inputs,
  409. provider_type=datasource_runtime.datasource_provider_type(),
  410. )
  411. return {
  412. "result": [result.model_dump() for result in website_crawl_result.result],
  413. "provider_type": datasource_node_data.get("provider_type"),
  414. }
  415. else:
  416. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  417. def run_free_workflow_node(
  418. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  419. ) -> WorkflowNodeExecution:
  420. """
  421. Run draft workflow node
  422. """
  423. # run draft workflow node
  424. start_at = time.perf_counter()
  425. workflow_node_execution = self._handle_node_run_result(
  426. getter=lambda: WorkflowEntry.run_free_node(
  427. node_id=node_id,
  428. node_data=node_data,
  429. tenant_id=tenant_id,
  430. user_id=user_id,
  431. user_inputs=user_inputs,
  432. ),
  433. start_at=start_at,
  434. tenant_id=tenant_id,
  435. node_id=node_id,
  436. )
  437. return workflow_node_execution
  438. def _handle_node_run_result(
  439. self,
  440. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  441. start_at: float,
  442. tenant_id: str,
  443. node_id: str,
  444. ) -> WorkflowNodeExecution:
  445. """
  446. Handle node run result
  447. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  448. :param start_at: float
  449. :param tenant_id: str
  450. :param node_id: str
  451. """
  452. try:
  453. node_instance, generator = getter()
  454. node_run_result: NodeRunResult | None = None
  455. for event in generator:
  456. if isinstance(event, RunCompletedEvent):
  457. node_run_result = event.run_result
  458. # sign output files
  459. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  460. break
  461. if not node_run_result:
  462. raise ValueError("Node run failed with no run result")
  463. # single step debug mode error handling return
  464. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  465. node_error_args: dict[str, Any] = {
  466. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  467. "error": node_run_result.error,
  468. "inputs": node_run_result.inputs,
  469. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  470. }
  471. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  472. node_run_result = NodeRunResult(
  473. **node_error_args,
  474. outputs={
  475. **node_instance.node_data.default_value_dict,
  476. "error_message": node_run_result.error,
  477. "error_type": node_run_result.error_type,
  478. },
  479. )
  480. else:
  481. node_run_result = NodeRunResult(
  482. **node_error_args,
  483. outputs={
  484. "error_message": node_run_result.error,
  485. "error_type": node_run_result.error_type,
  486. },
  487. )
  488. run_succeeded = node_run_result.status in (
  489. WorkflowNodeExecutionStatus.SUCCEEDED,
  490. WorkflowNodeExecutionStatus.EXCEPTION,
  491. )
  492. error = node_run_result.error if not run_succeeded else None
  493. except WorkflowNodeRunFailedError as e:
  494. node_instance = e.node_instance
  495. run_succeeded = False
  496. node_run_result = None
  497. error = e.error
  498. workflow_node_execution = WorkflowNodeExecution()
  499. workflow_node_execution.id = str(uuid4())
  500. workflow_node_execution.tenant_id = tenant_id
  501. workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
  502. workflow_node_execution.index = 1
  503. workflow_node_execution.node_id = node_id
  504. workflow_node_execution.node_type = node_instance.node_type
  505. workflow_node_execution.title = node_instance.node_data.title
  506. workflow_node_execution.elapsed_time = time.perf_counter() - start_at
  507. workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
  508. workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
  509. workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
  510. if run_succeeded and node_run_result:
  511. # create workflow node execution
  512. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  513. process_data = (
  514. WorkflowEntry.handle_special_values(node_run_result.process_data)
  515. if node_run_result.process_data
  516. else None
  517. )
  518. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  519. workflow_node_execution.inputs = json.dumps(inputs)
  520. workflow_node_execution.process_data = json.dumps(process_data)
  521. workflow_node_execution.outputs = json.dumps(outputs)
  522. workflow_node_execution.execution_metadata = (
  523. json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
  524. )
  525. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  526. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
  527. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  528. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
  529. workflow_node_execution.error = node_run_result.error
  530. else:
  531. # create workflow node execution
  532. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
  533. workflow_node_execution.error = error
  534. return workflow_node_execution
  535. def update_workflow(
  536. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  537. ) -> Optional[Workflow]:
  538. """
  539. Update workflow attributes
  540. :param session: SQLAlchemy database session
  541. :param workflow_id: Workflow ID
  542. :param tenant_id: Tenant ID
  543. :param account_id: Account ID (for permission check)
  544. :param data: Dictionary containing fields to update
  545. :return: Updated workflow or None if not found
  546. """
  547. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  548. workflow = session.scalar(stmt)
  549. if not workflow:
  550. return None
  551. allowed_fields = ["marked_name", "marked_comment"]
  552. for field, value in data.items():
  553. if field in allowed_fields:
  554. setattr(workflow, field, value)
  555. workflow.updated_by = account_id
  556. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  557. return workflow
  558. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  559. """
  560. Get second step parameters of rag pipeline
  561. """
  562. workflow = self.get_published_workflow(pipeline=pipeline)
  563. if not workflow:
  564. raise ValueError("Workflow not initialized")
  565. # get second step node
  566. rag_pipeline_variables = workflow.rag_pipeline_variables
  567. if not rag_pipeline_variables:
  568. return []
  569. # get datasource provider
  570. datasource_provider_variables = [
  571. item
  572. for item in rag_pipeline_variables
  573. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  574. ]
  575. return datasource_provider_variables
  576. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  577. """
  578. Get second step parameters of rag pipeline
  579. """
  580. workflow = self.get_draft_workflow(pipeline=pipeline)
  581. if not workflow:
  582. raise ValueError("Workflow not initialized")
  583. # get second step node
  584. rag_pipeline_variables = workflow.rag_pipeline_variables
  585. if not rag_pipeline_variables:
  586. return []
  587. # get datasource provider
  588. datasource_provider_variables = [
  589. item
  590. for item in rag_pipeline_variables
  591. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  592. ]
  593. return datasource_provider_variables
  594. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  595. """
  596. Get debug workflow run list
  597. Only return triggered_from == debugging
  598. :param app_model: app model
  599. :param args: request args
  600. """
  601. limit = int(args.get("limit", 20))
  602. base_query = db.session.query(WorkflowRun).filter(
  603. WorkflowRun.tenant_id == pipeline.tenant_id,
  604. WorkflowRun.app_id == pipeline.id,
  605. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
  606. )
  607. if args.get("last_id"):
  608. last_workflow_run = base_query.filter(
  609. WorkflowRun.id == args.get("last_id"),
  610. ).first()
  611. if not last_workflow_run:
  612. raise ValueError("Last workflow run not exists")
  613. workflow_runs = (
  614. base_query.filter(
  615. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  616. )
  617. .order_by(WorkflowRun.created_at.desc())
  618. .limit(limit)
  619. .all()
  620. )
  621. else:
  622. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  623. has_more = False
  624. if len(workflow_runs) == limit:
  625. current_page_first_workflow_run = workflow_runs[-1]
  626. rest_count = base_query.filter(
  627. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  628. WorkflowRun.id != current_page_first_workflow_run.id,
  629. ).count()
  630. if rest_count > 0:
  631. has_more = True
  632. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  633. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  634. """
  635. Get workflow run detail
  636. :param app_model: app model
  637. :param run_id: workflow run id
  638. """
  639. workflow_run = (
  640. db.session.query(WorkflowRun)
  641. .filter(
  642. WorkflowRun.tenant_id == pipeline.tenant_id,
  643. WorkflowRun.app_id == pipeline.id,
  644. WorkflowRun.id == run_id,
  645. )
  646. .first()
  647. )
  648. return workflow_run
  649. def get_rag_pipeline_workflow_run_node_executions(
  650. self,
  651. pipeline: Pipeline,
  652. run_id: str,
  653. user: Account | EndUser,
  654. ) -> list[WorkflowNodeExecution]:
  655. """
  656. Get workflow run node execution list
  657. """
  658. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  659. contexts.plugin_tool_providers.set({})
  660. contexts.plugin_tool_providers_lock.set(threading.Lock())
  661. if not workflow_run:
  662. return []
  663. # Use the repository to get the node execution
  664. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  665. session_factory=db.engine,
  666. app_id=pipeline.id,
  667. user=user,
  668. triggered_from=None
  669. )
  670. # Use the repository to get the node executions with ordering
  671. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  672. node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
  673. # Convert domain models to database models
  674. workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
  675. return workflow_node_executions