diff --git a/setup.py b/setup.py index a615085a..82ff49c6 100644 --- a/setup.py +++ b/setup.py @@ -21,4 +21,10 @@ "flask", "rkvst-archivist" ], + extras_require=[ + "oidc": [ + "PyJWT", + "pycrypto", + ] + ], ) diff --git a/tests/test_cli.py b/tests/test_cli.py index f04f2cf5..80099a54 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,10 +2,16 @@ # Licensed under the MIT License. import os import threading +import requests import pytest +import jwt +import jwcrypto.jwk +from flask import Flask, jsonify +from werkzeug.wrappers import Request from werkzeug.serving import make_server from scitt_emulator import cli, server + issuer = "did:web:example.com" content_type = "application/json" payload = '{"foo": "bar"}' @@ -16,16 +22,23 @@ def execute_cli(argv): class Service: - def __init__(self, config): + def __init__(self, config, create_flask_app=None): self.config = config + self.create_flask_app = ( + create_flask_app + if create_flask_app is not None + else server.create_flask_app + ) def __enter__(self): - app = server.create_flask_app(self.config) - self.service_parameters_path = app.service_parameters_path - host = "127.0.0.1" - self.server = make_server(host, 0, app) + app = self.create_flask_app(self.config) + if hasattr(app, "service_parameters_path"): + self.service_parameters_path = app.service_parameters_path + self.host = "127.0.0.1" + self.server = make_server(self.host, 0, app) port = self.server.port - self.url = f"http://{host}:{port}" + self.url = f"http://{self.host}:{port}" + app.url = self.url self.thread = threading.Thread(name="server", target=self.server.serve_forever) self.thread.start() return self @@ -142,3 +155,158 @@ def test_client_cli(use_lro: bool, tmp_path): with open(receipt_path_2, "rb") as f: receipt_2 = f.read() assert receipt == receipt_2 + + +class OIDCAuthMiddleware: + def __init__(self, app, config): + self.app = app + self.config = config + + def __call__(self, environ, start_response): + request = Request(environ) + self.validate_token( + self.config, request.headers["Authorization"].replace("Bearer ", "") + ) + return self.app(environ, start_response) + + @staticmethod + def validate_token(config, token): + oidc_config = requests.get( + f"{config['issuer']}/.well-known/openid-configuration" + ).json() + jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"]) + jwt.decode( + token, + key=jwks_client.get_signing_key_from_jwt(token).key, + algorithms=oidc_config["id_token_signing_alg_values_supported"], + audience=config.get("audience", None), + issuer=oidc_config["issuer"], + options={"strict_aud": config.get("strict_aud", True),}, + leeway=config.get("leeway", 0), + ) + + +def create_flask_app_oidc_server(config): + app = Flask("oidc_server") + + app.config.update(dict(DEBUG=True)) + app.config.update(config) + + @app.route("/.well-known/openid-configuration", methods=["GET"]) + def openid_configuration(): + return jsonify( + { + "issuer": app.url, + "jwks_uri": f"{app.url}/.well-known/jwks", + "response_types_supported": ["id_token"], + "claims_supported": ["sub", "aud", "exp", "iat", "iss"], + "id_token_signing_alg_values_supported": app.config["algorithms"], + "scopes_supported": ["openid"], + } + ) + + @app.route("/.well-known/jwks", methods=["GET"]) + def jwks(): + return jsonify( + { + "keys": [ + { + **app.config["key"].export_public(as_dict=True), + "use": "sig", + "kid": app.config["key"].thumbprint(), + } + ] + } + ) + + return app + + +def test_client_cli_token(tmp_path): + workspace_path = tmp_path / "workspace" + + claim_path = tmp_path / "claim.cose" + receipt_path = tmp_path / "claim.receipt.cbor" + entry_id_path = tmp_path / "claim.entry_id.txt" + retrieved_claim_path = tmp_path / "claim.retrieved.cose" + + key = jwcrypto.jwk.JWK.generate(kty="RSA", size=2048) + audience = "urn:scitt" + algorithm = "RS256" + + oidc_config = { + "key": key, + "audience": audience, + "algorithms": [algorithm], + } + + with Service( + oidc_config, create_flask_app=create_flask_app_oidc_server, + ) as oidc_service: + os.environ["no_proxy"] = ",".join( + os.environ.get("no_proxy", "").split(",") + [oidc_service.host] + ) + payload = {"iss": oidc_service.url, "aud": audience} + token = jwt.encode( + payload, + key.export_to_pem(private_key=True, password=None), + algorithm=algorithm, + headers={"kid": key.thumbprint()}, + ) + with Service( + { + "middleware": lambda app: OIDCAuthMiddleware( + app, {**oidc_config, "issuer": oidc_service.url,}, + ), + "tree_alg": "CCF", + "workspace": workspace_path, + "error_rate": 0.1, + "use_lro": False, + } + ) as service: + # create claim + command = [ + "client", + "create-claim", + "--out", + claim_path, + "--issuer", + issuer, + "--content-type", + content_type, + "--payload", + payload, + ] + execute_cli(command) + assert os.path.exists(claim_path) + + # submit claim without token + command = [ + "client", + "submit-claim", + "--claim", + claim_path, + "--out", + receipt_path, + "--out-entry-id", + entry_id_path, + "--url", + service.url, + ] + check_error = None + try: + execute_cli(command) + except Exception as error: + check_error = error + assert check_error + assert not os.path.exists(receipt_path) + assert not os.path.exists(entry_id_path) + + # submit claim with token + command += [ + "--token", + token, + ] + execute_cli(command) + assert os.path.exists(receipt_path) + assert os.path.exists(entry_id_path)