| import os | |||||
| import sys | import sys | ||||
| # so we need to disable gevent in debug mode. | # so we need to disable gevent in debug mode. | ||||
| # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. | # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. | ||||
| # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: | # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: | ||||
| #from gevent import monkey | |||||
| # | |||||
| # # gevent | |||||
| # monkey.patch_all() | |||||
| # | |||||
| # from grpc.experimental import gevent as grpc_gevent # type: ignore | |||||
| # | |||||
| # # grpc gevent | |||||
| # grpc_gevent.init_gevent() | |||||
| # import psycogreen.gevent # type: ignore | |||||
| # | |||||
| # psycogreen.gevent.patch_psycopg() | |||||
| # from gevent import monkey | |||||
| # | |||||
| # # gevent | |||||
| # monkey.patch_all() | |||||
| # | |||||
| # from grpc.experimental import gevent as grpc_gevent # type: ignore | |||||
| # | |||||
| # # grpc gevent | |||||
| # grpc_gevent.init_gevent() | |||||
| # import psycogreen.gevent # type: ignore | |||||
| # | |||||
| # psycogreen.gevent.patch_psycopg() | |||||
| from app_factory import create_app | from app_factory import create_app | ||||
| return {"result": "success"}, 200 | return {"result": "success"}, 200 | ||||
| api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") | api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") | ||||
| api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") | api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") | ||||
| api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") | api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") |
| import requests | import requests | ||||
| from flask import current_app, redirect, request | from flask import current_app, redirect, request | ||||
| from flask_login import current_user | |||||
| from flask_restful import Resource | from flask_restful import Resource | ||||
| from sqlalchemy import select | from sqlalchemy import select | ||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from werkzeug.exceptions import Forbidden, NotFound, Unauthorized | |||||
| from werkzeug.exceptions import Unauthorized | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from constants.languages import languages | from constants.languages import languages | ||||
| from controllers.console.wraps import account_initialization_required, setup_required | |||||
| from core.plugin.impl.oauth import OAuthHandler | |||||
| from events.tenant_event import tenant_was_created | from events.tenant_event import tenant_was_created | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.helper import extract_remote_ip | from libs.helper import extract_remote_ip | ||||
| from libs.login import login_required | |||||
| from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | ||||
| from models import Account | from models import Account | ||||
| from models.account import AccountStatus | from models.account import AccountStatus | ||||
| from models.oauth import DatasourceOauthParamConfig, DatasourceProvider | |||||
| from services.account_service import AccountService, RegisterService, TenantService | from services.account_service import AccountService, RegisterService, TenantService | ||||
| from services.errors.account import AccountNotFoundError, AccountRegisterError | from services.errors.account import AccountNotFoundError, AccountRegisterError | ||||
| from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError | from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError | ||||
| return account | return account | ||||
| api.add_resource(OAuthLogin, "/oauth/login/<provider>") | api.add_resource(OAuthLogin, "/oauth/login/<provider>") | ||||
| api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") | api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") |
| from flask import redirect, request | from flask import redirect, request | ||||
| from flask_login import current_user # type: ignore | from flask_login import current_user # type: ignore | ||||
| from flask_restful import ( # type: ignore | from flask_restful import ( # type: ignore | ||||
| Resource, # type: ignore | Resource, # type: ignore | ||||
| marshal_with, | |||||
| reqparse, | reqparse, | ||||
| ) | ) | ||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import Forbidden, NotFound | from werkzeug.exceptions import Forbidden, NotFound | ||||
| from configs import dify_config | from configs import dify_config | ||||
| setup_required, | setup_required, | ||||
| ) | ) | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||||
| from core.plugin.impl.oauth import OAuthHandler | from core.plugin.impl.oauth import OAuthHandler | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.login import login_required | from libs.login import login_required | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| # get all plugin oauth configs | # get all plugin oauth configs | ||||
| plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( | |||||
| provider=provider, | |||||
| plugin_id=plugin_id | |||||
| ).first() | |||||
| plugin_oauth_config = ( | |||||
| db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() | |||||
| ) | |||||
| if not plugin_oauth_config: | if not plugin_oauth_config: | ||||
| raise NotFound() | raise NotFound() | ||||
| oauth_handler = OAuthHandler() | oauth_handler = OAuthHandler() | ||||
| if system_credentials: | if system_credentials: | ||||
| system_credentials["redirect_url"] = redirect_url | system_credentials["redirect_url"] = redirect_url | ||||
| response = oauth_handler.get_authorization_url( | response = oauth_handler.get_authorization_url( | ||||
| current_user.current_tenant.id, | |||||
| current_user.id, | |||||
| plugin_id, | |||||
| provider, | |||||
| system_credentials=system_credentials | |||||
| current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials | |||||
| ) | ) | ||||
| return response.model_dump() | return response.model_dump() | ||||
| class DatasourceOauthCallback(Resource): | class DatasourceOauthCallback(Resource): | ||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider, plugin_id): | def get(self, provider, plugin_id): | ||||
| oauth_handler = OAuthHandler() | oauth_handler = OAuthHandler() | ||||
| plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( | |||||
| provider=provider, | |||||
| plugin_id=plugin_id | |||||
| ).first() | |||||
| plugin_oauth_config = ( | |||||
| db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() | |||||
| ) | |||||
| if not plugin_oauth_config: | if not plugin_oauth_config: | ||||
| raise NotFound() | raise NotFound() | ||||
| credentials = oauth_handler.get_credentials( | credentials = oauth_handler.get_credentials( | ||||
| plugin_id, | plugin_id, | ||||
| provider, | provider, | ||||
| system_credentials=plugin_oauth_config.system_credentials, | system_credentials=plugin_oauth_config.system_credentials, | ||||
| request=request | |||||
| request=request, | |||||
| ) | ) | ||||
| datasource_provider = DatasourceProvider( | datasource_provider = DatasourceProvider( | ||||
| plugin_id=plugin_id, | |||||
| provider=provider, | |||||
| auth_type="oauth", | |||||
| encrypted_credentials=credentials | |||||
| plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials | |||||
| ) | ) | ||||
| db.session.add(datasource_provider) | db.session.add(datasource_provider) | ||||
| db.session.commit() | db.session.commit() | ||||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}") | return redirect(f"{dify_config.CONSOLE_WEB_URL}") | ||||
| class DatasourceAuth(Resource): | class DatasourceAuth(Resource): | ||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| try: | try: | ||||
| datasource_provider_service.datasource_provider_credentials_validate( | datasource_provider_service.datasource_provider_credentials_validate( | ||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id, | |||||
| credentials=args["credentials"] | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id, | |||||
| credentials=args["credentials"], | |||||
| ) | ) | ||||
| except CredentialsValidateFailedError as ex: | except CredentialsValidateFailedError as ex: | ||||
| raise ValueError(str(ex)) | raise ValueError(str(ex)) | ||||
| return {"result": "success"}, 201 | return {"result": "success"}, 201 | ||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider, plugin_id): | def get(self, provider, plugin_id): | ||||
| datasource_provider_service = DatasourceProviderService() | datasource_provider_service = DatasourceProviderService() | ||||
| datasources = datasource_provider_service.get_datasource_credentials( | datasources = datasource_provider_service.get_datasource_credentials( | ||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id | |||||
| tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id | |||||
| ) | ) | ||||
| return {"result": datasources}, 200 | return {"result": datasources}, 200 | ||||
| class DatasourceAuthDeleteApi(Resource): | class DatasourceAuthDeleteApi(Resource): | ||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| raise Forbidden() | raise Forbidden() | ||||
| datasource_provider_service = DatasourceProviderService() | datasource_provider_service = DatasourceProviderService() | ||||
| datasource_provider_service.remove_datasource_credentials( | datasource_provider_service.remove_datasource_credentials( | ||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id | |||||
| tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id | |||||
| ) | ) | ||||
| return {"result": "success"}, 200 | return {"result": "success"}, 200 | ||||
| # Import Rag Pipeline | # Import Rag Pipeline | ||||
| api.add_resource( | api.add_resource( | ||||
| DatasourcePluginOauthApi, | DatasourcePluginOauthApi, | ||||
| DatasourceAuth, | DatasourceAuth, | ||||
| "/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>", | "/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>", | ||||
| ) | ) | ||||
| dsl = yaml.safe_load(template.yaml_content) | dsl = yaml.safe_load(template.yaml_content) | ||||
| return {"data": dsl}, 200 | return {"data": dsl}, 200 | ||||
| class CustomizedPipelineTemplateApi(Resource): | class CustomizedPipelineTemplateApi(Resource): | ||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) | RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) | ||||
| return 200 | return 200 | ||||
| api.add_resource( | api.add_resource( | ||||
| PipelineTemplateListApi, | PipelineTemplateListApi, | ||||
| "/rag/pipeline/templates", | "/rag/pipeline/templates", |
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, pipeline_id): | def get(self, pipeline_id): | ||||
| return { | return { | ||||
| "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, | "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, | ||||
| } | } |
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | ||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from fields.document_fields import dataset_and_document_fields | |||||
| from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | ||||
| from models.dataset import Document, Pipeline | from models.dataset import Document, Pipeline | ||||
| from models.enums import WorkflowRunTriggeredFrom | from models.enums import WorkflowRunTriggeredFrom | ||||
| streaming: Literal[True], | streaming: Literal[True], | ||||
| call_depth: int, | call_depth: int, | ||||
| workflow_thread_pool_id: Optional[str], | workflow_thread_pool_id: Optional[str], | ||||
| ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: | |||||
| ... | |||||
| ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| streaming: Literal[False], | streaming: Literal[False], | ||||
| call_depth: int, | call_depth: int, | ||||
| workflow_thread_pool_id: Optional[str], | workflow_thread_pool_id: Optional[str], | ||||
| ) -> Mapping[str, Any]: | |||||
| ... | |||||
| ) -> Mapping[str, Any]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| streaming: bool, | streaming: bool, | ||||
| call_depth: int, | call_depth: int, | ||||
| workflow_thread_pool_id: Optional[str], | workflow_thread_pool_id: Optional[str], | ||||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: | |||||
| ... | |||||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... | |||||
| def generate( | def generate( | ||||
| self, | self, | ||||
| description=dataset.description, | description=dataset.description, | ||||
| chunk_structure=dataset.chunk_structure, | chunk_structure=dataset.chunk_structure, | ||||
| ).model_dump(), | ).model_dump(), | ||||
| "documents": [PipelineDocument( | |||||
| id=document.id, | |||||
| position=document.position, | |||||
| data_source_type=document.data_source_type, | |||||
| data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, | |||||
| name=document.name, | |||||
| indexing_status=document.indexing_status, | |||||
| error=document.error, | |||||
| enabled=document.enabled, | |||||
| ).model_dump() for document in documents | |||||
| ] | |||||
| "documents": [ | |||||
| PipelineDocument( | |||||
| id=document.id, | |||||
| position=document.position, | |||||
| data_source_type=document.data_source_type, | |||||
| data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, | |||||
| name=document.name, | |||||
| indexing_status=document.indexing_status, | |||||
| error=document.error, | |||||
| enabled=document.enabled, | |||||
| ).model_dump() | |||||
| for document in documents | |||||
| ], | |||||
| } | } | ||||
| def _generate( | def _generate( | ||||
| ) | ) | ||||
| # new thread | # new thread | ||||
| worker_thread = threading.Thread( | |||||
| target=worker_with_context | |||||
| ) | |||||
| worker_thread = threading.Thread(target=worker_with_context) | |||||
| worker_thread.start() | worker_thread.start() | ||||
| if workflow.rag_pipeline_variables: | if workflow.rag_pipeline_variables: | ||||
| for v in workflow.rag_pipeline_variables: | for v in workflow.rag_pipeline_variables: | ||||
| rag_pipeline_variable = RAGPipelineVariable(**v) | rag_pipeline_variable = RAGPipelineVariable(**v) | ||||
| if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs: | |||||
| if ( | |||||
| rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id | |||||
| and rag_pipeline_variable.variable in inputs | |||||
| ): | |||||
| rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] | rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] | ||||
| variable_pool = VariablePool( | variable_pool = VariablePool( | ||||
| continue | continue | ||||
| real_run_nodes.append(node) | real_run_nodes.append(node) | ||||
| for edge in edges: | for edge in edges: | ||||
| if edge.get("source") in exclude_node_ids : | |||||
| if edge.get("source") in exclude_node_ids: | |||||
| continue | continue | ||||
| real_edges.append(edge) | real_edges.append(edge) | ||||
| graph_config = dict(graph_config) | graph_config = dict(graph_config) |
| """ | """ | ||||
| RAG Pipeline Application Generate Entity. | RAG Pipeline Application Generate Entity. | ||||
| """ | """ | ||||
| # pipeline config | # pipeline config | ||||
| pipeline_config: WorkflowUIBasedAppConfig | pipeline_config: WorkflowUIBasedAppConfig | ||||
| datasource_type: str | datasource_type: str |
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.datasource.entities.datasource_entities import DatasourceInvokeFrom | from core.datasource.entities.datasource_entities import DatasourceInvokeFrom | ||||
| from core.tools.entities.tool_entities import ToolInvokeFrom | |||||
| class DatasourceRuntime(BaseModel): | class DatasourceRuntime(BaseModel): |
| if not provider_entity: | if not provider_entity: | ||||
| raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") | raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") | ||||
| match (datasource_type): | |||||
| match datasource_type: | |||||
| case DatasourceProviderType.ONLINE_DOCUMENT: | case DatasourceProviderType.ONLINE_DOCUMENT: | ||||
| controller = OnlineDocumentDatasourcePluginProviderController( | controller = OnlineDocumentDatasourcePluginProviderController( | ||||
| entity=provider_entity.declaration, | entity=provider_entity.declaration, | ||||
| tenant_id, | tenant_id, | ||||
| datasource_type, | datasource_type, | ||||
| ).get_datasource(datasource_name) | ).get_datasource(datasource_name) | ||||
| "X-Plugin-ID": datasource_provider_id.plugin_id, | "X-Plugin-ID": datasource_provider_id.plugin_id, | ||||
| "Content-Type": "application/json", | "Content-Type": "application/json", | ||||
| }, | }, | ||||
| ) | ) | ||||
| for resp in response: | for resp in response: | ||||
| "identity": { | "identity": { | ||||
| "author": "langgenius", | "author": "langgenius", | ||||
| "name": "langgenius/file/file", | "name": "langgenius/file/file", | ||||
| "label": { | |||||
| "zh_Hans": "File", | |||||
| "en_US": "File", | |||||
| "pt_BR": "File", | |||||
| "ja_JP": "File" | |||||
| }, | |||||
| "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||||
| "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", | "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", | ||||
| "description": { | |||||
| "zh_Hans": "File", | |||||
| "en_US": "File", | |||||
| "pt_BR": "File", | |||||
| "ja_JP": "File" | |||||
| } | |||||
| "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||||
| }, | }, | ||||
| "credentials_schema": [], | "credentials_schema": [], | ||||
| "provider_type": "local_file", | "provider_type": "local_file", | ||||
| "datasources": [{ | |||||
| "identity": { | |||||
| "author": "langgenius", | |||||
| "name": "upload-file", | |||||
| "provider": "langgenius", | |||||
| "label": { | |||||
| "zh_Hans": "File", | |||||
| "en_US": "File", | |||||
| "pt_BR": "File", | |||||
| "ja_JP": "File" | |||||
| } | |||||
| }, | |||||
| "parameters": [], | |||||
| "description": { | |||||
| "zh_Hans": "File", | |||||
| "en_US": "File", | |||||
| "pt_BR": "File", | |||||
| "ja_JP": "File" | |||||
| "datasources": [ | |||||
| { | |||||
| "identity": { | |||||
| "author": "langgenius", | |||||
| "name": "upload-file", | |||||
| "provider": "langgenius", | |||||
| "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||||
| }, | |||||
| "parameters": [], | |||||
| "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, | |||||
| } | } | ||||
| }] | |||||
| } | |||||
| ], | |||||
| }, | |||||
| } | } |
| with redis_client.lock(lock_name, timeout=600): | with redis_client.lock(lock_name, timeout=600): | ||||
| keyword_table_handler = JiebaKeywordTableHandler() | keyword_table_handler = JiebaKeywordTableHandler() | ||||
| keyword_table = self._get_dataset_keyword_table() | keyword_table = self._get_dataset_keyword_table() | ||||
| keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||||
| keyword_number = ( | |||||
| self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||||
| ) | |||||
| for text in texts: | for text in texts: | ||||
| keywords = keyword_table_handler.extract_keywords( | |||||
| text.page_content, keyword_number | |||||
| ) | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) | |||||
| if text.metadata is not None: | if text.metadata is not None: | ||||
| self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) | self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) | ||||
| keyword_table = self._add_text_to_keyword_table( | keyword_table = self._add_text_to_keyword_table( | ||||
| keyword_table = self._get_dataset_keyword_table() | keyword_table = self._get_dataset_keyword_table() | ||||
| keywords_list = kwargs.get("keywords_list") | keywords_list = kwargs.get("keywords_list") | ||||
| keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||||
| keyword_number = ( | |||||
| self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||||
| ) | |||||
| for i in range(len(texts)): | for i in range(len(texts)): | ||||
| text = texts[i] | text = texts[i] | ||||
| if keywords_list: | if keywords_list: | ||||
| keywords = keywords_list[i] | keywords = keywords_list[i] | ||||
| if not keywords: | if not keywords: | ||||
| keywords = keyword_table_handler.extract_keywords( | |||||
| text.page_content, keyword_number | |||||
| ) | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) | |||||
| else: | else: | ||||
| keywords = keyword_table_handler.extract_keywords( | |||||
| text.page_content, keyword_number | |||||
| ) | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) | |||||
| if text.metadata is not None: | if text.metadata is not None: | ||||
| self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) | self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) | ||||
| keyword_table = self._add_text_to_keyword_table( | keyword_table = self._add_text_to_keyword_table( | ||||
| keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] | keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] | ||||
| ) | ) | ||||
| else: | else: | ||||
| keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||||
| keyword_number = ( | |||||
| self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk | |||||
| ) | |||||
| keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) | keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) | ||||
| segment.keywords = list(keywords) | segment.keywords = list(keywords) |
| @abstractmethod | @abstractmethod | ||||
| def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): | def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | @abstractmethod | ||||
| def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | ||||
| raise NotImplementedError | raise NotImplementedError |
| from core.rag.models.document import Document, GeneralStructureChunk | from core.rag.models.document import Document, GeneralStructureChunk | ||||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | from core.tools.utils.text_processing_utils import remove_leading_symbols | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule | |||||
| from models.dataset import Dataset, DatasetProcessRule | |||||
| from models.dataset import Document as DatasetDocument | |||||
| from services.entities.knowledge_entities.knowledge_entities import Rule | from services.entities.knowledge_entities.knowledge_entities import Rule | ||||
| keyword = Keyword(dataset) | keyword = Keyword(dataset) | ||||
| keyword.add_texts(documents) | keyword.add_texts(documents) | ||||
| def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | ||||
| paragraph = GeneralStructureChunk(**chunks) | paragraph = GeneralStructureChunk(**chunks) | ||||
| preview = [] | preview = [] | ||||
| for content in paragraph.general_chunks: | for content in paragraph.general_chunks: | ||||
| preview.append({"content": content}) | preview.append({"content": content}) | ||||
| return { | |||||
| "preview": preview, | |||||
| "total_segments": len(paragraph.general_chunks) | |||||
| } | |||||
| return {"preview": preview, "total_segments": len(paragraph.general_chunks)} |
| from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk | from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment | |||||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | ||||
| parent_childs = ParentChildStructureChunk(**chunks) | parent_childs = ParentChildStructureChunk(**chunks) | ||||
| preview = [] | preview = [] | ||||
| for parent_child in parent_childs.parent_child_chunks: | for parent_child in parent_childs.parent_child_chunks: | ||||
| preview.append( | |||||
| { | |||||
| "content": parent_child.parent_content, | |||||
| "child_chunks": parent_child.child_contents | |||||
| } | |||||
| ) | |||||
| return { | |||||
| "preview": preview, | |||||
| "total_segments": len(parent_childs.parent_child_chunks) | |||||
| } | |||||
| preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) | |||||
| return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)} |
| import re | import re | ||||
| import threading | import threading | ||||
| import uuid | import uuid | ||||
| from typing import Any, Mapping, Optional | |||||
| from collections.abc import Mapping | |||||
| from typing import Any, Optional | |||||
| import pandas as pd | import pandas as pd | ||||
| from flask import Flask, current_app | from flask import Flask, current_app | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | from core.tools.utils.text_processing_utils import remove_leading_symbols | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Dataset, Document as DatasetDocument | |||||
| from models.dataset import Dataset | |||||
| from services.entities.knowledge_entities.knowledge_entities import Rule | from services.entities.knowledge_entities.knowledge_entities import Rule | ||||
| doc = Document(page_content=result.page_content, metadata=metadata) | doc = Document(page_content=result.page_content, metadata=metadata) | ||||
| docs.append(doc) | docs.append(doc) | ||||
| return docs | return docs | ||||
| def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): | def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): | ||||
| pass | pass | ||||
| def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: | ||||
| return {"preview": chunks} | return {"preview": chunks} | ||||
| class ArrayFileVariable(ArrayFileSegment, ArrayVariable): | class ArrayFileVariable(ArrayFileSegment, ArrayVariable): | ||||
| pass | pass | ||||
| class RAGPipelineVariable(BaseModel): | class RAGPipelineVariable(BaseModel): | ||||
| belong_to_node_id: str = Field(description="belong to which node id, shared means public") | belong_to_node_id: str = Field(description="belong to which node id, shared means public") | ||||
| type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") | type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") | ||||
| label: str = Field(description="label") | label: str = Field(description="label") | ||||
| description: str | None = Field(description="description", default="") | description: str | None = Field(description="description", default="") | ||||
| variable: str = Field(description="variable key", default="") | variable: str = Field(description="variable key", default="") | ||||
| max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0) | |||||
| max_length: int | None = Field( | |||||
| description="max length, applicable to text-input, paragraph, and file-list", default=0 | |||||
| ) | |||||
| default_value: str | None = Field(description="default value", default="") | default_value: str | None = Field(description="default value", default="") | ||||
| placeholder: str | None = Field(description="placeholder", default="") | placeholder: str | None = Field(description="placeholder", default="") | ||||
| unit: str | None = Field(description="unit, applicable to Number", default="") | unit: str | None = Field(description="unit, applicable to Number", default="") | ||||
| tooltips: str | None = Field(description="helpful text", default="") | tooltips: str | None = Field(description="helpful text", default="") | ||||
| allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list) | |||||
| allowed_file_types: list[str] | None = Field( | |||||
| description="image, document, audio, video, custom.", default_factory=list | |||||
| ) | |||||
| allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) | allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) | ||||
| allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list) | |||||
| allowed_file_upload_methods: list[str] | None = Field( | |||||
| description="remote_url, local_file, tool_file.", default_factory=list | |||||
| ) | |||||
| required: bool = Field(description="optional, default false", default=False) | required: bool = Field(description="optional, default false", default=False) | ||||
| options: list[str] | None = Field(default_factory=list) | options: list[str] | None = Field(default_factory=list) |
| ) | ) | ||||
| rag_pipeline_variables: Mapping[str, Any] = Field( | rag_pipeline_variables: Mapping[str, Any] = Field( | ||||
| description="RAG pipeline variables.", | description="RAG pipeline variables.", | ||||
| default_factory=dict, | |||||
| default_factory=dict, | |||||
| ) | ) | ||||
| def __init__( | def __init__( |
| AGENT_LOG = "agent_log" | AGENT_LOG = "agent_log" | ||||
| ITERATION_ID = "iteration_id" | ITERATION_ID = "iteration_id" | ||||
| ITERATION_INDEX = "iteration_index" | ITERATION_INDEX = "iteration_index" | ||||
| DATASOURCE_INFO = "datasource_info" | |||||
| LOOP_ID = "loop_id" | LOOP_ID = "loop_id" | ||||
| LOOP_INDEX = "loop_index" | LOOP_INDEX = "loop_index" | ||||
| PARALLEL_ID = "parallel_id" | PARALLEL_ID = "parallel_id" |
| root_node_configs = [] | root_node_configs = [] | ||||
| all_node_id_config_mapping: dict[str, dict] = {} | all_node_id_config_mapping: dict[str, dict] = {} | ||||
| for node_config in node_configs: | for node_config in node_configs: | ||||
| node_id = node_config.get("id") | node_id = node_config.get("id") | ||||
| if not node_id: | if not node_id: | ||||
| ( | ( | ||||
| node_config.get("id") | node_config.get("id") | ||||
| for node_config in root_node_configs | for node_config in root_node_configs | ||||
| if node_config.get("data", {}).get("type", "") == NodeType.START.value | |||||
| if node_config.get("data", {}).get("type", "") == NodeType.START.value | |||||
| or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value | or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value | ||||
| ), | ), | ||||
| None, | None, |
| raise e | raise e | ||||
| # It may not be necessary, but it is necessary. :) | # It may not be necessary, but it is necessary. :) | ||||
| if ( | |||||
| self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() | |||||
| in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value] | |||||
| ): | |||||
| if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [ | |||||
| NodeType.END.value, | |||||
| NodeType.KNOWLEDGE_INDEX.value, | |||||
| ]: | |||||
| break | break | ||||
| previous_route_node_state = route_node_state | previous_route_node_state = route_node_state |
| from core.file import File | from core.file import File | ||||
| from core.file.enums import FileTransferMethod, FileType | from core.file.enums import FileTransferMethod, FileType | ||||
| from core.plugin.impl.exc import PluginDaemonClientSideError | from core.plugin.impl.exc import PluginDaemonClientSideError | ||||
| from core.variables.segments import ArrayAnySegment, FileSegment | |||||
| from core.variables.segments import ArrayAnySegment | |||||
| from core.variables.variables import ArrayAnyVariable | from core.variables.variables import ArrayAnyVariable | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||||
| from core.workflow.entities.node_entities import NodeRunResult | |||||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | from core.workflow.entities.variable_pool import VariablePool, VariableValue | ||||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||||
| from core.workflow.enums import SystemVariableKey | from core.workflow.enums import SystemVariableKey | ||||
| from core.workflow.nodes.base import BaseNode | from core.workflow.nodes.base import BaseNode | ||||
| from core.workflow.nodes.enums import NodeType | from core.workflow.nodes.enums import NodeType | ||||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | from core.workflow.utils.variable_template_parser import VariableTemplateParser | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import UploadFile | from models.model import UploadFile | ||||
| from models.workflow import WorkflowNodeExecutionStatus | |||||
| from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||||
| from .entities import DatasourceNodeData | from .entities import DatasourceNodeData | ||||
| from .exc import DatasourceNodeError, DatasourceParameterError | from .exc import DatasourceNodeError, DatasourceParameterError | ||||
| try: | try: | ||||
| from core.datasource.datasource_manager import DatasourceManager | from core.datasource.datasource_manager import DatasourceManager | ||||
| if datasource_type is None: | if datasource_type is None: | ||||
| raise DatasourceNodeError("Datasource type is not set") | raise DatasourceNodeError("Datasource type is not set") | ||||
| ) | ) | ||||
| except DatasourceNodeError as e: | except DatasourceNodeError as e: | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs={}, | |||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| error=f"Failed to get datasource runtime: {str(e)}", | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs={}, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| error=f"Failed to get datasource runtime: {str(e)}", | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| # get parameters | # get parameters | ||||
| datasource_parameters = datasource_runtime.entity.parameters | datasource_parameters = datasource_runtime.entity.parameters | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | status=WorkflowNodeExecutionStatus.SUCCEEDED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| outputs={ | outputs={ | ||||
| "online_document": online_document_result.result.model_dump(), | "online_document": online_document_result.result.model_dump(), | ||||
| "datasource_type": datasource_type, | "datasource_type": datasource_type, | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | status=WorkflowNodeExecutionStatus.SUCCEEDED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| outputs={ | outputs={ | ||||
| "website": datasource_info, | |||||
| "datasource_type": datasource_type, | |||||
| "website": datasource_info, | |||||
| "datasource_type": datasource_type, | |||||
| }, | }, | ||||
| ) | ) | ||||
| case DatasourceProviderType.LOCAL_FILE: | case DatasourceProviderType.LOCAL_FILE: | ||||
| related_id = datasource_info.get("related_id") | related_id = datasource_info.get("related_id") | ||||
| if not related_id: | if not related_id: | ||||
| raise DatasourceNodeError( | |||||
| "File is not exist" | |||||
| ) | |||||
| raise DatasourceNodeError("File is not exist") | |||||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() | upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() | ||||
| if not upload_file: | if not upload_file: | ||||
| raise ValueError("Invalid upload file Info") | raise ValueError("Invalid upload file Info") | ||||
| # construct new key list | # construct new key list | ||||
| new_key_list = ["file", key] | new_key_list = ["file", key] | ||||
| self._append_variables_recursively( | self._append_variables_recursively( | ||||
| variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value | |||||
| variable_pool=variable_pool, | |||||
| node_id=self.node_id, | |||||
| variable_key_list=new_key_list, | |||||
| variable_value=value, | |||||
| ) | ) | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | status=WorkflowNodeExecutionStatus.SUCCEEDED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| outputs={ | outputs={ | ||||
| "file_info": datasource_info, | |||||
| "datasource_type": datasource_type, | |||||
| }, | |||||
| ) | |||||
| case _: | |||||
| raise DatasourceNodeError( | |||||
| f"Unsupported datasource provider: {datasource_type}" | |||||
| "file_info": datasource_info, | |||||
| "datasource_type": datasource_type, | |||||
| }, | |||||
| ) | ) | ||||
| case _: | |||||
| raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") | |||||
| except PluginDaemonClientSideError as e: | except PluginDaemonClientSideError as e: | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| error=f"Failed to transform datasource message: {str(e)}", | error=f"Failed to transform datasource message: {str(e)}", | ||||
| error_type=type(e).__name__, | error_type=type(e).__name__, | ||||
| ) | ) | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| error=f"Failed to invoke datasource: {str(e)}", | error=f"Failed to invoke datasource: {str(e)}", | ||||
| error_type=type(e).__name__, | error_type=type(e).__name__, | ||||
| ) | ) | ||||
| assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) | assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) | ||||
| return list(variable.value) if variable else [] | return list(variable.value) if variable else [] | ||||
| def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): | |||||
| def _append_variables_recursively( | |||||
| self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | |||||
| ): | |||||
| """ | """ | ||||
| Append variables recursively | Append variables recursively | ||||
| :param node_id: node id | :param node_id: node id |
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from core.variables.segments import ObjectSegment | |||||
| from core.workflow.entities.node_entities import NodeRunResult | from core.workflow.entities.node_entities import NodeRunResult | ||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| from core.workflow.enums import SystemVariableKey | from core.workflow.enums import SystemVariableKey | ||||
| process_data=None, | process_data=None, | ||||
| outputs=outputs, | outputs=outputs, | ||||
| ) | ) | ||||
| results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks, | |||||
| variable_pool=variable_pool) | |||||
| results = self._invoke_knowledge_index( | |||||
| dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool | |||||
| ) | |||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results | status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results | ||||
| ) | ) | ||||
| ) | ) | ||||
| def _invoke_knowledge_index( | def _invoke_knowledge_index( | ||||
| self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], | |||||
| variable_pool: VariablePool | |||||
| self, | |||||
| dataset: Dataset, | |||||
| node_data: KnowledgeIndexNodeData, | |||||
| chunks: Mapping[str, Any], | |||||
| variable_pool: VariablePool, | |||||
| ) -> Any: | ) -> Any: | ||||
| document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) | document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) | ||||
| if not document_id: | if not document_id: | ||||
| document.indexing_status = "completed" | document.indexing_status = "completed" | ||||
| document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.add(document) | db.session.add(document) | ||||
| #update document segment status | |||||
| # update document segment status | |||||
| db.session.query(DocumentSegment).filter( | db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.document_id == document.id, | DocumentSegment.document_id == document.id, | ||||
| DocumentSegment.dataset_id == dataset.id, | DocumentSegment.dataset_id == dataset.id, |
| "external_knowledge_api_name": external_knowledge_api.name, | "external_knowledge_api_name": external_knowledge_api.name, | ||||
| "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | ||||
| } | } | ||||
| @property | @property | ||||
| def is_published(self): | def is_published(self): | ||||
| if self.pipeline_id: | if self.pipeline_id: | ||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] | class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] | ||||
| __tablename__ = "pipeline_customized_templates" | __tablename__ = "pipeline_customized_templates" | ||||
| __table_args__ = ( | __table_args__ = ( |
| from datetime import datetime | from datetime import datetime | ||||
| from sqlalchemy.dialects.postgresql import JSONB | from sqlalchemy.dialects.postgresql import JSONB | ||||
| provider: Mapped[str] = db.Column(db.String(255), nullable=False) | provider: Mapped[str] = db.Column(db.String(255), nullable=False) | ||||
| system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) | system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) | ||||
| class DatasourceProvider(Base): | class DatasourceProvider(Base): | ||||
| __tablename__ = "datasource_providers" | __tablename__ = "datasource_providers" | ||||
| __table_args__ = ( | __table_args__ = ( |
| from calendar import day_abbr | |||||
| import copy | import copy | ||||
| import datetime | import datetime | ||||
| import json | import json | ||||
| import time | import time | ||||
| import uuid | import uuid | ||||
| from collections import Counter | from collections import Counter | ||||
| from typing import Any, Optional, cast | |||||
| from typing import Any, Optional | |||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from sqlalchemy import func, select | from sqlalchemy import func, select | ||||
| db.session.commit() | db.session.commit() | ||||
| return dataset | return dataset | ||||
| @staticmethod | @staticmethod | ||||
| def get_dataset(dataset_id) -> Optional[Dataset]: | def get_dataset(dataset_id) -> Optional[Dataset]: | ||||
| dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() | dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() | ||||
| return dataset | return dataset | ||||
| @staticmethod | @staticmethod | ||||
| def update_rag_pipeline_dataset_settings(session: Session, | |||||
| dataset: Dataset, | |||||
| knowledge_configuration: KnowledgeConfiguration, | |||||
| has_published: bool = False): | |||||
| def update_rag_pipeline_dataset_settings( | |||||
| session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False | |||||
| ): | |||||
| dataset = session.merge(dataset) | dataset = session.merge(dataset) | ||||
| if not has_published: | if not has_published: | ||||
| dataset.chunk_structure = knowledge_configuration.chunk_structure | dataset.chunk_structure = knowledge_configuration.chunk_structure | ||||
| if action: | if action: | ||||
| deal_dataset_index_update_task.delay(dataset.id, action) | deal_dataset_index_update_task.delay(dataset.id, action) | ||||
| @staticmethod | @staticmethod | ||||
| def delete_dataset(dataset_id, user): | def delete_dataset(dataset_id, user): | ||||
| dataset = DatasetService.get_dataset(dataset_id) | dataset = DatasetService.get_dataset(dataset_id) |
| import logging | import logging | ||||
| from typing import Optional | |||||
| from flask_login import current_user | from flask_login import current_user | ||||
| def __init__(self) -> None: | def __init__(self) -> None: | ||||
| self.provider_manager = PluginDatasourceManager() | self.provider_manager = PluginDatasourceManager() | ||||
| def datasource_provider_credentials_validate(self, | |||||
| tenant_id: str, | |||||
| provider: str, | |||||
| plugin_id: str, | |||||
| credentials: dict) -> None: | |||||
| def datasource_provider_credentials_validate( | |||||
| self, tenant_id: str, provider: str, plugin_id: str, credentials: dict | |||||
| ) -> None: | |||||
| """ | """ | ||||
| validate datasource provider credentials. | validate datasource provider credentials. | ||||
| :param provider: | :param provider: | ||||
| :param credentials: | :param credentials: | ||||
| """ | """ | ||||
| credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, | |||||
| user_id=current_user.id, | |||||
| provider=provider, | |||||
| credentials=credentials) | |||||
| credential_valid = self.provider_manager.validate_provider_credentials( | |||||
| tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials | |||||
| ) | |||||
| if credential_valid: | if credential_valid: | ||||
| # Get all provider configurations of the current workspace | # Get all provider configurations of the current workspace | ||||
| datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id).first() | |||||
| datasource_provider = ( | |||||
| db.session.query(DatasourceProvider) | |||||
| .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) | |||||
| .first() | |||||
| ) | |||||
| provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, | |||||
| provider=provider | |||||
| ) | |||||
| provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) | |||||
| if not datasource_provider: | if not datasource_provider: | ||||
| for key, value in credentials.items(): | for key, value in credentials.items(): | ||||
| if key in provider_credential_secret_variables: | if key in provider_credential_secret_variables: | ||||
| # if send [__HIDDEN__] in secret input, it will be same as original value | # if send [__HIDDEN__] in secret input, it will be same as original value | ||||
| credentials[key] = encrypter.encrypt_token(tenant_id, value) | credentials[key] = encrypter.encrypt_token(tenant_id, value) | ||||
| datasource_provider = DatasourceProvider(tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id, | |||||
| auth_type="api_key", | |||||
| encrypted_credentials=credentials) | |||||
| datasource_provider = DatasourceProvider( | |||||
| tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id, | |||||
| auth_type="api_key", | |||||
| encrypted_credentials=credentials, | |||||
| ) | |||||
| db.session.add(datasource_provider) | db.session.add(datasource_provider) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | # Get all provider configurations of the current workspace | ||||
| datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter( | |||||
| DatasourceProvider.tenant_id == tenant_id, | |||||
| DatasourceProvider.provider == provider, | |||||
| DatasourceProvider.plugin_id == plugin_id | |||||
| ).all() | |||||
| datasource_providers: list[DatasourceProvider] = ( | |||||
| db.session.query(DatasourceProvider) | |||||
| .filter( | |||||
| DatasourceProvider.tenant_id == tenant_id, | |||||
| DatasourceProvider.provider == provider, | |||||
| DatasourceProvider.plugin_id == plugin_id, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| if not datasource_providers: | if not datasource_providers: | ||||
| return [] | return [] | ||||
| copy_credentials_list = [] | copy_credentials_list = [] | ||||
| return copy_credentials_list | return copy_credentials_list | ||||
| def remove_datasource_credentials(self, | |||||
| tenant_id: str, | |||||
| provider: str, | |||||
| plugin_id: str) -> None: | |||||
| def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None: | |||||
| """ | """ | ||||
| remove datasource credentials. | remove datasource credentials. | ||||
| :param plugin_id: plugin id | :param plugin_id: plugin id | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| plugin_id=plugin_id).first() | |||||
| datasource_provider = ( | |||||
| db.session.query(DatasourceProvider) | |||||
| .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) | |||||
| .first() | |||||
| ) | |||||
| if datasource_provider: | if datasource_provider: | ||||
| db.session.delete(datasource_provider) | db.session.delete(datasource_provider) | ||||
| db.session.commit() | db.session.commit() |
| """ | """ | ||||
| Knowledge Base Configuration. | Knowledge Base Configuration. | ||||
| """ | """ | ||||
| chunk_structure: str | chunk_structure: str | ||||
| indexing_technique: Literal["high_quality", "economy"] | indexing_technique: Literal["high_quality", "economy"] | ||||
| embedding_model_provider: Optional[str] = "" | embedding_model_provider: Optional[str] = "" |
| from configs import dify_config | from configs import dify_config | ||||
| from core.app.apps.pipeline.pipeline_generator import PipelineGenerator | from core.app.apps.pipeline.pipeline_generator import PipelineGenerator | ||||
| from core.app.apps.workflow.app_generator import WorkflowAppGenerator | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from models.dataset import Pipeline | from models.dataset import Pipeline | ||||
| from models.model import Account, App, EndUser | from models.model import Account, App, EndUser |
| from typing import Optional | from typing import Optional | ||||
| from flask_login import current_user | |||||
| import yaml | import yaml | ||||
| from flask_login import current_user | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import PipelineCustomizedTemplate | from models.dataset import PipelineCustomizedTemplate | ||||
| from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase | from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase | ||||
| from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType | from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType | ||||
| from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService | |||||
| class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | ||||
| ) | ) | ||||
| recommended_pipelines_results = [] | recommended_pipelines_results = [] | ||||
| for pipeline_customized_template in pipeline_customized_templates: | for pipeline_customized_template in pipeline_customized_templates: | ||||
| recommended_pipeline_result = { | recommended_pipeline_result = { | ||||
| "id": pipeline_customized_template.id, | "id": pipeline_customized_template.id, | ||||
| "name": pipeline_customized_template.name, | "name": pipeline_customized_template.name, | ||||
| return {"pipeline_templates": recommended_pipelines_results} | return {"pipeline_templates": recommended_pipelines_results} | ||||
| @classmethod | @classmethod | ||||
| def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: | def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: | ||||
| """ | """ |
| recommended_pipelines_results = [] | recommended_pipelines_results = [] | ||||
| for pipeline_built_in_template in pipeline_built_in_templates: | for pipeline_built_in_template in pipeline_built_in_templates: | ||||
| recommended_pipeline_result = { | recommended_pipeline_result = { | ||||
| "id": pipeline_built_in_template.id, | "id": pipeline_built_in_template.id, | ||||
| "name": pipeline_built_in_template.name, | "name": pipeline_built_in_template.name, |
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | from libs.infinite_scroll_pagination import InfiniteScrollPagination | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore | |||||
| from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore | |||||
| from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom | from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom | ||||
| from models.model import EndUser | from models.model import EndUser | ||||
| from models.workflow import ( | from models.workflow import ( | ||||
| class RagPipelineService: | class RagPipelineService: | ||||
| @classmethod | @classmethod | ||||
| def get_pipeline_templates( | |||||
| cls, type: str = "built-in", language: str = "en-US" | |||||
| ) -> dict: | |||||
| def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: | |||||
| if type == "built-in": | if type == "built-in": | ||||
| mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE | mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE | ||||
| retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() | retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() | ||||
| session=session, | session=session, | ||||
| dataset=dataset, | dataset=dataset, | ||||
| knowledge_configuration=knowledge_configuration, | knowledge_configuration=knowledge_configuration, | ||||
| has_published=pipeline.is_published | |||||
| has_published=pipeline.is_published, | |||||
| ) | ) | ||||
| # return new workflow | # return new workflow | ||||
| return workflow | return workflow | ||||
| ) | ) | ||||
| if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: | if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: | ||||
| datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) | datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) | ||||
| online_document_result: GetOnlineDocumentPagesResponse = ( | |||||
| datasource_runtime._get_online_document_pages( | |||||
| user_id=account.id, | |||||
| datasource_parameters=user_inputs, | |||||
| provider_type=datasource_runtime.datasource_provider_type(), | |||||
| ) | |||||
| online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( | |||||
| user_id=account.id, | |||||
| datasource_parameters=user_inputs, | |||||
| provider_type=datasource_runtime.datasource_provider_type(), | |||||
| ) | ) | ||||
| return { | return { | ||||
| "result": [page.model_dump() for page in online_document_result.result], | "result": [page.model_dump() for page in online_document_result.result], | ||||
| else: | else: | ||||
| raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") | raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") | ||||
| def run_free_workflow_node( | def run_free_workflow_node( | ||||
| self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] | self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] | ||||
| ) -> WorkflowNodeExecution: | ) -> WorkflowNodeExecution: | ||||
| WorkflowRun.app_id == pipeline.id, | WorkflowRun.app_id == pipeline.id, | ||||
| or_( | or_( | ||||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, | WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, | ||||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value | |||||
| ) | |||||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value, | |||||
| ), | |||||
| ) | ) | ||||
| if args.get("last_id"): | if args.get("last_id"): | ||||
| # Use the repository to get the node execution | # Use the repository to get the node execution | ||||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=db.engine, | |||||
| app_id=pipeline.id, | |||||
| user=user, | |||||
| triggered_from=None | |||||
| session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None | |||||
| ) | ) | ||||
| # Use the repository to get the node executions with ordering | # Use the repository to get the node executions with ordering | ||||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | order_config = OrderConfig(order_by=["index"], order_direction="desc") | ||||
| node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, | |||||
| order_config=order_config, | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) | |||||
| # Convert domain models to database models | |||||
| node_executions = repository.get_by_workflow_run( | |||||
| workflow_run_id=run_id, | |||||
| order_config=order_config, | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, | |||||
| ) | |||||
| # Convert domain models to database models | |||||
| workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] | workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] | ||||
| return workflow_node_executions | return workflow_node_executions |
| if node.get("data", {}).get("type") == "knowledge_index": | if node.get("data", {}).get("type") == "knowledge_index": | ||||
| knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) | knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) | ||||
| knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) | knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) | ||||
| if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure: | |||||
| if ( | |||||
| dataset | |||||
| and pipeline.is_published | |||||
| and dataset.chunk_structure != knowledge_configuration.chunk_structure | |||||
| ): | |||||
| raise ValueError("Chunk structure is not compatible with the published pipeline") | raise ValueError("Chunk structure is not compatible with the published pipeline") | ||||
| else: | else: | ||||
| dataset = Dataset( | dataset = Dataset( | ||||
| .filter( | .filter( | ||||
| DatasetCollectionBinding.provider_name | DatasetCollectionBinding.provider_name | ||||
| == knowledge_configuration.embedding_model_provider, | == knowledge_configuration.embedding_model_provider, | ||||
| DatasetCollectionBinding.model_name | |||||
| == knowledge_configuration.embedding_model, | |||||
| DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, | |||||
| DatasetCollectionBinding.type == "dataset", | DatasetCollectionBinding.type == "dataset", | ||||
| ) | ) | ||||
| .order_by(DatasetCollectionBinding.created_at) | .order_by(DatasetCollectionBinding.created_at) | ||||
| db.session.commit() | db.session.commit() | ||||
| dataset_collection_binding_id = dataset_collection_binding.id | dataset_collection_binding_id = dataset_collection_binding.id | ||||
| dataset.collection_binding_id = dataset_collection_binding_id | dataset.collection_binding_id = dataset_collection_binding_id | ||||
| dataset.embedding_model = ( | |||||
| knowledge_configuration.embedding_model | |||||
| ) | |||||
| dataset.embedding_model_provider = ( | |||||
| knowledge_configuration.embedding_model_provider | |||||
| ) | |||||
| dataset.embedding_model = knowledge_configuration.embedding_model | |||||
| dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider | |||||
| elif knowledge_configuration.indexing_technique == "economy": | elif knowledge_configuration.indexing_technique == "economy": | ||||
| dataset.keyword_number = knowledge_configuration.keyword_number | dataset.keyword_number = knowledge_configuration.keyword_number | ||||
| dataset.pipeline_id = pipeline.id | dataset.pipeline_id = pipeline.id | ||||
| .filter( | .filter( | ||||
| DatasetCollectionBinding.provider_name | DatasetCollectionBinding.provider_name | ||||
| == knowledge_configuration.embedding_model_provider, | == knowledge_configuration.embedding_model_provider, | ||||
| DatasetCollectionBinding.model_name | |||||
| == knowledge_configuration.embedding_model, | |||||
| DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, | |||||
| DatasetCollectionBinding.type == "dataset", | DatasetCollectionBinding.type == "dataset", | ||||
| ) | ) | ||||
| .order_by(DatasetCollectionBinding.created_at) | .order_by(DatasetCollectionBinding.created_at) | ||||
| db.session.commit() | db.session.commit() | ||||
| dataset_collection_binding_id = dataset_collection_binding.id | dataset_collection_binding_id = dataset_collection_binding.id | ||||
| dataset.collection_binding_id = dataset_collection_binding_id | dataset.collection_binding_id = dataset_collection_binding_id | ||||
| dataset.embedding_model = ( | |||||
| knowledge_configuration.embedding_model | |||||
| ) | |||||
| dataset.embedding_model_provider = ( | |||||
| knowledge_configuration.embedding_model_provider | |||||
| ) | |||||
| dataset.embedding_model = knowledge_configuration.embedding_model | |||||
| dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider | |||||
| elif knowledge_configuration.indexing_technique == "economy": | elif knowledge_configuration.indexing_technique == "economy": | ||||
| dataset.keyword_number = knowledge_configuration.keyword_number | dataset.keyword_number = knowledge_configuration.keyword_number | ||||
| dataset.pipeline_id = pipeline.id | dataset.pipeline_id = pipeline.id | ||||
| icon_type = "emoji" | icon_type = "emoji" | ||||
| icon = str(pipeline_data.get("icon", "")) | icon = str(pipeline_data.get("icon", "")) | ||||
| # Initialize pipeline based on mode | # Initialize pipeline based on mode | ||||
| workflow_data = data.get("workflow") | workflow_data = data.get("workflow") | ||||
| if not workflow_data or not isinstance(workflow_data, dict): | if not workflow_data or not isinstance(workflow_data, dict): | ||||
| ] | ] | ||||
| rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) | rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) | ||||
| graph = workflow_data.get("graph", {}) | graph = workflow_data.get("graph", {}) | ||||
| for node in graph.get("nodes", []): | for node in graph.get("nodes", []): | ||||
| if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: | if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: | ||||
| pipeline.description = pipeline_data.get("description", pipeline.description) | pipeline.description = pipeline_data.get("description", pipeline.description) | ||||
| pipeline.updated_by = account.id | pipeline.updated_by = account.id | ||||
| else: | else: | ||||
| if account.current_tenant_id is None: | if account.current_tenant_id is None: | ||||
| raise ValueError("Current tenant is not set") | raise ValueError("Current tenant is not set") | ||||
| # commit db session changes | # commit db session changes | ||||
| db.session.commit() | db.session.commit() | ||||
| return pipeline | return pipeline | ||||
| @classmethod | @classmethod | ||||
| except Exception: | except Exception: | ||||
| return None | return None | ||||
| @staticmethod | @staticmethod | ||||
| def create_rag_pipeline_dataset( | def create_rag_pipeline_dataset( | ||||
| tenant_id: str, | tenant_id: str, | ||||
| .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) | .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) | ||||
| .first() | .first() | ||||
| ): | ): | ||||
| raise ValueError( | |||||
| f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." | |||||
| ) | |||||
| raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") | |||||
| with Session(db.engine) as session: | with Session(db.engine) as session: | ||||
| rag_pipeline_dsl_service = RagPipelineDslService(session) | rag_pipeline_dsl_service = RagPipelineDslService(session) |
| # get all builtin providers | # get all builtin providers | ||||
| manager = PluginDatasourceManager() | manager = PluginDatasourceManager() | ||||
| datasources = manager.fetch_datasource_providers(tenant_id) | |||||
| datasources = manager.fetch_datasource_providers(tenant_id) | |||||
| for datasource in datasources: | for datasource in datasources: | ||||
| datasource_provider_service = DatasourceProviderService() | datasource_provider_service = DatasourceProviderService() | ||||
| credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id, | |||||
| provider=datasource.provider, | |||||
| plugin_id=datasource.plugin_id) | |||||
| credentials = datasource_provider_service.get_datasource_credentials( | |||||
| tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id | |||||
| ) | |||||
| if credentials: | if credentials: | ||||
| datasource.is_authorized = True | datasource.is_authorized = True | ||||
| return datasources | return datasources |