Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GrpcStore Write Retry #638

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nativelink-service/src/bytestream_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,11 @@ impl ByteStreamServer {
// by counting the number of bytes sent from the client. If they send
// less than the amount they said they were going to send and then
// close the stream, we know there's a problem.
Ok(None) => return Err(make_input_err!("Client closed stream before sending all data")),
None => return Err(make_input_err!("Client closed stream before sending all data")),
// Code path for client stream error. Probably client disconnect.
Err(err) => return Err(err),
Some(Err(err)) => return Err(err),
// Code path for received chunk of data.
Ok(Some(write_request)) => write_request,
Some(Ok(write_request)) => write_request,
};

if write_request.write_offset < 0 {
Expand Down
2 changes: 1 addition & 1 deletion nativelink-service/tests/bytestream_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ pub mod write_tests {
// Now disconnect our stream.
drop(tx);
let (result, _bs_server) = join_handle.await?;
assert!(result.is_ok(), "Expected success to be returned");
result?;
}
{
// Check to make sure our store recorded the data properly.
Expand Down
266 changes: 177 additions & 89 deletions nativelink-store/src/grpc_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::marker::Send;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use async_trait::async_trait;
Expand All @@ -35,6 +36,7 @@ use nativelink_proto::google::bytestream::{
};
use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf};
use nativelink_util::common::DigestInfo;
use nativelink_util::resource_info::ResourceInfo;
use nativelink_util::retry::{Retrier, RetryResult};
use nativelink_util::store_trait::{Store, UploadSizeInfo};
use nativelink_util::tls_utils;
Expand All @@ -44,7 +46,6 @@ use prost::Message;
use rand::rngs::OsRng;
use rand::Rng;
use tokio::time::sleep;
use tonic::transport::Channel;
use tonic::{transport, IntoRequest, Request, Response, Status, Streaming};
use tracing::error;
use uuid::Uuid;
Expand Down Expand Up @@ -88,6 +89,132 @@ impl Stream for FirstStream {
}
}

/// This structure wraps all of the information required to perform a write
/// request on the GrpcStore. It stores the last message retrieved which allows
/// the write to resume since the UUID allows upload resume at the server.
struct WriteState<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
instance_name: String,
read_stream_error: Option<Error>,
read_stream: WriteRequestStreamWrapper<T, E>,
// Tonic doesn't appear to report an error until it has taken two messages,
// therefore we are required to buffer the last two messages.
cached_messages: [Option<WriteRequest>; 2],
// When resuming after an error, the previous messages are cloned into this
// queue upfront to allow them to be served back.
resume_queue: [Option<WriteRequest>; 2],
// An optimisation to avoid having to manage resume_queue when it's empty.
is_resumed: bool,
}

impl<T, E> WriteState<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T, E>) -> Self {
Self {
instance_name,
read_stream_error: None,
read_stream,
cached_messages: [None, None],
resume_queue: [None, None],
is_resumed: false,
}
}

fn push_message(&mut self, message: WriteRequest) {
self.cached_messages.swap(0, 1);
self.cached_messages[0] = Some(message);
}

fn resumed_message(&mut self) -> Option<WriteRequest> {
if self.is_resumed {
// The resume_queue is a circular buffer, that we have to shift,
// since its only got two elements its a trivial swap.
self.resume_queue.swap(0, 1);
let message = self.resume_queue[0].take();
if message.is_none() {
self.is_resumed = false;
}
message
} else {
None
}
}

fn can_resume(&self) -> bool {
self.read_stream_error.is_none() && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg())
}

fn resume(&mut self) {
self.resume_queue = self.cached_messages.clone();
self.is_resumed = true;
}
}

/// A wrapper around WriteState to allow it to be reclaimed from the underlying
/// write call in the case of failure.
struct WriteStateWrapper<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
shared_state: Arc<Mutex<WriteState<T, E>>>,
}

impl<T, E> Stream for WriteStateWrapper<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
type Item = WriteRequest;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// This should be an uncontended lock since write was called.
let mut local_state = self.shared_state.lock();
// If this is the first or second call after a failure and we have
// cached messages, then use the cached write requests.
let cached_message = local_state.resumed_message();
if cached_message.is_some() {
return Poll::Ready(cached_message);
}
// Read a new write request from the downstream.
let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) else {
return Poll::Pending;
};
// Update the instance name in the write request and forward it on.
const IS_UPLOAD_TRUE: bool = true;
let result = match maybe_message {
Some(Ok(mut message)) => match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) {
Ok(mut resource_name) => {
if resource_name.instance_name != local_state.instance_name {
resource_name.instance_name = &local_state.instance_name;
message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE);
}
// Cache the last request in case there is an error to allow
// the upload to be resumed.
local_state.push_message(message.clone());
Some(message)
}
Err(err) => {
error!("{err:?}");
None
}
},
Some(Err(err)) => {
local_state.read_stream_error = Some(err);
None
}
None => None,
};
Poll::Ready(result)
}
}

impl GrpcStore {
pub async fn new(config: &nativelink_config::stores::GrpcStore) -> Result<Self, Error> {
let jitter_amt = config.retry.jitter;
Expand Down Expand Up @@ -237,16 +364,12 @@ impl GrpcStore {
}

fn get_read_request(&self, mut request: ReadRequest) -> Result<ReadRequest, Error> {
// `resource_name` pattern is: "{instance_name}/blobs/{hash}/{size}".
let first_slash_pos = request
.resource_name
.find('/')
.err_tip(|| "Resource name expected to follow pattern {instance_name}/blobs/{hash}/{size}")?;
request.resource_name = format!(
"{}/{}",
self.instance_name,
request.resource_name.get((first_slash_pos + 1)..).unwrap()
);
const IS_UPLOAD_FALSE: bool = false;
let mut resource_info = ResourceInfo::new(&request.resource_name, IS_UPLOAD_FALSE)?;
if resource_info.instance_name != self.instance_name {
resource_info.instance_name = &self.instance_name;
request.resource_name = resource_info.to_string(IS_UPLOAD_FALSE);
}
Ok(request)
}

Expand Down Expand Up @@ -295,81 +418,49 @@ impl GrpcStore {
"CAS operation on AC store"
);

struct LocalState<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
instance_name: String,
error: Mutex<Option<Error>>,
read_stream: Mutex<Option<WriteRequestStreamWrapper<T, E>>>,
client: ByteStreamClient<Channel>,
}

let local_state = Arc::new(LocalState {
instance_name: self.instance_name.clone(),
error: Mutex::new(None),
read_stream: Mutex::new(Some(stream)),
client: self.bytestream_client.clone(),
});
let local_state = Arc::new(Mutex::new(WriteState::new(self.instance_name.clone(), stream)));

let result = self
.retrier
.retry(
unfold(local_state, move |local_state| async move {
let stream = unfold((None, local_state.clone()), move |(stream, local_state)| async {
// Only consume the stream on the first request to read,
// then pass it for future requests in the unfold.
let mut stream = stream.or_else(|| local_state.read_stream.lock().take())?;
let maybe_message = stream.next().await;
if let Ok(maybe_message) = maybe_message {
if let Some(mut message) = maybe_message {
// `resource_name` pattern is: "{instance_name}/uploads/{uuid}/blobs/{hash}/{size}".
let first_slash_pos = match message.resource_name.find('/') {
Some(pos) => pos,
None => {
error!("{}", "Resource name should follow pattern {instance_name}/uploads/{uuid}/blobs/{hash}/{size}");
return None;
}
};
message.resource_name = format!(
"{}/{}",
&local_state.instance_name,
message.resource_name.get((first_slash_pos + 1)..).unwrap()
);
return Some((message, (Some(stream), local_state)));
.retry(unfold(local_state, move |local_state| async move {
let mut client = self.bytestream_client.clone();
// The client write may occur on a separate thread and
// therefore in order to share the state with it we have to
// wrap it in a Mutex and retrieve it after the write
// has completed. There is no way to get the value back
// from the client.
let result = client
.write(WriteStateWrapper {
shared_state: local_state.clone(),
})
.await;

// Get the state back from StateWrapper, this should be
// uncontended since write has returned.
let mut local_state_locked = local_state.lock();

let result = if let Some(err) = local_state_locked.read_stream_error.take() {
// If there was an error with the stream, then don't
// retry.
RetryResult::Err(err)
} else {
// On error determine whether it is possible to retry.
match result.err_tip(|| "in GrpcStore::write") {
Err(err) => {
if local_state_locked.can_resume() {
local_state_locked.resume();
RetryResult::Retry(err)
} else {
RetryResult::Err(err.append("Retry is not possible"))
}
return None;
}
// TODO(allada) I'm sure there's a way to do this without a mutex, but rust can be super
// picky with borrowing through a stream await.
*local_state.error.lock() = Some(maybe_message.unwrap_err());
None
});

let result = local_state.client.clone()
.write(stream)
.await
.err_tip(|| "in GrpcStore::write");

// If the stream has been consumed, don't retry, but
// otherwise it's ok to try again.
let result = if local_state.read_stream.lock().is_some() {
result.map_or_else(RetryResult::Retry, RetryResult::Ok)
} else {
result.map_or_else(RetryResult::Err, RetryResult::Ok)
};

// If there was an error with the stream, then don't retry.
let result = if let Some(err) = local_state.error.lock().take() {
RetryResult::Err(err)
} else {
result
};
Ok(response) => RetryResult::Ok(response),
}
};

Some((result, local_state))
}),
)
drop(local_state_locked);
Some((result, local_state))
}))
.await?;
Ok(result)
}
Expand All @@ -385,15 +476,12 @@ impl GrpcStore {

let mut request = grpc_request.into_inner();

// `resource_name` pattern is: "{instance_name}/uploads/{uuid}/blobs/{hash}/{size}".
let first_slash_pos = request.resource_name.find('/').err_tip(|| {
"Resource name expected to follow pattern {instance_name}/uploads/{uuid}/blobs/{hash}/{size}"
})?;
request.resource_name = format!(
"{}/{}",
self.instance_name,
request.resource_name.get((first_slash_pos + 1)..).unwrap()
);
const IS_UPLOAD_TRUE: bool = true;
let mut request_info = ResourceInfo::new(&request.resource_name, IS_UPLOAD_TRUE)?;
if request_info.instance_name != self.instance_name {
request_info.instance_name = &self.instance_name;
request.resource_name = request_info.to_string(IS_UPLOAD_TRUE);
}

self.perform_request(request, |request| async move {
self.bytestream_client
Expand Down
24 changes: 23 additions & 1 deletion nativelink-util/src/resource_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub struct ResourceInfo<'a> {
pub compressor: Option<&'a str>,
pub digest_function: Option<&'a str>,
pub hash: &'a str,
size: &'a str,
pub expected_size: usize,
pub optional_metadata: Option<&'a str>,
}
Expand Down Expand Up @@ -129,6 +130,25 @@ impl<'a> ResourceInfo<'a> {
}
Ok(output)
}

pub fn to_string(&self, is_upload: bool) -> String {
[
Some(self.instance_name),
is_upload.then_some("uploads"),
self.uuid,
Some(self.compressor.map_or("blobs", |_| "compressed-blobs")),
self.compressor,
self.digest_function,
Some(self.hash),
Some(self.size),
self.optional_metadata,
]
.into_iter()
.flatten()
.filter(|part| !part.is_empty())
.collect::<Vec<&str>>()
.join("/")
}
}

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -177,8 +197,9 @@ fn recursive_parse<'a>(
output.compressor = Some(part);
*bytes_processed += part.len() + SLASH_SIZE;
return Ok(state);
} else {
return Err(make_input_err!("Expected compressor, got {part}"));
}
continue;
}
State::DigestFunction => {
state = State::Hash;
Expand All @@ -196,6 +217,7 @@ fn recursive_parse<'a>(
return Ok(State::Size);
}
State::Size => {
output.size = part;
output.expected_size = part
.parse::<usize>()
.map_err(|_| make_input_err!("Digest size_bytes was not convertible to usize. Got: {}", part))?;
Expand Down
Loading
Loading