You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_manager.py 3.5KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import logging
  2. from threading import Lock
  3. from typing import Union
  4. import contexts
  5. from core.datasource.__base.datasource_plugin import DatasourcePlugin
  6. from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
  7. from core.datasource.entities.common_entities import I18nObject
  8. from core.datasource.entities.datasource_entities import DatasourceProviderType
  9. from core.datasource.errors import ToolProviderNotFoundError
  10. from core.plugin.manager.tool import PluginToolManager
  11. logger = logging.getLogger(__name__)
  12. class DatasourceManager:
  13. _builtin_provider_lock = Lock()
  14. _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
  15. _builtin_providers_loaded = False
  16. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  17. @classmethod
  18. def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController:
  19. """
  20. get the datasource plugin provider
  21. """
  22. # check if context is set
  23. try:
  24. contexts.datasource_plugin_providers.get()
  25. except LookupError:
  26. contexts.datasource_plugin_providers.set({})
  27. contexts.datasource_plugin_providers_lock.set(Lock())
  28. with contexts.datasource_plugin_providers_lock.get():
  29. datasource_plugin_providers = contexts.datasource_plugin_providers.get()
  30. if provider in datasource_plugin_providers:
  31. return datasource_plugin_providers[provider]
  32. manager = PluginToolManager()
  33. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  34. if not provider_entity:
  35. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  36. controller = DatasourcePluginProviderController(
  37. entity=provider_entity.declaration,
  38. plugin_id=provider_entity.plugin_id,
  39. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  40. tenant_id=tenant_id,
  41. )
  42. datasource_plugin_providers[provider] = controller
  43. return controller
  44. @classmethod
  45. def get_datasource_runtime(
  46. cls,
  47. provider_type: DatasourceProviderType,
  48. provider_id: str,
  49. datasource_name: str,
  50. tenant_id: str,
  51. ) -> DatasourcePlugin:
  52. """
  53. get the datasource runtime
  54. :param provider_type: the type of the provider
  55. :param provider_id: the id of the provider
  56. :param datasource_name: the name of the datasource
  57. :param tenant_id: the tenant id
  58. :return: the datasource plugin
  59. """
  60. if provider_type == DatasourceProviderType.RAG_PIPELINE:
  61. return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
  62. else:
  63. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  64. @classmethod
  65. def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
  66. """
  67. list all the datasource providers
  68. """
  69. manager = PluginToolManager()
  70. provider_entities = manager.fetch_tool_providers(tenant_id)
  71. return [
  72. DatasourcePluginProviderController(
  73. entity=provider.declaration,
  74. plugin_id=provider.plugin_id,
  75. plugin_unique_identifier=provider.plugin_unique_identifier,
  76. tenant_id=tenant_id,
  77. )
  78. for provider in provider_entities
  79. ]