選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

datasource_manager.py 5.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import logging
  2. from threading import Lock
  3. from typing import Union
  4. import contexts
  5. from core.datasource.__base.datasource_plugin import DatasourcePlugin
  6. from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
  7. from core.datasource.entities.common_entities import I18nObject
  8. from core.datasource.entities.datasource_entities import DatasourceProviderType
  9. from core.datasource.errors import DatasourceProviderNotFoundError
  10. from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
  11. from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
  12. from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
  13. from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
  14. from core.plugin.impl.datasource import PluginDatasourceManager
  15. logger = logging.getLogger(__name__)
  16. class DatasourceManager:
  17. _builtin_provider_lock = Lock()
  18. _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
  19. _builtin_providers_loaded = False
  20. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  21. @classmethod
  22. def get_datasource_plugin_provider(
  23. cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
  24. ) -> DatasourcePluginProviderController:
  25. """
  26. get the datasource plugin provider
  27. """
  28. # check if context is set
  29. try:
  30. contexts.datasource_plugin_providers.get()
  31. except LookupError:
  32. contexts.datasource_plugin_providers.set({})
  33. contexts.datasource_plugin_providers_lock.set(Lock())
  34. with contexts.datasource_plugin_providers_lock.get():
  35. datasource_plugin_providers = contexts.datasource_plugin_providers.get()
  36. if provider_id in datasource_plugin_providers:
  37. return datasource_plugin_providers[provider_id]
  38. manager = PluginDatasourceManager()
  39. provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
  40. if not provider_entity:
  41. raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
  42. controller: DatasourcePluginProviderController | None = None
  43. match datasource_type:
  44. case DatasourceProviderType.ONLINE_DOCUMENT:
  45. controller = OnlineDocumentDatasourcePluginProviderController(
  46. entity=provider_entity.declaration,
  47. plugin_id=provider_entity.plugin_id,
  48. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  49. tenant_id=tenant_id,
  50. )
  51. case DatasourceProviderType.ONLINE_DRIVE:
  52. controller = OnlineDriveDatasourcePluginProviderController(
  53. entity=provider_entity.declaration,
  54. plugin_id=provider_entity.plugin_id,
  55. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  56. tenant_id=tenant_id,
  57. )
  58. case DatasourceProviderType.WEBSITE_CRAWL:
  59. controller = WebsiteCrawlDatasourcePluginProviderController(
  60. entity=provider_entity.declaration,
  61. plugin_id=provider_entity.plugin_id,
  62. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  63. tenant_id=tenant_id,
  64. )
  65. case DatasourceProviderType.LOCAL_FILE:
  66. controller = LocalFileDatasourcePluginProviderController(
  67. entity=provider_entity.declaration,
  68. plugin_id=provider_entity.plugin_id,
  69. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  70. tenant_id=tenant_id,
  71. )
  72. case _:
  73. raise ValueError(f"Unsupported datasource type: {datasource_type}")
  74. if controller:
  75. datasource_plugin_providers[provider_id] = controller
  76. if controller is None:
  77. raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.")
  78. return controller
  79. @classmethod
  80. def get_datasource_runtime(
  81. cls,
  82. provider_id: str,
  83. datasource_name: str,
  84. tenant_id: str,
  85. datasource_type: DatasourceProviderType,
  86. ) -> DatasourcePlugin:
  87. """
  88. get the datasource runtime
  89. :param provider_type: the type of the provider
  90. :param provider_id: the id of the provider
  91. :param datasource_name: the name of the datasource
  92. :param tenant_id: the tenant id
  93. :return: the datasource plugin
  94. """
  95. return cls.get_datasource_plugin_provider(
  96. provider_id,
  97. tenant_id,
  98. datasource_type,
  99. ).get_datasource(datasource_name)