Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Http exception handler in the app for token expiration handling #6625

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/api/comments/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ async def post(
message=message,
user_id=user.id,
project_id=project_id,
timestamp=datetime.now(),
# timestamp=datetime.now(),
timestamp=datetime.utcnow(),
username=user.username,
)
try:
Expand Down
32 changes: 25 additions & 7 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
import sys
from fastapi import FastAPI
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from loguru import logger as log
from starlette.middleware.authentication import AuthenticationMiddleware
from pyinstrument import Profiler
from backend.routes import add_api_end_points
from backend.services.users.authentication_service import TokenAuthBackend
from starlette.middleware.authentication import AuthenticationMiddleware

from backend.config import settings
from backend.db import db_connection
from contextlib import asynccontextmanager
from backend.routes import add_api_end_points
from backend.services.users.authentication_service import TokenAuthBackend


def get_application() -> FastAPI:
Expand Down Expand Up @@ -69,8 +71,24 @@ async def pyinstrument_middleware(request, call_next):
AuthenticationMiddleware, backend=TokenAuthBackend(), on_error=None
)

add_api_end_points(_app)
# Custom exception handler for 401 errors
@_app.exception_handler(HTTPException)
async def custom_http_exception_handler(request: Request, exc: HTTPException):
if exc.status_code == 401 and "InvalidToken" in exc.detail.get("SubCode", ""):
return JSONResponse(
content={
"Error": exc.detail["Error"],
"SubCode": exc.detail["SubCode"],
},
status_code=exc.status_code,
headers={"WWW-Authenticate": "Bearer"},
)
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail},
)

add_api_end_points(_app)
return _app


Expand Down
9 changes: 4 additions & 5 deletions backend/services/messaging/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import threading

# from flask import current_app
from databases import Database

from backend.exceptions import NotFound
from backend.models.dtos.message_dto import ChatMessageDTO, ProjectChatDTO
from backend.models.postgis.project import ProjectStatus
from backend.models.postgis.project_chat import ProjectChat
from backend.models.postgis.project_info import ProjectInfo
from backend.models.postgis.statuses import TeamRoles
from backend.services.messaging.message_service import MessageService
from backend.services.project_service import ProjectService
from backend.services.project_admin_service import ProjectAdminService
from backend.services.project_service import ProjectService
from backend.services.team_service import TeamService
from backend.models.postgis.statuses import TeamRoles
from backend.models.postgis.project import ProjectStatus
from databases import Database


class ChatService:
Expand Down
52 changes: 32 additions & 20 deletions backend/services/users/authentication_service.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import base64
import binascii
import urllib.parse
from backend.models.postgis.user import User
from backend.models.dtos.user_dto import AuthUserDTO
from random import SystemRandom
from typing import Optional

from databases import Database
from fastapi import HTTPException, Security, status
from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from loguru import logger
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
SimpleUser,
)
from loguru import logger
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired

from backend.api.utils import TMAPIDecorators
from backend.services.messaging.message_service import MessageService
from backend.services.users.user_service import UserService, NotFound
from random import SystemRandom
from backend.config import settings
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
from databases import Database
from typing import Optional
from backend.models.dtos.user_dto import AuthUserDTO
from backend.models.postgis.user import User
from backend.services.messaging.message_service import MessageService
from backend.services.users.user_service import NotFound, UserService

# token_auth = HTTPTokenAuth(scheme="Token")
tm = TMAPIDecorators()
Expand All @@ -33,7 +35,10 @@
# @token_auth.error_handler
def handle_unauthorized_token():
logger.debug("Token not valid")
return {"Error": "Token is expired or invalid", "SubCode": "InvalidToken"}, 401
return JSONResponse(
content={"Error": "Token is expired or invalid", "SubCode": "InvalidToken"},
status_code=401,
)


# @token_auth.verify_token
Expand All @@ -49,7 +54,7 @@ def verify_token(token):
logger.debug("Unable to decode token")
return False # Can't decode token, so fail login

valid_token, user_id = AuthenticationService.is_valid_token(decoded_token, 604800)
valid_token, user_id = AuthenticationService.is_valid_token(decoded_token, 120)
if not valid_token:
logger.debug("Token not valid")
return False
Expand Down Expand Up @@ -82,9 +87,14 @@ async def authenticate(self, conn):
decoded_token, 604800
)
if not valid_token:
logger.debug("Token not valid")
return AuthCredentials([]), None

raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"Error": "Token is expired or invalid",
"SubCode": "InvalidToken",
},
headers={"WWW-Authenticate": "Bearer"},
)
tm.authenticated_user_id = (
user_id # Set the user ID on the decorator as a convenience
)
Expand Down Expand Up @@ -245,8 +255,11 @@ async def login_required(
valid_token, user_id = AuthenticationService.is_valid_token(decoded_token, 604800)
if not valid_token:
logger.debug("Token not valid")
raise HTTPException(status_code=401, detail="Token not valid")

raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"Error": "Token is expired or invalid", "SubCode": "InvalidToken"},
headers={"WWW-Authenticate": "Bearer"},
)
return AuthUserDTO(id=user_id)


Expand All @@ -271,6 +284,5 @@ async def login_required_optional(
valid_token, user_id = AuthenticationService.is_valid_token(decoded_token, 604800)
if not valid_token:
logger.debug("Token not valid")
raise HTTPException(status_code=401, detail="Token not valid")

return None
return AuthUserDTO(id=user_id)
Loading