You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rsa.py 2.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import hashlib
  2. import os
  3. from typing import Union
  4. from Crypto.Cipher import AES
  5. from Crypto.PublicKey import RSA
  6. from Crypto.Random import get_random_bytes
  7. from extensions.ext_redis import redis_client
  8. from extensions.ext_storage import storage
  9. from libs import gmpy2_pkcs10aep_cipher
  10. def generate_key_pair(tenant_id: str) -> str:
  11. private_key = RSA.generate(2048)
  12. public_key = private_key.publickey()
  13. pem_private = private_key.export_key()
  14. pem_public = public_key.export_key()
  15. filepath = os.path.join("privkeys", tenant_id, "private.pem")
  16. storage.save(filepath, pem_private)
  17. return pem_public.decode()
  18. prefix_hybrid = b"HYBRID:"
  19. def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
  20. if isinstance(public_key, str):
  21. public_key = public_key.encode()
  22. aes_key = get_random_bytes(16)
  23. cipher_aes = AES.new(aes_key, AES.MODE_EAX)
  24. ciphertext, tag = cipher_aes.encrypt_and_digest(text.encode())
  25. rsa_key = RSA.import_key(public_key)
  26. cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
  27. enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
  28. encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
  29. return prefix_hybrid + encrypted_data
  30. def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
  31. filepath = os.path.join("privkeys", tenant_id, "private.pem")
  32. cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}"
  33. private_key = redis_client.get(cache_key)
  34. if not private_key:
  35. try:
  36. private_key = storage.load(filepath)
  37. except FileNotFoundError:
  38. raise PrivkeyNotFoundError(f"Private key not found, tenant_id: {tenant_id}")
  39. redis_client.setex(cache_key, 120, private_key)
  40. rsa_key = RSA.import_key(private_key)
  41. cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
  42. return rsa_key, cipher_rsa
  43. def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
  44. if encrypted_text.startswith(prefix_hybrid):
  45. encrypted_text = encrypted_text[len(prefix_hybrid) :]
  46. enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()]
  47. nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16]
  48. tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32]
  49. ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :]
  50. aes_key = cipher_rsa.decrypt(enc_aes_key)
  51. cipher_aes = AES.new(aes_key, AES.MODE_EAX, nonce=nonce)
  52. decrypted_text = cipher_aes.decrypt_and_verify(ciphertext, tag)
  53. else:
  54. decrypted_text = cipher_rsa.decrypt(encrypted_text)
  55. return decrypted_text.decode()
  56. def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
  57. rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
  58. return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
  59. class PrivkeyNotFoundError(Exception):
  60. pass