test_spoolbuddy_ssh.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. """Unit tests for SpoolBuddy SSH update service."""
  2. import os
  3. from unittest.mock import AsyncMock, MagicMock, patch
  4. import pytest
  5. from backend.app.services.spoolbuddy_ssh import (
  6. _get_ssh_key_dir,
  7. _run_ssh_command,
  8. detect_current_branch,
  9. get_or_create_keypair,
  10. get_public_key,
  11. perform_ssh_update,
  12. )
  13. # -- _get_ssh_key_dir ---------------------------------------------------------
  14. def test_get_ssh_key_dir_creates_directory(tmp_path):
  15. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  16. mock_settings.base_dir = tmp_path
  17. key_dir = _get_ssh_key_dir()
  18. assert key_dir == tmp_path / "spoolbuddy" / "ssh"
  19. assert key_dir.exists()
  20. def test_get_ssh_key_dir_returns_existing(tmp_path):
  21. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  22. ssh_dir.mkdir(parents=True)
  23. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  24. mock_settings.base_dir = tmp_path
  25. assert _get_ssh_key_dir() == ssh_dir
  26. # -- get_or_create_keypair -----------------------------------------------------
  27. @pytest.mark.asyncio
  28. async def test_get_or_create_keypair_returns_existing(tmp_path):
  29. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  30. ssh_dir.mkdir(parents=True)
  31. priv = ssh_dir / "id_ed25519"
  32. pub = ssh_dir / "id_ed25519.pub"
  33. priv.write_text("PRIVATE")
  34. pub.write_text("PUBLIC")
  35. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  36. mock_settings.base_dir = tmp_path
  37. result = await get_or_create_keypair()
  38. assert result == (priv, pub)
  39. @pytest.mark.asyncio
  40. async def test_get_or_create_keypair_generates_new(tmp_path):
  41. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  42. mock_settings.base_dir = tmp_path
  43. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  44. async def fake_keygen(*args, **kwargs):
  45. # Simulate ssh-keygen creating the files
  46. ssh_dir.mkdir(parents=True, exist_ok=True)
  47. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  48. (ssh_dir / "id_ed25519.pub").write_text("PUBLIC")
  49. mock_proc = AsyncMock()
  50. mock_proc.communicate = AsyncMock(return_value=(b"", b""))
  51. mock_proc.returncode = 0
  52. return mock_proc
  53. with patch("asyncio.create_subprocess_exec", side_effect=fake_keygen) as mock_exec:
  54. priv, pub = await get_or_create_keypair()
  55. mock_exec.assert_called_once()
  56. args = mock_exec.call_args[0]
  57. assert "ssh-keygen" in args
  58. assert "-t" in args
  59. assert "ed25519" in args
  60. @pytest.mark.asyncio
  61. async def test_get_or_create_keypair_raises_on_failure(tmp_path):
  62. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  63. mock_settings.base_dir = tmp_path
  64. mock_proc = AsyncMock()
  65. mock_proc.communicate = AsyncMock(return_value=(b"", b"keygen error"))
  66. mock_proc.returncode = 1
  67. with (
  68. patch("asyncio.create_subprocess_exec", return_value=mock_proc),
  69. pytest.raises(RuntimeError, match="ssh-keygen failed"),
  70. ):
  71. await get_or_create_keypair()
  72. # -- get_public_key ------------------------------------------------------------
  73. @pytest.mark.asyncio
  74. async def test_get_public_key(tmp_path):
  75. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  76. ssh_dir.mkdir(parents=True)
  77. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  78. (ssh_dir / "id_ed25519.pub").write_text("ssh-ed25519 AAAA bambuddy-spoolbuddy\n")
  79. with patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings:
  80. mock_settings.base_dir = tmp_path
  81. key = await get_public_key()
  82. assert key == "ssh-ed25519 AAAA bambuddy-spoolbuddy"
  83. # -- detect_current_branch ----------------------------------------------------
  84. def test_detect_branch_from_git(tmp_path):
  85. (tmp_path / ".git").mkdir()
  86. with (
  87. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  88. patch("subprocess.run") as mock_run,
  89. ):
  90. mock_settings.base_dir = tmp_path
  91. mock_run.return_value = MagicMock(returncode=0, stdout="dev\n")
  92. assert detect_current_branch() == "dev"
  93. def test_detect_branch_env_fallback(tmp_path):
  94. with (
  95. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  96. patch.dict(os.environ, {"GIT_BRANCH": "staging"}),
  97. ):
  98. mock_settings.base_dir = tmp_path
  99. assert detect_current_branch() == "staging"
  100. def test_detect_branch_default_main(tmp_path):
  101. with (
  102. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  103. patch.dict(os.environ, {}, clear=True),
  104. ):
  105. mock_settings.base_dir = tmp_path
  106. # Remove GIT_BRANCH if present
  107. os.environ.pop("GIT_BRANCH", None)
  108. assert detect_current_branch() == "main"
  109. # -- _run_ssh_command ----------------------------------------------------------
  110. @pytest.mark.asyncio
  111. async def test_run_ssh_command_success(tmp_path):
  112. key_file = tmp_path / "key"
  113. key_file.write_text("KEY")
  114. mock_proc = AsyncMock()
  115. mock_proc.communicate = AsyncMock(return_value=(b"hello\n", b""))
  116. mock_proc.returncode = 0
  117. with patch("asyncio.create_subprocess_exec", return_value=mock_proc) as mock_exec:
  118. rc, stdout, stderr = await _run_ssh_command("10.0.0.1", "echo hello", key_file)
  119. assert rc == 0
  120. assert stdout == "hello\n"
  121. assert stderr == ""
  122. args = mock_exec.call_args[0]
  123. assert "spoolbuddy@10.0.0.1" in args
  124. assert "echo hello" in args
  125. assert "BatchMode=yes" in args
  126. @pytest.mark.asyncio
  127. async def test_run_ssh_command_failure(tmp_path):
  128. key_file = tmp_path / "key"
  129. key_file.write_text("KEY")
  130. mock_proc = AsyncMock()
  131. mock_proc.communicate = AsyncMock(return_value=(b"", b"Connection refused"))
  132. mock_proc.returncode = 255
  133. with patch("asyncio.create_subprocess_exec", return_value=mock_proc):
  134. rc, stdout, stderr = await _run_ssh_command("10.0.0.1", "echo hello", key_file)
  135. assert rc == 255
  136. assert "Connection refused" in stderr
  137. @pytest.mark.asyncio
  138. async def test_run_ssh_command_timeout(tmp_path):
  139. key_file = tmp_path / "key"
  140. key_file.write_text("KEY")
  141. mock_proc = AsyncMock()
  142. mock_proc.communicate = AsyncMock(return_value=(b"", b""))
  143. mock_proc.kill = MagicMock()
  144. async def fake_wait_for(coro, timeout):
  145. # Consume the coroutine to avoid warning
  146. coro.close()
  147. raise TimeoutError
  148. with (
  149. patch("asyncio.create_subprocess_exec", return_value=mock_proc),
  150. patch("backend.app.services.spoolbuddy_ssh.asyncio.wait_for", side_effect=fake_wait_for),
  151. ):
  152. rc, stdout, stderr = await _run_ssh_command("10.0.0.1", "sleep 999", key_file, timeout=1)
  153. assert rc == -1
  154. assert "timed out" in stderr
  155. mock_proc.kill.assert_called_once()
  156. # -- perform_ssh_update --------------------------------------------------------
  157. def _make_update_mocks(tmp_path):
  158. """Create common mocks for perform_ssh_update tests."""
  159. mock_db_device = MagicMock()
  160. mock_db_device.update_status = None
  161. mock_db_device.update_message = None
  162. mock_db_device.pending_command = None
  163. mock_result = MagicMock()
  164. mock_result.scalar_one_or_none.return_value = mock_db_device
  165. mock_session = AsyncMock()
  166. mock_session.execute = AsyncMock(return_value=mock_result)
  167. mock_session.commit = AsyncMock()
  168. mock_ctx = AsyncMock()
  169. mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
  170. mock_ctx.__aexit__ = AsyncMock(return_value=False)
  171. mock_ws = MagicMock()
  172. mock_ws.broadcast = AsyncMock()
  173. return mock_db_device, mock_ctx, mock_ws
  174. @pytest.mark.asyncio
  175. async def test_perform_ssh_update_success(tmp_path):
  176. """Full update flow: all SSH commands succeed."""
  177. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  178. ssh_dir.mkdir(parents=True)
  179. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  180. (ssh_dir / "id_ed25519.pub").write_text("PUBLIC")
  181. ssh_calls = []
  182. async def mock_ssh(ip, cmd, key, timeout=60):
  183. ssh_calls.append(cmd)
  184. return 0, "ok", ""
  185. _, mock_ctx, mock_ws = _make_update_mocks(tmp_path)
  186. with (
  187. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  188. patch("backend.app.services.spoolbuddy_ssh._run_ssh_command", side_effect=mock_ssh),
  189. patch("backend.app.services.spoolbuddy_ssh.detect_current_branch", return_value="dev"),
  190. patch("backend.app.core.database.async_session", return_value=mock_ctx),
  191. patch("backend.app.api.routes.spoolbuddy.ws_manager", mock_ws),
  192. ):
  193. mock_settings.base_dir = tmp_path
  194. await perform_ssh_update("sb-test", "10.0.0.1")
  195. # Should have run: echo ok, git fetch, git checkout+reset, pip install,
  196. # systemctl restart, find (SW cleanup), systemctl restart getty
  197. assert len(ssh_calls) == 7
  198. assert "echo ok" in ssh_calls[0]
  199. assert "fetch" in ssh_calls[1]
  200. assert "checkout" in ssh_calls[2]
  201. assert "pip" in ssh_calls[3]
  202. assert "spoolbuddy.service" in ssh_calls[4]
  203. assert "Service Worker" in ssh_calls[5]
  204. assert "getty" in ssh_calls[6]
  205. assert mock_ws.broadcast.call_count >= 4
  206. @pytest.mark.asyncio
  207. async def test_perform_ssh_update_ssh_failure(tmp_path):
  208. """SSH connectivity check fails — should set error status."""
  209. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  210. ssh_dir.mkdir(parents=True)
  211. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  212. (ssh_dir / "id_ed25519.pub").write_text("PUBLIC")
  213. async def mock_ssh(ip, cmd, key, timeout=60):
  214. if "echo ok" in cmd:
  215. return 255, "", "Connection refused"
  216. return 0, "", ""
  217. mock_device, mock_ctx, mock_ws = _make_update_mocks(tmp_path)
  218. with (
  219. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  220. patch("backend.app.services.spoolbuddy_ssh._run_ssh_command", side_effect=mock_ssh),
  221. patch("backend.app.services.spoolbuddy_ssh.detect_current_branch", return_value="main"),
  222. patch("backend.app.core.database.async_session", return_value=mock_ctx),
  223. patch("backend.app.api.routes.spoolbuddy.ws_manager", mock_ws),
  224. ):
  225. mock_settings.base_dir = tmp_path
  226. await perform_ssh_update("sb-test", "10.0.0.1")
  227. # Should broadcast error status
  228. error_broadcasts = [c for c in mock_ws.broadcast.call_args_list if c[0][0].get("update_status") == "error"]
  229. assert len(error_broadcasts) >= 1
  230. assert "SSH connection failed" in error_broadcasts[0][0][0]["update_message"]
  231. @pytest.mark.asyncio
  232. async def test_perform_ssh_update_git_fetch_failure(tmp_path):
  233. """Git fetch fails — should set error and stop."""
  234. ssh_dir = tmp_path / "spoolbuddy" / "ssh"
  235. ssh_dir.mkdir(parents=True)
  236. (ssh_dir / "id_ed25519").write_text("PRIVATE")
  237. (ssh_dir / "id_ed25519.pub").write_text("PUBLIC")
  238. ssh_calls = []
  239. async def mock_ssh(ip, cmd, key, timeout=60):
  240. ssh_calls.append(cmd)
  241. if "fetch" in cmd:
  242. return 1, "", "fatal: could not read from remote"
  243. return 0, "ok", ""
  244. _, mock_ctx, mock_ws = _make_update_mocks(tmp_path)
  245. with (
  246. patch("backend.app.services.spoolbuddy_ssh.settings") as mock_settings,
  247. patch("backend.app.services.spoolbuddy_ssh._run_ssh_command", side_effect=mock_ssh),
  248. patch("backend.app.services.spoolbuddy_ssh.detect_current_branch", return_value="main"),
  249. patch("backend.app.core.database.async_session", return_value=mock_ctx),
  250. patch("backend.app.api.routes.spoolbuddy.ws_manager", mock_ws),
  251. ):
  252. mock_settings.base_dir = tmp_path
  253. await perform_ssh_update("sb-test", "10.0.0.1")
  254. # Should stop after git fetch — no checkout, pip, restart
  255. assert len(ssh_calls) == 2 # echo ok + git fetch
  256. assert not any("checkout" in c for c in ssh_calls)