Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graceful shutdown for network connections #1439

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
rust-version = "1.81.0"

[profile.release]
lto = true
lto = false
opt-level = 3

[profile.dev]
Expand Down
4 changes: 4 additions & 0 deletions nativelink-config/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ pub struct ServerConfig {

/// Services to attach to server.
pub services: Option<ServicesConfig>,

/// Do not wait for connections to close during a graceful shutdown.
#[serde(default)]
pub experimental_connections_dont_block_graceful_shutdown: bool,
}

#[allow(non_camel_case_types)]
Expand Down
1 change: 1 addition & 0 deletions nativelink-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub fn nativelink_test(attr: TokenStream, item: TokenStream) -> TokenStream {
#[allow(clippy::disallowed_methods)]
#[tokio::test(#attr)]
async fn #fn_name(#fn_inputs) #fn_output {
nativelink_util::shutdown_manager::ShutdownManager::init(&tokio::runtime::Handle::current());
// Error means already initialized, which is ok.
let _ = nativelink_util::init_tracing();
// If already set it's ok.
Expand Down
2 changes: 1 addition & 1 deletion nativelink-service/src/worker_api_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl WorkerApiServer {
) -> Result<Response<()>, Error> {
let worker_id: WorkerId = going_away_request.worker_id.try_into()?;
self.scheduler
.remove_worker(&worker_id)
.set_drain_worker(&worker_id, true)
.await
.err_tip(|| "While calling WorkerApiServer::inner_going_away")?;
Ok(Response::new(()))
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ rust_library(
"src/proto_stream_utils.rs",
"src/resource_info.rs",
"src/retry.rs",
"src/shutdown_manager.rs",
"src/store_trait.rs",
"src/task.rs",
"src/tls_utils.rs",
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub mod platform_properties;
pub mod proto_stream_utils;
pub mod resource_info;
pub mod retry;
pub mod shutdown_manager;
pub mod store_trait;
pub mod task;
pub mod tls_utils;
Expand Down
181 changes: 181 additions & 0 deletions nativelink-util/src/shutdown_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright 2024 The NativeLink Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};

use futures::future::ready;
use futures::FutureExt;
use parking_lot::Mutex;
use tokio::runtime::Handle;
#[cfg(target_family = "unix")]
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{broadcast, oneshot};
use tracing::{event, Level};

static SHUTDOWN_MANAGER: ShutdownManager = ShutdownManager {
is_shutting_down: AtomicBool::new(false),
shutdown_tx: Mutex::new(None), // Will be initialized in `init`.
maybe_shutdown_weak_sender: Mutex::new(None),
};

/// Broadcast Channel Capacity
/// Note: The actual capacity may be greater than the provided capacity.
const BROADCAST_CAPACITY: usize = 1;

/// ShutdownManager is a singleton that manages the shutdown of the
/// application. Services can register to be notified when a graceful
/// shutdown is initiated using [`ShutdownManager::wait_for_shutdown`].
/// When the future returned by [`ShutdownManager::wait_for_shutdown`] is
/// completed, the caller will then be handed a [`ShutdownGuard`] which
/// must be held until the caller has completed its shutdown procedure.
/// Once the caller has completed its shutdown procedure, the caller
/// must drop the [`ShutdownGuard`]. When all [`ShutdownGuard`]s have
/// been dropped, the application will then exit.
pub struct ShutdownManager {
is_shutting_down: AtomicBool,
shutdown_tx: Mutex<Option<broadcast::Sender<Arc<oneshot::Sender<()>>>>>,
maybe_shutdown_weak_sender: Mutex<Option<Weak<oneshot::Sender<()>>>>,
}

impl ShutdownManager {
#[allow(clippy::disallowed_methods)]
pub fn init(runtime: &Handle) {
let (shutdown_tx, _) = broadcast::channel::<Arc<oneshot::Sender<()>>>(BROADCAST_CAPACITY);
*SHUTDOWN_MANAGER.shutdown_tx.lock() = Some(shutdown_tx);

runtime.spawn(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen to SIGINT");
event!(Level::WARN, "User terminated process via SIGINT");
std::process::exit(130);
});

#[cfg(target_family = "unix")]
{
runtime.spawn(async move {
signal(SignalKind::terminate())
.expect("Failed to listen to SIGTERM")
.recv()
.await;
event!(Level::WARN, "Received SIGTERM, begginning shutdown.");
Self::graceful_shutdown();
});
}
}

pub fn is_shutting_down() -> bool {
SHUTDOWN_MANAGER.is_shutting_down.load(Ordering::Acquire)
}

#[allow(clippy::disallowed_methods)]
fn graceful_shutdown() {
if SHUTDOWN_MANAGER
.is_shutting_down
.swap(true, Ordering::Release)
{
event!(Level::WARN, "Shutdown already in progress.");
return;
}
let (complete_tx, complete_rx) = oneshot::channel::<()>();
let shutdown_guard = Arc::new(complete_tx);
SHUTDOWN_MANAGER
.maybe_shutdown_weak_sender
.lock()
.replace(Arc::downgrade(&shutdown_guard))
.expect("Expected maybe_shutdown_weak_sender to be empty");
tokio::spawn(async move {
{
let shutdown_tx_lock = SHUTDOWN_MANAGER.shutdown_tx.lock();
// No need to check result of send, since it will only fail if
// all receivers have been dropped, in which case it means we
// can safely shutdown.
let _ = shutdown_tx_lock
.as_ref()
.expect("ShutdownManager was never initialized")
.send(shutdown_guard);
}
// It is impossible for the result to be anything but Err(RecvError),
// which means all receivers have been dropped and we can safely shutdown.
let _ = complete_rx.await;
event!(Level::WARN, "All services gracefully shutdown.",);
std::process::exit(143);
});
}

pub fn wait_for_shutdown(service_name: impl Into<String>) -> impl Future<Output = ShutdownGuard> + Send {
let service_name = service_name.into();
if Self::is_shutting_down() {
let maybe_shutdown_weak_sender_lock = SHUTDOWN_MANAGER
.maybe_shutdown_weak_sender
.lock();
let maybe_sender = maybe_shutdown_weak_sender_lock
.as_ref()
.expect("Expected maybe_shutdown_weak_sender to be set");
if let Some(sender) = maybe_sender.upgrade() {
event!(
Level::INFO,
"Service {service_name} has been notified of shutdown request"
);
return ready(ShutdownGuard {
service_name,
_maybe_guard: Some(sender),
}).left_future();
}
return ready(ShutdownGuard {
service_name,
_maybe_guard: None,
}).left_future();
}
let mut shutdown_receiver = SHUTDOWN_MANAGER
.shutdown_tx
.lock()
.as_ref()
.expect("ShutdownManager was never initialized")
.subscribe();
async move {
let sender = shutdown_receiver
.recv()
.await
.expect("Shutdown sender dropped. This should never happen.");
event!(
Level::INFO,
"Service {service_name} has been notified of shutdown request"
);
ShutdownGuard {
service_name,
_maybe_guard: Some(sender),
}
}
.right_future()
}
}

#[derive(Clone)]
pub struct ShutdownGuard {
service_name: String,
_maybe_guard: Option<Arc<oneshot::Sender<()>>>,
}

impl Drop for ShutdownGuard {
fn drop(&mut self) {
event!(
Level::INFO,
"Service {} has completed shutdown.",
self.service_name
);
}
}
63 changes: 47 additions & 16 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ use nativelink_util::common::fs;
use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC};
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime};
use nativelink_util::origin_context::ActiveOriginContext;
use nativelink_util::shutdown_manager::ShutdownManager;
use nativelink_util::store_trait::Store;
use nativelink_util::{spawn, tls_utils};
use tokio::process;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::mpsc;
use tokio::time::sleep;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::Streaming;
Expand Down Expand Up @@ -168,7 +169,6 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
async fn run(
&mut self,
update_for_worker_stream: Streaming<UpdateForWorker>,
shutdown_rx: &mut broadcast::Receiver<Arc<oneshot::Sender<()>>>,
) -> Result<(), Error> {
// This big block of logic is designed to help simplify upstream components. Upstream
// components can write standard futures that return a `Result<(), Error>` and this block
Expand All @@ -190,9 +190,22 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,

let mut update_for_worker_stream = update_for_worker_stream.fuse();

// If we are shutting down we need to hold onto the shutdown guard
// until we are done processing all the futures.
let mut _maybe_shutdown_guard = None;
let wait_for_shutdown_fut = ShutdownManager::wait_for_shutdown("LocalWorker").fuse();
tokio::pin!(wait_for_shutdown_fut);
loop {
select! {
maybe_update = update_for_worker_stream.next() => {
if maybe_update.is_none() && ShutdownManager::is_shutting_down() {
event!(
Level::ERROR,
"Closed stream",
);
// Happy shutdown path, no need to log anything.
continue;
}
match maybe_update
.err_tip(|| "UpdateForWorker stream closed early")?
.err_tip(|| "Got error in UpdateForWorker stream")?
Expand Down Expand Up @@ -349,22 +362,39 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
let fut = res.err_tip(|| "New future stream receives should never be closed")?;
futures.push(fut);
},
res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??,
complete_msg = shutdown_rx.recv().fuse() => {
event!(Level::WARN, "Worker loop reveiced shutdown signal. Shutting down worker...",);
res = futures.next() => {
let res = res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")?;
// If we are shutting down and we get an error, we want to
// keep draining, but not reconnect.
if ShutdownManager::is_shutting_down() {
// If we are shutting down and we only have keep alive left,
// we can exit.
if futures.len() == 1 {
return Ok(());
}
if res.is_err() {
event!(
Level::ERROR,
"During shutdown future failed with error: {:?}", res.unwrap_err(),
);
continue;
}
}
// If we are not shutting down and get an error, return the error.
res?;
},
shutdown_guard = wait_for_shutdown_fut.as_mut() => {
_maybe_shutdown_guard = Some(shutdown_guard);
event!(Level::INFO, "Worker loop reveiced shutdown signal. Shutting down worker...",);
let mut grpc_client = self.grpc_client.clone();
let worker_id = self.worker_id.clone();
let running_actions_manager = self.running_actions_manager.clone();
let complete_msg_clone = complete_msg.map_err(|e| make_err!(Code::Internal, "Failed to receive shutdown message: {e:?}"))?.clone();
let shutdown_future = async move {
futures.push(async move {
if let Err(e) = grpc_client.going_away(GoingAwayRequest { worker_id }).await {
event!(Level::ERROR, "Failed to send GoingAwayRequest: {e}",);
return Err(e.into());
}
running_actions_manager.complete_actions(complete_msg_clone).await;
Ok::<(), Error>(())
};
futures.push(shutdown_future.boxed());
}.boxed());
},
};
}
Expand Down Expand Up @@ -526,10 +556,7 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
}

#[instrument(skip(self), level = Level::INFO)]
pub async fn run(
mut self,
mut shutdown_rx: broadcast::Receiver<Arc<oneshot::Sender<()>>>,
) -> Result<(), Error> {
pub async fn run(mut self) -> Result<(), Error> {
let sleep_fn = self
.sleep_fn
.take()
Expand Down Expand Up @@ -575,7 +602,11 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
);

// Now listen for connections and run all other services.
if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await {
let res = inner.run(update_for_worker_stream).await;
if ShutdownManager::is_shutting_down() {
return Ok(()); // Do not reconnect if we are shutting down.
}
if let Err(err) = res {
'no_more_actions: {
// Ensure there are no actions in transit before we try to kill
// all our actions.
Expand Down
16 changes: 0 additions & 16 deletions nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1349,11 +1349,6 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static {
hasher: DigestHasherFunc,
) -> impl Future<Output = Result<(), Error>> + Send;

fn complete_actions(
&self,
complete_msg: Arc<oneshot::Sender<()>>,
) -> impl Future<Output = ()> + Send;

fn kill_all(&self) -> impl Future<Output = ()> + Send;

fn kill_operation(
Expand Down Expand Up @@ -1884,17 +1879,6 @@ impl RunningActionsManager for RunningActionsManagerImpl {
Ok(())
}

// Waits for all running actions to complete and signals completion.
// Use the Arc<oneshot::Sender<()>> to signal the completion of the actions
// Dropping the sender automatically notifies the process to terminate.
async fn complete_actions(&self, _complete_msg: Arc<oneshot::Sender<()>>) {
let _ = self
.action_done_tx
.subscribe()
.wait_for(|_| self.running_actions.lock().is_empty())
.await;
}

// Note: When the future returns the process should be fully killed and cleaned up.
async fn kill_all(&self) {
self.metrics
Expand Down
Loading
Loading