| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- import logging
- from threading import Lock
- from typing import Union
-
- import contexts
- from core.datasource.__base.datasource_plugin import DatasourcePlugin
- from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
- from core.datasource.entities.common_entities import I18nObject
- from core.datasource.entities.datasource_entities import DatasourceProviderType
- from core.datasource.errors import DatasourceProviderNotFoundError
- from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
- from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
- from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
- from core.plugin.impl.datasource import PluginDatasourceManager
-
- logger = logging.getLogger(__name__)
-
-
- class DatasourceManager:
- _builtin_provider_lock = Lock()
- _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
- _builtin_providers_loaded = False
- _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
-
- @classmethod
- def get_datasource_plugin_provider(
- cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
- ) -> DatasourcePluginProviderController:
- """
- get the datasource plugin provider
- """
- # check if context is set
- try:
- contexts.datasource_plugin_providers.get()
- except LookupError:
- contexts.datasource_plugin_providers.set({})
- contexts.datasource_plugin_providers_lock.set(Lock())
-
- with contexts.datasource_plugin_providers_lock.get():
- datasource_plugin_providers = contexts.datasource_plugin_providers.get()
- if provider in datasource_plugin_providers:
- return datasource_plugin_providers[provider]
-
- manager = PluginDatasourceManager()
- provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
- if not provider_entity:
- raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
-
- match datasource_type:
- case DatasourceProviderType.ONLINE_DOCUMENT:
- controller = OnlineDocumentDatasourcePluginProviderController(
- entity=provider_entity.declaration,
- plugin_id=provider_entity.plugin_id,
- plugin_unique_identifier=provider_entity.plugin_unique_identifier,
- tenant_id=tenant_id,
- )
- case DatasourceProviderType.WEBSITE_CRAWL:
- controller = WebsiteCrawlDatasourcePluginProviderController(
- entity=provider_entity.declaration,
- plugin_id=provider_entity.plugin_id,
- plugin_unique_identifier=provider_entity.plugin_unique_identifier,
- tenant_id=tenant_id,
- )
- case DatasourceProviderType.LOCAL_FILE:
- controller = LocalFileDatasourcePluginProviderController(
- entity=provider_entity.declaration,
- plugin_id=provider_entity.plugin_id,
- plugin_unique_identifier=provider_entity.plugin_unique_identifier,
- tenant_id=tenant_id,
- )
- case _:
- raise ValueError(f"Unsupported datasource type: {datasource_type}")
-
- datasource_plugin_providers[provider] = controller
-
- return controller
-
- @classmethod
- def get_datasource_runtime(
- cls,
- provider_id: str,
- datasource_name: str,
- tenant_id: str,
- datasource_type: DatasourceProviderType,
- ) -> DatasourcePlugin:
- """
- get the datasource runtime
-
- :param provider_type: the type of the provider
- :param provider_id: the id of the provider
- :param datasource_name: the name of the datasource
- :param tenant_id: the tenant id
-
- :return: the datasource plugin
- """
- return cls.get_datasource_plugin_provider(
- provider_id,
- tenant_id,
- datasource_type,
- ).get_datasource(datasource_name)
|