| import json |
| import base64 |
| from cryptography.hazmat.primitives import serialization |
| from cryptography.hazmat.primitives import hashes |
| from cryptography.hazmat.primitives.asymmetric import rsa, padding |
| |
| # This method decodes the JWT and verifies the signature. If a key is provided, |
| # that will be used for signature verification. Otherwise, the key sent within |
| # the JWT payload will be used instead. |
| # This returns a tuple of (decoded_header, decoded_payload, verify_succeeded). |
| def decode_jwt(token, key=None): |
| try: |
| # Decode the header and payload. |
| header, payload, signature = token.split('.') |
| decoded_header = decode_base64_json(header) |
| decoded_payload = decode_base64_json(payload) |
| |
| # If decoding failed, return nothing. |
| if not decoded_header or not decoded_payload: |
| return None, None, False |
| |
| # If there is a key passed in (for refresh), use that for checking the signature below. |
| # Otherwise (for registration), use the key sent within the JWT to check the signature. |
| if key == None: |
| key = decoded_payload.get('key') |
| public_key = serialization.load_pem_public_key(jwk_to_pem(key)) |
| # Verifying the signature will throw an exception if it fails. |
| verify_rs256_signature(header, payload, signature, public_key) |
| return decoded_header, decoded_payload, True |
| except Exception: |
| return None, None, False |
| |
| def jwk_to_pem(jwk_data): |
| jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data |
| key_type = jwk.get("kty") |
| |
| if key_type != "RSA": |
| raise ValueError(f"Unsupported key type: {key_type}") |
| |
| n = int.from_bytes(decode_base64url(jwk["n"]), 'big') |
| e = int.from_bytes(decode_base64url(jwk["e"]), 'big') |
| public_key = rsa.RSAPublicNumbers(e, n).public_key() |
| pem_public_key = public_key.public_bytes( |
| encoding=serialization.Encoding.PEM, |
| format=serialization.PublicFormat.SubjectPublicKeyInfo |
| ) |
| return pem_public_key |
| |
| def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key): |
| message = (f'{encoded_header}.{encoded_payload}').encode('utf-8') |
| signature_bytes = decode_base64(signature) |
| # This will throw an exception if verification fails. |
| public_key.verify( |
| signature_bytes, |
| message, |
| padding.PKCS1v15(), |
| hashes.SHA256() |
| ) |
| |
| def add_base64_padding(encoded_data): |
| remainder = len(encoded_data) % 4 |
| if remainder > 0: |
| encoded_data += '=' * (4 - remainder) |
| return encoded_data |
| |
| def decode_base64url(encoded_data): |
| encoded_data = add_base64_padding(encoded_data) |
| encoded_data = encoded_data.replace("-", "+").replace("_", "/") |
| return base64.b64decode(encoded_data) |
| |
| def decode_base64(encoded_data): |
| encoded_data = add_base64_padding(encoded_data) |
| return base64.urlsafe_b64decode(encoded_data) |
| |
| def decode_base64_json(encoded_data): |
| return json.loads(decode_base64(encoded_data)) |