You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline_transform_service.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. import json
  2. from datetime import UTC, datetime
  3. from pathlib import Path
  4. from uuid import uuid4
  5. import yaml
  6. from flask_login import current_user
  7. from constants import DOCUMENT_EXTENSIONS
  8. from core.plugin.impl.plugin import PluginInstaller
  9. from extensions.ext_database import db
  10. from factories import variable_factory
  11. from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
  12. from models.model import UploadFile
  13. from models.workflow import Workflow, WorkflowType
  14. from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
  15. from services.plugin.plugin_migration import PluginMigration
  16. from services.plugin.plugin_service import PluginService
  17. class RagPipelineTransformService:
  18. def transform_dataset(self, dataset_id: str):
  19. dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
  20. if not dataset:
  21. raise ValueError("Dataset not found")
  22. if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
  23. return {
  24. "pipeline_id": dataset.pipeline_id,
  25. "dataset_id": dataset_id,
  26. "status": "success",
  27. }
  28. if dataset.provider != "vendor":
  29. raise ValueError("External dataset is not supported")
  30. datasource_type = dataset.data_source_type
  31. indexing_technique = dataset.indexing_technique
  32. if not datasource_type and not indexing_technique:
  33. return self._transfrom_to_empty_pipeline(dataset)
  34. doc_form = dataset.doc_form
  35. if not doc_form:
  36. return self._transfrom_to_empty_pipeline(dataset)
  37. retrieval_model = dataset.retrieval_model
  38. pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
  39. # deal dependencies
  40. self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
  41. # Extract app data
  42. workflow_data = pipeline_yaml.get("workflow")
  43. if not workflow_data:
  44. raise ValueError("Missing workflow data for rag pipeline")
  45. graph = workflow_data.get("graph", {})
  46. nodes = graph.get("nodes", [])
  47. new_nodes = []
  48. for node in nodes:
  49. if (
  50. node.get("data", {}).get("type") == "datasource"
  51. and node.get("data", {}).get("provider_type") == "local_file"
  52. ):
  53. node = self._deal_file_extensions(node)
  54. if node.get("data", {}).get("type") == "knowledge-index":
  55. node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
  56. new_nodes.append(node)
  57. if new_nodes:
  58. graph["nodes"] = new_nodes
  59. workflow_data["graph"] = graph
  60. pipeline_yaml["workflow"] = workflow_data
  61. # create pipeline
  62. pipeline = self._create_pipeline(pipeline_yaml)
  63. # save chunk structure to dataset
  64. if doc_form == "hierarchical_model":
  65. dataset.chunk_structure = "hierarchical_model"
  66. elif doc_form == "text_model":
  67. dataset.chunk_structure = "text_model"
  68. else:
  69. raise ValueError("Unsupported doc form")
  70. dataset.runtime_mode = "rag_pipeline"
  71. dataset.pipeline_id = pipeline.id
  72. # deal document data
  73. self._deal_document_data(dataset)
  74. db.session.commit()
  75. return {
  76. "pipeline_id": pipeline.id,
  77. "dataset_id": dataset_id,
  78. "status": "success",
  79. }
  80. def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
  81. pipeline_yaml = {}
  82. if doc_form == "text_model":
  83. match datasource_type:
  84. case "upload_file":
  85. if indexing_technique == "high_quality":
  86. # get graph from transform.file-general-high-quality.yml
  87. with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
  88. pipeline_yaml = yaml.safe_load(f)
  89. if indexing_technique == "economy":
  90. # get graph from transform.file-general-economy.yml
  91. with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
  92. pipeline_yaml = yaml.safe_load(f)
  93. case "notion_import":
  94. if indexing_technique == "high_quality":
  95. # get graph from transform.notion-general-high-quality.yml
  96. with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
  97. pipeline_yaml = yaml.safe_load(f)
  98. if indexing_technique == "economy":
  99. # get graph from transform.notion-general-economy.yml
  100. with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
  101. pipeline_yaml = yaml.safe_load(f)
  102. case "website_crawl":
  103. if indexing_technique == "high_quality":
  104. # get graph from transform.website-crawl-general-high-quality.yml
  105. with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
  106. pipeline_yaml = yaml.safe_load(f)
  107. if indexing_technique == "economy":
  108. # get graph from transform.website-crawl-general-economy.yml
  109. with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
  110. pipeline_yaml = yaml.safe_load(f)
  111. case _:
  112. raise ValueError("Unsupported datasource type")
  113. elif doc_form == "hierarchical_model":
  114. match datasource_type:
  115. case "upload_file":
  116. # get graph from transform.file-parentchild.yml
  117. with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
  118. pipeline_yaml = yaml.safe_load(f)
  119. case "notion_import":
  120. # get graph from transform.notion-parentchild.yml
  121. with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
  122. pipeline_yaml = yaml.safe_load(f)
  123. case "website_crawl":
  124. # get graph from transform.website-crawl-parentchild.yml
  125. with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
  126. pipeline_yaml = yaml.safe_load(f)
  127. case _:
  128. raise ValueError("Unsupported datasource type")
  129. else:
  130. raise ValueError("Unsupported doc form")
  131. return pipeline_yaml
  132. def _deal_file_extensions(self, node: dict):
  133. file_extensions = node.get("data", {}).get("fileExtensions", [])
  134. if not file_extensions:
  135. return node
  136. file_extensions = [file_extension.lower() for file_extension in file_extensions]
  137. node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS
  138. return node
  139. def _deal_knowledge_index(
  140. self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
  141. ):
  142. knowledge_configuration_dict = node.get("data", {})
  143. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
  144. if indexing_technique == "high_quality":
  145. knowledge_configuration.embedding_model = dataset.embedding_model
  146. knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
  147. if retrieval_model:
  148. retrieval_setting = RetrievalSetting(**retrieval_model)
  149. if indexing_technique == "economy":
  150. retrieval_setting.search_method = "keyword_search"
  151. knowledge_configuration.retrieval_model = retrieval_setting
  152. else:
  153. dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
  154. knowledge_configuration_dict.update(knowledge_configuration.model_dump())
  155. node["data"] = knowledge_configuration_dict
  156. return node
  157. def _create_pipeline(
  158. self,
  159. data: dict,
  160. ) -> Pipeline:
  161. """Create a new app or update an existing one."""
  162. pipeline_data = data.get("rag_pipeline", {})
  163. # Initialize pipeline based on mode
  164. workflow_data = data.get("workflow")
  165. if not workflow_data or not isinstance(workflow_data, dict):
  166. raise ValueError("Missing workflow data for rag pipeline")
  167. environment_variables_list = workflow_data.get("environment_variables", [])
  168. environment_variables = [
  169. variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
  170. ]
  171. conversation_variables_list = workflow_data.get("conversation_variables", [])
  172. conversation_variables = [
  173. variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
  174. ]
  175. rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
  176. graph = workflow_data.get("graph", {})
  177. # Create new app
  178. pipeline = Pipeline()
  179. pipeline.id = str(uuid4())
  180. pipeline.tenant_id = current_user.current_tenant_id
  181. pipeline.name = pipeline_data.get("name", "")
  182. pipeline.description = pipeline_data.get("description", "")
  183. pipeline.created_by = current_user.id
  184. pipeline.updated_by = current_user.id
  185. pipeline.is_published = True
  186. pipeline.is_public = True
  187. db.session.add(pipeline)
  188. db.session.flush()
  189. # create draft workflow
  190. draft_workflow = Workflow(
  191. tenant_id=pipeline.tenant_id,
  192. app_id=pipeline.id,
  193. features="{}",
  194. type=WorkflowType.RAG_PIPELINE.value,
  195. version="draft",
  196. graph=json.dumps(graph),
  197. created_by=current_user.id,
  198. environment_variables=environment_variables,
  199. conversation_variables=conversation_variables,
  200. rag_pipeline_variables=rag_pipeline_variables_list,
  201. )
  202. published_workflow = Workflow(
  203. tenant_id=pipeline.tenant_id,
  204. app_id=pipeline.id,
  205. features="{}",
  206. type=WorkflowType.RAG_PIPELINE.value,
  207. version=str(datetime.now(UTC).replace(tzinfo=None)),
  208. graph=json.dumps(graph),
  209. created_by=current_user.id,
  210. environment_variables=environment_variables,
  211. conversation_variables=conversation_variables,
  212. rag_pipeline_variables=rag_pipeline_variables_list,
  213. )
  214. db.session.add(draft_workflow)
  215. db.session.add(published_workflow)
  216. db.session.flush()
  217. pipeline.workflow_id = published_workflow.id
  218. db.session.add(pipeline)
  219. return pipeline
  220. def _deal_dependencies(self, pipeline_yaml: dict, tenant_id: str):
  221. installer_manager = PluginInstaller()
  222. installed_plugins = installer_manager.list_plugins(tenant_id)
  223. plugin_migration = PluginMigration()
  224. installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
  225. dependencies = pipeline_yaml.get("dependencies", [])
  226. need_install_plugin_unique_identifiers = []
  227. for dependency in dependencies:
  228. if dependency.get("type") == "marketplace":
  229. plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier")
  230. plugin_id = plugin_unique_identifier.split(":")[0]
  231. if plugin_id not in installed_plugins_ids:
  232. plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) # type: ignore
  233. if plugin_unique_identifier:
  234. need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
  235. if need_install_plugin_unique_identifiers:
  236. print(need_install_plugin_unique_identifiers)
  237. PluginService.install_from_marketplace_pkg(tenant_id, need_install_plugin_unique_identifiers)
  238. def _transfrom_to_empty_pipeline(self, dataset: Dataset):
  239. pipeline = Pipeline(
  240. tenant_id=dataset.tenant_id,
  241. name=dataset.name,
  242. description=dataset.description,
  243. created_by=current_user.id,
  244. )
  245. db.session.add(pipeline)
  246. db.session.flush()
  247. dataset.pipeline_id = pipeline.id
  248. dataset.runtime_mode = "rag_pipeline"
  249. dataset.updated_by = current_user.id
  250. dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
  251. db.session.add(dataset)
  252. db.session.commit()
  253. return {
  254. "pipeline_id": pipeline.id,
  255. "dataset_id": dataset.id,
  256. "status": "success",
  257. }
  258. def _deal_document_data(self, dataset: Dataset):
  259. file_node_id = "1752479895761"
  260. notion_node_id = "1752489759475"
  261. jina_node_id = "1752491761974"
  262. firecrawl_node_id = "1752565402678"
  263. documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
  264. for document in documents:
  265. data_source_info_dict = document.data_source_info_dict
  266. if not data_source_info_dict:
  267. continue
  268. if document.data_source_type == "upload_file":
  269. document.data_source_type = "local_file"
  270. file_id = data_source_info_dict.get("upload_file_id")
  271. if file_id:
  272. file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
  273. if file:
  274. data_source_info = json.dumps(
  275. {
  276. "real_file_id": file_id,
  277. "name": file.name,
  278. "size": file.size,
  279. "extension": file.extension,
  280. "mime_type": file.mime_type,
  281. "url": "",
  282. "transfer_method": "local_file",
  283. }
  284. )
  285. document.data_source_info = data_source_info
  286. document_pipeline_execution_log = DocumentPipelineExecutionLog(
  287. document_id=document.id,
  288. pipeline_id=dataset.pipeline_id,
  289. datasource_type="local_file",
  290. datasource_info=data_source_info,
  291. input_data={},
  292. created_by=document.created_by,
  293. created_at=document.created_at,
  294. datasource_node_id=file_node_id,
  295. )
  296. db.session.add(document)
  297. db.session.add(document_pipeline_execution_log)
  298. elif document.data_source_type == "notion_import":
  299. document.data_source_type = "online_document"
  300. data_source_info = json.dumps(
  301. {
  302. "workspace_id": data_source_info_dict.get("notion_workspace_id"),
  303. "page": {
  304. "page_id": data_source_info_dict.get("notion_page_id"),
  305. "page_name": document.name,
  306. "page_icon": data_source_info_dict.get("notion_page_icon"),
  307. "type": data_source_info_dict.get("type"),
  308. "last_edited_time": data_source_info_dict.get("last_edited_time"),
  309. "parent_id": None,
  310. },
  311. }
  312. )
  313. document.data_source_info = data_source_info
  314. document_pipeline_execution_log = DocumentPipelineExecutionLog(
  315. document_id=document.id,
  316. pipeline_id=dataset.pipeline_id,
  317. datasource_type="online_document",
  318. datasource_info=data_source_info,
  319. input_data={},
  320. created_by=document.created_by,
  321. created_at=document.created_at,
  322. datasource_node_id=notion_node_id,
  323. )
  324. db.session.add(document)
  325. db.session.add(document_pipeline_execution_log)
  326. elif document.data_source_type == "website_crawl":
  327. document.data_source_type = "website_crawl"
  328. data_source_info = json.dumps(
  329. {
  330. "source_url": data_source_info_dict.get("url"),
  331. "content": "",
  332. "title": document.name,
  333. "description": "",
  334. }
  335. )
  336. document.data_source_info = data_source_info
  337. if data_source_info_dict.get("provider") == "firecrawl":
  338. datasource_node_id = firecrawl_node_id
  339. elif data_source_info_dict.get("provider") == "jinareader":
  340. datasource_node_id = jina_node_id
  341. else:
  342. continue
  343. document_pipeline_execution_log = DocumentPipelineExecutionLog(
  344. document_id=document.id,
  345. pipeline_id=dataset.pipeline_id,
  346. datasource_type="website_crawl",
  347. datasource_info=data_source_info,
  348. input_data={},
  349. created_by=document.created_by,
  350. created_at=document.created_at,
  351. datasource_node_id=datasource_node_id,
  352. )
  353. db.session.add(document)
  354. db.session.add(document_pipeline_execution_log)