You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 27KB

6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
6 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
6 miesięcy temu
5 miesięcy temu
5 miesięcy temu
5 miesięcy temu
5 miesięcy temu
5 miesięcy temu
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. import json
  2. import threading
  3. import time
  4. from collections.abc import Callable, Generator, Sequence
  5. from datetime import UTC, datetime
  6. from typing import Any, Optional
  7. from uuid import uuid4
  8. from flask_login import current_user
  9. from sqlalchemy import select
  10. from sqlalchemy.orm import Session
  11. import contexts
  12. from configs import dify_config
  13. from core.model_runtime.utils.encoders import jsonable_encoder
  14. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  15. from core.variables.variables import Variable
  16. from core.workflow.entities.node_entities import NodeRunResult
  17. from core.workflow.errors import WorkflowNodeRunFailedError
  18. from core.workflow.graph_engine.entities.event import InNodeEvent
  19. from core.workflow.nodes.base.node import BaseNode
  20. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  21. from core.workflow.nodes.event.event import RunCompletedEvent
  22. from core.workflow.nodes.event.types import NodeEvent
  23. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  24. from core.workflow.repository.workflow_node_execution_repository import OrderConfig
  25. from core.workflow.workflow_entry import WorkflowEntry
  26. from extensions.ext_database import db
  27. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  28. from models.account import Account
  29. from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
  30. from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
  31. from models.workflow import (
  32. Workflow,
  33. WorkflowNodeExecution,
  34. WorkflowNodeExecutionStatus,
  35. WorkflowNodeExecutionTriggeredFrom,
  36. WorkflowRun,
  37. WorkflowType,
  38. )
  39. from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
  40. from services.errors.app import WorkflowHashNotEqualError
  41. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  42. class RagPipelineService:
  43. @staticmethod
  44. def get_pipeline_templates(
  45. type: str = "built-in", language: str = "en-US"
  46. ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
  47. if type == "built-in":
  48. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  49. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  50. result = retrieval_instance.get_pipeline_templates(language)
  51. if not result.get("pipeline_templates") and language != "en-US":
  52. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  53. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  54. return result.get("pipeline_templates")
  55. else:
  56. mode = "customized"
  57. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  58. result = retrieval_instance.get_pipeline_templates(language)
  59. return result.get("pipeline_templates")
  60. @classmethod
  61. def get_pipeline_template_detail(cls, pipeline_id: str) -> Optional[dict]:
  62. """
  63. Get pipeline template detail.
  64. :param pipeline_id: pipeline id
  65. :return:
  66. """
  67. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  68. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
  69. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(pipeline_id)
  70. return result
  71. @classmethod
  72. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  73. """
  74. Update pipeline template.
  75. :param template_id: template id
  76. :param template_info: template info
  77. """
  78. customized_template: PipelineCustomizedTemplate | None = (
  79. db.query(PipelineCustomizedTemplate)
  80. .filter(
  81. PipelineCustomizedTemplate.id == template_id,
  82. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  83. )
  84. .first()
  85. )
  86. if not customized_template:
  87. raise ValueError("Customized pipeline template not found.")
  88. customized_template.name = template_info.name
  89. customized_template.description = template_info.description
  90. customized_template.icon = template_info.icon_info.model_dump()
  91. db.commit()
  92. return customized_template
  93. @classmethod
  94. def delete_customized_pipeline_template(cls, template_id: str):
  95. """
  96. Delete customized pipeline template.
  97. """
  98. customized_template: PipelineCustomizedTemplate | None = (
  99. db.query(PipelineCustomizedTemplate)
  100. .filter(
  101. PipelineCustomizedTemplate.id == template_id,
  102. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  103. )
  104. .first()
  105. )
  106. if not customized_template:
  107. raise ValueError("Customized pipeline template not found.")
  108. db.delete(customized_template)
  109. db.commit()
  110. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  111. """
  112. Get draft workflow
  113. """
  114. # fetch draft workflow by rag pipeline
  115. workflow = (
  116. db.session.query(Workflow)
  117. .filter(
  118. Workflow.tenant_id == pipeline.tenant_id,
  119. Workflow.app_id == pipeline.id,
  120. Workflow.version == "draft",
  121. )
  122. .first()
  123. )
  124. # return draft workflow
  125. return workflow
  126. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  127. """
  128. Get published workflow
  129. """
  130. if not pipeline.workflow_id:
  131. return None
  132. # fetch published workflow by workflow_id
  133. workflow = (
  134. db.session.query(Workflow)
  135. .filter(
  136. Workflow.tenant_id == pipeline.tenant_id,
  137. Workflow.app_id == pipeline.id,
  138. Workflow.id == pipeline.workflow_id,
  139. )
  140. .first()
  141. )
  142. return workflow
  143. def get_all_published_workflow(
  144. self,
  145. *,
  146. session: Session,
  147. pipeline: Pipeline,
  148. page: int,
  149. limit: int,
  150. user_id: str | None,
  151. named_only: bool = False,
  152. ) -> tuple[Sequence[Workflow], bool]:
  153. """
  154. Get published workflow with pagination
  155. """
  156. if not pipeline.workflow_id:
  157. return [], False
  158. stmt = (
  159. select(Workflow)
  160. .where(Workflow.app_id == pipeline.id)
  161. .order_by(Workflow.version.desc())
  162. .limit(limit + 1)
  163. .offset((page - 1) * limit)
  164. )
  165. if user_id:
  166. stmt = stmt.where(Workflow.created_by == user_id)
  167. if named_only:
  168. stmt = stmt.where(Workflow.marked_name != "")
  169. workflows = session.scalars(stmt).all()
  170. has_more = len(workflows) > limit
  171. if has_more:
  172. workflows = workflows[:-1]
  173. return workflows, has_more
  174. def sync_draft_workflow(
  175. self,
  176. *,
  177. pipeline: Pipeline,
  178. graph: dict,
  179. unique_hash: Optional[str],
  180. account: Account,
  181. environment_variables: Sequence[Variable],
  182. conversation_variables: Sequence[Variable],
  183. rag_pipeline_variables: list,
  184. ) -> Workflow:
  185. """
  186. Sync draft workflow
  187. :raises WorkflowHashNotEqualError
  188. """
  189. # fetch draft workflow by app_model
  190. workflow = self.get_draft_workflow(pipeline=pipeline)
  191. if workflow and workflow.unique_hash != unique_hash:
  192. raise WorkflowHashNotEqualError()
  193. # create draft workflow if not found
  194. if not workflow:
  195. workflow = Workflow(
  196. tenant_id=pipeline.tenant_id,
  197. app_id=pipeline.id,
  198. features="{}",
  199. type=WorkflowType.RAG_PIPELINE.value,
  200. version="draft",
  201. graph=json.dumps(graph),
  202. created_by=account.id,
  203. environment_variables=environment_variables,
  204. conversation_variables=conversation_variables,
  205. rag_pipeline_variables=rag_pipeline_variables,
  206. )
  207. db.session.add(workflow)
  208. db.session.flush()
  209. pipeline.workflow_id = workflow.id
  210. # update draft workflow if found
  211. else:
  212. workflow.graph = json.dumps(graph)
  213. workflow.updated_by = account.id
  214. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  215. workflow.environment_variables = environment_variables
  216. workflow.conversation_variables = conversation_variables
  217. workflow.rag_pipeline_variables = rag_pipeline_variables
  218. # commit db session changes
  219. db.session.commit()
  220. # trigger workflow events TODO
  221. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  222. # return draft workflow
  223. return workflow
  224. def publish_workflow(
  225. self,
  226. *,
  227. session: Session,
  228. pipeline: Pipeline,
  229. account: Account,
  230. marked_name: str = "",
  231. marked_comment: str = "",
  232. ) -> Workflow:
  233. draft_workflow_stmt = select(Workflow).where(
  234. Workflow.tenant_id == pipeline.tenant_id,
  235. Workflow.app_id == pipeline.id,
  236. Workflow.version == "draft",
  237. )
  238. draft_workflow = session.scalar(draft_workflow_stmt)
  239. if not draft_workflow:
  240. raise ValueError("No valid workflow found.")
  241. # create new workflow
  242. workflow = Workflow.new(
  243. tenant_id=pipeline.tenant_id,
  244. app_id=pipeline.id,
  245. type=draft_workflow.type,
  246. version=str(datetime.now(UTC).replace(tzinfo=None)),
  247. graph=draft_workflow.graph,
  248. features=draft_workflow.features,
  249. created_by=account.id,
  250. environment_variables=draft_workflow.environment_variables,
  251. conversation_variables=draft_workflow.conversation_variables,
  252. marked_name=marked_name,
  253. marked_comment=marked_comment,
  254. )
  255. # commit db session changes
  256. session.add(workflow)
  257. # trigger app workflow events TODO
  258. # app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
  259. # return new workflow
  260. return workflow
  261. def get_default_block_configs(self) -> list[dict]:
  262. """
  263. Get default block configs
  264. """
  265. # return default block config
  266. default_block_configs = []
  267. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  268. node_class = node_class_mapping[LATEST_VERSION]
  269. default_config = node_class.get_default_config()
  270. if default_config:
  271. default_block_configs.append(default_config)
  272. return default_block_configs
  273. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  274. """
  275. Get default config of node.
  276. :param node_type: node type
  277. :param filters: filter by node config parameters.
  278. :return:
  279. """
  280. node_type_enum = NodeType(node_type)
  281. # return default block config
  282. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  283. return None
  284. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  285. default_config = node_class.get_default_config(filters=filters)
  286. if not default_config:
  287. return None
  288. return default_config
  289. def run_draft_workflow_node(
  290. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  291. ) -> WorkflowNodeExecution:
  292. """
  293. Run draft workflow node
  294. """
  295. # fetch draft workflow by app_model
  296. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  297. if not draft_workflow:
  298. raise ValueError("Workflow not initialized")
  299. # run draft workflow node
  300. start_at = time.perf_counter()
  301. workflow_node_execution = self._handle_node_run_result(
  302. getter=lambda: WorkflowEntry.single_step_run(
  303. workflow=draft_workflow,
  304. node_id=node_id,
  305. user_inputs=user_inputs,
  306. user_id=account.id,
  307. ),
  308. start_at=start_at,
  309. tenant_id=pipeline.tenant_id,
  310. node_id=node_id,
  311. )
  312. workflow_node_execution.app_id = pipeline.id
  313. workflow_node_execution.created_by = account.id
  314. workflow_node_execution.workflow_id = draft_workflow.id
  315. db.session.add(workflow_node_execution)
  316. db.session.commit()
  317. return workflow_node_execution
  318. def run_published_workflow_node(
  319. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  320. ) -> WorkflowNodeExecution:
  321. """
  322. Run published workflow node
  323. """
  324. # fetch published workflow by app_model
  325. published_workflow = self.get_published_workflow(pipeline=pipeline)
  326. if not published_workflow:
  327. raise ValueError("Workflow not initialized")
  328. # run draft workflow node
  329. start_at = time.perf_counter()
  330. workflow_node_execution = self._handle_node_run_result(
  331. getter=lambda: WorkflowEntry.single_step_run(
  332. workflow=published_workflow,
  333. node_id=node_id,
  334. user_inputs=user_inputs,
  335. user_id=account.id,
  336. ),
  337. start_at=start_at,
  338. tenant_id=pipeline.tenant_id,
  339. node_id=node_id,
  340. )
  341. workflow_node_execution.app_id = pipeline.id
  342. workflow_node_execution.created_by = account.id
  343. workflow_node_execution.workflow_id = published_workflow.id
  344. db.session.add(workflow_node_execution)
  345. db.session.commit()
  346. return workflow_node_execution
  347. def run_datasource_workflow_node(
  348. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  349. ) -> WorkflowNodeExecution:
  350. """
  351. Run published workflow datasource
  352. """
  353. # fetch published workflow by app_model
  354. published_workflow = self.get_published_workflow(pipeline=pipeline)
  355. if not published_workflow:
  356. raise ValueError("Workflow not initialized")
  357. # run draft workflow node
  358. start_at = time.perf_counter()
  359. datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {})
  360. if not datasource_node_data:
  361. raise ValueError("Datasource node data not found")
  362. from core.datasource.datasource_manager import DatasourceManager
  363. datasource_runtime = DatasourceManager.get_datasource_runtime(
  364. provider_id=datasource_node_data.get("provider_id"),
  365. datasource_name=datasource_node_data.get("datasource_name"),
  366. tenant_id=pipeline.tenant_id,
  367. )
  368. result = datasource_runtime._invoke_first_step(
  369. inputs=user_inputs,
  370. provider_type=datasource_node_data.get("provider_type"),
  371. user_id=account.id,
  372. )
  373. return {
  374. "result": result,
  375. "provider_type": datasource_node_data.get("provider_type"),
  376. }
  377. def run_free_workflow_node(
  378. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  379. ) -> WorkflowNodeExecution:
  380. """
  381. Run draft workflow node
  382. """
  383. # run draft workflow node
  384. start_at = time.perf_counter()
  385. workflow_node_execution = self._handle_node_run_result(
  386. getter=lambda: WorkflowEntry.run_free_node(
  387. node_id=node_id,
  388. node_data=node_data,
  389. tenant_id=tenant_id,
  390. user_id=user_id,
  391. user_inputs=user_inputs,
  392. ),
  393. start_at=start_at,
  394. tenant_id=tenant_id,
  395. node_id=node_id,
  396. )
  397. return workflow_node_execution
  398. def _handle_node_run_result(
  399. self,
  400. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  401. start_at: float,
  402. tenant_id: str,
  403. node_id: str,
  404. ) -> WorkflowNodeExecution:
  405. """
  406. Handle node run result
  407. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  408. :param start_at: float
  409. :param tenant_id: str
  410. :param node_id: str
  411. """
  412. try:
  413. node_instance, generator = getter()
  414. node_run_result: NodeRunResult | None = None
  415. for event in generator:
  416. if isinstance(event, RunCompletedEvent):
  417. node_run_result = event.run_result
  418. # sign output files
  419. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  420. break
  421. if not node_run_result:
  422. raise ValueError("Node run failed with no run result")
  423. # single step debug mode error handling return
  424. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  425. node_error_args: dict[str, Any] = {
  426. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  427. "error": node_run_result.error,
  428. "inputs": node_run_result.inputs,
  429. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  430. }
  431. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  432. node_run_result = NodeRunResult(
  433. **node_error_args,
  434. outputs={
  435. **node_instance.node_data.default_value_dict,
  436. "error_message": node_run_result.error,
  437. "error_type": node_run_result.error_type,
  438. },
  439. )
  440. else:
  441. node_run_result = NodeRunResult(
  442. **node_error_args,
  443. outputs={
  444. "error_message": node_run_result.error,
  445. "error_type": node_run_result.error_type,
  446. },
  447. )
  448. run_succeeded = node_run_result.status in (
  449. WorkflowNodeExecutionStatus.SUCCEEDED,
  450. WorkflowNodeExecutionStatus.EXCEPTION,
  451. )
  452. error = node_run_result.error if not run_succeeded else None
  453. except WorkflowNodeRunFailedError as e:
  454. node_instance = e.node_instance
  455. run_succeeded = False
  456. node_run_result = None
  457. error = e.error
  458. workflow_node_execution = WorkflowNodeExecution()
  459. workflow_node_execution.id = str(uuid4())
  460. workflow_node_execution.tenant_id = tenant_id
  461. workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
  462. workflow_node_execution.index = 1
  463. workflow_node_execution.node_id = node_id
  464. workflow_node_execution.node_type = node_instance.node_type
  465. workflow_node_execution.title = node_instance.node_data.title
  466. workflow_node_execution.elapsed_time = time.perf_counter() - start_at
  467. workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
  468. workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
  469. workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
  470. if run_succeeded and node_run_result:
  471. # create workflow node execution
  472. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  473. process_data = (
  474. WorkflowEntry.handle_special_values(node_run_result.process_data)
  475. if node_run_result.process_data
  476. else None
  477. )
  478. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  479. workflow_node_execution.inputs = json.dumps(inputs)
  480. workflow_node_execution.process_data = json.dumps(process_data)
  481. workflow_node_execution.outputs = json.dumps(outputs)
  482. workflow_node_execution.execution_metadata = (
  483. json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
  484. )
  485. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  486. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
  487. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  488. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
  489. workflow_node_execution.error = node_run_result.error
  490. else:
  491. # create workflow node execution
  492. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
  493. workflow_node_execution.error = error
  494. return workflow_node_execution
  495. def update_workflow(
  496. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  497. ) -> Optional[Workflow]:
  498. """
  499. Update workflow attributes
  500. :param session: SQLAlchemy database session
  501. :param workflow_id: Workflow ID
  502. :param tenant_id: Tenant ID
  503. :param account_id: Account ID (for permission check)
  504. :param data: Dictionary containing fields to update
  505. :return: Updated workflow or None if not found
  506. """
  507. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  508. workflow = session.scalar(stmt)
  509. if not workflow:
  510. return None
  511. allowed_fields = ["marked_name", "marked_comment"]
  512. for field, value in data.items():
  513. if field in allowed_fields:
  514. setattr(workflow, field, value)
  515. workflow.updated_by = account_id
  516. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  517. return workflow
  518. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
  519. """
  520. Get second step parameters of rag pipeline
  521. """
  522. workflow = self.get_published_workflow(pipeline=pipeline)
  523. if not workflow:
  524. raise ValueError("Workflow not initialized")
  525. # get second step node
  526. rag_pipeline_variables = workflow.rag_pipeline_variables
  527. if not rag_pipeline_variables:
  528. return {}
  529. # get datasource provider
  530. datasource_provider_variables = [
  531. item
  532. for item in rag_pipeline_variables
  533. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  534. ]
  535. return datasource_provider_variables
  536. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
  537. """
  538. Get second step parameters of rag pipeline
  539. """
  540. workflow = self.get_draft_workflow(pipeline=pipeline)
  541. if not workflow:
  542. raise ValueError("Workflow not initialized")
  543. # get second step node
  544. rag_pipeline_variables = workflow.rag_pipeline_variables
  545. if not rag_pipeline_variables:
  546. return {}
  547. # get datasource provider
  548. datasource_provider_variables = [
  549. item
  550. for item in rag_pipeline_variables
  551. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  552. ]
  553. return datasource_provider_variables
  554. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  555. """
  556. Get debug workflow run list
  557. Only return triggered_from == debugging
  558. :param app_model: app model
  559. :param args: request args
  560. """
  561. limit = int(args.get("limit", 20))
  562. base_query = db.session.query(WorkflowRun).filter(
  563. WorkflowRun.tenant_id == pipeline.tenant_id,
  564. WorkflowRun.app_id == pipeline.id,
  565. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
  566. )
  567. if args.get("last_id"):
  568. last_workflow_run = base_query.filter(
  569. WorkflowRun.id == args.get("last_id"),
  570. ).first()
  571. if not last_workflow_run:
  572. raise ValueError("Last workflow run not exists")
  573. workflow_runs = (
  574. base_query.filter(
  575. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  576. )
  577. .order_by(WorkflowRun.created_at.desc())
  578. .limit(limit)
  579. .all()
  580. )
  581. else:
  582. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  583. has_more = False
  584. if len(workflow_runs) == limit:
  585. current_page_first_workflow_run = workflow_runs[-1]
  586. rest_count = base_query.filter(
  587. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  588. WorkflowRun.id != current_page_first_workflow_run.id,
  589. ).count()
  590. if rest_count > 0:
  591. has_more = True
  592. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  593. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  594. """
  595. Get workflow run detail
  596. :param app_model: app model
  597. :param run_id: workflow run id
  598. """
  599. workflow_run = (
  600. db.session.query(WorkflowRun)
  601. .filter(
  602. WorkflowRun.tenant_id == pipeline.tenant_id,
  603. WorkflowRun.app_id == pipeline.id,
  604. WorkflowRun.id == run_id,
  605. )
  606. .first()
  607. )
  608. return workflow_run
  609. def get_rag_pipeline_workflow_run_node_executions(
  610. self,
  611. pipeline: Pipeline,
  612. run_id: str,
  613. ) -> list[WorkflowNodeExecution]:
  614. """
  615. Get workflow run node execution list
  616. """
  617. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  618. contexts.plugin_tool_providers.set({})
  619. contexts.plugin_tool_providers_lock.set(threading.Lock())
  620. if not workflow_run:
  621. return []
  622. # Use the repository to get the node execution
  623. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  624. session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id
  625. )
  626. # Use the repository to get the node executions with ordering
  627. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  628. node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
  629. return list(node_executions)