Source code for clarin.sru.fcs.server.auth

import logging
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional

import jwt
from clarin.sru.constants import SRUDiagnostics
from clarin.sru.exception import SRUConfigException
from clarin.sru.exception import SRUException
from clarin.sru.server.auth import SRUAuthenticationInfo
from clarin.sru.server.auth import SRUAuthenticationInfoProvider
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from jwt.algorithms import RSAAlgorithm
from werkzeug import Request

# from typing import Self  # 3.11


# ---------------------------------------------------------------------------


LOGGER = logging.getLogger(__name__)


ALGORITHM_NAME = "RS256"


# ---------------------------------------------------------------------------


[docs]class AuthenticationInfo(SRUAuthenticationInfo): def __init__(self, subject: str): super().__init__() self._subject = subject @property def authentication_method(self) -> str: return "JWT" @property def subject(self) -> str: return self._subject
[docs]@dataclass(frozen=True) class Key: key_id: str public_key: RSAPublicKey
[docs]@dataclass(frozen=True) class Verifier: key_id: str public_key: RSAPublicKey jwt: jwt.PyJWT claims: Dict[str, Any]
[docs] def decode(self, token: str, no_verification: bool = False) -> Dict[str, Any]: # TODO: we could probably bypass most checks by using internals like jwt.PyJWS()._load(token) if no_verification: return self.jwt.decode(token, options=dict(verify_signature=False)) return self.jwt.decode( token, self.public_key, # type: ignore algorithms=[ALGORITHM_NAME], **self.claims, )
[docs] def verify(self, token: str) -> bool: # from `api_jws.py` -> `decode_complete`/`_load`/`_verify_signature` # TODO: not sure, better to use the stable JWT API instead of internals payload, signing_input, header, signature = jwt.PyJWS()._load(token) # algorithm = RSAAlgorithm(RSAAlgorithm.SHA256) # return algorithm.verify(signing_input, self.public_key, signature) try: self.jwt.decode( token, self.public_key, # type: ignore algorithms=[ALGORITHM_NAME], **self.claims, ) except ( jwt.exceptions.InvalidAudienceError, jwt.exceptions.ExpiredSignatureError, # jwt.exceptions.InvalidIssuerError, # NOTE: test cases required, but not 'issuers' not set, so not raised jwt.exceptions.InvalidIssuedAtError, # NOTE: test cases required ): # e.g. 'Invalid audience' # e.g. 'Signature has expired' # TODO: what if verifiers with different leeways? raise except jwt.exceptions.MissingRequiredClaimError: # this is required, if missing, raise? # e.g. 'Token is missing the "aud" claim' raise except jwt.exceptions.ImmatureSignatureError: # TODO: fail on signature not yet 'live'? # e.g. 'The token is not yet valid (iat)' return False except ( jwt.exceptions.InvalidSignatureError, # is a DecodeError jwt.exceptions.DecodeError, jwt.exceptions.InvalidAlgorithmError, ): # here we would optimally try other keys as we could serve multiple clients # e.g. 'Signature verification failed' # e.g. 'Invalid crypto padding' / 'Invalid header string: ...' - some parts are missing return False return True
# ---------------------------------------------------------------------------
[docs]class AuthenticationProvider(SRUAuthenticationInfoProvider): def __init__(self, verifiers: List[Verifier]) -> None: super().__init__() self.verifiers = verifiers @property def key_count(self) -> int: if not self.verifiers: return 0 return len(self.verifiers) # ----------------------------------------------------
[docs] def get_AuthenticationInfo( self, request: Request ) -> Optional[SRUAuthenticationInfo]: value = request.headers.get("Authentication") if not value or value.isspace(): return None LOGGER.debug("Found Authentication header with: '%s'", value) if not value.lower().startswith("bearer"): LOGGER.debug( "Authentication header with incorrect format. Expected start: 'Bearer '" ) return None token = value[6:].strip() if not token: LOGGER.debug("No bearer token in Authentication header?") return None return self._check_token(token)
def _check_token(self, token_raw: str) -> AuthenticationInfo: try: token = jwt.PyJWT().decode(token_raw, options=dict(verify_signature=False)) LOGGER.debug( "Token: jti=%s, iss=%s, aud=%s, sub=%s, iat=%s, exp=%s, nbt=%s", token.get("jti"), token.get("iss"), token.get("aud"), token.get("sub"), token.get("iat"), token.get("exp"), token.get("nbt"), ) # TODO: sanitize 'sub' of JWT token token["sub"] = str(token["sub"]) if "sub" in token else "" if not self.verifiers: LOGGER.warning("No JWT verifiers found. Return unverified 'sub' claim.") # NOTE: `token.get("sub")` should return a `str` return AuthenticationInfo(str(token.get("sub"))) for verifier in self.verifiers: try: LOGGER.debug( "Trying to verify token with key '%s'", verifier.key_id ) if verifier.verify(token_raw): return AuthenticationInfo(str(token.get("sub"))) # TODO: figure out exceptions we want to catch or ignore except jwt.exceptions.InvalidAudienceError as ex: # java: InvalidClaimException # e.g. audience does not match raise SRUException( SRUDiagnostics.AUTHENTICATION_ERROR, "error processing request authentication", message=str(ex), ) from ex except jwt.exceptions.ExpiredSignatureError as ex: raise SRUException( SRUDiagnostics.AUTHENTICATION_ERROR, "error processing request authentication", message="token expired", ) from ex raise SRUException( SRUDiagnostics.AUTHENTICATION_ERROR, "error processing request authentication", message="Could not verify JSON Web token signature.", ) except jwt.PyJWTError as ex: raise SRUException( SRUDiagnostics.AUTHENTICATION_ERROR, "error processing request authentication", message="Could not decode JSON Web token", ) from ex # ----------------------------------------------------
[docs] class Builder: def __init__(self): self.keys = List[Key] self.audiences = List[str] self.ignore_IssuedAt = False self.leeway_IssuedAt = -1 self.leeway_ExpiresAt = -1 self.leeway_NotBefore = -1
[docs] @classmethod def create(cls) -> "AuthenticationProvider.Builder": return AuthenticationProvider.Builder()
[docs] def with_audience(self, audience: str) -> "AuthenticationProvider.Builder": self.audiences.add(audience) return self
[docs] def with_ignore_IssuedAt(self) -> "AuthenticationProvider.Builder": self.ignore_IssuedAt = True return self
[docs] def with_IssuedAt(self, leeway: int) -> "AuthenticationProvider.Builder": if leeway < 0: raise ValueError("leeway < 0") self.leeway_IssuedAt = leeway return self
[docs] def with_ExpiresAt(self, leeway: int) -> "AuthenticationProvider.Builder": if leeway < 0: raise ValueError("leeway < 0") self.leeway_ExpiresAt = leeway return self
[docs] def with_NotBefore(self, leeway: int) -> "AuthenticationProvider.Builder": if leeway < 0: raise ValueError("leeway < 0") self.leeway_NotBefore = leeway return self
[docs] def with_public_key( self, key_id: str, key: Any ) -> "AuthenticationProvider.Builder": self._load_public_key(key_id, key) return self
def _load_public_key(self, key_id: str, key: Any) -> None: # RSASSA-PKCS1-v1_5 with SHA-256 ("RS256") # see: https://www.rfc-editor.org/rfc/rfc7519#section-8 # https://pyjwt.readthedocs.io/en/stable/faq.html#how-can-i-extract-a-public-private-key-from-a-x509-certificate try: # algorithm_name = "RS256" algorithm = RSAAlgorithm(RSAAlgorithm.SHA256) key_obj: RSAPublicKey = algorithm.prepare_key(key) self.keys[key_id] = key_obj except (ValueError, TypeError) as ex: raise SRUConfigException(f"Failed to load key '{key_id}'") from ex
[docs] def build(self) -> "AuthenticationProvider": verifiers: List[Verifier] = list() for key_id, key in self.keys.items(): claims = dict() if self.audiences: claims["audience"] = list(self.audiences) options = dict() if self.ignore_IssuedAt: options["verify_iat"] = False else: pass if self.leeway_ExpiresAt > 0: options["leeway"] = self.leeway_ExpiresAt # both IssuedAd (IAT) and NotBefore (NBF) should be the # same based on fcs-simple-client (auth) # unfortunately PyJWT does not have separate leeways # see: com.auth0.jwt.interfaces.Verification # TODO: warn? or just use min/max/avg from all? # if self.leeway_IssuedAt > 0: # options["leeway"] = self.leeway_IssuedAt # if self.leeway_NotBefore > 0: # options["leeway"] = self.leeway_NotBefore verifiers.append( Verifier( key_id=key_id, public_key=key, jwt=jwt.PyJWT(**options), claims=claims, ) ) return AuthenticationProvider(verifiers)
# ---------------------------------------------------------------------------