Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Commit

Permalink
Fixup after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
honzasp committed Aug 17, 2023
1 parent 7b9cb67 commit 49887fd
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 54 deletions.
8 changes: 8 additions & 0 deletions sqld/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ pub trait Connection: Send + Sync + 'static {

/// Parse the SQL statement and return information about it.
async fn describe(&self, sql: String, auth: Authenticated) -> Result<DescribeResult>;

/// Check whether the connection is in autocommit mode.
async fn is_autocommit(&self) -> Result<bool>;
}

fn make_batch_program(batch: Vec<Query>) -> Vec<Step> {
Expand Down Expand Up @@ -273,6 +276,11 @@ impl<DB: Connection> Connection for TrackedConnection<DB> {
async fn describe(&self, sql: String, auth: Authenticated) -> crate::Result<DescribeResult> {
self.inner.describe(sql, auth).await
}

#[inline]
async fn is_autocommit(&self) -> crate::Result<bool> {
self.inner.is_autocommit().await
}
}

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions sqld/src/connection/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub enum Cond {
Not { cond: Box<Self> },
Or { conds: Vec<Self> },
And { conds: Vec<Self> },
IsAutocommit,
}

pub type DescribeResult = crate::Result<DescribeResponse>;
Expand Down
21 changes: 11 additions & 10 deletions sqld/src/hrana/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ use std::task;
use tokio::sync::{mpsc, oneshot};

use crate::auth::Authenticated;
use crate::database::{Database, Program};
use crate::connection::program::Program;
use crate::connection::Connection;
use crate::query_result_builder::{
Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError,
};

use super::result_builder::{estimate_cols_json_size, value_json_size, value_to_proto};
use super::{batch, proto, stmt};

pub struct CursorHandle<D> {
open_tx: Option<oneshot::Sender<OpenReq<D>>>,
pub struct CursorHandle<C> {
open_tx: Option<oneshot::Sender<OpenReq<C>>>,
entry_rx: mpsc::Receiver<Result<SizedEntry>>,
}

Expand All @@ -25,16 +26,16 @@ pub struct SizedEntry {
pub size: u64,
}

struct OpenReq<D> {
db: Arc<D>,
struct OpenReq<C> {
db: Arc<C>,
auth: Authenticated,
pgm: Program,
}

impl<D> CursorHandle<D> {
impl<C> CursorHandle<C> {
pub fn spawn(join_set: &mut tokio::task::JoinSet<()>) -> Self
where
D: Database,
C: Connection,
{
let (open_tx, open_rx) = oneshot::channel();
let (entry_tx, entry_rx) = mpsc::channel(1);
Expand All @@ -46,7 +47,7 @@ impl<D> CursorHandle<D> {
}
}

pub fn open(&mut self, db: Arc<D>, auth: Authenticated, pgm: Program) {
pub fn open(&mut self, db: Arc<C>, auth: Authenticated, pgm: Program) {
let open_tx = self.open_tx.take().unwrap();
let _: Result<_, _> = open_tx.send(OpenReq { db, auth, pgm });
}
Expand All @@ -60,8 +61,8 @@ impl<D> CursorHandle<D> {
}
}

async fn run_cursor<D: Database>(
open_rx: oneshot::Receiver<OpenReq<D>>,
async fn run_cursor<C: Connection>(
open_rx: oneshot::Receiver<OpenReq<C>>,
entry_tx: mpsc::Sender<Result<SizedEntry>>,
) {
let Ok(open_req) = open_rx.await else {
Expand Down
52 changes: 34 additions & 18 deletions sqld/src/hrana/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,28 @@ impl<C: Connection> Server<C> {

pub async fn handle_request(
&self,
connection_maker: Arc<dyn MakeConnection<Connection = C>>,
auth: Authenticated,
req: hyper::Request<hyper::Body>,
endpoint: Endpoint,
version: Version,
encoding: Encoding,
) -> Result<hyper::Response<hyper::Body>> {
handle_request(self, auth, req, endpoint, version, encoding)
.await
.or_else(|err| {
err.downcast::<stream::StreamError>()
.map(|err| stream_error_response(err, encoding))
})
.or_else(|err| err.downcast::<ProtocolError>().map(protocol_error_response))
handle_request(
self,
connection_maker,
auth,
req,
endpoint,
version,
encoding,
)
.await
.or_else(|err| {
err.downcast::<stream::StreamError>()
.map(|err| stream_error_response(err, encoding))
})
.or_else(|err| err.downcast::<ProtocolError>().map(protocol_error_response))
}
}

Expand All @@ -65,30 +74,36 @@ pub(crate) async fn handle_index() -> hyper::Response<hyper::Body> {
)
}

async fn handle_request<D: Database>(
server: &Server<D>,
async fn handle_request<C: Connection>(
server: &Server<C>,
connection_maker: Arc<dyn MakeConnection<Connection = C>>,
auth: Authenticated,
req: hyper::Request<hyper::Body>,
endpoint: Endpoint,
version: Version,
encoding: Encoding,
) -> Result<hyper::Response<hyper::Body>> {
match endpoint {
Endpoint::Pipeline => handle_pipeline(server, auth, req, version, encoding).await,
Endpoint::Cursor => handle_cursor(server, auth, req, version, encoding).await,
Endpoint::Pipeline => {
handle_pipeline(server, connection_maker, auth, req, version, encoding).await
}
Endpoint::Cursor => {
handle_cursor(server, connection_maker, auth, req, version, encoding).await
}
}
}

async fn handle_pipeline<D: Database>(
server: &Server<D>,
connection_maker: Arc<dyn MakeConnection<Connection = D>>,
async fn handle_pipeline<C: Connection>(
server: &Server<C>,
connection_maker: Arc<dyn MakeConnection<Connection = C>>,
auth: Authenticated,
req: hyper::Request<hyper::Body>,
version: Version,
encoding: Encoding,
) -> Result<hyper::Response<hyper::Body>> {
let req_body: proto::PipelineReqBody = read_decode_request(req, encoding).await?;
let mut stream_guard = stream::acquire(server, req_body.baton.as_deref()).await?;
let mut stream_guard =
stream::acquire(server, connection_maker, req_body.baton.as_deref()).await?;

let mut results = Vec::with_capacity(req_body.requests.len());
for request in req_body.requests.into_iter() {
Expand All @@ -106,15 +121,16 @@ async fn handle_pipeline<D: Database>(
Ok(encode_response(hyper::StatusCode::OK, &resp_body, encoding))
}

async fn handle_cursor<D: Database>(
server: &Server<D>,
async fn handle_cursor<C: Connection>(
server: &Server<C>,
connection_maker: Arc<dyn MakeConnection<Connection = C>>,
auth: Authenticated,
req: hyper::Request<hyper::Body>,
version: Version,
encoding: Encoding,
) -> Result<hyper::Response<hyper::Body>> {
let req_body: proto::CursorReqBody = read_decode_request(req, encoding).await?;
let stream_guard = stream::acquire(server, req_body.baton.as_deref()).await?;
let stream_guard = stream::acquire(server, connection_maker, req_body.baton.as_deref()).await?;

let mut join_set = tokio::task::JoinSet::new();
let mut cursor_hnd = cursor::CursorHandle::spawn(&mut join_set);
Expand Down
4 changes: 2 additions & 2 deletions sqld/src/hrana/http/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ impl<D> ServerStreamState<D> {
/// otherwise we create a new stream.
pub async fn acquire<'srv, D: Connection>(
server: &'srv Server<D>,
connection_maker: Arc<dyn MakeConnection<Connection = D>>,
baton: Option<&str>,
db_factory: Arc<dyn MakeConnection<Connection = D>>,
) -> Result<Guard<'srv, D>> {
let stream = match baton {
Some(baton) => {
Expand Down Expand Up @@ -148,7 +148,7 @@ pub async fn acquire<'srv, D: Connection>(
stream
}
None => {
let db = db_factory
let db = connection_maker
.create()
.await
.context("Could not create a database connection")?;
Expand Down
2 changes: 1 addition & 1 deletion sqld/src/hrana/result_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl QueryResultBuilder for SingleStatementBuilder {
let mut f = SizeFormatter::new();
write!(&mut f, "{error}").unwrap();
TOTAL_RESPONSE_SIZE.fetch_sub(self.current_size as usize, Ordering::Relaxed);
self.current_size = f.0;
self.current_size = f.size;
TOTAL_RESPONSE_SIZE.fetch_add(self.current_size as usize, Ordering::Relaxed);
self.err = Some(error);

Expand Down
27 changes: 19 additions & 8 deletions sqld/src/hrana/ws/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,39 @@ pub(super) async fn handle_tcp<F: MakeNamespace>(
socket: tokio::net::TcpStream,
conn_id: u64,
) -> Result<()> {
let (ws, version, encoding) = handshake::handshake_tcp(socket)
.await
.context("Could not perform the WebSocket handshake on TCP connection")?;
let (ws, version, encoding, ns) = handshake::handshake_tcp(
let handshake::Output {
ws,
version,
encoding,
namespace,
} = handshake::handshake_tcp(
socket,
server.disable_default_namespace,
server.disable_namespaces,
)
.await
.context("Could not perform the WebSocket handshake on TCP connection")?;
handle_ws(server, ws, version, encoding, conn_id, ns).await
handle_ws(server, ws, version, encoding, conn_id, namespace).await
}

pub(super) async fn handle_upgrade<F: MakeNamespace>(
server: Arc<Server<F>>,
upgrade: Upgrade,
conn_id: u64,
) -> Result<()> {
let (ws, version, encoding, ns) = handshake::handshake_upgrade(
let handshake::Output {
ws,
version,
encoding,
namespace,
} = handshake::handshake_upgrade(
upgrade,
server.disable_default_namespace,
server.disable_namespaces,
)
.await
.context("Could not perform the WebSocket handshake on HTTP connection")?;
handle_ws(server, ws, version, encoding, conn_id, ns).await
handle_ws(server, ws, version, encoding, conn_id, namespace).await
}

async fn handle_ws<F: MakeNamespace>(
Expand Down Expand Up @@ -210,7 +217,10 @@ async fn handle_client_msg<F: MakeNamespace>(
}
}

async fn handle_hello_msg(conn: &mut Conn<impl Database>, jwt: Option<String>) -> Result<bool> {
async fn handle_hello_msg<F: MakeNamespace>(
conn: &mut Conn<F>,
jwt: Option<String>,
) -> Result<bool> {
let hello_res = match conn.session.as_mut() {
None => session::handle_initial_hello(&conn.server, conn.version, jwt)
.map(|session| conn.session = Some(session)),
Expand Down Expand Up @@ -246,6 +256,7 @@ async fn handle_request_msg<F: MakeNamespace>(
};

let response_rx = session::handle_request(
&conn.server,
session,
&mut conn.join_set,
request,
Expand Down
29 changes: 23 additions & 6 deletions sqld/src/hrana/ws/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use tungstenite::http;

use crate::http::db_factory::namespace_from_headers;

use super::super::Version;
use super::super::{Encoding, Version};
use super::Upgrade;

Expand All @@ -24,11 +23,19 @@ enum Subproto {
Hrana3Protobuf,
}

#[derive(Debug)]
pub struct Output {
pub ws: WebSocket,
pub version: Version,
pub encoding: Encoding,
pub namespace: Bytes,
}

pub async fn handshake_tcp(
socket: tokio::net::TcpStream,
disable_default_ns: bool,
disable_namespaces: bool,
) -> Result<(WebSocket, Version, Encoding, Bytes)> {
) -> Result<Output> {
socket
.set_nodelay(true)
.context("Could not disable Nagle's algorithm")?;
Expand Down Expand Up @@ -61,17 +68,22 @@ pub async fn handshake_tcp(
tokio_tungstenite::accept_hdr_async_with_config(socket, callback, ws_config).await?;

let (version, encoding) = subproto.unwrap().version_encoding();
Ok((WebSocket::Tcp(stream), version, encoding, namespace.unwrap()))
Ok(Output {
ws: WebSocket::Tcp(stream),
version,
encoding,
namespace: namespace.unwrap(),
})
}

pub async fn handshake_upgrade(
upgrade: Upgrade,
disable_default_ns: bool,
disable_namespaces: bool,
) -> Result<(WebSocket, Version, Encoding, Bytes)> {
) -> Result<Output> {
let mut req = upgrade.request;

let ns = namespace_from_headers(req.headers(), disable_default_ns, disable_namespaces)?;
let namespace = namespace_from_headers(req.headers(), disable_default_ns, disable_namespaces)?;
let ws_config = Some(get_ws_config());
let (mut resp, stream_fut_subproto_res) = match hyper_tungstenite::upgrade(&mut req, ws_config)
{
Expand Down Expand Up @@ -112,7 +124,12 @@ pub async fn handshake_upgrade(
.context("Could not upgrade HTTP request to a WebSocket")?;

let (version, encoding) = subproto.version_encoding();
Ok((WebSocket::Upgraded(stream), version, encoding, ns))
Ok(Output {
ws: WebSocket::Upgraded(stream),
version,
encoding,
namespace,
})
}

fn negotiate_subproto(
Expand Down
7 changes: 4 additions & 3 deletions sqld/src/hrana/ws/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,12 @@ pub(super) fn handle_repeated_hello<F: MakeNamespace>(
Ok(())
}

pub(super) async fn handle_request<D: Connection>(
session: &mut Session<D>,
pub(super) async fn handle_request<F: MakeNamespace>(
server: &Server<F>,
session: &mut Session<<F::Database as Database>::Connection>,
join_set: &mut tokio::task::JoinSet<()>,
req: proto::Request,
connection_maker: Arc<dyn MakeConnection<Connection = D>>,
connection_maker: Arc<dyn MakeConnection<Connection = <F::Database as Database>::Connection>>,
) -> Result<oneshot::Receiver<Result<proto::Response>>> {
// TODO: this function has rotten: it is too long and contains too much duplicated code. It
// should be refactored at the next opportunity, together with code in stmt.rs and batch.rs
Expand Down
Loading

0 comments on commit 49887fd

Please sign in to comment.