You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 36KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933
  1. import json
  2. import re
  3. import threading
  4. import time
  5. from collections.abc import Callable, Generator, Sequence
  6. from datetime import UTC, datetime
  7. from typing import Any, Optional, cast
  8. from uuid import uuid4
  9. from flask_login import current_user
  10. from sqlalchemy import or_, select
  11. from sqlalchemy.orm import Session
  12. import contexts
  13. from configs import dify_config
  14. from core.datasource.entities.datasource_entities import (
  15. DatasourceProviderType,
  16. GetOnlineDocumentPagesResponse,
  17. GetWebsiteCrawlResponse,
  18. )
  19. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  20. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  21. from core.model_runtime.utils.encoders import jsonable_encoder
  22. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  23. from core.variables.variables import Variable
  24. from core.workflow.entities.node_entities import NodeRunResult
  25. from core.workflow.entities.workflow_node_execution import (
  26. WorkflowNodeExecution,
  27. WorkflowNodeExecutionStatus,
  28. )
  29. from core.workflow.errors import WorkflowNodeRunFailedError
  30. from core.workflow.graph_engine.entities.event import InNodeEvent
  31. from core.workflow.nodes.base.node import BaseNode
  32. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  33. from core.workflow.nodes.event.event import RunCompletedEvent
  34. from core.workflow.nodes.event.types import NodeEvent
  35. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  36. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  37. from core.workflow.workflow_entry import WorkflowEntry
  38. from extensions.ext_database import db
  39. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  40. from models.account import Account
  41. from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
  42. from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
  43. from models.model import EndUser
  44. from models.oauth import DatasourceProvider
  45. from models.workflow import (
  46. Workflow,
  47. WorkflowNodeExecutionTriggeredFrom,
  48. WorkflowRun,
  49. WorkflowType,
  50. )
  51. from services.dataset_service import DatasetService
  52. from services.datasource_provider_service import DatasourceProviderService
  53. from services.entities.knowledge_entities.rag_pipeline_entities import (
  54. KnowledgeConfiguration,
  55. PipelineTemplateInfoEntity,
  56. )
  57. from services.errors.app import WorkflowHashNotEqualError
  58. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  59. class RagPipelineService:
  60. @classmethod
  61. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  62. if type == "built-in":
  63. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  64. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  65. result = retrieval_instance.get_pipeline_templates(language)
  66. if not result.get("pipeline_templates") and language != "en-US":
  67. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  68. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  69. return result
  70. else:
  71. mode = "customized"
  72. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  73. result = retrieval_instance.get_pipeline_templates(language)
  74. return result
  75. @classmethod
  76. def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
  77. """
  78. Get pipeline template detail.
  79. :param template_id: template id
  80. :return:
  81. """
  82. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  83. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  84. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  85. return result
  86. @classmethod
  87. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  88. """
  89. Update pipeline template.
  90. :param template_id: template id
  91. :param template_info: template info
  92. """
  93. customized_template: PipelineCustomizedTemplate | None = (
  94. db.query(PipelineCustomizedTemplate)
  95. .filter(
  96. PipelineCustomizedTemplate.id == template_id,
  97. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  98. )
  99. .first()
  100. )
  101. if not customized_template:
  102. raise ValueError("Customized pipeline template not found.")
  103. customized_template.name = template_info.name
  104. customized_template.description = template_info.description
  105. customized_template.icon = template_info.icon_info.model_dump()
  106. db.commit()
  107. return customized_template
  108. @classmethod
  109. def delete_customized_pipeline_template(cls, template_id: str):
  110. """
  111. Delete customized pipeline template.
  112. """
  113. customized_template: PipelineCustomizedTemplate | None = (
  114. db.query(PipelineCustomizedTemplate)
  115. .filter(
  116. PipelineCustomizedTemplate.id == template_id,
  117. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  118. )
  119. .first()
  120. )
  121. if not customized_template:
  122. raise ValueError("Customized pipeline template not found.")
  123. db.delete(customized_template)
  124. db.commit()
  125. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  126. """
  127. Get draft workflow
  128. """
  129. # fetch draft workflow by rag pipeline
  130. workflow = (
  131. db.session.query(Workflow)
  132. .filter(
  133. Workflow.tenant_id == pipeline.tenant_id,
  134. Workflow.app_id == pipeline.id,
  135. Workflow.version == "draft",
  136. )
  137. .first()
  138. )
  139. # return draft workflow
  140. return workflow
  141. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  142. """
  143. Get published workflow
  144. """
  145. if not pipeline.workflow_id:
  146. return None
  147. # fetch published workflow by workflow_id
  148. workflow = (
  149. db.session.query(Workflow)
  150. .filter(
  151. Workflow.tenant_id == pipeline.tenant_id,
  152. Workflow.app_id == pipeline.id,
  153. Workflow.id == pipeline.workflow_id,
  154. )
  155. .first()
  156. )
  157. return workflow
  158. def get_all_published_workflow(
  159. self,
  160. *,
  161. session: Session,
  162. pipeline: Pipeline,
  163. page: int,
  164. limit: int,
  165. user_id: str | None,
  166. named_only: bool = False,
  167. ) -> tuple[Sequence[Workflow], bool]:
  168. """
  169. Get published workflow with pagination
  170. """
  171. if not pipeline.workflow_id:
  172. return [], False
  173. stmt = (
  174. select(Workflow)
  175. .where(Workflow.app_id == pipeline.id)
  176. .order_by(Workflow.version.desc())
  177. .limit(limit + 1)
  178. .offset((page - 1) * limit)
  179. )
  180. if user_id:
  181. stmt = stmt.where(Workflow.created_by == user_id)
  182. if named_only:
  183. stmt = stmt.where(Workflow.marked_name != "")
  184. workflows = session.scalars(stmt).all()
  185. has_more = len(workflows) > limit
  186. if has_more:
  187. workflows = workflows[:-1]
  188. return workflows, has_more
  189. def sync_draft_workflow(
  190. self,
  191. *,
  192. pipeline: Pipeline,
  193. graph: dict,
  194. unique_hash: Optional[str],
  195. account: Account,
  196. environment_variables: Sequence[Variable],
  197. conversation_variables: Sequence[Variable],
  198. rag_pipeline_variables: list,
  199. ) -> Workflow:
  200. """
  201. Sync draft workflow
  202. :raises WorkflowHashNotEqualError
  203. """
  204. # fetch draft workflow by app_model
  205. workflow = self.get_draft_workflow(pipeline=pipeline)
  206. if workflow and workflow.unique_hash != unique_hash:
  207. raise WorkflowHashNotEqualError()
  208. # create draft workflow if not found
  209. if not workflow:
  210. workflow = Workflow(
  211. tenant_id=pipeline.tenant_id,
  212. app_id=pipeline.id,
  213. features="{}",
  214. type=WorkflowType.RAG_PIPELINE.value,
  215. version="draft",
  216. graph=json.dumps(graph),
  217. created_by=account.id,
  218. environment_variables=environment_variables,
  219. conversation_variables=conversation_variables,
  220. rag_pipeline_variables=rag_pipeline_variables,
  221. )
  222. db.session.add(workflow)
  223. db.session.flush()
  224. pipeline.workflow_id = workflow.id
  225. # update draft workflow if found
  226. else:
  227. workflow.graph = json.dumps(graph)
  228. workflow.updated_by = account.id
  229. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  230. workflow.environment_variables = environment_variables
  231. workflow.conversation_variables = conversation_variables
  232. workflow.rag_pipeline_variables = rag_pipeline_variables
  233. # commit db session changes
  234. db.session.commit()
  235. # trigger workflow events TODO
  236. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  237. # return draft workflow
  238. return workflow
  239. def publish_workflow(
  240. self,
  241. *,
  242. session: Session,
  243. pipeline: Pipeline,
  244. account: Account,
  245. ) -> Workflow:
  246. draft_workflow_stmt = select(Workflow).where(
  247. Workflow.tenant_id == pipeline.tenant_id,
  248. Workflow.app_id == pipeline.id,
  249. Workflow.version == "draft",
  250. )
  251. draft_workflow = session.scalar(draft_workflow_stmt)
  252. if not draft_workflow:
  253. raise ValueError("No valid workflow found.")
  254. # create new workflow
  255. workflow = Workflow.new(
  256. tenant_id=pipeline.tenant_id,
  257. app_id=pipeline.id,
  258. type=draft_workflow.type,
  259. version=str(datetime.now(UTC).replace(tzinfo=None)),
  260. graph=draft_workflow.graph,
  261. features=draft_workflow.features,
  262. created_by=account.id,
  263. environment_variables=draft_workflow.environment_variables,
  264. conversation_variables=draft_workflow.conversation_variables,
  265. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  266. marked_name="",
  267. marked_comment="",
  268. )
  269. # commit db session changes
  270. session.add(workflow)
  271. graph = workflow.graph_dict
  272. nodes = graph.get("nodes", [])
  273. for node in nodes:
  274. if node.get("data", {}).get("type") == "knowledge-index":
  275. knowledge_configuration = node.get("data", {})
  276. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  277. # update dataset
  278. dataset = pipeline.dataset
  279. if not dataset:
  280. raise ValueError("Dataset not found")
  281. DatasetService.update_rag_pipeline_dataset_settings(
  282. session=session,
  283. dataset=dataset,
  284. knowledge_configuration=knowledge_configuration,
  285. has_published=pipeline.is_published,
  286. )
  287. # return new workflow
  288. return workflow
  289. def get_default_block_configs(self) -> list[dict]:
  290. """
  291. Get default block configs
  292. """
  293. # return default block config
  294. default_block_configs = []
  295. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  296. node_class = node_class_mapping[LATEST_VERSION]
  297. default_config = node_class.get_default_config()
  298. if default_config:
  299. default_block_configs.append(default_config)
  300. return default_block_configs
  301. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  302. """
  303. Get default config of node.
  304. :param node_type: node type
  305. :param filters: filter by node config parameters.
  306. :return:
  307. """
  308. node_type_enum = NodeType(node_type)
  309. # return default block config
  310. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  311. return None
  312. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  313. default_config = node_class.get_default_config(filters=filters)
  314. if not default_config:
  315. return None
  316. return default_config
  317. def run_draft_workflow_node(
  318. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  319. ) -> WorkflowNodeExecution:
  320. """
  321. Run draft workflow node
  322. """
  323. # fetch draft workflow by app_model
  324. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  325. if not draft_workflow:
  326. raise ValueError("Workflow not initialized")
  327. # run draft workflow node
  328. start_at = time.perf_counter()
  329. workflow_node_execution = self._handle_node_run_result(
  330. getter=lambda: WorkflowEntry.single_step_run(
  331. workflow=draft_workflow,
  332. node_id=node_id,
  333. user_inputs=user_inputs,
  334. user_id=account.id,
  335. ),
  336. start_at=start_at,
  337. tenant_id=pipeline.tenant_id,
  338. node_id=node_id,
  339. )
  340. workflow_node_execution.workflow_id = draft_workflow.id
  341. db.session.add(workflow_node_execution)
  342. db.session.commit()
  343. return workflow_node_execution
  344. def run_published_workflow_node(
  345. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  346. ) -> WorkflowNodeExecution:
  347. """
  348. Run published workflow node
  349. """
  350. # fetch published workflow by app_model
  351. published_workflow = self.get_published_workflow(pipeline=pipeline)
  352. if not published_workflow:
  353. raise ValueError("Workflow not initialized")
  354. # run draft workflow node
  355. start_at = time.perf_counter()
  356. workflow_node_execution = self._handle_node_run_result(
  357. getter=lambda: WorkflowEntry.single_step_run(
  358. workflow=published_workflow,
  359. node_id=node_id,
  360. user_inputs=user_inputs,
  361. user_id=account.id,
  362. ),
  363. start_at=start_at,
  364. tenant_id=pipeline.tenant_id,
  365. node_id=node_id,
  366. )
  367. workflow_node_execution.workflow_id = published_workflow.id
  368. db.session.add(workflow_node_execution)
  369. db.session.commit()
  370. return workflow_node_execution
  371. def run_datasource_workflow_node_status(
  372. self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
  373. ) -> dict:
  374. """
  375. Run published workflow datasource
  376. """
  377. if is_published:
  378. # fetch published workflow by app_model
  379. workflow = self.get_published_workflow(pipeline=pipeline)
  380. else:
  381. workflow = self.get_draft_workflow(pipeline=pipeline)
  382. if not workflow:
  383. raise ValueError("Workflow not initialized")
  384. # run draft workflow node
  385. datasource_node_data = None
  386. start_at = time.perf_counter()
  387. datasource_nodes = workflow.graph_dict.get("nodes", [])
  388. for datasource_node in datasource_nodes:
  389. if datasource_node.get("id") == node_id:
  390. datasource_node_data = datasource_node.get("data", {})
  391. break
  392. if not datasource_node_data:
  393. raise ValueError("Datasource node data not found")
  394. from core.datasource.datasource_manager import DatasourceManager
  395. datasource_runtime = DatasourceManager.get_datasource_runtime(
  396. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  397. datasource_name=datasource_node_data.get("datasource_name"),
  398. tenant_id=pipeline.tenant_id,
  399. datasource_type=DatasourceProviderType(datasource_type),
  400. )
  401. datasource_provider_service = DatasourceProviderService()
  402. credentials = datasource_provider_service.get_real_datasource_credentials(
  403. tenant_id=pipeline.tenant_id,
  404. provider=datasource_node_data.get('provider_name'),
  405. plugin_id=datasource_node_data.get('plugin_id'),
  406. )
  407. if credentials:
  408. datasource_runtime.runtime.credentials = credentials[0].get("credentials")
  409. match datasource_type:
  410. case DatasourceProviderType.WEBSITE_CRAWL:
  411. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  412. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  413. user_id=account.id,
  414. datasource_parameters={"job_id": job_id},
  415. provider_type=datasource_runtime.datasource_provider_type(),
  416. )
  417. return {
  418. "result": [result.model_dump() for result in website_crawl_result.result],
  419. "job_id": website_crawl_result.job_id,
  420. "status": website_crawl_result.status,
  421. "provider_type": datasource_node_data.get("provider_type"),
  422. }
  423. case _:
  424. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  425. def run_datasource_workflow_node(
  426. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool
  427. ) -> dict:
  428. """
  429. Run published workflow datasource
  430. """
  431. if is_published:
  432. # fetch published workflow by app_model
  433. workflow = self.get_published_workflow(pipeline=pipeline)
  434. else:
  435. workflow = self.get_draft_workflow(pipeline=pipeline)
  436. if not workflow:
  437. raise ValueError("Workflow not initialized")
  438. # run draft workflow node
  439. datasource_node_data = None
  440. start_at = time.perf_counter()
  441. datasource_nodes = workflow.graph_dict.get("nodes", [])
  442. for datasource_node in datasource_nodes:
  443. if datasource_node.get("id") == node_id:
  444. datasource_node_data = datasource_node.get("data", {})
  445. break
  446. if not datasource_node_data:
  447. raise ValueError("Datasource node data not found")
  448. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  449. for key, value in datasource_parameters.items():
  450. if not user_inputs.get(key):
  451. user_inputs[key] = value["value"]
  452. from core.datasource.datasource_manager import DatasourceManager
  453. datasource_runtime = DatasourceManager.get_datasource_runtime(
  454. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  455. datasource_name=datasource_node_data.get("datasource_name"),
  456. tenant_id=pipeline.tenant_id,
  457. datasource_type=DatasourceProviderType(datasource_type),
  458. )
  459. datasource_provider_service = DatasourceProviderService()
  460. credentials = datasource_provider_service.get_real_datasource_credentials(
  461. tenant_id=pipeline.tenant_id,
  462. provider=datasource_node_data.get('provider_name'),
  463. plugin_id=datasource_node_data.get('plugin_id'),
  464. )
  465. if credentials:
  466. datasource_runtime.runtime.credentials = credentials[0].get("credentials")
  467. match datasource_type:
  468. case DatasourceProviderType.ONLINE_DOCUMENT:
  469. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  470. online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
  471. user_id=account.id,
  472. datasource_parameters=user_inputs,
  473. provider_type=datasource_runtime.datasource_provider_type(),
  474. )
  475. return {
  476. "result": [page.model_dump() for page in online_document_result.result],
  477. "provider_type": datasource_node_data.get("provider_type"),
  478. }
  479. case DatasourceProviderType.WEBSITE_CRAWL:
  480. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  481. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  482. user_id=account.id,
  483. datasource_parameters=user_inputs,
  484. provider_type=datasource_runtime.datasource_provider_type(),
  485. )
  486. return {
  487. "result": [result.model_dump() for result in website_crawl_result.result],
  488. "provider_type": datasource_node_data.get("provider_type"),
  489. }
  490. case _:
  491. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  492. def run_free_workflow_node(
  493. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  494. ) -> WorkflowNodeExecution:
  495. """
  496. Run draft workflow node
  497. """
  498. # run draft workflow node
  499. start_at = time.perf_counter()
  500. workflow_node_execution = self._handle_node_run_result(
  501. getter=lambda: WorkflowEntry.run_free_node(
  502. node_id=node_id,
  503. node_data=node_data,
  504. tenant_id=tenant_id,
  505. user_id=user_id,
  506. user_inputs=user_inputs,
  507. ),
  508. start_at=start_at,
  509. tenant_id=tenant_id,
  510. node_id=node_id,
  511. )
  512. return workflow_node_execution
  513. def _handle_node_run_result(
  514. self,
  515. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  516. start_at: float,
  517. tenant_id: str,
  518. node_id: str,
  519. ) -> WorkflowNodeExecution:
  520. """
  521. Handle node run result
  522. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  523. :param start_at: float
  524. :param tenant_id: str
  525. :param node_id: str
  526. """
  527. try:
  528. node_instance, generator = getter()
  529. node_run_result: NodeRunResult | None = None
  530. for event in generator:
  531. if isinstance(event, RunCompletedEvent):
  532. node_run_result = event.run_result
  533. # sign output files
  534. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  535. break
  536. if not node_run_result:
  537. raise ValueError("Node run failed with no run result")
  538. # single step debug mode error handling return
  539. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  540. node_error_args: dict[str, Any] = {
  541. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  542. "error": node_run_result.error,
  543. "inputs": node_run_result.inputs,
  544. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  545. }
  546. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  547. node_run_result = NodeRunResult(
  548. **node_error_args,
  549. outputs={
  550. **node_instance.node_data.default_value_dict,
  551. "error_message": node_run_result.error,
  552. "error_type": node_run_result.error_type,
  553. },
  554. )
  555. else:
  556. node_run_result = NodeRunResult(
  557. **node_error_args,
  558. outputs={
  559. "error_message": node_run_result.error,
  560. "error_type": node_run_result.error_type,
  561. },
  562. )
  563. run_succeeded = node_run_result.status in (
  564. WorkflowNodeExecutionStatus.SUCCEEDED,
  565. WorkflowNodeExecutionStatus.EXCEPTION,
  566. )
  567. error = node_run_result.error if not run_succeeded else None
  568. except WorkflowNodeRunFailedError as e:
  569. node_instance = e.node_instance
  570. run_succeeded = False
  571. node_run_result = None
  572. error = e.error
  573. workflow_node_execution = WorkflowNodeExecution(
  574. id=str(uuid4()),
  575. workflow_id=node_instance.workflow_id,
  576. index=1,
  577. node_id=node_id,
  578. node_type=node_instance.node_type,
  579. title=node_instance.node_data.title,
  580. elapsed_time=time.perf_counter() - start_at,
  581. finished_at=datetime.now(UTC).replace(tzinfo=None),
  582. created_at=datetime.now(UTC).replace(tzinfo=None),
  583. )
  584. if run_succeeded and node_run_result:
  585. # create workflow node execution
  586. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  587. process_data = (
  588. WorkflowEntry.handle_special_values(node_run_result.process_data)
  589. if node_run_result.process_data
  590. else None
  591. )
  592. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  593. workflow_node_execution.inputs = inputs
  594. workflow_node_execution.process_data = process_data
  595. workflow_node_execution.outputs = outputs
  596. workflow_node_execution.metadata = node_run_result.metadata
  597. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  598. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  599. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  600. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  601. workflow_node_execution.error = node_run_result.error
  602. else:
  603. # create workflow node execution
  604. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
  605. workflow_node_execution.error = error
  606. return workflow_node_execution
  607. def update_workflow(
  608. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  609. ) -> Optional[Workflow]:
  610. """
  611. Update workflow attributes
  612. :param session: SQLAlchemy database session
  613. :param workflow_id: Workflow ID
  614. :param tenant_id: Tenant ID
  615. :param account_id: Account ID (for permission check)
  616. :param data: Dictionary containing fields to update
  617. :return: Updated workflow or None if not found
  618. """
  619. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  620. workflow = session.scalar(stmt)
  621. if not workflow:
  622. return None
  623. allowed_fields = ["marked_name", "marked_comment"]
  624. for field, value in data.items():
  625. if field in allowed_fields:
  626. setattr(workflow, field, value)
  627. workflow.updated_by = account_id
  628. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  629. return workflow
  630. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  631. """
  632. Get second step parameters of rag pipeline
  633. """
  634. workflow = self.get_published_workflow(pipeline=pipeline)
  635. if not workflow:
  636. raise ValueError("Workflow not initialized")
  637. # get second step node
  638. rag_pipeline_variables = workflow.rag_pipeline_variables
  639. if not rag_pipeline_variables:
  640. return []
  641. # get datasource provider
  642. datasource_provider_variables = [
  643. item
  644. for item in rag_pipeline_variables
  645. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  646. ]
  647. return datasource_provider_variables
  648. def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  649. """
  650. Get first step parameters of rag pipeline
  651. """
  652. published_workflow = self.get_published_workflow(pipeline=pipeline)
  653. if not published_workflow:
  654. raise ValueError("Workflow not initialized")
  655. # get second step node
  656. datasource_node_data = None
  657. datasource_nodes = published_workflow.graph_dict.get("nodes", [])
  658. for datasource_node in datasource_nodes:
  659. if datasource_node.get("id") == node_id:
  660. datasource_node_data = datasource_node.get("data", {})
  661. break
  662. if not datasource_node_data:
  663. raise ValueError("Datasource node data not found")
  664. variables = datasource_node_data.get("variables", {})
  665. if variables:
  666. variables_map = {
  667. item["variable"]: item for item in variables
  668. }
  669. else:
  670. return []
  671. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  672. user_input_variables = []
  673. for key, value in datasource_parameters.items():
  674. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  675. user_input_variables.append(variables_map.get(key, {}))
  676. return user_input_variables
  677. def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  678. """
  679. Get first step parameters of rag pipeline
  680. """
  681. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  682. if not draft_workflow:
  683. raise ValueError("Workflow not initialized")
  684. # get second step node
  685. datasource_node_data = None
  686. datasource_nodes = draft_workflow.graph_dict.get("nodes", [])
  687. for datasource_node in datasource_nodes:
  688. if datasource_node.get("id") == node_id:
  689. datasource_node_data = datasource_node.get("data", {})
  690. break
  691. if not datasource_node_data:
  692. raise ValueError("Datasource node data not found")
  693. variables = datasource_node_data.get("variables", {})
  694. if variables:
  695. variables_map = {
  696. item["variable"]: item for item in variables
  697. }
  698. else:
  699. return []
  700. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  701. user_input_variables = []
  702. for key, value in datasource_parameters.items():
  703. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  704. user_input_variables.append(variables_map.get(key, {}))
  705. return user_input_variables
  706. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  707. """
  708. Get second step parameters of rag pipeline
  709. """
  710. workflow = self.get_draft_workflow(pipeline=pipeline)
  711. if not workflow:
  712. raise ValueError("Workflow not initialized")
  713. # get second step node
  714. rag_pipeline_variables = workflow.rag_pipeline_variables
  715. if not rag_pipeline_variables:
  716. return []
  717. # get datasource provider
  718. datasource_provider_variables = [
  719. item
  720. for item in rag_pipeline_variables
  721. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  722. ]
  723. return datasource_provider_variables
  724. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  725. """
  726. Get debug workflow run list
  727. Only return triggered_from == debugging
  728. :param app_model: app model
  729. :param args: request args
  730. """
  731. limit = int(args.get("limit", 20))
  732. base_query = db.session.query(WorkflowRun).filter(
  733. WorkflowRun.tenant_id == pipeline.tenant_id,
  734. WorkflowRun.app_id == pipeline.id,
  735. or_(
  736. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  737. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  738. ),
  739. )
  740. if args.get("last_id"):
  741. last_workflow_run = base_query.filter(
  742. WorkflowRun.id == args.get("last_id"),
  743. ).first()
  744. if not last_workflow_run:
  745. raise ValueError("Last workflow run not exists")
  746. workflow_runs = (
  747. base_query.filter(
  748. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  749. )
  750. .order_by(WorkflowRun.created_at.desc())
  751. .limit(limit)
  752. .all()
  753. )
  754. else:
  755. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  756. has_more = False
  757. if len(workflow_runs) == limit:
  758. current_page_first_workflow_run = workflow_runs[-1]
  759. rest_count = base_query.filter(
  760. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  761. WorkflowRun.id != current_page_first_workflow_run.id,
  762. ).count()
  763. if rest_count > 0:
  764. has_more = True
  765. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  766. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  767. """
  768. Get workflow run detail
  769. :param app_model: app model
  770. :param run_id: workflow run id
  771. """
  772. workflow_run = (
  773. db.session.query(WorkflowRun)
  774. .filter(
  775. WorkflowRun.tenant_id == pipeline.tenant_id,
  776. WorkflowRun.app_id == pipeline.id,
  777. WorkflowRun.id == run_id,
  778. )
  779. .first()
  780. )
  781. return workflow_run
  782. def get_rag_pipeline_workflow_run_node_executions(
  783. self,
  784. pipeline: Pipeline,
  785. run_id: str,
  786. user: Account | EndUser,
  787. ) -> list[WorkflowNodeExecution]:
  788. """
  789. Get workflow run node execution list
  790. """
  791. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  792. contexts.plugin_tool_providers.set({})
  793. contexts.plugin_tool_providers_lock.set(threading.Lock())
  794. if not workflow_run:
  795. return []
  796. # Use the repository to get the node execution
  797. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  798. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  799. )
  800. # Use the repository to get the node executions with ordering
  801. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  802. node_executions = repository.get_by_workflow_run(
  803. workflow_run_id=run_id,
  804. order_config=order_config,
  805. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  806. )
  807. return list(node_executions)
  808. @classmethod
  809. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  810. """
  811. Publish customized pipeline template
  812. """
  813. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  814. if not pipeline:
  815. raise ValueError("Pipeline not found")
  816. if not pipeline.workflow_id:
  817. raise ValueError("Pipeline workflow not found")
  818. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  819. if not workflow:
  820. raise ValueError("Workflow not found")
  821. db.session.commit()