You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline_run_task.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import contextvars
  2. import logging
  3. import threading
  4. import time
  5. import uuid
  6. import click
  7. from celery import shared_task # type: ignore
  8. from flask import current_app, g
  9. from sqlalchemy.orm import Session, sessionmaker
  10. from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
  11. from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
  12. from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
  13. from extensions.ext_database import db
  14. from extensions.ext_redis import redis_client
  15. from models.account import Account, Tenant
  16. from models.dataset import Pipeline
  17. from models.enums import WorkflowRunTriggeredFrom
  18. from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
  19. @shared_task(queue="pipeline")
  20. def rag_pipeline_run_task(
  21. pipeline_id: str,
  22. application_generate_entity: dict,
  23. user_id: str,
  24. tenant_id: str,
  25. workflow_id: str,
  26. streaming: bool,
  27. workflow_execution_id: str | None = None,
  28. workflow_thread_pool_id: str | None = None,
  29. ):
  30. """
  31. Async Run rag pipeline
  32. :param pipeline_id: Pipeline ID
  33. :param user_id: User ID
  34. :param tenant_id: Tenant ID
  35. :param workflow_id: Workflow ID
  36. :param invoke_from: Invoke source (debugger, published, etc.)
  37. :param streaming: Whether to stream results
  38. :param datasource_type: Type of datasource
  39. :param datasource_info: Datasource information dict
  40. :param batch: Batch identifier
  41. :param document_id: Document ID (optional)
  42. :param start_node_id: Starting node ID
  43. :param inputs: Input parameters dict
  44. :param workflow_execution_id: Workflow execution ID
  45. :param workflow_thread_pool_id: Thread pool ID for workflow execution
  46. """
  47. logging.info(click.style(f"Start run rag pipeline: {pipeline_id}", fg="green"))
  48. start_at = time.perf_counter()
  49. indexing_cache_key = f"rag_pipeline_run_{pipeline_id}_{user_id}"
  50. try:
  51. with Session(db.engine) as session:
  52. account = session.query(Account).filter(Account.id == user_id).first()
  53. if not account:
  54. raise ValueError(f"Account {user_id} not found")
  55. tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
  56. if not tenant:
  57. raise ValueError(f"Tenant {tenant_id} not found")
  58. account.current_tenant = tenant
  59. pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
  60. if not pipeline:
  61. raise ValueError(f"Pipeline {pipeline_id} not found")
  62. workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
  63. if not workflow:
  64. raise ValueError(f"Workflow {pipeline.workflow_id} not found")
  65. if workflow_execution_id is None:
  66. workflow_execution_id = str(uuid.uuid4())
  67. # Create application generate entity from dict
  68. entity = RagPipelineGenerateEntity(**application_generate_entity)
  69. # Create workflow node execution repository
  70. session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
  71. workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
  72. session_factory=session_factory,
  73. user=account,
  74. app_id=entity.app_config.app_id,
  75. triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
  76. )
  77. workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
  78. session_factory=session_factory,
  79. user=account,
  80. app_id=entity.app_config.app_id,
  81. triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
  82. )
  83. # Use app context to ensure Flask globals work properly
  84. with current_app.app_context():
  85. # Set the user directly in g for preserve_flask_contexts
  86. g._login_user = account
  87. # Copy context for thread (after setting user)
  88. context = contextvars.copy_context()
  89. # Get Flask app object in the main thread where app context exists
  90. flask_app = current_app._get_current_object() # type: ignore
  91. # Create a wrapper function that passes user context
  92. def _run_with_user_context():
  93. # Don't create a new app context here - let _generate handle it
  94. # Just ensure the user is available in contextvars
  95. from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
  96. pipeline_generator = PipelineGenerator()
  97. pipeline_generator._generate(
  98. flask_app=flask_app,
  99. context=context,
  100. pipeline=pipeline,
  101. workflow_id=workflow_id,
  102. user=account,
  103. application_generate_entity=entity,
  104. invoke_from=InvokeFrom.PUBLISHED,
  105. workflow_execution_repository=workflow_execution_repository,
  106. workflow_node_execution_repository=workflow_node_execution_repository,
  107. streaming=streaming,
  108. workflow_thread_pool_id=workflow_thread_pool_id,
  109. )
  110. # Create and start worker thread
  111. worker_thread = threading.Thread(target=_run_with_user_context)
  112. worker_thread.start()
  113. worker_thread.join() # Wait for worker thread to complete
  114. end_at = time.perf_counter()
  115. logging.info(
  116. click.style(f"Rag pipeline run: {pipeline_id} completed. Latency: {end_at - start_at}s", fg="green")
  117. )
  118. except Exception:
  119. logging.exception(click.style(f"Error running rag pipeline {pipeline_id}", fg="red"))
  120. raise
  121. finally:
  122. redis_client.delete(indexing_cache_key)
  123. db.session.close()