您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rag_pipeline.py 61KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456
  1. import json
  2. import logging
  3. import re
  4. import threading
  5. import time
  6. from collections.abc import Callable, Generator, Mapping, Sequence
  7. from datetime import UTC, datetime
  8. from typing import Any, Union, cast
  9. from uuid import uuid4
  10. from flask_login import current_user
  11. from sqlalchemy import func, or_, select
  12. from sqlalchemy.orm import Session, sessionmaker
  13. import contexts
  14. from configs import dify_config
  15. from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.datasource.entities.datasource_entities import (
  18. DatasourceMessage,
  19. DatasourceProviderType,
  20. GetOnlineDocumentPageContentRequest,
  21. OnlineDocumentPagesMessage,
  22. OnlineDriveBrowseFilesRequest,
  23. OnlineDriveBrowseFilesResponse,
  24. WebsiteCrawlMessage,
  25. )
  26. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  27. from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
  28. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  29. from core.helper import marketplace
  30. from core.rag.entities.event import (
  31. DatasourceCompletedEvent,
  32. DatasourceErrorEvent,
  33. DatasourceProcessingEvent,
  34. )
  35. from core.repositories.factory import DifyCoreRepositoryFactory
  36. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  37. from core.variables.variables import Variable
  38. from core.workflow.entities.variable_pool import VariablePool
  39. from core.workflow.entities.workflow_node_execution import (
  40. WorkflowNodeExecution,
  41. WorkflowNodeExecutionStatus,
  42. )
  43. from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
  44. from core.workflow.errors import WorkflowNodeRunFailedError
  45. from core.workflow.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent
  46. from core.workflow.graph_events.base import GraphNodeEventBase
  47. from core.workflow.node_events.base import NodeRunResult
  48. from core.workflow.nodes.base.node import Node
  49. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  50. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  51. from core.workflow.system_variable import SystemVariable
  52. from core.workflow.workflow_entry import WorkflowEntry
  53. from extensions.ext_database import db
  54. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  55. from models.account import Account
  56. from models.dataset import ( # type: ignore
  57. Dataset,
  58. Document,
  59. DocumentPipelineExecutionLog,
  60. Pipeline,
  61. PipelineCustomizedTemplate,
  62. PipelineRecommendedPlugin,
  63. )
  64. from models.enums import WorkflowRunTriggeredFrom
  65. from models.model import EndUser
  66. from models.workflow import (
  67. Workflow,
  68. WorkflowNodeExecutionModel,
  69. WorkflowNodeExecutionTriggeredFrom,
  70. WorkflowRun,
  71. WorkflowType,
  72. )
  73. from repositories.factory import DifyAPIRepositoryFactory
  74. from services.datasource_provider_service import DatasourceProviderService
  75. from services.entities.knowledge_entities.rag_pipeline_entities import (
  76. KnowledgeConfiguration,
  77. PipelineTemplateInfoEntity,
  78. )
  79. from services.errors.app import WorkflowHashNotEqualError
  80. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  81. from services.tools.builtin_tools_manage_service import BuiltinToolManageService
  82. from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
  83. logger = logging.getLogger(__name__)
  84. class RagPipelineService:
  85. def __init__(self, session_maker: sessionmaker | None = None):
  86. """Initialize RagPipelineService with repository dependencies."""
  87. if session_maker is None:
  88. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  89. self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
  90. session_maker
  91. )
  92. @classmethod
  93. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  94. if type == "built-in":
  95. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  96. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  97. result = retrieval_instance.get_pipeline_templates(language)
  98. if not result.get("pipeline_templates") and language != "en-US":
  99. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  100. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  101. return result
  102. else:
  103. mode = "customized"
  104. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  105. result = retrieval_instance.get_pipeline_templates(language)
  106. return result
  107. @classmethod
  108. def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
  109. """
  110. Get pipeline template detail.
  111. :param template_id: template id
  112. :return:
  113. """
  114. if type == "built-in":
  115. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  116. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  117. built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
  118. return built_in_result
  119. else:
  120. mode = "customized"
  121. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  122. customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
  123. return customized_result
  124. @classmethod
  125. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  126. """
  127. Update pipeline template.
  128. :param template_id: template id
  129. :param template_info: template info
  130. """
  131. customized_template: PipelineCustomizedTemplate | None = (
  132. db.session.query(PipelineCustomizedTemplate)
  133. .where(
  134. PipelineCustomizedTemplate.id == template_id,
  135. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  136. )
  137. .first()
  138. )
  139. if not customized_template:
  140. raise ValueError("Customized pipeline template not found.")
  141. # check template name is exist
  142. template_name = template_info.name
  143. if template_name:
  144. template = (
  145. db.session.query(PipelineCustomizedTemplate)
  146. .where(
  147. PipelineCustomizedTemplate.name == template_name,
  148. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  149. PipelineCustomizedTemplate.id != template_id,
  150. )
  151. .first()
  152. )
  153. if template:
  154. raise ValueError("Template name is already exists")
  155. customized_template.name = template_info.name
  156. customized_template.description = template_info.description
  157. customized_template.icon = template_info.icon_info.model_dump()
  158. customized_template.updated_by = current_user.id
  159. db.session.commit()
  160. return customized_template
  161. @classmethod
  162. def delete_customized_pipeline_template(cls, template_id: str):
  163. """
  164. Delete customized pipeline template.
  165. """
  166. customized_template: PipelineCustomizedTemplate | None = (
  167. db.session.query(PipelineCustomizedTemplate)
  168. .where(
  169. PipelineCustomizedTemplate.id == template_id,
  170. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  171. )
  172. .first()
  173. )
  174. if not customized_template:
  175. raise ValueError("Customized pipeline template not found.")
  176. db.session.delete(customized_template)
  177. db.session.commit()
  178. def get_draft_workflow(self, pipeline: Pipeline) -> Workflow | None:
  179. """
  180. Get draft workflow
  181. """
  182. # fetch draft workflow by rag pipeline
  183. workflow = (
  184. db.session.query(Workflow)
  185. .where(
  186. Workflow.tenant_id == pipeline.tenant_id,
  187. Workflow.app_id == pipeline.id,
  188. Workflow.version == "draft",
  189. )
  190. .first()
  191. )
  192. # return draft workflow
  193. return workflow
  194. def get_published_workflow(self, pipeline: Pipeline) -> Workflow | None:
  195. """
  196. Get published workflow
  197. """
  198. if not pipeline.workflow_id:
  199. return None
  200. # fetch published workflow by workflow_id
  201. workflow = (
  202. db.session.query(Workflow)
  203. .where(
  204. Workflow.tenant_id == pipeline.tenant_id,
  205. Workflow.app_id == pipeline.id,
  206. Workflow.id == pipeline.workflow_id,
  207. )
  208. .first()
  209. )
  210. return workflow
  211. def get_all_published_workflow(
  212. self,
  213. *,
  214. session: Session,
  215. pipeline: Pipeline,
  216. page: int,
  217. limit: int,
  218. user_id: str | None,
  219. named_only: bool = False,
  220. ) -> tuple[Sequence[Workflow], bool]:
  221. """
  222. Get published workflow with pagination
  223. """
  224. if not pipeline.workflow_id:
  225. return [], False
  226. stmt = (
  227. select(Workflow)
  228. .where(Workflow.app_id == pipeline.id)
  229. .order_by(Workflow.version.desc())
  230. .limit(limit + 1)
  231. .offset((page - 1) * limit)
  232. )
  233. if user_id:
  234. stmt = stmt.where(Workflow.created_by == user_id)
  235. if named_only:
  236. stmt = stmt.where(Workflow.marked_name != "")
  237. workflows = session.scalars(stmt).all()
  238. has_more = len(workflows) > limit
  239. if has_more:
  240. workflows = workflows[:-1]
  241. return workflows, has_more
  242. def sync_draft_workflow(
  243. self,
  244. *,
  245. pipeline: Pipeline,
  246. graph: dict,
  247. unique_hash: str | None,
  248. account: Account,
  249. environment_variables: Sequence[Variable],
  250. conversation_variables: Sequence[Variable],
  251. rag_pipeline_variables: list,
  252. ) -> Workflow:
  253. """
  254. Sync draft workflow
  255. :raises WorkflowHashNotEqualError
  256. """
  257. # fetch draft workflow by app_model
  258. workflow = self.get_draft_workflow(pipeline=pipeline)
  259. if workflow and workflow.unique_hash != unique_hash:
  260. raise WorkflowHashNotEqualError()
  261. # create draft workflow if not found
  262. if not workflow:
  263. workflow = Workflow(
  264. tenant_id=pipeline.tenant_id,
  265. app_id=pipeline.id,
  266. features="{}",
  267. type=WorkflowType.RAG_PIPELINE.value,
  268. version="draft",
  269. graph=json.dumps(graph),
  270. created_by=account.id,
  271. environment_variables=environment_variables,
  272. conversation_variables=conversation_variables,
  273. rag_pipeline_variables=rag_pipeline_variables,
  274. )
  275. db.session.add(workflow)
  276. db.session.flush()
  277. pipeline.workflow_id = workflow.id
  278. # update draft workflow if found
  279. else:
  280. workflow.graph = json.dumps(graph)
  281. workflow.updated_by = account.id
  282. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  283. workflow.environment_variables = environment_variables
  284. workflow.conversation_variables = conversation_variables
  285. workflow.rag_pipeline_variables = rag_pipeline_variables
  286. # commit db session changes
  287. db.session.commit()
  288. # trigger workflow events TODO
  289. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  290. # return draft workflow
  291. return workflow
  292. def publish_workflow(
  293. self,
  294. *,
  295. session: Session,
  296. pipeline: Pipeline,
  297. account: Account,
  298. ) -> Workflow:
  299. draft_workflow_stmt = select(Workflow).where(
  300. Workflow.tenant_id == pipeline.tenant_id,
  301. Workflow.app_id == pipeline.id,
  302. Workflow.version == "draft",
  303. )
  304. draft_workflow = session.scalar(draft_workflow_stmt)
  305. if not draft_workflow:
  306. raise ValueError("No valid workflow found.")
  307. # create new workflow
  308. workflow = Workflow.new(
  309. tenant_id=pipeline.tenant_id,
  310. app_id=pipeline.id,
  311. type=draft_workflow.type,
  312. version=str(datetime.now(UTC).replace(tzinfo=None)),
  313. graph=draft_workflow.graph,
  314. features=draft_workflow.features,
  315. created_by=account.id,
  316. environment_variables=draft_workflow.environment_variables,
  317. conversation_variables=draft_workflow.conversation_variables,
  318. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  319. marked_name="",
  320. marked_comment="",
  321. )
  322. # commit db session changes
  323. session.add(workflow)
  324. graph = workflow.graph_dict
  325. nodes = graph.get("nodes", [])
  326. from services.dataset_service import DatasetService
  327. for node in nodes:
  328. if node.get("data", {}).get("type") == "knowledge-index":
  329. knowledge_configuration = node.get("data", {})
  330. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  331. # update dataset
  332. dataset = pipeline.retrieve_dataset(session=session)
  333. if not dataset:
  334. raise ValueError("Dataset not found")
  335. DatasetService.update_rag_pipeline_dataset_settings(
  336. session=session,
  337. dataset=dataset,
  338. knowledge_configuration=knowledge_configuration,
  339. has_published=pipeline.is_published,
  340. )
  341. # return new workflow
  342. return workflow
  343. def get_default_block_configs(self) -> list[dict]:
  344. """
  345. Get default block configs
  346. """
  347. # return default block config
  348. default_block_configs: list[dict[str, Any]] = []
  349. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  350. node_class = node_class_mapping[LATEST_VERSION]
  351. default_config = node_class.get_default_config()
  352. if default_config:
  353. default_block_configs.append(dict(default_config))
  354. return default_block_configs
  355. def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
  356. """
  357. Get default config of node.
  358. :param node_type: node type
  359. :param filters: filter by node config parameters.
  360. :return:
  361. """
  362. node_type_enum = NodeType(node_type)
  363. # return default block config
  364. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  365. return None
  366. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  367. default_config = node_class.get_default_config(filters=filters)
  368. if not default_config:
  369. return None
  370. return default_config
  371. def run_draft_workflow_node(
  372. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  373. ) -> WorkflowNodeExecutionModel | None:
  374. """
  375. Run draft workflow node
  376. """
  377. # fetch draft workflow by app_model
  378. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  379. if not draft_workflow:
  380. raise ValueError("Workflow not initialized")
  381. # run draft workflow node
  382. start_at = time.perf_counter()
  383. node_config = draft_workflow.get_node_config_by_id(node_id)
  384. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  385. if eclosing_node_type_and_id:
  386. _, enclosing_node_id = eclosing_node_type_and_id
  387. else:
  388. enclosing_node_id = None
  389. workflow_node_execution = self._handle_node_run_result(
  390. getter=lambda: WorkflowEntry.single_step_run(
  391. workflow=draft_workflow,
  392. node_id=node_id,
  393. user_inputs=user_inputs,
  394. user_id=account.id,
  395. variable_pool=VariablePool(
  396. system_variables=SystemVariable.empty(),
  397. user_inputs=user_inputs,
  398. environment_variables=[],
  399. conversation_variables=[],
  400. rag_pipeline_variables=[],
  401. ),
  402. variable_loader=DraftVarLoader(
  403. engine=db.engine,
  404. app_id=pipeline.id,
  405. tenant_id=pipeline.tenant_id,
  406. ),
  407. ),
  408. start_at=start_at,
  409. tenant_id=pipeline.tenant_id,
  410. node_id=node_id,
  411. )
  412. workflow_node_execution.workflow_id = draft_workflow.id
  413. # Create repository and save the node execution
  414. repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
  415. session_factory=db.engine,
  416. user=account,
  417. app_id=pipeline.id,
  418. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  419. )
  420. repository.save(workflow_node_execution)
  421. # Convert node_execution to WorkflowNodeExecution after save
  422. workflow_node_execution_db_model = self._node_execution_service_repo.get_execution_by_id(
  423. workflow_node_execution.id
  424. )
  425. with Session(bind=db.engine) as session, session.begin():
  426. draft_var_saver = DraftVariableSaver(
  427. session=session,
  428. app_id=pipeline.id,
  429. node_id=workflow_node_execution.node_id,
  430. node_type=NodeType(workflow_node_execution.node_type),
  431. enclosing_node_id=enclosing_node_id,
  432. node_execution_id=workflow_node_execution.id,
  433. user=account,
  434. )
  435. draft_var_saver.save(
  436. process_data=workflow_node_execution.process_data,
  437. outputs=workflow_node_execution.outputs,
  438. )
  439. session.commit()
  440. return workflow_node_execution_db_model
  441. def run_datasource_workflow_node(
  442. self,
  443. pipeline: Pipeline,
  444. node_id: str,
  445. user_inputs: dict,
  446. account: Account,
  447. datasource_type: str,
  448. is_published: bool,
  449. credential_id: str | None = None,
  450. ) -> Generator[Mapping[str, Any], None, None]:
  451. """
  452. Run published workflow datasource
  453. """
  454. try:
  455. if is_published:
  456. # fetch published workflow by app_model
  457. workflow = self.get_published_workflow(pipeline=pipeline)
  458. else:
  459. workflow = self.get_draft_workflow(pipeline=pipeline)
  460. if not workflow:
  461. raise ValueError("Workflow not initialized")
  462. # run draft workflow node
  463. datasource_node_data = None
  464. datasource_nodes = workflow.graph_dict.get("nodes", [])
  465. for datasource_node in datasource_nodes:
  466. if datasource_node.get("id") == node_id:
  467. datasource_node_data = datasource_node.get("data", {})
  468. break
  469. if not datasource_node_data:
  470. raise ValueError("Datasource node data not found")
  471. variables_map = {}
  472. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  473. for key, value in datasource_parameters.items():
  474. param_value = value.get("value")
  475. if not param_value:
  476. variables_map[key] = param_value
  477. elif isinstance(param_value, str):
  478. # handle string type parameter value, check if it contains variable reference pattern
  479. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  480. match = re.match(pattern, param_value)
  481. if match:
  482. # extract variable path and try to get value from user inputs
  483. full_path = match.group(1)
  484. last_part = full_path.split(".")[-1]
  485. variables_map[key] = user_inputs.get(last_part, param_value)
  486. else:
  487. variables_map[key] = param_value
  488. elif isinstance(param_value, list) and param_value:
  489. # handle list type parameter value, check if the last element is in user inputs
  490. last_part = param_value[-1]
  491. variables_map[key] = user_inputs.get(last_part, param_value)
  492. else:
  493. # other type directly use original value
  494. variables_map[key] = param_value
  495. from core.datasource.datasource_manager import DatasourceManager
  496. datasource_runtime = DatasourceManager.get_datasource_runtime(
  497. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  498. datasource_name=datasource_node_data.get("datasource_name"),
  499. tenant_id=pipeline.tenant_id,
  500. datasource_type=DatasourceProviderType(datasource_type),
  501. )
  502. datasource_provider_service = DatasourceProviderService()
  503. credentials = datasource_provider_service.get_datasource_credentials(
  504. tenant_id=pipeline.tenant_id,
  505. provider=datasource_node_data.get("provider_name"),
  506. plugin_id=datasource_node_data.get("plugin_id"),
  507. credential_id=credential_id,
  508. )
  509. if credentials:
  510. datasource_runtime.runtime.credentials = credentials
  511. match datasource_type:
  512. case DatasourceProviderType.ONLINE_DOCUMENT:
  513. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  514. online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
  515. datasource_runtime.get_online_document_pages(
  516. user_id=account.id,
  517. datasource_parameters=user_inputs,
  518. provider_type=datasource_runtime.datasource_provider_type(),
  519. )
  520. )
  521. start_time = time.time()
  522. start_event = DatasourceProcessingEvent(
  523. total=0,
  524. completed=0,
  525. )
  526. yield start_event.model_dump()
  527. try:
  528. for online_document_message in online_document_result:
  529. end_time = time.time()
  530. online_document_event = DatasourceCompletedEvent(
  531. data=online_document_message.result, time_consuming=round(end_time - start_time, 2)
  532. )
  533. yield online_document_event.model_dump()
  534. except Exception as e:
  535. logger.exception("Error during online document.")
  536. yield DatasourceErrorEvent(error=str(e)).model_dump()
  537. case DatasourceProviderType.ONLINE_DRIVE:
  538. datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
  539. online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = (
  540. datasource_runtime.online_drive_browse_files(
  541. user_id=account.id,
  542. request=OnlineDriveBrowseFilesRequest(
  543. bucket=user_inputs.get("bucket"),
  544. prefix=user_inputs.get("prefix", ""),
  545. max_keys=user_inputs.get("max_keys", 20),
  546. next_page_parameters=user_inputs.get("next_page_parameters"),
  547. ),
  548. provider_type=datasource_runtime.datasource_provider_type(),
  549. )
  550. )
  551. start_time = time.time()
  552. start_event = DatasourceProcessingEvent(
  553. total=0,
  554. completed=0,
  555. )
  556. yield start_event.model_dump()
  557. for online_drive_message in online_drive_result:
  558. end_time = time.time()
  559. online_drive_event = DatasourceCompletedEvent(
  560. data=online_drive_message.result,
  561. time_consuming=round(end_time - start_time, 2),
  562. total=None,
  563. completed=None,
  564. )
  565. yield online_drive_event.model_dump()
  566. case DatasourceProviderType.WEBSITE_CRAWL:
  567. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  568. website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
  569. datasource_runtime.get_website_crawl(
  570. user_id=account.id,
  571. datasource_parameters=variables_map,
  572. provider_type=datasource_runtime.datasource_provider_type(),
  573. )
  574. )
  575. start_time = time.time()
  576. try:
  577. for website_crawl_message in website_crawl_result:
  578. end_time = time.time()
  579. crawl_event: DatasourceCompletedEvent | DatasourceProcessingEvent
  580. if website_crawl_message.result.status == "completed":
  581. crawl_event = DatasourceCompletedEvent(
  582. data=website_crawl_message.result.web_info_list or [],
  583. total=website_crawl_message.result.total,
  584. completed=website_crawl_message.result.completed,
  585. time_consuming=round(end_time - start_time, 2),
  586. )
  587. else:
  588. crawl_event = DatasourceProcessingEvent(
  589. total=website_crawl_message.result.total,
  590. completed=website_crawl_message.result.completed,
  591. )
  592. yield crawl_event.model_dump()
  593. except Exception as e:
  594. logger.exception("Error during website crawl.")
  595. yield DatasourceErrorEvent(error=str(e)).model_dump()
  596. case _:
  597. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  598. except Exception as e:
  599. logger.exception("Error in run_datasource_workflow_node.")
  600. yield DatasourceErrorEvent(error=str(e)).model_dump()
  601. def run_datasource_node_preview(
  602. self,
  603. pipeline: Pipeline,
  604. node_id: str,
  605. user_inputs: dict,
  606. account: Account,
  607. datasource_type: str,
  608. is_published: bool,
  609. credential_id: str | None = None,
  610. ) -> Mapping[str, Any]:
  611. """
  612. Run published workflow datasource
  613. """
  614. try:
  615. if is_published:
  616. # fetch published workflow by app_model
  617. workflow = self.get_published_workflow(pipeline=pipeline)
  618. else:
  619. workflow = self.get_draft_workflow(pipeline=pipeline)
  620. if not workflow:
  621. raise ValueError("Workflow not initialized")
  622. # run draft workflow node
  623. datasource_node_data = None
  624. datasource_nodes = workflow.graph_dict.get("nodes", [])
  625. for datasource_node in datasource_nodes:
  626. if datasource_node.get("id") == node_id:
  627. datasource_node_data = datasource_node.get("data", {})
  628. break
  629. if not datasource_node_data:
  630. raise ValueError("Datasource node data not found")
  631. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  632. for key, value in datasource_parameters.items():
  633. if not user_inputs.get(key):
  634. user_inputs[key] = value["value"]
  635. from core.datasource.datasource_manager import DatasourceManager
  636. datasource_runtime = DatasourceManager.get_datasource_runtime(
  637. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  638. datasource_name=datasource_node_data.get("datasource_name"),
  639. tenant_id=pipeline.tenant_id,
  640. datasource_type=DatasourceProviderType(datasource_type),
  641. )
  642. datasource_provider_service = DatasourceProviderService()
  643. credentials = datasource_provider_service.get_datasource_credentials(
  644. tenant_id=pipeline.tenant_id,
  645. provider=datasource_node_data.get("provider_name"),
  646. plugin_id=datasource_node_data.get("plugin_id"),
  647. credential_id=credential_id,
  648. )
  649. if credentials:
  650. datasource_runtime.runtime.credentials = credentials
  651. match datasource_type:
  652. case DatasourceProviderType.ONLINE_DOCUMENT:
  653. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  654. online_document_result: Generator[DatasourceMessage, None, None] = (
  655. datasource_runtime.get_online_document_page_content(
  656. user_id=account.id,
  657. datasource_parameters=GetOnlineDocumentPageContentRequest(
  658. workspace_id=user_inputs.get("workspace_id", ""),
  659. page_id=user_inputs.get("page_id", ""),
  660. type=user_inputs.get("type", ""),
  661. ),
  662. provider_type=datasource_type,
  663. )
  664. )
  665. try:
  666. variables: dict[str, Any] = {}
  667. for online_document_message in online_document_result:
  668. if online_document_message.type == DatasourceMessage.MessageType.VARIABLE:
  669. assert isinstance(online_document_message.message, DatasourceMessage.VariableMessage)
  670. variable_name = online_document_message.message.variable_name
  671. variable_value = online_document_message.message.variable_value
  672. if online_document_message.message.stream:
  673. if not isinstance(variable_value, str):
  674. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  675. if variable_name not in variables:
  676. variables[variable_name] = ""
  677. variables[variable_name] += variable_value
  678. else:
  679. variables[variable_name] = variable_value
  680. return variables
  681. except Exception as e:
  682. logger.exception("Error during get online document content.")
  683. raise RuntimeError(str(e))
  684. # TODO Online Drive
  685. case _:
  686. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  687. except Exception as e:
  688. logger.exception("Error in run_datasource_node_preview.")
  689. raise RuntimeError(str(e))
  690. def run_free_workflow_node(
  691. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  692. ) -> WorkflowNodeExecution:
  693. """
  694. Run draft workflow node
  695. """
  696. # run draft workflow node
  697. start_at = time.perf_counter()
  698. workflow_node_execution = self._handle_node_run_result(
  699. getter=lambda: WorkflowEntry.run_free_node(
  700. node_id=node_id,
  701. node_data=node_data,
  702. tenant_id=tenant_id,
  703. user_id=user_id,
  704. user_inputs=user_inputs,
  705. ),
  706. start_at=start_at,
  707. tenant_id=tenant_id,
  708. node_id=node_id,
  709. )
  710. return workflow_node_execution
  711. def _handle_node_run_result(
  712. self,
  713. getter: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]],
  714. start_at: float,
  715. tenant_id: str,
  716. node_id: str,
  717. ) -> WorkflowNodeExecution:
  718. """
  719. Handle node run result
  720. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  721. :param start_at: float
  722. :param tenant_id: str
  723. :param node_id: str
  724. """
  725. try:
  726. node_instance, generator = getter()
  727. node_run_result: NodeRunResult | None = None
  728. for event in generator:
  729. if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)):
  730. node_run_result = event.node_run_result
  731. if node_run_result:
  732. # sign output files
  733. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
  734. break
  735. if not node_run_result:
  736. raise ValueError("Node run failed with no run result")
  737. # single step debug mode error handling return
  738. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.error_strategy:
  739. node_error_args: dict[str, Any] = {
  740. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  741. "error": node_run_result.error,
  742. "inputs": node_run_result.inputs,
  743. "metadata": {"error_strategy": node_instance.error_strategy},
  744. }
  745. if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  746. node_run_result = NodeRunResult(
  747. **node_error_args,
  748. outputs={
  749. **node_instance.default_value_dict,
  750. "error_message": node_run_result.error,
  751. "error_type": node_run_result.error_type,
  752. },
  753. )
  754. else:
  755. node_run_result = NodeRunResult(
  756. **node_error_args,
  757. outputs={
  758. "error_message": node_run_result.error,
  759. "error_type": node_run_result.error_type,
  760. },
  761. )
  762. run_succeeded = node_run_result.status in (
  763. WorkflowNodeExecutionStatus.SUCCEEDED,
  764. WorkflowNodeExecutionStatus.EXCEPTION,
  765. )
  766. error = node_run_result.error if not run_succeeded else None
  767. except WorkflowNodeRunFailedError as e:
  768. node_instance = e._node # type: ignore
  769. run_succeeded = False
  770. node_run_result = None
  771. error = e._error # type: ignore
  772. workflow_node_execution = WorkflowNodeExecution(
  773. id=str(uuid4()),
  774. workflow_id=node_instance.workflow_id,
  775. index=1,
  776. node_id=node_id,
  777. node_type=node_instance.node_type,
  778. title=node_instance.title,
  779. elapsed_time=time.perf_counter() - start_at,
  780. finished_at=datetime.now(UTC).replace(tzinfo=None),
  781. created_at=datetime.now(UTC).replace(tzinfo=None),
  782. )
  783. if run_succeeded and node_run_result:
  784. # create workflow node execution
  785. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  786. process_data = (
  787. WorkflowEntry.handle_special_values(node_run_result.process_data)
  788. if node_run_result.process_data
  789. else None
  790. )
  791. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  792. workflow_node_execution.inputs = inputs
  793. workflow_node_execution.process_data = process_data
  794. workflow_node_execution.outputs = outputs
  795. workflow_node_execution.metadata = node_run_result.metadata
  796. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  797. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  798. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  799. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  800. workflow_node_execution.error = node_run_result.error
  801. else:
  802. # create workflow node execution
  803. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
  804. workflow_node_execution.error = error
  805. # update document status
  806. variable_pool = node_instance.graph_runtime_state.variable_pool
  807. invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
  808. if invoke_from:
  809. if invoke_from.value == InvokeFrom.PUBLISHED.value:
  810. document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
  811. if document_id:
  812. document = db.session.query(Document).where(Document.id == document_id.value).first()
  813. if document:
  814. document.indexing_status = "error"
  815. document.error = error
  816. db.session.add(document)
  817. db.session.commit()
  818. return workflow_node_execution
  819. def update_workflow(
  820. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  821. ) -> Workflow | None:
  822. """
  823. Update workflow attributes
  824. :param session: SQLAlchemy database session
  825. :param workflow_id: Workflow ID
  826. :param tenant_id: Tenant ID
  827. :param account_id: Account ID (for permission check)
  828. :param data: Dictionary containing fields to update
  829. :return: Updated workflow or None if not found
  830. """
  831. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  832. workflow = session.scalar(stmt)
  833. if not workflow:
  834. return None
  835. allowed_fields = ["marked_name", "marked_comment"]
  836. for field, value in data.items():
  837. if field in allowed_fields:
  838. setattr(workflow, field, value)
  839. workflow.updated_by = account_id
  840. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  841. return workflow
  842. def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  843. """
  844. Get first step parameters of rag pipeline
  845. """
  846. workflow = (
  847. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  848. )
  849. if not workflow:
  850. raise ValueError("Workflow not initialized")
  851. datasource_node_data = None
  852. datasource_nodes = workflow.graph_dict.get("nodes", [])
  853. for datasource_node in datasource_nodes:
  854. if datasource_node.get("id") == node_id:
  855. datasource_node_data = datasource_node.get("data", {})
  856. break
  857. if not datasource_node_data:
  858. raise ValueError("Datasource node data not found")
  859. variables = workflow.rag_pipeline_variables
  860. if variables:
  861. variables_map = {item["variable"]: item for item in variables}
  862. else:
  863. return []
  864. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  865. user_input_variables_keys = []
  866. user_input_variables = []
  867. for _, value in datasource_parameters.items():
  868. if value.get("value") and isinstance(value.get("value"), str):
  869. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  870. match = re.match(pattern, value["value"])
  871. if match:
  872. full_path = match.group(1)
  873. last_part = full_path.split(".")[-1]
  874. user_input_variables_keys.append(last_part)
  875. elif value.get("value") and isinstance(value.get("value"), list):
  876. last_part = value.get("value")[-1]
  877. user_input_variables_keys.append(last_part)
  878. for key, value in variables_map.items():
  879. if key in user_input_variables_keys:
  880. user_input_variables.append(value)
  881. return user_input_variables
  882. def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  883. """
  884. Get second step parameters of rag pipeline
  885. """
  886. workflow = (
  887. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  888. )
  889. if not workflow:
  890. raise ValueError("Workflow not initialized")
  891. # get second step node
  892. rag_pipeline_variables = workflow.rag_pipeline_variables
  893. if not rag_pipeline_variables:
  894. return []
  895. variables_map = {item["variable"]: item for item in rag_pipeline_variables}
  896. # get datasource node data
  897. datasource_node_data = None
  898. datasource_nodes = workflow.graph_dict.get("nodes", [])
  899. for datasource_node in datasource_nodes:
  900. if datasource_node.get("id") == node_id:
  901. datasource_node_data = datasource_node.get("data", {})
  902. break
  903. if datasource_node_data:
  904. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  905. for _, value in datasource_parameters.items():
  906. if value.get("value") and isinstance(value.get("value"), str):
  907. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  908. match = re.match(pattern, value["value"])
  909. if match:
  910. full_path = match.group(1)
  911. last_part = full_path.split(".")[-1]
  912. variables_map.pop(last_part, None)
  913. elif value.get("value") and isinstance(value.get("value"), list):
  914. last_part = value.get("value")[-1]
  915. variables_map.pop(last_part, None)
  916. all_second_step_variables = list(variables_map.values())
  917. datasource_provider_variables = [
  918. item
  919. for item in all_second_step_variables
  920. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  921. ]
  922. return datasource_provider_variables
  923. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  924. """
  925. Get debug workflow run list
  926. Only return triggered_from == debugging
  927. :param app_model: app model
  928. :param args: request args
  929. """
  930. limit = int(args.get("limit", 20))
  931. base_query = db.session.query(WorkflowRun).where(
  932. WorkflowRun.tenant_id == pipeline.tenant_id,
  933. WorkflowRun.app_id == pipeline.id,
  934. or_(
  935. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  936. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  937. ),
  938. )
  939. if args.get("last_id"):
  940. last_workflow_run = base_query.where(
  941. WorkflowRun.id == args.get("last_id"),
  942. ).first()
  943. if not last_workflow_run:
  944. raise ValueError("Last workflow run not exists")
  945. workflow_runs = (
  946. base_query.where(
  947. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  948. )
  949. .order_by(WorkflowRun.created_at.desc())
  950. .limit(limit)
  951. .all()
  952. )
  953. else:
  954. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  955. has_more = False
  956. if len(workflow_runs) == limit:
  957. current_page_first_workflow_run = workflow_runs[-1]
  958. rest_count = base_query.where(
  959. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  960. WorkflowRun.id != current_page_first_workflow_run.id,
  961. ).count()
  962. if rest_count > 0:
  963. has_more = True
  964. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  965. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None:
  966. """
  967. Get workflow run detail
  968. :param app_model: app model
  969. :param run_id: workflow run id
  970. """
  971. workflow_run = (
  972. db.session.query(WorkflowRun)
  973. .where(
  974. WorkflowRun.tenant_id == pipeline.tenant_id,
  975. WorkflowRun.app_id == pipeline.id,
  976. WorkflowRun.id == run_id,
  977. )
  978. .first()
  979. )
  980. return workflow_run
  981. def get_rag_pipeline_workflow_run_node_executions(
  982. self,
  983. pipeline: Pipeline,
  984. run_id: str,
  985. user: Account | EndUser,
  986. ) -> list[WorkflowNodeExecutionModel]:
  987. """
  988. Get workflow run node execution list
  989. """
  990. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  991. contexts.plugin_tool_providers.set({})
  992. contexts.plugin_tool_providers_lock.set(threading.Lock())
  993. if not workflow_run:
  994. return []
  995. # Use the repository to get the node execution
  996. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  997. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  998. )
  999. # Use the repository to get the node executions with ordering
  1000. order_config = OrderConfig(order_by=["created_at"], order_direction="asc")
  1001. node_executions = repository.get_db_models_by_workflow_run(
  1002. workflow_run_id=run_id,
  1003. order_config=order_config,
  1004. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  1005. )
  1006. return list(node_executions)
  1007. @classmethod
  1008. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  1009. """
  1010. Publish customized pipeline template
  1011. """
  1012. pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
  1013. if not pipeline:
  1014. raise ValueError("Pipeline not found")
  1015. if not pipeline.workflow_id:
  1016. raise ValueError("Pipeline workflow not found")
  1017. workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
  1018. if not workflow:
  1019. raise ValueError("Workflow not found")
  1020. with Session(db.engine) as session:
  1021. dataset = pipeline.retrieve_dataset(session=session)
  1022. if not dataset:
  1023. raise ValueError("Dataset not found")
  1024. # check template name is exist
  1025. template_name = args.get("name")
  1026. if template_name:
  1027. template = (
  1028. db.session.query(PipelineCustomizedTemplate)
  1029. .where(
  1030. PipelineCustomizedTemplate.name == template_name,
  1031. PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
  1032. )
  1033. .first()
  1034. )
  1035. if template:
  1036. raise ValueError("Template name is already exists")
  1037. max_position = (
  1038. db.session.query(func.max(PipelineCustomizedTemplate.position))
  1039. .where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
  1040. .scalar()
  1041. )
  1042. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  1043. with Session(db.engine) as session:
  1044. rag_pipeline_dsl_service = RagPipelineDslService(session)
  1045. dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
  1046. pipeline_customized_template = PipelineCustomizedTemplate(
  1047. name=args.get("name"),
  1048. description=args.get("description"),
  1049. icon=args.get("icon_info"),
  1050. tenant_id=pipeline.tenant_id,
  1051. yaml_content=dsl,
  1052. position=max_position + 1 if max_position else 1,
  1053. chunk_structure=dataset.chunk_structure,
  1054. language="en-US",
  1055. created_by=current_user.id,
  1056. )
  1057. db.session.add(pipeline_customized_template)
  1058. db.session.commit()
  1059. def is_workflow_exist(self, pipeline: Pipeline) -> bool:
  1060. return (
  1061. db.session.query(Workflow)
  1062. .where(
  1063. Workflow.tenant_id == pipeline.tenant_id,
  1064. Workflow.app_id == pipeline.id,
  1065. Workflow.version == Workflow.VERSION_DRAFT,
  1066. )
  1067. .count()
  1068. ) > 0
  1069. def get_node_last_run(
  1070. self, pipeline: Pipeline, workflow: Workflow, node_id: str
  1071. ) -> WorkflowNodeExecutionModel | None:
  1072. node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
  1073. sessionmaker(db.engine)
  1074. )
  1075. node_exec = node_execution_service_repo.get_node_last_execution(
  1076. tenant_id=pipeline.tenant_id,
  1077. app_id=pipeline.id,
  1078. workflow_id=workflow.id,
  1079. node_id=node_id,
  1080. )
  1081. return node_exec
  1082. def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account):
  1083. """
  1084. Set datasource variables
  1085. """
  1086. # fetch draft workflow by app_model
  1087. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  1088. if not draft_workflow:
  1089. raise ValueError("Workflow not initialized")
  1090. # run draft workflow node
  1091. start_at = time.perf_counter()
  1092. node_id = args.get("start_node_id")
  1093. if not node_id:
  1094. raise ValueError("Node id is required")
  1095. node_config = draft_workflow.get_node_config_by_id(node_id)
  1096. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  1097. if eclosing_node_type_and_id:
  1098. _, enclosing_node_id = eclosing_node_type_and_id
  1099. else:
  1100. enclosing_node_id = None
  1101. system_inputs = SystemVariable(
  1102. datasource_type=args.get("datasource_type", "online_document"),
  1103. datasource_info=args.get("datasource_info", {}),
  1104. )
  1105. workflow_node_execution = self._handle_node_run_result(
  1106. getter=lambda: WorkflowEntry.single_step_run(
  1107. workflow=draft_workflow,
  1108. node_id=node_id,
  1109. user_inputs={},
  1110. user_id=current_user.id,
  1111. variable_pool=VariablePool(
  1112. system_variables=system_inputs,
  1113. user_inputs={},
  1114. environment_variables=[],
  1115. conversation_variables=[],
  1116. rag_pipeline_variables=[],
  1117. ),
  1118. variable_loader=DraftVarLoader(
  1119. engine=db.engine,
  1120. app_id=pipeline.id,
  1121. tenant_id=pipeline.tenant_id,
  1122. ),
  1123. ),
  1124. start_at=start_at,
  1125. tenant_id=pipeline.tenant_id,
  1126. node_id=node_id,
  1127. )
  1128. workflow_node_execution.workflow_id = draft_workflow.id
  1129. # Create repository and save the node execution
  1130. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  1131. session_factory=db.engine,
  1132. user=current_user,
  1133. app_id=pipeline.id,
  1134. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  1135. )
  1136. repository.save(workflow_node_execution)
  1137. # Convert node_execution to WorkflowNodeExecution after save
  1138. workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
  1139. with Session(bind=db.engine) as session, session.begin():
  1140. draft_var_saver = DraftVariableSaver(
  1141. session=session,
  1142. app_id=pipeline.id,
  1143. node_id=workflow_node_execution_db_model.node_id,
  1144. node_type=NodeType(workflow_node_execution_db_model.node_type),
  1145. enclosing_node_id=enclosing_node_id,
  1146. node_execution_id=workflow_node_execution.id,
  1147. user=current_user,
  1148. )
  1149. draft_var_saver.save(
  1150. process_data=workflow_node_execution.process_data,
  1151. outputs=workflow_node_execution.outputs,
  1152. )
  1153. session.commit()
  1154. return workflow_node_execution_db_model
  1155. def get_recommended_plugins(self) -> dict:
  1156. # Query active recommended plugins
  1157. pipeline_recommended_plugins = (
  1158. db.session.query(PipelineRecommendedPlugin)
  1159. .where(PipelineRecommendedPlugin.active == True)
  1160. .order_by(PipelineRecommendedPlugin.position.asc())
  1161. .all()
  1162. )
  1163. if not pipeline_recommended_plugins:
  1164. return {
  1165. "installed_recommended_plugins": [],
  1166. "uninstalled_recommended_plugins": [],
  1167. }
  1168. # Batch fetch plugin manifests
  1169. plugin_ids = [plugin.plugin_id for plugin in pipeline_recommended_plugins]
  1170. providers = BuiltinToolManageService.list_builtin_tools(
  1171. user_id=current_user.id,
  1172. tenant_id=current_user.current_tenant_id,
  1173. )
  1174. providers_map = {provider.plugin_id: provider.to_dict() for provider in providers}
  1175. plugin_manifests = marketplace.batch_fetch_plugin_manifests(plugin_ids)
  1176. plugin_manifests_map = {manifest.plugin_id: manifest for manifest in plugin_manifests}
  1177. installed_plugin_list = []
  1178. uninstalled_plugin_list = []
  1179. for plugin_id in plugin_ids:
  1180. if providers_map.get(plugin_id):
  1181. installed_plugin_list.append(providers_map.get(plugin_id))
  1182. else:
  1183. plugin_manifest = plugin_manifests_map.get(plugin_id)
  1184. if plugin_manifest:
  1185. uninstalled_plugin_list.append(
  1186. {
  1187. "plugin_id": plugin_id,
  1188. "name": plugin_manifest.name,
  1189. "icon": plugin_manifest.icon,
  1190. "plugin_unique_identifier": plugin_manifest.latest_package_identifier,
  1191. }
  1192. )
  1193. # Build recommended plugins list
  1194. return {
  1195. "installed_recommended_plugins": installed_plugin_list,
  1196. "uninstalled_recommended_plugins": uninstalled_plugin_list,
  1197. }
  1198. def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
  1199. """
  1200. Retry error document
  1201. """
  1202. document_pipeline_execution_log = (
  1203. db.session.query(DocumentPipelineExecutionLog)
  1204. .where(DocumentPipelineExecutionLog.document_id == document.id)
  1205. .first()
  1206. )
  1207. if not document_pipeline_execution_log:
  1208. raise ValueError("Document pipeline execution log not found")
  1209. pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
  1210. if not pipeline:
  1211. raise ValueError("Pipeline not found")
  1212. # convert to app config
  1213. workflow = self.get_published_workflow(pipeline)
  1214. if not workflow:
  1215. raise ValueError("Workflow not found")
  1216. PipelineGenerator().generate(
  1217. pipeline=pipeline,
  1218. workflow=workflow,
  1219. user=user,
  1220. args={
  1221. "inputs": document_pipeline_execution_log.input_data,
  1222. "start_node_id": document_pipeline_execution_log.datasource_node_id,
  1223. "datasource_type": document_pipeline_execution_log.datasource_type,
  1224. "datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
  1225. "original_document_id": document.id,
  1226. },
  1227. invoke_from=InvokeFrom.PUBLISHED,
  1228. streaming=False,
  1229. call_depth=0,
  1230. workflow_thread_pool_id=None,
  1231. is_retry=True,
  1232. )
  1233. def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]:
  1234. """
  1235. Get datasource plugins
  1236. """
  1237. dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
  1238. if not dataset:
  1239. raise ValueError("Dataset not found")
  1240. pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
  1241. if not pipeline:
  1242. raise ValueError("Pipeline not found")
  1243. workflow: Workflow | None = None
  1244. if is_published:
  1245. workflow = self.get_published_workflow(pipeline=pipeline)
  1246. else:
  1247. workflow = self.get_draft_workflow(pipeline=pipeline)
  1248. if not pipeline or not workflow:
  1249. raise ValueError("Pipeline or workflow not found")
  1250. datasource_nodes = workflow.graph_dict.get("nodes", [])
  1251. datasource_plugins = []
  1252. for datasource_node in datasource_nodes:
  1253. if datasource_node.get("data", {}).get("type") == "datasource":
  1254. datasource_node_data = datasource_node["data"]
  1255. if not datasource_node_data:
  1256. continue
  1257. variables = workflow.rag_pipeline_variables
  1258. if variables:
  1259. variables_map = {item["variable"]: item for item in variables}
  1260. else:
  1261. variables_map = {}
  1262. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  1263. user_input_variables_keys = []
  1264. user_input_variables = []
  1265. for _, value in datasource_parameters.items():
  1266. if value.get("value") and isinstance(value.get("value"), str):
  1267. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  1268. match = re.match(pattern, value["value"])
  1269. if match:
  1270. full_path = match.group(1)
  1271. last_part = full_path.split(".")[-1]
  1272. user_input_variables_keys.append(last_part)
  1273. elif value.get("value") and isinstance(value.get("value"), list):
  1274. last_part = value.get("value")[-1]
  1275. user_input_variables_keys.append(last_part)
  1276. for key, value in variables_map.items():
  1277. if key in user_input_variables_keys:
  1278. user_input_variables.append(value)
  1279. # get credentials
  1280. datasource_provider_service: DatasourceProviderService = DatasourceProviderService()
  1281. credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials(
  1282. tenant_id=tenant_id,
  1283. provider=datasource_node_data.get("provider_name"),
  1284. plugin_id=datasource_node_data.get("plugin_id"),
  1285. )
  1286. credential_info_list: list[Any] = []
  1287. for credential in credentials:
  1288. credential_info_list.append(
  1289. {
  1290. "id": credential.get("id"),
  1291. "name": credential.get("name"),
  1292. "type": credential.get("type"),
  1293. "is_default": credential.get("is_default"),
  1294. }
  1295. )
  1296. datasource_plugins.append(
  1297. {
  1298. "node_id": datasource_node.get("id"),
  1299. "plugin_id": datasource_node_data.get("plugin_id"),
  1300. "provider_name": datasource_node_data.get("provider_name"),
  1301. "datasource_type": datasource_node_data.get("provider_type"),
  1302. "title": datasource_node_data.get("title"),
  1303. "user_input_variables": user_input_variables,
  1304. "credentials": credential_info_list,
  1305. }
  1306. )
  1307. return datasource_plugins
  1308. def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline:
  1309. """
  1310. Get pipeline
  1311. """
  1312. dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
  1313. if not dataset:
  1314. raise ValueError("Dataset not found")
  1315. pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
  1316. if not pipeline:
  1317. raise ValueError("Pipeline not found")
  1318. return pipeline