encryption.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. """At-rest encryption for high-value secrets (TOTP keys, OIDC client_secret).
  2. The encryption key is resolved on first use in this priority order:
  3. 1. ``MFA_ENCRYPTION_KEY`` environment variable (must be a URL-safe base64
  4. string that decodes to exactly 32 bytes — the Fernet key format).
  5. 2. ``DATA_DIR/.mfa_encryption_key`` file (read if present and valid). A
  6. corrupted or unreadable file falls back to plaintext (step 4) without
  7. overwriting — to protect previously encrypted rows.
  8. 3. Auto-generate a new Fernet key, write to ``DATA_DIR/.mfa_encryption_key``
  9. with mode ``0o600`` (only when neither env var nor key file exists).
  10. Falls back to plaintext (step 4) on OSError.
  11. 4. ``None`` (legacy plaintext fallback) — unreadable or corrupted key file,
  12. or read-only filesystem.
  13. Existing plaintext values are read back correctly even after a key is
  14. configured — values without the ``fernet:`` prefix are returned as-is. This
  15. keeps the auto-bootstrap non-breaking for installs that already wrote
  16. plaintext rows before the key existed.
  17. """
  18. from __future__ import annotations
  19. import base64
  20. import binascii
  21. import logging
  22. import os
  23. from typing import Literal
  24. logger = logging.getLogger(__name__)
  25. _FERNET_PREFIX = "fernet:"
  26. _fernet_instance = None
  27. _warn_shown = False
  28. # Public source values exposed via get_key_source(). Internal failure causes
  29. # (none_write_failed, none_corrupted) are mapped to "none" before exposure
  30. # so the public API stays stable for the EncryptionStatusResponse schema.
  31. _PublicSource = Literal["env", "file", "generated", "none"]
  32. # Internal source carries the specific failure cause for accurate logging.
  33. # "none" remains valid for legacy test stubs (lambda: (None, "none")).
  34. _InternalSource = Literal[
  35. "env",
  36. "file",
  37. "generated",
  38. "none",
  39. "none_write_failed",
  40. "none_corrupted",
  41. ]
  42. _key_source: _PublicSource | None = None
  43. _KEY_FILE_NAME = ".mfa_encryption_key"
  44. def _validate_fernet_key(key: str) -> bool:
  45. try:
  46. decoded = base64.urlsafe_b64decode(key.encode())
  47. except (binascii.Error, ValueError):
  48. return False
  49. return len(decoded) == 32
  50. def _load_or_generate_key() -> tuple[str | None, _InternalSource]:
  51. # Lazy import: keeps cryptography out of import-time even when the helper
  52. # is patched in tests that never invoke encryption.
  53. from cryptography.fernet import Fernet
  54. from backend.app.core.paths import resolve_data_dir
  55. # 1. Environment variable
  56. env_key = os.environ.get("MFA_ENCRYPTION_KEY")
  57. if env_key:
  58. if _validate_fernet_key(env_key):
  59. return env_key, "env"
  60. logger.error(
  61. "MFA_ENCRYPTION_KEY is set but is not a valid Fernet key "
  62. "(must decode to exactly 32 bytes). Falling back to file-based key."
  63. )
  64. data_dir = resolve_data_dir()
  65. key_file = data_dir / _KEY_FILE_NAME
  66. # 2. Existing file in DATA_DIR
  67. if key_file.exists():
  68. try:
  69. file_key = key_file.read_text().strip()
  70. except OSError as exc:
  71. # Refusing to fall through to regeneration — overwriting the file
  72. # would destroy access to every row already encrypted under the
  73. # current key. Operator must fix permissions or pin the key
  74. # explicitly via MFA_ENCRYPTION_KEY.
  75. logger.error(
  76. "Failed to read existing MFA key file %s (%s). "
  77. "Refusing to regenerate — this would destroy all previously encrypted secrets. "
  78. "Fix the file permissions or set MFA_ENCRYPTION_KEY explicitly.",
  79. key_file,
  80. exc,
  81. )
  82. return None, "none_corrupted"
  83. if _validate_fernet_key(file_key):
  84. return file_key, "file"
  85. logger.error(
  86. "%s is present but is not a valid Fernet key. "
  87. "Refusing to overwrite — fix the file or set MFA_ENCRYPTION_KEY. "
  88. "Falling back to plaintext storage.",
  89. key_file,
  90. )
  91. return None, "none_corrupted"
  92. # 3. Generate a new key and persist it.
  93. # S1: Use os.open(O_WRONLY|O_CREAT|O_EXCL, 0o600) to avoid the TOCTOU
  94. # window between write_text() (umask-respecting) and chmod() — the key
  95. # is created with 0o600 from the start, never world-readable.
  96. new_key = Fernet.generate_key().decode()
  97. try:
  98. data_dir.mkdir(parents=True, exist_ok=True)
  99. fd = os.open(str(key_file), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600)
  100. try:
  101. os.write(fd, new_key.encode())
  102. finally:
  103. os.close(fd)
  104. # S9: Some filesystems (Windows, SMB, FUSE without uid mapping) silently
  105. # ignore mode bits — verify and warn so operators know the key is not
  106. # protected at the FS level.
  107. actual_mode = key_file.stat().st_mode & 0o777
  108. if actual_mode != 0o600:
  109. logger.warning(
  110. "MFA key file %s: filesystem did not enforce 0o600 (actual: 0o%o). "
  111. "Key may be world-readable on Windows / SMB / FUSE mounts.",
  112. key_file,
  113. actual_mode,
  114. )
  115. logger.info("Generated new MFA encryption key and saved to %s", key_file)
  116. return new_key, "generated"
  117. except FileExistsError:
  118. # Race between key_file.exists() check above and O_EXCL — another
  119. # process created the file. Treat as corrupted (do NOT regenerate).
  120. logger.error(
  121. "Race detected creating %s (file appeared between check and create). "
  122. "Refusing to overwrite — set MFA_ENCRYPTION_KEY explicitly to recover.",
  123. key_file,
  124. )
  125. return None, "none_corrupted"
  126. except OSError as exc:
  127. logger.error(
  128. "Could not save MFA encryption key to %s (%s). "
  129. "Falling back to plaintext storage. Set MFA_ENCRYPTION_KEY in the "
  130. "environment or fix the data-dir permissions to enable encryption.",
  131. key_file,
  132. exc,
  133. )
  134. return None, "none_write_failed"
  135. def get_key_source() -> _PublicSource | None:
  136. return _key_source
  137. def is_encryption_active() -> bool:
  138. return _get_fernet() is not None
  139. def _get_fernet():
  140. global _fernet_instance, _warn_shown, _key_source
  141. if _fernet_instance is not None:
  142. return _fernet_instance
  143. key, internal_source = _load_or_generate_key()
  144. # S8: collapse internal failure causes to public "none" while keeping
  145. # the differentiated source for the warning path below.
  146. _key_source = "none" if internal_source.startswith("none") else internal_source
  147. if key is None:
  148. if not _warn_shown:
  149. # S8: only emit the "DATA_DIR not writable" warning when that's
  150. # actually the cause. The corrupted-file path already error-logged
  151. # in _load_or_generate_key with a more specific message.
  152. if internal_source == "none_write_failed":
  153. logger.warning(
  154. "MFA_ENCRYPTION_KEY is not set and DATA_DIR is not writable — "
  155. "TOTP secrets and OIDC client_secrets are stored in plaintext. "
  156. "Generate a key with: "
  157. 'python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"'
  158. )
  159. # Suppresses repetitive warnings across calls; reset together
  160. # with _fernet_instance when re-initializing (e.g. in tests).
  161. _warn_shown = True
  162. return None
  163. from cryptography.fernet import Fernet
  164. _fernet_instance = Fernet(key.encode())
  165. return _fernet_instance
  166. def mfa_encrypt(plaintext: str) -> str:
  167. """Encrypt a secret value. Returns the ciphertext with a ``fernet:`` prefix,
  168. or the original plaintext if no encryption key is available."""
  169. f = _get_fernet()
  170. if f is None:
  171. return plaintext
  172. return _FERNET_PREFIX + f.encrypt(plaintext.encode()).decode()
  173. def mfa_decrypt(value: str) -> str:
  174. """Decrypt a value previously encrypted with ``mfa_encrypt``.
  175. Values without the ``fernet:`` prefix are returned as-is (legacy plaintext).
  176. Raises ``RuntimeError`` if the prefix is present but no key is configured.
  177. """
  178. if not value.startswith(_FERNET_PREFIX):
  179. # S7: Warn when a key IS configured but the stored value is plaintext.
  180. # This surfaces rows that were written before encryption was enabled so
  181. # operators know they need a migration / re-enroll cycle. WARNING level
  182. # so it shows up in normal operator log review.
  183. if _get_fernet() is not None:
  184. logger.warning(
  185. "mfa_decrypt: encryption key is active but the stored value has no "
  186. "'fernet:' prefix — returning legacy plaintext. Consider re-enrolling "
  187. "this secret to store it encrypted."
  188. )
  189. return value # Legacy plaintext — backward compatible
  190. f = _get_fernet()
  191. if f is None:
  192. raise RuntimeError(
  193. "MFA_ENCRYPTION_KEY must be set to decrypt MFA secrets that were stored with encryption enabled."
  194. )
  195. from cryptography.fernet import InvalidToken
  196. try:
  197. return f.decrypt(value[len(_FERNET_PREFIX) :].encode()).decode()
  198. except InvalidToken as exc:
  199. raise RuntimeError(
  200. "MFA secret was encrypted under a different MFA_ENCRYPTION_KEY. "
  201. "Key rotation is not currently supported — restore the previous key "
  202. "or have users re-enroll."
  203. ) from exc