encryption.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """At-rest encryption for high-value secrets (TOTP keys, OIDC client_secret).
  2. Set the ``MFA_ENCRYPTION_KEY`` environment variable to a URL-safe base64-encoded
  3. 32-byte key (generate with ``python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"``)
  4. to enable Fernet symmetric encryption.
  5. When the key is not set, values are stored as plaintext and a warning is emitted.
  6. Existing plaintext values are read back correctly even after the key is added
  7. (values without the ``fernet:`` prefix are treated as legacy plaintext).
  8. """
  9. from __future__ import annotations
  10. import logging
  11. import os
  12. logger = logging.getLogger(__name__)
  13. _FERNET_PREFIX = "fernet:"
  14. _fernet_instance = None
  15. _warn_shown = False
  16. def _get_fernet():
  17. global _fernet_instance, _warn_shown
  18. if _fernet_instance is not None:
  19. return _fernet_instance
  20. key = os.environ.get("MFA_ENCRYPTION_KEY")
  21. if key:
  22. from cryptography.fernet import Fernet
  23. _fernet_instance = Fernet(key.encode() if isinstance(key, str) else key)
  24. return _fernet_instance
  25. if not _warn_shown:
  26. logger.warning(
  27. "MFA_ENCRYPTION_KEY is not set — TOTP secrets and OIDC client_secrets are "
  28. "stored in plaintext. Generate a key with: "
  29. 'python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"'
  30. )
  31. _warn_shown = True
  32. return None
  33. def mfa_encrypt(plaintext: str) -> str:
  34. """Encrypt a secret value. Returns the ciphertext with a ``fernet:`` prefix,
  35. or the original plaintext if ``MFA_ENCRYPTION_KEY`` is not configured."""
  36. f = _get_fernet()
  37. if f is None:
  38. return plaintext
  39. return _FERNET_PREFIX + f.encrypt(plaintext.encode()).decode()
  40. def mfa_decrypt(value: str) -> str:
  41. """Decrypt a value previously encrypted with ``mfa_encrypt``.
  42. Values without the ``fernet:`` prefix are returned as-is (legacy plaintext).
  43. Raises ``RuntimeError`` if the prefix is present but no key is configured.
  44. """
  45. if not value.startswith(_FERNET_PREFIX):
  46. # Nit6: Warn when a key IS configured but the stored value is plaintext.
  47. # This surfaces rows that were written before encryption was enabled so
  48. # operators know they need a migration / re-enroll cycle.
  49. if _get_fernet() is not None:
  50. logger.warning(
  51. "mfa_decrypt: MFA_ENCRYPTION_KEY is set but the stored value has no "
  52. "'fernet:' prefix — returning legacy plaintext. Consider re-enrolling "
  53. "this secret to store it encrypted."
  54. )
  55. return value # Legacy plaintext — backward compatible
  56. f = _get_fernet()
  57. if f is None:
  58. raise RuntimeError(
  59. "MFA_ENCRYPTION_KEY must be set to decrypt MFA secrets that were stored with encryption enabled."
  60. )
  61. from cryptography.fernet import InvalidToken
  62. try:
  63. return f.decrypt(value[len(_FERNET_PREFIX) :].encode()).decode()
  64. except InvalidToken:
  65. raise RuntimeError(
  66. "MFA secret was encrypted under a different MFA_ENCRYPTION_KEY. "
  67. "Key rotation is not currently supported — restore the previous key "
  68. "or have users re-enroll."
  69. )