diff --git a/fastapi_jwt/jwt.py b/fastapi_jwt/jwt.py index 4b86a1f..1946a05 100644 --- a/fastapi_jwt/jwt.py +++ b/fastapi_jwt/jwt.py @@ -1,6 +1,7 @@ from abc import ABC from datetime import datetime, timedelta -from typing import Any, Dict, Optional, Set, Type +from os import getenv +from typing import Any, Dict, Literal, Optional, Set, Type from uuid import uuid4 from fastapi.exceptions import HTTPException @@ -20,6 +21,9 @@ else: # pragma: nocover raise ImportError("No JWT backend found, please install 'python-jose' or 'authlib'") +ACCESS_COOKIE_NAME = getenv("JWT_ACCESS_COOKIE_NAME", "access_token_cookie") +REFRESH_COOKIE_NAME = getenv("JWT_REFRESH_COOKIE_NAME", "refresh_token_cookie") + def force_jwt_backend(cls: Type[AbstractJWTBackend]) -> None: global DEFAULT_JWT_BACKEND @@ -61,11 +65,11 @@ def __getitem__(self, item: str) -> Any: class JwtAuthBase(ABC): class JwtAccessCookie(APIKeyCookie): def __init__(self, *args: Any, **kwargs: Any): - APIKeyCookie.__init__(self, *args, name="access_token_cookie", auto_error=False, **kwargs) + APIKeyCookie.__init__(self, *args, name=ACCESS_COOKIE_NAME, auto_error=False, **kwargs) class JwtRefreshCookie(APIKeyCookie): def __init__(self, *args: Any, **kwargs: Any): - APIKeyCookie.__init__(self, *args, name="refresh_token_cookie", auto_error=False, **kwargs) + APIKeyCookie.__init__(self, *args, name=REFRESH_COOKIE_NAME, auto_error=False, **kwargs) class JwtAccessBearer(HTTPBearer): def __init__(self, *args: Any, **kwargs: Any): @@ -181,12 +185,19 @@ def create_refresh_token( return self.jwt_backend.encode(to_encode, self.secret_key) @staticmethod - def set_access_cookie(response: Response, access_token: str, expires_delta: Optional[timedelta] = None) -> None: + def set_access_cookie( + response: Response, + access_token: str, + samesite: Optional[Literal["lax", "strict", "none"]] = "lax", + httponly: bool = False, + expires_delta: Optional[timedelta] = None, + ) -> None: seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None response.set_cookie( - key="access_token_cookie", + key=ACCESS_COOKIE_NAME, value=access_token, - httponly=False, + httponly=httponly, + samesite=samesite, max_age=seconds_expires, ) @@ -194,23 +205,26 @@ def set_access_cookie(response: Response, access_token: str, expires_delta: Opti def set_refresh_cookie( response: Response, refresh_token: str, + samesite: Optional[Literal["lax", "strict", "none"]] = "lax", + httponly: bool = True, expires_delta: Optional[timedelta] = None, ) -> None: seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None response.set_cookie( - key="refresh_token_cookie", + key=REFRESH_COOKIE_NAME, value=refresh_token, - httponly=True, + httponly=httponly, + samesite=samesite, max_age=seconds_expires, ) @staticmethod def unset_access_cookie(response: Response) -> None: - response.set_cookie(key="access_token_cookie", value="", httponly=False, max_age=-1) + response.set_cookie(key=ACCESS_COOKIE_NAME, value="", httponly=False, max_age=-1) @staticmethod def unset_refresh_cookie(response: Response) -> None: - response.set_cookie(key="refresh_token_cookie", value="", httponly=True, max_age=-1) + response.set_cookie(key=REFRESH_COOKIE_NAME, value="", httponly=True, max_age=-1) class JwtAccess(JwtAuthBase):