You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

encryption.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import contextlib
  2. from copy import deepcopy
  3. from typing import Any, Optional, Protocol
  4. from core.entities.provider_entities import BasicProviderConfig
  5. from core.helper import encrypter
  6. from core.helper.provider_cache import SingletonProviderCredentialsCache
  7. from core.tools.__base.tool_provider import ToolProviderController
  8. class ProviderConfigCache(Protocol):
  9. """
  10. Interface for provider configuration cache operations
  11. """
  12. def get(self) -> Optional[dict]:
  13. """Get cached provider configuration"""
  14. ...
  15. def set(self, config: dict[str, Any]) -> None:
  16. """Cache provider configuration"""
  17. ...
  18. def delete(self) -> None:
  19. """Delete cached provider configuration"""
  20. ...
  21. class ProviderConfigEncrypter:
  22. tenant_id: str
  23. config: list[BasicProviderConfig]
  24. provider_config_cache: ProviderConfigCache
  25. def __init__(
  26. self,
  27. tenant_id: str,
  28. config: list[BasicProviderConfig],
  29. provider_config_cache: ProviderConfigCache,
  30. ):
  31. self.tenant_id = tenant_id
  32. self.config = config
  33. self.provider_config_cache = provider_config_cache
  34. def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
  35. """
  36. deep copy data
  37. """
  38. return deepcopy(data)
  39. def encrypt(self, data: dict[str, str]) -> dict[str, str]:
  40. """
  41. encrypt tool credentials with tenant id
  42. return a deep copy of credentials with encrypted values
  43. """
  44. data = self._deep_copy(data)
  45. # get fields need to be decrypted
  46. fields = dict[str, BasicProviderConfig]()
  47. for credential in self.config:
  48. fields[credential.name] = credential
  49. for field_name, field in fields.items():
  50. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  51. if field_name in data:
  52. encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
  53. data[field_name] = encrypted
  54. return data
  55. def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
  56. """
  57. mask tool credentials
  58. return a deep copy of credentials with masked values
  59. """
  60. data = self._deep_copy(data)
  61. # get fields need to be decrypted
  62. fields = dict[str, BasicProviderConfig]()
  63. for credential in self.config:
  64. fields[credential.name] = credential
  65. for field_name, field in fields.items():
  66. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  67. if field_name in data:
  68. if len(data[field_name]) > 6:
  69. data[field_name] = (
  70. data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
  71. )
  72. else:
  73. data[field_name] = "*" * len(data[field_name])
  74. return data
  75. def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
  76. """
  77. decrypt tool credentials with tenant id
  78. return a deep copy of credentials with decrypted values
  79. """
  80. cached_credentials = self.provider_config_cache.get()
  81. if cached_credentials:
  82. return cached_credentials
  83. data = self._deep_copy(data)
  84. # get fields need to be decrypted
  85. fields = dict[str, BasicProviderConfig]()
  86. for credential in self.config:
  87. fields[credential.name] = credential
  88. for field_name, field in fields.items():
  89. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  90. if field_name in data:
  91. with contextlib.suppress(Exception):
  92. # if the value is None or empty string, skip decrypt
  93. if not data[field_name]:
  94. continue
  95. data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
  96. self.provider_config_cache.set(data)
  97. return data
  98. def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
  99. return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
  100. def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
  101. cache = SingletonProviderCredentialsCache(
  102. tenant_id=tenant_id,
  103. provider_type=controller.provider_type.value,
  104. provider_identity=controller.entity.identity.name,
  105. )
  106. encrypt = ProviderConfigEncrypter(
  107. tenant_id=tenant_id,
  108. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  109. provider_config_cache=cache,
  110. )
  111. return encrypt, cache