| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
-
- import logging
- from threading import Lock
- from typing import Union
-
- import contexts
- from core.datasource.__base.datasource_plugin import DatasourcePlugin
- from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
- from core.datasource.entities.common_entities import I18nObject
- from core.datasource.entities.datasource_entities import DatasourceProviderType
- from core.datasource.errors import ToolProviderNotFoundError
- from core.plugin.manager.tool import PluginToolManager
-
- logger = logging.getLogger(__name__)
-
-
- class DatasourceManager:
- _builtin_provider_lock = Lock()
- _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
- _builtin_providers_loaded = False
- _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
-
- @classmethod
- def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController:
- """
- get the datasource plugin provider
- """
- # check if context is set
- try:
- contexts.datasource_plugin_providers.get()
- except LookupError:
- contexts.datasource_plugin_providers.set({})
- contexts.datasource_plugin_providers_lock.set(Lock())
-
- with contexts.datasource_plugin_providers_lock.get():
- datasource_plugin_providers = contexts.datasource_plugin_providers.get()
- if provider in datasource_plugin_providers:
- return datasource_plugin_providers[provider]
-
- manager = PluginToolManager()
- provider_entity = manager.fetch_tool_provider(tenant_id, provider)
- if not provider_entity:
- raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
-
- controller = DatasourcePluginProviderController(
- entity=provider_entity.declaration,
- plugin_id=provider_entity.plugin_id,
- plugin_unique_identifier=provider_entity.plugin_unique_identifier,
- tenant_id=tenant_id,
- )
-
- datasource_plugin_providers[provider] = controller
-
- return controller
-
- @classmethod
- def get_datasource_runtime(
- cls,
- provider_type: DatasourceProviderType,
- provider_id: str,
- datasource_name: str,
- tenant_id: str,
- ) -> DatasourcePlugin:
- """
- get the datasource runtime
-
- :param provider_type: the type of the provider
- :param provider_id: the id of the provider
- :param datasource_name: the name of the datasource
- :param tenant_id: the tenant id
-
- :return: the datasource plugin
- """
- if provider_type == DatasourceProviderType.RAG_PIPELINE:
- return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
- else:
- raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
-
-
- @classmethod
- def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
- """
- list all the datasource providers
- """
- manager = PluginToolManager()
- provider_entities = manager.fetch_tool_providers(tenant_id)
- return [
- DatasourcePluginProviderController(
- entity=provider.declaration,
- plugin_id=provider.plugin_id,
- plugin_unique_identifier=provider.plugin_unique_identifier,
- tenant_id=tenant_id,
- )
- for provider in provider_entities
- ]
|