Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

datasource_manager.py 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import logging
  2. from threading import Lock
  3. from typing import Union
  4. import contexts
  5. from core.datasource.__base.datasource_plugin import DatasourcePlugin
  6. from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
  7. from core.datasource.entities.common_entities import I18nObject
  8. from core.datasource.entities.datasource_entities import DatasourceProviderType
  9. from core.datasource.errors import DatasourceProviderNotFoundError
  10. from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
  11. from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
  12. from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
  13. from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
  14. from core.plugin.impl.datasource import PluginDatasourceManager
  15. logger = logging.getLogger(__name__)
  16. class DatasourceManager:
  17. _builtin_provider_lock = Lock()
  18. _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
  19. _builtin_providers_loaded = False
  20. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  21. @classmethod
  22. def get_datasource_plugin_provider(
  23. cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
  24. ) -> DatasourcePluginProviderController:
  25. """
  26. get the datasource plugin provider
  27. """
  28. # check if context is set
  29. try:
  30. contexts.datasource_plugin_providers.get()
  31. except LookupError:
  32. contexts.datasource_plugin_providers.set({})
  33. contexts.datasource_plugin_providers_lock.set(Lock())
  34. with contexts.datasource_plugin_providers_lock.get():
  35. datasource_plugin_providers = contexts.datasource_plugin_providers.get()
  36. if provider_id in datasource_plugin_providers:
  37. return datasource_plugin_providers[provider_id]
  38. manager = PluginDatasourceManager()
  39. provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
  40. if not provider_entity:
  41. raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
  42. match datasource_type:
  43. case DatasourceProviderType.ONLINE_DOCUMENT:
  44. controller = OnlineDocumentDatasourcePluginProviderController(
  45. entity=provider_entity.declaration,
  46. plugin_id=provider_entity.plugin_id,
  47. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  48. tenant_id=tenant_id,
  49. )
  50. case DatasourceProviderType.ONLINE_DRIVE:
  51. controller = OnlineDriveDatasourcePluginProviderController(
  52. entity=provider_entity.declaration,
  53. plugin_id=provider_entity.plugin_id,
  54. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  55. tenant_id=tenant_id,
  56. )
  57. case DatasourceProviderType.WEBSITE_CRAWL:
  58. controller = WebsiteCrawlDatasourcePluginProviderController(
  59. entity=provider_entity.declaration,
  60. plugin_id=provider_entity.plugin_id,
  61. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  62. tenant_id=tenant_id,
  63. )
  64. case DatasourceProviderType.LOCAL_FILE:
  65. controller = LocalFileDatasourcePluginProviderController(
  66. entity=provider_entity.declaration,
  67. plugin_id=provider_entity.plugin_id,
  68. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  69. tenant_id=tenant_id,
  70. )
  71. case _:
  72. raise ValueError(f"Unsupported datasource type: {datasource_type}")
  73. datasource_plugin_providers[provider_id] = controller
  74. return controller
  75. @classmethod
  76. def get_datasource_runtime(
  77. cls,
  78. provider_id: str,
  79. datasource_name: str,
  80. tenant_id: str,
  81. datasource_type: DatasourceProviderType,
  82. ) -> DatasourcePlugin:
  83. """
  84. get the datasource runtime
  85. :param provider_type: the type of the provider
  86. :param provider_id: the id of the provider
  87. :param datasource_name: the name of the datasource
  88. :param tenant_id: the tenant id
  89. :return: the datasource plugin
  90. """
  91. return cls.get_datasource_plugin_provider(
  92. provider_id,
  93. tenant_id,
  94. datasource_type,
  95. ).get_datasource(datasource_name)