您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

datasource_manager.py 4.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import logging
  2. from threading import Lock
  3. from typing import Union
  4. import contexts
  5. from core.datasource.__base.datasource_plugin import DatasourcePlugin
  6. from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
  7. from core.datasource.entities.common_entities import I18nObject
  8. from core.datasource.entities.datasource_entities import DatasourceProviderType
  9. from core.datasource.errors import DatasourceProviderNotFoundError
  10. from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
  11. from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
  12. from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
  13. from core.plugin.impl.datasource import PluginDatasourceManager
  14. logger = logging.getLogger(__name__)
  15. class DatasourceManager:
  16. _builtin_provider_lock = Lock()
  17. _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
  18. _builtin_providers_loaded = False
  19. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  20. @classmethod
  21. def get_datasource_plugin_provider(
  22. cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
  23. ) -> DatasourcePluginProviderController:
  24. """
  25. get the datasource plugin provider
  26. """
  27. # check if context is set
  28. try:
  29. contexts.datasource_plugin_providers.get()
  30. except LookupError:
  31. contexts.datasource_plugin_providers.set({})
  32. contexts.datasource_plugin_providers_lock.set(Lock())
  33. with contexts.datasource_plugin_providers_lock.get():
  34. datasource_plugin_providers = contexts.datasource_plugin_providers.get()
  35. if provider in datasource_plugin_providers:
  36. return datasource_plugin_providers[provider]
  37. manager = PluginDatasourceManager()
  38. provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
  39. if not provider_entity:
  40. raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
  41. match datasource_type:
  42. case DatasourceProviderType.ONLINE_DOCUMENT:
  43. controller = OnlineDocumentDatasourcePluginProviderController(
  44. entity=provider_entity.declaration,
  45. plugin_id=provider_entity.plugin_id,
  46. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  47. tenant_id=tenant_id,
  48. )
  49. case DatasourceProviderType.WEBSITE_CRAWL:
  50. controller = WebsiteCrawlDatasourcePluginProviderController(
  51. entity=provider_entity.declaration,
  52. plugin_id=provider_entity.plugin_id,
  53. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  54. tenant_id=tenant_id,
  55. )
  56. case DatasourceProviderType.LOCAL_FILE:
  57. controller = LocalFileDatasourcePluginProviderController(
  58. entity=provider_entity.declaration,
  59. plugin_id=provider_entity.plugin_id,
  60. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  61. tenant_id=tenant_id,
  62. )
  63. case _:
  64. raise ValueError(f"Unsupported datasource type: {datasource_type}")
  65. datasource_plugin_providers[provider] = controller
  66. return controller
  67. @classmethod
  68. def get_datasource_runtime(
  69. cls,
  70. provider_id: str,
  71. datasource_name: str,
  72. tenant_id: str,
  73. datasource_type: DatasourceProviderType,
  74. ) -> DatasourcePlugin:
  75. """
  76. get the datasource runtime
  77. :param provider_type: the type of the provider
  78. :param provider_id: the id of the provider
  79. :param datasource_name: the name of the datasource
  80. :param tenant_id: the tenant id
  81. :return: the datasource plugin
  82. """
  83. return cls.get_datasource_plugin_provider(
  84. provider_id,
  85. tenant_id,
  86. datasource_type,
  87. ).get_datasource(datasource_name)