Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Commit

Permalink
Add cursors to Hrana over HTTP
Browse files Browse the repository at this point in the history
  • Loading branch information
honzasp committed Aug 17, 2023
1 parent 3c0801b commit 7b9cb67
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 21 deletions.
12 changes: 10 additions & 2 deletions sqld/src/hrana/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::{anyhow, Result};
use rusqlite::types::ValueRef;
use std::mem::take;
use std::sync::Arc;
use std::task;
use tokio::sync::{mpsc, oneshot};

use crate::auth::Authenticated;
Expand Down Expand Up @@ -30,8 +31,11 @@ struct OpenReq<D> {
pgm: Program,
}

impl<D: Database> CursorHandle<D> {
pub fn spawn(join_set: &mut tokio::task::JoinSet<()>) -> Self {
impl<D> CursorHandle<D> {
pub fn spawn(join_set: &mut tokio::task::JoinSet<()>) -> Self
where
D: Database,
{
let (open_tx, open_rx) = oneshot::channel();
let (entry_tx, entry_rx) = mpsc::channel(1);

Expand All @@ -50,6 +54,10 @@ impl<D: Database> CursorHandle<D> {
pub async fn fetch(&mut self) -> Result<Option<SizedEntry>> {
self.entry_rx.recv().await.transpose()
}

pub fn poll_fetch(&mut self, cx: &mut task::Context) -> task::Poll<Option<Result<SizedEntry>>> {
self.entry_rx.poll_recv(cx)
}
}

async fn run_cursor<D: Database>(
Expand Down
111 changes: 98 additions & 13 deletions sqld/src/hrana/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::sync::Arc;

use anyhow::{bail, Context, Result};
use anyhow::{Context, Result};
use bytes::Bytes;
use futures::stream::Stream;
use parking_lot::Mutex;
use serde::{de::DeserializeOwned, Serialize};
use std::pin::Pin;
use std::sync::Arc;
use std::task;

use super::{Encoding, ProtocolError, Version};
use super::{batch, cursor, Encoding, ProtocolError, Version};
use crate::auth::Authenticated;
use crate::connection::{Connection, MakeConnection};
mod proto;
Expand Down Expand Up @@ -84,7 +87,7 @@ async fn handle_pipeline<D: Database>(
version: Version,
encoding: Encoding,
) -> Result<hyper::Response<hyper::Body>> {
let req_body: proto::PipelineRequestBody = read_decode_request(req, encoding).await?;
let req_body: proto::PipelineReqBody = read_decode_request(req, encoding).await?;
let mut stream_guard = stream::acquire(server, req_body.baton.as_deref()).await?;

let mut results = Vec::with_capacity(req_body.requests.len());
Expand All @@ -95,7 +98,7 @@ async fn handle_pipeline<D: Database>(
results.push(result);
}

let resp_body = proto::PipelineResponseBody {
let resp_body = proto::PipelineRespBody {
baton: stream_guard.release(),
base_url: server.self_url.clone(),
results,
Expand All @@ -104,13 +107,96 @@ async fn handle_pipeline<D: Database>(
}

async fn handle_cursor<D: Database>(
_server: &Server<D>,
_auth: Authenticated,
_req: hyper::Request<hyper::Body>,
_version: Version,
_encoding: Encoding,
server: &Server<D>,
auth: Authenticated,
req: hyper::Request<hyper::Body>,
version: Version,
encoding: Encoding,
) -> Result<hyper::Response<hyper::Body>> {
bail!("Cursor over HTTP not implemented")
let req_body: proto::CursorReqBody = read_decode_request(req, encoding).await?;
let stream_guard = stream::acquire(server, req_body.baton.as_deref()).await?;

let mut join_set = tokio::task::JoinSet::new();
let mut cursor_hnd = cursor::CursorHandle::spawn(&mut join_set);
let db = stream_guard.get_db_owned()?;
let sqls = stream_guard.sqls();
let pgm = batch::proto_batch_to_program(&req_body.batch, sqls, version)?;
cursor_hnd.open(db, auth, pgm);

let resp_body = proto::CursorRespBody {
baton: stream_guard.release(),
base_url: server.self_url.clone(),
};
let body = hyper::Body::wrap_stream(CursorStream {
resp_body: Some(resp_body),
join_set,
cursor_hnd,
encoding,
});
let content_type = match encoding {
Encoding::Json => "text/plain",
Encoding::Protobuf => "application/octet-stream",
};

Ok(hyper::Response::builder()
.status(hyper::StatusCode::OK)
.header(hyper::http::header::CONTENT_TYPE, content_type)
.body(body)
.unwrap())
}

struct CursorStream<D> {
resp_body: Option<proto::CursorRespBody>,
join_set: tokio::task::JoinSet<()>,
cursor_hnd: cursor::CursorHandle<D>,
encoding: Encoding,
}

impl<D> Stream for CursorStream<D> {
type Item = Result<Bytes>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut task::Context,
) -> task::Poll<Option<Result<Bytes>>> {
let this = self.get_mut();

if let Some(resp_body) = this.resp_body.take() {
let chunk = encode_stream_item(&resp_body, this.encoding);
return task::Poll::Ready(Some(Ok(chunk)));
}

match this.join_set.poll_join_next(cx) {
task::Poll::Pending => {}
task::Poll::Ready(Some(Ok(()))) => {}
task::Poll::Ready(Some(Err(err))) => panic!("Cursor task crashed: {}", err),
task::Poll::Ready(None) => {}
};

match this.cursor_hnd.poll_fetch(cx) {
task::Poll::Pending => task::Poll::Pending,
task::Poll::Ready(None) => task::Poll::Ready(None),
task::Poll::Ready(Some(Ok(entry))) => {
let chunk = encode_stream_item(&entry.entry, this.encoding);
task::Poll::Ready(Some(Ok(chunk)))
}
task::Poll::Ready(Some(Err(err))) => task::Poll::Ready(Some(Err(err))),
}
}
}

fn encode_stream_item<T: Serialize + prost::Message>(item: &T, encoding: Encoding) -> Bytes {
let mut data: Vec<u8>;
match encoding {
Encoding::Json => {
data = serde_json::to_vec(item).unwrap();
data.push(b'\n');
}
Encoding::Protobuf => {
data = <T as prost::Message>::encode_length_delimited_to_vec(item);
}
}
Bytes::from(data)
}

async fn read_decode_request<T: DeserializeOwned + prost::Message + Default>(
Expand Down Expand Up @@ -160,7 +246,6 @@ fn encode_response<T: Serialize + prost::Message>(
"application/x-protobuf",
),
};

hyper::Response::builder()
.status(status)
.header(hyper::http::header::CONTENT_TYPE, content_type)
Expand Down
20 changes: 18 additions & 2 deletions sqld/src/hrana/http/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ pub use super::super::proto::*;
use serde::{Deserialize, Serialize};

#[derive(Deserialize, prost::Message)]
pub struct PipelineRequestBody {
pub struct PipelineReqBody {
#[prost(string, optional, tag = "1")]
pub baton: Option<String>,
#[prost(message, repeated, tag = "2")]
pub requests: Vec<StreamRequest>,
}

#[derive(Serialize, prost::Message)]
pub struct PipelineResponseBody {
pub struct PipelineRespBody {
#[prost(string, optional, tag = "1")]
pub baton: Option<String>,
#[prost(string, optional, tag = "2")]
Expand All @@ -34,6 +34,22 @@ pub enum StreamResult {
},
}

#[derive(Deserialize, prost::Message)]
pub struct CursorReqBody {
#[prost(string, optional, tag = "1")]
pub baton: Option<String>,
#[prost(message, required, tag = "2")]
pub batch: Batch,
}

#[derive(Serialize, prost::Message)]
pub struct CursorRespBody {
#[prost(string, optional, tag = "1")]
pub baton: Option<String>,
#[prost(string, optional, tag = "2")]
pub base_url: Option<String>,
}

#[derive(Deserialize, Debug, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamRequest {
Expand Down
2 changes: 1 addition & 1 deletion sqld/src/hrana/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::connection::Connection;

/// An error from executing a [`proto::StreamRequest`]
#[derive(thiserror::Error, Debug)]
pub enum StreamResponseError {
enum StreamResponseError {
#[error("The server already stores {count} SQL texts, it cannot store more")]
SqlTooMany { count: usize },
#[error(transparent)]
Expand Down
11 changes: 8 additions & 3 deletions sqld/src/hrana/http/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ enum Handle<D> {
struct Stream<D> {
/// The database connection that corresponds to this stream. This is `None` after the `"close"`
/// request was executed.
db: Option<D>,
db: Option<Arc<D>>,
/// The cache of SQL texts stored on the server with `"store_sql"` requests.
sqls: HashMap<i32, String>,
/// Stream id of this stream. The id is generated randomly (it should be unguessable).
Expand Down Expand Up @@ -155,7 +155,7 @@ pub async fn acquire<'srv, D: Connection>(

let mut state = server.stream_state.lock();
let stream = Box::new(Stream {
db: Some(db),
db: Some(Arc::new(db)),
sqls: HashMap::new(),
stream_id: gen_stream_id(&mut state),
// initializing the sequence number randomly makes it much harder to exploit
Expand All @@ -181,7 +181,12 @@ pub async fn acquire<'srv, D: Connection>(
impl<'srv, D: Connection> Guard<'srv, D> {
pub fn get_db(&self) -> Result<&D, ProtocolError> {
let stream = self.stream.as_ref().unwrap();
stream.db.as_ref().ok_or(ProtocolError::BatonStreamClosed)
stream.db.as_deref().ok_or(ProtocolError::BatonStreamClosed)
}

pub fn get_db_owned(&self) -> Result<Arc<D>, ProtocolError> {
let stream = self.stream.as_ref().unwrap();
stream.db.clone().ok_or(ProtocolError::BatonStreamClosed)
}

/// Closes the database connection. The next call to [`Guard::release()`] will then remove the
Expand Down

0 comments on commit 7b9cb67

Please sign in to comment.