test_spoolbuddy_ssh.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. """Unit tests for SpoolBuddy SSH update service."""
  2. import asyncio
  3. import os
  4. from unittest.mock import AsyncMock, MagicMock, patch
  5. import pytest
  6. from backend.app.services.spoolbuddy_ssh import (
  7. _get_ssh_key_dir,
  8. _run_ssh_command,
  9. detect_current_branch,
  10. get_or_create_keypair,
  11. get_public_key,
  12. perform_ssh_update,
  13. )
  14. # -- _get_ssh_key_dir ---------------------------------------------------------
  15. def test_get_ssh_key_dir_creates_directory(tmp_path):
  16. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  17. mock_settings.base_dir = tmp_path
  18. key_dir = _get_ssh_key_dir()
  19. assert key_dir == tmp_path / "spoolbuddy" / "ssh"
  20. assert key_dir.exists()
  21. def test_get_ssh_key_dir_returns_existing(tmp_path):
  22. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  23. ssh_dir.mkdir(parents=True)
  24. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  25. mock_settings.base_dir = tmp_path
  26. assert _get_ssh_key_dir() == ssh_dir
  27. # -- get_or_create_keypair -----------------------------------------------------
  28. @pytest.mark.asyncio
  29. async def test_get_or_create_keypair_returns_existing(tmp_path):
  30. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  31. ssh_dir.mkdir(parents=True)
  32. priv = ssh_dir / "id_ed25519"
  33. pub = ssh_dir / "id_ed25519.pub"
  34. priv.write_text("PRIVATE")
  35. pub.write_text("PUBLIC")
  36. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  37. mock_settings.base_dir = tmp_path
  38. result = await get_or_create_keypair()
  39. assert result == (priv, pub)
  40. @pytest.mark.asyncio
  41. async def test_get_or_create_keypair_generates_new(tmp_path):
  42. """Key generation runs in-process via `cryptography` — no ssh-keygen subprocess.
  43. This matters in Docker: when the container runs under an arbitrary PUID
  44. that isn't in /etc/passwd, `ssh-keygen` aborts with "no user exists for uid
  45. <N>". Generating the keypair in-process avoids the getpwuid() lookup.
  46. """
  47. from cryptography.hazmat.primitives import serialization
  48. from cryptography.hazmat.primitives.asymmetric import ed25519
  49. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  50. mock_settings.base_dir = tmp_path
  51. priv, pub = await get_or_create_keypair()
  52. assert priv.exists()
  53. assert pub.exists()
  54. # Private key permissions — no world/group access
  55. assert (priv.stat().st_mode & 0o077) == 0
  56. # Public key is a valid OpenSSH ed25519 key with our comment
  57. pub_text = pub.read_text()
  58. assert pub_text.startswith("ssh-ed25519 ")
  59. assert pub_text.rstrip().endswith("bambuddy-spoolbuddy")
  60. # Private key is a valid OpenSSH-format ed25519 key we can load back
  61. loaded = serialization.load_ssh_private_key(priv.read_bytes(), password=None)
  62. assert isinstance(loaded, ed25519.Ed25519PrivateKey)
  63. @pytest.mark.asyncio
  64. async def test_get_or_create_keypair_does_not_shell_out(tmp_path):
  65. """Regression guard: must not invoke any subprocess (fixes Docker PUID bug)."""
  66. with (
  67. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  68. patch("asyncio.create_subprocess_exec") as mock_exec,
  69. ):
  70. mock_settings.base_dir = tmp_path
  71. await get_or_create_keypair()
  72. mock_exec.assert_not_called()
  73. # -- get_public_key ------------------------------------------------------------
  74. @pytest.mark.asyncio
  75. async def test_get_public_key(tmp_path):
  76. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  77. ssh_dir.mkdir(parents=True)
  78. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  79. (ssh_dir / "id_ed25519.pub").write_text("ssh-ed25519 AAAA bambuddy-spoolbuddy\n")
  80. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  81. mock_settings.base_dir = tmp_path
  82. key = await get_public_key()
  83. assert key == "ssh-ed25519 AAAA bambuddy-spoolbuddy"
  84. # -- detect_current_branch ----------------------------------------------------
  85. def test_detect_branch_from_git_head(tmp_path):
  86. """Read branch directly from .git/HEAD in the application root — no subprocess."""
  87. git_dir = tmp_path / ".git"
  88. git_dir.mkdir()
  89. (git_dir / "HEAD").write_text("ref: refs/heads/dev\n")
  90. with (
  91. patch("backend.app.services.spoolbuddy_ssh._APP_DIR", tmp_path),
  92. patch("asyncio.create_subprocess_exec") as mock_exec,
  93. patch("subprocess.run") as mock_run,
  94. ):
  95. assert detect_current_branch() == "dev"
  96. # Regression guard: must not shell out (fails with getpwuid under
  97. # arbitrary Docker PUIDs if ever reintroduced).
  98. mock_exec.assert_not_called()
  99. mock_run.assert_not_called()
  100. def test_detect_branch_uses_app_dir_not_data_dir(tmp_path):
  101. """Branch detection must look in the application root, not the data dir.
  102. Regression guard for the Docker bug where `.git` was being looked up in
  103. `settings.base_dir` (which is `DATA_DIR=/app/data` in Docker), so it was
  104. never found and the fallback always returned "main" — even when the user
  105. was on a feature branch bind-mounted at `/app`.
  106. """
  107. app_dir = tmp_path / "app"
  108. data_dir = tmp_path / "app" / "data"
  109. app_dir.mkdir()
  110. data_dir.mkdir()
  111. # Real .git lives at the application root (bind-mount style).
  112. (app_dir / ".git").mkdir()
  113. (app_dir / ".git" / "HEAD").write_text("ref: refs/heads/dev\n")
  114. # Decoy .git in the data dir — if the code ever regresses to reading
  115. # from settings.base_dir, this would be returned instead.
  116. (data_dir / ".git").mkdir()
  117. (data_dir / ".git" / "HEAD").write_text("ref: refs/heads/wrong-branch\n")
  118. with (
  119. patch("backend.app.services.spoolbuddy_ssh._APP_DIR", app_dir),
  120. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  121. ):
  122. mock_settings.base_dir = data_dir
  123. assert detect_current_branch() == "dev"
  124. def test_detect_branch_worktree_gitdir_file(tmp_path):
  125. """Git worktrees store a `gitdir:` pointer instead of a dir — follow it."""
  126. real_git_dir = tmp_path / "real-git"
  127. real_git_dir.mkdir()
  128. (real_git_dir / "HEAD").write_text("ref: refs/heads/feature-x\n")
  129. (tmp_path / ".git").write_text(f"gitdir: {real_git_dir}\n")
  130. with patch("backend.app.services.spoolbuddy_ssh._APP_DIR", tmp_path):
  131. assert detect_current_branch() == "feature-x"
  132. def test_detect_branch_detached_head_falls_back(tmp_path):
  133. """Detached HEAD (raw commit hash) should fall through to the env var."""
  134. git_dir = tmp_path / ".git"
  135. git_dir.mkdir()
  136. (git_dir / "HEAD").write_text("deadbeef1234\n")
  137. with (
  138. patch("backend.app.services.spoolbuddy_ssh._APP_DIR", tmp_path),
  139. patch.dict(os.environ, {"GIT_BRANCH": "release"}),
  140. ):
  141. assert detect_current_branch() == "release"
  142. def test_detect_branch_env_fallback(tmp_path):
  143. with (
  144. patch("backend.app.services.spoolbuddy_ssh._APP_DIR", tmp_path),
  145. patch.dict(os.environ, {"GIT_BRANCH": "staging"}),
  146. ):
  147. assert detect_current_branch() == "staging"
  148. def test_detect_branch_default_main(tmp_path):
  149. with (
  150. patch("backend.app.services.spoolbuddy_ssh._APP_DIR", tmp_path),
  151. patch.dict(os.environ, {}, clear=True),
  152. ):
  153. # Remove GIT_BRANCH if present
  154. os.environ.pop("GIT_BRANCH", None)
  155. assert detect_current_branch() == "main"
  156. # -- _run_ssh_command ----------------------------------------------------------
  157. #
  158. # _run_ssh_command uses asyncssh (pure Python) rather than the OpenSSH `ssh`
  159. # binary. Both `ssh` and `ssh-keygen` call getpwuid(getuid()) during startup
  160. # and abort with "No user exists for uid <N>" when the container runs under
  161. # an arbitrary PUID that is not listed in /etc/passwd — asyncssh avoids the
  162. # subprocess entirely.
  163. @pytest.mark.asyncio
  164. async def test_run_ssh_command_success(tmp_path):
  165. key_file = tmp_path / "key"
  166. key_file.write_text("KEY")
  167. mock_result = MagicMock()
  168. mock_result.stdout = "hello\n"
  169. mock_result.stderr = ""
  170. mock_result.exit_status = 0
  171. mock_conn = AsyncMock()
  172. mock_conn.run = AsyncMock(return_value=mock_result)
  173. mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
  174. mock_conn.__aexit__ = AsyncMock(return_value=False)
  175. with patch("backend.app.services.spoolbuddy_ssh.asyncssh.connect", return_value=mock_conn) as mock_connect:
  176. rc, stdout, stderr = await _run_ssh_command("10.0.0.1", "echo hello", key_file)
  177. assert rc == 0
  178. assert stdout == "hello\n"
  179. assert stderr == ""
  180. kwargs = mock_connect.call_args.kwargs
  181. assert kwargs["host"] == "10.0.0.1"
  182. assert kwargs["username"] == "spoolbuddy"
  183. assert kwargs["client_keys"] == [str(key_file)]
  184. # Host-key verification is disabled (equivalent to StrictHostKeyChecking=no)
  185. assert kwargs["known_hosts"] is None
  186. # ~/.ssh/config loading is disabled — HOME may not resolve under arbitrary
  187. # Docker PUIDs.
  188. assert kwargs["config"] == []
  189. mock_conn.run.assert_awaited_once()
  190. run_args = mock_conn.run.call_args
  191. assert run_args.args[0] == "echo hello"
  192. # check=False — we handle non-zero exit codes ourselves
  193. assert run_args.kwargs.get("check") is False
  194. @pytest.mark.asyncio
  195. async def test_run_ssh_command_no_subprocess(tmp_path):
  196. """Regression guard: _run_ssh_command must not spawn any subprocess.
  197. The whole point of switching to asyncssh is to avoid `ssh`/`ssh-keygen`
  198. calling getpwuid() inside Docker containers with arbitrary PUIDs.
  199. """
  200. key_file = tmp_path / "key"
  201. key_file.write_text("KEY")
  202. mock_result = MagicMock()
  203. mock_result.stdout = ""
  204. mock_result.stderr = ""
  205. mock_result.exit_status = 0
  206. mock_conn = AsyncMock()
  207. mock_conn.run = AsyncMock(return_value=mock_result)
  208. mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
  209. mock_conn.__aexit__ = AsyncMock(return_value=False)
  210. with (
  211. patch("backend.app.services.spoolbuddy_ssh.asyncssh.connect", return_value=mock_conn),
  212. patch("asyncio.create_subprocess_exec") as mock_exec,
  213. ):
  214. await _run_ssh_command("10.0.0.1", "echo hi", key_file)
  215. mock_exec.assert_not_called()
  216. @pytest.mark.asyncio
  217. async def test_run_ssh_command_connection_failure(tmp_path):
  218. """Connection errors should surface as rc=255 with the asyncssh message."""
  219. import asyncssh
  220. key_file = tmp_path / "key"
  221. key_file.write_text("KEY")
  222. with patch(
  223. "backend.app.services.spoolbuddy_ssh.asyncssh.connect",
  224. side_effect=asyncssh.Error(code=0, reason="Connection refused"),
  225. ):
  226. rc, stdout, stderr = await _run_ssh_command("10.0.0.1", "echo hello", key_file)
  227. assert rc == 255
  228. assert stdout == ""
  229. assert "Connection refused" in stderr
  230. @pytest.mark.asyncio
  231. async def test_run_ssh_command_os_error(tmp_path):
  232. """OS-level connection errors (DNS, route) also map to rc=255."""
  233. key_file = tmp_path / "key"
  234. key_file.write_text("KEY")
  235. with patch(
  236. "backend.app.services.spoolbuddy_ssh.asyncssh.connect",
  237. side_effect=OSError("Network is unreachable"),
  238. ):
  239. rc, _, stderr = await _run_ssh_command("10.0.0.1", "echo hello", key_file)
  240. assert rc == 255
  241. assert "Network is unreachable" in stderr
  242. @pytest.mark.asyncio
  243. async def test_run_ssh_command_timeout(tmp_path):
  244. """asyncio.timeout should convert long-running commands into rc=-1."""
  245. key_file = tmp_path / "key"
  246. key_file.write_text("KEY")
  247. # asyncssh.connect() returns a _ConnectionManager synchronously; the hang
  248. # must happen inside __aenter__ so the surrounding asyncio.timeout can
  249. # cancel it.
  250. mock_conn = AsyncMock()
  251. async def hang_enter():
  252. await asyncio.sleep(10)
  253. mock_conn.__aenter__ = AsyncMock(side_effect=hang_enter)
  254. mock_conn.__aexit__ = AsyncMock(return_value=False)
  255. with patch("backend.app.services.spoolbuddy_ssh.asyncssh.connect", return_value=mock_conn):
  256. rc, _, stderr = await _run_ssh_command("10.0.0.1", "sleep 999", key_file, timeout=0.05)
  257. assert rc == -1
  258. assert "timed out" in stderr
  259. # -- perform_ssh_update --------------------------------------------------------
  260. def _make_update_mocks(tmp_path):
  261. """Create common mocks for perform_ssh_update tests."""
  262. mock_db_device = MagicMock()
  263. mock_db_device.update_status = None
  264. mock_db_device.update_message = None
  265. mock_db_device.pending_command = None
  266. mock_result = MagicMock()
  267. mock_result.scalar_one_or_none.return_value = mock_db_device
  268. mock_session = AsyncMock()
  269. mock_session.execute = AsyncMock(return_value=mock_result)
  270. mock_session.commit = AsyncMock()
  271. mock_ctx = AsyncMock()
  272. mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
  273. mock_ctx.__aexit__ = AsyncMock(return_value=False)
  274. mock_ws = MagicMock()
  275. mock_ws.broadcast = AsyncMock()
  276. return mock_db_device, mock_ctx, mock_ws
  277. @pytest.mark.asyncio
  278. async def test_perform_ssh_update_success(tmp_path):
  279. """Full update flow: all SSH commands succeed."""
  280. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  281. ssh_dir.mkdir(parents=True)
  282. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  283. (ssh_dir / "id_ed25519.pub").write_text("PUBLIC")
  284. ssh_calls = []
  285. async def mock_ssh(ip, cmd, key, timeout=60):
  286. ssh_calls.append(cmd)
  287. return 0, "ok", ""
  288. _, mock_ctx, mock_ws = _make_update_mocks(tmp_path)
  289. with (
  290. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  291. patch("backend.app.services.spoolbuddy_ssh._run_ssh_command", side_effect=mock_ssh),
  292. patch("backend.app.services.spoolbuddy_ssh.detect_current_branch", return_value="dev"),
  293. patch("backend.app.core.database.async_session", return_value=mock_ctx),
  294. patch("backend.app.api.routes.spoolbuddy.ws_manager", mock_ws),
  295. ):
  296. mock_settings.base_dir = tmp_path
  297. await perform_ssh_update("sb-test", "10.0.0.1")
  298. # Should have run: echo ok, git fetch, git checkout+reset, pip install,
  299. # systemctl restart, find (SW cleanup), systemctl restart getty
  300. assert len(ssh_calls) == 7
  301. assert "echo ok" in ssh_calls[0]
  302. assert "fetch" in ssh_calls[1]
  303. assert "checkout" in ssh_calls[2]
  304. assert "pip" in ssh_calls[3]
  305. assert "spoolbuddy.service" in ssh_calls[4]
  306. assert "Service Worker" in ssh_calls[5]
  307. assert "getty" in ssh_calls[6]
  308. assert mock_ws.broadcast.call_count >= 4
  309. @pytest.mark.asyncio
  310. async def test_perform_ssh_update_ssh_failure(tmp_path):
  311. """SSH connectivity check fails — should set error status."""
  312. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  313. ssh_dir.mkdir(parents=True)
  314. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  315. (ssh_dir / "id_ed25519.pub").write_text("PUBLIC")
  316. async def mock_ssh(ip, cmd, key, timeout=60):
  317. if "echo ok" in cmd:
  318. return 255, "", "Connection refused"
  319. return 0, "", ""
  320. mock_device, mock_ctx, mock_ws = _make_update_mocks(tmp_path)
  321. with (
  322. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  323. patch("backend.app.services.spoolbuddy_ssh._run_ssh_command", side_effect=mock_ssh),
  324. patch("backend.app.services.spoolbuddy_ssh.detect_current_branch", return_value="main"),
  325. patch("backend.app.core.database.async_session", return_value=mock_ctx),
  326. patch("backend.app.api.routes.spoolbuddy.ws_manager", mock_ws),
  327. ):
  328. mock_settings.base_dir = tmp_path
  329. await perform_ssh_update("sb-test", "10.0.0.1")
  330. # Should broadcast error status
  331. error_broadcasts = [c for c in mock_ws.broadcast.call_args_list if c[0][0].get("update_status") == "error"]
  332. assert len(error_broadcasts) >= 1
  333. assert "SSH connection failed" in error_broadcasts[0][0][0]["update_message"]
  334. @pytest.mark.asyncio
  335. async def test_perform_ssh_update_git_fetch_failure(tmp_path):
  336. """Git fetch fails — should set error and stop."""
  337. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  338. ssh_dir.mkdir(parents=True)
  339. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  340. (ssh_dir / "id_ed25519.pub").write_text("PUBLIC")
  341. ssh_calls = []
  342. async def mock_ssh(ip, cmd, key, timeout=60):
  343. ssh_calls.append(cmd)
  344. if "fetch" in cmd:
  345. return 1, "", "fatal: could not read from remote"
  346. return 0, "ok", ""
  347. _, mock_ctx, mock_ws = _make_update_mocks(tmp_path)
  348. with (
  349. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  350. patch("backend.app.services.spoolbuddy_ssh._run_ssh_command", side_effect=mock_ssh),
  351. patch("backend.app.services.spoolbuddy_ssh.detect_current_branch", return_value="main"),
  352. patch("backend.app.core.database.async_session", return_value=mock_ctx),
  353. patch("backend.app.api.routes.spoolbuddy.ws_manager", mock_ws),
  354. ):
  355. mock_settings.base_dir = tmp_path
  356. await perform_ssh_update("sb-test", "10.0.0.1")
  357. # Should stop after git fetch — no checkout, pip, restart
  358. assert len(ssh_calls) == 2 # echo ok + git fetch
  359. assert not any("checkout" in c for c in ssh_calls)