Quellcode durchsuchen

r2

tags/2.0.0-beta.1
jyong vor 5 Monaten
Ursprung
Commit
9cdd2cbb27
35 geänderte Dateien mit 230 neuen und 301 gelöschten Zeilen
  1. 13
    14
      api/app.py
  2. 0
    2
      api/controllers/console/auth/data_source_oauth.py
  3. 1
    7
      api/controllers/console/auth/oauth.py
  4. 21
    35
      api/controllers/console/datasets/rag_pipeline/datasource_auth.py
  5. 2
    0
      api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
  6. 0
    1
      api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
  7. 17
    21
      api/core/app/apps/pipeline/pipeline_generator.py
  8. 5
    2
      api/core/app/apps/pipeline/pipeline_runner.py
  9. 1
    0
      api/core/app/entities/app_invoke_entities.py
  10. 0
    1
      api/core/datasource/__base/datasource_runtime.py
  11. 1
    3
      api/core/datasource/datasource_manager.py
  12. 14
    33
      api/core/plugin/impl/datasource.py
  13. 12
    12
      api/core/rag/datasource/keyword/jieba/jieba.py
  14. 1
    1
      api/core/rag/index_processor/index_processor_base.py
  15. 3
    6
      api/core/rag/index_processor/processor/paragraph_index_processor.py
  16. 4
    12
      api/core/rag/index_processor/processor/parent_child_index_processor.py
  17. 5
    4
      api/core/rag/index_processor/processor/qa_index_processor.py
  18. 10
    3
      api/core/variables/variables.py
  19. 1
    1
      api/core/workflow/entities/variable_pool.py
  20. 1
    0
      api/core/workflow/entities/workflow_node_execution.py
  21. 1
    2
      api/core/workflow/graph_engine/entities/graph.py
  22. 4
    4
      api/core/workflow/graph_engine/graph_engine.py
  23. 30
    31
      api/core/workflow/nodes/datasource/datasource_node.py
  24. 9
    6
      api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
  25. 1
    1
      api/models/dataset.py
  26. 1
    1
      api/models/oauth.py
  27. 4
    8
      api/services/dataset_service.py
  28. 34
    33
      api/services/datasource_provider_service.py
  29. 1
    0
      api/services/entities/knowledge_entities/rag_pipeline_entities.py
  30. 0
    1
      api/services/rag_pipeline/pipeline_generate_service.py
  31. 1
    4
      api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py
  32. 0
    1
      api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
  33. 16
    22
      api/services/rag_pipeline/rag_pipeline.py
  34. 12
    25
      api/services/rag_pipeline/rag_pipeline_dsl_service.py
  35. 4
    4
      api/services/rag_pipeline/rag_pipeline_manage_service.py

+ 13
- 14
api/app.py Datei anzeigen

@@ -1,4 +1,3 @@
import os
import sys


@@ -18,19 +17,19 @@ else:
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
#from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()

from app_factory import create_app


+ 0
- 2
api/controllers/console/auth/data_source_oauth.py Datei anzeigen

@@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource):
return {"result": "success"}, 200




api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")

+ 1
- 7
api/controllers/console/auth/oauth.py Datei anzeigen

@@ -4,24 +4,19 @@ from typing import Optional

import requests
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restful import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from werkzeug.exceptions import Unauthorized

from configs import dify_config
from constants.languages import languages
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.impl.oauth import OAuthHandler
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
@@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account



api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

+ 21
- 35
api/controllers/console/datasets/rag_pipeline/datasource_auth.py Datei anzeigen

@@ -1,12 +1,9 @@

from flask import redirect, request
from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore
Resource, # type: ignore
marshal_with,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound

from configs import dify_config
@@ -16,7 +13,6 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.login import login_required
@@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource):
if not current_user.is_editor:
raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
@@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource):
if system_credentials:
system_credentials["redirect_url"] = redirect_url
response = oauth_handler.get_authorization_url(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
system_credentials=system_credentials
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
)
return response.model_dump()


class DatasourceOauthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
oauth_handler = OAuthHandler()
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
credentials = oauth_handler.get_credentials(
@@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource):
plugin_id,
provider,
system_credentials=plugin_oauth_config.system_credentials,
request=request
request=request,
)
datasource_provider = DatasourceProvider(
plugin_id=plugin_id,
provider=provider,
auth_type="oauth",
encrypted_credentials=credentials
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
)
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}")


class DatasourceAuth(Resource):
@setup_required
@login_required
@@ -99,28 +88,27 @@ class DatasourceAuth(Resource):

try:
datasource_provider_service.datasource_provider_credentials_validate(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id,
credentials=args["credentials"]
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id,
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))

return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
)
return {"result": datasources}, 200


class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource):
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
)
return {"result": "success"}, 200


# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
@@ -149,4 +136,3 @@ api.add_resource(
DatasourceAuth,
"/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
)


+ 2
- 0
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py Datei anzeigen

@@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource):
dsl = yaml.safe_load(template.yaml_content)
return {"data": dsl}, 200


class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource):
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
return 200


api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",

+ 0
- 1
api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py Datei anzeigen

@@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, pipeline_id):

return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}

+ 17
- 21
api/core/app/apps/pipeline/pipeline_generator.py Datei anzeigen

@@ -32,7 +32,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from fields.document_fields import dataset_and_document_fields
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline
from models.enums import WorkflowRunTriggeredFrom
@@ -55,8 +54,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
...
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...

@overload
def generate(
@@ -70,8 +68,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]:
...
) -> Mapping[str, Any]: ...

@overload
def generate(
@@ -85,8 +82,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
...
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...

def generate(
self,
@@ -233,17 +229,19 @@ class PipelineGenerator(BaseAppGenerator):
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump() for document in documents
]
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}

def _generate(
@@ -316,9 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
)

# new thread
worker_thread = threading.Thread(
target=worker_with_context
)
worker_thread = threading.Thread(target=worker_with_context)

worker_thread.start()


+ 5
- 2
api/core/app/apps/pipeline/pipeline_runner.py Datei anzeigen

@@ -111,7 +111,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs:
if (
rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id
and rag_pipeline_variable.variable in inputs
):
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]

variable_pool = VariablePool(
@@ -195,7 +198,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids :
if edge.get("source") in exclude_node_ids:
continue
real_edges.append(edge)
graph_config = dict(graph_config)

+ 1
- 0
api/core/app/entities/app_invoke_entities.py Datei anzeigen

@@ -232,6 +232,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""

# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str

+ 0
- 1
api/core/datasource/__base/datasource_runtime.py Datei anzeigen

@@ -5,7 +5,6 @@ from pydantic import Field

from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom


class DatasourceRuntime(BaseModel):

+ 1
- 3
api/core/datasource/datasource_manager.py Datei anzeigen

@@ -46,7 +46,7 @@ class DatasourceManager:
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")

match (datasource_type):
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
@@ -98,5 +98,3 @@ class DatasourceManager:
tenant_id,
datasource_type,
).get_datasource(datasource_name)



+ 14
- 33
api/core/plugin/impl/datasource.py Datei anzeigen

@@ -215,7 +215,6 @@ class PluginDatasourceManager(BasePluginClient):
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},

)

for resp in response:
@@ -233,41 +232,23 @@ class PluginDatasourceManager(BasePluginClient):
"identity": {
"author": "langgenius",
"name": "langgenius/file/file",
"label": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
},
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
"description": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"credentials_schema": [],
"provider_type": "local_file",
"datasources": [{
"identity": {
"author": "langgenius",
"name": "upload-file",
"provider": "langgenius",
"label": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
},
"parameters": [],
"description": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
"datasources": [
{
"identity": {
"author": "langgenius",
"name": "upload-file",
"provider": "langgenius",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"parameters": [],
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
}
}]
}
],
},
}

+ 12
- 12
api/core/rag/datasource/keyword/jieba/jieba.py Datei anzeigen

@@ -28,12 +28,12 @@ class Jieba(BaseKeyword):
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)

for text in texts:
keywords = keyword_table_handler.extract_keywords(
text.page_content, keyword_number
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@@ -51,19 +51,17 @@ class Jieba(BaseKeyword):

keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(
text.page_content, keyword_number
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
else:
keywords = keyword_table_handler.extract_keywords(
text.page_content, keyword_number
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@@ -242,7 +240,9 @@ class Jieba(BaseKeyword):
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)

keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)

+ 1
- 1
api/core/rag/index_processor/index_processor_base.py Datei anzeigen

@@ -38,7 +38,7 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
raise NotImplementedError
@abstractmethod
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
raise NotImplementedError

+ 3
- 6
api/core/rag/index_processor/processor/paragraph_index_processor.py Datei anzeigen

@@ -15,7 +15,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, GeneralStructureChunk
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule


@@ -152,13 +153,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset)
keyword.add_texts(documents)


def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
paragraph = GeneralStructureChunk(**chunks)
preview = []
for content in paragraph.general_chunks:
preview.append({"content": content})
return {
"preview": preview,
"total_segments": len(paragraph.general_chunks)
}
return {"preview": preview, "total_segments": len(paragraph.general_chunks)}

+ 4
- 12
api/core/rag/index_processor/processor/parent_child_index_processor.py Datei anzeigen

@@ -16,7 +16,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule


@@ -239,14 +240,5 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
parent_childs = ParentChildStructureChunk(**chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append(
{
"content": parent_child.parent_content,
"child_chunks": parent_child.child_contents

}
)
return {
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks)
}
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}

+ 5
- 4
api/core/rag/index_processor/processor/qa_index_processor.py Datei anzeigen

@@ -4,7 +4,8 @@ import logging
import re
import threading
import uuid
from typing import Any, Mapping, Optional
from collections.abc import Mapping
from typing import Any, Optional

import pandas as pd
from flask import Flask, current_app
@@ -20,7 +21,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, Document as DatasetDocument
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import Rule


@@ -160,10 +161,10 @@ class QAIndexProcessor(BaseIndexProcessor):
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]):
pass
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
return {"preview": chunks}


+ 10
- 3
api/core/variables/variables.py Datei anzeigen

@@ -94,19 +94,26 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass


class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
label: str = Field(description="label")
description: str | None = Field(description="description", default="")
variable: str = Field(description="variable key", default="")
max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0)
max_length: int | None = Field(
description="max length, applicable to text-input, paragraph, and file-list", default=0
)
default_value: str | None = Field(description="default value", default="")
placeholder: str | None = Field(description="placeholder", default="")
unit: str | None = Field(description="unit, applicable to Number", default="")
tooltips: str | None = Field(description="helpful text", default="")
allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list)
allowed_file_types: list[str] | None = Field(
description="image, document, audio, video, custom.", default_factory=list
)
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list)
allowed_file_upload_methods: list[str] | None = Field(
description="remote_url, local_file, tool_file.", default_factory=list
)
required: bool = Field(description="optional, default false", default=False)
options: list[str] | None = Field(default_factory=list)

+ 1
- 1
api/core/workflow/entities/variable_pool.py Datei anzeigen

@@ -49,7 +49,7 @@ class VariablePool(BaseModel):
)
rag_pipeline_variables: Mapping[str, Any] = Field(
description="RAG pipeline variables.",
default_factory=dict,
default_factory=dict,
)

def __init__(

+ 1
- 0
api/core/workflow/entities/workflow_node_execution.py Datei anzeigen

@@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
DATASOURCE_INFO = "datasource_info"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"

+ 1
- 2
api/core/workflow/graph_engine/entities/graph.py Datei anzeigen

@@ -122,7 +122,6 @@ class Graph(BaseModel):
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}

for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
@@ -142,7 +141,7 @@ class Graph(BaseModel):
(
node_config.get("id")
for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value
if node_config.get("data", {}).get("type", "") == NodeType.START.value
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
),
None,

+ 4
- 4
api/core/workflow/graph_engine/graph_engine.py Datei anzeigen

@@ -317,10 +317,10 @@ class GraphEngine:
raise e

# It may not be necessary, but it is necessary. :)
if (
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value]
):
if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [
NodeType.END.value,
NodeType.KNOWLEDGE_INDEX.value,
]:
break

previous_route_node_state = route_node_state

+ 30
- 31
api/core/workflow/nodes/datasource/datasource_node.py Datei anzeigen

@@ -11,18 +11,19 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment, FileSegment
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import UploadFile
from models.workflow import WorkflowNodeExecutionStatus

from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError

@@ -54,7 +55,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
try:
from core.datasource.datasource_manager import DatasourceManager


if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")

@@ -66,13 +66,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
)
except DatasourceNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)

status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)

# get parameters
datasource_parameters = datasource_runtime.entity.parameters
@@ -102,7 +101,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_type,
@@ -112,18 +111,16 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": datasource_info,
"datasource_type": datasource_type,
"website": datasource_info,
"datasource_type": datasource_type,
},
)
case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError(
"File is not exist"
)
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
@@ -146,26 +143,27 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
# construct new key list
new_key_list = ["file", key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
variable_pool=variable_pool,
node_id=self.node_id,
variable_key_list=new_key_list,
variable_value=value,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file_info": datasource_info,
"datasource_type": datasource_type,
},
)
case _:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_type}"
"file_info": datasource_info,
"datasource_type": datasource_type,
},
)
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
@@ -173,7 +171,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
@@ -227,8 +225,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []


def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
def _append_variables_recursively(
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
):
"""
Append variables recursively
:param node_id: node id

+ 9
- 6
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py Datei anzeigen

@@ -6,7 +6,6 @@ from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables.segments import ObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@@ -72,8 +71,9 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
process_data=None,
outputs=outputs,
)
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks,
variable_pool=variable_pool)
results = self._invoke_knowledge_index(
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
)
@@ -96,8 +96,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
)

def _invoke_knowledge_index(
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any],
variable_pool: VariablePool
self,
dataset: Dataset,
node_data: KnowledgeIndexNodeData,
chunks: Mapping[str, Any],
variable_pool: VariablePool,
) -> Any:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
@@ -116,7 +119,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document)
#update document segment status
# update document segment status
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,

+ 1
- 1
api/models/dataset.py Datei anzeigen

@@ -208,6 +208,7 @@ class Dataset(Base):
"external_knowledge_api_name": external_knowledge_api.name,
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
}

@property
def is_published(self):
if self.pipeline_id:
@@ -1177,7 +1178,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())



class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
__table_args__ = (

+ 1
- 1
api/models/oauth.py Datei anzeigen

@@ -1,4 +1,3 @@

from datetime import datetime

from sqlalchemy.dialects.postgresql import JSONB
@@ -21,6 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)


class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (

+ 4
- 8
api/services/dataset_service.py Datei anzeigen

@@ -1,4 +1,3 @@
from calendar import day_abbr
import copy
import datetime
import json
@@ -7,7 +6,7 @@ import random
import time
import uuid
from collections import Counter
from typing import Any, Optional, cast
from typing import Any, Optional

from flask_login import current_user
from sqlalchemy import func, select
@@ -282,7 +281,6 @@ class DatasetService:
db.session.commit()
return dataset


@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
@@ -494,10 +492,9 @@ class DatasetService:
return dataset

@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False):
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
):
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
@@ -616,7 +613,6 @@ class DatasetService:
if action:
deal_dataset_index_update_task.delay(dataset.id, action)


@staticmethod
def delete_dataset(dataset_id, user):
dataset = DatasetService.get_dataset(dataset_id)

+ 34
- 33
api/services/datasource_provider_service.py Datei anzeigen

@@ -1,5 +1,4 @@
import logging
from typing import Optional

from flask_login import current_user

@@ -22,11 +21,9 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()

def datasource_provider_credentials_validate(self,
tenant_id: str,
provider: str,
plugin_id: str,
credentials: dict) -> None:
def datasource_provider_credentials_validate(
self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
) -> None:
"""
validate datasource provider credentials.

@@ -34,29 +31,30 @@ class DatasourceProviderService:
:param provider:
:param credentials:
"""
credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
credentials=credentials)
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials
)
if credential_valid:
# Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
)

provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
provider=provider
)
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
if not datasource_provider:
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
@@ -101,11 +99,15 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id
).all()
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
@@ -128,10 +130,7 @@ class DatasourceProviderService:

return copy_credentials_list

def remove_datasource_credentials(self,
tenant_id: str,
provider: str,
plugin_id: str) -> None:
def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None:
"""
remove datasource credentials.

@@ -140,9 +139,11 @@ class DatasourceProviderService:
:param plugin_id: plugin id
:return:
"""
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider:
db.session.delete(datasource_provider)
db.session.commit()

+ 1
- 0
api/services/entities/knowledge_entities/rag_pipeline_entities.py Datei anzeigen

@@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel):
"""
Knowledge Base Configuration.
"""

chunk_structure: str
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: Optional[str] = ""

+ 0
- 1
api/services/rag_pipeline/pipeline_generate_service.py Datei anzeigen

@@ -3,7 +3,6 @@ from typing import Any, Union

from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from models.dataset import Pipeline
from models.model import Account, App, EndUser

+ 1
- 4
api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py Datei anzeigen

@@ -1,13 +1,12 @@
from typing import Optional

from flask_login import current_user
import yaml
from flask_login import current_user

from extensions.ext_database import db
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService


class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
)
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:

recommended_pipeline_result = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
@@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):

return {"pipeline_templates": recommended_pipelines_results}


@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
"""

+ 0
- 1
api/services/rag_pipeline/pipeline_template/database/database_retrieval.py Datei anzeigen

@@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):

recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:

recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,

+ 16
- 22
api/services/rag_pipeline/rag_pipeline.py Datei anzeigen

@@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
@@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi

class RagPipelineService:
@classmethod
def get_pipeline_templates(
cls, type: str = "built-in", language: str = "en-US"
) -> dict:
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@@ -308,7 +306,7 @@ class RagPipelineService:
session=session,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published
has_published=pipeline.is_published,
)
# return new workflow
return workflow
@@ -444,12 +442,10 @@ class RagPipelineService:
)
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = (
datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [page.model_dump() for page in online_document_result.result],
@@ -470,7 +466,6 @@ class RagPipelineService:
else:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")


def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
@@ -689,8 +684,8 @@ class RagPipelineService:
WorkflowRun.app_id == pipeline.id,
or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
)
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
),
)

if args.get("last_id"):
@@ -763,18 +758,17 @@ class RagPipelineService:

# Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine,
app_id=pipeline.id,
user=user,
triggered_from=None
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
)

# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
# Convert domain models to database models
node_executions = repository.get_by_workflow_run(
workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]

return workflow_node_executions

+ 12
- 25
api/services/rag_pipeline/rag_pipeline_dsl_service.py Datei anzeigen

@@ -279,7 +279,11 @@ class RagPipelineDslService:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
if (
dataset
and pipeline.is_published
and dataset.chunk_structure != knowledge_configuration.chunk_structure
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
else:
dataset = Dataset(
@@ -304,8 +308,7 @@ class RagPipelineDslService:
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@@ -323,12 +326,8 @@ class RagPipelineDslService:
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
@@ -443,8 +442,7 @@ class RagPipelineDslService:
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@@ -462,12 +460,8 @@ class RagPipelineDslService:
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
@@ -538,7 +532,6 @@ class RagPipelineDslService:
icon_type = "emoji"
icon = str(pipeline_data.get("icon", ""))


# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
@@ -554,7 +547,6 @@ class RagPipelineDslService:
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])


graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
@@ -576,7 +568,6 @@ class RagPipelineDslService:
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id

else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
@@ -636,7 +627,6 @@ class RagPipelineDslService:
# commit db session changes
db.session.commit()


return pipeline

@classmethod
@@ -874,7 +864,6 @@ class RagPipelineDslService:
except Exception:
return None


@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
@@ -886,9 +875,7 @@ class RagPipelineDslService:
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")

with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)

+ 4
- 4
api/services/rag_pipeline/rag_pipeline_manage_service.py Datei anzeigen

@@ -12,12 +12,12 @@ class RagPipelineManageService:

# get all builtin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_datasource_providers(tenant_id)
datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
provider=datasource.provider,
plugin_id=datasource.plugin_id)
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
if credentials:
datasource.is_authorized = True
return datasources

Laden…
Abbrechen
Speichern