|
|
|
@@ -7,7 +7,7 @@ import threading |
|
|
|
import time |
|
|
|
import uuid |
|
|
|
from collections.abc import Generator, Mapping |
|
|
|
from typing import Any, Literal, Optional, Union, overload |
|
|
|
from typing import Any, Literal, Optional, Union, cast, overload |
|
|
|
|
|
|
|
from flask import Flask, current_app |
|
|
|
from pydantic import ValidationError |
|
|
|
@@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera |
|
|
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline |
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity |
|
|
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse |
|
|
|
from core.datasource.entities.datasource_entities import ( |
|
|
|
DatasourceProviderType, |
|
|
|
OnlineDriveBrowseFilesRequest, |
|
|
|
) |
|
|
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin |
|
|
|
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument |
|
|
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError |
|
|
|
from core.rag.index_processor.constant.built_in_field import BuiltInField |
|
|
|
@@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline |
|
|
|
from models.enums import WorkflowRunTriggeredFrom |
|
|
|
from models.model import AppMode |
|
|
|
from services.dataset_service import DocumentService |
|
|
|
from services.datasource_provider_service import DatasourceProviderService |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
@@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator): |
|
|
|
inputs: Mapping[str, Any] = args["inputs"] |
|
|
|
start_node_id: str = args["start_node_id"] |
|
|
|
datasource_type: str = args["datasource_type"] |
|
|
|
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] |
|
|
|
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( |
|
|
|
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user |
|
|
|
) |
|
|
|
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) |
|
|
|
# convert to app config |
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config( |
|
|
|
pipeline=pipeline, |
|
|
|
workflow=workflow, |
|
|
|
start_node_id=start_node_id |
|
|
|
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id |
|
|
|
) |
|
|
|
documents = [] |
|
|
|
if invoke_from == InvokeFrom.PUBLISHED: |
|
|
|
@@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator): |
|
|
|
raise ValueError("inputs is required") |
|
|
|
|
|
|
|
# convert to app config |
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, |
|
|
|
workflow=workflow, |
|
|
|
start_node_id=args.get("start_node_id","shared")) |
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config( |
|
|
|
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared") |
|
|
|
) |
|
|
|
|
|
|
|
dataset = pipeline.dataset |
|
|
|
if not dataset: |
|
|
|
@@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator): |
|
|
|
raise ValueError("Pipeline dataset is required") |
|
|
|
|
|
|
|
# convert to app config |
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, |
|
|
|
workflow=workflow, |
|
|
|
start_node_id=args.get("start_node_id","shared")) |
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config( |
|
|
|
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared") |
|
|
|
) |
|
|
|
|
|
|
|
# init application generate entity |
|
|
|
application_generate_entity = RagPipelineGenerateEntity( |
|
|
|
@@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator): |
|
|
|
if doc_metadata: |
|
|
|
document.doc_metadata = doc_metadata |
|
|
|
return document |
|
|
|
|
|
|
|
def _format_datasource_info_list( |
|
|
|
self, |
|
|
|
datasource_type: str, |
|
|
|
datasource_info_list: list[Mapping[str, Any]], |
|
|
|
pipeline: Pipeline, |
|
|
|
workflow: Workflow, |
|
|
|
start_node_id: str, |
|
|
|
user: Union[Account, EndUser], |
|
|
|
) -> list[Mapping[str, Any]]: |
|
|
|
""" |
|
|
|
Format datasource info list. |
|
|
|
""" |
|
|
|
if datasource_type == "online_drive": |
|
|
|
all_files = [] |
|
|
|
datasource_node_data = None |
|
|
|
datasource_nodes = workflow.graph_dict.get("nodes", []) |
|
|
|
for datasource_node in datasource_nodes: |
|
|
|
if datasource_node.get("id") == start_node_id: |
|
|
|
datasource_node_data = datasource_node.get("data", {}) |
|
|
|
break |
|
|
|
if not datasource_node_data: |
|
|
|
raise ValueError("Datasource node data not found") |
|
|
|
|
|
|
|
from core.datasource.datasource_manager import DatasourceManager |
|
|
|
|
|
|
|
datasource_runtime = DatasourceManager.get_datasource_runtime( |
|
|
|
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", |
|
|
|
datasource_name=datasource_node_data.get("datasource_name"), |
|
|
|
tenant_id=pipeline.tenant_id, |
|
|
|
datasource_type=DatasourceProviderType(datasource_type), |
|
|
|
) |
|
|
|
datasource_provider_service = DatasourceProviderService() |
|
|
|
credentials = datasource_provider_service.get_real_datasource_credentials( |
|
|
|
tenant_id=pipeline.tenant_id, |
|
|
|
provider=datasource_node_data.get("provider_name"), |
|
|
|
plugin_id=datasource_node_data.get("plugin_id"), |
|
|
|
) |
|
|
|
if credentials: |
|
|
|
datasource_runtime.runtime.credentials = credentials[0].get("credentials") |
|
|
|
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) |
|
|
|
|
|
|
|
for datasource_info in datasource_info_list: |
|
|
|
if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"): |
|
|
|
# get all files in the folder |
|
|
|
self._get_files_in_folder( |
|
|
|
datasource_runtime, |
|
|
|
datasource_info.get("key", ""), |
|
|
|
None, |
|
|
|
datasource_info.get("bucket", None), |
|
|
|
user.id, |
|
|
|
all_files, |
|
|
|
datasource_info, |
|
|
|
) |
|
|
|
return all_files |
|
|
|
else: |
|
|
|
return datasource_info_list |
|
|
|
|
|
|
|
def _get_files_in_folder( |
|
|
|
self, |
|
|
|
datasource_runtime: OnlineDriveDatasourcePlugin, |
|
|
|
prefix: str, |
|
|
|
start_after: Optional[str], |
|
|
|
bucket: Optional[str], |
|
|
|
user_id: str, |
|
|
|
all_files: list, |
|
|
|
datasource_info: Mapping[str, Any], |
|
|
|
): |
|
|
|
""" |
|
|
|
Get files in a folder. |
|
|
|
""" |
|
|
|
result_generator = datasource_runtime.online_drive_browse_files( |
|
|
|
user_id=user_id, |
|
|
|
request=OnlineDriveBrowseFilesRequest( |
|
|
|
bucket=bucket, |
|
|
|
prefix=prefix, |
|
|
|
max_keys=20, |
|
|
|
start_after=start_after, |
|
|
|
), |
|
|
|
provider_type=datasource_runtime.datasource_provider_type(), |
|
|
|
) |
|
|
|
is_truncated = False |
|
|
|
last_file_key = None |
|
|
|
for result in result_generator: |
|
|
|
for files in result.result: |
|
|
|
for file in files.files: |
|
|
|
if file.key.endswith("/"): |
|
|
|
self._get_files_in_folder( |
|
|
|
datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info |
|
|
|
) |
|
|
|
else: |
|
|
|
all_files.append( |
|
|
|
{ |
|
|
|
"key": file.key, |
|
|
|
"bucket": bucket, |
|
|
|
} |
|
|
|
) |
|
|
|
last_file_key = file.key |
|
|
|
is_truncated = files.is_truncated |
|
|
|
|
|
|
|
if is_truncated: |
|
|
|
self._get_files_in_folder( |
|
|
|
datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info |
|
|
|
) |