Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Commit

Permalink
Add error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
LucioFranco committed Aug 2, 2023
1 parent a08bf6c commit 14e43b4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 81 deletions.
51 changes: 50 additions & 1 deletion sqld/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::query_result_builder::QueryResultBuilderError;
use axum::response::IntoResponse;
use hyper::StatusCode;

use crate::{auth::AuthError, query_result_builder::QueryResultBuilderError};

#[allow(clippy::enum_variant_names)]
#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -39,6 +42,52 @@ pub enum Error {
Json(#[from] serde_json::Error),
#[error("Too many concurrent requests")]
TooManyRequests,
#[error("Failed to parse query: `{0}`")]
FailedToParse(String),
#[error("Query error: `{0}`")]
QueryError(String),
#[error("Unauthorized: `{0}`")]
AuthError(#[from] AuthError),
// Catch-all error since we use anyhow in certain places
#[error("Internal Error: `{0}`")]
Anyhow(#[from] anyhow::Error),
}

impl Error {
fn format_err(&self, status: StatusCode) -> axum::response::Response {
let json = serde_json::json!({ "error": self.to_string() });
(status, axum::Json(json)).into_response()
}
}

impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
use Error::*;

match &self {
FailedToParse(_) => self.format_err(StatusCode::BAD_REQUEST),
AuthError(_) => self.format_err(StatusCode::UNAUTHORIZED),
Anyhow(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
LibSqlInvalidQueryParams(_) => self.format_err(StatusCode::BAD_REQUEST),
LibSqlTxTimeout => self.format_err(StatusCode::BAD_REQUEST),
LibSqlTxBusy => self.format_err(StatusCode::TOO_MANY_REQUESTS),
IOError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
RusqliteError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
RpcQueryError(_) => self.format_err(StatusCode::BAD_REQUEST),
RpcQueryExecutionError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
DbValueError(_) => self.format_err(StatusCode::BAD_REQUEST),
Internal(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
InvalidBatchStep(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
NotAuthorized(_) => self.format_err(StatusCode::UNAUTHORIZED),
ReplicatorExited => self.format_err(StatusCode::SERVICE_UNAVAILABLE),
DbCreateTimeout => self.format_err(StatusCode::SERVICE_UNAVAILABLE),
BuilderError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
Blocked(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
Json(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR),
TooManyRequests => self.format_err(StatusCode::TOO_MANY_REQUESTS),
QueryError(_) => self.format_err(StatusCode::BAD_REQUEST),
}
}
}

impl From<tokio::sync::oneshot::error::RecvError> for Error {
Expand Down
20 changes: 10 additions & 10 deletions sqld/src/http/hrana_over_http_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub(crate) async fn handle_execute<D: Database>(
AxumState(AppState { db_factory, .. }): AxumState<AppState<D>>,
auth: Authenticated,
req: hyper::Request<hyper::Body>,
) -> hyper::Response<hyper::Body> {
) -> crate::Result<hyper::Response<hyper::Body>> {
#[derive(Debug, Deserialize)]
struct ReqBody {
stmt: hrana::proto::Stmt,
Expand All @@ -41,7 +41,7 @@ pub(crate) async fn handle_execute<D: Database>(
result: hrana::proto::StmtResult,
}

handle_request(db_factory, req, |db, req_body: ReqBody| async move {
let res = handle_request(db_factory, req, |db, req_body: ReqBody| async move {
let query = hrana::stmt::proto_stmt_to_query(
&req_body.stmt,
&HashMap::new(),
Expand All @@ -54,16 +54,16 @@ pub(crate) async fn handle_execute<D: Database>(
.map_err(catch_stmt_error)
.context("Could not execute statement")
})
.await
// TODO(lucio): Handle error
.unwrap()
.await?;

Ok(res)
}

pub(crate) async fn handle_batch<D: Database>(
AxumState(AppState { db_factory, .. }): AxumState<AppState<D>>,
auth: Authenticated,
req: hyper::Request<hyper::Body>,
) -> hyper::Response<hyper::Body> {
) -> crate::Result<hyper::Response<hyper::Body>> {
#[derive(Debug, Deserialize)]
struct ReqBody {
batch: hrana::proto::Batch,
Expand All @@ -74,7 +74,7 @@ pub(crate) async fn handle_batch<D: Database>(
result: hrana::proto::BatchResult,
}

handle_request(db_factory, req, |db, req_body: ReqBody| async move {
let res = handle_request(db_factory, req, |db, req_body: ReqBody| async move {
let pgm = hrana::batch::proto_batch_to_program(
&req_body.batch,
&HashMap::new(),
Expand All @@ -86,9 +86,9 @@ pub(crate) async fn handle_batch<D: Database>(
.map(|result| RespBody { result })
.context("Could not execute batch")
})
.await
// TODO(lucio): handle errors
.unwrap()
.await?;

Ok(res)
}

async fn handle_request<ReqBody, RespBody, F, Fut, FT>(
Expand Down
102 changes: 32 additions & 70 deletions sqld/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@ use axum::extract::{FromRef, FromRequestParts, State as AxumState};
use axum::http::request::Parts;
use axum::response::{Html, IntoResponse};
use axum::routing::{get, post};
use axum::Router;
use axum::{Json, Router};
use axum_extra::middleware::option_layer;
use base64::prelude::BASE64_STANDARD_NO_PAD;
use base64::Engine;
use bytes::Bytes;
use hyper::server::conn::AddrIncoming;
use hyper::{Body, Request, Response, StatusCode};
use hyper::{header, Body, Request, Response, StatusCode};
use serde::Serialize;
use serde_json::Number;
use tokio::sync::{mpsc, oneshot};
use tonic::codegen::http;
use tower_http::trace::DefaultOnResponse;
use tower_http::{compression::CompressionLayer, cors};
use tracing::{Level, Span};
Expand Down Expand Up @@ -73,28 +71,13 @@ struct RowsResponse {
rows: Vec<Vec<serde_json::Value>>,
}

#[derive(Debug, Serialize)]
struct ErrorResponse {
message: String,
}

fn error(msg: &str, code: StatusCode) -> Response<Body> {
let err = serde_json::json!({ "error": msg });
Response::builder()
.status(code)
.body(Body::from(serde_json::to_vec(&err).unwrap()))
.unwrap()
}

fn parse_queries(queries: Vec<QueryObject>) -> anyhow::Result<Vec<Query>> {
fn parse_queries(queries: Vec<QueryObject>) -> crate::Result<Vec<Query>> {
let mut out = Vec::with_capacity(queries.len());
for query in queries {
let mut iter = Statement::parse(&query.q);
let stmt = iter.next().transpose()?.unwrap_or_default();
if iter.next().is_some() {
anyhow::bail!(
"found more than one command in a single statement string. It is allowed to issue only one command per string."
);
return Err(Error::FailedToParse("found more than one command in a single statement string. It is allowed to issue only one command per string.".to_string()));
}
let query = Query {
stmt,
Expand All @@ -106,7 +89,11 @@ fn parse_queries(queries: Vec<QueryObject>) -> anyhow::Result<Vec<Query>> {
}

match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) {
State::Txn => anyhow::bail!("interactive transaction not allowed in HTTP queries"),
State::Txn => {
return Err(Error::QueryError(
"interactive transaction not allowed in HTTP queries".to_string(),
))
}
State::Init => (),
// maybe we should err here, but let's sqlite deal with that.
State::Invalid => (),
Expand All @@ -115,45 +102,25 @@ fn parse_queries(queries: Vec<QueryObject>) -> anyhow::Result<Vec<Query>> {
Ok(out)
}

fn parse_payload(data: &[u8]) -> Result<HttpQuery, Response<Body>> {
match serde_json::from_slice(data) {
Ok(data) => Ok(data),
Err(e) => Err(error(&e.to_string(), http::status::StatusCode::BAD_REQUEST)),
}
}

async fn handle_query<D: Database>(
auth: Authenticated,
AxumState(state): AxumState<AppState<D>>,
body: Bytes,
) -> Response<Body> {
Json(query): Json<HttpQuery>,
) -> Result<axum::response::Response, Error> {
let AppState { db_factory, .. } = state;

let req = match parse_payload(&body) {
Ok(req) => req,
Err(resp) => return resp,
};

let batch = match parse_queries(req.statements) {
Ok(queries) => queries,
Err(e) => return error(&e.to_string(), StatusCode::BAD_REQUEST),
};
let batch = parse_queries(query.statements)?;

// TODO(lucio): convert this error into a status
let db = db_factory.create().await.unwrap();
let db = db_factory.create().await?;

let builder = JsonHttpPayloadBuilder::new();
match db.execute_batch_or_rollback(batch, auth, builder).await {
// TODO(lucio): convert these into axum responses
Ok((builder, _)) => Response::builder()
.header("Content-Type", "application/json")
.body(Body::from(builder.into_ret()))
.unwrap(),
Err(e) => error(
&format!("internal error: {e}"),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
let (builder, _) = db.execute_batch_or_rollback(batch, auth, builder).await?;

let res = (
[(header::CONTENT_TYPE, "application/json")],
builder.into_ret(),
);
Ok(res.into_response())
}

async fn show_console<D>(
Expand Down Expand Up @@ -206,17 +173,20 @@ async fn handle_hrana_v2<D: Database>(
AxumState(state): AxumState<AppState<D>>,
auth: Authenticated,
req: Request<Body>,
) -> Response<Body> {
) -> Result<Response<Body>, Error> {
let server = state.hrana_http_srv;

// TODO(lucio): handle error
server.handle_pipeline(auth, req).await.unwrap()
let res = server.handle_pipeline(auth, req).await?;

Ok(res)
}

async fn handle_fallback() -> impl IntoResponse {
(StatusCode::NOT_FOUND).into_response()
}

/// Router wide state that each request has access too via
/// axum's `State` extractor.
pub(crate) struct AppState<D> {
auth: Arc<Auth>,
db_factory: Arc<dyn DbFactory<Db = D>>,
Expand Down Expand Up @@ -301,11 +271,11 @@ pub async fn run_http<D: Database>(
);

let listener = tokio::net::TcpListener::bind(&addr).await?;
let server = hyper::server::Server::builder(AddrIncoming::from_listener(listener)?)
hyper::server::Server::builder(AddrIncoming::from_listener(listener)?)
.tcp_nodelay(true)
.serve(layered_app.into_make_service());

server.await.context("Http server exited with an error")?;
.serve(layered_app.into_make_service())
.await
.context("foo")?;

Ok(())
}
Expand All @@ -317,21 +287,13 @@ where
Arc<Auth>: FromRef<S>,
S: Send + Sync,
{
type Rejection = axum::response::Response<String>;
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let auth = <Arc<Auth> as FromRef<S>>::from_ref(state);

let auth_header = parts.headers.get(hyper::header::AUTHORIZATION);
let auth = match auth.authenticate_http(auth_header) {
Ok(auth) => auth,
Err(err) => {
return Err(Response::builder()
.status(hyper::StatusCode::UNAUTHORIZED)
.body(err.to_string())
.unwrap());
}
};
let auth = auth.authenticate_http(auth_header)?;

Ok(auth)
}
Expand Down

0 comments on commit 14e43b4

Please sign in to comment.