From 49887fd3bdfae1b1f9228f257a40c12f1c1f5b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20=C5=A0pa=C4=8Dek?= Date: Thu, 17 Aug 2023 12:21:05 +0200 Subject: [PATCH] Fixup after rebase --- sqld/src/connection/mod.rs | 8 +++++ sqld/src/connection/program.rs | 1 + sqld/src/hrana/cursor.rs | 21 +++++++------ sqld/src/hrana/http/mod.rs | 52 +++++++++++++++++++++----------- sqld/src/hrana/http/stream.rs | 4 +-- sqld/src/hrana/result_builder.rs | 2 +- sqld/src/hrana/ws/conn.rs | 27 ++++++++++++----- sqld/src/hrana/ws/handshake.rs | 29 ++++++++++++++---- sqld/src/hrana/ws/session.rs | 7 +++-- sqld/src/http/mod.rs | 13 +++++--- sqld/src/rpc/proxy.rs | 4 ++- 11 files changed, 114 insertions(+), 54 deletions(-) diff --git a/sqld/src/connection/mod.rs b/sqld/src/connection/mod.rs index ec6da8ac..68ca51e4 100644 --- a/sqld/src/connection/mod.rs +++ b/sqld/src/connection/mod.rs @@ -99,6 +99,9 @@ pub trait Connection: Send + Sync + 'static { /// Parse the SQL statement and return information about it. async fn describe(&self, sql: String, auth: Authenticated) -> Result; + + /// Check whether the connection is in autocommit mode. + async fn is_autocommit(&self) -> Result; } fn make_batch_program(batch: Vec) -> Vec { @@ -273,6 +276,11 @@ impl Connection for TrackedConnection { async fn describe(&self, sql: String, auth: Authenticated) -> crate::Result { self.inner.describe(sql, auth).await } + + #[inline] + async fn is_autocommit(&self) -> crate::Result { + self.inner.is_autocommit().await + } } #[cfg(test)] diff --git a/sqld/src/connection/program.rs b/sqld/src/connection/program.rs index c85110ac..fabfbd18 100644 --- a/sqld/src/connection/program.rs +++ b/sqld/src/connection/program.rs @@ -57,6 +57,7 @@ pub enum Cond { Not { cond: Box }, Or { conds: Vec }, And { conds: Vec }, + IsAutocommit, } pub type DescribeResult = crate::Result; diff --git a/sqld/src/hrana/cursor.rs b/sqld/src/hrana/cursor.rs index 3e8247b5..189651ae 100644 --- a/sqld/src/hrana/cursor.rs +++ b/sqld/src/hrana/cursor.rs @@ -6,7 +6,8 @@ use std::task; use tokio::sync::{mpsc, oneshot}; use crate::auth::Authenticated; -use crate::database::{Database, Program}; +use crate::connection::program::Program; +use crate::connection::Connection; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; @@ -14,8 +15,8 @@ use crate::query_result_builder::{ use super::result_builder::{estimate_cols_json_size, value_json_size, value_to_proto}; use super::{batch, proto, stmt}; -pub struct CursorHandle { - open_tx: Option>>, +pub struct CursorHandle { + open_tx: Option>>, entry_rx: mpsc::Receiver>, } @@ -25,16 +26,16 @@ pub struct SizedEntry { pub size: u64, } -struct OpenReq { - db: Arc, +struct OpenReq { + db: Arc, auth: Authenticated, pgm: Program, } -impl CursorHandle { +impl CursorHandle { pub fn spawn(join_set: &mut tokio::task::JoinSet<()>) -> Self where - D: Database, + C: Connection, { let (open_tx, open_rx) = oneshot::channel(); let (entry_tx, entry_rx) = mpsc::channel(1); @@ -46,7 +47,7 @@ impl CursorHandle { } } - pub fn open(&mut self, db: Arc, auth: Authenticated, pgm: Program) { + pub fn open(&mut self, db: Arc, auth: Authenticated, pgm: Program) { let open_tx = self.open_tx.take().unwrap(); let _: Result<_, _> = open_tx.send(OpenReq { db, auth, pgm }); } @@ -60,8 +61,8 @@ impl CursorHandle { } } -async fn run_cursor( - open_rx: oneshot::Receiver>, +async fn run_cursor( + open_rx: oneshot::Receiver>, entry_tx: mpsc::Sender>, ) { let Ok(open_req) = open_rx.await else { diff --git a/sqld/src/hrana/http/mod.rs b/sqld/src/hrana/http/mod.rs index a48c6d1f..aa60347d 100644 --- a/sqld/src/hrana/http/mod.rs +++ b/sqld/src/hrana/http/mod.rs @@ -42,19 +42,28 @@ impl Server { pub async fn handle_request( &self, + connection_maker: Arc>, auth: Authenticated, req: hyper::Request, endpoint: Endpoint, version: Version, encoding: Encoding, ) -> Result> { - handle_request(self, auth, req, endpoint, version, encoding) - .await - .or_else(|err| { - err.downcast::() - .map(|err| stream_error_response(err, encoding)) - }) - .or_else(|err| err.downcast::().map(protocol_error_response)) + handle_request( + self, + connection_maker, + auth, + req, + endpoint, + version, + encoding, + ) + .await + .or_else(|err| { + err.downcast::() + .map(|err| stream_error_response(err, encoding)) + }) + .or_else(|err| err.downcast::().map(protocol_error_response)) } } @@ -65,8 +74,9 @@ pub(crate) async fn handle_index() -> hyper::Response { ) } -async fn handle_request( - server: &Server, +async fn handle_request( + server: &Server, + connection_maker: Arc>, auth: Authenticated, req: hyper::Request, endpoint: Endpoint, @@ -74,21 +84,26 @@ async fn handle_request( encoding: Encoding, ) -> Result> { match endpoint { - Endpoint::Pipeline => handle_pipeline(server, auth, req, version, encoding).await, - Endpoint::Cursor => handle_cursor(server, auth, req, version, encoding).await, + Endpoint::Pipeline => { + handle_pipeline(server, connection_maker, auth, req, version, encoding).await + } + Endpoint::Cursor => { + handle_cursor(server, connection_maker, auth, req, version, encoding).await + } } } -async fn handle_pipeline( - server: &Server, - connection_maker: Arc>, +async fn handle_pipeline( + server: &Server, + connection_maker: Arc>, auth: Authenticated, req: hyper::Request, version: Version, encoding: Encoding, ) -> Result> { let req_body: proto::PipelineReqBody = read_decode_request(req, encoding).await?; - let mut stream_guard = stream::acquire(server, req_body.baton.as_deref()).await?; + let mut stream_guard = + stream::acquire(server, connection_maker, req_body.baton.as_deref()).await?; let mut results = Vec::with_capacity(req_body.requests.len()); for request in req_body.requests.into_iter() { @@ -106,15 +121,16 @@ async fn handle_pipeline( Ok(encode_response(hyper::StatusCode::OK, &resp_body, encoding)) } -async fn handle_cursor( - server: &Server, +async fn handle_cursor( + server: &Server, + connection_maker: Arc>, auth: Authenticated, req: hyper::Request, version: Version, encoding: Encoding, ) -> Result> { let req_body: proto::CursorReqBody = read_decode_request(req, encoding).await?; - let stream_guard = stream::acquire(server, req_body.baton.as_deref()).await?; + let stream_guard = stream::acquire(server, connection_maker, req_body.baton.as_deref()).await?; let mut join_set = tokio::task::JoinSet::new(); let mut cursor_hnd = cursor::CursorHandle::spawn(&mut join_set); diff --git a/sqld/src/hrana/http/stream.rs b/sqld/src/hrana/http/stream.rs index 35c0f283..326f8332 100644 --- a/sqld/src/hrana/http/stream.rs +++ b/sqld/src/hrana/http/stream.rs @@ -104,8 +104,8 @@ impl ServerStreamState { /// otherwise we create a new stream. pub async fn acquire<'srv, D: Connection>( server: &'srv Server, + connection_maker: Arc>, baton: Option<&str>, - db_factory: Arc>, ) -> Result> { let stream = match baton { Some(baton) => { @@ -148,7 +148,7 @@ pub async fn acquire<'srv, D: Connection>( stream } None => { - let db = db_factory + let db = connection_maker .create() .await .context("Could not create a database connection")?; diff --git a/sqld/src/hrana/result_builder.rs b/sqld/src/hrana/result_builder.rs index 7c6343cd..c26b52f1 100644 --- a/sqld/src/hrana/result_builder.rs +++ b/sqld/src/hrana/result_builder.rs @@ -149,7 +149,7 @@ impl QueryResultBuilder for SingleStatementBuilder { let mut f = SizeFormatter::new(); write!(&mut f, "{error}").unwrap(); TOTAL_RESPONSE_SIZE.fetch_sub(self.current_size as usize, Ordering::Relaxed); - self.current_size = f.0; + self.current_size = f.size; TOTAL_RESPONSE_SIZE.fetch_add(self.current_size as usize, Ordering::Relaxed); self.err = Some(error); diff --git a/sqld/src/hrana/ws/conn.rs b/sqld/src/hrana/ws/conn.rs index 6a7c0897..db673212 100644 --- a/sqld/src/hrana/ws/conn.rs +++ b/sqld/src/hrana/ws/conn.rs @@ -53,17 +53,19 @@ pub(super) async fn handle_tcp( socket: tokio::net::TcpStream, conn_id: u64, ) -> Result<()> { - let (ws, version, encoding) = handshake::handshake_tcp(socket) - .await - .context("Could not perform the WebSocket handshake on TCP connection")?; - let (ws, version, encoding, ns) = handshake::handshake_tcp( + let handshake::Output { + ws, + version, + encoding, + namespace, + } = handshake::handshake_tcp( socket, server.disable_default_namespace, server.disable_namespaces, ) .await .context("Could not perform the WebSocket handshake on TCP connection")?; - handle_ws(server, ws, version, encoding, conn_id, ns).await + handle_ws(server, ws, version, encoding, conn_id, namespace).await } pub(super) async fn handle_upgrade( @@ -71,14 +73,19 @@ pub(super) async fn handle_upgrade( upgrade: Upgrade, conn_id: u64, ) -> Result<()> { - let (ws, version, encoding, ns) = handshake::handshake_upgrade( + let handshake::Output { + ws, + version, + encoding, + namespace, + } = handshake::handshake_upgrade( upgrade, server.disable_default_namespace, server.disable_namespaces, ) .await .context("Could not perform the WebSocket handshake on HTTP connection")?; - handle_ws(server, ws, version, encoding, conn_id, ns).await + handle_ws(server, ws, version, encoding, conn_id, namespace).await } async fn handle_ws( @@ -210,7 +217,10 @@ async fn handle_client_msg( } } -async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { +async fn handle_hello_msg( + conn: &mut Conn, + jwt: Option, +) -> Result { let hello_res = match conn.session.as_mut() { None => session::handle_initial_hello(&conn.server, conn.version, jwt) .map(|session| conn.session = Some(session)), @@ -246,6 +256,7 @@ async fn handle_request_msg( }; let response_rx = session::handle_request( + &conn.server, session, &mut conn.join_set, request, diff --git a/sqld/src/hrana/ws/handshake.rs b/sqld/src/hrana/ws/handshake.rs index 14e00eb1..9e25a713 100644 --- a/sqld/src/hrana/ws/handshake.rs +++ b/sqld/src/hrana/ws/handshake.rs @@ -6,7 +6,6 @@ use tungstenite::http; use crate::http::db_factory::namespace_from_headers; -use super::super::Version; use super::super::{Encoding, Version}; use super::Upgrade; @@ -24,11 +23,19 @@ enum Subproto { Hrana3Protobuf, } +#[derive(Debug)] +pub struct Output { + pub ws: WebSocket, + pub version: Version, + pub encoding: Encoding, + pub namespace: Bytes, +} + pub async fn handshake_tcp( socket: tokio::net::TcpStream, disable_default_ns: bool, disable_namespaces: bool, -) -> Result<(WebSocket, Version, Encoding, Bytes)> { +) -> Result { socket .set_nodelay(true) .context("Could not disable Nagle's algorithm")?; @@ -61,17 +68,22 @@ pub async fn handshake_tcp( tokio_tungstenite::accept_hdr_async_with_config(socket, callback, ws_config).await?; let (version, encoding) = subproto.unwrap().version_encoding(); - Ok((WebSocket::Tcp(stream), version, encoding, namespace.unwrap())) + Ok(Output { + ws: WebSocket::Tcp(stream), + version, + encoding, + namespace: namespace.unwrap(), + }) } pub async fn handshake_upgrade( upgrade: Upgrade, disable_default_ns: bool, disable_namespaces: bool, -) -> Result<(WebSocket, Version, Encoding, Bytes)> { +) -> Result { let mut req = upgrade.request; - let ns = namespace_from_headers(req.headers(), disable_default_ns, disable_namespaces)?; + let namespace = namespace_from_headers(req.headers(), disable_default_ns, disable_namespaces)?; let ws_config = Some(get_ws_config()); let (mut resp, stream_fut_subproto_res) = match hyper_tungstenite::upgrade(&mut req, ws_config) { @@ -112,7 +124,12 @@ pub async fn handshake_upgrade( .context("Could not upgrade HTTP request to a WebSocket")?; let (version, encoding) = subproto.version_encoding(); - Ok((WebSocket::Upgraded(stream), version, encoding, ns)) + Ok(Output { + ws: WebSocket::Upgraded(stream), + version, + encoding, + namespace, + }) } fn negotiate_subproto( diff --git a/sqld/src/hrana/ws/session.rs b/sqld/src/hrana/ws/session.rs index 608b23d9..f2e3ee72 100644 --- a/sqld/src/hrana/ws/session.rs +++ b/sqld/src/hrana/ws/session.rs @@ -102,11 +102,12 @@ pub(super) fn handle_repeated_hello( Ok(()) } -pub(super) async fn handle_request( - session: &mut Session, +pub(super) async fn handle_request( + server: &Server, + session: &mut Session<::Connection>, join_set: &mut tokio::task::JoinSet<()>, req: proto::Request, - connection_maker: Arc>, + connection_maker: Arc::Connection>>, ) -> Result>> { // TODO: this function has rotten: it is too long and contains too much duplicated code. It // should be refactored at the next opportunity, together with code in stmt.rs and batch.rs diff --git a/sqld/src/http/mod.rs b/sqld/src/http/mod.rs index 07b6f29e..a7e97888 100644 --- a/sqld/src/http/mod.rs +++ b/sqld/src/http/mod.rs @@ -112,9 +112,9 @@ fn parse_queries(queries: Vec) -> crate::Result> { Ok(out) } -async fn handle_query( +async fn handle_query( auth: Authenticated, - MakeConnectionExtractor(connection_maker): MakeConnectionExtractor, + MakeConnectionExtractor(connection_maker): MakeConnectionExtractor, Json(query): Json, ) -> Result { let batch = parse_queries(query.statements)?; @@ -252,14 +252,17 @@ where macro_rules! handle_hrana { ($endpoint:expr, $version:expr, $encoding:expr,) => {{ - async fn handle_hrana( - AxumState(state): AxumState>, + async fn handle_hrana( + AxumState(state): AxumState>, + MakeConnectionExtractor(connection_maker): MakeConnectionExtractor< + ::Connection, + >, auth: Authenticated, req: Request, ) -> Result, Error> { Ok(state .hrana_http_srv - .handle_request(auth, req, $endpoint, $version, $encoding) + .handle_request(connection_maker, auth, req, $endpoint, $version, $encoding) .await?) } handle_hrana diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index ea03df8d..8751a170 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -249,7 +249,9 @@ pub mod rpc { connection::program::Cond::And { conds } => cond::Cond::And(AndCond { conds: conds.into_iter().map(|c| c.into()).collect(), }), - database::Cond::IsAutocommit => cond::Cond::IsAutocommit(IsAutocommitCond {}), + connection::program::Cond::IsAutocommit => { + cond::Cond::IsAutocommit(IsAutocommitCond {}) + } }; Self { cond: Some(cond) }