您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rag_pipeline_run_task.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import contextvars
  2. import json
  3. import logging
  4. import time
  5. import uuid
  6. from collections.abc import Mapping
  7. from concurrent.futures import ThreadPoolExecutor
  8. from typing import Any
  9. import click
  10. from celery import shared_task # type: ignore
  11. from flask import current_app, g
  12. from sqlalchemy.orm import Session, sessionmaker
  13. from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
  14. from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
  15. from core.repositories.factory import DifyCoreRepositoryFactory
  16. from extensions.ext_database import db
  17. from extensions.ext_redis import redis_client
  18. from models.account import Account, Tenant
  19. from models.dataset import Pipeline
  20. from models.enums import WorkflowRunTriggeredFrom
  21. from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
  22. from services.file_service import FileService
  23. @shared_task(queue="pipeline")
  24. def rag_pipeline_run_task(
  25. rag_pipeline_invoke_entities_file_id: str,
  26. tenant_id: str,
  27. ):
  28. """
  29. Async Run rag pipeline
  30. :param rag_pipeline_invoke_entities: Rag pipeline invoke entities
  31. rag_pipeline_invoke_entities include:
  32. :param pipeline_id: Pipeline ID
  33. :param user_id: User ID
  34. :param tenant_id: Tenant ID
  35. :param workflow_id: Workflow ID
  36. :param invoke_from: Invoke source (debugger, published, etc.)
  37. :param streaming: Whether to stream results
  38. :param datasource_type: Type of datasource
  39. :param datasource_info: Datasource information dict
  40. :param batch: Batch identifier
  41. :param document_id: Document ID (optional)
  42. :param start_node_id: Starting node ID
  43. :param inputs: Input parameters dict
  44. :param workflow_execution_id: Workflow execution ID
  45. :param workflow_thread_pool_id: Thread pool ID for workflow execution
  46. """
  47. # run with threading, thread pool size is 10
  48. try:
  49. start_at = time.perf_counter()
  50. rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
  51. rag_pipeline_invoke_entities_file_id
  52. )
  53. rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
  54. # Get Flask app object for thread context
  55. flask_app = current_app._get_current_object() # type: ignore
  56. with ThreadPoolExecutor(max_workers=10) as executor:
  57. futures = []
  58. for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
  59. # Submit task to thread pool with Flask app
  60. future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
  61. futures.append(future)
  62. # Wait for all tasks to complete
  63. for future in futures:
  64. try:
  65. future.result() # This will raise any exceptions that occurred in the thread
  66. except Exception:
  67. logging.exception("Error in pipeline task")
  68. end_at = time.perf_counter()
  69. logging.info(
  70. click.style(
  71. f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
  72. )
  73. )
  74. except Exception:
  75. logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
  76. raise
  77. finally:
  78. tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
  79. tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
  80. # Check if there are waiting tasks in the queue
  81. # Use rpop to get the next task from the queue (FIFO order)
  82. next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
  83. if next_file_id:
  84. # Process the next waiting task
  85. # Keep the flag set to indicate a task is running
  86. redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
  87. rag_pipeline_run_task.delay( # type: ignore
  88. rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
  89. if isinstance(next_file_id, bytes)
  90. else next_file_id,
  91. tenant_id=tenant_id,
  92. )
  93. else:
  94. # No more waiting tasks, clear the flag
  95. redis_client.delete(tenant_pipeline_task_key)
  96. file_service = FileService(db.engine)
  97. file_service.delete_file(rag_pipeline_invoke_entities_file_id)
  98. db.session.close()
  99. def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
  100. """Run a single RAG pipeline task within Flask app context."""
  101. # Create Flask application context for this thread
  102. with flask_app.app_context():
  103. try:
  104. rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
  105. user_id = rag_pipeline_invoke_entity_model.user_id
  106. tenant_id = rag_pipeline_invoke_entity_model.tenant_id
  107. pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
  108. workflow_id = rag_pipeline_invoke_entity_model.workflow_id
  109. streaming = rag_pipeline_invoke_entity_model.streaming
  110. workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
  111. workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
  112. application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
  113. with Session(db.engine) as session:
  114. # Load required entities
  115. account = session.query(Account).where(Account.id == user_id).first()
  116. if not account:
  117. raise ValueError(f"Account {user_id} not found")
  118. tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
  119. if not tenant:
  120. raise ValueError(f"Tenant {tenant_id} not found")
  121. account.current_tenant = tenant
  122. pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
  123. if not pipeline:
  124. raise ValueError(f"Pipeline {pipeline_id} not found")
  125. workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
  126. if not workflow:
  127. raise ValueError(f"Workflow {pipeline.workflow_id} not found")
  128. if workflow_execution_id is None:
  129. workflow_execution_id = str(uuid.uuid4())
  130. # Create application generate entity from dict
  131. entity = RagPipelineGenerateEntity(**application_generate_entity)
  132. # Create workflow repositories
  133. session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
  134. workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
  135. session_factory=session_factory,
  136. user=account,
  137. app_id=entity.app_config.app_id,
  138. triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
  139. )
  140. workflow_node_execution_repository = (
  141. DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
  142. session_factory=session_factory,
  143. user=account,
  144. app_id=entity.app_config.app_id,
  145. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  146. )
  147. )
  148. # Set the user directly in g for preserve_flask_contexts
  149. g._login_user = account
  150. # Copy context for passing to pipeline generator
  151. context = contextvars.copy_context()
  152. # Direct execution without creating another thread
  153. # Since we're already in a thread pool, no need for nested threading
  154. from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
  155. pipeline_generator = PipelineGenerator()
  156. # Using protected method intentionally for async execution
  157. pipeline_generator._generate( # type: ignore[attr-defined]
  158. flask_app=flask_app,
  159. context=context,
  160. pipeline=pipeline,
  161. workflow_id=workflow_id,
  162. user=account,
  163. application_generate_entity=entity,
  164. invoke_from=InvokeFrom.PUBLISHED,
  165. workflow_execution_repository=workflow_execution_repository,
  166. workflow_node_execution_repository=workflow_node_execution_repository,
  167. streaming=streaming,
  168. workflow_thread_pool_id=workflow_thread_pool_id,
  169. )
  170. except Exception:
  171. logging.exception("Error in pipeline task")
  172. raise