您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rag_pipeline.py 38KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. import json
  2. import re
  3. import threading
  4. import time
  5. from collections.abc import Callable, Generator, Sequence
  6. from datetime import UTC, datetime
  7. from typing import Any, Optional, cast
  8. from uuid import uuid4
  9. from flask_login import current_user
  10. from sqlalchemy import func, or_, select
  11. from sqlalchemy.orm import Session
  12. import contexts
  13. from configs import dify_config
  14. from core.app.entities.app_invoke_entities import InvokeFrom
  15. from core.datasource.entities.datasource_entities import (
  16. DatasourceProviderType,
  17. GetOnlineDocumentPagesResponse,
  18. GetWebsiteCrawlResponse,
  19. )
  20. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  21. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  22. from core.model_runtime.utils.encoders import jsonable_encoder
  23. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  24. from core.variables.variables import Variable
  25. from core.workflow.entities.node_entities import NodeRunResult
  26. from core.workflow.entities.workflow_node_execution import (
  27. WorkflowNodeExecution,
  28. WorkflowNodeExecutionStatus,
  29. )
  30. from core.workflow.enums import SystemVariableKey
  31. from core.workflow.errors import WorkflowNodeRunFailedError
  32. from core.workflow.graph_engine.entities.event import InNodeEvent
  33. from core.workflow.nodes.base.node import BaseNode
  34. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  35. from core.workflow.nodes.event.event import RunCompletedEvent
  36. from core.workflow.nodes.event.types import NodeEvent
  37. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  38. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  39. from core.workflow.workflow_entry import WorkflowEntry
  40. from extensions.ext_database import db
  41. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  42. from models.account import Account
  43. from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
  44. from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
  45. from models.model import EndUser
  46. from models.oauth import DatasourceProvider
  47. from models.workflow import (
  48. Workflow,
  49. WorkflowNodeExecutionTriggeredFrom,
  50. WorkflowRun,
  51. WorkflowType,
  52. )
  53. from services.dataset_service import DatasetService
  54. from services.datasource_provider_service import DatasourceProviderService
  55. from services.entities.knowledge_entities.rag_pipeline_entities import (
  56. KnowledgeConfiguration,
  57. PipelineTemplateInfoEntity,
  58. )
  59. from services.errors.app import WorkflowHashNotEqualError
  60. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  61. class RagPipelineService:
  62. @classmethod
  63. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  64. if type == "built-in":
  65. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  66. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  67. result = retrieval_instance.get_pipeline_templates(language)
  68. if not result.get("pipeline_templates") and language != "en-US":
  69. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  70. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  71. return result
  72. else:
  73. mode = "customized"
  74. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  75. result = retrieval_instance.get_pipeline_templates(language)
  76. return result
  77. @classmethod
  78. def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
  79. """
  80. Get pipeline template detail.
  81. :param template_id: template id
  82. :return:
  83. """
  84. if type == "built-in":
  85. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  86. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  87. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  88. else:
  89. mode = "customized"
  90. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  91. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  92. return result
  93. @classmethod
  94. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  95. """
  96. Update pipeline template.
  97. :param template_id: template id
  98. :param template_info: template info
  99. """
  100. customized_template: PipelineCustomizedTemplate | None = (
  101. db.session.query(PipelineCustomizedTemplate)
  102. .filter(
  103. PipelineCustomizedTemplate.id == template_id,
  104. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  105. )
  106. .first()
  107. )
  108. if not customized_template:
  109. raise ValueError("Customized pipeline template not found.")
  110. customized_template.name = template_info.name
  111. customized_template.description = template_info.description
  112. customized_template.icon = template_info.icon_info.model_dump()
  113. customized_template.updated_by = current_user.id
  114. db.session.commit()
  115. return customized_template
  116. @classmethod
  117. def delete_customized_pipeline_template(cls, template_id: str):
  118. """
  119. Delete customized pipeline template.
  120. """
  121. customized_template: PipelineCustomizedTemplate | None = (
  122. db.session.query(PipelineCustomizedTemplate)
  123. .filter(
  124. PipelineCustomizedTemplate.id == template_id,
  125. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  126. )
  127. .first()
  128. )
  129. if not customized_template:
  130. raise ValueError("Customized pipeline template not found.")
  131. db.session.delete(customized_template)
  132. db.session.commit()
  133. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  134. """
  135. Get draft workflow
  136. """
  137. # fetch draft workflow by rag pipeline
  138. workflow = (
  139. db.session.query(Workflow)
  140. .filter(
  141. Workflow.tenant_id == pipeline.tenant_id,
  142. Workflow.app_id == pipeline.id,
  143. Workflow.version == "draft",
  144. )
  145. .first()
  146. )
  147. # return draft workflow
  148. return workflow
  149. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  150. """
  151. Get published workflow
  152. """
  153. if not pipeline.workflow_id:
  154. return None
  155. # fetch published workflow by workflow_id
  156. workflow = (
  157. db.session.query(Workflow)
  158. .filter(
  159. Workflow.tenant_id == pipeline.tenant_id,
  160. Workflow.app_id == pipeline.id,
  161. Workflow.id == pipeline.workflow_id,
  162. )
  163. .first()
  164. )
  165. return workflow
  166. def get_all_published_workflow(
  167. self,
  168. *,
  169. session: Session,
  170. pipeline: Pipeline,
  171. page: int,
  172. limit: int,
  173. user_id: str | None,
  174. named_only: bool = False,
  175. ) -> tuple[Sequence[Workflow], bool]:
  176. """
  177. Get published workflow with pagination
  178. """
  179. if not pipeline.workflow_id:
  180. return [], False
  181. stmt = (
  182. select(Workflow)
  183. .where(Workflow.app_id == pipeline.id)
  184. .order_by(Workflow.version.desc())
  185. .limit(limit + 1)
  186. .offset((page - 1) * limit)
  187. )
  188. if user_id:
  189. stmt = stmt.where(Workflow.created_by == user_id)
  190. if named_only:
  191. stmt = stmt.where(Workflow.marked_name != "")
  192. workflows = session.scalars(stmt).all()
  193. has_more = len(workflows) > limit
  194. if has_more:
  195. workflows = workflows[:-1]
  196. return workflows, has_more
  197. def sync_draft_workflow(
  198. self,
  199. *,
  200. pipeline: Pipeline,
  201. graph: dict,
  202. unique_hash: Optional[str],
  203. account: Account,
  204. environment_variables: Sequence[Variable],
  205. conversation_variables: Sequence[Variable],
  206. rag_pipeline_variables: list,
  207. ) -> Workflow:
  208. """
  209. Sync draft workflow
  210. :raises WorkflowHashNotEqualError
  211. """
  212. # fetch draft workflow by app_model
  213. workflow = self.get_draft_workflow(pipeline=pipeline)
  214. if workflow and workflow.unique_hash != unique_hash:
  215. raise WorkflowHashNotEqualError()
  216. # create draft workflow if not found
  217. if not workflow:
  218. workflow = Workflow(
  219. tenant_id=pipeline.tenant_id,
  220. app_id=pipeline.id,
  221. features="{}",
  222. type=WorkflowType.RAG_PIPELINE.value,
  223. version="draft",
  224. graph=json.dumps(graph),
  225. created_by=account.id,
  226. environment_variables=environment_variables,
  227. conversation_variables=conversation_variables,
  228. rag_pipeline_variables=rag_pipeline_variables,
  229. )
  230. db.session.add(workflow)
  231. db.session.flush()
  232. pipeline.workflow_id = workflow.id
  233. # update draft workflow if found
  234. else:
  235. workflow.graph = json.dumps(graph)
  236. workflow.updated_by = account.id
  237. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  238. workflow.environment_variables = environment_variables
  239. workflow.conversation_variables = conversation_variables
  240. workflow.rag_pipeline_variables = rag_pipeline_variables
  241. # commit db session changes
  242. db.session.commit()
  243. # trigger workflow events TODO
  244. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  245. # return draft workflow
  246. return workflow
  247. def publish_workflow(
  248. self,
  249. *,
  250. session: Session,
  251. pipeline: Pipeline,
  252. account: Account,
  253. ) -> Workflow:
  254. draft_workflow_stmt = select(Workflow).where(
  255. Workflow.tenant_id == pipeline.tenant_id,
  256. Workflow.app_id == pipeline.id,
  257. Workflow.version == "draft",
  258. )
  259. draft_workflow = session.scalar(draft_workflow_stmt)
  260. if not draft_workflow:
  261. raise ValueError("No valid workflow found.")
  262. # create new workflow
  263. workflow = Workflow.new(
  264. tenant_id=pipeline.tenant_id,
  265. app_id=pipeline.id,
  266. type=draft_workflow.type,
  267. version=str(datetime.now(UTC).replace(tzinfo=None)),
  268. graph=draft_workflow.graph,
  269. features=draft_workflow.features,
  270. created_by=account.id,
  271. environment_variables=draft_workflow.environment_variables,
  272. conversation_variables=draft_workflow.conversation_variables,
  273. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  274. marked_name="",
  275. marked_comment="",
  276. )
  277. # commit db session changes
  278. session.add(workflow)
  279. graph = workflow.graph_dict
  280. nodes = graph.get("nodes", [])
  281. for node in nodes:
  282. if node.get("data", {}).get("type") == "knowledge-index":
  283. knowledge_configuration = node.get("data", {})
  284. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  285. # update dataset
  286. dataset = pipeline.dataset
  287. if not dataset:
  288. raise ValueError("Dataset not found")
  289. DatasetService.update_rag_pipeline_dataset_settings(
  290. session=session,
  291. dataset=dataset,
  292. knowledge_configuration=knowledge_configuration,
  293. has_published=pipeline.is_published,
  294. )
  295. # return new workflow
  296. return workflow
  297. def get_default_block_configs(self) -> list[dict]:
  298. """
  299. Get default block configs
  300. """
  301. # return default block config
  302. default_block_configs = []
  303. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  304. node_class = node_class_mapping[LATEST_VERSION]
  305. default_config = node_class.get_default_config()
  306. if default_config:
  307. default_block_configs.append(default_config)
  308. return default_block_configs
  309. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  310. """
  311. Get default config of node.
  312. :param node_type: node type
  313. :param filters: filter by node config parameters.
  314. :return:
  315. """
  316. node_type_enum = NodeType(node_type)
  317. # return default block config
  318. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  319. return None
  320. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  321. default_config = node_class.get_default_config(filters=filters)
  322. if not default_config:
  323. return None
  324. return default_config
  325. def run_draft_workflow_node(
  326. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  327. ) -> WorkflowNodeExecution:
  328. """
  329. Run draft workflow node
  330. """
  331. # fetch draft workflow by app_model
  332. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  333. if not draft_workflow:
  334. raise ValueError("Workflow not initialized")
  335. # run draft workflow node
  336. start_at = time.perf_counter()
  337. workflow_node_execution = self._handle_node_run_result(
  338. getter=lambda: WorkflowEntry.single_step_run(
  339. workflow=draft_workflow,
  340. node_id=node_id,
  341. user_inputs=user_inputs,
  342. user_id=account.id,
  343. ),
  344. start_at=start_at,
  345. tenant_id=pipeline.tenant_id,
  346. node_id=node_id,
  347. )
  348. workflow_node_execution.workflow_id = draft_workflow.id
  349. db.session.add(workflow_node_execution)
  350. db.session.commit()
  351. return workflow_node_execution
  352. def run_published_workflow_node(
  353. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  354. ) -> WorkflowNodeExecution:
  355. """
  356. Run published workflow node
  357. """
  358. # fetch published workflow by app_model
  359. published_workflow = self.get_published_workflow(pipeline=pipeline)
  360. if not published_workflow:
  361. raise ValueError("Workflow not initialized")
  362. # run draft workflow node
  363. start_at = time.perf_counter()
  364. workflow_node_execution = self._handle_node_run_result(
  365. getter=lambda: WorkflowEntry.single_step_run(
  366. workflow=published_workflow,
  367. node_id=node_id,
  368. user_inputs=user_inputs,
  369. user_id=account.id,
  370. ),
  371. start_at=start_at,
  372. tenant_id=pipeline.tenant_id,
  373. node_id=node_id,
  374. )
  375. workflow_node_execution.workflow_id = published_workflow.id
  376. db.session.add(workflow_node_execution)
  377. db.session.commit()
  378. return workflow_node_execution
  379. def run_datasource_workflow_node_status(
  380. self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
  381. ) -> dict:
  382. """
  383. Run published workflow datasource
  384. """
  385. if is_published:
  386. # fetch published workflow by app_model
  387. workflow = self.get_published_workflow(pipeline=pipeline)
  388. else:
  389. workflow = self.get_draft_workflow(pipeline=pipeline)
  390. if not workflow:
  391. raise ValueError("Workflow not initialized")
  392. # run draft workflow node
  393. datasource_node_data = None
  394. start_at = time.perf_counter()
  395. datasource_nodes = workflow.graph_dict.get("nodes", [])
  396. for datasource_node in datasource_nodes:
  397. if datasource_node.get("id") == node_id:
  398. datasource_node_data = datasource_node.get("data", {})
  399. break
  400. if not datasource_node_data:
  401. raise ValueError("Datasource node data not found")
  402. from core.datasource.datasource_manager import DatasourceManager
  403. datasource_runtime = DatasourceManager.get_datasource_runtime(
  404. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  405. datasource_name=datasource_node_data.get("datasource_name"),
  406. tenant_id=pipeline.tenant_id,
  407. datasource_type=DatasourceProviderType(datasource_type),
  408. )
  409. datasource_provider_service = DatasourceProviderService()
  410. credentials = datasource_provider_service.get_real_datasource_credentials(
  411. tenant_id=pipeline.tenant_id,
  412. provider=datasource_node_data.get('provider_name'),
  413. plugin_id=datasource_node_data.get('plugin_id'),
  414. )
  415. if credentials:
  416. datasource_runtime.runtime.credentials = credentials[0].get("credentials")
  417. match datasource_type:
  418. case DatasourceProviderType.WEBSITE_CRAWL:
  419. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  420. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  421. user_id=account.id,
  422. datasource_parameters={"job_id": job_id},
  423. provider_type=datasource_runtime.datasource_provider_type(),
  424. )
  425. return {
  426. "result": [result for result in website_crawl_result.result],
  427. "job_id": website_crawl_result.result.job_id,
  428. "status": website_crawl_result.result.status,
  429. "provider_type": datasource_node_data.get("provider_type"),
  430. }
  431. case _:
  432. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  433. def run_datasource_workflow_node(
  434. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
  435. is_published: bool
  436. ) -> dict:
  437. """
  438. Run published workflow datasource
  439. """
  440. if is_published:
  441. # fetch published workflow by app_model
  442. workflow = self.get_published_workflow(pipeline=pipeline)
  443. else:
  444. workflow = self.get_draft_workflow(pipeline=pipeline)
  445. if not workflow:
  446. raise ValueError("Workflow not initialized")
  447. # run draft workflow node
  448. datasource_node_data = None
  449. start_at = time.perf_counter()
  450. datasource_nodes = workflow.graph_dict.get("nodes", [])
  451. for datasource_node in datasource_nodes:
  452. if datasource_node.get("id") == node_id:
  453. datasource_node_data = datasource_node.get("data", {})
  454. break
  455. if not datasource_node_data:
  456. raise ValueError("Datasource node data not found")
  457. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  458. for key, value in datasource_parameters.items():
  459. if not user_inputs.get(key):
  460. user_inputs[key] = value["value"]
  461. from core.datasource.datasource_manager import DatasourceManager
  462. datasource_runtime = DatasourceManager.get_datasource_runtime(
  463. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  464. datasource_name=datasource_node_data.get("datasource_name"),
  465. tenant_id=pipeline.tenant_id,
  466. datasource_type=DatasourceProviderType(datasource_type),
  467. )
  468. datasource_provider_service = DatasourceProviderService()
  469. credentials = datasource_provider_service.get_real_datasource_credentials(
  470. tenant_id=pipeline.tenant_id,
  471. provider=datasource_node_data.get('provider_name'),
  472. plugin_id=datasource_node_data.get('plugin_id'),
  473. )
  474. if credentials:
  475. datasource_runtime.runtime.credentials = credentials[0].get("credentials")
  476. match datasource_type:
  477. case DatasourceProviderType.ONLINE_DOCUMENT:
  478. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  479. online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
  480. user_id=account.id,
  481. datasource_parameters=user_inputs,
  482. provider_type=datasource_runtime.datasource_provider_type(),
  483. )
  484. return {
  485. "result": [page.model_dump() for page in online_document_result.result],
  486. "provider_type": datasource_node_data.get("provider_type"),
  487. }
  488. case DatasourceProviderType.WEBSITE_CRAWL:
  489. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  490. website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
  491. user_id=account.id,
  492. datasource_parameters=user_inputs,
  493. provider_type=datasource_runtime.datasource_provider_type(),
  494. )
  495. return {
  496. "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [],
  497. "job_id": website_crawl_result.result.job_id,
  498. "status": website_crawl_result.result.status,
  499. "provider_type": datasource_node_data.get("provider_type"),
  500. }
  501. case _:
  502. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  503. def run_free_workflow_node(
  504. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  505. ) -> WorkflowNodeExecution:
  506. """
  507. Run draft workflow node
  508. """
  509. # run draft workflow node
  510. start_at = time.perf_counter()
  511. workflow_node_execution = self._handle_node_run_result(
  512. getter=lambda: WorkflowEntry.run_free_node(
  513. node_id=node_id,
  514. node_data=node_data,
  515. tenant_id=tenant_id,
  516. user_id=user_id,
  517. user_inputs=user_inputs,
  518. ),
  519. start_at=start_at,
  520. tenant_id=tenant_id,
  521. node_id=node_id,
  522. )
  523. return workflow_node_execution
  524. def _handle_node_run_result(
  525. self,
  526. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  527. start_at: float,
  528. tenant_id: str,
  529. node_id: str,
  530. ) -> WorkflowNodeExecution:
  531. """
  532. Handle node run result
  533. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  534. :param start_at: float
  535. :param tenant_id: str
  536. :param node_id: str
  537. """
  538. try:
  539. node_instance, generator = getter()
  540. node_run_result: NodeRunResult | None = None
  541. for event in generator:
  542. if isinstance(event, RunCompletedEvent):
  543. node_run_result = event.run_result
  544. # sign output files
  545. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  546. break
  547. if not node_run_result:
  548. raise ValueError("Node run failed with no run result")
  549. # single step debug mode error handling return
  550. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  551. node_error_args: dict[str, Any] = {
  552. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  553. "error": node_run_result.error,
  554. "inputs": node_run_result.inputs,
  555. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  556. }
  557. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  558. node_run_result = NodeRunResult(
  559. **node_error_args,
  560. outputs={
  561. **node_instance.node_data.default_value_dict,
  562. "error_message": node_run_result.error,
  563. "error_type": node_run_result.error_type,
  564. },
  565. )
  566. else:
  567. node_run_result = NodeRunResult(
  568. **node_error_args,
  569. outputs={
  570. "error_message": node_run_result.error,
  571. "error_type": node_run_result.error_type,
  572. },
  573. )
  574. run_succeeded = node_run_result.status in (
  575. WorkflowNodeExecutionStatus.SUCCEEDED,
  576. WorkflowNodeExecutionStatus.EXCEPTION,
  577. )
  578. error = node_run_result.error if not run_succeeded else None
  579. except WorkflowNodeRunFailedError as e:
  580. node_instance = e.node_instance
  581. run_succeeded = False
  582. node_run_result = None
  583. error = e.error
  584. workflow_node_execution = WorkflowNodeExecution(
  585. id=str(uuid4()),
  586. workflow_id=node_instance.workflow_id,
  587. index=1,
  588. node_id=node_id,
  589. node_type=node_instance.node_type,
  590. title=node_instance.node_data.title,
  591. elapsed_time=time.perf_counter() - start_at,
  592. finished_at=datetime.now(UTC).replace(tzinfo=None),
  593. created_at=datetime.now(UTC).replace(tzinfo=None),
  594. )
  595. if run_succeeded and node_run_result:
  596. # create workflow node execution
  597. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  598. process_data = (
  599. WorkflowEntry.handle_special_values(node_run_result.process_data)
  600. if node_run_result.process_data
  601. else None
  602. )
  603. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  604. workflow_node_execution.inputs = inputs
  605. workflow_node_execution.process_data = process_data
  606. workflow_node_execution.outputs = outputs
  607. workflow_node_execution.metadata = node_run_result.metadata
  608. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  609. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  610. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  611. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  612. workflow_node_execution.error = node_run_result.error
  613. else:
  614. # create workflow node execution
  615. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
  616. workflow_node_execution.error = error
  617. # update document status
  618. variable_pool = node_instance.graph_runtime_state.variable_pool
  619. invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
  620. if invoke_from:
  621. if invoke_from.value == InvokeFrom.PUBLISHED.value:
  622. document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
  623. if document_id:
  624. document = db.session.query(Document).filter(Document.id == document_id.value).first()
  625. if document:
  626. document.indexing_status = "error"
  627. document.error = error
  628. db.session.add(document)
  629. db.session.commit()
  630. return workflow_node_execution
  631. def update_workflow(
  632. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  633. ) -> Optional[Workflow]:
  634. """
  635. Update workflow attributes
  636. :param session: SQLAlchemy database session
  637. :param workflow_id: Workflow ID
  638. :param tenant_id: Tenant ID
  639. :param account_id: Account ID (for permission check)
  640. :param data: Dictionary containing fields to update
  641. :return: Updated workflow or None if not found
  642. """
  643. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  644. workflow = session.scalar(stmt)
  645. if not workflow:
  646. return None
  647. allowed_fields = ["marked_name", "marked_comment"]
  648. for field, value in data.items():
  649. if field in allowed_fields:
  650. setattr(workflow, field, value)
  651. workflow.updated_by = account_id
  652. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  653. return workflow
  654. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  655. """
  656. Get second step parameters of rag pipeline
  657. """
  658. workflow = self.get_published_workflow(pipeline=pipeline)
  659. if not workflow:
  660. raise ValueError("Workflow not initialized")
  661. # get second step node
  662. rag_pipeline_variables = workflow.rag_pipeline_variables
  663. if not rag_pipeline_variables:
  664. return []
  665. # get datasource provider
  666. datasource_provider_variables = [
  667. item
  668. for item in rag_pipeline_variables
  669. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  670. ]
  671. return datasource_provider_variables
  672. def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  673. """
  674. Get first step parameters of rag pipeline
  675. """
  676. published_workflow = self.get_published_workflow(pipeline=pipeline)
  677. if not published_workflow:
  678. raise ValueError("Workflow not initialized")
  679. # get second step node
  680. datasource_node_data = None
  681. datasource_nodes = published_workflow.graph_dict.get("nodes", [])
  682. for datasource_node in datasource_nodes:
  683. if datasource_node.get("id") == node_id:
  684. datasource_node_data = datasource_node.get("data", {})
  685. break
  686. if not datasource_node_data:
  687. raise ValueError("Datasource node data not found")
  688. variables = datasource_node_data.get("variables", {})
  689. if variables:
  690. variables_map = {
  691. item["variable"]: item for item in variables
  692. }
  693. else:
  694. return []
  695. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  696. user_input_variables = []
  697. for key, value in datasource_parameters.items():
  698. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  699. user_input_variables.append(variables_map.get(key, {}))
  700. return user_input_variables
  701. def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  702. """
  703. Get first step parameters of rag pipeline
  704. """
  705. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  706. if not draft_workflow:
  707. raise ValueError("Workflow not initialized")
  708. # get second step node
  709. datasource_node_data = None
  710. datasource_nodes = draft_workflow.graph_dict.get("nodes", [])
  711. for datasource_node in datasource_nodes:
  712. if datasource_node.get("id") == node_id:
  713. datasource_node_data = datasource_node.get("data", {})
  714. break
  715. if not datasource_node_data:
  716. raise ValueError("Datasource node data not found")
  717. variables = datasource_node_data.get("variables", {})
  718. if variables:
  719. variables_map = {
  720. item["variable"]: item for item in variables
  721. }
  722. else:
  723. return []
  724. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  725. user_input_variables = []
  726. for key, value in datasource_parameters.items():
  727. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  728. user_input_variables.append(variables_map.get(key, {}))
  729. return user_input_variables
  730. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  731. """
  732. Get second step parameters of rag pipeline
  733. """
  734. workflow = self.get_draft_workflow(pipeline=pipeline)
  735. if not workflow:
  736. raise ValueError("Workflow not initialized")
  737. # get second step node
  738. rag_pipeline_variables = workflow.rag_pipeline_variables
  739. if not rag_pipeline_variables:
  740. return []
  741. # get datasource provider
  742. datasource_provider_variables = [
  743. item
  744. for item in rag_pipeline_variables
  745. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  746. ]
  747. return datasource_provider_variables
  748. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  749. """
  750. Get debug workflow run list
  751. Only return triggered_from == debugging
  752. :param app_model: app model
  753. :param args: request args
  754. """
  755. limit = int(args.get("limit", 20))
  756. base_query = db.session.query(WorkflowRun).filter(
  757. WorkflowRun.tenant_id == pipeline.tenant_id,
  758. WorkflowRun.app_id == pipeline.id,
  759. or_(
  760. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  761. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  762. ),
  763. )
  764. if args.get("last_id"):
  765. last_workflow_run = base_query.filter(
  766. WorkflowRun.id == args.get("last_id"),
  767. ).first()
  768. if not last_workflow_run:
  769. raise ValueError("Last workflow run not exists")
  770. workflow_runs = (
  771. base_query.filter(
  772. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  773. )
  774. .order_by(WorkflowRun.created_at.desc())
  775. .limit(limit)
  776. .all()
  777. )
  778. else:
  779. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  780. has_more = False
  781. if len(workflow_runs) == limit:
  782. current_page_first_workflow_run = workflow_runs[-1]
  783. rest_count = base_query.filter(
  784. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  785. WorkflowRun.id != current_page_first_workflow_run.id,
  786. ).count()
  787. if rest_count > 0:
  788. has_more = True
  789. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  790. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  791. """
  792. Get workflow run detail
  793. :param app_model: app model
  794. :param run_id: workflow run id
  795. """
  796. workflow_run = (
  797. db.session.query(WorkflowRun)
  798. .filter(
  799. WorkflowRun.tenant_id == pipeline.tenant_id,
  800. WorkflowRun.app_id == pipeline.id,
  801. WorkflowRun.id == run_id,
  802. )
  803. .first()
  804. )
  805. return workflow_run
  806. def get_rag_pipeline_workflow_run_node_executions(
  807. self,
  808. pipeline: Pipeline,
  809. run_id: str,
  810. user: Account | EndUser,
  811. ) -> list[WorkflowNodeExecution]:
  812. """
  813. Get workflow run node execution list
  814. """
  815. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  816. contexts.plugin_tool_providers.set({})
  817. contexts.plugin_tool_providers_lock.set(threading.Lock())
  818. if not workflow_run:
  819. return []
  820. # Use the repository to get the node execution
  821. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  822. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  823. )
  824. # Use the repository to get the node executions with ordering
  825. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  826. node_executions = repository.get_by_workflow_run(
  827. workflow_run_id=run_id,
  828. order_config=order_config,
  829. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  830. )
  831. return list(node_executions)
  832. @classmethod
  833. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  834. """
  835. Publish customized pipeline template
  836. """
  837. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  838. if not pipeline:
  839. raise ValueError("Pipeline not found")
  840. if not pipeline.workflow_id:
  841. raise ValueError("Pipeline workflow not found")
  842. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  843. if not workflow:
  844. raise ValueError("Workflow not found")
  845. dataset = pipeline.dataset
  846. if not dataset:
  847. raise ValueError("Dataset not found")
  848. max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar()
  849. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  850. dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
  851. pipeline_customized_template = PipelineCustomizedTemplate(
  852. name=args.get("name"),
  853. description=args.get("description"),
  854. icon=args.get("icon_info"),
  855. tenant_id=pipeline.tenant_id,
  856. yaml_content=dsl,
  857. position=max_position + 1 if max_position else 1,
  858. chunk_structure=dataset.chunk_structure,
  859. language="en-US",
  860. created_by=current_user.id,
  861. )
  862. db.session.add(pipeline_customized_template)
  863. db.session.commit()