Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

rag_pipeline.py 30KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790
  1. import json
  2. import threading
  3. import time
  4. from collections.abc import Callable, Generator, Sequence
  5. from datetime import UTC, datetime
  6. from typing import Any, Optional, cast
  7. from uuid import uuid4
  8. from flask_login import current_user
  9. from sqlalchemy import or_, select
  10. from sqlalchemy.orm import Session
  11. import contexts
  12. from configs import dify_config
  13. from core.datasource.entities.datasource_entities import (
  14. DatasourceProviderType,
  15. GetOnlineDocumentPagesResponse,
  16. GetWebsiteCrawlResponse,
  17. )
  18. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  19. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  20. from core.model_runtime.utils.encoders import jsonable_encoder
  21. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  22. from core.variables.variables import Variable
  23. from core.workflow.entities.node_entities import NodeRunResult
  24. from core.workflow.errors import WorkflowNodeRunFailedError
  25. from core.workflow.graph_engine.entities.event import InNodeEvent
  26. from core.workflow.nodes.base.node import BaseNode
  27. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  28. from core.workflow.nodes.event.event import RunCompletedEvent
  29. from core.workflow.nodes.event.types import NodeEvent
  30. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  31. from core.workflow.repository.workflow_node_execution_repository import OrderConfig
  32. from core.workflow.workflow_entry import WorkflowEntry
  33. from extensions.ext_database import db
  34. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  35. from models.account import Account
  36. from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
  37. from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
  38. from models.model import EndUser
  39. from models.workflow import (
  40. Workflow,
  41. WorkflowNodeExecution,
  42. WorkflowNodeExecutionStatus,
  43. WorkflowNodeExecutionTriggeredFrom,
  44. WorkflowRun,
  45. WorkflowType,
  46. )
  47. from services.dataset_service import DatasetService
  48. from services.entities.knowledge_entities.rag_pipeline_entities import (
  49. KnowledgeConfiguration,
  50. PipelineTemplateInfoEntity,
  51. )
  52. from services.errors.app import WorkflowHashNotEqualError
  53. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  54. class RagPipelineService:
  55. @classmethod
  56. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  57. if type == "built-in":
  58. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  59. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  60. result = retrieval_instance.get_pipeline_templates(language)
  61. if not result.get("pipeline_templates") and language != "en-US":
  62. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  63. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  64. return result
  65. else:
  66. mode = "customized"
  67. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  68. result = retrieval_instance.get_pipeline_templates(language)
  69. return result
  70. @classmethod
  71. def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
  72. """
  73. Get pipeline template detail.
  74. :param template_id: template id
  75. :return:
  76. """
  77. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  78. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  79. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  80. return result
  81. @classmethod
  82. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  83. """
  84. Update pipeline template.
  85. :param template_id: template id
  86. :param template_info: template info
  87. """
  88. customized_template: PipelineCustomizedTemplate | None = (
  89. db.query(PipelineCustomizedTemplate)
  90. .filter(
  91. PipelineCustomizedTemplate.id == template_id,
  92. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  93. )
  94. .first()
  95. )
  96. if not customized_template:
  97. raise ValueError("Customized pipeline template not found.")
  98. customized_template.name = template_info.name
  99. customized_template.description = template_info.description
  100. customized_template.icon = template_info.icon_info.model_dump()
  101. db.commit()
  102. return customized_template
  103. @classmethod
  104. def delete_customized_pipeline_template(cls, template_id: str):
  105. """
  106. Delete customized pipeline template.
  107. """
  108. customized_template: PipelineCustomizedTemplate | None = (
  109. db.query(PipelineCustomizedTemplate)
  110. .filter(
  111. PipelineCustomizedTemplate.id == template_id,
  112. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  113. )
  114. .first()
  115. )
  116. if not customized_template:
  117. raise ValueError("Customized pipeline template not found.")
  118. db.delete(customized_template)
  119. db.commit()
  120. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  121. """
  122. Get draft workflow
  123. """
  124. # fetch draft workflow by rag pipeline
  125. workflow = (
  126. db.session.query(Workflow)
  127. .filter(
  128. Workflow.tenant_id == pipeline.tenant_id,
  129. Workflow.app_id == pipeline.id,
  130. Workflow.version == "draft",
  131. )
  132. .first()
  133. )
  134. # return draft workflow
  135. return workflow
  136. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  137. """
  138. Get published workflow
  139. """
  140. if not pipeline.workflow_id:
  141. return None
  142. # fetch published workflow by workflow_id
  143. workflow = (
  144. db.session.query(Workflow)
  145. .filter(
  146. Workflow.tenant_id == pipeline.tenant_id,
  147. Workflow.app_id == pipeline.id,
  148. Workflow.id == pipeline.workflow_id,
  149. )
  150. .first()
  151. )
  152. return workflow
  153. def get_all_published_workflow(
  154. self,
  155. *,
  156. session: Session,
  157. pipeline: Pipeline,
  158. page: int,
  159. limit: int,
  160. user_id: str | None,
  161. named_only: bool = False,
  162. ) -> tuple[Sequence[Workflow], bool]:
  163. """
  164. Get published workflow with pagination
  165. """
  166. if not pipeline.workflow_id:
  167. return [], False
  168. stmt = (
  169. select(Workflow)
  170. .where(Workflow.app_id == pipeline.id)
  171. .order_by(Workflow.version.desc())
  172. .limit(limit + 1)
  173. .offset((page - 1) * limit)
  174. )
  175. if user_id:
  176. stmt = stmt.where(Workflow.created_by == user_id)
  177. if named_only:
  178. stmt = stmt.where(Workflow.marked_name != "")
  179. workflows = session.scalars(stmt).all()
  180. has_more = len(workflows) > limit
  181. if has_more:
  182. workflows = workflows[:-1]
  183. return workflows, has_more
  184. def sync_draft_workflow(
  185. self,
  186. *,
  187. pipeline: Pipeline,
  188. graph: dict,
  189. unique_hash: Optional[str],
  190. account: Account,
  191. environment_variables: Sequence[Variable],
  192. conversation_variables: Sequence[Variable],
  193. rag_pipeline_variables: list,
  194. ) -> Workflow:
  195. """
  196. Sync draft workflow
  197. :raises WorkflowHashNotEqualError
  198. """
  199. # fetch draft workflow by app_model
  200. workflow = self.get_draft_workflow(pipeline=pipeline)
  201. if workflow and workflow.unique_hash != unique_hash:
  202. raise WorkflowHashNotEqualError()
  203. # create draft workflow if not found
  204. if not workflow:
  205. workflow = Workflow(
  206. tenant_id=pipeline.tenant_id,
  207. app_id=pipeline.id,
  208. features="{}",
  209. type=WorkflowType.RAG_PIPELINE.value,
  210. version="draft",
  211. graph=json.dumps(graph),
  212. created_by=account.id,
  213. environment_variables=environment_variables,
  214. conversation_variables=conversation_variables,
  215. rag_pipeline_variables=rag_pipeline_variables,
  216. )
  217. db.session.add(workflow)
  218. db.session.flush()
  219. pipeline.workflow_id = workflow.id
  220. # update draft workflow if found
  221. else:
  222. workflow.graph = json.dumps(graph)
  223. workflow.updated_by = account.id
  224. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  225. workflow.environment_variables = environment_variables
  226. workflow.conversation_variables = conversation_variables
  227. workflow.rag_pipeline_variables = rag_pipeline_variables
  228. # commit db session changes
  229. db.session.commit()
  230. # trigger workflow events TODO
  231. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  232. # return draft workflow
  233. return workflow
  234. def publish_workflow(
  235. self,
  236. *,
  237. session: Session,
  238. pipeline: Pipeline,
  239. account: Account,
  240. ) -> Workflow:
  241. draft_workflow_stmt = select(Workflow).where(
  242. Workflow.tenant_id == pipeline.tenant_id,
  243. Workflow.app_id == pipeline.id,
  244. Workflow.version == "draft",
  245. )
  246. draft_workflow = session.scalar(draft_workflow_stmt)
  247. if not draft_workflow:
  248. raise ValueError("No valid workflow found.")
  249. # create new workflow
  250. workflow = Workflow.new(
  251. tenant_id=pipeline.tenant_id,
  252. app_id=pipeline.id,
  253. type=draft_workflow.type,
  254. version=str(datetime.now(UTC).replace(tzinfo=None)),
  255. graph=draft_workflow.graph,
  256. features=draft_workflow.features,
  257. created_by=account.id,
  258. environment_variables=draft_workflow.environment_variables,
  259. conversation_variables=draft_workflow.conversation_variables,
  260. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  261. marked_name="",
  262. marked_comment="",
  263. )
  264. # commit db session changes
  265. session.add(workflow)
  266. graph = workflow.graph_dict
  267. nodes = graph.get("nodes", [])
  268. for node in nodes:
  269. if node.get("data", {}).get("type") == "knowledge-index":
  270. knowledge_configuration = node.get("data", {})
  271. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  272. # update dataset
  273. dataset = pipeline.dataset
  274. if not dataset:
  275. raise ValueError("Dataset not found")
  276. DatasetService.update_rag_pipeline_dataset_settings(
  277. session=session,
  278. dataset=dataset,
  279. knowledge_configuration=knowledge_configuration,
  280. has_published=pipeline.is_published,
  281. )
  282. # return new workflow
  283. return workflow
  284. def get_default_block_configs(self) -> list[dict]:
  285. """
  286. Get default block configs
  287. """
  288. # return default block config
  289. default_block_configs = []
  290. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  291. node_class = node_class_mapping[LATEST_VERSION]
  292. default_config = node_class.get_default_config()
  293. if default_config:
  294. default_block_configs.append(default_config)
  295. return default_block_configs
  296. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  297. """
  298. Get default config of node.
  299. :param node_type: node type
  300. :param filters: filter by node config parameters.
  301. :return:
  302. """
  303. node_type_enum = NodeType(node_type)
  304. # return default block config
  305. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  306. return None
  307. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  308. default_config = node_class.get_default_config(filters=filters)
  309. if not default_config:
  310. return None
  311. return default_config
  312. def run_draft_workflow_node(
  313. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  314. ) -> WorkflowNodeExecution:
  315. """
  316. Run draft workflow node
  317. """
  318. # fetch draft workflow by app_model
  319. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  320. if not draft_workflow:
  321. raise ValueError("Workflow not initialized")
  322. # run draft workflow node
  323. start_at = time.perf_counter()
  324. workflow_node_execution = self._handle_node_run_result(
  325. getter=lambda: WorkflowEntry.single_step_run(
  326. workflow=draft_workflow,
  327. node_id=node_id,
  328. user_inputs=user_inputs,
  329. user_id=account.id,
  330. ),
  331. start_at=start_at,
  332. tenant_id=pipeline.tenant_id,
  333. node_id=node_id,
  334. )
  335. workflow_node_execution.app_id = pipeline.id
  336. workflow_node_execution.created_by = account.id
  337. workflow_node_execution.workflow_id = draft_workflow.id
  338. db.session.add(workflow_node_execution)
  339. db.session.commit()
  340. return workflow_node_execution
  341. def run_published_workflow_node(
  342. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  343. ) -> WorkflowNodeExecution:
  344. """
  345. Run published workflow node
  346. """
  347. # fetch published workflow by app_model
  348. published_workflow = self.get_published_workflow(pipeline=pipeline)
  349. if not published_workflow:
  350. raise ValueError("Workflow not initialized")
  351. # run draft workflow node
  352. start_at = time.perf_counter()
  353. workflow_node_execution = self._handle_node_run_result(
  354. getter=lambda: WorkflowEntry.single_step_run(
  355. workflow=published_workflow,
  356. node_id=node_id,
  357. user_inputs=user_inputs,
  358. user_id=account.id,
  359. ),
  360. start_at=start_at,
  361. tenant_id=pipeline.tenant_id,
  362. node_id=node_id,
  363. )
  364. workflow_node_execution.app_id = pipeline.id
  365. workflow_node_execution.created_by = account.id
  366. workflow_node_execution.workflow_id = published_workflow.id
  367. db.session.add(workflow_node_execution)
  368. db.session.commit()
  369. return workflow_node_execution
  370. def run_datasource_workflow_node(
  371. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str
  372. ) -> dict:
  373. """
  374. Run published workflow datasource
  375. """
  376. # fetch published workflow by app_model
  377. published_workflow = self.get_published_workflow(pipeline=pipeline)
  378. if not published_workflow:
  379. raise ValueError("Workflow not initialized")
  380. # run draft workflow node
  381. start_at = time.perf_counter()
  382. datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {})
  383. if not datasource_node_data:
  384. raise ValueError("Datasource node data not found")
  385. from core.datasource.datasource_manager import DatasourceManager
  386. datasource_runtime = DatasourceManager.get_datasource_runtime(
  387. provider_id=datasource_node_data.get("provider_id"),
  388. datasource_name=datasource_node_data.get("datasource_name"),
  389. tenant_id=pipeline.tenant_id,
  390. datasource_type=DatasourceProviderType(datasource_type),
  391. )
  392. if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
  393. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  394. online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
  395. user_id=account.id,
  396. datasource_parameters=user_inputs,
  397. provider_type=datasource_runtime.datasource_provider_type(),
  398. )
  399. return {
  400. "result": [page.model_dump() for page in online_document_result.result],
  401. "provider_type": datasource_node_data.get("provider_type"),
  402. }
  403. elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
  404. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  405. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  406. user_id=account.id,
  407. datasource_parameters=user_inputs,
  408. provider_type=datasource_runtime.datasource_provider_type(),
  409. )
  410. return {
  411. "result": [result.model_dump() for result in website_crawl_result.result],
  412. "provider_type": datasource_node_data.get("provider_type"),
  413. }
  414. else:
  415. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  416. def run_free_workflow_node(
  417. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  418. ) -> WorkflowNodeExecution:
  419. """
  420. Run draft workflow node
  421. """
  422. # run draft workflow node
  423. start_at = time.perf_counter()
  424. workflow_node_execution = self._handle_node_run_result(
  425. getter=lambda: WorkflowEntry.run_free_node(
  426. node_id=node_id,
  427. node_data=node_data,
  428. tenant_id=tenant_id,
  429. user_id=user_id,
  430. user_inputs=user_inputs,
  431. ),
  432. start_at=start_at,
  433. tenant_id=tenant_id,
  434. node_id=node_id,
  435. )
  436. return workflow_node_execution
  437. def _handle_node_run_result(
  438. self,
  439. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  440. start_at: float,
  441. tenant_id: str,
  442. node_id: str,
  443. ) -> WorkflowNodeExecution:
  444. """
  445. Handle node run result
  446. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  447. :param start_at: float
  448. :param tenant_id: str
  449. :param node_id: str
  450. """
  451. try:
  452. node_instance, generator = getter()
  453. node_run_result: NodeRunResult | None = None
  454. for event in generator:
  455. if isinstance(event, RunCompletedEvent):
  456. node_run_result = event.run_result
  457. # sign output files
  458. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  459. break
  460. if not node_run_result:
  461. raise ValueError("Node run failed with no run result")
  462. # single step debug mode error handling return
  463. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  464. node_error_args: dict[str, Any] = {
  465. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  466. "error": node_run_result.error,
  467. "inputs": node_run_result.inputs,
  468. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  469. }
  470. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  471. node_run_result = NodeRunResult(
  472. **node_error_args,
  473. outputs={
  474. **node_instance.node_data.default_value_dict,
  475. "error_message": node_run_result.error,
  476. "error_type": node_run_result.error_type,
  477. },
  478. )
  479. else:
  480. node_run_result = NodeRunResult(
  481. **node_error_args,
  482. outputs={
  483. "error_message": node_run_result.error,
  484. "error_type": node_run_result.error_type,
  485. },
  486. )
  487. run_succeeded = node_run_result.status in (
  488. WorkflowNodeExecutionStatus.SUCCEEDED,
  489. WorkflowNodeExecutionStatus.EXCEPTION,
  490. )
  491. error = node_run_result.error if not run_succeeded else None
  492. except WorkflowNodeRunFailedError as e:
  493. node_instance = e.node_instance
  494. run_succeeded = False
  495. node_run_result = None
  496. error = e.error
  497. workflow_node_execution = WorkflowNodeExecution()
  498. workflow_node_execution.id = str(uuid4())
  499. workflow_node_execution.tenant_id = tenant_id
  500. workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
  501. workflow_node_execution.index = 1
  502. workflow_node_execution.node_id = node_id
  503. workflow_node_execution.node_type = node_instance.node_type
  504. workflow_node_execution.title = node_instance.node_data.title
  505. workflow_node_execution.elapsed_time = time.perf_counter() - start_at
  506. workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
  507. workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
  508. workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
  509. if run_succeeded and node_run_result:
  510. # create workflow node execution
  511. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  512. process_data = (
  513. WorkflowEntry.handle_special_values(node_run_result.process_data)
  514. if node_run_result.process_data
  515. else None
  516. )
  517. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  518. workflow_node_execution.inputs = json.dumps(inputs)
  519. workflow_node_execution.process_data = json.dumps(process_data)
  520. workflow_node_execution.outputs = json.dumps(outputs)
  521. workflow_node_execution.execution_metadata = (
  522. json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
  523. )
  524. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  525. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
  526. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  527. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
  528. workflow_node_execution.error = node_run_result.error
  529. else:
  530. # create workflow node execution
  531. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
  532. workflow_node_execution.error = error
  533. return workflow_node_execution
  534. def update_workflow(
  535. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  536. ) -> Optional[Workflow]:
  537. """
  538. Update workflow attributes
  539. :param session: SQLAlchemy database session
  540. :param workflow_id: Workflow ID
  541. :param tenant_id: Tenant ID
  542. :param account_id: Account ID (for permission check)
  543. :param data: Dictionary containing fields to update
  544. :return: Updated workflow or None if not found
  545. """
  546. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  547. workflow = session.scalar(stmt)
  548. if not workflow:
  549. return None
  550. allowed_fields = ["marked_name", "marked_comment"]
  551. for field, value in data.items():
  552. if field in allowed_fields:
  553. setattr(workflow, field, value)
  554. workflow.updated_by = account_id
  555. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  556. return workflow
  557. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  558. """
  559. Get second step parameters of rag pipeline
  560. """
  561. workflow = self.get_published_workflow(pipeline=pipeline)
  562. if not workflow:
  563. raise ValueError("Workflow not initialized")
  564. # get second step node
  565. rag_pipeline_variables = workflow.rag_pipeline_variables
  566. if not rag_pipeline_variables:
  567. return []
  568. # get datasource provider
  569. datasource_provider_variables = [
  570. item
  571. for item in rag_pipeline_variables
  572. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  573. ]
  574. return datasource_provider_variables
  575. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  576. """
  577. Get second step parameters of rag pipeline
  578. """
  579. workflow = self.get_draft_workflow(pipeline=pipeline)
  580. if not workflow:
  581. raise ValueError("Workflow not initialized")
  582. # get second step node
  583. rag_pipeline_variables = workflow.rag_pipeline_variables
  584. if not rag_pipeline_variables:
  585. return []
  586. # get datasource provider
  587. datasource_provider_variables = [
  588. item
  589. for item in rag_pipeline_variables
  590. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  591. ]
  592. return datasource_provider_variables
  593. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  594. """
  595. Get debug workflow run list
  596. Only return triggered_from == debugging
  597. :param app_model: app model
  598. :param args: request args
  599. """
  600. limit = int(args.get("limit", 20))
  601. base_query = db.session.query(WorkflowRun).filter(
  602. WorkflowRun.tenant_id == pipeline.tenant_id,
  603. WorkflowRun.app_id == pipeline.id,
  604. or_(
  605. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  606. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  607. ),
  608. )
  609. if args.get("last_id"):
  610. last_workflow_run = base_query.filter(
  611. WorkflowRun.id == args.get("last_id"),
  612. ).first()
  613. if not last_workflow_run:
  614. raise ValueError("Last workflow run not exists")
  615. workflow_runs = (
  616. base_query.filter(
  617. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  618. )
  619. .order_by(WorkflowRun.created_at.desc())
  620. .limit(limit)
  621. .all()
  622. )
  623. else:
  624. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  625. has_more = False
  626. if len(workflow_runs) == limit:
  627. current_page_first_workflow_run = workflow_runs[-1]
  628. rest_count = base_query.filter(
  629. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  630. WorkflowRun.id != current_page_first_workflow_run.id,
  631. ).count()
  632. if rest_count > 0:
  633. has_more = True
  634. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  635. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  636. """
  637. Get workflow run detail
  638. :param app_model: app model
  639. :param run_id: workflow run id
  640. """
  641. workflow_run = (
  642. db.session.query(WorkflowRun)
  643. .filter(
  644. WorkflowRun.tenant_id == pipeline.tenant_id,
  645. WorkflowRun.app_id == pipeline.id,
  646. WorkflowRun.id == run_id,
  647. )
  648. .first()
  649. )
  650. return workflow_run
  651. def get_rag_pipeline_workflow_run_node_executions(
  652. self,
  653. pipeline: Pipeline,
  654. run_id: str,
  655. user: Account | EndUser,
  656. ) -> list[WorkflowNodeExecution]:
  657. """
  658. Get workflow run node execution list
  659. """
  660. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  661. contexts.plugin_tool_providers.set({})
  662. contexts.plugin_tool_providers_lock.set(threading.Lock())
  663. if not workflow_run:
  664. return []
  665. # Use the repository to get the node execution
  666. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  667. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  668. )
  669. # Use the repository to get the node executions with ordering
  670. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  671. node_executions = repository.get_by_workflow_run(
  672. workflow_run_id=run_id,
  673. order_config=order_config,
  674. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  675. )
  676. # Convert domain models to database models
  677. workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
  678. return workflow_node_executions
  679. @classmethod
  680. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  681. """
  682. Publish customized pipeline template
  683. """
  684. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  685. if not pipeline:
  686. raise ValueError("Pipeline not found")
  687. if not pipeline.workflow_id:
  688. raise ValueError("Pipeline workflow not found")
  689. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  690. if not workflow:
  691. raise ValueError("Workflow not found")
  692. db.session.commit()