Просмотр исходного кода

fix(oidc): normalise trailing slash on both sides of issuer comparison (#995)

PyJWT compares the iss claim against discovery_issuer with an exact string
match. Authentik (and similar providers) include a trailing slash in the JWT
iss claim while the discovery document issuer may omit it, or vice-versa.

Disable PyJWT built-in issuer validation and compare both sides after
rstrip('/') to make the check slash-agnostic.

Adds a regression test that verifies a login succeeds when the provider is
configured without a trailing slash but the JWT iss claim carries one.
Sn0rrii 1 месяц назад
Родитель
Сommit
071570f754
2 измененных файлов с 127 добавлено и 6 удалено
  1. 9 6
      backend/app/api/routes/mfa.py
  2. 118 0
      backend/tests/integration/test_mfa_api.py

+ 9 - 6
backend/app/api/routes/mfa.py

@@ -1329,8 +1329,8 @@ async def oidc_callback(
         # ── Step 3: Fetch JWKS and validate ID token ─────────────────────────
         # ── Step 3: Fetch JWKS and validate ID token ─────────────────────────
         # Use the issuer from the discovery document as the canonical value (OIDC Core
         # Use the issuer from the discovery document as the canonical value (OIDC Core
         # §3.1.3.7 requires iss == discovery issuer exactly).  We strip trailing slashes
         # §3.1.3.7 requires iss == discovery issuer exactly).  We strip trailing slashes
-        # from both sides because some providers (e.g. older PocketID versions) are
-        # inconsistent between the discovery issuer and the JWT iss claim.
+        # from both sides because some providers (e.g. Authentik, older PocketID versions)
+        # are inconsistent between the discovery issuer and the JWT iss claim.
         discovery_issuer: str = discovery.get("issuer", provider.issuer_url).rstrip("/")
         discovery_issuer: str = discovery.get("issuer", provider.issuer_url).rstrip("/")
         try:
         try:
             async with httpx.AsyncClient(timeout=10) as jwks_http:
             async with httpx.AsyncClient(timeout=10) as jwks_http:
@@ -1342,16 +1342,19 @@ async def oidc_callback(
             jwks_client.fetch_data = lambda: jwks_data  # type: ignore[method-assign]
             jwks_client.fetch_data = lambda: jwks_data  # type: ignore[method-assign]
             signing_key = jwks_client.get_signing_key_from_jwt(id_token)
             signing_key = jwks_client.get_signing_key_from_jwt(id_token)
 
 
-            # M-3: Use PyJWT native issuer validation (issuer= parameter) instead of
-            # decoding with verify_iss=False and checking manually.  PyJWT will raise
-            # InvalidIssuerError when iss != discovery_issuer, which is caught below.
+            # M-3: Decode without built-in issuer check, then compare normalised
+            # (both sides rstrip("/")) to handle providers like Authentik that include
+            # a trailing slash in iss but not in the discovery issuer, or vice-versa.
             claims = jwt.decode(
             claims = jwt.decode(
                 id_token,
                 id_token,
                 signing_key.key,
                 signing_key.key,
                 algorithms=["RS256", "ES256", "RS384", "ES384", "RS512"],
                 algorithms=["RS256", "ES256", "RS384", "ES384", "RS512"],
                 audience=provider.client_id,
                 audience=provider.client_id,
-                issuer=discovery_issuer,
+                options={"verify_iss": False},
             )
             )
+            token_iss = claims.get("iss", "").rstrip("/")
+            if token_iss != discovery_issuer:
+                raise jwt.exceptions.InvalidIssuerError("Invalid issuer")
         except Exception as exc:
         except Exception as exc:
             logger.error("OIDC JWT validation failed for provider %d: %s", provider_id, exc, exc_info=True)
             logger.error("OIDC JWT validation failed for provider %d: %s", provider_id, exc, exc_info=True)
             return RedirectResponse(url=f"{frontend_error_url}token_validation_failed", status_code=302)
             return RedirectResponse(url=f"{frontend_error_url}token_validation_failed", status_code=302)

+ 118 - 0
backend/tests/integration/test_mfa_api.py

@@ -3079,3 +3079,121 @@ class TestOIDCIssuerUrlTrailingSlash:
         assert called_url.endswith("/.well-known/openid-configuration"), (
         assert called_url.endswith("/.well-known/openid-configuration"), (
             f"Expected discovery URL to end with /.well-known/openid-configuration, got: {called_url}"
             f"Expected discovery URL to end with /.well-known/openid-configuration, got: {called_url}"
         )
         )
+
+    @pytest.mark.asyncio
+    @pytest.mark.integration
+    async def test_iss_claim_trailing_slash_accepted(
+        self, async_client: AsyncClient, db_session: AsyncSession
+    ):
+        """Provider configured without trailing slash, Authentik JWT iss has trailing slash.
+
+        Both sides must be normalised before comparison so the login succeeds.
+        """
+        import time
+        from unittest.mock import patch
+
+        import jwt as pyjwt
+
+        private_pem, jwks_data = _make_test_rsa_key()
+        issuer_no_slash = "https://authentik.example.com/application/o/bambuddy"
+        issuer_with_slash = issuer_no_slash + "/"
+        client_id = "bambuddy-client"
+        nonce = secrets.token_urlsafe(16)
+
+        now = int(time.time())
+        id_token = pyjwt.encode(
+            {
+                "sub": "authentik-sub-123",
+                "iss": issuer_with_slash,
+                "aud": client_id,
+                "nonce": nonce,
+                "email": "authentik-user@example.com",
+                "email_verified": True,
+                "iat": now,
+                "exp": now + 300,
+            },
+            private_pem,
+            algorithm="RS256",
+            headers={"kid": "test-kid-1"},
+        )
+
+        admin_token = await _setup_and_login(async_client, "authentikadm", "authentikadm1")
+        create_resp = await async_client.post(
+            "/api/v1/auth/oidc/providers",
+            json={
+                "name": "Authentik-ISS",
+                "issuer_url": issuer_no_slash,
+                "client_id": client_id,
+                "client_secret": "secret",
+                "scopes": "openid email profile",
+                "is_enabled": True,
+                "auto_create_users": True,
+            },
+            headers=_auth_header(admin_token),
+        )
+        assert create_resp.status_code == 201
+        provider_id = create_resp.json()["id"]
+
+        state = secrets.token_urlsafe(32)
+        code_verifier = secrets.token_urlsafe(48)
+        db_session.add(
+            AuthEphemeralToken(
+                token=state,
+                token_type="oidc_state",
+                provider_id=provider_id,
+                nonce=nonce,
+                code_verifier=code_verifier,
+                expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
+            )
+        )
+        await db_session.commit()
+
+        discovery_doc = {
+            "issuer": issuer_with_slash,
+            "authorization_endpoint": f"{issuer_no_slash}/authorize",
+            "token_endpoint": f"{issuer_no_slash}/token",
+            "jwks_uri": f"{issuer_no_slash}/.well-known/jwks.json",
+        }
+        token_response = {"access_token": "mock", "token_type": "Bearer", "id_token": id_token}
+
+        class _MockResp:
+            def __init__(self, data):
+                self._data = data
+                self.is_success = True
+                self.status_code = 200
+                self.text = str(data)
+
+            def json(self):
+                return self._data
+
+            def raise_for_status(self):
+                pass
+
+        class _MockHttpxClient:
+            def __init__(self, *a, **kw):
+                pass
+
+            async def __aenter__(self):
+                return self
+
+            async def __aexit__(self, *a):
+                pass
+
+            async def get(self, url, **kw):
+                return _MockResp(jwks_data if "jwks" in url else discovery_doc)
+
+            async def post(self, url, **kw):
+                return _MockResp(token_response)
+
+        with patch("backend.app.api.routes.mfa.httpx.AsyncClient", _MockHttpxClient):
+            resp = await async_client.get(
+                f"/api/v1/auth/oidc/callback?code=auth-code&state={state}",
+                follow_redirects=False,
+            )
+
+        location = resp.headers.get("location", "")
+        assert resp.status_code == 302, f"Expected redirect, got {resp.status_code}"
+        assert "token_validation_failed" not in location, (
+            "Trailing slash mismatch in iss claim must not cause token_validation_failed"
+        )
+        assert "oidc_token=" in location, f"Expected oidc_token in redirect, got: {location}"