|
|
|
@@ -10,7 +10,7 @@ from uuid import uuid4 |
|
|
|
|
|
|
|
from flask_login import current_user |
|
|
|
from sqlalchemy import func, or_, select |
|
|
|
from sqlalchemy.orm import Session |
|
|
|
from sqlalchemy.orm import Session, sessionmaker |
|
|
|
|
|
|
|
import contexts |
|
|
|
from configs import dify_config |
|
|
|
@@ -33,6 +33,7 @@ from core.rag.entities.event import ( |
|
|
|
DatasourceErrorEvent, |
|
|
|
DatasourceProcessingEvent, |
|
|
|
) |
|
|
|
from core.repositories.factory import DifyCoreRepositoryFactory |
|
|
|
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository |
|
|
|
from core.variables.variables import Variable |
|
|
|
from core.workflow.entities.variable_pool import VariablePool |
|
|
|
@@ -63,6 +64,7 @@ from models.workflow import ( |
|
|
|
WorkflowRun, |
|
|
|
WorkflowType, |
|
|
|
) |
|
|
|
from repositories.factory import DifyAPIRepositoryFactory |
|
|
|
from services.dataset_service import DatasetService |
|
|
|
from services.datasource_provider_service import DatasourceProviderService |
|
|
|
from services.entities.knowledge_entities.rag_pipeline_entities import ( |
|
|
|
@@ -78,6 +80,16 @@ logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class RagPipelineService: |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, session_maker: sessionmaker | None = None): |
|
|
|
"""Initialize RagPipelineService with repository dependencies.""" |
|
|
|
if session_maker is None: |
|
|
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) |
|
|
|
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( |
|
|
|
session_maker |
|
|
|
) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: |
|
|
|
if type == "built-in": |
|
|
|
@@ -390,7 +402,7 @@ class RagPipelineService: |
|
|
|
|
|
|
|
def run_draft_workflow_node( |
|
|
|
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account |
|
|
|
) -> WorkflowNodeExecutionModel: |
|
|
|
) -> WorkflowNodeExecutionModel | None: |
|
|
|
""" |
|
|
|
Run draft workflow node |
|
|
|
""" |
|
|
|
@@ -435,7 +447,8 @@ class RagPipelineService: |
|
|
|
workflow_node_execution.workflow_id = draft_workflow.id |
|
|
|
|
|
|
|
# Create repository and save the node execution |
|
|
|
repository = SQLAlchemyWorkflowNodeExecutionRepository( |
|
|
|
|
|
|
|
repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( |
|
|
|
session_factory=db.engine, |
|
|
|
user=account, |
|
|
|
app_id=pipeline.id, |
|
|
|
@@ -444,16 +457,17 @@ class RagPipelineService: |
|
|
|
repository.save(workflow_node_execution) |
|
|
|
|
|
|
|
# Convert node_execution to WorkflowNodeExecution after save |
|
|
|
workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution) |
|
|
|
workflow_node_execution_db_model = self._node_execution_service_repo.get_execution_by_id(workflow_node_execution.id) |
|
|
|
|
|
|
|
with Session(bind=db.engine) as session, session.begin(): |
|
|
|
draft_var_saver = DraftVariableSaver( |
|
|
|
session=session, |
|
|
|
app_id=pipeline.id, |
|
|
|
node_id=workflow_node_execution_db_model.node_id, |
|
|
|
node_type=NodeType(workflow_node_execution_db_model.node_type), |
|
|
|
node_id=workflow_node_execution.node_id, |
|
|
|
node_type=NodeType(workflow_node_execution.node_type), |
|
|
|
enclosing_node_id=enclosing_node_id, |
|
|
|
node_execution_id=workflow_node_execution.id, |
|
|
|
user=account, |
|
|
|
) |
|
|
|
draft_var_saver.save( |
|
|
|
process_data=workflow_node_execution.process_data, |