You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 51KB

6 months ago
5 months ago
5 months ago
6 months ago
6 months ago
5 months ago
6 months ago
4 months ago
6 months ago
6 months ago
5 months ago
6 months ago
4 months ago
4 months ago
5 months ago
4 months ago
5 months ago
5 months ago
4 months ago
3 months ago
5 months ago
5 months ago
6 months ago
5 months ago
6 months ago
6 months ago
5 months ago
6 months ago
3 months ago
4 months ago
5 months ago
5 months ago
4 months ago
5 months ago
4 months ago
5 months ago
5 months ago
5 months ago
5 months ago
6 months ago
3 months ago
6 months ago
5 months ago
6 months ago
5 months ago
5 months ago
6 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
4 months ago
6 months ago
6 months ago
5 months ago
6 months ago
4 months ago
4 months ago
6 months ago
4 months ago
6 months ago
4 months ago
6 months ago
4 months ago
4 months ago
6 months ago
4 months ago
6 months ago
4 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
6 months ago
6 months ago
6 months ago
3 months ago
6 months ago
3 months ago
6 months ago
3 months ago
3 months ago
6 months ago
3 months ago
3 months ago
6 months ago
3 months ago
3 months ago
3 months ago
3 months ago
3 months ago
3 months ago
5 months ago
3 months ago
5 months ago
3 months ago
3 months ago
3 months ago
3 months ago
3 months ago
4 months ago
4 months ago
4 months ago
4 months ago
3 months ago
5 months ago
2 months ago
2 months ago
2 months ago
4 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
5 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
5 months ago
6 months ago
4 months ago
6 months ago
4 months ago
5 months ago
4 months ago
5 months ago
5 months ago
4 months ago
5 months ago
5 months ago
4 months ago
5 months ago
4 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
4 months ago
5 months ago
5 months ago
4 months ago
5 months ago
4 months ago
5 months ago
5 months ago
4 months ago
5 months ago
4 months ago
5 months ago
4 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
4 months ago
5 months ago
5 months ago
5 months ago
5 months ago
4 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
4 months ago
4 months ago
4 months ago
4 months ago
4 months ago
4 months ago
5 months ago
4 months ago
4 months ago
4 months ago
5 months ago
3 months ago
3 months ago
3 months ago
3 months ago
2 months ago
2 months ago
2 months ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228
  1. import json
  2. import logging
  3. import re
  4. import threading
  5. import time
  6. from collections.abc import Callable, Generator, Mapping, Sequence
  7. from datetime import UTC, datetime
  8. from typing import Any, Optional, cast
  9. from uuid import uuid4
  10. from flask_login import current_user
  11. from sqlalchemy import func, or_, select
  12. from sqlalchemy.orm import Session
  13. import contexts
  14. from configs import dify_config
  15. from core.app.entities.app_invoke_entities import InvokeFrom
  16. from core.datasource.entities.datasource_entities import (
  17. DatasourceMessage,
  18. DatasourceProviderType,
  19. GetOnlineDocumentPageContentRequest,
  20. OnlineDocumentPagesMessage,
  21. OnlineDriveBrowseFilesRequest,
  22. OnlineDriveBrowseFilesResponse,
  23. WebsiteCrawlMessage,
  24. )
  25. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  26. from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
  27. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  28. from core.rag.entities.event import (
  29. DatasourceCompletedEvent,
  30. DatasourceErrorEvent,
  31. DatasourceProcessingEvent,
  32. )
  33. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  34. from core.variables.variables import Variable
  35. from core.workflow.entities.variable_pool import VariablePool
  36. from core.workflow.entities.workflow_node_execution import (
  37. WorkflowNodeExecution,
  38. WorkflowNodeExecutionStatus,
  39. )
  40. from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
  41. from core.workflow.errors import WorkflowNodeRunFailedError
  42. from core.workflow.graph_events.base import GraphNodeEventBase
  43. from core.workflow.node_events.base import NodeRunResult
  44. from core.workflow.node_events.node import StreamCompletedEvent
  45. from core.workflow.nodes.base.node import Node
  46. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  47. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  48. from core.workflow.system_variable import SystemVariable
  49. from core.workflow.workflow_entry import WorkflowEntry
  50. from extensions.ext_database import db
  51. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  52. from models.account import Account
  53. from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
  54. from models.enums import WorkflowRunTriggeredFrom
  55. from models.model import EndUser
  56. from models.workflow import (
  57. Workflow,
  58. WorkflowNodeExecutionModel,
  59. WorkflowNodeExecutionTriggeredFrom,
  60. WorkflowRun,
  61. WorkflowType,
  62. )
  63. from services.dataset_service import DatasetService
  64. from services.datasource_provider_service import DatasourceProviderService
  65. from services.entities.knowledge_entities.rag_pipeline_entities import (
  66. KnowledgeConfiguration,
  67. PipelineTemplateInfoEntity,
  68. )
  69. from services.errors.app import WorkflowHashNotEqualError
  70. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  71. from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
  72. logger = logging.getLogger(__name__)
  73. class RagPipelineService:
  74. @classmethod
  75. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  76. if type == "built-in":
  77. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  78. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  79. result = retrieval_instance.get_pipeline_templates(language)
  80. if not result.get("pipeline_templates") and language != "en-US":
  81. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  82. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  83. return result
  84. else:
  85. mode = "customized"
  86. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  87. result = retrieval_instance.get_pipeline_templates(language)
  88. return result
  89. @classmethod
  90. def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
  91. """
  92. Get pipeline template detail.
  93. :param template_id: template id
  94. :return:
  95. """
  96. if type == "built-in":
  97. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  98. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  99. built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  100. return built_in_result
  101. else:
  102. mode = "customized"
  103. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  104. customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  105. return customized_result
  106. @classmethod
  107. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  108. """
  109. Update pipeline template.
  110. :param template_id: template id
  111. :param template_info: template info
  112. """
  113. customized_template: PipelineCustomizedTemplate | None = (
  114. db.session.query(PipelineCustomizedTemplate)
  115. .filter(
  116. PipelineCustomizedTemplate.id == template_id,
  117. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  118. )
  119. .first()
  120. )
  121. if not customized_template:
  122. raise ValueError("Customized pipeline template not found.")
  123. # check template name is exist
  124. template_name = template_info.name
  125. if template_name:
  126. template = (
  127. db.session.query(PipelineCustomizedTemplate)
  128. .filter(
  129. PipelineCustomizedTemplate.name == template_name,
  130. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  131. PipelineCustomizedTemplate.id != template_id,
  132. )
  133. .first()
  134. )
  135. if template:
  136. raise ValueError("Template name is already exists")
  137. customized_template.name = template_info.name
  138. customized_template.description = template_info.description
  139. customized_template.icon = template_info.icon_info.model_dump()
  140. customized_template.updated_by = current_user.id
  141. db.session.commit()
  142. return customized_template
  143. @classmethod
  144. def delete_customized_pipeline_template(cls, template_id: str):
  145. """
  146. Delete customized pipeline template.
  147. """
  148. customized_template: PipelineCustomizedTemplate | None = (
  149. db.session.query(PipelineCustomizedTemplate)
  150. .filter(
  151. PipelineCustomizedTemplate.id == template_id,
  152. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  153. )
  154. .first()
  155. )
  156. if not customized_template:
  157. raise ValueError("Customized pipeline template not found.")
  158. db.session.delete(customized_template)
  159. db.session.commit()
  160. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  161. """
  162. Get draft workflow
  163. """
  164. # fetch draft workflow by rag pipeline
  165. workflow = (
  166. db.session.query(Workflow)
  167. .filter(
  168. Workflow.tenant_id == pipeline.tenant_id,
  169. Workflow.app_id == pipeline.id,
  170. Workflow.version == "draft",
  171. )
  172. .first()
  173. )
  174. # return draft workflow
  175. return workflow
  176. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  177. """
  178. Get published workflow
  179. """
  180. if not pipeline.workflow_id:
  181. return None
  182. # fetch published workflow by workflow_id
  183. workflow = (
  184. db.session.query(Workflow)
  185. .filter(
  186. Workflow.tenant_id == pipeline.tenant_id,
  187. Workflow.app_id == pipeline.id,
  188. Workflow.id == pipeline.workflow_id,
  189. )
  190. .first()
  191. )
  192. return workflow
  193. def get_all_published_workflow(
  194. self,
  195. *,
  196. session: Session,
  197. pipeline: Pipeline,
  198. page: int,
  199. limit: int,
  200. user_id: str | None,
  201. named_only: bool = False,
  202. ) -> tuple[Sequence[Workflow], bool]:
  203. """
  204. Get published workflow with pagination
  205. """
  206. if not pipeline.workflow_id:
  207. return [], False
  208. stmt = (
  209. select(Workflow)
  210. .where(Workflow.app_id == pipeline.id)
  211. .order_by(Workflow.version.desc())
  212. .limit(limit + 1)
  213. .offset((page - 1) * limit)
  214. )
  215. if user_id:
  216. stmt = stmt.where(Workflow.created_by == user_id)
  217. if named_only:
  218. stmt = stmt.where(Workflow.marked_name != "")
  219. workflows = session.scalars(stmt).all()
  220. has_more = len(workflows) > limit
  221. if has_more:
  222. workflows = workflows[:-1]
  223. return workflows, has_more
  224. def sync_draft_workflow(
  225. self,
  226. *,
  227. pipeline: Pipeline,
  228. graph: dict,
  229. unique_hash: Optional[str],
  230. account: Account,
  231. environment_variables: Sequence[Variable],
  232. conversation_variables: Sequence[Variable],
  233. rag_pipeline_variables: list,
  234. ) -> Workflow:
  235. """
  236. Sync draft workflow
  237. :raises WorkflowHashNotEqualError
  238. """
  239. # fetch draft workflow by app_model
  240. workflow = self.get_draft_workflow(pipeline=pipeline)
  241. if workflow and workflow.unique_hash != unique_hash:
  242. raise WorkflowHashNotEqualError()
  243. # create draft workflow if not found
  244. if not workflow:
  245. workflow = Workflow(
  246. tenant_id=pipeline.tenant_id,
  247. app_id=pipeline.id,
  248. features="{}",
  249. type=WorkflowType.RAG_PIPELINE.value,
  250. version="draft",
  251. graph=json.dumps(graph),
  252. created_by=account.id,
  253. environment_variables=environment_variables,
  254. conversation_variables=conversation_variables,
  255. rag_pipeline_variables=rag_pipeline_variables,
  256. )
  257. db.session.add(workflow)
  258. db.session.flush()
  259. pipeline.workflow_id = workflow.id
  260. # update draft workflow if found
  261. else:
  262. workflow.graph = json.dumps(graph)
  263. workflow.updated_by = account.id
  264. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  265. workflow.environment_variables = environment_variables
  266. workflow.conversation_variables = conversation_variables
  267. workflow.rag_pipeline_variables = rag_pipeline_variables
  268. # commit db session changes
  269. db.session.commit()
  270. # trigger workflow events TODO
  271. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  272. # return draft workflow
  273. return workflow
  274. def publish_workflow(
  275. self,
  276. *,
  277. session: Session,
  278. pipeline: Pipeline,
  279. account: Account,
  280. ) -> Workflow:
  281. draft_workflow_stmt = select(Workflow).where(
  282. Workflow.tenant_id == pipeline.tenant_id,
  283. Workflow.app_id == pipeline.id,
  284. Workflow.version == "draft",
  285. )
  286. draft_workflow = session.scalar(draft_workflow_stmt)
  287. if not draft_workflow:
  288. raise ValueError("No valid workflow found.")
  289. # create new workflow
  290. workflow = Workflow.new(
  291. tenant_id=pipeline.tenant_id,
  292. app_id=pipeline.id,
  293. type=draft_workflow.type,
  294. version=str(datetime.now(UTC).replace(tzinfo=None)),
  295. graph=draft_workflow.graph,
  296. features=draft_workflow.features,
  297. created_by=account.id,
  298. environment_variables=draft_workflow.environment_variables,
  299. conversation_variables=draft_workflow.conversation_variables,
  300. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  301. marked_name="",
  302. marked_comment="",
  303. )
  304. # commit db session changes
  305. session.add(workflow)
  306. graph = workflow.graph_dict
  307. nodes = graph.get("nodes", [])
  308. for node in nodes:
  309. if node.get("data", {}).get("type") == "knowledge-index":
  310. knowledge_configuration = node.get("data", {})
  311. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  312. # update dataset
  313. dataset = pipeline.dataset
  314. if not dataset:
  315. raise ValueError("Dataset not found")
  316. DatasetService.update_rag_pipeline_dataset_settings(
  317. session=session,
  318. dataset=dataset,
  319. knowledge_configuration=knowledge_configuration,
  320. has_published=pipeline.is_published,
  321. )
  322. # return new workflow
  323. return workflow
  324. def get_default_block_configs(self) -> list[dict]:
  325. """
  326. Get default block configs
  327. """
  328. # return default block config
  329. default_block_configs = []
  330. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  331. node_class = node_class_mapping[LATEST_VERSION]
  332. default_config = node_class.get_default_config()
  333. if default_config:
  334. default_block_configs.append(default_config)
  335. return default_block_configs
  336. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  337. """
  338. Get default config of node.
  339. :param node_type: node type
  340. :param filters: filter by node config parameters.
  341. :return:
  342. """
  343. node_type_enum = NodeType(node_type)
  344. # return default block config
  345. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  346. return None
  347. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  348. default_config = node_class.get_default_config(filters=filters)
  349. if not default_config:
  350. return None
  351. return default_config
  352. def run_draft_workflow_node(
  353. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  354. ) -> WorkflowNodeExecutionModel:
  355. """
  356. Run draft workflow node
  357. """
  358. # fetch draft workflow by app_model
  359. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  360. if not draft_workflow:
  361. raise ValueError("Workflow not initialized")
  362. # run draft workflow node
  363. start_at = time.perf_counter()
  364. node_config = draft_workflow.get_node_config_by_id(node_id)
  365. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  366. if eclosing_node_type_and_id:
  367. _, enclosing_node_id = eclosing_node_type_and_id
  368. else:
  369. enclosing_node_id = None
  370. workflow_node_execution = self._handle_node_run_result(
  371. getter=lambda: WorkflowEntry.single_step_run(
  372. workflow=draft_workflow,
  373. node_id=node_id,
  374. user_inputs=user_inputs,
  375. user_id=account.id,
  376. variable_pool=VariablePool(
  377. system_variables=SystemVariable.empty(),
  378. user_inputs=user_inputs,
  379. environment_variables=[],
  380. conversation_variables=[],
  381. rag_pipeline_variables=[],
  382. ),
  383. variable_loader=DraftVarLoader(
  384. engine=db.engine,
  385. app_id=pipeline.id,
  386. tenant_id=pipeline.tenant_id,
  387. ),
  388. ),
  389. start_at=start_at,
  390. tenant_id=pipeline.tenant_id,
  391. node_id=node_id,
  392. )
  393. workflow_node_execution.workflow_id = draft_workflow.id
  394. # Create repository and save the node execution
  395. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  396. session_factory=db.engine,
  397. user=account,
  398. app_id=pipeline.id,
  399. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  400. )
  401. repository.save(workflow_node_execution)
  402. # Convert node_execution to WorkflowNodeExecution after save
  403. workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution)
  404. with Session(bind=db.engine) as session, session.begin():
  405. draft_var_saver = DraftVariableSaver(
  406. session=session,
  407. app_id=pipeline.id,
  408. node_id=workflow_node_execution_db_model.node_id,
  409. node_type=NodeType(workflow_node_execution_db_model.node_type),
  410. enclosing_node_id=enclosing_node_id,
  411. node_execution_id=workflow_node_execution.id,
  412. )
  413. draft_var_saver.save(
  414. process_data=workflow_node_execution.process_data,
  415. outputs=workflow_node_execution.outputs,
  416. )
  417. session.commit()
  418. return workflow_node_execution_db_model
  419. def run_datasource_workflow_node(
  420. self,
  421. pipeline: Pipeline,
  422. node_id: str,
  423. user_inputs: dict,
  424. account: Account,
  425. datasource_type: str,
  426. is_published: bool,
  427. credential_id: Optional[str] = None,
  428. ) -> Generator[Mapping[str, Any], None, None]:
  429. """
  430. Run published workflow datasource
  431. """
  432. try:
  433. if is_published:
  434. # fetch published workflow by app_model
  435. workflow = self.get_published_workflow(pipeline=pipeline)
  436. else:
  437. workflow = self.get_draft_workflow(pipeline=pipeline)
  438. if not workflow:
  439. raise ValueError("Workflow not initialized")
  440. # run draft workflow node
  441. datasource_node_data = None
  442. datasource_nodes = workflow.graph_dict.get("nodes", [])
  443. for datasource_node in datasource_nodes:
  444. if datasource_node.get("id") == node_id:
  445. datasource_node_data = datasource_node.get("data", {})
  446. break
  447. if not datasource_node_data:
  448. raise ValueError("Datasource node data not found")
  449. variables_map = {}
  450. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  451. for key, value in datasource_parameters.items():
  452. if value.get("value") and isinstance(value.get("value"), str):
  453. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  454. match = re.match(pattern, value["value"])
  455. if match:
  456. full_path = match.group(1)
  457. last_part = full_path.split(".")[-1]
  458. if last_part in user_inputs:
  459. variables_map[key] = user_inputs[last_part]
  460. else:
  461. variables_map[key] = value["value"]
  462. else:
  463. variables_map[key] = value["value"]
  464. else:
  465. variables_map[key] = value["value"]
  466. from core.datasource.datasource_manager import DatasourceManager
  467. datasource_runtime = DatasourceManager.get_datasource_runtime(
  468. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  469. datasource_name=datasource_node_data.get("datasource_name"),
  470. tenant_id=pipeline.tenant_id,
  471. datasource_type=DatasourceProviderType(datasource_type),
  472. )
  473. datasource_provider_service = DatasourceProviderService()
  474. credentials = datasource_provider_service.get_datasource_credentials(
  475. tenant_id=pipeline.tenant_id,
  476. provider=datasource_node_data.get("provider_name"),
  477. plugin_id=datasource_node_data.get("plugin_id"),
  478. credential_id=credential_id,
  479. )
  480. if credentials:
  481. datasource_runtime.runtime.credentials = credentials
  482. match datasource_type:
  483. case DatasourceProviderType.ONLINE_DOCUMENT:
  484. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  485. online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
  486. datasource_runtime.get_online_document_pages(
  487. user_id=account.id,
  488. datasource_parameters=user_inputs,
  489. provider_type=datasource_runtime.datasource_provider_type(),
  490. )
  491. )
  492. start_time = time.time()
  493. start_event = DatasourceProcessingEvent(
  494. total=0,
  495. completed=0,
  496. )
  497. yield start_event.model_dump()
  498. try:
  499. for message in online_document_result:
  500. end_time = time.time()
  501. online_document_event = DatasourceCompletedEvent(
  502. data=message.result, time_consuming=round(end_time - start_time, 2)
  503. )
  504. yield online_document_event.model_dump()
  505. except Exception as e:
  506. logger.exception("Error during online document.")
  507. yield DatasourceErrorEvent(error=str(e)).model_dump()
  508. case DatasourceProviderType.ONLINE_DRIVE:
  509. datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
  510. online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = (
  511. datasource_runtime.online_drive_browse_files(
  512. user_id=account.id,
  513. request=OnlineDriveBrowseFilesRequest(
  514. bucket=user_inputs.get("bucket"),
  515. prefix=user_inputs.get("prefix", ""),
  516. max_keys=user_inputs.get("max_keys", 20),
  517. next_page_parameters=user_inputs.get("next_page_parameters"),
  518. ),
  519. provider_type=datasource_runtime.datasource_provider_type(),
  520. )
  521. )
  522. start_time = time.time()
  523. start_event = DatasourceProcessingEvent(
  524. total=0,
  525. completed=0,
  526. )
  527. yield start_event.model_dump()
  528. for message in online_drive_result:
  529. end_time = time.time()
  530. online_drive_event = DatasourceCompletedEvent(
  531. data=message.result,
  532. time_consuming=round(end_time - start_time, 2),
  533. total=None,
  534. completed=None,
  535. )
  536. yield online_drive_event.model_dump()
  537. case DatasourceProviderType.WEBSITE_CRAWL:
  538. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  539. website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
  540. datasource_runtime.get_website_crawl(
  541. user_id=account.id,
  542. datasource_parameters=variables_map,
  543. provider_type=datasource_runtime.datasource_provider_type(),
  544. )
  545. )
  546. start_time = time.time()
  547. try:
  548. for message in website_crawl_result:
  549. end_time = time.time()
  550. if message.result.status == "completed":
  551. crawl_event = DatasourceCompletedEvent(
  552. data=message.result.web_info_list or [],
  553. total=message.result.total,
  554. completed=message.result.completed,
  555. time_consuming=round(end_time - start_time, 2),
  556. )
  557. else:
  558. crawl_event = DatasourceProcessingEvent(
  559. total=message.result.total,
  560. completed=message.result.completed,
  561. )
  562. yield crawl_event.model_dump()
  563. except Exception as e:
  564. logger.exception("Error during website crawl.")
  565. yield DatasourceErrorEvent(error=str(e)).model_dump()
  566. case _:
  567. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  568. except Exception as e:
  569. logger.exception("Error in run_datasource_workflow_node.")
  570. yield DatasourceErrorEvent(error=str(e)).model_dump()
  571. def run_datasource_node_preview(
  572. self,
  573. pipeline: Pipeline,
  574. node_id: str,
  575. user_inputs: dict,
  576. account: Account,
  577. datasource_type: str,
  578. is_published: bool,
  579. credential_id: Optional[str] = None,
  580. ) -> Mapping[str, Any]:
  581. """
  582. Run published workflow datasource
  583. """
  584. try:
  585. if is_published:
  586. # fetch published workflow by app_model
  587. workflow = self.get_published_workflow(pipeline=pipeline)
  588. else:
  589. workflow = self.get_draft_workflow(pipeline=pipeline)
  590. if not workflow:
  591. raise ValueError("Workflow not initialized")
  592. # run draft workflow node
  593. datasource_node_data = None
  594. datasource_nodes = workflow.graph_dict.get("nodes", [])
  595. for datasource_node in datasource_nodes:
  596. if datasource_node.get("id") == node_id:
  597. datasource_node_data = datasource_node.get("data", {})
  598. break
  599. if not datasource_node_data:
  600. raise ValueError("Datasource node data not found")
  601. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  602. for key, value in datasource_parameters.items():
  603. if not user_inputs.get(key):
  604. user_inputs[key] = value["value"]
  605. from core.datasource.datasource_manager import DatasourceManager
  606. datasource_runtime = DatasourceManager.get_datasource_runtime(
  607. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  608. datasource_name=datasource_node_data.get("datasource_name"),
  609. tenant_id=pipeline.tenant_id,
  610. datasource_type=DatasourceProviderType(datasource_type),
  611. )
  612. datasource_provider_service = DatasourceProviderService()
  613. credentials = datasource_provider_service.get_datasource_credentials(
  614. tenant_id=pipeline.tenant_id,
  615. provider=datasource_node_data.get("provider_name"),
  616. plugin_id=datasource_node_data.get("plugin_id"),
  617. credential_id=credential_id,
  618. )
  619. if credentials:
  620. datasource_runtime.runtime.credentials = credentials
  621. match datasource_type:
  622. case DatasourceProviderType.ONLINE_DOCUMENT:
  623. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  624. online_document_result: Generator[DatasourceMessage, None, None] = (
  625. datasource_runtime.get_online_document_page_content(
  626. user_id=account.id,
  627. datasource_parameters=GetOnlineDocumentPageContentRequest(
  628. workspace_id=user_inputs.get("workspace_id", ""),
  629. page_id=user_inputs.get("page_id", ""),
  630. type=user_inputs.get("type", ""),
  631. ),
  632. provider_type=datasource_type,
  633. )
  634. )
  635. try:
  636. variables: dict[str, Any] = {}
  637. for message in online_document_result:
  638. if message.type == DatasourceMessage.MessageType.VARIABLE:
  639. assert isinstance(message.message, DatasourceMessage.VariableMessage)
  640. variable_name = message.message.variable_name
  641. variable_value = message.message.variable_value
  642. if message.message.stream:
  643. if not isinstance(variable_value, str):
  644. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  645. if variable_name not in variables:
  646. variables[variable_name] = ""
  647. variables[variable_name] += variable_value
  648. else:
  649. variables[variable_name] = variable_value
  650. return variables
  651. except Exception as e:
  652. logger.exception("Error during get online document content.")
  653. raise RuntimeError(str(e))
  654. # TODO Online Drive
  655. case _:
  656. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  657. except Exception as e:
  658. logger.exception("Error in run_datasource_node_preview.")
  659. raise RuntimeError(str(e))
  660. def run_free_workflow_node(
  661. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  662. ) -> WorkflowNodeExecution:
  663. """
  664. Run draft workflow node
  665. """
  666. # run draft workflow node
  667. start_at = time.perf_counter()
  668. workflow_node_execution = self._handle_node_run_result(
  669. getter=lambda: WorkflowEntry.run_free_node(
  670. node_id=node_id,
  671. node_data=node_data,
  672. tenant_id=tenant_id,
  673. user_id=user_id,
  674. user_inputs=user_inputs,
  675. ),
  676. start_at=start_at,
  677. tenant_id=tenant_id,
  678. node_id=node_id,
  679. )
  680. return workflow_node_execution
  681. def _handle_node_run_result(
  682. self,
  683. getter: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]],
  684. start_at: float,
  685. tenant_id: str,
  686. node_id: str,
  687. ) -> WorkflowNodeExecution:
  688. """
  689. Handle node run result
  690. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  691. :param start_at: float
  692. :param tenant_id: str
  693. :param node_id: str
  694. """
  695. try:
  696. node_instance, generator = getter()
  697. node_run_result: NodeRunResult | None = None
  698. for event in generator:
  699. if isinstance(event, StreamCompletedEvent):
  700. node_run_result = event.node_run_result
  701. # sign output files
  702. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
  703. break
  704. if not node_run_result:
  705. raise ValueError("Node run failed with no run result")
  706. # single step debug mode error handling return
  707. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.error_strategy:
  708. node_error_args: dict[str, Any] = {
  709. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  710. "error": node_run_result.error,
  711. "inputs": node_run_result.inputs,
  712. "metadata": {"error_strategy": node_instance.error_strategy},
  713. }
  714. if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  715. node_run_result = NodeRunResult(
  716. **node_error_args,
  717. outputs={
  718. **node_instance.default_value_dict,
  719. "error_message": node_run_result.error,
  720. "error_type": node_run_result.error_type,
  721. },
  722. )
  723. else:
  724. node_run_result = NodeRunResult(
  725. **node_error_args,
  726. outputs={
  727. "error_message": node_run_result.error,
  728. "error_type": node_run_result.error_type,
  729. },
  730. )
  731. run_succeeded = node_run_result.status in (
  732. WorkflowNodeExecutionStatus.SUCCEEDED,
  733. WorkflowNodeExecutionStatus.EXCEPTION,
  734. )
  735. error = node_run_result.error if not run_succeeded else None
  736. except WorkflowNodeRunFailedError as e:
  737. node_instance = e._node
  738. run_succeeded = False
  739. node_run_result = None
  740. error = e._error
  741. workflow_node_execution = WorkflowNodeExecution(
  742. id=str(uuid4()),
  743. workflow_id=node_instance.workflow_id,
  744. index=1,
  745. node_id=node_id,
  746. node_type=node_instance.node_type,
  747. title=node_instance.title,
  748. elapsed_time=time.perf_counter() - start_at,
  749. finished_at=datetime.now(UTC).replace(tzinfo=None),
  750. created_at=datetime.now(UTC).replace(tzinfo=None),
  751. )
  752. if run_succeeded and node_run_result:
  753. # create workflow node execution
  754. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  755. process_data = (
  756. WorkflowEntry.handle_special_values(node_run_result.process_data)
  757. if node_run_result.process_data
  758. else None
  759. )
  760. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  761. workflow_node_execution.inputs = inputs
  762. workflow_node_execution.process_data = process_data
  763. workflow_node_execution.outputs = outputs
  764. workflow_node_execution.metadata = node_run_result.metadata
  765. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  766. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  767. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  768. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  769. workflow_node_execution.error = node_run_result.error
  770. else:
  771. # create workflow node execution
  772. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
  773. workflow_node_execution.error = error
  774. # update document status
  775. variable_pool = node_instance.graph_runtime_state.variable_pool
  776. invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
  777. if invoke_from:
  778. if invoke_from.value == InvokeFrom.PUBLISHED.value:
  779. document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
  780. if document_id:
  781. document = db.session.query(Document).filter(Document.id == document_id.value).first()
  782. if document:
  783. document.indexing_status = "error"
  784. document.error = error
  785. db.session.add(document)
  786. db.session.commit()
  787. return workflow_node_execution
  788. def update_workflow(
  789. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  790. ) -> Optional[Workflow]:
  791. """
  792. Update workflow attributes
  793. :param session: SQLAlchemy database session
  794. :param workflow_id: Workflow ID
  795. :param tenant_id: Tenant ID
  796. :param account_id: Account ID (for permission check)
  797. :param data: Dictionary containing fields to update
  798. :return: Updated workflow or None if not found
  799. """
  800. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  801. workflow = session.scalar(stmt)
  802. if not workflow:
  803. return None
  804. allowed_fields = ["marked_name", "marked_comment"]
  805. for field, value in data.items():
  806. if field in allowed_fields:
  807. setattr(workflow, field, value)
  808. workflow.updated_by = account_id
  809. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  810. return workflow
  811. def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  812. """
  813. Get first step parameters of rag pipeline
  814. """
  815. workflow = (
  816. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  817. )
  818. if not workflow:
  819. raise ValueError("Workflow not initialized")
  820. datasource_node_data = None
  821. datasource_nodes = workflow.graph_dict.get("nodes", [])
  822. for datasource_node in datasource_nodes:
  823. if datasource_node.get("id") == node_id:
  824. datasource_node_data = datasource_node.get("data", {})
  825. break
  826. if not datasource_node_data:
  827. raise ValueError("Datasource node data not found")
  828. variables = workflow.rag_pipeline_variables
  829. if variables:
  830. variables_map = {item["variable"]: item for item in variables}
  831. else:
  832. return []
  833. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  834. user_input_variables = []
  835. for key, value in datasource_parameters.items():
  836. if value.get("value") and isinstance(value.get("value"), str):
  837. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  838. match = re.match(pattern, value["value"])
  839. if match:
  840. full_path = match.group(1)
  841. last_part = full_path.split(".")[-1]
  842. user_input_variables.append(variables_map.get(last_part, {}))
  843. return user_input_variables
  844. def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  845. """
  846. Get second step parameters of rag pipeline
  847. """
  848. workflow = (
  849. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  850. )
  851. if not workflow:
  852. raise ValueError("Workflow not initialized")
  853. # get second step node
  854. rag_pipeline_variables = workflow.rag_pipeline_variables
  855. if not rag_pipeline_variables:
  856. return []
  857. variables_map = {item["variable"]: item for item in rag_pipeline_variables}
  858. # get datasource node data
  859. datasource_node_data = None
  860. datasource_nodes = workflow.graph_dict.get("nodes", [])
  861. for datasource_node in datasource_nodes:
  862. if datasource_node.get("id") == node_id:
  863. datasource_node_data = datasource_node.get("data", {})
  864. break
  865. if datasource_node_data:
  866. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  867. for key, value in datasource_parameters.items():
  868. if value.get("value") and isinstance(value.get("value"), str):
  869. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  870. match = re.match(pattern, value["value"])
  871. if match:
  872. full_path = match.group(1)
  873. last_part = full_path.split(".")[-1]
  874. variables_map.pop(last_part)
  875. all_second_step_variables = list(variables_map.values())
  876. datasource_provider_variables = [
  877. item
  878. for item in all_second_step_variables
  879. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  880. ]
  881. return datasource_provider_variables
  882. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  883. """
  884. Get debug workflow run list
  885. Only return triggered_from == debugging
  886. :param app_model: app model
  887. :param args: request args
  888. """
  889. limit = int(args.get("limit", 20))
  890. base_query = db.session.query(WorkflowRun).filter(
  891. WorkflowRun.tenant_id == pipeline.tenant_id,
  892. WorkflowRun.app_id == pipeline.id,
  893. or_(
  894. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  895. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  896. ),
  897. )
  898. if args.get("last_id"):
  899. last_workflow_run = base_query.filter(
  900. WorkflowRun.id == args.get("last_id"),
  901. ).first()
  902. if not last_workflow_run:
  903. raise ValueError("Last workflow run not exists")
  904. workflow_runs = (
  905. base_query.filter(
  906. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  907. )
  908. .order_by(WorkflowRun.created_at.desc())
  909. .limit(limit)
  910. .all()
  911. )
  912. else:
  913. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  914. has_more = False
  915. if len(workflow_runs) == limit:
  916. current_page_first_workflow_run = workflow_runs[-1]
  917. rest_count = base_query.filter(
  918. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  919. WorkflowRun.id != current_page_first_workflow_run.id,
  920. ).count()
  921. if rest_count > 0:
  922. has_more = True
  923. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  924. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  925. """
  926. Get workflow run detail
  927. :param app_model: app model
  928. :param run_id: workflow run id
  929. """
  930. workflow_run = (
  931. db.session.query(WorkflowRun)
  932. .filter(
  933. WorkflowRun.tenant_id == pipeline.tenant_id,
  934. WorkflowRun.app_id == pipeline.id,
  935. WorkflowRun.id == run_id,
  936. )
  937. .first()
  938. )
  939. return workflow_run
  940. def get_rag_pipeline_workflow_run_node_executions(
  941. self,
  942. pipeline: Pipeline,
  943. run_id: str,
  944. user: Account | EndUser,
  945. ) -> list[WorkflowNodeExecutionModel]:
  946. """
  947. Get workflow run node execution list
  948. """
  949. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  950. contexts.plugin_tool_providers.set({})
  951. contexts.plugin_tool_providers_lock.set(threading.Lock())
  952. if not workflow_run:
  953. return []
  954. # Use the repository to get the node execution
  955. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  956. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  957. )
  958. # Use the repository to get the node executions with ordering
  959. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  960. node_executions = repository.get_db_models_by_workflow_run(
  961. workflow_run_id=run_id,
  962. order_config=order_config,
  963. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  964. )
  965. return list(node_executions)
  966. @classmethod
  967. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  968. """
  969. Publish customized pipeline template
  970. """
  971. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  972. if not pipeline:
  973. raise ValueError("Pipeline not found")
  974. if not pipeline.workflow_id:
  975. raise ValueError("Pipeline workflow not found")
  976. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  977. if not workflow:
  978. raise ValueError("Workflow not found")
  979. dataset = pipeline.dataset
  980. if not dataset:
  981. raise ValueError("Dataset not found")
  982. # check template name is exist
  983. template_name = args.get("name")
  984. if template_name:
  985. template = (
  986. db.session.query(PipelineCustomizedTemplate)
  987. .filter(
  988. PipelineCustomizedTemplate.name == template_name,
  989. PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
  990. )
  991. .first()
  992. )
  993. if template:
  994. raise ValueError("Template name is already exists")
  995. max_position = (
  996. db.session.query(func.max(PipelineCustomizedTemplate.position))
  997. .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
  998. .scalar()
  999. )
  1000. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  1001. dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
  1002. pipeline_customized_template = PipelineCustomizedTemplate(
  1003. name=args.get("name"),
  1004. description=args.get("description"),
  1005. icon=args.get("icon_info"),
  1006. tenant_id=pipeline.tenant_id,
  1007. yaml_content=dsl,
  1008. position=max_position + 1 if max_position else 1,
  1009. chunk_structure=dataset.chunk_structure,
  1010. language="en-US",
  1011. created_by=current_user.id,
  1012. )
  1013. db.session.add(pipeline_customized_template)
  1014. db.session.commit()
  1015. def is_workflow_exist(self, pipeline: Pipeline) -> bool:
  1016. return (
  1017. db.session.query(Workflow)
  1018. .filter(
  1019. Workflow.tenant_id == pipeline.tenant_id,
  1020. Workflow.app_id == pipeline.id,
  1021. Workflow.version == Workflow.VERSION_DRAFT,
  1022. )
  1023. .count()
  1024. ) > 0
  1025. def get_node_last_run(
  1026. self, pipeline: Pipeline, workflow: Workflow, node_id: str
  1027. ) -> WorkflowNodeExecutionModel | None:
  1028. # TODO(QuantumGhost): This query is not fully covered by index.
  1029. criteria = (
  1030. WorkflowNodeExecutionModel.tenant_id == pipeline.tenant_id,
  1031. WorkflowNodeExecutionModel.app_id == pipeline.id,
  1032. WorkflowNodeExecutionModel.workflow_id == workflow.id,
  1033. WorkflowNodeExecutionModel.node_id == node_id,
  1034. )
  1035. node_exec = (
  1036. db.session.query(WorkflowNodeExecutionModel)
  1037. .filter(*criteria)
  1038. .order_by(WorkflowNodeExecutionModel.created_at.desc())
  1039. .first()
  1040. )
  1041. return node_exec
  1042. def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser):
  1043. """
  1044. Set datasource variables
  1045. """
  1046. # fetch draft workflow by app_model
  1047. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  1048. if not draft_workflow:
  1049. raise ValueError("Workflow not initialized")
  1050. # run draft workflow node
  1051. start_at = time.perf_counter()
  1052. node_id = args.get("start_node_id")
  1053. if not node_id:
  1054. raise ValueError("Node id is required")
  1055. node_config = draft_workflow.get_node_config_by_id(node_id)
  1056. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  1057. if eclosing_node_type_and_id:
  1058. _, enclosing_node_id = eclosing_node_type_and_id
  1059. else:
  1060. enclosing_node_id = None
  1061. system_inputs = SystemVariable(
  1062. datasource_type=args.get("datasource_type", "online_document"),
  1063. datasource_info=args.get("datasource_info", {}),
  1064. )
  1065. workflow_node_execution = self._handle_node_run_result(
  1066. getter=lambda: WorkflowEntry.single_step_run(
  1067. workflow=draft_workflow,
  1068. node_id=node_id,
  1069. user_inputs={},
  1070. user_id=current_user.id,
  1071. variable_pool=VariablePool(
  1072. system_variables=system_inputs,
  1073. user_inputs={},
  1074. environment_variables=[],
  1075. conversation_variables=[],
  1076. rag_pipeline_variables=[],
  1077. ),
  1078. variable_loader=DraftVarLoader(
  1079. engine=db.engine,
  1080. app_id=pipeline.id,
  1081. tenant_id=pipeline.tenant_id,
  1082. ),
  1083. ),
  1084. start_at=start_at,
  1085. tenant_id=pipeline.tenant_id,
  1086. node_id=node_id,
  1087. )
  1088. workflow_node_execution.workflow_id = draft_workflow.id
  1089. # Create repository and save the node execution
  1090. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  1091. session_factory=db.engine,
  1092. user=current_user,
  1093. app_id=pipeline.id,
  1094. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  1095. )
  1096. repository.save(workflow_node_execution)
  1097. # Convert node_execution to WorkflowNodeExecution after save
  1098. workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution)
  1099. with Session(bind=db.engine) as session, session.begin():
  1100. draft_var_saver = DraftVariableSaver(
  1101. session=session,
  1102. app_id=pipeline.id,
  1103. node_id=workflow_node_execution_db_model.node_id,
  1104. node_type=NodeType(workflow_node_execution_db_model.node_type),
  1105. enclosing_node_id=enclosing_node_id,
  1106. node_execution_id=workflow_node_execution.id,
  1107. )
  1108. draft_var_saver.save(
  1109. process_data=workflow_node_execution.process_data,
  1110. outputs=workflow_node_execution.outputs,
  1111. )
  1112. session.commit()
  1113. return workflow_node_execution_db_model