From 14e43b48d7b7c3cb6b8d643df3f543c4c4fb8b4c Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Wed, 2 Aug 2023 10:37:23 -0400 Subject: [PATCH] Add error handling --- sqld/src/error.rs | 51 ++++++++++++++- sqld/src/http/hrana_over_http_1.rs | 20 +++--- sqld/src/http/mod.rs | 102 +++++++++-------------------- 3 files changed, 92 insertions(+), 81 deletions(-) diff --git a/sqld/src/error.rs b/sqld/src/error.rs index eab89f52..a39c1202 100644 --- a/sqld/src/error.rs +++ b/sqld/src/error.rs @@ -1,4 +1,7 @@ -use crate::query_result_builder::QueryResultBuilderError; +use axum::response::IntoResponse; +use hyper::StatusCode; + +use crate::{auth::AuthError, query_result_builder::QueryResultBuilderError}; #[allow(clippy::enum_variant_names)] #[derive(Debug, thiserror::Error)] @@ -39,6 +42,52 @@ pub enum Error { Json(#[from] serde_json::Error), #[error("Too many concurrent requests")] TooManyRequests, + #[error("Failed to parse query: `{0}`")] + FailedToParse(String), + #[error("Query error: `{0}`")] + QueryError(String), + #[error("Unauthorized: `{0}`")] + AuthError(#[from] AuthError), + // Catch-all error since we use anyhow in certain places + #[error("Internal Error: `{0}`")] + Anyhow(#[from] anyhow::Error), +} + +impl Error { + fn format_err(&self, status: StatusCode) -> axum::response::Response { + let json = serde_json::json!({ "error": self.to_string() }); + (status, axum::Json(json)).into_response() + } +} + +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + use Error::*; + + match &self { + FailedToParse(_) => self.format_err(StatusCode::BAD_REQUEST), + AuthError(_) => self.format_err(StatusCode::UNAUTHORIZED), + Anyhow(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + LibSqlInvalidQueryParams(_) => self.format_err(StatusCode::BAD_REQUEST), + LibSqlTxTimeout => self.format_err(StatusCode::BAD_REQUEST), + LibSqlTxBusy => self.format_err(StatusCode::TOO_MANY_REQUESTS), + IOError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + RusqliteError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + RpcQueryError(_) => self.format_err(StatusCode::BAD_REQUEST), + RpcQueryExecutionError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + DbValueError(_) => self.format_err(StatusCode::BAD_REQUEST), + Internal(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + InvalidBatchStep(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + NotAuthorized(_) => self.format_err(StatusCode::UNAUTHORIZED), + ReplicatorExited => self.format_err(StatusCode::SERVICE_UNAVAILABLE), + DbCreateTimeout => self.format_err(StatusCode::SERVICE_UNAVAILABLE), + BuilderError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + Blocked(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + Json(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + TooManyRequests => self.format_err(StatusCode::TOO_MANY_REQUESTS), + QueryError(_) => self.format_err(StatusCode::BAD_REQUEST), + } + } } impl From for Error { diff --git a/sqld/src/http/hrana_over_http_1.rs b/sqld/src/http/hrana_over_http_1.rs index 1f301259..570ef300 100644 --- a/sqld/src/http/hrana_over_http_1.rs +++ b/sqld/src/http/hrana_over_http_1.rs @@ -30,7 +30,7 @@ pub(crate) async fn handle_execute( AxumState(AppState { db_factory, .. }): AxumState>, auth: Authenticated, req: hyper::Request, -) -> hyper::Response { +) -> crate::Result> { #[derive(Debug, Deserialize)] struct ReqBody { stmt: hrana::proto::Stmt, @@ -41,7 +41,7 @@ pub(crate) async fn handle_execute( result: hrana::proto::StmtResult, } - handle_request(db_factory, req, |db, req_body: ReqBody| async move { + let res = handle_request(db_factory, req, |db, req_body: ReqBody| async move { let query = hrana::stmt::proto_stmt_to_query( &req_body.stmt, &HashMap::new(), @@ -54,16 +54,16 @@ pub(crate) async fn handle_execute( .map_err(catch_stmt_error) .context("Could not execute statement") }) - .await - // TODO(lucio): Handle error - .unwrap() + .await?; + + Ok(res) } pub(crate) async fn handle_batch( AxumState(AppState { db_factory, .. }): AxumState>, auth: Authenticated, req: hyper::Request, -) -> hyper::Response { +) -> crate::Result> { #[derive(Debug, Deserialize)] struct ReqBody { batch: hrana::proto::Batch, @@ -74,7 +74,7 @@ pub(crate) async fn handle_batch( result: hrana::proto::BatchResult, } - handle_request(db_factory, req, |db, req_body: ReqBody| async move { + let res = handle_request(db_factory, req, |db, req_body: ReqBody| async move { let pgm = hrana::batch::proto_batch_to_program( &req_body.batch, &HashMap::new(), @@ -86,9 +86,9 @@ pub(crate) async fn handle_batch( .map(|result| RespBody { result }) .context("Could not execute batch") }) - .await - // TODO(lucio): handle errors - .unwrap() + .await?; + + Ok(res) } async fn handle_request( diff --git a/sqld/src/http/mod.rs b/sqld/src/http/mod.rs index ba682f0c..ec8592f8 100644 --- a/sqld/src/http/mod.rs +++ b/sqld/src/http/mod.rs @@ -11,17 +11,15 @@ use axum::extract::{FromRef, FromRequestParts, State as AxumState}; use axum::http::request::Parts; use axum::response::{Html, IntoResponse}; use axum::routing::{get, post}; -use axum::Router; +use axum::{Json, Router}; use axum_extra::middleware::option_layer; use base64::prelude::BASE64_STANDARD_NO_PAD; use base64::Engine; -use bytes::Bytes; use hyper::server::conn::AddrIncoming; -use hyper::{Body, Request, Response, StatusCode}; +use hyper::{header, Body, Request, Response, StatusCode}; use serde::Serialize; use serde_json::Number; use tokio::sync::{mpsc, oneshot}; -use tonic::codegen::http; use tower_http::trace::DefaultOnResponse; use tower_http::{compression::CompressionLayer, cors}; use tracing::{Level, Span}; @@ -73,28 +71,13 @@ struct RowsResponse { rows: Vec>, } -#[derive(Debug, Serialize)] -struct ErrorResponse { - message: String, -} - -fn error(msg: &str, code: StatusCode) -> Response { - let err = serde_json::json!({ "error": msg }); - Response::builder() - .status(code) - .body(Body::from(serde_json::to_vec(&err).unwrap())) - .unwrap() -} - -fn parse_queries(queries: Vec) -> anyhow::Result> { +fn parse_queries(queries: Vec) -> crate::Result> { let mut out = Vec::with_capacity(queries.len()); for query in queries { let mut iter = Statement::parse(&query.q); let stmt = iter.next().transpose()?.unwrap_or_default(); if iter.next().is_some() { - anyhow::bail!( - "found more than one command in a single statement string. It is allowed to issue only one command per string." - ); + return Err(Error::FailedToParse("found more than one command in a single statement string. It is allowed to issue only one command per string.".to_string())); } let query = Query { stmt, @@ -106,7 +89,11 @@ fn parse_queries(queries: Vec) -> anyhow::Result> { } match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) { - State::Txn => anyhow::bail!("interactive transaction not allowed in HTTP queries"), + State::Txn => { + return Err(Error::QueryError( + "interactive transaction not allowed in HTTP queries".to_string(), + )) + } State::Init => (), // maybe we should err here, but let's sqlite deal with that. State::Invalid => (), @@ -115,45 +102,25 @@ fn parse_queries(queries: Vec) -> anyhow::Result> { Ok(out) } -fn parse_payload(data: &[u8]) -> Result> { - match serde_json::from_slice(data) { - Ok(data) => Ok(data), - Err(e) => Err(error(&e.to_string(), http::status::StatusCode::BAD_REQUEST)), - } -} - async fn handle_query( auth: Authenticated, AxumState(state): AxumState>, - body: Bytes, -) -> Response { + Json(query): Json, +) -> Result { let AppState { db_factory, .. } = state; - let req = match parse_payload(&body) { - Ok(req) => req, - Err(resp) => return resp, - }; - - let batch = match parse_queries(req.statements) { - Ok(queries) => queries, - Err(e) => return error(&e.to_string(), StatusCode::BAD_REQUEST), - }; + let batch = parse_queries(query.statements)?; - // TODO(lucio): convert this error into a status - let db = db_factory.create().await.unwrap(); + let db = db_factory.create().await?; let builder = JsonHttpPayloadBuilder::new(); - match db.execute_batch_or_rollback(batch, auth, builder).await { - // TODO(lucio): convert these into axum responses - Ok((builder, _)) => Response::builder() - .header("Content-Type", "application/json") - .body(Body::from(builder.into_ret())) - .unwrap(), - Err(e) => error( - &format!("internal error: {e}"), - StatusCode::INTERNAL_SERVER_ERROR, - ), - } + let (builder, _) = db.execute_batch_or_rollback(batch, auth, builder).await?; + + let res = ( + [(header::CONTENT_TYPE, "application/json")], + builder.into_ret(), + ); + Ok(res.into_response()) } async fn show_console( @@ -206,17 +173,20 @@ async fn handle_hrana_v2( AxumState(state): AxumState>, auth: Authenticated, req: Request, -) -> Response { +) -> Result, Error> { let server = state.hrana_http_srv; - // TODO(lucio): handle error - server.handle_pipeline(auth, req).await.unwrap() + let res = server.handle_pipeline(auth, req).await?; + + Ok(res) } async fn handle_fallback() -> impl IntoResponse { (StatusCode::NOT_FOUND).into_response() } +/// Router wide state that each request has access too via +/// axum's `State` extractor. pub(crate) struct AppState { auth: Arc, db_factory: Arc>, @@ -301,11 +271,11 @@ pub async fn run_http( ); let listener = tokio::net::TcpListener::bind(&addr).await?; - let server = hyper::server::Server::builder(AddrIncoming::from_listener(listener)?) + hyper::server::Server::builder(AddrIncoming::from_listener(listener)?) .tcp_nodelay(true) - .serve(layered_app.into_make_service()); - - server.await.context("Http server exited with an error")?; + .serve(layered_app.into_make_service()) + .await + .context("foo")?; Ok(()) } @@ -317,21 +287,13 @@ where Arc: FromRef, S: Send + Sync, { - type Rejection = axum::response::Response; + type Rejection = Error; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let auth = as FromRef>::from_ref(state); let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); - let auth = match auth.authenticate_http(auth_header) { - Ok(auth) => auth, - Err(err) => { - return Err(Response::builder() - .status(hyper::StatusCode::UNAUTHORIZED) - .body(err.to_string()) - .unwrap()); - } - }; + let auth = auth.authenticate_http(auth_header)?; Ok(auth) }