From 051ea341b5573fe3edcd53042f347929b92c2b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Padilla?= Date: Thu, 12 Mar 2026 12:46:08 -0400 Subject: [PATCH] Merge commit from fork MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Padilla CVE: CVE-2026-32597 Upstream-Status: Backport [https://github.com/jpadilla/pyjwt/commit/051ea341b5573fe3edcd53042f347929b92c2b92] Signed-off-by: Hitendra Prajapati --- CHANGELOG.rst | 2 + jwt/api_jws.py | 27 +++++++++++++- tests/test_api_jws.py | 87 +++++++++++++++++++++++++++++++++++++++++++ tests/test_api_jwt.py | 18 +++++++++ 4 files changed, 132 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8bc2319..289f45b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -26,6 +26,8 @@ Changed Fixed ~~~~~ +- Validate the crit (Critical) Header Parameter defined in RFC 7515 §4.1.11. by @dmbs335 in + `GHSA-752w-5fwx-jx9f `__ Added ~~~~~ diff --git a/jwt/api_jws.py b/jwt/api_jws.py index fa6708c..1750442 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -129,7 +129,7 @@ class PyJWS: header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} if headers: - self._validate_headers(headers) + self._validate_headers(headers, encoding=True) header.update(headers) if not header["typ"]: @@ -197,6 +197,8 @@ class PyJWS: payload, signing_input, header, signature = self._load(jwt) + self._validate_headers(header) + if header.get("b64", True) is False: if detached_payload is None: raise DecodeError( @@ -309,14 +311,35 @@ class PyJWS: if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") - def _validate_headers(self, headers: dict[str, Any]) -> None: + # Extensions that PyJWT actually understands and supports + _supported_crit: set[str] = {"b64"} + + def _validate_headers( + self, headers: dict[str, Any], *, encoding: bool = False + ) -> None: if "kid" in headers: self._validate_kid(headers["kid"]) + if not encoding and "crit" in headers: + self._validate_crit(headers) def _validate_kid(self, kid: Any) -> None: if not isinstance(kid, str): raise InvalidTokenError("Key ID header parameter must be a string") + def _validate_crit(self, headers: dict[str, Any]) -> None: + crit = headers["crit"] + if not isinstance(crit, list) or len(crit) == 0: + raise InvalidTokenError("Invalid 'crit' header: must be a non-empty list") + for ext in crit: + if not isinstance(ext, str): + raise InvalidTokenError("Invalid 'crit' header: values must be strings") + if ext not in self._supported_crit: + raise InvalidTokenError(f"Unsupported critical extension: {ext}") + if ext not in headers: + raise InvalidTokenError( + f"Critical extension '{ext}' is missing from headers" + ) + _jws_global_obj = PyJWS() encode = _jws_global_obj.encode diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 3385716..434874b 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -815,3 +815,90 @@ class TestJWS: ) assert len(record) == 1 assert "foo" in str(record[0].message) + + def test_decode_rejects_unknown_crit_extension( + self, jws: PyJWS, payload: bytes + ) -> None: + secret = "secret" + token = jws.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": ["x-custom-policy"], "x-custom-policy": "require-mfa"}, + ) + with pytest.raises(InvalidTokenError, match="Unsupported critical extension"): + jws.decode(token, secret, algorithms=["HS256"]) + def test_decode_rejects_empty_crit(self, jws: PyJWS, payload: bytes) -> None: + secret = "secret" + token = jws.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": []}, + ) + with pytest.raises(InvalidTokenError, match="must be a non-empty list"): + jws.decode(token, secret, algorithms=["HS256"]) + def test_decode_rejects_non_list_crit(self, jws: PyJWS, payload: bytes) -> None: + secret = "secret" + token = jws.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": "b64"}, + ) + with pytest.raises(InvalidTokenError, match="must be a non-empty list"): + jws.decode(token, secret, algorithms=["HS256"]) + def test_decode_rejects_crit_with_non_string_values( + self, jws: PyJWS, payload: bytes + ) -> None: + secret = "secret" + token = jws.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": [123]}, + ) + with pytest.raises(InvalidTokenError, match="values must be strings"): + jws.decode(token, secret, algorithms=["HS256"]) + def test_decode_rejects_crit_extension_missing_from_header( + self, jws: PyJWS, payload: bytes + ) -> None: + secret = "secret" + token = jws.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": ["b64"]}, + ) + with pytest.raises(InvalidTokenError, match="missing from headers"): + jws.decode(token, secret, algorithms=["HS256"]) + def test_decode_accepts_supported_crit_extension( + self, jws: PyJWS, payload: bytes + ) -> None: + secret = "secret" + token = jws.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": ["b64"], "b64": False}, + is_payload_detached=True, + ) + decoded = jws.decode( + token, + secret, + algorithms=["HS256"], + detached_payload=payload, + ) + assert decoded == payload + def test_get_unverified_header_rejects_unknown_crit( + self, jws: PyJWS, payload: bytes + ) -> None: + secret = "secret" + token = jws.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": ["x-unknown"], "x-unknown": "value"}, + ) + with pytest.raises(InvalidTokenError, match="Unsupported critical extension"): + jws.get_unverified_header(token) diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 82b9299..618b60f 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -802,3 +802,21 @@ class TestJWT: options={"strict_aud": True}, algorithms=["HS256"], ) + + # -------------------- Crit Header Tests -------------------- + + def test_decode_rejects_token_with_unknown_crit_extension(self, jwt: PyJWT) -> None: + """RFC 7515 §4.1.11: tokens with unsupported critical extensions MUST be rejected.""" + from jwt.exceptions import InvalidTokenError + + secret = "secret" + payload = {"sub": "attacker", "role": "admin"} + token = jwt.encode( + payload, + secret, + algorithm="HS256", + headers={"crit": ["x-custom-policy"], "x-custom-policy": "require-mfa"}, + ) + + with pytest.raises(InvalidTokenError, match="Unsupported critical extension"): + jwt.decode(token, secret, algorithms=["HS256"]) -- 2.50.1