You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 39KB

hace 6 meses
hace 5 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 5 meses
hace 7 meses
hace 4 meses
hace 6 meses
hace 7 meses
hace 5 meses
hace 6 meses
hace 4 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 4 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 4 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 6 meses
hace 7 meses
hace 5 meses
hace 7 meses
hace 5 meses
hace 5 meses
hace 7 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 7 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 4 meses
hace 6 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 4 meses
hace 6 meses
hace 4 meses
hace 6 meses
hace 4 meses
hace 4 meses
hace 6 meses
hace 4 meses
hace 6 meses
hace 4 meses
hace 6 meses
hace 6 meses
hace 6 meses
hace 6 meses
hace 6 meses
hace 6 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 6 meses
hace 6 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 6 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 4 meses
hace 4 meses
hace 4 meses
hace 4 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 4 meses
hace 6 meses
hace 5 meses
hace 6 meses
hace 6 meses
hace 6 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 6 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 5 meses
hace 4 meses
hace 4 meses
hace 4 meses
hace 4 meses
hace 4 meses
hace 5 meses
hace 4 meses
hace 4 meses
hace 4 meses
hace 5 meses
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. import json
  2. import re
  3. import threading
  4. import time
  5. from collections.abc import Callable, Generator, Sequence
  6. from datetime import UTC, datetime
  7. from typing import Any, Optional, cast
  8. from uuid import uuid4
  9. from flask_login import current_user
  10. from sqlalchemy import func, or_, select
  11. from sqlalchemy.orm import Session
  12. import contexts
  13. from configs import dify_config
  14. from core.app.entities.app_invoke_entities import InvokeFrom
  15. from core.datasource.entities.datasource_entities import (
  16. DatasourceProviderType,
  17. OnlineDocumentPagesMessage,
  18. WebsiteCrawlMessage,
  19. )
  20. from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
  21. from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
  22. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  23. from core.variables.variables import Variable
  24. from core.workflow.entities.node_entities import NodeRunResult
  25. from core.workflow.entities.workflow_node_execution import (
  26. WorkflowNodeExecution,
  27. WorkflowNodeExecutionStatus,
  28. )
  29. from core.workflow.enums import SystemVariableKey
  30. from core.workflow.errors import WorkflowNodeRunFailedError
  31. from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent
  32. from core.workflow.nodes.base.node import BaseNode
  33. from core.workflow.nodes.enums import ErrorStrategy, NodeType
  34. from core.workflow.nodes.event.event import RunCompletedEvent
  35. from core.workflow.nodes.event.types import NodeEvent
  36. from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
  37. from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
  38. from core.workflow.workflow_entry import WorkflowEntry
  39. from extensions.ext_database import db
  40. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  41. from models.account import Account
  42. from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
  43. from models.enums import WorkflowRunTriggeredFrom
  44. from models.model import EndUser
  45. from models.workflow import (
  46. Workflow,
  47. WorkflowNodeExecutionModel,
  48. WorkflowNodeExecutionTriggeredFrom,
  49. WorkflowRun,
  50. WorkflowType,
  51. )
  52. from services.dataset_service import DatasetService
  53. from services.datasource_provider_service import DatasourceProviderService
  54. from services.entities.knowledge_entities.rag_pipeline_entities import (
  55. KnowledgeConfiguration,
  56. PipelineTemplateInfoEntity,
  57. )
  58. from services.errors.app import WorkflowHashNotEqualError
  59. from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
  60. class RagPipelineService:
  61. @classmethod
  62. def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
  63. if type == "built-in":
  64. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  65. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  66. result = retrieval_instance.get_pipeline_templates(language)
  67. if not result.get("pipeline_templates") and language != "en-US":
  68. template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
  69. result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
  70. return result
  71. else:
  72. mode = "customized"
  73. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  74. result = retrieval_instance.get_pipeline_templates(language)
  75. return result
  76. @classmethod
  77. def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
  78. """
  79. Get pipeline template detail.
  80. :param template_id: template id
  81. :return:
  82. """
  83. if type == "built-in":
  84. mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
  85. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  86. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  87. else:
  88. mode = "customized"
  89. retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
  90. result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
  91. return result
  92. @classmethod
  93. def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
  94. """
  95. Update pipeline template.
  96. :param template_id: template id
  97. :param template_info: template info
  98. """
  99. customized_template: PipelineCustomizedTemplate | None = (
  100. db.session.query(PipelineCustomizedTemplate)
  101. .filter(
  102. PipelineCustomizedTemplate.id == template_id,
  103. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  104. )
  105. .first()
  106. )
  107. if not customized_template:
  108. raise ValueError("Customized pipeline template not found.")
  109. customized_template.name = template_info.name
  110. customized_template.description = template_info.description
  111. customized_template.icon = template_info.icon_info.model_dump()
  112. customized_template.updated_by = current_user.id
  113. db.session.commit()
  114. return customized_template
  115. @classmethod
  116. def delete_customized_pipeline_template(cls, template_id: str):
  117. """
  118. Delete customized pipeline template.
  119. """
  120. customized_template: PipelineCustomizedTemplate | None = (
  121. db.session.query(PipelineCustomizedTemplate)
  122. .filter(
  123. PipelineCustomizedTemplate.id == template_id,
  124. PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
  125. )
  126. .first()
  127. )
  128. if not customized_template:
  129. raise ValueError("Customized pipeline template not found.")
  130. db.session.delete(customized_template)
  131. db.session.commit()
  132. def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  133. """
  134. Get draft workflow
  135. """
  136. # fetch draft workflow by rag pipeline
  137. workflow = (
  138. db.session.query(Workflow)
  139. .filter(
  140. Workflow.tenant_id == pipeline.tenant_id,
  141. Workflow.app_id == pipeline.id,
  142. Workflow.version == "draft",
  143. )
  144. .first()
  145. )
  146. # return draft workflow
  147. return workflow
  148. def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
  149. """
  150. Get published workflow
  151. """
  152. if not pipeline.workflow_id:
  153. return None
  154. # fetch published workflow by workflow_id
  155. workflow = (
  156. db.session.query(Workflow)
  157. .filter(
  158. Workflow.tenant_id == pipeline.tenant_id,
  159. Workflow.app_id == pipeline.id,
  160. Workflow.id == pipeline.workflow_id,
  161. )
  162. .first()
  163. )
  164. return workflow
  165. def get_all_published_workflow(
  166. self,
  167. *,
  168. session: Session,
  169. pipeline: Pipeline,
  170. page: int,
  171. limit: int,
  172. user_id: str | None,
  173. named_only: bool = False,
  174. ) -> tuple[Sequence[Workflow], bool]:
  175. """
  176. Get published workflow with pagination
  177. """
  178. if not pipeline.workflow_id:
  179. return [], False
  180. stmt = (
  181. select(Workflow)
  182. .where(Workflow.app_id == pipeline.id)
  183. .order_by(Workflow.version.desc())
  184. .limit(limit + 1)
  185. .offset((page - 1) * limit)
  186. )
  187. if user_id:
  188. stmt = stmt.where(Workflow.created_by == user_id)
  189. if named_only:
  190. stmt = stmt.where(Workflow.marked_name != "")
  191. workflows = session.scalars(stmt).all()
  192. has_more = len(workflows) > limit
  193. if has_more:
  194. workflows = workflows[:-1]
  195. return workflows, has_more
  196. def sync_draft_workflow(
  197. self,
  198. *,
  199. pipeline: Pipeline,
  200. graph: dict,
  201. unique_hash: Optional[str],
  202. account: Account,
  203. environment_variables: Sequence[Variable],
  204. conversation_variables: Sequence[Variable],
  205. rag_pipeline_variables: list,
  206. ) -> Workflow:
  207. """
  208. Sync draft workflow
  209. :raises WorkflowHashNotEqualError
  210. """
  211. # fetch draft workflow by app_model
  212. workflow = self.get_draft_workflow(pipeline=pipeline)
  213. if workflow and workflow.unique_hash != unique_hash:
  214. raise WorkflowHashNotEqualError()
  215. # create draft workflow if not found
  216. if not workflow:
  217. workflow = Workflow(
  218. tenant_id=pipeline.tenant_id,
  219. app_id=pipeline.id,
  220. features="{}",
  221. type=WorkflowType.RAG_PIPELINE.value,
  222. version="draft",
  223. graph=json.dumps(graph),
  224. created_by=account.id,
  225. environment_variables=environment_variables,
  226. conversation_variables=conversation_variables,
  227. rag_pipeline_variables=rag_pipeline_variables,
  228. )
  229. db.session.add(workflow)
  230. db.session.flush()
  231. pipeline.workflow_id = workflow.id
  232. # update draft workflow if found
  233. else:
  234. workflow.graph = json.dumps(graph)
  235. workflow.updated_by = account.id
  236. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  237. workflow.environment_variables = environment_variables
  238. workflow.conversation_variables = conversation_variables
  239. workflow.rag_pipeline_variables = rag_pipeline_variables
  240. # commit db session changes
  241. db.session.commit()
  242. # trigger workflow events TODO
  243. # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
  244. # return draft workflow
  245. return workflow
  246. def publish_workflow(
  247. self,
  248. *,
  249. session: Session,
  250. pipeline: Pipeline,
  251. account: Account,
  252. ) -> Workflow:
  253. draft_workflow_stmt = select(Workflow).where(
  254. Workflow.tenant_id == pipeline.tenant_id,
  255. Workflow.app_id == pipeline.id,
  256. Workflow.version == "draft",
  257. )
  258. draft_workflow = session.scalar(draft_workflow_stmt)
  259. if not draft_workflow:
  260. raise ValueError("No valid workflow found.")
  261. # create new workflow
  262. workflow = Workflow.new(
  263. tenant_id=pipeline.tenant_id,
  264. app_id=pipeline.id,
  265. type=draft_workflow.type,
  266. version=str(datetime.now(UTC).replace(tzinfo=None)),
  267. graph=draft_workflow.graph,
  268. features=draft_workflow.features,
  269. created_by=account.id,
  270. environment_variables=draft_workflow.environment_variables,
  271. conversation_variables=draft_workflow.conversation_variables,
  272. rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
  273. marked_name="",
  274. marked_comment="",
  275. )
  276. # commit db session changes
  277. session.add(workflow)
  278. graph = workflow.graph_dict
  279. nodes = graph.get("nodes", [])
  280. for node in nodes:
  281. if node.get("data", {}).get("type") == "knowledge-index":
  282. knowledge_configuration = node.get("data", {})
  283. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
  284. # update dataset
  285. dataset = pipeline.dataset
  286. if not dataset:
  287. raise ValueError("Dataset not found")
  288. DatasetService.update_rag_pipeline_dataset_settings(
  289. session=session,
  290. dataset=dataset,
  291. knowledge_configuration=knowledge_configuration,
  292. has_published=pipeline.is_published,
  293. )
  294. # return new workflow
  295. return workflow
  296. def get_default_block_configs(self) -> list[dict]:
  297. """
  298. Get default block configs
  299. """
  300. # return default block config
  301. default_block_configs = []
  302. for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
  303. node_class = node_class_mapping[LATEST_VERSION]
  304. default_config = node_class.get_default_config()
  305. if default_config:
  306. default_block_configs.append(default_config)
  307. return default_block_configs
  308. def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
  309. """
  310. Get default config of node.
  311. :param node_type: node type
  312. :param filters: filter by node config parameters.
  313. :return:
  314. """
  315. node_type_enum = NodeType(node_type)
  316. # return default block config
  317. if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
  318. return None
  319. node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
  320. default_config = node_class.get_default_config(filters=filters)
  321. if not default_config:
  322. return None
  323. return default_config
  324. def run_draft_workflow_node(
  325. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  326. ) -> WorkflowNodeExecution:
  327. """
  328. Run draft workflow node
  329. """
  330. # fetch draft workflow by app_model
  331. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  332. if not draft_workflow:
  333. raise ValueError("Workflow not initialized")
  334. # run draft workflow node
  335. start_at = time.perf_counter()
  336. workflow_node_execution = self._handle_node_run_result(
  337. getter=lambda: WorkflowEntry.single_step_run(
  338. workflow=draft_workflow,
  339. node_id=node_id,
  340. user_inputs=user_inputs,
  341. user_id=account.id,
  342. ),
  343. start_at=start_at,
  344. tenant_id=pipeline.tenant_id,
  345. node_id=node_id,
  346. )
  347. workflow_node_execution.workflow_id = draft_workflow.id
  348. db.session.add(workflow_node_execution)
  349. db.session.commit()
  350. return workflow_node_execution
  351. def run_published_workflow_node(
  352. self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
  353. ) -> WorkflowNodeExecution:
  354. """
  355. Run published workflow node
  356. """
  357. # fetch published workflow by app_model
  358. published_workflow = self.get_published_workflow(pipeline=pipeline)
  359. if not published_workflow:
  360. raise ValueError("Workflow not initialized")
  361. # run draft workflow node
  362. start_at = time.perf_counter()
  363. workflow_node_execution = self._handle_node_run_result(
  364. getter=lambda: WorkflowEntry.single_step_run(
  365. workflow=published_workflow,
  366. node_id=node_id,
  367. user_inputs=user_inputs,
  368. user_id=account.id,
  369. ),
  370. start_at=start_at,
  371. tenant_id=pipeline.tenant_id,
  372. node_id=node_id,
  373. )
  374. workflow_node_execution.workflow_id = published_workflow.id
  375. db.session.add(workflow_node_execution)
  376. db.session.commit()
  377. return workflow_node_execution
  378. # def run_datasource_workflow_node_status(
  379. # self, pipeline: Pipeline, node_id: str, job_id: str, account: Account,
  380. # datasource_type: str, is_published: bool
  381. # ) -> dict:
  382. # """
  383. # Run published workflow datasource
  384. # """
  385. # if is_published:
  386. # # fetch published workflow by app_model
  387. # workflow = self.get_published_workflow(pipeline=pipeline)
  388. # else:
  389. # workflow = self.get_draft_workflow(pipeline=pipeline)
  390. # if not workflow:
  391. # raise ValueError("Workflow not initialized")
  392. #
  393. # # run draft workflow node
  394. # datasource_node_data = None
  395. # start_at = time.perf_counter()
  396. # datasource_nodes = workflow.graph_dict.get("nodes", [])
  397. # for datasource_node in datasource_nodes:
  398. # if datasource_node.get("id") == node_id:
  399. # datasource_node_data = datasource_node.get("data", {})
  400. # break
  401. # if not datasource_node_data:
  402. # raise ValueError("Datasource node data not found")
  403. #
  404. # from core.datasource.datasource_manager import DatasourceManager
  405. #
  406. # datasource_runtime = DatasourceManager.get_datasource_runtime(
  407. # provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  408. # datasource_name=datasource_node_data.get("datasource_name"),
  409. # tenant_id=pipeline.tenant_id,
  410. # datasource_type=DatasourceProviderType(datasource_type),
  411. # )
  412. # datasource_provider_service = DatasourceProviderService()
  413. # credentials = datasource_provider_service.get_real_datasource_credentials(
  414. # tenant_id=pipeline.tenant_id,
  415. # provider=datasource_node_data.get('provider_name'),
  416. # plugin_id=datasource_node_data.get('plugin_id'),
  417. # )
  418. # if credentials:
  419. # datasource_runtime.runtime.credentials = credentials[0].get("credentials")
  420. # match datasource_type:
  421. #
  422. # case DatasourceProviderType.WEBSITE_CRAWL:
  423. # datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  424. # website_crawl_results: list[WebsiteCrawlMessage] = []
  425. # for website_message in datasource_runtime.get_website_crawl(
  426. # user_id=account.id,
  427. # datasource_parameters={"job_id": job_id},
  428. # provider_type=datasource_runtime.datasource_provider_type(),
  429. # ):
  430. # website_crawl_results.append(website_message)
  431. # return {
  432. # "result": [result for result in website_crawl_results.result],
  433. # "status": website_crawl_results.result.status,
  434. # "provider_type": datasource_node_data.get("provider_type"),
  435. # }
  436. # case _:
  437. # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  438. def run_datasource_workflow_node(
  439. self,
  440. pipeline: Pipeline,
  441. node_id: str,
  442. user_inputs: dict,
  443. account: Account,
  444. datasource_type: str,
  445. is_published: bool,
  446. ) -> Generator[str, None, None]:
  447. """
  448. Run published workflow datasource
  449. """
  450. if is_published:
  451. # fetch published workflow by app_model
  452. workflow = self.get_published_workflow(pipeline=pipeline)
  453. else:
  454. workflow = self.get_draft_workflow(pipeline=pipeline)
  455. if not workflow:
  456. raise ValueError("Workflow not initialized")
  457. # run draft workflow node
  458. datasource_node_data = None
  459. start_at = time.perf_counter()
  460. datasource_nodes = workflow.graph_dict.get("nodes", [])
  461. for datasource_node in datasource_nodes:
  462. if datasource_node.get("id") == node_id:
  463. datasource_node_data = datasource_node.get("data", {})
  464. break
  465. if not datasource_node_data:
  466. raise ValueError("Datasource node data not found")
  467. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  468. for key, value in datasource_parameters.items():
  469. if not user_inputs.get(key):
  470. user_inputs[key] = value["value"]
  471. from core.datasource.datasource_manager import DatasourceManager
  472. datasource_runtime = DatasourceManager.get_datasource_runtime(
  473. provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
  474. datasource_name=datasource_node_data.get("datasource_name"),
  475. tenant_id=pipeline.tenant_id,
  476. datasource_type=DatasourceProviderType(datasource_type),
  477. )
  478. datasource_provider_service = DatasourceProviderService()
  479. credentials = datasource_provider_service.get_real_datasource_credentials(
  480. tenant_id=pipeline.tenant_id,
  481. provider=datasource_node_data.get("provider_name"),
  482. plugin_id=datasource_node_data.get("plugin_id"),
  483. )
  484. if credentials:
  485. datasource_runtime.runtime.credentials = credentials[0].get("credentials")
  486. match datasource_type:
  487. case DatasourceProviderType.ONLINE_DOCUMENT:
  488. datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
  489. online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
  490. datasource_runtime.get_online_document_pages(
  491. user_id=account.id,
  492. datasource_parameters=user_inputs,
  493. provider_type=datasource_runtime.datasource_provider_type(),
  494. )
  495. )
  496. start_time = time.time()
  497. for message in online_document_result:
  498. end_time = time.time()
  499. online_document_event = DatasourceRunEvent(
  500. status="completed", data=message.result, time_consuming=round(end_time - start_time, 2)
  501. )
  502. yield json.dumps(online_document_event.model_dump())
  503. case DatasourceProviderType.WEBSITE_CRAWL:
  504. datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
  505. website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl(
  506. user_id=account.id,
  507. datasource_parameters=user_inputs,
  508. provider_type=datasource_runtime.datasource_provider_type(),
  509. )
  510. start_time = time.time()
  511. for message in website_crawl_result:
  512. end_time = time.time()
  513. crawl_event = DatasourceRunEvent(
  514. status=message.result.status,
  515. data=message.result.web_info_list,
  516. total=message.result.total,
  517. completed=message.result.completed,
  518. time_consuming=round(end_time - start_time, 2),
  519. )
  520. yield json.dumps(crawl_event.model_dump())
  521. case _:
  522. raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
  523. def run_free_workflow_node(
  524. self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
  525. ) -> WorkflowNodeExecution:
  526. """
  527. Run draft workflow node
  528. """
  529. # run draft workflow node
  530. start_at = time.perf_counter()
  531. workflow_node_execution = self._handle_node_run_result(
  532. getter=lambda: WorkflowEntry.run_free_node(
  533. node_id=node_id,
  534. node_data=node_data,
  535. tenant_id=tenant_id,
  536. user_id=user_id,
  537. user_inputs=user_inputs,
  538. ),
  539. start_at=start_at,
  540. tenant_id=tenant_id,
  541. node_id=node_id,
  542. )
  543. return workflow_node_execution
  544. def _handle_node_run_result(
  545. self,
  546. getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
  547. start_at: float,
  548. tenant_id: str,
  549. node_id: str,
  550. ) -> WorkflowNodeExecution:
  551. """
  552. Handle node run result
  553. :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
  554. :param start_at: float
  555. :param tenant_id: str
  556. :param node_id: str
  557. """
  558. try:
  559. node_instance, generator = getter()
  560. node_run_result: NodeRunResult | None = None
  561. for event in generator:
  562. if isinstance(event, RunCompletedEvent):
  563. node_run_result = event.run_result
  564. # sign output files
  565. node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
  566. break
  567. if not node_run_result:
  568. raise ValueError("Node run failed with no run result")
  569. # single step debug mode error handling return
  570. if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
  571. node_error_args: dict[str, Any] = {
  572. "status": WorkflowNodeExecutionStatus.EXCEPTION,
  573. "error": node_run_result.error,
  574. "inputs": node_run_result.inputs,
  575. "metadata": {"error_strategy": node_instance.node_data.error_strategy},
  576. }
  577. if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
  578. node_run_result = NodeRunResult(
  579. **node_error_args,
  580. outputs={
  581. **node_instance.node_data.default_value_dict,
  582. "error_message": node_run_result.error,
  583. "error_type": node_run_result.error_type,
  584. },
  585. )
  586. else:
  587. node_run_result = NodeRunResult(
  588. **node_error_args,
  589. outputs={
  590. "error_message": node_run_result.error,
  591. "error_type": node_run_result.error_type,
  592. },
  593. )
  594. run_succeeded = node_run_result.status in (
  595. WorkflowNodeExecutionStatus.SUCCEEDED,
  596. WorkflowNodeExecutionStatus.EXCEPTION,
  597. )
  598. error = node_run_result.error if not run_succeeded else None
  599. except WorkflowNodeRunFailedError as e:
  600. node_instance = e.node_instance
  601. run_succeeded = False
  602. node_run_result = None
  603. error = e.error
  604. workflow_node_execution = WorkflowNodeExecution(
  605. id=str(uuid4()),
  606. workflow_id=node_instance.workflow_id,
  607. index=1,
  608. node_id=node_id,
  609. node_type=node_instance.node_type,
  610. title=node_instance.node_data.title,
  611. elapsed_time=time.perf_counter() - start_at,
  612. finished_at=datetime.now(UTC).replace(tzinfo=None),
  613. created_at=datetime.now(UTC).replace(tzinfo=None),
  614. )
  615. if run_succeeded and node_run_result:
  616. # create workflow node execution
  617. inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
  618. process_data = (
  619. WorkflowEntry.handle_special_values(node_run_result.process_data)
  620. if node_run_result.process_data
  621. else None
  622. )
  623. outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
  624. workflow_node_execution.inputs = inputs
  625. workflow_node_execution.process_data = process_data
  626. workflow_node_execution.outputs = outputs
  627. workflow_node_execution.metadata = node_run_result.metadata
  628. if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  629. workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
  630. elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
  631. workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
  632. workflow_node_execution.error = node_run_result.error
  633. else:
  634. # create workflow node execution
  635. workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
  636. workflow_node_execution.error = error
  637. # update document status
  638. variable_pool = node_instance.graph_runtime_state.variable_pool
  639. invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
  640. if invoke_from:
  641. if invoke_from.value == InvokeFrom.PUBLISHED.value:
  642. document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
  643. if document_id:
  644. document = db.session.query(Document).filter(Document.id == document_id.value).first()
  645. if document:
  646. document.indexing_status = "error"
  647. document.error = error
  648. db.session.add(document)
  649. db.session.commit()
  650. return workflow_node_execution
  651. def update_workflow(
  652. self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
  653. ) -> Optional[Workflow]:
  654. """
  655. Update workflow attributes
  656. :param session: SQLAlchemy database session
  657. :param workflow_id: Workflow ID
  658. :param tenant_id: Tenant ID
  659. :param account_id: Account ID (for permission check)
  660. :param data: Dictionary containing fields to update
  661. :return: Updated workflow or None if not found
  662. """
  663. stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
  664. workflow = session.scalar(stmt)
  665. if not workflow:
  666. return None
  667. allowed_fields = ["marked_name", "marked_comment"]
  668. for field, value in data.items():
  669. if field in allowed_fields:
  670. setattr(workflow, field, value)
  671. workflow.updated_by = account_id
  672. workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
  673. return workflow
  674. def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  675. """
  676. Get second step parameters of rag pipeline
  677. """
  678. workflow = self.get_published_workflow(pipeline=pipeline)
  679. if not workflow:
  680. raise ValueError("Workflow not initialized")
  681. # get second step node
  682. rag_pipeline_variables = workflow.rag_pipeline_variables
  683. if not rag_pipeline_variables:
  684. return []
  685. # get datasource provider
  686. datasource_provider_variables = [
  687. item
  688. for item in rag_pipeline_variables
  689. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  690. ]
  691. return datasource_provider_variables
  692. def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  693. """
  694. Get first step parameters of rag pipeline
  695. """
  696. published_workflow = self.get_published_workflow(pipeline=pipeline)
  697. if not published_workflow:
  698. raise ValueError("Workflow not initialized")
  699. # get second step node
  700. datasource_node_data = None
  701. datasource_nodes = published_workflow.graph_dict.get("nodes", [])
  702. for datasource_node in datasource_nodes:
  703. if datasource_node.get("id") == node_id:
  704. datasource_node_data = datasource_node.get("data", {})
  705. break
  706. if not datasource_node_data:
  707. raise ValueError("Datasource node data not found")
  708. variables = datasource_node_data.get("variables", {})
  709. if variables:
  710. variables_map = {item["variable"]: item for item in variables}
  711. else:
  712. return []
  713. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  714. user_input_variables = []
  715. for key, value in datasource_parameters.items():
  716. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  717. user_input_variables.append(variables_map.get(key, {}))
  718. return user_input_variables
  719. def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  720. """
  721. Get first step parameters of rag pipeline
  722. """
  723. draft_workflow = self.get_draft_workflow(pipeline=pipeline)
  724. if not draft_workflow:
  725. raise ValueError("Workflow not initialized")
  726. # get second step node
  727. datasource_node_data = None
  728. datasource_nodes = draft_workflow.graph_dict.get("nodes", [])
  729. for datasource_node in datasource_nodes:
  730. if datasource_node.get("id") == node_id:
  731. datasource_node_data = datasource_node.get("data", {})
  732. break
  733. if not datasource_node_data:
  734. raise ValueError("Datasource node data not found")
  735. variables = datasource_node_data.get("variables", {})
  736. if variables:
  737. variables_map = {item["variable"]: item for item in variables}
  738. else:
  739. return []
  740. datasource_parameters = datasource_node_data.get("datasource_parameters", {})
  741. user_input_variables = []
  742. for key, value in datasource_parameters.items():
  743. if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
  744. user_input_variables.append(variables_map.get(key, {}))
  745. return user_input_variables
  746. def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
  747. """
  748. Get second step parameters of rag pipeline
  749. """
  750. workflow = self.get_draft_workflow(pipeline=pipeline)
  751. if not workflow:
  752. raise ValueError("Workflow not initialized")
  753. # get second step node
  754. rag_pipeline_variables = workflow.rag_pipeline_variables
  755. if not rag_pipeline_variables:
  756. return []
  757. # get datasource provider
  758. datasource_provider_variables = [
  759. item
  760. for item in rag_pipeline_variables
  761. if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
  762. ]
  763. return datasource_provider_variables
  764. def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
  765. """
  766. Get debug workflow run list
  767. Only return triggered_from == debugging
  768. :param app_model: app model
  769. :param args: request args
  770. """
  771. limit = int(args.get("limit", 20))
  772. base_query = db.session.query(WorkflowRun).filter(
  773. WorkflowRun.tenant_id == pipeline.tenant_id,
  774. WorkflowRun.app_id == pipeline.id,
  775. or_(
  776. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
  777. WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
  778. ),
  779. )
  780. if args.get("last_id"):
  781. last_workflow_run = base_query.filter(
  782. WorkflowRun.id == args.get("last_id"),
  783. ).first()
  784. if not last_workflow_run:
  785. raise ValueError("Last workflow run not exists")
  786. workflow_runs = (
  787. base_query.filter(
  788. WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
  789. )
  790. .order_by(WorkflowRun.created_at.desc())
  791. .limit(limit)
  792. .all()
  793. )
  794. else:
  795. workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
  796. has_more = False
  797. if len(workflow_runs) == limit:
  798. current_page_first_workflow_run = workflow_runs[-1]
  799. rest_count = base_query.filter(
  800. WorkflowRun.created_at < current_page_first_workflow_run.created_at,
  801. WorkflowRun.id != current_page_first_workflow_run.id,
  802. ).count()
  803. if rest_count > 0:
  804. has_more = True
  805. return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
  806. def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
  807. """
  808. Get workflow run detail
  809. :param app_model: app model
  810. :param run_id: workflow run id
  811. """
  812. workflow_run = (
  813. db.session.query(WorkflowRun)
  814. .filter(
  815. WorkflowRun.tenant_id == pipeline.tenant_id,
  816. WorkflowRun.app_id == pipeline.id,
  817. WorkflowRun.id == run_id,
  818. )
  819. .first()
  820. )
  821. return workflow_run
  822. def get_rag_pipeline_workflow_run_node_executions(
  823. self,
  824. pipeline: Pipeline,
  825. run_id: str,
  826. user: Account | EndUser,
  827. ) -> list[WorkflowNodeExecutionModel]:
  828. """
  829. Get workflow run node execution list
  830. """
  831. workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
  832. contexts.plugin_tool_providers.set({})
  833. contexts.plugin_tool_providers_lock.set(threading.Lock())
  834. if not workflow_run:
  835. return []
  836. # Use the repository to get the node execution
  837. repository = SQLAlchemyWorkflowNodeExecutionRepository(
  838. session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
  839. )
  840. # Use the repository to get the node executions with ordering
  841. order_config = OrderConfig(order_by=["index"], order_direction="desc")
  842. node_executions = repository.get_db_models_by_workflow_run(
  843. workflow_run_id=run_id,
  844. order_config=order_config,
  845. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  846. )
  847. return list(node_executions)
  848. @classmethod
  849. def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
  850. """
  851. Publish customized pipeline template
  852. """
  853. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  854. if not pipeline:
  855. raise ValueError("Pipeline not found")
  856. if not pipeline.workflow_id:
  857. raise ValueError("Pipeline workflow not found")
  858. workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  859. if not workflow:
  860. raise ValueError("Workflow not found")
  861. dataset = pipeline.dataset
  862. if not dataset:
  863. raise ValueError("Dataset not found")
  864. max_position = (
  865. db.session.query(func.max(PipelineCustomizedTemplate.position))
  866. .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
  867. .scalar()
  868. )
  869. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  870. dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
  871. pipeline_customized_template = PipelineCustomizedTemplate(
  872. name=args.get("name"),
  873. description=args.get("description"),
  874. icon=args.get("icon_info"),
  875. tenant_id=pipeline.tenant_id,
  876. yaml_content=dsl,
  877. position=max_position + 1 if max_position else 1,
  878. chunk_structure=dataset.chunk_structure,
  879. language="en-US",
  880. created_by=current_user.id,
  881. )
  882. db.session.add(pipeline_customized_template)
  883. db.session.commit()