feat: Add RS256 signing and JWKS support to common/jwt library#10899
feat: Add RS256 signing and JWKS support to common/jwt library#10899
Conversation
Extend the JWT library to support RS256 asymmetric signing alongside existing HS256, and add JWKS fetching/caching utilities for distributed token validation. Foundation for OAuth2/OIDC Provider (Account Manager). Closes #10898 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR extends the ai.backend.common.jwt library to support RS256 asymmetric JWT signing/validation and adds JWKS parsing + async fetching with caching to enable distributed public-key validation (a prerequisite for upcoming OAuth2/OIDC provider work).
Changes:
- Added
JWTAlgorithm(HS256,RS256) and RSA key path fields toJWTConfig. - Implemented RSA key utilities (
keys.py) and JWKS utilities (jwks.py) plus validator support for JWKS-based key lookup. - Expanded
JWTClaimswith optional OAuth2-style claims and added unit tests covering RS256 and JWKS scenarios.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
src/ai/backend/common/jwt/config.py |
Introduces JWTAlgorithm enum and RSA key path configuration. |
src/ai/backend/common/jwt/types.py |
Adds optional claims (incl. kid, OAuth2-style fields) and conditional serialization. |
src/ai/backend/common/jwt/signer.py |
Extends token generation to support RS256 with RSA private keys and optional kid. |
src/ai/backend/common/jwt/validator.py |
Extends validation to support RS256 with RSA public keys and adds JWKS-based validation. |
src/ai/backend/common/jwt/keys.py |
New RSA key generation/loading/serialization + JWK conversion utilities. |
src/ai/backend/common/jwt/jwks.py |
New JWKS parsing and async fetcher with TTL caching. |
src/ai/backend/common/jwt/exceptions.py |
Adds JWKS-related exception types. |
src/ai/backend/common/jwt/__init__.py |
Exports new public API surface and updates imports. |
tests/unit/common/jwt/test_keys.py |
Tests RSA key generation, PEM round-trip, and JWK conversion. |
tests/unit/common/jwt/test_rs256.py |
Tests RS256 sign/verify behavior and HS256 backward compatibility. |
tests/unit/common/jwt/test_jwks.py |
Tests JWKS parsing and fetcher caching behavior. |
tests/unit/common/jwt/test_jwks_validation.py |
Tests validator JWKS flow, key rotation, and OAuth2 claims round-trip. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """ | ||
| Parse a JWKS JSON response into a JWKSKeySet. | ||
|
|
||
| Only RSA keys with ``"use": "sig"`` and a ``kid`` field are included. |
There was a problem hiding this comment.
The docstring says only RSA keys with "use": "sig" are included, but from_jwks_dict() doesn’t check the use field. Either enforce use == "sig" when present (and optionally alg == "RS256"), or update the docstring/behavior to match.
| Only RSA keys with ``"use": "sig"`` and a ``kid`` field are included. | |
| Only RSA keys with a ``kid`` field are included. |
| if kid is None: | ||
| continue | ||
| n = _base64url_to_int(jwk["n"]) | ||
| e = _base64url_to_int(jwk["e"]) | ||
| public_numbers = RSAPublicNumbers(e=e, n=n) |
There was a problem hiding this comment.
from_jwks_dict() directly indexes jwk["n"]/jwk["e"] and base64-decodes them. A single malformed JWK entry (missing fields / invalid base64) will raise and abort parsing the whole JWKS; consider skipping invalid entries (per-key try/except) so one bad key doesn’t break all validation.
| Async JWKS fetcher with TTL-based caching. | ||
|
|
||
| Fetches a JWKS endpoint and caches the result for a configurable duration. | ||
| Thread-safe for concurrent async access. | ||
|
|
There was a problem hiding this comment.
JWKSFetcher docstring claims it is “Thread-safe for concurrent async access”, but there’s no lock around _get_key_set() / _fetch_jwks(). Concurrent callers can race and trigger multiple simultaneous fetches and inconsistent cache updates; consider adding an asyncio.Lock (or remove the thread-safe claim).
| except JWKSFetchError: | ||
| raise | ||
| except Exception as e: | ||
| raise JWKSFetchError(f"Failed to fetch JWKS from {self._url}: {e}") from e | ||
|
|
||
| key_set = JWKSKeySet.from_jwks_dict(data) |
There was a problem hiding this comment.
_fetch_jwks() wraps network/JSON errors into JWKSFetchError, but parsing (JWKSKeySet.from_jwks_dict(data)) happens outside the try/except. If parsing raises (e.g., KeyError/ValueError), callers will see a raw exception instead of JWKSFetchError; wrap the parsing step too (or guarantee from_jwks_dict() never raises).
| except JWKSFetchError: | |
| raise | |
| except Exception as e: | |
| raise JWKSFetchError(f"Failed to fetch JWKS from {self._url}: {e}") from e | |
| key_set = JWKSKeySet.from_jwks_dict(data) | |
| key_set = JWKSKeySet.from_jwks_dict(data) | |
| except JWKSFetchError: | |
| raise | |
| except Exception as e: | |
| raise JWKSFetchError(f"Failed to fetch JWKS from {self._url}: {e}") from e |
| def validate_token_with_jwks( | ||
| self, | ||
| token: str, | ||
| jwks_key_set: JWKSKeySet, | ||
| ) -> JWTClaims: |
There was a problem hiding this comment.
validate_token_with_jwks() doesn’t enforce RS256 mode. If JWTConfig.algorithm is left as the default HS256, this method will attempt to validate using an RSA public key and produce confusing errors. Consider checking self._config.algorithm == JWTAlgorithm.RS256 up front and raising a clear JWTDecodeError when misconfigured.
| "access_key": str(self.access_key), | ||
| "role": self.role, | ||
| } | ||
| if self.kid is not None: | ||
| result["kid"] = self.kid |
There was a problem hiding this comment.
kid is a standard JWT header parameter, not a registered claim. Serializing it into the payload here means (1) tokens from external issuers that only set header kid will yield JWTClaims.kid=None, and (2) tokens from this library will carry a redundant/nonstandard payload field. Consider keeping kid exclusively in the header, or explicitly populating JWTClaims.kid from the decoded header during validation (especially in the JWKS flow).
| iat=now, | ||
| access_key=user_context.access_key, | ||
| role=user_context.role, | ||
| kid=kid, |
There was a problem hiding this comment.
generate_token() always injects kid into the JWT payload via JWTClaims(kid=kid) even though kid is used as a header parameter for RS256/JWKS. This duplicates the header for RS256 and also makes kid appear in HS256 tokens if a caller passes it; consider restricting kid to RS256 header-only to keep semantics consistent.
| kid=kid, |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
HyeockJinKim
left a comment
There was a problem hiding this comment.
I don't really see a reason to keep the existing schema—wouldn't it make more sense to create a separate one for RS256?
Summary
Extend the
common/jwtlibrary to support RS256 asymmetric signing alongside existing HS256, and add JWKS fetching/caching utilities for distributed token validation. This is the foundational step for the OAuth2/OIDC Provider work in Account Manager.Closes #10898 | Parent epic: #10500 (Phase 1, Step 1)
Changes
Modified files
common/jwt/config.py— AddedJWTAlgorithmenum (HS256,RS256), optionalprivate_key_pathandpublic_key_pathfields toJWTConfig. Fully backward compatible.common/jwt/types.py— Added optional OAuth2 claims toJWTClaims:kid,iss,aud,sub,scope,jti. Only serialized when set.common/jwt/signer.py— Extendedgenerate_token()to acceptprivate_key(RSAPrivateKey) +kidfor RS256. HS256 viasecret_keyunchanged.common/jwt/validator.py— Extendedvalidate_token()to acceptpublic_key(RSAPublicKey) for RS256. Addedvalidate_token_with_jwks()for JWKS-based key lookup bykid.common/jwt/exceptions.py— AddedJWKSError,JWKSFetchError,JWKSKeyNotFoundError.common/jwt/__init__.py— Exports all new public symbols; converted to absolute imports.New files
common/jwt/keys.py— RSA key utilities:generate_rsa_key_pair(),load_private_key(),load_public_key(), PEM serialization,public_key_to_jwk()for JWK format conversion.common/jwt/jwks.py—JWKSKeySet(parse JWKS JSON, index by kid) andJWKSFetcher(async HTTP fetcher with TTL-based caching viaaiohttp).New tests (4 files)
test_keys.py— RSA key generation, PEM round-trip, JWK conversiontest_rs256.py— RS256 sign/verify round-trip, kid header, wrong key rejection, expired token, HS256 backward compattest_jwks.py— JWKS parsing, key lookup by kid, non-RSA key filtering, fetcher caching (mocked HTTP)test_jwks_validation.py— End-to-end JWKS-based validation, key rotation scenarios, OAuth2 claims round-tripBackward Compatibility
JWTConfigdefaults to HS256,generate_token(ctx, secret_key)andvalidate_token(token, secret_key)signatures still worktest_signer.py,test_validator.py,test_types.py) pass unchangedTest plan
pants test tests/unit/common/jwt/)pants fmt/fix/lint/checkpass with no warnings🤖 Generated with Claude Code