Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

A few overload protection mechanisms #552

Merged
merged 7 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 82 additions & 8 deletions sqld/src/database/factory.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use std::{sync::Arc, time::Duration};
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};

use futures::Future;
use tokio::{sync::Semaphore, time::timeout};
Expand All @@ -15,11 +21,16 @@ pub trait DbFactory: Send + Sync + 'static {

async fn create(&self) -> Result<Self::Db, Error>;

fn throttled(self, conccurency: usize, timeout: Option<Duration>) -> ThrottledDbFactory<Self>
fn throttled(
self,
conccurency: usize,
timeout: Option<Duration>,
max_total_response_size: u64,
) -> ThrottledDbFactory<Self>
where
Self: Sized,
{
ThrottledDbFactory::new(conccurency, self, timeout)
ThrottledDbFactory::new(conccurency, self, timeout, max_total_response_size)
}
}

Expand All @@ -38,34 +49,96 @@ where
}
}

#[derive(Clone)]
pub struct ThrottledDbFactory<F> {
semaphore: Arc<Semaphore>,
factory: F,
timeout: Option<Duration>,
// Max memory available for responses. High memory pressure
// will result in reducing concurrency to prevent out-of-memory errors.
max_total_response_size: u64,
waiters: AtomicUsize,
}

impl<F> ThrottledDbFactory<F> {
fn new(conccurency: usize, factory: F, timeout: Option<Duration>) -> Self {
fn new(
conccurency: usize,
factory: F,
timeout: Option<Duration>,
max_total_response_size: u64,
) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(conccurency)),
factory,
timeout,
max_total_response_size,
waiters: AtomicUsize::new(0),
}
}

// How many units should be acquired from the semaphore,
// depending on current memory pressure.
fn units_to_take(&self) -> u32 {
let total_response_size = crate::query_result_builder::TOTAL_RESPONSE_SIZE
.load(std::sync::atomic::Ordering::Relaxed) as u64;
if total_response_size * 2 > self.max_total_response_size {
tracing::trace!("High memory pressure, reducing concurrency");
16
} else if total_response_size * 4 > self.max_total_response_size {
tracing::trace!("Medium memory pressure, reducing concurrency");
4
} else {
1
}
}
}

struct WaitersGuard<'a> {
pub waiters: &'a AtomicUsize,
}

impl<'a> WaitersGuard<'a> {
fn new(waiters: &'a AtomicUsize) -> Self {
waiters.fetch_add(1, Ordering::Relaxed);
Self { waiters }
}
}

impl Drop for WaitersGuard<'_> {
fn drop(&mut self) {
self.waiters.fetch_sub(1, Ordering::Relaxed);
}
}

#[async_trait::async_trait]
impl<F: DbFactory> DbFactory for ThrottledDbFactory<F> {
type Db = TrackedDb<F::Db>;

async fn create(&self) -> Result<Self::Db, Error> {
let fut = self.semaphore.clone().acquire_owned();
let permit = match self.timeout {
// If the memory pressure is high, request more units to reduce concurrency.
let units = self.units_to_take();
let waiters_guard = WaitersGuard::new(&self.waiters);
if waiters_guard.waiters.load(Ordering::Relaxed) >= 128 {
return Err(Error::TooManyRequests);
}
let fut = self.semaphore.clone().acquire_many_owned(units);
let mut permit = match self.timeout {
Some(t) => timeout(t, fut).await.map_err(|_| Error::DbCreateTimeout)?,
None => fut.await,
}
.expect("semaphore closed");

let units = self.units_to_take();
if units > 1 {
tracing::debug!("Reacquiring {units} units due to high memory pressure");
let fut = self.semaphore.clone().acquire_many_owned(64);
let mem_permit = match self.timeout {
Some(t) => timeout(t, fut).await.map_err(|_| Error::DbCreateTimeout)?,
None => fut.await,
}
.expect("semaphore closed");
permit.merge(mem_permit);
}

let inner = self.factory.create().await?;
Ok(TrackedDb { permit, inner })
}
Expand Down Expand Up @@ -123,7 +196,8 @@ mod test {

#[tokio::test]
async fn throttle_db_creation() {
let factory = (|| async { Ok(DummyDb) }).throttled(10, Some(Duration::from_millis(100)));
let factory =
(|| async { Ok(DummyDb) }).throttled(10, Some(Duration::from_millis(100)), u64::MAX);

let mut conns = Vec::with_capacity(10);
for _ in 0..10 {
Expand Down
5 changes: 5 additions & 0 deletions sqld/src/database/libsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub struct LibSqlDbFactory<W: WalHook + 'static> {
config_store: Arc<DatabaseConfigStore>,
extensions: Vec<PathBuf>,
max_response_size: u64,
max_total_response_size: u64,
/// In wal mode, closing the last database takes time, and causes other databases creation to
/// return sqlite busy. To mitigate that, we hold on to one connection
_db: Option<LibSqlDb>,
Expand All @@ -45,6 +46,7 @@ where
W: WalHook + 'static + Sync + Send,
W::Context: Send + 'static,
{
#[allow(clippy::too_many_arguments)]
pub async fn new<F>(
db_path: PathBuf,
hook: &'static WalMethodsHook<W>,
Expand All @@ -53,6 +55,7 @@ where
config_store: Arc<DatabaseConfigStore>,
extensions: Vec<PathBuf>,
max_response_size: u64,
max_total_response_size: u64,
) -> Result<Self>
where
F: Fn() -> W::Context + Sync + Send + 'static,
Expand All @@ -65,6 +68,7 @@ where
config_store,
extensions,
max_response_size,
max_total_response_size,
_db: None,
};

Expand Down Expand Up @@ -113,6 +117,7 @@ where
self.config_store.clone(),
QueryBuilderConfig {
max_size: Some(self.max_response_size),
max_total_size: Some(self.max_total_response_size),
},
)
.await
Expand Down
4 changes: 4 additions & 0 deletions sqld/src/database/write_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub struct WriteProxyDbFactory {
config_store: Arc<DatabaseConfigStore>,
applied_frame_no_receiver: watch::Receiver<FrameNo>,
max_response_size: u64,
max_total_response_size: u64,
}

impl WriteProxyDbFactory {
Expand All @@ -48,6 +49,7 @@ impl WriteProxyDbFactory {
config_store: Arc<DatabaseConfigStore>,
applied_frame_no_receiver: watch::Receiver<FrameNo>,
max_response_size: u64,
max_total_response_size: u64,
) -> Self {
let client = ProxyClient::with_origin(channel, uri);
Self {
Expand All @@ -58,6 +60,7 @@ impl WriteProxyDbFactory {
config_store,
applied_frame_no_receiver,
max_response_size,
max_total_response_size,
}
}
}
Expand All @@ -75,6 +78,7 @@ impl DbFactory for WriteProxyDbFactory {
self.applied_frame_no_receiver.clone(),
QueryBuilderConfig {
max_size: Some(self.max_response_size),
max_total_size: Some(self.max_total_response_size),
},
)
.await?;
Expand Down
2 changes: 2 additions & 0 deletions sqld/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ pub enum Error {
Blocked(Option<String>),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error("Too many concurrent requests")]
TooManyRequests,
}

impl From<tokio::sync::oneshot::error::RecvError> for Error {
Expand Down
72 changes: 48 additions & 24 deletions sqld/src/hrana/result_builder.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::fmt::{self, Write as _};
use std::io;
use std::sync::atomic::Ordering;

use bytes::Bytes;
use rusqlite::types::ValueRef;

use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error};
use crate::query_result_builder::{
Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError,
Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, TOTAL_RESPONSE_SIZE,
};

use super::proto;
Expand All @@ -21,6 +22,7 @@ pub struct SingleStatementBuilder {
last_insert_rowid: Option<i64>,
current_size: u64,
max_response_size: u64,
max_total_response_size: u64,
}

struct SizeFormatter(u64);
Expand Down Expand Up @@ -61,14 +63,42 @@ fn value_json_size(v: &ValueRef) -> u64 {
f.0
}

impl Drop for SingleStatementBuilder {
fn drop(&mut self) {
TOTAL_RESPONSE_SIZE.fetch_sub(self.current_size as usize, Ordering::Relaxed);
}
}

impl SingleStatementBuilder {
fn inc_current_size(&mut self, size: u64) -> Result<(), QueryResultBuilderError> {
if self.current_size + size > self.max_response_size {
return Err(QueryResultBuilderError::ResponseTooLarge(
self.current_size + size,
));
}

self.current_size += size;
let total_size = TOTAL_RESPONSE_SIZE.fetch_add(size as usize, Ordering::Relaxed) as u64;
if total_size + size > self.max_total_response_size {
tracing::debug!(
"Total responses exceeded threshold: {}/{}, aborting query",
total_size + size,
self.max_total_response_size
);
return Err(QueryResultBuilderError::ResponseTooLarge(total_size + size));
}
Ok(())
}
}

impl QueryResultBuilder for SingleStatementBuilder {
type Ret = Result<proto::StmtResult, crate::error::Error>;

fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> {
*self = Self {
max_response_size: config.max_size.unwrap_or(u64::MAX),
..Default::default()
};
let _ = std::mem::take(self);

self.max_response_size = config.max_size.unwrap_or(u64::MAX);
self.max_total_response_size = config.max_total_size.unwrap_or(u64::MAX);

Ok(())
}
Expand Down Expand Up @@ -119,12 +149,7 @@ impl QueryResultBuilder for SingleStatementBuilder {
}
}));

self.current_size += cols_size;
if self.current_size > self.max_response_size {
return Err(QueryResultBuilderError::ResponseTooLarge(
self.max_response_size,
));
}
self.inc_current_size(cols_size)?;

Ok(())
}
Expand All @@ -150,7 +175,7 @@ impl QueryResultBuilder for SingleStatementBuilder {
));
}

self.current_size += estimate_size;
self.inc_current_size(estimate_size)?;

let val = match v {
ValueRef::Null => proto::Value::Null,
Expand Down Expand Up @@ -188,14 +213,14 @@ impl QueryResultBuilder for SingleStatementBuilder {
Ok(())
}

fn into_ret(self) -> Self::Ret {
match self.err {
fn into_ret(mut self) -> Self::Ret {
match std::mem::take(&mut self.err) {
Some(err) => Err(err),
None => Ok(proto::StmtResult {
cols: self.cols,
rows: self.rows,
affected_row_count: self.affected_row_count,
last_insert_rowid: self.last_insert_rowid,
cols: std::mem::take(&mut self.cols),
rows: std::mem::take(&mut self.rows),
affected_row_count: std::mem::take(&mut self.affected_row_count),
last_insert_rowid: std::mem::take(&mut self.last_insert_rowid),
}),
}
}
Expand Down Expand Up @@ -249,12 +274,11 @@ impl QueryResultBuilder for HranaBatchProtoBuilder {
.finish_step(affected_row_count, last_insert_rowid)?;
self.current_size += self.stmt_builder.current_size;

let new_builder = SingleStatementBuilder {
current_size: 0,
max_response_size: self.max_response_size - self.current_size,
..Default::default()
};
match std::mem::replace(&mut self.stmt_builder, new_builder).into_ret() {
let max_total_response_size = self.stmt_builder.max_total_response_size;
let previous_builder = std::mem::take(&mut self.stmt_builder);
self.stmt_builder.max_response_size = self.max_response_size - self.current_size;
self.stmt_builder.max_total_response_size = max_total_response_size;
match previous_builder.into_ret() {
Ok(res) => {
self.step_results.push((!self.step_empty).then_some(res));
self.step_errors.push(None);
Expand Down
Loading
Loading