You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 54KB

6 月之前
5 月之前
5 月之前
6 月之前
6 月之前
5 月之前
7 月之前
5 月之前
7 月之前
5 月之前
6 月之前
5 月之前
4 月之前
5 月之前
4 月之前
5 月之前
5 月之前
4 月之前
3 月之前
5 月之前
5 月之前
2 月之前
6 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
4 月之前
5 月之前
5 月之前
4 月之前
5 月之前
4 月之前
5 月之前
5 月之前
5 月之前
5 月之前
6 月之前
3 月之前
7 月之前
5 月之前
7 月之前
5 月之前
5 月之前
7 月之前
6 月之前
5 月之前
6 月之前
5 月之前
7 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
5 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
5 月之前
6 月之前
5 月之前
6 月之前
4 月之前
6 月之前
5 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
6 月之前
6 月之前
6 月之前
6 月之前
3 月之前
6 月之前
3 月之前
3 月之前
6 月之前
3 月之前
3 月之前
6 月之前
3 月之前
3 月之前
3 月之前
3 月之前
3 月之前
3 月之前
5 月之前
3 月之前
5 月之前
3 月之前
4 月之前
4 月之前
4 月之前
4 月之前
4 月之前
4 月之前
3 月之前
3 月之前
4 月之前
4 月之前
4 月之前
4 月之前
4 月之前
5 月之前
4 月之前
6 月之前
6 月之前
2 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
6 月之前
5 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
5 月之前
6 月之前
4 月之前
5 月之前
4 月之前
5 月之前
5 月之前
4 月之前
5 月之前
5 月之前
4 月之前
5 月之前
4 月之前
5 月之前
5 月之前
5 月之前
5 月之前
4 月之前
5 月之前
5 月之前
4 月之前
5 月之前
4 月之前
5 月之前
5 月之前
4 月之前
5 月之前
4 月之前
4 月之前
5 月之前
4 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
4 月之前
4 月之前
5 月之前
4 月之前
5 月之前
5 月之前
5 月之前
5 月之前
5 月之前
4 月之前
4 月之前
4 月之前
4 月之前
4 月之前
2 月之前
2 月之前
2 月之前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314
  1. import json
  2. import logging
  3. import re
  4. import threading
  5. import time
  6. from collections.abc import Callable, Generator, Mapping, Sequence
  7. from datetime import UTC, datetime
  8. from typing import Any, Optional, cast
  9. from uuid import uuid4
  10. from flask_login import current_user
  11. from sqlalchemy import func, or_, select
  12. from sqlalchemy.orm import Session, sessionmaker
  13. import contexts
  14. from configs import dify_config
  15. from core.app.entities.app_invoke_entities import InvokeFrom
  16. from core.datasource.entities.datasource_entities import (
  17. DatasourceMessage,
  18. DatasourceProviderType,
  19. GetOnlineDocumentPageContentRequest,
  20. OnlineDocumentPagesMessage,
  21. OnlineDriveBrowseFilesRequest,
  22. OnlineDriveBrowseFilesResponse,
  23. WebsiteCrawlMessage,
  24. )
  25. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  26. from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
  27. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  28. from core.helper import marketplace
  29. from core.rag.entities.event import (
  30. DatasourceCompletedEvent,
  31. DatasourceErrorEvent,
  32. DatasourceProcessingEvent,
  33. )
  34. from core.repositories.factory import DifyCoreRepositoryFactory
  35. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  36. from core.variables.variables import Variable
  37. from core.workflow.entities.variable_pool import VariablePool
  38. from core.workflow.entities.workflow_node_execution import (
  39. WorkflowNodeExecution,
  40. WorkflowNodeExecutionStatus,
  41. )
  42. from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
  43. from core.workflow.errors import WorkflowNodeRunFailedError
  44. from core.workflow.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent
  45. from core.workflow.graph_events.base import GraphNodeEventBase
  46. from core.workflow.node_events.base import NodeRunResult
  47. from core.workflow.nodes.base.node import Node
  48. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  49. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  50. from core.workflow.system_variable import SystemVariable
  51. from core.workflow.workflow_entry import WorkflowEntry
  52. from extensions.ext_database import db
  53. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  54. from models.account import Account
  55. from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore
  56. from models.enums import WorkflowRunTriggeredFrom
  57. from models.model import EndUser
  58. from models.workflow import (
  59. Workflow,
  60. WorkflowNodeExecutionModel,
  61. WorkflowNodeExecutionTriggeredFrom,
  62. WorkflowRun,
  63. WorkflowType,
  64. )
  65. from repositories.factory import DifyAPIRepositoryFactory
  66. from services.dataset_service import DatasetService
  67. from services.datasource_provider_service import DatasourceProviderService
  68. from services.entities.knowledge_entities.rag_pipeline_entities import (
  69. KnowledgeConfiguration,
  70. PipelineTemplateInfoEntity,
  71. )
  72. from services.errors.app import WorkflowHashNotEqualError
  73. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  74. from services.tools.builtin_tools_manage_service import BuiltinToolManageService
  75. from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
  76. logger = logging.getLogger(__name__)
  77. class RagPipelineService:
  78. def __init__(self, session_maker: sessionmaker | None = None):
  79. """Initialize RagPipelineService with repository dependencies."""
  80. if session_maker is None:
  81. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  82. self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
  83. session_maker
  84. )
  85. @classmethod
  86. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  87. if type == "built-in":
  88. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  89. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  90. result = retrieval_instance.get_pipeline_templates(language)
  91. if not result.get("pipeline_templates") and language != "en-US":
  92. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  93. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  94. return result
  95. else:
  96. mode = "customized"
  97. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  98. result = retrieval_instance.get_pipeline_templates(language)
  99. return result
  100. @classmethod
  101. def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
  102. """
  103. Get pipeline template detail.
  104. :param template_id: template id
  105. :return:
  106. """
  107. if type == "built-in":
  108. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  109. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  110. built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  111. return built_in_result
  112. else:
  113. mode = "customized"
  114. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  115. customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  116. return customized_result
  117. @classmethod
  118. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  119. """
  120. Update pipeline template.
  121. :param template_id: template id
  122. :param template_info: template info
  123. """
  124. customized_template: PipelineCustomizedTemplate | None = (
  125. db.session.query(PipelineCustomizedTemplate)
  126. .filter(
  127. PipelineCustomizedTemplate.id == template_id,
  128. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  129. )
  130. .first()
  131. )
  132. if not customized_template:
  133. raise ValueError("Customized pipeline template not found.")
  134. # check template name is exist
  135. template_name = template_info.name
  136. if template_name:
  137. template = (
  138. db.session.query(PipelineCustomizedTemplate)
  139. .filter(
  140. PipelineCustomizedTemplate.name == template_name,
  141. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  142. PipelineCustomizedTemplate.id != template_id,
  143. )
  144. .first()
  145. )
  146. if template:
  147. raise ValueError("Template name is already exists")
  148. customized_template.name = template_info.name
  149. customized_template.description = template_info.description
  150. customized_template.icon = template_info.icon_info.model_dump()
  151. customized_template.updated_by = current_user.id
  152. db.session.commit()
  153. return customized_template
  154. @classmethod
  155. def delete_customized_pipeline_template(cls, template_id: str):
  156. """
  157. Delete customized pipeline template.
  158. """
  159. customized_template: PipelineCustomizedTemplate | None = (
  160. db.session.query(PipelineCustomizedTemplate)
  161. .filter(
  162. PipelineCustomizedTemplate.id == template_id,
  163. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  164. )
  165. .first()
  166. )
  167. if not customized_template:
  168. raise ValueError("Customized pipeline template not found.")
  169. db.session.delete(customized_template)
  170. db.session.commit()
  171. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  172. """
  173. Get draft workflow
  174. """
  175. # fetch draft workflow by rag pipeline
  176. workflow = (
  177. db.session.query(Workflow)
  178. .filter(
  179. Workflow.tenant_id == pipeline.tenant_id,
  180. Workflow.app_id == pipeline.id,
  181. Workflow.version == "draft",
  182. )
  183. .first()
  184. )
  185. # return draft workflow
  186. return workflow
  187. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  188. """
  189. Get published workflow
  190. """
  191. if not pipeline.workflow_id:
  192. return None
  193. # fetch published workflow by workflow_id
  194. workflow = (
  195. db.session.query(Workflow)
  196. .filter(
  197. Workflow.tenant_id == pipeline.tenant_id,
  198. Workflow.app_id == pipeline.id,
  199. Workflow.id == pipeline.workflow_id,
  200. )
  201. .first()
  202. )
  203. return workflow
  204. def get_all_published_workflow(
  205. self,
  206. *,
  207. session: Session,
  208. pipeline: Pipeline,
  209. page: int,
  210. limit: int,
  211. user_id: str | None,
  212. named_only: bool = False,
  213. ) -> tuple[Sequence[Workflow], bool]:
  214. """
  215. Get published workflow with pagination
  216. """
  217. if not pipeline.workflow_id:
  218. return [], False
  219. stmt = (
  220. select(Workflow)
  221. .where(Workflow.app_id == pipeline.id)
  222. .order_by(Workflow.version.desc())
  223. .limit(limit + 1)
  224. .offset((page - 1) * limit)
  225. )
  226. if user_id:
  227. stmt = stmt.where(Workflow.created_by == user_id)
  228. if named_only:
  229. stmt = stmt.where(Workflow.marked_name != "")
  230. workflows = session.scalars(stmt).all()
  231. has_more = len(workflows) > limit
  232. if has_more:
  233. workflows = workflows[:-1]
  234. return workflows, has_more
  235. def sync_draft_workflow(
  236. self,
  237. *,
  238. pipeline: Pipeline,
  239. graph: dict,
  240. unique_hash: Optional[str],
  241. account: Account,
  242. environment_variables: Sequence[Variable],
  243. conversation_variables: Sequence[Variable],
  244. rag_pipeline_variables: list,
  245. ) -> Workflow:
  246. """
  247. Sync draft workflow
  248. :raises WorkflowHashNotEqualError
  249. """
  250. # fetch draft workflow by app_model
  251. workflow = self.get_draft_workflow(pipeline=pipeline)
  252. if workflow and workflow.unique_hash != unique_hash:
  253. raise WorkflowHashNotEqualError()
  254. # create draft workflow if not found
  255. if not workflow:
  256. workflow = Workflow(
  257. tenant_id=pipeline.tenant_id,
  258. app_id=pipeline.id,
  259. features="{}",
  260. type=WorkflowType.RAG_PIPELINE.value,
  261. version="draft",
  262. graph=json.dumps(graph),
  263. created_by=account.id,
  264. environment_variables=environment_variables,
  265. conversation_variables=conversation_variables,
  266. rag_pipeline_variables=rag_pipeline_variables,
  267. )
  268. db.session.add(workflow)
  269. db.session.flush()
  270. pipeline.workflow_id = workflow.id
  271. # update draft workflow if found
  272. else:
  273. workflow.graph = json.dumps(graph)
  274. workflow.updated_by = account.id
  275. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  276. workflow.environment_variables = environment_variables
  277. workflow.conversation_variables = conversation_variables
  278. workflow.rag_pipeline_variables = rag_pipeline_variables
  279. # commit db session changes
  280. db.session.commit()
  281. # trigger workflow events TODO
  282. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  283. # return draft workflow
  284. return workflow
  285. def publish_workflow(
  286. self,
  287. *,
  288. session: Session,
  289. pipeline: Pipeline,
  290. account: Account,
  291. ) -> Workflow:
  292. draft_workflow_stmt = select(Workflow).where(
  293. Workflow.tenant_id == pipeline.tenant_id,
  294. Workflow.app_id == pipeline.id,
  295. Workflow.version == "draft",
  296. )
  297. draft_workflow = session.scalar(draft_workflow_stmt)
  298. if not draft_workflow:
  299. raise ValueError("No valid workflow found.")
  300. # create new workflow
  301. workflow = Workflow.new(
  302. tenant_id=pipeline.tenant_id,
  303. app_id=pipeline.id,
  304. type=draft_workflow.type,
  305. version=str(datetime.now(UTC).replace(tzinfo=None)),
  306. graph=draft_workflow.graph,
  307. features=draft_workflow.features,
  308. created_by=account.id,
  309. environment_variables=draft_workflow.environment_variables,
  310. conversation_variables=draft_workflow.conversation_variables,
  311. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  312. marked_name="",
  313. marked_comment="",
  314. )
  315. # commit db session changes
  316. session.add(workflow)
  317. graph = workflow.graph_dict
  318. nodes = graph.get("nodes", [])
  319. for node in nodes:
  320. if node.get("data", {}).get("type") == "knowledge-index":
  321. knowledge_configuration = node.get("data", {})
  322. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  323. # update dataset
  324. with Session(db.engine) as session:
  325. dataset = pipeline.retrieve_dataset(session=session)
  326. if not dataset:
  327. raise ValueError("Dataset not found")
  328. DatasetService.update_rag_pipeline_dataset_settings(
  329. session=session,
  330. dataset=dataset,
  331. knowledge_configuration=knowledge_configuration,
  332. has_published=pipeline.is_published,
  333. )
  334. # return new workflow
  335. return workflow
  336. def get_default_block_configs(self) -> list[dict]:
  337. """
  338. Get default block configs
  339. """
  340. # return default block config
  341. default_block_configs = []
  342. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  343. node_class = node_class_mapping[LATEST_VERSION]
  344. default_config = node_class.get_default_config()
  345. if default_config:
  346. default_block_configs.append(default_config)
  347. return default_block_configs
  348. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  349. """
  350. Get default config of node.
  351. :param node_type: node type
  352. :param filters: filter by node config parameters.
  353. :return:
  354. """
  355. node_type_enum = NodeType(node_type)
  356. # return default block config
  357. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  358. return None
  359. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  360. default_config = node_class.get_default_config(filters=filters)
  361. if not default_config:
  362. return None
  363. return default_config
  364. def run_draft_workflow_node(
  365. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  366. ) -> WorkflowNodeExecutionModel | None:
  367. """
  368. Run draft workflow node
  369. """
  370. # fetch draft workflow by app_model
  371. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  372. if not draft_workflow:
  373. raise ValueError("Workflow not initialized")
  374. # run draft workflow node
  375. start_at = time.perf_counter()
  376. node_config = draft_workflow.get_node_config_by_id(node_id)
  377. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  378. if eclosing_node_type_and_id:
  379. _, enclosing_node_id = eclosing_node_type_and_id
  380. else:
  381. enclosing_node_id = None
  382. workflow_node_execution = self._handle_node_run_result(
  383. getter=lambda: WorkflowEntry.single_step_run(
  384. workflow=draft_workflow,
  385. node_id=node_id,
  386. user_inputs=user_inputs,
  387. user_id=account.id,
  388. variable_pool=VariablePool(
  389. system_variables=SystemVariable.empty(),
  390. user_inputs=user_inputs,
  391. environment_variables=[],
  392. conversation_variables=[],
  393. rag_pipeline_variables=[],
  394. ),
  395. variable_loader=DraftVarLoader(
  396. engine=db.engine,
  397. app_id=pipeline.id,
  398. tenant_id=pipeline.tenant_id,
  399. ),
  400. ),
  401. start_at=start_at,
  402. tenant_id=pipeline.tenant_id,
  403. node_id=node_id,
  404. )
  405. workflow_node_execution.workflow_id = draft_workflow.id
  406. # Create repository and save the node execution
  407. repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
  408. session_factory=db.engine,
  409. user=account,
  410. app_id=pipeline.id,
  411. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  412. )
  413. repository.save(workflow_node_execution)
  414. # Convert node_execution to WorkflowNodeExecution after save
  415. workflow_node_execution_db_model = self._node_execution_service_repo.get_execution_by_id(
  416. workflow_node_execution.id
  417. )
  418. with Session(bind=db.engine) as session, session.begin():
  419. draft_var_saver = DraftVariableSaver(
  420. session=session,
  421. app_id=pipeline.id,
  422. node_id=workflow_node_execution.node_id,
  423. node_type=NodeType(workflow_node_execution.node_type),
  424. enclosing_node_id=enclosing_node_id,
  425. node_execution_id=workflow_node_execution.id,
  426. user=account,
  427. )
  428. draft_var_saver.save(
  429. process_data=workflow_node_execution.process_data,
  430. outputs=workflow_node_execution.outputs,
  431. )
  432. session.commit()
  433. return workflow_node_execution_db_model
  434. def run_datasource_workflow_node(
  435. self,
  436. pipeline: Pipeline,
  437. node_id: str,
  438. user_inputs: dict,
  439. account: Account,
  440. datasource_type: str,
  441. is_published: bool,
  442. credential_id: Optional[str] = None,
  443. ) -> Generator[Mapping[str, Any], None, None]:
  444. """
  445. Run published workflow datasource
  446. """
  447. try:
  448. if is_published:
  449. # fetch published workflow by app_model
  450. workflow = self.get_published_workflow(pipeline=pipeline)
  451. else:
  452. workflow = self.get_draft_workflow(pipeline=pipeline)
  453. if not workflow:
  454. raise ValueError("Workflow not initialized")
  455. # run draft workflow node
  456. datasource_node_data = None
  457. datasource_nodes = workflow.graph_dict.get("nodes", [])
  458. for datasource_node in datasource_nodes:
  459. if datasource_node.get("id") == node_id:
  460. datasource_node_data = datasource_node.get("data", {})
  461. break
  462. if not datasource_node_data:
  463. raise ValueError("Datasource node data not found")
  464. variables_map = {}
  465. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  466. for key, value in datasource_parameters.items():
  467. param_value = value.get("value")
  468. if not param_value:
  469. variables_map[key] = param_value
  470. elif isinstance(param_value, str):
  471. # handle string type parameter value, check if it contains variable reference pattern
  472. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  473. match = re.match(pattern, param_value)
  474. if match:
  475. # extract variable path and try to get value from user inputs
  476. full_path = match.group(1)
  477. last_part = full_path.split(".")[-1]
  478. variables_map[key] = user_inputs.get(last_part, param_value)
  479. else:
  480. variables_map[key] = param_value
  481. elif isinstance(param_value, list) and param_value:
  482. # handle list type parameter value, check if the last element is in user inputs
  483. last_part = param_value[-1]
  484. variables_map[key] = user_inputs.get(last_part, param_value)
  485. else:
  486. # other type directly use original value
  487. variables_map[key] = param_value
  488. from core.datasource.datasource_manager import DatasourceManager
  489. datasource_runtime = DatasourceManager.get_datasource_runtime(
  490. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  491. datasource_name=datasource_node_data.get("datasource_name"),
  492. tenant_id=pipeline.tenant_id,
  493. datasource_type=DatasourceProviderType(datasource_type),
  494. )
  495. datasource_provider_service = DatasourceProviderService()
  496. credentials = datasource_provider_service.get_datasource_credentials(
  497. tenant_id=pipeline.tenant_id,
  498. provider=datasource_node_data.get("provider_name"),
  499. plugin_id=datasource_node_data.get("plugin_id"),
  500. credential_id=credential_id,
  501. )
  502. if credentials:
  503. datasource_runtime.runtime.credentials = credentials
  504. match datasource_type:
  505. case DatasourceProviderType.ONLINE_DOCUMENT:
  506. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  507. online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
  508. datasource_runtime.get_online_document_pages(
  509. user_id=account.id,
  510. datasource_parameters=user_inputs,
  511. provider_type=datasource_runtime.datasource_provider_type(),
  512. )
  513. )
  514. start_time = time.time()
  515. start_event = DatasourceProcessingEvent(
  516. total=0,
  517. completed=0,
  518. )
  519. yield start_event.model_dump()
  520. try:
  521. for message in online_document_result:
  522. end_time = time.time()
  523. online_document_event = DatasourceCompletedEvent(
  524. data=message.result, time_consuming=round(end_time - start_time, 2)
  525. )
  526. yield online_document_event.model_dump()
  527. except Exception as e:
  528. logger.exception("Error during online document.")
  529. yield DatasourceErrorEvent(error=str(e)).model_dump()
  530. case DatasourceProviderType.ONLINE_DRIVE:
  531. datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
  532. online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = (
  533. datasource_runtime.online_drive_browse_files(
  534. user_id=account.id,
  535. request=OnlineDriveBrowseFilesRequest(
  536. bucket=user_inputs.get("bucket"),
  537. prefix=user_inputs.get("prefix", ""),
  538. max_keys=user_inputs.get("max_keys", 20),
  539. next_page_parameters=user_inputs.get("next_page_parameters"),
  540. ),
  541. provider_type=datasource_runtime.datasource_provider_type(),
  542. )
  543. )
  544. start_time = time.time()
  545. start_event = DatasourceProcessingEvent(
  546. total=0,
  547. completed=0,
  548. )
  549. yield start_event.model_dump()
  550. for message in online_drive_result:
  551. end_time = time.time()
  552. online_drive_event = DatasourceCompletedEvent(
  553. data=message.result,
  554. time_consuming=round(end_time - start_time, 2),
  555. total=None,
  556. completed=None,
  557. )
  558. yield online_drive_event.model_dump()
  559. case DatasourceProviderType.WEBSITE_CRAWL:
  560. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  561. website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
  562. datasource_runtime.get_website_crawl(
  563. user_id=account.id,
  564. datasource_parameters=variables_map,
  565. provider_type=datasource_runtime.datasource_provider_type(),
  566. )
  567. )
  568. start_time = time.time()
  569. try:
  570. for message in website_crawl_result:
  571. end_time = time.time()
  572. if message.result.status == "completed":
  573. crawl_event = DatasourceCompletedEvent(
  574. data=message.result.web_info_list or [],
  575. total=message.result.total,
  576. completed=message.result.completed,
  577. time_consuming=round(end_time - start_time, 2),
  578. )
  579. else:
  580. crawl_event = DatasourceProcessingEvent(
  581. total=message.result.total,
  582. completed=message.result.completed,
  583. )
  584. yield crawl_event.model_dump()
  585. except Exception as e:
  586. logger.exception("Error during website crawl.")
  587. yield DatasourceErrorEvent(error=str(e)).model_dump()
  588. case _:
  589. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  590. except Exception as e:
  591. logger.exception("Error in run_datasource_workflow_node.")
  592. yield DatasourceErrorEvent(error=str(e)).model_dump()
  593. def run_datasource_node_preview(
  594. self,
  595. pipeline: Pipeline,
  596. node_id: str,
  597. user_inputs: dict,
  598. account: Account,
  599. datasource_type: str,
  600. is_published: bool,
  601. credential_id: Optional[str] = None,
  602. ) -> Mapping[str, Any]:
  603. """
  604. Run published workflow datasource
  605. """
  606. try:
  607. if is_published:
  608. # fetch published workflow by app_model
  609. workflow = self.get_published_workflow(pipeline=pipeline)
  610. else:
  611. workflow = self.get_draft_workflow(pipeline=pipeline)
  612. if not workflow:
  613. raise ValueError("Workflow not initialized")
  614. # run draft workflow node
  615. datasource_node_data = None
  616. datasource_nodes = workflow.graph_dict.get("nodes", [])
  617. for datasource_node in datasource_nodes:
  618. if datasource_node.get("id") == node_id:
  619. datasource_node_data = datasource_node.get("data", {})
  620. break
  621. if not datasource_node_data:
  622. raise ValueError("Datasource node data not found")
  623. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  624. for key, value in datasource_parameters.items():
  625. if not user_inputs.get(key):
  626. user_inputs[key] = value["value"]
  627. from core.datasource.datasource_manager import DatasourceManager
  628. datasource_runtime = DatasourceManager.get_datasource_runtime(
  629. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  630. datasource_name=datasource_node_data.get("datasource_name"),
  631. tenant_id=pipeline.tenant_id,
  632. datasource_type=DatasourceProviderType(datasource_type),
  633. )
  634. datasource_provider_service = DatasourceProviderService()
  635. credentials = datasource_provider_service.get_datasource_credentials(
  636. tenant_id=pipeline.tenant_id,
  637. provider=datasource_node_data.get("provider_name"),
  638. plugin_id=datasource_node_data.get("plugin_id"),
  639. credential_id=credential_id,
  640. )
  641. if credentials:
  642. datasource_runtime.runtime.credentials = credentials
  643. match datasource_type:
  644. case DatasourceProviderType.ONLINE_DOCUMENT:
  645. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  646. online_document_result: Generator[DatasourceMessage, None, None] = (
  647. datasource_runtime.get_online_document_page_content(
  648. user_id=account.id,
  649. datasource_parameters=GetOnlineDocumentPageContentRequest(
  650. workspace_id=user_inputs.get("workspace_id", ""),
  651. page_id=user_inputs.get("page_id", ""),
  652. type=user_inputs.get("type", ""),
  653. ),
  654. provider_type=datasource_type,
  655. )
  656. )
  657. try:
  658. variables: dict[str, Any] = {}
  659. for message in online_document_result:
  660. if message.type == DatasourceMessage.MessageType.VARIABLE:
  661. assert isinstance(message.message, DatasourceMessage.VariableMessage)
  662. variable_name = message.message.variable_name
  663. variable_value = message.message.variable_value
  664. if message.message.stream:
  665. if not isinstance(variable_value, str):
  666. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  667. if variable_name not in variables:
  668. variables[variable_name] = ""
  669. variables[variable_name] += variable_value
  670. else:
  671. variables[variable_name] = variable_value
  672. return variables
  673. except Exception as e:
  674. logger.exception("Error during get online document content.")
  675. raise RuntimeError(str(e))
  676. # TODO Online Drive
  677. case _:
  678. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  679. except Exception as e:
  680. logger.exception("Error in run_datasource_node_preview.")
  681. raise RuntimeError(str(e))
  682. def run_free_workflow_node(
  683. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  684. ) -> WorkflowNodeExecution:
  685. """
  686. Run draft workflow node
  687. """
  688. # run draft workflow node
  689. start_at = time.perf_counter()
  690. workflow_node_execution = self._handle_node_run_result(
  691. getter=lambda: WorkflowEntry.run_free_node(
  692. node_id=node_id,
  693. node_data=node_data,
  694. tenant_id=tenant_id,
  695. user_id=user_id,
  696. user_inputs=user_inputs,
  697. ),
  698. start_at=start_at,
  699. tenant_id=tenant_id,
  700. node_id=node_id,
  701. )
  702. return workflow_node_execution
  703. def _handle_node_run_result(
  704. self,
  705. getter: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]],
  706. start_at: float,
  707. tenant_id: str,
  708. node_id: str,
  709. ) -> WorkflowNodeExecution:
  710. """
  711. Handle node run result
  712. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  713. :param start_at: float
  714. :param tenant_id: str
  715. :param node_id: str
  716. """
  717. try:
  718. node_instance, generator = getter()
  719. node_run_result: NodeRunResult | None = None
  720. for event in generator:
  721. if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)):
  722. node_run_result = event.node_run_result
  723. # sign output files
  724. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
  725. break
  726. if not node_run_result:
  727. raise ValueError("Node run failed with no run result")
  728. # single step debug mode error handling return
  729. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.error_strategy:
  730. node_error_args: dict[str, Any] = {
  731. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  732. "error": node_run_result.error,
  733. "inputs": node_run_result.inputs,
  734. "metadata": {"error_strategy": node_instance.error_strategy},
  735. }
  736. if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  737. node_run_result = NodeRunResult(
  738. **node_error_args,
  739. outputs={
  740. **node_instance.default_value_dict,
  741. "error_message": node_run_result.error,
  742. "error_type": node_run_result.error_type,
  743. },
  744. )
  745. else:
  746. node_run_result = NodeRunResult(
  747. **node_error_args,
  748. outputs={
  749. "error_message": node_run_result.error,
  750. "error_type": node_run_result.error_type,
  751. },
  752. )
  753. run_succeeded = node_run_result.status in (
  754. WorkflowNodeExecutionStatus.SUCCEEDED,
  755. WorkflowNodeExecutionStatus.EXCEPTION,
  756. )
  757. error = node_run_result.error if not run_succeeded else None
  758. except WorkflowNodeRunFailedError as e:
  759. node_instance = e._node
  760. run_succeeded = False
  761. node_run_result = None
  762. error = e._error
  763. workflow_node_execution = WorkflowNodeExecution(
  764. id=str(uuid4()),
  765. workflow_id=node_instance.workflow_id,
  766. index=1,
  767. node_id=node_id,
  768. node_type=node_instance.node_type,
  769. title=node_instance.title,
  770. elapsed_time=time.perf_counter() - start_at,
  771. finished_at=datetime.now(UTC).replace(tzinfo=None),
  772. created_at=datetime.now(UTC).replace(tzinfo=None),
  773. )
  774. if run_succeeded and node_run_result:
  775. # create workflow node execution
  776. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  777. process_data = (
  778. WorkflowEntry.handle_special_values(node_run_result.process_data)
  779. if node_run_result.process_data
  780. else None
  781. )
  782. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  783. workflow_node_execution.inputs = inputs
  784. workflow_node_execution.process_data = process_data
  785. workflow_node_execution.outputs = outputs
  786. workflow_node_execution.metadata = node_run_result.metadata
  787. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  788. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  789. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  790. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  791. workflow_node_execution.error = node_run_result.error
  792. else:
  793. # create workflow node execution
  794. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
  795. workflow_node_execution.error = error
  796. # update document status
  797. variable_pool = node_instance.graph_runtime_state.variable_pool
  798. invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
  799. if invoke_from:
  800. if invoke_from.value == InvokeFrom.PUBLISHED.value:
  801. document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
  802. if document_id:
  803. document = db.session.query(Document).filter(Document.id == document_id.value).first()
  804. if document:
  805. document.indexing_status = "error"
  806. document.error = error
  807. db.session.add(document)
  808. db.session.commit()
  809. return workflow_node_execution
  810. def update_workflow(
  811. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  812. ) -> Optional[Workflow]:
  813. """
  814. Update workflow attributes
  815. :param session: SQLAlchemy database session
  816. :param workflow_id: Workflow ID
  817. :param tenant_id: Tenant ID
  818. :param account_id: Account ID (for permission check)
  819. :param data: Dictionary containing fields to update
  820. :return: Updated workflow or None if not found
  821. """
  822. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  823. workflow = session.scalar(stmt)
  824. if not workflow:
  825. return None
  826. allowed_fields = ["marked_name", "marked_comment"]
  827. for field, value in data.items():
  828. if field in allowed_fields:
  829. setattr(workflow, field, value)
  830. workflow.updated_by = account_id
  831. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  832. return workflow
  833. def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  834. """
  835. Get first step parameters of rag pipeline
  836. """
  837. workflow = (
  838. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  839. )
  840. if not workflow:
  841. raise ValueError("Workflow not initialized")
  842. datasource_node_data = None
  843. datasource_nodes = workflow.graph_dict.get("nodes", [])
  844. for datasource_node in datasource_nodes:
  845. if datasource_node.get("id") == node_id:
  846. datasource_node_data = datasource_node.get("data", {})
  847. break
  848. if not datasource_node_data:
  849. raise ValueError("Datasource node data not found")
  850. variables = workflow.rag_pipeline_variables
  851. if variables:
  852. variables_map = {item["variable"]: item for item in variables}
  853. else:
  854. return []
  855. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  856. user_input_variables_keys = []
  857. user_input_variables = []
  858. for _, value in datasource_parameters.items():
  859. if value.get("value") and isinstance(value.get("value"), str):
  860. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  861. match = re.match(pattern, value["value"])
  862. if match:
  863. full_path = match.group(1)
  864. last_part = full_path.split(".")[-1]
  865. user_input_variables_keys.append(last_part)
  866. elif value.get("value") and isinstance(value.get("value"), list):
  867. last_part = value.get("value")[-1]
  868. user_input_variables_keys.append(last_part)
  869. for key, value in variables_map.items():
  870. if key in user_input_variables_keys:
  871. user_input_variables.append(value)
  872. return user_input_variables
  873. def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
  874. """
  875. Get second step parameters of rag pipeline
  876. """
  877. workflow = (
  878. self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
  879. )
  880. if not workflow:
  881. raise ValueError("Workflow not initialized")
  882. # get second step node
  883. rag_pipeline_variables = workflow.rag_pipeline_variables
  884. if not rag_pipeline_variables:
  885. return []
  886. variables_map = {item["variable"]: item for item in rag_pipeline_variables}
  887. # get datasource node data
  888. datasource_node_data = None
  889. datasource_nodes = workflow.graph_dict.get("nodes", [])
  890. for datasource_node in datasource_nodes:
  891. if datasource_node.get("id") == node_id:
  892. datasource_node_data = datasource_node.get("data", {})
  893. break
  894. if datasource_node_data:
  895. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  896. for _, value in datasource_parameters.items():
  897. if value.get("value") and isinstance(value.get("value"), str):
  898. pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
  899. match = re.match(pattern, value["value"])
  900. if match:
  901. full_path = match.group(1)
  902. last_part = full_path.split(".")[-1]
  903. variables_map.pop(last_part, None)
  904. elif value.get("value") and isinstance(value.get("value"), list):
  905. last_part = value.get("value")[-1]
  906. variables_map.pop(last_part, None)
  907. all_second_step_variables = list(variables_map.values())
  908. datasource_provider_variables = [
  909. item
  910. for item in all_second_step_variables
  911. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  912. ]
  913. return datasource_provider_variables
  914. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  915. """
  916. Get debug workflow run list
  917. Only return triggered_from == debugging
  918. :param app_model: app model
  919. :param args: request args
  920. """
  921. limit = int(args.get("limit", 20))
  922. base_query = db.session.query(WorkflowRun).filter(
  923. WorkflowRun.tenant_id == pipeline.tenant_id,
  924. WorkflowRun.app_id == pipeline.id,
  925. or_(
  926. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  927. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  928. ),
  929. )
  930. if args.get("last_id"):
  931. last_workflow_run = base_query.filter(
  932. WorkflowRun.id == args.get("last_id"),
  933. ).first()
  934. if not last_workflow_run:
  935. raise ValueError("Last workflow run not exists")
  936. workflow_runs = (
  937. base_query.filter(
  938. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  939. )
  940. .order_by(WorkflowRun.created_at.desc())
  941. .limit(limit)
  942. .all()
  943. )
  944. else:
  945. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  946. has_more = False
  947. if len(workflow_runs) == limit:
  948. current_page_first_workflow_run = workflow_runs[-1]
  949. rest_count = base_query.filter(
  950. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  951. WorkflowRun.id != current_page_first_workflow_run.id,
  952. ).count()
  953. if rest_count > 0:
  954. has_more = True
  955. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  956. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  957. """
  958. Get workflow run detail
  959. :param app_model: app model
  960. :param run_id: workflow run id
  961. """
  962. workflow_run = (
  963. db.session.query(WorkflowRun)
  964. .filter(
  965. WorkflowRun.tenant_id == pipeline.tenant_id,
  966. WorkflowRun.app_id == pipeline.id,
  967. WorkflowRun.id == run_id,
  968. )
  969. .first()
  970. )
  971. return workflow_run
  972. def get_rag_pipeline_workflow_run_node_executions(
  973. self,
  974. pipeline: Pipeline,
  975. run_id: str,
  976. user: Account | EndUser,
  977. ) -> list[WorkflowNodeExecutionModel]:
  978. """
  979. Get workflow run node execution list
  980. """
  981. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  982. contexts.plugin_tool_providers.set({})
  983. contexts.plugin_tool_providers_lock.set(threading.Lock())
  984. if not workflow_run:
  985. return []
  986. # Use the repository to get the node execution
  987. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  988. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  989. )
  990. # Use the repository to get the node executions with ordering
  991. order_config = OrderConfig(order_by=["created_at"], order_direction="asc")
  992. node_executions = repository.get_db_models_by_workflow_run(
  993. workflow_run_id=run_id,
  994. order_config=order_config,
  995. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  996. )
  997. return list(node_executions)
  998. @classmethod
  999. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  1000. """
  1001. Publish customized pipeline template
  1002. """
  1003. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  1004. if not pipeline:
  1005. raise ValueError("Pipeline not found")
  1006. if not pipeline.workflow_id:
  1007. raise ValueError("Pipeline workflow not found")
  1008. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  1009. if not workflow:
  1010. raise ValueError("Workflow not found")
  1011. with Session(db.engine) as session:
  1012. dataset = pipeline.retrieve_dataset(session=session)
  1013. if not dataset:
  1014. raise ValueError("Dataset not found")
  1015. # check template name is exist
  1016. template_name = args.get("name")
  1017. if template_name:
  1018. template = (
  1019. db.session.query(PipelineCustomizedTemplate)
  1020. .filter(
  1021. PipelineCustomizedTemplate.name == template_name,
  1022. PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
  1023. )
  1024. .first()
  1025. )
  1026. if template:
  1027. raise ValueError("Template name is already exists")
  1028. max_position = (
  1029. db.session.query(func.max(PipelineCustomizedTemplate.position))
  1030. .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
  1031. .scalar()
  1032. )
  1033. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  1034. with Session(db.engine) as session:
  1035. rag_pipeline_dsl_service = RagPipelineDslService(session)
  1036. dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
  1037. pipeline_customized_template = PipelineCustomizedTemplate(
  1038. name=args.get("name"),
  1039. description=args.get("description"),
  1040. icon=args.get("icon_info"),
  1041. tenant_id=pipeline.tenant_id,
  1042. yaml_content=dsl,
  1043. position=max_position + 1 if max_position else 1,
  1044. chunk_structure=dataset.chunk_structure,
  1045. language="en-US",
  1046. created_by=current_user.id,
  1047. )
  1048. db.session.add(pipeline_customized_template)
  1049. db.session.commit()
  1050. def is_workflow_exist(self, pipeline: Pipeline) -> bool:
  1051. return (
  1052. db.session.query(Workflow)
  1053. .filter(
  1054. Workflow.tenant_id == pipeline.tenant_id,
  1055. Workflow.app_id == pipeline.id,
  1056. Workflow.version == Workflow.VERSION_DRAFT,
  1057. )
  1058. .count()
  1059. ) > 0
  1060. def get_node_last_run(
  1061. self, pipeline: Pipeline, workflow: Workflow, node_id: str
  1062. ) -> WorkflowNodeExecutionModel | None:
  1063. node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
  1064. sessionmaker(db.engine)
  1065. )
  1066. node_exec = node_execution_service_repo.get_node_last_execution(
  1067. tenant_id=pipeline.tenant_id,
  1068. app_id=pipeline.id,
  1069. workflow_id=workflow.id,
  1070. node_id=node_id,
  1071. )
  1072. return node_exec
  1073. def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account):
  1074. """
  1075. Set datasource variables
  1076. """
  1077. # fetch draft workflow by app_model
  1078. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  1079. if not draft_workflow:
  1080. raise ValueError("Workflow not initialized")
  1081. # run draft workflow node
  1082. start_at = time.perf_counter()
  1083. node_id = args.get("start_node_id")
  1084. if not node_id:
  1085. raise ValueError("Node id is required")
  1086. node_config = draft_workflow.get_node_config_by_id(node_id)
  1087. eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
  1088. if eclosing_node_type_and_id:
  1089. _, enclosing_node_id = eclosing_node_type_and_id
  1090. else:
  1091. enclosing_node_id = None
  1092. system_inputs = SystemVariable(
  1093. datasource_type=args.get("datasource_type", "online_document"),
  1094. datasource_info=args.get("datasource_info", {}),
  1095. )
  1096. workflow_node_execution = self._handle_node_run_result(
  1097. getter=lambda: WorkflowEntry.single_step_run(
  1098. workflow=draft_workflow,
  1099. node_id=node_id,
  1100. user_inputs={},
  1101. user_id=current_user.id,
  1102. variable_pool=VariablePool(
  1103. system_variables=system_inputs,
  1104. user_inputs={},
  1105. environment_variables=[],
  1106. conversation_variables=[],
  1107. rag_pipeline_variables=[],
  1108. ),
  1109. variable_loader=DraftVarLoader(
  1110. engine=db.engine,
  1111. app_id=pipeline.id,
  1112. tenant_id=pipeline.tenant_id,
  1113. ),
  1114. ),
  1115. start_at=start_at,
  1116. tenant_id=pipeline.tenant_id,
  1117. node_id=node_id,
  1118. )
  1119. workflow_node_execution.workflow_id = draft_workflow.id
  1120. # Create repository and save the node execution
  1121. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  1122. session_factory=db.engine,
  1123. user=current_user,
  1124. app_id=pipeline.id,
  1125. triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
  1126. )
  1127. repository.save(workflow_node_execution)
  1128. # Convert node_execution to WorkflowNodeExecution after save
  1129. workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution)
  1130. with Session(bind=db.engine) as session, session.begin():
  1131. draft_var_saver = DraftVariableSaver(
  1132. session=session,
  1133. app_id=pipeline.id,
  1134. node_id=workflow_node_execution_db_model.node_id,
  1135. node_type=NodeType(workflow_node_execution_db_model.node_type),
  1136. enclosing_node_id=enclosing_node_id,
  1137. node_execution_id=workflow_node_execution.id,
  1138. user=current_user,
  1139. )
  1140. draft_var_saver.save(
  1141. process_data=workflow_node_execution.process_data,
  1142. outputs=workflow_node_execution.outputs,
  1143. )
  1144. session.commit()
  1145. return workflow_node_execution_db_model
  1146. def get_recommended_plugins(self) -> dict:
  1147. # Query active recommended plugins
  1148. pipeline_recommended_plugins = (
  1149. db.session.query(PipelineRecommendedPlugin)
  1150. .filter(PipelineRecommendedPlugin.active == True)
  1151. .order_by(PipelineRecommendedPlugin.position.asc())
  1152. .all()
  1153. )
  1154. if not pipeline_recommended_plugins:
  1155. return {
  1156. "installed_recommended_plugins": [],
  1157. "uninstalled_recommended_plugins": [],
  1158. }
  1159. # Batch fetch plugin manifests
  1160. plugin_ids = [plugin.plugin_id for plugin in pipeline_recommended_plugins]
  1161. providers = BuiltinToolManageService.list_builtin_tools(
  1162. user_id=current_user.id,
  1163. tenant_id=current_user.current_tenant_id,
  1164. )
  1165. providers_map = {provider.plugin_id: provider.to_dict() for provider in providers}
  1166. plugin_manifests = marketplace.batch_fetch_plugin_manifests(plugin_ids)
  1167. plugin_manifests_map = {manifest.plugin_id: manifest for manifest in plugin_manifests}
  1168. installed_plugin_list = []
  1169. uninstalled_plugin_list = []
  1170. for plugin_id in plugin_ids:
  1171. if providers_map.get(plugin_id):
  1172. installed_plugin_list.append(providers_map.get(plugin_id))
  1173. else:
  1174. plugin_manifest = plugin_manifests_map.get(plugin_id)
  1175. if plugin_manifest:
  1176. uninstalled_plugin_list.append(
  1177. {
  1178. "plugin_id": plugin_id,
  1179. "name": plugin_manifest.name,
  1180. "icon": plugin_manifest.icon,
  1181. "plugin_unique_identifier": plugin_manifest.latest_package_identifier,
  1182. }
  1183. )
  1184. # Build recommended plugins list
  1185. return {
  1186. "installed_recommended_plugins": installed_plugin_list,
  1187. "uninstalled_recommended_plugins": uninstalled_plugin_list,
  1188. }