浏览代码

r2

tags/2.0.0-beta.1
jyong 3 个月前
父节点
当前提交
b5e4ce6c68

+ 4
- 2
api/controllers/console/datasets/rag_pipeline/datasource_auth.py 查看文件

@@ -177,7 +177,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
raise ValueError(str(ex))

return {"result": "success"}, 201


class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@@ -189,6 +190,7 @@ class DatasourceAuthListApi(Resource):
)
return {"result": datasources}, 200


# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
@@ -211,4 +213,4 @@ api.add_resource(
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
)

+ 3
- 1
api/core/app/apps/pipeline/pipeline_config_manager.py 查看文件

@@ -26,7 +26,9 @@ class PipelineConfigManager(BaseAppConfigManager):
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id),
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)

return pipeline_config

+ 121
- 11
api/core/app/apps/pipeline/pipeline_generator.py 查看文件

@@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, Optional, Union, cast, overload

from flask import Flask, current_app
from pydantic import ValidationError
@@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
@@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
from services.datasource_provider_service import DatasourceProviderService

logger = logging.getLogger(__name__)

@@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator):
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
start_node_id=start_node_id
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
@@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")

# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)

dataset = pipeline.dataset
if not dataset:
@@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required")

# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)

# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
@@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator):
if doc_metadata:
document.doc_metadata = doc_metadata
return document

def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")

from core.datasource.datasource_manager import DatasourceManager

datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)

for datasource_info in datasource_info_list:
if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"):
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("key", ""),
None,
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
)
return all_files
else:
return datasource_info_list

def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
start_after: Optional[str],
bucket: Optional[str],
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
start_after=start_after,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
last_file_key = None
for result in result_generator:
for files in result.result:
for file in files.files:
if file.key.endswith("/"):
self._get_files_in_folder(
datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info
)
else:
all_files.append(
{
"key": file.key,
"bucket": bucket,
}
)
last_file_key = file.key
is_truncated = files.is_truncated

if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info
)

+ 1
- 1
api/core/datasource/datasource_file_manager.py 查看文件

@@ -1,10 +1,10 @@
import base64
from datetime import datetime
import hashlib
import hmac
import logging
import os
import time
from datetime import datetime
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4

+ 1
- 1
api/core/datasource/utils/message_transformer.py 查看文件

@@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer:
mimetype = meta.get("mime_type")
if not mimetype:
mimetype = guess_type(filename)[0] or "application/octet-stream"
# if message is str, encode it to bytes

if not isinstance(message.message, DatasourceMessage.BlobMessage):

+ 5
- 3
api/core/file/file_manager.py 查看文件

@@ -72,9 +72,11 @@ def to_prompt_message_content(


def download(f: File, /):
if f.transfer_method in (FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE,
):
return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)

+ 1
- 1
api/core/plugin/impl/datasource.py 查看文件

@@ -56,7 +56,7 @@ class PluginDatasourceManager(BasePluginClient):
tool.identity.provider = provider.declaration.identity.name

return all_response
def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
Fetch datasource providers for the given tenant.

+ 0
- 1
api/core/rag/models/document.py 查看文件

@@ -69,7 +69,6 @@ class QAChunk(BaseModel):
answer: str



class QAStructureChunk(BaseModel):
"""
QAStructureChunk.

+ 15
- 14
api/services/datasource_provider_service.py 查看文件

@@ -131,7 +131,6 @@ class DatasourceProviderService:
)

return copy_credentials_list

def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
@@ -144,19 +143,21 @@ class DatasourceProviderService:
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
credentials = self.get_datasource_credentials(tenant_id=tenant_id,
provider=datasource.provider,
plugin_id=datasource.plugin_id)
datasource_credentials.append({
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name,
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials": credentials,
})
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name,
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials": credentials,
}
)
return datasource_credentials

def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:

正在加载...
取消
保存