test_get_db_cancel_safety.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """Tests for `get_db` cancel-safety (#1112).
  2. Starlette's BaseHTTPMiddleware cancels the inner task scope when a
  3. client disconnects mid-request. Pre-fix `get_db` only caught `Exception`
  4. (not `BaseException`), so `CancelledError` skipped the rollback path —
  5. the SQLite write lock stayed held until the connection was eventually
  6. GC'd, producing the "database is locked" cascade in @Carter3DP's
  7. support package on #1112.
  8. The fix:
  9. 1. Catch `BaseException` so `CancelledError` triggers rollback.
  10. 2. `asyncio.shield` rollback + close so the cleanup completes even
  11. when the await is cancelled by the same cancel scope.
  12. """
  13. from __future__ import annotations
  14. import asyncio
  15. from unittest.mock import AsyncMock, patch
  16. import pytest
  17. from backend.app.core import database
  18. class _FakeSession:
  19. """Minimal async-context-manager stand-in for `AsyncSession`.
  20. Records which lifecycle methods were invoked so tests can assert on
  21. the cleanup order without a real engine / DB file.
  22. """
  23. def __init__(self):
  24. self.commit = AsyncMock(name="commit")
  25. self.rollback = AsyncMock(name="rollback")
  26. self.close = AsyncMock(name="close")
  27. async def __aenter__(self):
  28. return self
  29. async def __aexit__(self, exc_type, exc, tb):
  30. return False # don't suppress
  31. @pytest.fixture
  32. def fake_session_factory(monkeypatch):
  33. """Patch `database.async_session` to yield a fresh `_FakeSession`."""
  34. session = _FakeSession()
  35. monkeypatch.setattr(database, "async_session", lambda: session)
  36. return session
  37. async def _consume_get_db(action):
  38. """Drive `get_db` like FastAPI's dependency machinery does:
  39. enter the async generator, run `action(session)`, then advance to
  40. completion. Returns the entered session."""
  41. gen = database.get_db()
  42. session = await gen.__anext__()
  43. try:
  44. await action(session)
  45. except StopAsyncIteration:
  46. return session
  47. # Advance to the end so the generator's finally runs.
  48. try:
  49. await gen.__anext__()
  50. except StopAsyncIteration:
  51. pass
  52. return session
  53. class TestCancelSafety:
  54. """Pin the cancel-safety contract end-to-end."""
  55. @pytest.mark.asyncio
  56. async def test_commit_on_clean_exit(self, fake_session_factory):
  57. session = fake_session_factory
  58. async def noop(_s):
  59. pass
  60. await _consume_get_db(noop)
  61. session.commit.assert_awaited_once()
  62. session.rollback.assert_not_awaited()
  63. session.close.assert_awaited_once()
  64. @pytest.mark.asyncio
  65. async def test_rollback_on_regular_exception(self, fake_session_factory):
  66. session = fake_session_factory
  67. gen = database.get_db()
  68. await gen.__anext__()
  69. with pytest.raises(ValueError):
  70. await gen.athrow(ValueError("route handler bug"))
  71. session.commit.assert_not_awaited()
  72. session.rollback.assert_awaited_once()
  73. session.close.assert_awaited_once()
  74. @pytest.mark.asyncio
  75. async def test_rollback_on_cancelled_error(self, fake_session_factory):
  76. """The actual #1112 fix: CancelledError must NOT skip the rollback.
  77. Pre-fix `except Exception` caught nothing because CancelledError
  78. is a BaseException, not an Exception."""
  79. session = fake_session_factory
  80. gen = database.get_db()
  81. await gen.__anext__()
  82. with pytest.raises(asyncio.CancelledError):
  83. await gen.athrow(asyncio.CancelledError("client disconnected"))
  84. session.commit.assert_not_awaited()
  85. session.rollback.assert_awaited_once()
  86. session.close.assert_awaited_once()
  87. @pytest.mark.asyncio
  88. async def test_close_runs_even_if_rollback_raises(self, fake_session_factory):
  89. """A failing rollback (broken connection during cancellation) must
  90. not prevent `close` from running — otherwise the pool would never
  91. reclaim the connection."""
  92. session = fake_session_factory
  93. session.rollback.side_effect = OSError("broken pipe during rollback")
  94. gen = database.get_db()
  95. await gen.__anext__()
  96. with pytest.raises(asyncio.CancelledError):
  97. await gen.athrow(asyncio.CancelledError())
  98. session.rollback.assert_awaited_once()
  99. session.close.assert_awaited_once()
  100. @pytest.mark.asyncio
  101. async def test_close_failure_does_not_propagate(self, fake_session_factory):
  102. """A failing close on the clean-exit path must not raise out of
  103. `get_db` — the request already succeeded."""
  104. session = fake_session_factory
  105. session.close.side_effect = OSError("close failed")
  106. async def noop(_s):
  107. pass
  108. # Must not raise.
  109. await _consume_get_db(noop)
  110. session.commit.assert_awaited_once()
  111. session.close.assert_awaited_once()
  112. @pytest.mark.asyncio
  113. async def test_rollback_uses_shield(self, fake_session_factory):
  114. """Cancellation arriving DURING rollback must not abort the
  115. rollback — `asyncio.shield` keeps it running. Verify the call
  116. path goes through `shield` so future refactors don't silently
  117. drop the protection."""
  118. # The fixture wires the fake session into `database.async_session`;
  119. # we don't need the local handle here.
  120. with patch.object(asyncio, "shield", wraps=asyncio.shield) as shield:
  121. gen = database.get_db()
  122. await gen.__anext__()
  123. with pytest.raises(asyncio.CancelledError):
  124. await gen.athrow(asyncio.CancelledError())
  125. # rollback + close both shielded.
  126. assert shield.call_count == 2