ソースを参照

r2

tags/2.0.0-beta.1
jyong 3ヶ月前
コミット
b5e4ce6c68

+ 4
- 2
api/controllers/console/datasets/rag_pipeline/datasource_auth.py ファイルの表示

@@ -177,7 +177,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
raise ValueError(str(ex))

return {"result": "success"}, 201


class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@@ -189,6 +190,7 @@ class DatasourceAuthListApi(Resource):
)
return {"result": datasources}, 200


# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
@@ -211,4 +213,4 @@ api.add_resource(
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
)

+ 3
- 1
api/core/app/apps/pipeline/pipeline_config_manager.py ファイルの表示

@@ -26,7 +26,9 @@ class PipelineConfigManager(BaseAppConfigManager):
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id),
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)

return pipeline_config

+ 121
- 11
api/core/app/apps/pipeline/pipeline_generator.py ファイルの表示

@@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, Optional, Union, cast, overload

from flask import Flask, current_app
from pydantic import ValidationError
@@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
@@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
from services.datasource_provider_service import DatasourceProviderService

logger = logging.getLogger(__name__)

@@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator):
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
start_node_id=start_node_id
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
@@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")

# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)

dataset = pipeline.dataset
if not dataset:
@@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required")

# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)

# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
@@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator):
if doc_metadata:
document.doc_metadata = doc_metadata
return document

def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")

from core.datasource.datasource_manager import DatasourceManager

datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)

for datasource_info in datasource_info_list:
if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"):
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("key", ""),
None,
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
)
return all_files
else:
return datasource_info_list

def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
start_after: Optional[str],
bucket: Optional[str],
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
start_after=start_after,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
last_file_key = None
for result in result_generator:
for files in result.result:
for file in files.files:
if file.key.endswith("/"):
self._get_files_in_folder(
datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info
)
else:
all_files.append(
{
"key": file.key,
"bucket": bucket,
}
)
last_file_key = file.key
is_truncated = files.is_truncated

if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info
)

+ 1
- 1
api/core/datasource/datasource_file_manager.py ファイルの表示

@@ -1,10 +1,10 @@
import base64
from datetime import datetime
import hashlib
import hmac
import logging
import os
import time
from datetime import datetime
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4

+ 1
- 1
api/core/datasource/utils/message_transformer.py ファイルの表示

@@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer:
mimetype = meta.get("mime_type")
if not mimetype:
mimetype = guess_type(filename)[0] or "application/octet-stream"
# if message is str, encode it to bytes

if not isinstance(message.message, DatasourceMessage.BlobMessage):

+ 5
- 3
api/core/file/file_manager.py ファイルの表示

@@ -72,9 +72,11 @@ def to_prompt_message_content(


def download(f: File, /):
if f.transfer_method in (FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE,
):
return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)

+ 1
- 1
api/core/plugin/impl/datasource.py ファイルの表示

@@ -56,7 +56,7 @@ class PluginDatasourceManager(BasePluginClient):
tool.identity.provider = provider.declaration.identity.name

return all_response
def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
Fetch datasource providers for the given tenant.

+ 0
- 1
api/core/rag/models/document.py ファイルの表示

@@ -69,7 +69,6 @@ class QAChunk(BaseModel):
answer: str



class QAStructureChunk(BaseModel):
"""
QAStructureChunk.

+ 15
- 14
api/services/datasource_provider_service.py ファイルの表示

@@ -131,7 +131,6 @@ class DatasourceProviderService:
)

return copy_credentials_list

def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
@@ -144,19 +143,21 @@ class DatasourceProviderService:
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
credentials = self.get_datasource_credentials(tenant_id=tenant_id,
provider=datasource.provider,
plugin_id=datasource.plugin_id)
datasource_credentials.append({
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name,
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials": credentials,
})
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name,
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials": credentials,
}
)
return datasource_credentials

def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:

読み込み中…
キャンセル
保存