您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rag_pipeline.py 33KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. import json
  2. import re
  3. import threading
  4. import time
  5. from collections.abc import Callable, Generator, Sequence
  6. from datetime import UTC, datetime
  7. from typing import Any, Optional, cast
  8. from uuid import uuid4
  9. from flask_login import current_user
  10. from sqlalchemy import or_, select
  11. from sqlalchemy.orm import Session
  12. import contexts
  13. from configs import dify_config
  14. from core.datasource.entities.datasource_entities import (
  15. DatasourceProviderType,
  16. GetOnlineDocumentPagesResponse,
  17. GetWebsiteCrawlResponse,
  18. )
  19. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  20. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  21. from core.model_runtime.utils.encoders import jsonable_encoder
  22. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  23. from core.variables.variables import Variable
  24. from core.workflow.entities.node_entities import NodeRunResult
  25. from core.workflow.entities.workflow_node_execution import (
  26. WorkflowNodeExecution,
  27. WorkflowNodeExecutionStatus,
  28. )
  29. from core.workflow.errors import WorkflowNodeRunFailedError
  30. from core.workflow.graph_engine.entities.event import InNodeEvent
  31. from core.workflow.nodes.base.node import BaseNode
  32. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  33. from core.workflow.nodes.event.event import RunCompletedEvent
  34. from core.workflow.nodes.event.types import NodeEvent
  35. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  36. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  37. from core.workflow.workflow_entry import WorkflowEntry
  38. from extensions.ext_database import db
  39. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  40. from models.account import Account
  41. from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
  42. from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
  43. from models.model import EndUser
  44. from models.workflow import (
  45. Workflow,
  46. WorkflowNodeExecutionTriggeredFrom,
  47. WorkflowRun,
  48. WorkflowType,
  49. )
  50. from services.dataset_service import DatasetService
  51. from services.entities.knowledge_entities.rag_pipeline_entities import (
  52. KnowledgeConfiguration,
  53. PipelineTemplateInfoEntity,
  54. )
  55. from services.errors.app import WorkflowHashNotEqualError
  56. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  57. class RagPipelineService:
  58. @classmethod
  59. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  60. if type == "built-in":
  61. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  62. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  63. result = retrieval_instance.get_pipeline_templates(language)
  64. if not result.get("pipeline_templates") and language != "en-US":
  65. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  66. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  67. return result
  68. else:
  69. mode = "customized"
  70. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  71. result = retrieval_instance.get_pipeline_templates(language)
  72. return result
  73. @classmethod
  74. def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
  75. """
  76. Get pipeline template detail.
  77. :param template_id: template id
  78. :return:
  79. """
  80. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  81. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  82. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  83. return result
  84. @classmethod
  85. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  86. """
  87. Update pipeline template.
  88. :param template_id: template id
  89. :param template_info: template info
  90. """
  91. customized_template: PipelineCustomizedTemplate | None = (
  92. db.query(PipelineCustomizedTemplate)
  93. .filter(
  94. PipelineCustomizedTemplate.id == template_id,
  95. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  96. )
  97. .first()
  98. )
  99. if not customized_template:
  100. raise ValueError("Customized pipeline template not found.")
  101. customized_template.name = template_info.name
  102. customized_template.description = template_info.description
  103. customized_template.icon = template_info.icon_info.model_dump()
  104. db.commit()
  105. return customized_template
  106. @classmethod
  107. def delete_customized_pipeline_template(cls, template_id: str):
  108. """
  109. Delete customized pipeline template.
  110. """
  111. customized_template: PipelineCustomizedTemplate | None = (
  112. db.query(PipelineCustomizedTemplate)
  113. .filter(
  114. PipelineCustomizedTemplate.id == template_id,
  115. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  116. )
  117. .first()
  118. )
  119. if not customized_template:
  120. raise ValueError("Customized pipeline template not found.")
  121. db.delete(customized_template)
  122. db.commit()
  123. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  124. """
  125. Get draft workflow
  126. """
  127. # fetch draft workflow by rag pipeline
  128. workflow = (
  129. db.session.query(Workflow)
  130. .filter(
  131. Workflow.tenant_id == pipeline.tenant_id,
  132. Workflow.app_id == pipeline.id,
  133. Workflow.version == "draft",
  134. )
  135. .first()
  136. )
  137. # return draft workflow
  138. return workflow
  139. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  140. """
  141. Get published workflow
  142. """
  143. if not pipeline.workflow_id:
  144. return None
  145. # fetch published workflow by workflow_id
  146. workflow = (
  147. db.session.query(Workflow)
  148. .filter(
  149. Workflow.tenant_id == pipeline.tenant_id,
  150. Workflow.app_id == pipeline.id,
  151. Workflow.id == pipeline.workflow_id,
  152. )
  153. .first()
  154. )
  155. return workflow
  156. def get_all_published_workflow(
  157. self,
  158. *,
  159. session: Session,
  160. pipeline: Pipeline,
  161. page: int,
  162. limit: int,
  163. user_id: str | None,
  164. named_only: bool = False,
  165. ) -> tuple[Sequence[Workflow], bool]:
  166. """
  167. Get published workflow with pagination
  168. """
  169. if not pipeline.workflow_id:
  170. return [], False
  171. stmt = (
  172. select(Workflow)
  173. .where(Workflow.app_id == pipeline.id)
  174. .order_by(Workflow.version.desc())
  175. .limit(limit + 1)
  176. .offset((page - 1) * limit)
  177. )
  178. if user_id:
  179. stmt = stmt.where(Workflow.created_by == user_id)
  180. if named_only:
  181. stmt = stmt.where(Workflow.marked_name != "")
  182. workflows = session.scalars(stmt).all()
  183. has_more = len(workflows) > limit
  184. if has_more:
  185. workflows = workflows[:-1]
  186. return workflows, has_more
  187. def sync_draft_workflow(
  188. self,
  189. *,
  190. pipeline: Pipeline,
  191. graph: dict,
  192. unique_hash: Optional[str],
  193. account: Account,
  194. environment_variables: Sequence[Variable],
  195. conversation_variables: Sequence[Variable],
  196. rag_pipeline_variables: list,
  197. ) -> Workflow:
  198. """
  199. Sync draft workflow
  200. :raises WorkflowHashNotEqualError
  201. """
  202. # fetch draft workflow by app_model
  203. workflow = self.get_draft_workflow(pipeline=pipeline)
  204. if workflow and workflow.unique_hash != unique_hash:
  205. raise WorkflowHashNotEqualError()
  206. # create draft workflow if not found
  207. if not workflow:
  208. workflow = Workflow(
  209. tenant_id=pipeline.tenant_id,
  210. app_id=pipeline.id,
  211. features="{}",
  212. type=WorkflowType.RAG_PIPELINE.value,
  213. version="draft",
  214. graph=json.dumps(graph),
  215. created_by=account.id,
  216. environment_variables=environment_variables,
  217. conversation_variables=conversation_variables,
  218. rag_pipeline_variables=rag_pipeline_variables,
  219. )
  220. db.session.add(workflow)
  221. db.session.flush()
  222. pipeline.workflow_id = workflow.id
  223. # update draft workflow if found
  224. else:
  225. workflow.graph = json.dumps(graph)
  226. workflow.updated_by = account.id
  227. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  228. workflow.environment_variables = environment_variables
  229. workflow.conversation_variables = conversation_variables
  230. workflow.rag_pipeline_variables = rag_pipeline_variables
  231. # commit db session changes
  232. db.session.commit()
  233. # trigger workflow events TODO
  234. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  235. # return draft workflow
  236. return workflow
  237. def publish_workflow(
  238. self,
  239. *,
  240. session: Session,
  241. pipeline: Pipeline,
  242. account: Account,
  243. ) -> Workflow:
  244. draft_workflow_stmt = select(Workflow).where(
  245. Workflow.tenant_id == pipeline.tenant_id,
  246. Workflow.app_id == pipeline.id,
  247. Workflow.version == "draft",
  248. )
  249. draft_workflow = session.scalar(draft_workflow_stmt)
  250. if not draft_workflow:
  251. raise ValueError("No valid workflow found.")
  252. # create new workflow
  253. workflow = Workflow.new(
  254. tenant_id=pipeline.tenant_id,
  255. app_id=pipeline.id,
  256. type=draft_workflow.type,
  257. version=str(datetime.now(UTC).replace(tzinfo=None)),
  258. graph=draft_workflow.graph,
  259. features=draft_workflow.features,
  260. created_by=account.id,
  261. environment_variables=draft_workflow.environment_variables,
  262. conversation_variables=draft_workflow.conversation_variables,
  263. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  264. marked_name="",
  265. marked_comment="",
  266. )
  267. # commit db session changes
  268. session.add(workflow)
  269. graph = workflow.graph_dict
  270. nodes = graph.get("nodes", [])
  271. for node in nodes:
  272. if node.get("data", {}).get("type") == "knowledge-index":
  273. knowledge_configuration = node.get("data", {})
  274. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  275. # update dataset
  276. dataset = pipeline.dataset
  277. if not dataset:
  278. raise ValueError("Dataset not found")
  279. DatasetService.update_rag_pipeline_dataset_settings(
  280. session=session,
  281. dataset=dataset,
  282. knowledge_configuration=knowledge_configuration,
  283. has_published=pipeline.is_published,
  284. )
  285. # return new workflow
  286. return workflow
  287. def get_default_block_configs(self) -> list[dict]:
  288. """
  289. Get default block configs
  290. """
  291. # return default block config
  292. default_block_configs = []
  293. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  294. node_class = node_class_mapping[LATEST_VERSION]
  295. default_config = node_class.get_default_config()
  296. if default_config:
  297. default_block_configs.append(default_config)
  298. return default_block_configs
  299. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  300. """
  301. Get default config of node.
  302. :param node_type: node type
  303. :param filters: filter by node config parameters.
  304. :return:
  305. """
  306. node_type_enum = NodeType(node_type)
  307. # return default block config
  308. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  309. return None
  310. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  311. default_config = node_class.get_default_config(filters=filters)
  312. if not default_config:
  313. return None
  314. return default_config
  315. def run_draft_workflow_node(
  316. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  317. ) -> WorkflowNodeExecution:
  318. """
  319. Run draft workflow node
  320. """
  321. # fetch draft workflow by app_model
  322. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  323. if not draft_workflow:
  324. raise ValueError("Workflow not initialized")
  325. # run draft workflow node
  326. start_at = time.perf_counter()
  327. workflow_node_execution = self._handle_node_run_result(
  328. getter=lambda: WorkflowEntry.single_step_run(
  329. workflow=draft_workflow,
  330. node_id=node_id,
  331. user_inputs=user_inputs,
  332. user_id=account.id,
  333. ),
  334. start_at=start_at,
  335. tenant_id=pipeline.tenant_id,
  336. node_id=node_id,
  337. )
  338. workflow_node_execution.app_id = pipeline.id
  339. workflow_node_execution.created_by = account.id
  340. workflow_node_execution.workflow_id = draft_workflow.id
  341. db.session.add(workflow_node_execution)
  342. db.session.commit()
  343. return workflow_node_execution
  344. def run_published_workflow_node(
  345. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  346. ) -> WorkflowNodeExecution:
  347. """
  348. Run published workflow node
  349. """
  350. # fetch published workflow by app_model
  351. published_workflow = self.get_published_workflow(pipeline=pipeline)
  352. if not published_workflow:
  353. raise ValueError("Workflow not initialized")
  354. # run draft workflow node
  355. start_at = time.perf_counter()
  356. workflow_node_execution = self._handle_node_run_result(
  357. getter=lambda: WorkflowEntry.single_step_run(
  358. workflow=published_workflow,
  359. node_id=node_id,
  360. user_inputs=user_inputs,
  361. user_id=account.id,
  362. ),
  363. start_at=start_at,
  364. tenant_id=pipeline.tenant_id,
  365. node_id=node_id,
  366. )
  367. workflow_node_execution.app_id = pipeline.id
  368. workflow_node_execution.created_by = account.id
  369. workflow_node_execution.workflow_id = published_workflow.id
  370. db.session.add(workflow_node_execution)
  371. db.session.commit()
  372. return workflow_node_execution
  373. def run_datasource_workflow_node(
  374. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str
  375. ) -> dict:
  376. """
  377. Run published workflow datasource
  378. """
  379. # fetch published workflow by app_model
  380. published_workflow = self.get_published_workflow(pipeline=pipeline)
  381. if not published_workflow:
  382. raise ValueError("Workflow not initialized")
  383. # run draft workflow node
  384. datasource_node_data = None
  385. start_at = time.perf_counter()
  386. datasource_nodes = published_workflow.graph_dict.get("nodes", [])
  387. for datasource_node in datasource_nodes:
  388. if datasource_node.get("id") == node_id:
  389. datasource_node_data = datasource_node.get("data", {})
  390. break
  391. if not datasource_node_data:
  392. raise ValueError("Datasource node data not found")
  393. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  394. for key, value in datasource_parameters.items():
  395. if not user_inputs.get(key):
  396. user_inputs[key] = value["value"]
  397. from core.datasource.datasource_manager import DatasourceManager
  398. datasource_runtime = DatasourceManager.get_datasource_runtime(
  399. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  400. datasource_name=datasource_node_data.get("datasource_name"),
  401. tenant_id=pipeline.tenant_id,
  402. datasource_type=DatasourceProviderType(datasource_type),
  403. )
  404. if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
  405. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  406. online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
  407. user_id=account.id,
  408. datasource_parameters=user_inputs,
  409. provider_type=datasource_runtime.datasource_provider_type(),
  410. )
  411. return {
  412. "result": [page.model_dump() for page in online_document_result.result],
  413. "provider_type": datasource_node_data.get("provider_type"),
  414. }
  415. elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
  416. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  417. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  418. user_id=account.id,
  419. datasource_parameters=user_inputs,
  420. provider_type=datasource_runtime.datasource_provider_type(),
  421. )
  422. return {
  423. "result": [result.model_dump() for result in website_crawl_result.result],
  424. "provider_type": datasource_node_data.get("provider_type"),
  425. }
  426. else:
  427. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  428. def run_free_workflow_node(
  429. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  430. ) -> WorkflowNodeExecution:
  431. """
  432. Run draft workflow node
  433. """
  434. # run draft workflow node
  435. start_at = time.perf_counter()
  436. workflow_node_execution = self._handle_node_run_result(
  437. getter=lambda: WorkflowEntry.run_free_node(
  438. node_id=node_id,
  439. node_data=node_data,
  440. tenant_id=tenant_id,
  441. user_id=user_id,
  442. user_inputs=user_inputs,
  443. ),
  444. start_at=start_at,
  445. tenant_id=tenant_id,
  446. node_id=node_id,
  447. )
  448. return workflow_node_execution
  449. def _handle_node_run_result(
  450. self,
  451. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  452. start_at: float,
  453. tenant_id: str,
  454. node_id: str,
  455. ) -> WorkflowNodeExecution:
  456. """
  457. Handle node run result
  458. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  459. :param start_at: float
  460. :param tenant_id: str
  461. :param node_id: str
  462. """
  463. try:
  464. node_instance, generator = getter()
  465. node_run_result: NodeRunResult | None = None
  466. for event in generator:
  467. if isinstance(event, RunCompletedEvent):
  468. node_run_result = event.run_result
  469. # sign output files
  470. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  471. break
  472. if not node_run_result:
  473. raise ValueError("Node run failed with no run result")
  474. # single step debug mode error handling return
  475. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  476. node_error_args: dict[str, Any] = {
  477. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  478. "error": node_run_result.error,
  479. "inputs": node_run_result.inputs,
  480. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  481. }
  482. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  483. node_run_result = NodeRunResult(
  484. **node_error_args,
  485. outputs={
  486. **node_instance.node_data.default_value_dict,
  487. "error_message": node_run_result.error,
  488. "error_type": node_run_result.error_type,
  489. },
  490. )
  491. else:
  492. node_run_result = NodeRunResult(
  493. **node_error_args,
  494. outputs={
  495. "error_message": node_run_result.error,
  496. "error_type": node_run_result.error_type,
  497. },
  498. )
  499. run_succeeded = node_run_result.status in (
  500. WorkflowNodeExecutionStatus.SUCCEEDED,
  501. WorkflowNodeExecutionStatus.EXCEPTION,
  502. )
  503. error = node_run_result.error if not run_succeeded else None
  504. except WorkflowNodeRunFailedError as e:
  505. node_instance = e.node_instance
  506. run_succeeded = False
  507. node_run_result = None
  508. error = e.error
  509. workflow_node_execution = WorkflowNodeExecution()
  510. workflow_node_execution.id = str(uuid4())
  511. workflow_node_execution.tenant_id = tenant_id
  512. workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
  513. workflow_node_execution.index = 1
  514. workflow_node_execution.node_id = node_id
  515. workflow_node_execution.node_type = node_instance.node_type
  516. workflow_node_execution.title = node_instance.node_data.title
  517. workflow_node_execution.elapsed_time = time.perf_counter() - start_at
  518. workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
  519. workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
  520. workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
  521. if run_succeeded and node_run_result:
  522. # create workflow node execution
  523. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  524. process_data = (
  525. WorkflowEntry.handle_special_values(node_run_result.process_data)
  526. if node_run_result.process_data
  527. else None
  528. )
  529. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  530. workflow_node_execution.inputs = json.dumps(inputs)
  531. workflow_node_execution.process_data = json.dumps(process_data)
  532. workflow_node_execution.outputs = json.dumps(outputs)
  533. workflow_node_execution.execution_metadata = (
  534. json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
  535. )
  536. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  537. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
  538. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  539. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
  540. workflow_node_execution.error = node_run_result.error
  541. else:
  542. # create workflow node execution
  543. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
  544. workflow_node_execution.error = error
  545. return workflow_node_execution
  546. def update_workflow(
  547. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  548. ) -> Optional[Workflow]:
  549. """
  550. Update workflow attributes
  551. :param session: SQLAlchemy database session
  552. :param workflow_id: Workflow ID
  553. :param tenant_id: Tenant ID
  554. :param account_id: Account ID (for permission check)
  555. :param data: Dictionary containing fields to update
  556. :return: Updated workflow or None if not found
  557. """
  558. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  559. workflow = session.scalar(stmt)
  560. if not workflow:
  561. return None
  562. allowed_fields = ["marked_name", "marked_comment"]
  563. for field, value in data.items():
  564. if field in allowed_fields:
  565. setattr(workflow, field, value)
  566. workflow.updated_by = account_id
  567. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  568. return workflow
  569. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  570. """
  571. Get second step parameters of rag pipeline
  572. """
  573. workflow = self.get_published_workflow(pipeline=pipeline)
  574. if not workflow:
  575. raise ValueError("Workflow not initialized")
  576. # get second step node
  577. rag_pipeline_variables = workflow.rag_pipeline_variables
  578. if not rag_pipeline_variables:
  579. return []
  580. # get datasource provider
  581. datasource_provider_variables = [
  582. item
  583. for item in rag_pipeline_variables
  584. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  585. ]
  586. return datasource_provider_variables
  587. def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  588. """
  589. Get first step parameters of rag pipeline
  590. """
  591. published_workflow = self.get_published_workflow(pipeline=pipeline)
  592. if not published_workflow:
  593. raise ValueError("Workflow not initialized")
  594. # get second step node
  595. datasource_node_data = None
  596. datasource_nodes = published_workflow.graph_dict.get("nodes", [])
  597. for datasource_node in datasource_nodes:
  598. if datasource_node.get("id") == node_id:
  599. datasource_node_data = datasource_node.get("data", {})
  600. break
  601. if not datasource_node_data:
  602. raise ValueError("Datasource node data not found")
  603. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  604. if datasource_parameters:
  605. datasource_parameters_map = {
  606. item["variable"]: item for item in datasource_parameters
  607. }
  608. else:
  609. datasource_parameters_map = {}
  610. variables = datasource_node_data.get("variables", {})
  611. user_input_variables = []
  612. for key, value in variables.items():
  613. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  614. user_input_variables.append(datasource_parameters_map.get(key, {}))
  615. return user_input_variables
  616. def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  617. """
  618. Get first step parameters of rag pipeline
  619. """
  620. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  621. if not draft_workflow:
  622. raise ValueError("Workflow not initialized")
  623. # get second step node
  624. datasource_node_data = None
  625. datasource_nodes = draft_workflow.graph_dict.get("nodes", [])
  626. for datasource_node in datasource_nodes:
  627. if datasource_node.get("id") == node_id:
  628. datasource_node_data = datasource_node.get("data", {})
  629. break
  630. if not datasource_node_data:
  631. raise ValueError("Datasource node data not found")
  632. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  633. if datasource_parameters:
  634. datasource_parameters_map = {
  635. item["variable"]: item for item in datasource_parameters
  636. }
  637. else:
  638. datasource_parameters_map = {}
  639. variables = datasource_node_data.get("variables", {})
  640. user_input_variables = []
  641. for key, value in variables.items():
  642. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  643. user_input_variables.append(datasource_parameters_map.get(key, {}))
  644. return user_input_variables
  645. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  646. """
  647. Get second step parameters of rag pipeline
  648. """
  649. workflow = self.get_draft_workflow(pipeline=pipeline)
  650. if not workflow:
  651. raise ValueError("Workflow not initialized")
  652. # get second step node
  653. rag_pipeline_variables = workflow.rag_pipeline_variables
  654. if not rag_pipeline_variables:
  655. return []
  656. # get datasource provider
  657. datasource_provider_variables = [
  658. item
  659. for item in rag_pipeline_variables
  660. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  661. ]
  662. return datasource_provider_variables
  663. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  664. """
  665. Get debug workflow run list
  666. Only return triggered_from == debugging
  667. :param app_model: app model
  668. :param args: request args
  669. """
  670. limit = int(args.get("limit", 20))
  671. base_query = db.session.query(WorkflowRun).filter(
  672. WorkflowRun.tenant_id == pipeline.tenant_id,
  673. WorkflowRun.app_id == pipeline.id,
  674. or_(
  675. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  676. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  677. ),
  678. )
  679. if args.get("last_id"):
  680. last_workflow_run = base_query.filter(
  681. WorkflowRun.id == args.get("last_id"),
  682. ).first()
  683. if not last_workflow_run:
  684. raise ValueError("Last workflow run not exists")
  685. workflow_runs = (
  686. base_query.filter(
  687. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  688. )
  689. .order_by(WorkflowRun.created_at.desc())
  690. .limit(limit)
  691. .all()
  692. )
  693. else:
  694. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  695. has_more = False
  696. if len(workflow_runs) == limit:
  697. current_page_first_workflow_run = workflow_runs[-1]
  698. rest_count = base_query.filter(
  699. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  700. WorkflowRun.id != current_page_first_workflow_run.id,
  701. ).count()
  702. if rest_count > 0:
  703. has_more = True
  704. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  705. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  706. """
  707. Get workflow run detail
  708. :param app_model: app model
  709. :param run_id: workflow run id
  710. """
  711. workflow_run = (
  712. db.session.query(WorkflowRun)
  713. .filter(
  714. WorkflowRun.tenant_id == pipeline.tenant_id,
  715. WorkflowRun.app_id == pipeline.id,
  716. WorkflowRun.id == run_id,
  717. )
  718. .first()
  719. )
  720. return workflow_run
  721. def get_rag_pipeline_workflow_run_node_executions(
  722. self,
  723. pipeline: Pipeline,
  724. run_id: str,
  725. user: Account | EndUser,
  726. ) -> list[WorkflowNodeExecution]:
  727. """
  728. Get workflow run node execution list
  729. """
  730. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  731. contexts.plugin_tool_providers.set({})
  732. contexts.plugin_tool_providers_lock.set(threading.Lock())
  733. if not workflow_run:
  734. return []
  735. # Use the repository to get the node execution
  736. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  737. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  738. )
  739. # Use the repository to get the node executions with ordering
  740. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  741. node_executions = repository.get_by_workflow_run(
  742. workflow_run_id=run_id,
  743. order_config=order_config,
  744. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  745. )
  746. # Convert domain models to database models
  747. workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
  748. return workflow_node_executions
  749. @classmethod
  750. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  751. """
  752. Publish customized pipeline template
  753. """
  754. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  755. if not pipeline:
  756. raise ValueError("Pipeline not found")
  757. if not pipeline.workflow_id:
  758. raise ValueError("Pipeline workflow not found")
  759. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  760. if not workflow:
  761. raise ValueError("Workflow not found")
  762. db.session.commit()