Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

rag_pipeline_transform_service.py 18KB

pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
pirms 3 mēnešiem
pirms 2 mēnešiem
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. import json
  2. from datetime import UTC, datetime
  3. from pathlib import Path
  4. from typing import Optional
  5. from uuid import uuid4
  6. import yaml
  7. from flask_login import current_user
  8. from constants import DOCUMENT_EXTENSIONS
  9. from core.plugin.impl.plugin import PluginInstaller
  10. from extensions.ext_database import db
  11. from factories import variable_factory
  12. from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
  13. from models.model import UploadFile
  14. from models.workflow import Workflow, WorkflowType
  15. from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
  16. from services.plugin.plugin_migration import PluginMigration
  17. from services.plugin.plugin_service import PluginService
  18. class RagPipelineTransformService:
  19. def transform_dataset(self, dataset_id: str):
  20. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  21. if not dataset:
  22. raise ValueError("Dataset not found")
  23. if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
  24. return {
  25. "pipeline_id": dataset.pipeline_id,
  26. "dataset_id": dataset_id,
  27. "status": "success",
  28. }
  29. if dataset.provider != "vendor":
  30. raise ValueError("External dataset is not supported")
  31. datasource_type = dataset.data_source_type
  32. indexing_technique = dataset.indexing_technique
  33. if not datasource_type and not indexing_technique:
  34. return self._transfrom_to_empty_pipeline(dataset)
  35. doc_form = dataset.doc_form
  36. if not doc_form:
  37. return self._transfrom_to_empty_pipeline(dataset)
  38. retrieval_model = dataset.retrieval_model
  39. pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
  40. # deal dependencies
  41. self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
  42. # Extract app data
  43. workflow_data = pipeline_yaml.get("workflow")
  44. graph = workflow_data.get("graph", {})
  45. nodes = graph.get("nodes", [])
  46. new_nodes = []
  47. for node in nodes:
  48. if (
  49. node.get("data", {}).get("type") == "datasource"
  50. and node.get("data", {}).get("provider_type") == "local_file"
  51. ):
  52. node = self._deal_file_extensions(node)
  53. if node.get("data", {}).get("type") == "knowledge-index":
  54. node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
  55. new_nodes.append(node)
  56. if new_nodes:
  57. graph["nodes"] = new_nodes
  58. workflow_data["graph"] = graph
  59. pipeline_yaml["workflow"] = workflow_data
  60. # create pipeline
  61. pipeline = self._create_pipeline(pipeline_yaml)
  62. # save chunk structure to dataset
  63. if doc_form == "hierarchical_model":
  64. dataset.chunk_structure = "hierarchical_model"
  65. elif doc_form == "text_model":
  66. dataset.chunk_structure = "text_model"
  67. else:
  68. raise ValueError("Unsupported doc form")
  69. dataset.runtime_mode = "rag_pipeline"
  70. dataset.pipeline_id = pipeline.id
  71. # deal document data
  72. self._deal_document_data(dataset)
  73. db.session.commit()
  74. return {
  75. "pipeline_id": pipeline.id,
  76. "dataset_id": dataset_id,
  77. "status": "success",
  78. }
  79. def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]):
  80. if doc_form == "text_model":
  81. match datasource_type:
  82. case "upload_file":
  83. if indexing_technique == "high_quality":
  84. # get graph from transform.file-general-high-quality.yml
  85. with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
  86. pipeline_yaml = yaml.safe_load(f)
  87. if indexing_technique == "economy":
  88. # get graph from transform.file-general-economy.yml
  89. with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
  90. pipeline_yaml = yaml.safe_load(f)
  91. case "notion_import":
  92. if indexing_technique == "high_quality":
  93. # get graph from transform.notion-general-high-quality.yml
  94. with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
  95. pipeline_yaml = yaml.safe_load(f)
  96. if indexing_technique == "economy":
  97. # get graph from transform.notion-general-economy.yml
  98. with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
  99. pipeline_yaml = yaml.safe_load(f)
  100. case "website_crawl":
  101. if indexing_technique == "high_quality":
  102. # get graph from transform.website-crawl-general-high-quality.yml
  103. with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
  104. pipeline_yaml = yaml.safe_load(f)
  105. if indexing_technique == "economy":
  106. # get graph from transform.website-crawl-general-economy.yml
  107. with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
  108. pipeline_yaml = yaml.safe_load(f)
  109. case _:
  110. raise ValueError("Unsupported datasource type")
  111. elif doc_form == "hierarchical_model":
  112. match datasource_type:
  113. case "upload_file":
  114. # get graph from transform.file-parentchild.yml
  115. with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
  116. pipeline_yaml = yaml.safe_load(f)
  117. case "notion_import":
  118. # get graph from transform.notion-parentchild.yml
  119. with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
  120. pipeline_yaml = yaml.safe_load(f)
  121. case "website_crawl":
  122. # get graph from transform.website-crawl-parentchild.yml
  123. with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
  124. pipeline_yaml = yaml.safe_load(f)
  125. case _:
  126. raise ValueError("Unsupported datasource type")
  127. else:
  128. raise ValueError("Unsupported doc form")
  129. return pipeline_yaml
  130. def _deal_file_extensions(self, node: dict):
  131. file_extensions = node.get("data", {}).get("fileExtensions", [])
  132. if not file_extensions:
  133. return node
  134. file_extensions = [file_extension.lower() for file_extension in file_extensions]
  135. node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS
  136. return node
  137. def _deal_knowledge_index(
  138. self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict
  139. ):
  140. knowledge_configuration_dict = node.get("data", {})
  141. knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
  142. if indexing_technique == "high_quality":
  143. knowledge_configuration.embedding_model = dataset.embedding_model
  144. knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
  145. if retrieval_model:
  146. retrieval_setting = RetrievalSetting(**retrieval_model)
  147. if indexing_technique == "economy":
  148. retrieval_setting.search_method = "keyword_search"
  149. knowledge_configuration.retrieval_model = retrieval_setting
  150. else:
  151. dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
  152. knowledge_configuration_dict.update(knowledge_configuration.model_dump())
  153. node["data"] = knowledge_configuration_dict
  154. return node
  155. def _create_pipeline(
  156. self,
  157. data: dict,
  158. ) -> Pipeline:
  159. """Create a new app or update an existing one."""
  160. pipeline_data = data.get("rag_pipeline", {})
  161. # Initialize pipeline based on mode
  162. workflow_data = data.get("workflow")
  163. if not workflow_data or not isinstance(workflow_data, dict):
  164. raise ValueError("Missing workflow data for rag pipeline")
  165. environment_variables_list = workflow_data.get("environment_variables", [])
  166. environment_variables = [
  167. variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
  168. ]
  169. conversation_variables_list = workflow_data.get("conversation_variables", [])
  170. conversation_variables = [
  171. variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
  172. ]
  173. rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
  174. graph = workflow_data.get("graph", {})
  175. # Create new app
  176. pipeline = Pipeline()
  177. pipeline.id = str(uuid4())
  178. pipeline.tenant_id = current_user.current_tenant_id
  179. pipeline.name = pipeline_data.get("name", "")
  180. pipeline.description = pipeline_data.get("description", "")
  181. pipeline.created_by = current_user.id
  182. pipeline.updated_by = current_user.id
  183. pipeline.is_published = True
  184. pipeline.is_public = True
  185. db.session.add(pipeline)
  186. db.session.flush()
  187. # create draft workflow
  188. draft_workflow = Workflow(
  189. tenant_id=pipeline.tenant_id,
  190. app_id=pipeline.id,
  191. features="{}",
  192. type=WorkflowType.RAG_PIPELINE.value,
  193. version="draft",
  194. graph=json.dumps(graph),
  195. created_by=current_user.id,
  196. environment_variables=environment_variables,
  197. conversation_variables=conversation_variables,
  198. rag_pipeline_variables=rag_pipeline_variables_list,
  199. )
  200. published_workflow = Workflow(
  201. tenant_id=pipeline.tenant_id,
  202. app_id=pipeline.id,
  203. features="{}",
  204. type=WorkflowType.RAG_PIPELINE.value,
  205. version=str(datetime.now(UTC).replace(tzinfo=None)),
  206. graph=json.dumps(graph),
  207. created_by=current_user.id,
  208. environment_variables=environment_variables,
  209. conversation_variables=conversation_variables,
  210. rag_pipeline_variables=rag_pipeline_variables_list,
  211. )
  212. db.session.add(draft_workflow)
  213. db.session.add(published_workflow)
  214. db.session.flush()
  215. pipeline.workflow_id = published_workflow.id
  216. db.session.add(pipeline)
  217. return pipeline
  218. def _deal_dependencies(self, pipeline_yaml: dict, tenant_id: str):
  219. installer_manager = PluginInstaller()
  220. installed_plugins = installer_manager.list_plugins(tenant_id)
  221. plugin_migration = PluginMigration()
  222. installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
  223. dependencies = pipeline_yaml.get("dependencies", [])
  224. need_install_plugin_unique_identifiers = []
  225. for dependency in dependencies:
  226. if dependency.get("type") == "marketplace":
  227. plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier")
  228. plugin_id = plugin_unique_identifier.split(":")[0]
  229. if plugin_id not in installed_plugins_ids:
  230. plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id)
  231. if plugin_unique_identifier:
  232. need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
  233. if need_install_plugin_unique_identifiers:
  234. print(need_install_plugin_unique_identifiers)
  235. PluginService.install_from_marketplace_pkg(tenant_id, need_install_plugin_unique_identifiers)
  236. def _transfrom_to_empty_pipeline(self, dataset: Dataset):
  237. pipeline = Pipeline(
  238. tenant_id=dataset.tenant_id,
  239. name=dataset.name,
  240. description=dataset.description,
  241. created_by=current_user.id,
  242. )
  243. db.session.add(pipeline)
  244. db.session.flush()
  245. dataset.pipeline_id = pipeline.id
  246. dataset.runtime_mode = "rag_pipeline"
  247. dataset.updated_by = current_user.id
  248. dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
  249. db.session.add(dataset)
  250. db.session.commit()
  251. return {
  252. "pipeline_id": pipeline.id,
  253. "dataset_id": dataset.id,
  254. "status": "success",
  255. }
  256. def _deal_document_data(self, dataset: Dataset):
  257. file_node_id = "1752479895761"
  258. notion_node_id = "1752489759475"
  259. jina_node_id = "1752491761974"
  260. firecrawl_node_id = "1752565402678"
  261. documents = db.session.query(Document).filter(Document.dataset_id == dataset.id).all()
  262. for document in documents:
  263. data_source_info_dict = document.data_source_info_dict
  264. if not data_source_info_dict:
  265. continue
  266. if document.data_source_type == "upload_file":
  267. document.data_source_type = "local_file"
  268. file_id = data_source_info_dict.get("upload_file_id")
  269. if file_id:
  270. file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
  271. if file:
  272. data_source_info = json.dumps(
  273. {
  274. "real_file_id": file_id,
  275. "name": file.name,
  276. "size": file.size,
  277. "extension": file.extension,
  278. "mime_type": file.mime_type,
  279. "url": "",
  280. "transfer_method": "local_file",
  281. }
  282. )
  283. document.data_source_info = data_source_info
  284. document_pipeline_execution_log = DocumentPipelineExecutionLog(
  285. document_id=document.id,
  286. pipeline_id=dataset.pipeline_id,
  287. datasource_type="local_file",
  288. datasource_info=data_source_info,
  289. input_data={},
  290. created_by=document.created_by,
  291. created_at=document.created_at,
  292. datasource_node_id=file_node_id,
  293. )
  294. db.session.add(document)
  295. db.session.add(document_pipeline_execution_log)
  296. elif document.data_source_type == "notion_import":
  297. document.data_source_type = "online_document"
  298. data_source_info = json.dumps(
  299. {
  300. "workspace_id": data_source_info_dict.get("notion_workspace_id"),
  301. "page": {
  302. "page_id": data_source_info_dict.get("notion_page_id"),
  303. "page_name": document.name,
  304. "page_icon": data_source_info_dict.get("notion_page_icon"),
  305. "type": data_source_info_dict.get("type"),
  306. "last_edited_time": data_source_info_dict.get("last_edited_time"),
  307. "parent_id": None,
  308. },
  309. }
  310. )
  311. document.data_source_info = data_source_info
  312. document_pipeline_execution_log = DocumentPipelineExecutionLog(
  313. document_id=document.id,
  314. pipeline_id=dataset.pipeline_id,
  315. datasource_type="online_document",
  316. datasource_info=data_source_info,
  317. input_data={},
  318. created_by=document.created_by,
  319. created_at=document.created_at,
  320. datasource_node_id=notion_node_id,
  321. )
  322. db.session.add(document)
  323. db.session.add(document_pipeline_execution_log)
  324. elif document.data_source_type == "website_crawl":
  325. document.data_source_type = "website_crawl"
  326. data_source_info = json.dumps(
  327. {
  328. "source_url": data_source_info_dict.get("url"),
  329. "content": "",
  330. "title": document.name,
  331. "description": "",
  332. }
  333. )
  334. document.data_source_info = data_source_info
  335. if data_source_info_dict.get("provider") == "firecrawl":
  336. datasource_node_id = firecrawl_node_id
  337. elif data_source_info_dict.get("provider") == "jinareader":
  338. datasource_node_id = jina_node_id
  339. else:
  340. continue
  341. document_pipeline_execution_log = DocumentPipelineExecutionLog(
  342. document_id=document.id,
  343. pipeline_id=dataset.pipeline_id,
  344. datasource_type="website_crawl",
  345. datasource_info=data_source_info,
  346. input_data={},
  347. created_by=document.created_by,
  348. created_at=document.created_at,
  349. datasource_node_id=datasource_node_id,
  350. )
  351. db.session.add(document)
  352. db.session.add(document_pipeline_execution_log)