| @@ -1,4 +1,3 @@ | |||
| import os | |||
| import sys | |||
| @@ -18,19 +17,19 @@ else: | |||
| # so we need to disable gevent in debug mode. | |||
| # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. | |||
| # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: | |||
| #from gevent import monkey | |||
| # | |||
| # # gevent | |||
| # monkey.patch_all() | |||
| # | |||
| # from grpc.experimental import gevent as grpc_gevent # type: ignore | |||
| # | |||
| # # grpc gevent | |||
| # grpc_gevent.init_gevent() | |||
| # import psycogreen.gevent # type: ignore | |||
| # | |||
| # psycogreen.gevent.patch_psycopg() | |||
| # from gevent import monkey | |||
| # | |||
| # # gevent | |||
| # monkey.patch_all() | |||
| # | |||
| # from grpc.experimental import gevent as grpc_gevent # type: ignore | |||
| # | |||
| # # grpc gevent | |||
| # grpc_gevent.init_gevent() | |||
| # import psycogreen.gevent # type: ignore | |||
| # | |||
| # psycogreen.gevent.patch_psycopg() | |||
| from app_factory import create_app | |||
| @@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource): | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") | |||
| api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") | |||
| api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") | |||
| @@ -4,24 +4,19 @@ from typing import Optional | |||
| import requests | |||
| from flask import current_app, redirect, request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden, NotFound, Unauthorized | |||
| from werkzeug.exceptions import Unauthorized | |||
| from configs import dify_config | |||
| from constants.languages import languages | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from events.tenant_event import tenant_was_created | |||
| from extensions.ext_database import db | |||
| from libs.helper import extract_remote_ip | |||
| from libs.login import login_required | |||
| from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | |||
| from models import Account | |||
| from models.account import AccountStatus | |||
| from models.oauth import DatasourceOauthParamConfig, DatasourceProvider | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| from services.errors.account import AccountNotFoundError, AccountRegisterError | |||
| from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError | |||
| @@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): | |||
| return account | |||
| api.add_resource(OAuthLogin, "/oauth/login/<provider>") | |||
| api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") | |||
| @@ -1,12 +1,9 @@ | |||
| from flask import redirect, request | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import ( # type: ignore | |||
| Resource, # type: ignore | |||
| marshal_with, | |||
| reqparse, | |||
| ) | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| from configs import dify_config | |||
| @@ -16,7 +13,6 @@ from controllers.console.wraps import ( | |||
| setup_required, | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from extensions.ext_database import db | |||
| from libs.login import login_required | |||
| @@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| # get all plugin oauth configs | |||
| plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( | |||
| provider=provider, | |||
| plugin_id=plugin_id | |||
| ).first() | |||
| plugin_oauth_config = ( | |||
| db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() | |||
| ) | |||
| if not plugin_oauth_config: | |||
| raise NotFound() | |||
| oauth_handler = OAuthHandler() | |||
| @@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource): | |||
| if system_credentials: | |||
| system_credentials["redirect_url"] = redirect_url | |||
| response = oauth_handler.get_authorization_url( | |||
| current_user.current_tenant.id, | |||
| current_user.id, | |||
| plugin_id, | |||
| provider, | |||
| system_credentials=system_credentials | |||
| current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials | |||
| ) | |||
| return response.model_dump() | |||
| class DatasourceOauthCallback(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider, plugin_id): | |||
| oauth_handler = OAuthHandler() | |||
| plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( | |||
| provider=provider, | |||
| plugin_id=plugin_id | |||
| ).first() | |||
| plugin_oauth_config = ( | |||
| db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() | |||
| ) | |||
| if not plugin_oauth_config: | |||
| raise NotFound() | |||
| credentials = oauth_handler.get_credentials( | |||
| @@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource): | |||
| plugin_id, | |||
| provider, | |||
| system_credentials=plugin_oauth_config.system_credentials, | |||
| request=request | |||
| request=request, | |||
| ) | |||
| datasource_provider = DatasourceProvider( | |||
| plugin_id=plugin_id, | |||
| provider=provider, | |||
| auth_type="oauth", | |||
| encrypted_credentials=credentials | |||
| plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials | |||
| ) | |||
| db.session.add(datasource_provider) | |||
| db.session.commit() | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}") | |||
| class DatasourceAuth(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -99,28 +88,27 @@ class DatasourceAuth(Resource): | |||
| try: | |||
| datasource_provider_service.datasource_provider_credentials_validate( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id, | |||
| credentials=args["credentials"] | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id, | |||
| credentials=args["credentials"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 201 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider, plugin_id): | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasources = datasource_provider_service.get_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id | |||
| tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id | |||
| ) | |||
| return {"result": datasources}, 200 | |||
| class DatasourceAuthDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource): | |||
| raise Forbidden() | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasource_provider_service.remove_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id | |||
| tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| # Import Rag Pipeline | |||
| api.add_resource( | |||
| DatasourcePluginOauthApi, | |||
| @@ -149,4 +136,3 @@ api.add_resource( | |||
| DatasourceAuth, | |||
| "/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>", | |||
| ) | |||
| @@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource): | |||
| dsl = yaml.safe_load(template.yaml_content) | |||
| return {"data": dsl}, 200 | |||
| class CustomizedPipelineTemplateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource): | |||
| RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) | |||
| return 200 | |||
| api.add_resource( | |||
| PipelineTemplateListApi, | |||
| "/rag/pipeline/templates", | |||
| @@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, pipeline_id): | |||
| return { | |||
| "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, | |||
| } | |||
| @@ -32,7 +32,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from fields.document_fields import dataset_and_document_fields | |||
| from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from models.dataset import Document, Pipeline | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| @@ -55,8 +54,7 @@ class PipelineGenerator(BaseAppGenerator): | |||
| streaming: Literal[True], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: | |||
| ... | |||
| ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... | |||
| @overload | |||
| def generate( | |||
| @@ -70,8 +68,7 @@ class PipelineGenerator(BaseAppGenerator): | |||
| streaming: Literal[False], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Mapping[str, Any]: | |||
| ... | |||
| ) -> Mapping[str, Any]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -85,8 +82,7 @@ class PipelineGenerator(BaseAppGenerator): | |||
| streaming: bool, | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: | |||
| ... | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... | |||
| def generate( | |||
| self, | |||
| @@ -233,17 +229,19 @@ class PipelineGenerator(BaseAppGenerator): | |||
| description=dataset.description, | |||
| chunk_structure=dataset.chunk_structure, | |||
| ).model_dump(), | |||
| "documents": [PipelineDocument( | |||
| id=document.id, | |||
| position=document.position, | |||
| data_source_type=document.data_source_type, | |||
| data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, | |||
| name=document.name, | |||
| indexing_status=document.indexing_status, | |||
| error=document.error, | |||
| enabled=document.enabled, | |||
| ).model_dump() for document in documents | |||
| ] | |||
| "documents": [ | |||
| PipelineDocument( | |||
| id=document.id, | |||
| position=document.position, | |||
| data_source_type=document.data_source_type, | |||
| data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, | |||
| name=document.name, | |||
| indexing_status=document.indexing_status, | |||
| error=document.error, | |||
| enabled=document.enabled, | |||
| ).model_dump() | |||
| for document in documents | |||
| ], | |||
| } | |||
| def _generate( | |||
| @@ -316,9 +314,7 @@ class PipelineGenerator(BaseAppGenerator): | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=worker_with_context | |||
| ) | |||
| worker_thread = threading.Thread(target=worker_with_context) | |||
| worker_thread.start() | |||
| @@ -111,7 +111,10 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| if workflow.rag_pipeline_variables: | |||
| for v in workflow.rag_pipeline_variables: | |||
| rag_pipeline_variable = RAGPipelineVariable(**v) | |||
| if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs: | |||
| if ( | |||
| rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id | |||
| and rag_pipeline_variable.variable in inputs | |||
| ): | |||
| rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] | |||
| variable_pool = VariablePool( | |||
| @@ -195,7 +198,7 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| continue | |||
| real_run_nodes.append(node) | |||
| for edge in edges: | |||
| if edge.get("source") in exclude_node_ids : | |||
| if edge.get("source") in exclude_node_ids: | |||
| continue | |||
| real_edges.append(edge) | |||
| graph_config = dict(graph_config) | |||
| @@ -232,6 +232,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): | |||
| """ | |||
| RAG Pipeline Application Generate Entity. | |||
| """ | |||
| # pipeline config | |||
| pipeline_config: WorkflowUIBasedAppConfig | |||
| datasource_type: str | |||
| @@ -5,7 +5,6 @@ from pydantic import Field | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.datasource.entities.datasource_entities import DatasourceInvokeFrom | |||
| from core.tools.entities.tool_entities import ToolInvokeFrom | |||
| class DatasourceRuntime(BaseModel): | |||
| @@ -46,7 +46,7 @@ class DatasourceManager: | |||
| if not provider_entity: | |||
| raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") | |||
| match (datasource_type): | |||
| match datasource_type: | |||
| case DatasourceProviderType.ONLINE_DOCUMENT: | |||
| controller = OnlineDocumentDatasourcePluginProviderController( | |||
| entity=provider_entity.declaration, | |||
| @@ -98,5 +98,3 @@ class DatasourceManager: | |||
| tenant_id, | |||
| datasource_type, | |||
| ).get_datasource(datasource_name) | |||
| @@ -215,7 +215,6 @@ class PluginDatasourceManager(BasePluginClient): | |||
| "X-Plugin-ID": datasource_provider_id.plugin_id, | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| @@ -233,41 +232,23 @@ class PluginDatasourceManager(BasePluginClient): | |||
| "identity": { | |||
| "author": "langgenius", | |||
| "name": "langgenius/file/file", | |||
| "label": { | |||
| "zh_Hans": "File", | |||
| "en_US": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File" | |||
| }, | |||
| "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||
| "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", | |||
| "description": { | |||
| "zh_Hans": "File", | |||
| "en_US": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File" | |||
| } | |||
| "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||
| }, | |||
| "credentials_schema": [], | |||
| "provider_type": "local_file", | |||
| "datasources": [{ | |||
| "identity": { | |||
| "author": "langgenius", | |||
| "name": "upload-file", | |||
| "provider": "langgenius", | |||
| "label": { | |||
| "zh_Hans": "File", | |||
| "en_US": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File" | |||
| } | |||
| }, | |||
| "parameters": [], | |||
| "description": { | |||
| "zh_Hans": "File", | |||
| "en_US": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File" | |||
| "datasources": [ | |||
| { | |||
| "identity": { | |||
| "author": "langgenius", | |||
| "name": "upload-file", | |||
| "provider": "langgenius", | |||
| "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||
| }, | |||
| "parameters": [], | |||
| "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||
| } | |||
| }] | |||
| } | |||
| ], | |||
| }, | |||
| } | |||
| @@ -28,12 +28,12 @@ class Jieba(BaseKeyword): | |||
| with redis_client.lock(lock_name, timeout=600): | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||
| keyword_number = ( | |||
| self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||
| ) | |||
| for text in texts: | |||
| keywords = keyword_table_handler.extract_keywords( | |||
| text.page_content, keyword_number | |||
| ) | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) | |||
| if text.metadata is not None: | |||
| self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table( | |||
| @@ -51,19 +51,17 @@ class Jieba(BaseKeyword): | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| keywords_list = kwargs.get("keywords_list") | |||
| keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||
| keyword_number = ( | |||
| self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||
| ) | |||
| for i in range(len(texts)): | |||
| text = texts[i] | |||
| if keywords_list: | |||
| keywords = keywords_list[i] | |||
| if not keywords: | |||
| keywords = keyword_table_handler.extract_keywords( | |||
| text.page_content, keyword_number | |||
| ) | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) | |||
| else: | |||
| keywords = keyword_table_handler.extract_keywords( | |||
| text.page_content, keyword_number | |||
| ) | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) | |||
| if text.metadata is not None: | |||
| self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table( | |||
| @@ -242,7 +240,9 @@ class Jieba(BaseKeyword): | |||
| keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] | |||
| ) | |||
| else: | |||
| keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||
| keyword_number = ( | |||
| self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||
| ) | |||
| keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) | |||
| segment.keywords = list(keywords) | |||
| @@ -38,7 +38,7 @@ class BaseIndexProcessor(ABC): | |||
| @abstractmethod | |||
| def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | |||
| raise NotImplementedError | |||
| @@ -15,7 +15,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.models.document import Document, GeneralStructureChunk | |||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | |||
| from libs import helper | |||
| from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule | |||
| from models.dataset import Dataset, DatasetProcessRule | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.entities.knowledge_entities.knowledge_entities import Rule | |||
| @@ -152,13 +153,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): | |||
| keyword = Keyword(dataset) | |||
| keyword.add_texts(documents) | |||
| def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | |||
| paragraph = GeneralStructureChunk(**chunks) | |||
| preview = [] | |||
| for content in paragraph.general_chunks: | |||
| preview.append({"content": content}) | |||
| return { | |||
| "preview": preview, | |||
| "total_segments": len(paragraph.general_chunks) | |||
| } | |||
| return {"preview": preview, "total_segments": len(paragraph.general_chunks)} | |||
| @@ -16,7 +16,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment | |||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | |||
| @@ -239,14 +240,5 @@ class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| parent_childs = ParentChildStructureChunk(**chunks) | |||
| preview = [] | |||
| for parent_child in parent_childs.parent_child_chunks: | |||
| preview.append( | |||
| { | |||
| "content": parent_child.parent_content, | |||
| "child_chunks": parent_child.child_contents | |||
| } | |||
| ) | |||
| return { | |||
| "preview": preview, | |||
| "total_segments": len(parent_childs.parent_child_chunks) | |||
| } | |||
| preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) | |||
| return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)} | |||
| @@ -4,7 +4,8 @@ import logging | |||
| import re | |||
| import threading | |||
| import uuid | |||
| from typing import Any, Mapping, Optional | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| import pandas as pd | |||
| from flask import Flask, current_app | |||
| @@ -20,7 +21,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.models.document import Document | |||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | |||
| from libs import helper | |||
| from models.dataset import Dataset, Document as DatasetDocument | |||
| from models.dataset import Dataset | |||
| from services.entities.knowledge_entities.knowledge_entities import Rule | |||
| @@ -160,10 +161,10 @@ class QAIndexProcessor(BaseIndexProcessor): | |||
| doc = Document(page_content=result.page_content, metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): | |||
| pass | |||
| def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | |||
| return {"preview": chunks} | |||
| @@ -94,19 +94,26 @@ class FileVariable(FileSegment, Variable): | |||
| class ArrayFileVariable(ArrayFileSegment, ArrayVariable): | |||
| pass | |||
| class RAGPipelineVariable(BaseModel): | |||
| belong_to_node_id: str = Field(description="belong to which node id, shared means public") | |||
| type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") | |||
| label: str = Field(description="label") | |||
| description: str | None = Field(description="description", default="") | |||
| variable: str = Field(description="variable key", default="") | |||
| max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0) | |||
| max_length: int | None = Field( | |||
| description="max length, applicable to text-input, paragraph, and file-list", default=0 | |||
| ) | |||
| default_value: str | None = Field(description="default value", default="") | |||
| placeholder: str | None = Field(description="placeholder", default="") | |||
| unit: str | None = Field(description="unit, applicable to Number", default="") | |||
| tooltips: str | None = Field(description="helpful text", default="") | |||
| allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list) | |||
| allowed_file_types: list[str] | None = Field( | |||
| description="image, document, audio, video, custom.", default_factory=list | |||
| ) | |||
| allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) | |||
| allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list) | |||
| allowed_file_upload_methods: list[str] | None = Field( | |||
| description="remote_url, local_file, tool_file.", default_factory=list | |||
| ) | |||
| required: bool = Field(description="optional, default false", default=False) | |||
| options: list[str] | None = Field(default_factory=list) | |||
| @@ -49,7 +49,7 @@ class VariablePool(BaseModel): | |||
| ) | |||
| rag_pipeline_variables: Mapping[str, Any] = Field( | |||
| description="RAG pipeline variables.", | |||
| default_factory=dict, | |||
| default_factory=dict, | |||
| ) | |||
| def __init__( | |||
| @@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): | |||
| AGENT_LOG = "agent_log" | |||
| ITERATION_ID = "iteration_id" | |||
| ITERATION_INDEX = "iteration_index" | |||
| DATASOURCE_INFO = "datasource_info" | |||
| LOOP_ID = "loop_id" | |||
| LOOP_INDEX = "loop_index" | |||
| PARALLEL_ID = "parallel_id" | |||
| @@ -122,7 +122,6 @@ class Graph(BaseModel): | |||
| root_node_configs = [] | |||
| all_node_id_config_mapping: dict[str, dict] = {} | |||
| for node_config in node_configs: | |||
| node_id = node_config.get("id") | |||
| if not node_id: | |||
| @@ -142,7 +141,7 @@ class Graph(BaseModel): | |||
| ( | |||
| node_config.get("id") | |||
| for node_config in root_node_configs | |||
| if node_config.get("data", {}).get("type", "") == NodeType.START.value | |||
| if node_config.get("data", {}).get("type", "") == NodeType.START.value | |||
| or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value | |||
| ), | |||
| None, | |||
| @@ -317,10 +317,10 @@ class GraphEngine: | |||
| raise e | |||
| # It may not be necessary, but it is necessary. :) | |||
| if ( | |||
| self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() | |||
| in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value] | |||
| ): | |||
| if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [ | |||
| NodeType.END.value, | |||
| NodeType.KNOWLEDGE_INDEX.value, | |||
| ]: | |||
| break | |||
| previous_route_node_state = route_node_state | |||
| @@ -11,18 +11,19 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen | |||
| from core.file import File | |||
| from core.file.enums import FileTransferMethod, FileType | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from core.variables.segments import ArrayAnySegment, FileSegment | |||
| from core.variables.segments import ArrayAnySegment | |||
| from core.variables.variables import ArrayAnyVariable | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from models.model import UploadFile | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from .entities import DatasourceNodeData | |||
| from .exc import DatasourceNodeError, DatasourceParameterError | |||
| @@ -54,7 +55,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| try: | |||
| from core.datasource.datasource_manager import DatasourceManager | |||
| if datasource_type is None: | |||
| raise DatasourceNodeError("Datasource type is not set") | |||
| @@ -66,13 +66,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| ) | |||
| except DatasourceNodeError as e: | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs={}, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| error=f"Failed to get datasource runtime: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs={}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| error=f"Failed to get datasource runtime: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| # get parameters | |||
| datasource_parameters = datasource_runtime.entity.parameters | |||
| @@ -102,7 +101,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| outputs={ | |||
| "online_document": online_document_result.result.model_dump(), | |||
| "datasource_type": datasource_type, | |||
| @@ -112,18 +111,16 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| outputs={ | |||
| "website": datasource_info, | |||
| "datasource_type": datasource_type, | |||
| "website": datasource_info, | |||
| "datasource_type": datasource_type, | |||
| }, | |||
| ) | |||
| case DatasourceProviderType.LOCAL_FILE: | |||
| related_id = datasource_info.get("related_id") | |||
| if not related_id: | |||
| raise DatasourceNodeError( | |||
| "File is not exist" | |||
| ) | |||
| raise DatasourceNodeError("File is not exist") | |||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() | |||
| if not upload_file: | |||
| raise ValueError("Invalid upload file Info") | |||
| @@ -146,26 +143,27 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| # construct new key list | |||
| new_key_list = ["file", key] | |||
| self._append_variables_recursively( | |||
| variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value | |||
| variable_pool=variable_pool, | |||
| node_id=self.node_id, | |||
| variable_key_list=new_key_list, | |||
| variable_value=value, | |||
| ) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| outputs={ | |||
| "file_info": datasource_info, | |||
| "datasource_type": datasource_type, | |||
| }, | |||
| ) | |||
| case _: | |||
| raise DatasourceNodeError( | |||
| f"Unsupported datasource provider: {datasource_type}" | |||
| "file_info": datasource_info, | |||
| "datasource_type": datasource_type, | |||
| }, | |||
| ) | |||
| case _: | |||
| raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") | |||
| except PluginDaemonClientSideError as e: | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| error=f"Failed to transform datasource message: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| @@ -173,7 +171,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| error=f"Failed to invoke datasource: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| @@ -227,8 +225,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) | |||
| return list(variable.value) if variable else [] | |||
| def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): | |||
| def _append_variables_recursively( | |||
| self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | |||
| ): | |||
| """ | |||
| Append variables recursively | |||
| :param node_id: node id | |||
| @@ -6,7 +6,6 @@ from typing import Any, cast | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.variables.segments import ObjectSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| @@ -72,8 +71,9 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): | |||
| process_data=None, | |||
| outputs=outputs, | |||
| ) | |||
| results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks, | |||
| variable_pool=variable_pool) | |||
| results = self._invoke_knowledge_index( | |||
| dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool | |||
| ) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results | |||
| ) | |||
| @@ -96,8 +96,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): | |||
| ) | |||
| def _invoke_knowledge_index( | |||
| self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], | |||
| variable_pool: VariablePool | |||
| self, | |||
| dataset: Dataset, | |||
| node_data: KnowledgeIndexNodeData, | |||
| chunks: Mapping[str, Any], | |||
| variable_pool: VariablePool, | |||
| ) -> Any: | |||
| document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) | |||
| if not document_id: | |||
| @@ -116,7 +119,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): | |||
| document.indexing_status = "completed" | |||
| document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| #update document segment status | |||
| # update document segment status | |||
| db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == document.id, | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| @@ -208,6 +208,7 @@ class Dataset(Base): | |||
| "external_knowledge_api_name": external_knowledge_api.name, | |||
| "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | |||
| } | |||
| @property | |||
| def is_published(self): | |||
| if self.pipeline_id: | |||
| @@ -1177,7 +1178,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] | |||
| __tablename__ = "pipeline_customized_templates" | |||
| __table_args__ = ( | |||
| @@ -1,4 +1,3 @@ | |||
| from datetime import datetime | |||
| from sqlalchemy.dialects.postgresql import JSONB | |||
| @@ -21,6 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] | |||
| provider: Mapped[str] = db.Column(db.String(255), nullable=False) | |||
| system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) | |||
| class DatasourceProvider(Base): | |||
| __tablename__ = "datasource_providers" | |||
| __table_args__ = ( | |||
| @@ -1,4 +1,3 @@ | |||
| from calendar import day_abbr | |||
| import copy | |||
| import datetime | |||
| import json | |||
| @@ -7,7 +6,7 @@ import random | |||
| import time | |||
| import uuid | |||
| from collections import Counter | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, Optional | |||
| from flask_login import current_user | |||
| from sqlalchemy import func, select | |||
| @@ -282,7 +281,6 @@ class DatasetService: | |||
| db.session.commit() | |||
| return dataset | |||
| @staticmethod | |||
| def get_dataset(dataset_id) -> Optional[Dataset]: | |||
| dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() | |||
| @@ -494,10 +492,9 @@ class DatasetService: | |||
| return dataset | |||
| @staticmethod | |||
| def update_rag_pipeline_dataset_settings(session: Session, | |||
| dataset: Dataset, | |||
| knowledge_configuration: KnowledgeConfiguration, | |||
| has_published: bool = False): | |||
| def update_rag_pipeline_dataset_settings( | |||
| session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False | |||
| ): | |||
| dataset = session.merge(dataset) | |||
| if not has_published: | |||
| dataset.chunk_structure = knowledge_configuration.chunk_structure | |||
| @@ -616,7 +613,6 @@ class DatasetService: | |||
| if action: | |||
| deal_dataset_index_update_task.delay(dataset.id, action) | |||
| @staticmethod | |||
| def delete_dataset(dataset_id, user): | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import Optional | |||
| from flask_login import current_user | |||
| @@ -22,11 +21,9 @@ class DatasourceProviderService: | |||
| def __init__(self) -> None: | |||
| self.provider_manager = PluginDatasourceManager() | |||
| def datasource_provider_credentials_validate(self, | |||
| tenant_id: str, | |||
| provider: str, | |||
| plugin_id: str, | |||
| credentials: dict) -> None: | |||
| def datasource_provider_credentials_validate( | |||
| self, tenant_id: str, provider: str, plugin_id: str, credentials: dict | |||
| ) -> None: | |||
| """ | |||
| validate datasource provider credentials. | |||
| @@ -34,29 +31,30 @@ class DatasourceProviderService: | |||
| :param provider: | |||
| :param credentials: | |||
| """ | |||
| credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, | |||
| user_id=current_user.id, | |||
| provider=provider, | |||
| credentials=credentials) | |||
| credential_valid = self.provider_manager.validate_provider_credentials( | |||
| tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials | |||
| ) | |||
| if credential_valid: | |||
| # Get all provider configurations of the current workspace | |||
| datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id).first() | |||
| datasource_provider = ( | |||
| db.session.query(DatasourceProvider) | |||
| .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) | |||
| .first() | |||
| ) | |||
| provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, | |||
| provider=provider | |||
| ) | |||
| provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) | |||
| if not datasource_provider: | |||
| for key, value in credentials.items(): | |||
| if key in provider_credential_secret_variables: | |||
| # if send [__HIDDEN__] in secret input, it will be same as original value | |||
| credentials[key] = encrypter.encrypt_token(tenant_id, value) | |||
| datasource_provider = DatasourceProvider(tenant_id=tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id, | |||
| auth_type="api_key", | |||
| encrypted_credentials=credentials) | |||
| datasource_provider = DatasourceProvider( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id, | |||
| auth_type="api_key", | |||
| encrypted_credentials=credentials, | |||
| ) | |||
| db.session.add(datasource_provider) | |||
| db.session.commit() | |||
| else: | |||
| @@ -101,11 +99,15 @@ class DatasourceProviderService: | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter( | |||
| DatasourceProvider.tenant_id == tenant_id, | |||
| DatasourceProvider.provider == provider, | |||
| DatasourceProvider.plugin_id == plugin_id | |||
| ).all() | |||
| datasource_providers: list[DatasourceProvider] = ( | |||
| db.session.query(DatasourceProvider) | |||
| .filter( | |||
| DatasourceProvider.tenant_id == tenant_id, | |||
| DatasourceProvider.provider == provider, | |||
| DatasourceProvider.plugin_id == plugin_id, | |||
| ) | |||
| .all() | |||
| ) | |||
| if not datasource_providers: | |||
| return [] | |||
| copy_credentials_list = [] | |||
| @@ -128,10 +130,7 @@ class DatasourceProviderService: | |||
| return copy_credentials_list | |||
| def remove_datasource_credentials(self, | |||
| tenant_id: str, | |||
| provider: str, | |||
| plugin_id: str) -> None: | |||
| def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None: | |||
| """ | |||
| remove datasource credentials. | |||
| @@ -140,9 +139,11 @@ class DatasourceProviderService: | |||
| :param plugin_id: plugin id | |||
| :return: | |||
| """ | |||
| datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, | |||
| provider=provider, | |||
| plugin_id=plugin_id).first() | |||
| datasource_provider = ( | |||
| db.session.query(DatasourceProvider) | |||
| .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) | |||
| .first() | |||
| ) | |||
| if datasource_provider: | |||
| db.session.delete(datasource_provider) | |||
| db.session.commit() | |||
| @@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel): | |||
| """ | |||
| Knowledge Base Configuration. | |||
| """ | |||
| chunk_structure: str | |||
| indexing_technique: Literal["high_quality", "economy"] | |||
| embedding_model_provider: Optional[str] = "" | |||
| @@ -3,7 +3,6 @@ from typing import Any, Union | |||
| from configs import dify_config | |||
| from core.app.apps.pipeline.pipeline_generator import PipelineGenerator | |||
| from core.app.apps.workflow.app_generator import WorkflowAppGenerator | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from models.dataset import Pipeline | |||
| from models.model import Account, App, EndUser | |||
| @@ -1,13 +1,12 @@ | |||
| from typing import Optional | |||
| from flask_login import current_user | |||
| import yaml | |||
| from flask_login import current_user | |||
| from extensions.ext_database import db | |||
| from models.dataset import PipelineCustomizedTemplate | |||
| from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase | |||
| from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType | |||
| from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService | |||
| class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | |||
| @@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | |||
| ) | |||
| recommended_pipelines_results = [] | |||
| for pipeline_customized_template in pipeline_customized_templates: | |||
| recommended_pipeline_result = { | |||
| "id": pipeline_customized_template.id, | |||
| "name": pipeline_customized_template.name, | |||
| @@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | |||
| return {"pipeline_templates": recommended_pipelines_results} | |||
| @classmethod | |||
| def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: | |||
| """ | |||
| @@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | |||
| recommended_pipelines_results = [] | |||
| for pipeline_built_in_template in pipeline_built_in_templates: | |||
| recommended_pipeline_result = { | |||
| "id": pipeline_built_in_template.id, | |||
| "name": pipeline_built_in_template.name, | |||
| @@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| from models.account import Account | |||
| from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore | |||
| from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore | |||
| from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom | |||
| from models.model import EndUser | |||
| from models.workflow import ( | |||
| @@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi | |||
| class RagPipelineService: | |||
| @classmethod | |||
| def get_pipeline_templates( | |||
| cls, type: str = "built-in", language: str = "en-US" | |||
| ) -> dict: | |||
| def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: | |||
| if type == "built-in": | |||
| mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE | |||
| retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() | |||
| @@ -308,7 +306,7 @@ class RagPipelineService: | |||
| session=session, | |||
| dataset=dataset, | |||
| knowledge_configuration=knowledge_configuration, | |||
| has_published=pipeline.is_published | |||
| has_published=pipeline.is_published, | |||
| ) | |||
| # return new workflow | |||
| return workflow | |||
| @@ -444,12 +442,10 @@ class RagPipelineService: | |||
| ) | |||
| if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: | |||
| datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) | |||
| online_document_result: GetOnlineDocumentPagesResponse = ( | |||
| datasource_runtime._get_online_document_pages( | |||
| user_id=account.id, | |||
| datasource_parameters=user_inputs, | |||
| provider_type=datasource_runtime.datasource_provider_type(), | |||
| ) | |||
| online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( | |||
| user_id=account.id, | |||
| datasource_parameters=user_inputs, | |||
| provider_type=datasource_runtime.datasource_provider_type(), | |||
| ) | |||
| return { | |||
| "result": [page.model_dump() for page in online_document_result.result], | |||
| @@ -470,7 +466,6 @@ class RagPipelineService: | |||
| else: | |||
| raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") | |||
| def run_free_workflow_node( | |||
| self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] | |||
| ) -> WorkflowNodeExecution: | |||
| @@ -689,8 +684,8 @@ class RagPipelineService: | |||
| WorkflowRun.app_id == pipeline.id, | |||
| or_( | |||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, | |||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value | |||
| ) | |||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value, | |||
| ), | |||
| ) | |||
| if args.get("last_id"): | |||
| @@ -763,18 +758,17 @@ class RagPipelineService: | |||
| # Use the repository to get the node execution | |||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=db.engine, | |||
| app_id=pipeline.id, | |||
| user=user, | |||
| triggered_from=None | |||
| session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None | |||
| ) | |||
| # Use the repository to get the node executions with ordering | |||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | |||
| node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, | |||
| order_config=order_config, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) | |||
| # Convert domain models to database models | |||
| node_executions = repository.get_by_workflow_run( | |||
| workflow_run_id=run_id, | |||
| order_config=order_config, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, | |||
| ) | |||
| # Convert domain models to database models | |||
| workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] | |||
| return workflow_node_executions | |||
| @@ -279,7 +279,11 @@ class RagPipelineDslService: | |||
| if node.get("data", {}).get("type") == "knowledge_index": | |||
| knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) | |||
| knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) | |||
| if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure: | |||
| if ( | |||
| dataset | |||
| and pipeline.is_published | |||
| and dataset.chunk_structure != knowledge_configuration.chunk_structure | |||
| ): | |||
| raise ValueError("Chunk structure is not compatible with the published pipeline") | |||
| else: | |||
| dataset = Dataset( | |||
| @@ -304,8 +308,7 @@ class RagPipelineDslService: | |||
| .filter( | |||
| DatasetCollectionBinding.provider_name | |||
| == knowledge_configuration.embedding_model_provider, | |||
| DatasetCollectionBinding.model_name | |||
| == knowledge_configuration.embedding_model, | |||
| DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, | |||
| DatasetCollectionBinding.type == "dataset", | |||
| ) | |||
| .order_by(DatasetCollectionBinding.created_at) | |||
| @@ -323,12 +326,8 @@ class RagPipelineDslService: | |||
| db.session.commit() | |||
| dataset_collection_binding_id = dataset_collection_binding.id | |||
| dataset.collection_binding_id = dataset_collection_binding_id | |||
| dataset.embedding_model = ( | |||
| knowledge_configuration.embedding_model | |||
| ) | |||
| dataset.embedding_model_provider = ( | |||
| knowledge_configuration.embedding_model_provider | |||
| ) | |||
| dataset.embedding_model = knowledge_configuration.embedding_model | |||
| dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider | |||
| elif knowledge_configuration.indexing_technique == "economy": | |||
| dataset.keyword_number = knowledge_configuration.keyword_number | |||
| dataset.pipeline_id = pipeline.id | |||
| @@ -443,8 +442,7 @@ class RagPipelineDslService: | |||
| .filter( | |||
| DatasetCollectionBinding.provider_name | |||
| == knowledge_configuration.embedding_model_provider, | |||
| DatasetCollectionBinding.model_name | |||
| == knowledge_configuration.embedding_model, | |||
| DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, | |||
| DatasetCollectionBinding.type == "dataset", | |||
| ) | |||
| .order_by(DatasetCollectionBinding.created_at) | |||
| @@ -462,12 +460,8 @@ class RagPipelineDslService: | |||
| db.session.commit() | |||
| dataset_collection_binding_id = dataset_collection_binding.id | |||
| dataset.collection_binding_id = dataset_collection_binding_id | |||
| dataset.embedding_model = ( | |||
| knowledge_configuration.embedding_model | |||
| ) | |||
| dataset.embedding_model_provider = ( | |||
| knowledge_configuration.embedding_model_provider | |||
| ) | |||
| dataset.embedding_model = knowledge_configuration.embedding_model | |||
| dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider | |||
| elif knowledge_configuration.indexing_technique == "economy": | |||
| dataset.keyword_number = knowledge_configuration.keyword_number | |||
| dataset.pipeline_id = pipeline.id | |||
| @@ -538,7 +532,6 @@ class RagPipelineDslService: | |||
| icon_type = "emoji" | |||
| icon = str(pipeline_data.get("icon", "")) | |||
| # Initialize pipeline based on mode | |||
| workflow_data = data.get("workflow") | |||
| if not workflow_data or not isinstance(workflow_data, dict): | |||
| @@ -554,7 +547,6 @@ class RagPipelineDslService: | |||
| ] | |||
| rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) | |||
| graph = workflow_data.get("graph", {}) | |||
| for node in graph.get("nodes", []): | |||
| if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: | |||
| @@ -576,7 +568,6 @@ class RagPipelineDslService: | |||
| pipeline.description = pipeline_data.get("description", pipeline.description) | |||
| pipeline.updated_by = account.id | |||
| else: | |||
| if account.current_tenant_id is None: | |||
| raise ValueError("Current tenant is not set") | |||
| @@ -636,7 +627,6 @@ class RagPipelineDslService: | |||
| # commit db session changes | |||
| db.session.commit() | |||
| return pipeline | |||
| @classmethod | |||
| @@ -874,7 +864,6 @@ class RagPipelineDslService: | |||
| except Exception: | |||
| return None | |||
| @staticmethod | |||
| def create_rag_pipeline_dataset( | |||
| tenant_id: str, | |||
| @@ -886,9 +875,7 @@ class RagPipelineDslService: | |||
| .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) | |||
| .first() | |||
| ): | |||
| raise ValueError( | |||
| f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." | |||
| ) | |||
| raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") | |||
| with Session(db.engine) as session: | |||
| rag_pipeline_dsl_service = RagPipelineDslService(session) | |||
| @@ -12,12 +12,12 @@ class RagPipelineManageService: | |||
| # get all builtin providers | |||
| manager = PluginDatasourceManager() | |||
| datasources = manager.fetch_datasource_providers(tenant_id) | |||
| datasources = manager.fetch_datasource_providers(tenant_id) | |||
| for datasource in datasources: | |||
| datasource_provider_service = DatasourceProviderService() | |||
| credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id, | |||
| provider=datasource.provider, | |||
| plugin_id=datasource.plugin_id) | |||
| credentials = datasource_provider_service.get_datasource_credentials( | |||
| tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id | |||
| ) | |||
| if credentials: | |||
| datasource.is_authorized = True | |||
| return datasources | |||