您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

encryption.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from copy import deepcopy
  2. from typing import Any, Optional, Protocol
  3. from core.entities.provider_entities import BasicProviderConfig
  4. from core.helper import encrypter
  5. from core.helper.provider_cache import SingletonProviderCredentialsCache
  6. from core.tools.__base.tool_provider import ToolProviderController
  7. class ProviderConfigCache(Protocol):
  8. """
  9. Interface for provider configuration cache operations
  10. """
  11. def get(self) -> Optional[dict]:
  12. """Get cached provider configuration"""
  13. ...
  14. def set(self, config: dict[str, Any]) -> None:
  15. """Cache provider configuration"""
  16. ...
  17. def delete(self) -> None:
  18. """Delete cached provider configuration"""
  19. ...
  20. class ProviderConfigEncrypter:
  21. tenant_id: str
  22. config: list[BasicProviderConfig]
  23. provider_config_cache: ProviderConfigCache
  24. def __init__(
  25. self,
  26. tenant_id: str,
  27. config: list[BasicProviderConfig],
  28. provider_config_cache: ProviderConfigCache,
  29. ):
  30. self.tenant_id = tenant_id
  31. self.config = config
  32. self.provider_config_cache = provider_config_cache
  33. def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
  34. """
  35. deep copy data
  36. """
  37. return deepcopy(data)
  38. def encrypt(self, data: dict[str, str]) -> dict[str, str]:
  39. """
  40. encrypt tool credentials with tenant id
  41. return a deep copy of credentials with encrypted values
  42. """
  43. data = self._deep_copy(data)
  44. # get fields need to be decrypted
  45. fields = dict[str, BasicProviderConfig]()
  46. for credential in self.config:
  47. fields[credential.name] = credential
  48. for field_name, field in fields.items():
  49. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  50. if field_name in data:
  51. encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
  52. data[field_name] = encrypted
  53. return data
  54. def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
  55. """
  56. mask tool credentials
  57. return a deep copy of credentials with masked values
  58. """
  59. data = self._deep_copy(data)
  60. # get fields need to be decrypted
  61. fields = dict[str, BasicProviderConfig]()
  62. for credential in self.config:
  63. fields[credential.name] = credential
  64. for field_name, field in fields.items():
  65. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  66. if field_name in data:
  67. if len(data[field_name]) > 6:
  68. data[field_name] = (
  69. data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
  70. )
  71. else:
  72. data[field_name] = "*" * len(data[field_name])
  73. return data
  74. def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
  75. """
  76. decrypt tool credentials with tenant id
  77. return a deep copy of credentials with decrypted values
  78. """
  79. cached_credentials = self.provider_config_cache.get()
  80. if cached_credentials:
  81. return cached_credentials
  82. data = self._deep_copy(data)
  83. # get fields need to be decrypted
  84. fields = dict[str, BasicProviderConfig]()
  85. for credential in self.config:
  86. fields[credential.name] = credential
  87. for field_name, field in fields.items():
  88. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  89. if field_name in data:
  90. try:
  91. # if the value is None or empty string, skip decrypt
  92. if not data[field_name]:
  93. continue
  94. data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
  95. except Exception:
  96. pass
  97. self.provider_config_cache.set(data)
  98. return data
  99. def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
  100. return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
  101. def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
  102. cache = SingletonProviderCredentialsCache(
  103. tenant_id=tenant_id,
  104. provider_type=controller.provider_type.value,
  105. provider_identity=controller.entity.identity.name,
  106. )
  107. encrypt = ProviderConfigEncrypter(
  108. tenant_id=tenant_id,
  109. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  110. provider_config_cache=cache,
  111. )
  112. return encrypt, cache