diff --git a/Cargo.toml b/Cargo.toml index 33fe54582..f1c66a57b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" rust-version = "1.81.0" [profile.release] -lto = true +lto = false opt-level = 3 [profile.dev] diff --git a/nativelink-config/src/cas_server.rs b/nativelink-config/src/cas_server.rs index 3061c4f44..3ddde236f 100644 --- a/nativelink-config/src/cas_server.rs +++ b/nativelink-config/src/cas_server.rs @@ -406,6 +406,10 @@ pub struct ServerConfig { /// Services to attach to server. pub services: Option, + + /// Do not wait for connections to close during a graceful shutdown. + #[serde(default)] + pub experimental_connections_dont_block_graceful_shutdown: bool, } #[allow(non_camel_case_types)] diff --git a/nativelink-macro/src/lib.rs b/nativelink-macro/src/lib.rs index f37175569..9d5d3a77e 100644 --- a/nativelink-macro/src/lib.rs +++ b/nativelink-macro/src/lib.rs @@ -34,6 +34,7 @@ pub fn nativelink_test(attr: TokenStream, item: TokenStream) -> TokenStream { #[allow(clippy::disallowed_methods)] #[tokio::test(#attr)] async fn #fn_name(#fn_inputs) #fn_output { + nativelink_util::shutdown_manager::ShutdownManager::init(&tokio::runtime::Handle::current()); // Error means already initialized, which is ok. let _ = nativelink_util::init_tracing(); // If already set it's ok. diff --git a/nativelink-service/src/worker_api_server.rs b/nativelink-service/src/worker_api_server.rs index 2f875a2a4..9f37400fe 100644 --- a/nativelink-service/src/worker_api_server.rs +++ b/nativelink-service/src/worker_api_server.rs @@ -190,7 +190,7 @@ impl WorkerApiServer { ) -> Result, Error> { let worker_id: WorkerId = going_away_request.worker_id.try_into()?; self.scheduler - .remove_worker(&worker_id) + .set_drain_worker(&worker_id, true) .await .err_tip(|| "While calling WorkerApiServer::inner_going_away")?; Ok(Response::new(())) diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index ac17063f1..83e41f008 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -30,6 +30,7 @@ rust_library( "src/proto_stream_utils.rs", "src/resource_info.rs", "src/retry.rs", + "src/shutdown_manager.rs", "src/store_trait.rs", "src/task.rs", "src/tls_utils.rs", diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 17edbf700..9bf74f7bf 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -32,6 +32,7 @@ pub mod platform_properties; pub mod proto_stream_utils; pub mod resource_info; pub mod retry; +pub mod shutdown_manager; pub mod store_trait; pub mod task; pub mod tls_utils; diff --git a/nativelink-util/src/shutdown_manager.rs b/nativelink-util/src/shutdown_manager.rs new file mode 100644 index 000000000..89935cd02 --- /dev/null +++ b/nativelink-util/src/shutdown_manager.rs @@ -0,0 +1,181 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Weak}; + +use futures::future::ready; +use futures::FutureExt; +use parking_lot::Mutex; +use tokio::runtime::Handle; +#[cfg(target_family = "unix")] +use tokio::signal::unix::{signal, SignalKind}; +use tokio::sync::{broadcast, oneshot}; +use tracing::{event, Level}; + +static SHUTDOWN_MANAGER: ShutdownManager = ShutdownManager { + is_shutting_down: AtomicBool::new(false), + shutdown_tx: Mutex::new(None), // Will be initialized in `init`. + maybe_shutdown_weak_sender: Mutex::new(None), +}; + +/// Broadcast Channel Capacity +/// Note: The actual capacity may be greater than the provided capacity. +const BROADCAST_CAPACITY: usize = 1; + +/// ShutdownManager is a singleton that manages the shutdown of the +/// application. Services can register to be notified when a graceful +/// shutdown is initiated using [`ShutdownManager::wait_for_shutdown`]. +/// When the future returned by [`ShutdownManager::wait_for_shutdown`] is +/// completed, the caller will then be handed a [`ShutdownGuard`] which +/// must be held until the caller has completed its shutdown procedure. +/// Once the caller has completed its shutdown procedure, the caller +/// must drop the [`ShutdownGuard`]. When all [`ShutdownGuard`]s have +/// been dropped, the application will then exit. +pub struct ShutdownManager { + is_shutting_down: AtomicBool, + shutdown_tx: Mutex>>>>, + maybe_shutdown_weak_sender: Mutex>>>, +} + +impl ShutdownManager { + #[allow(clippy::disallowed_methods)] + pub fn init(runtime: &Handle) { + let (shutdown_tx, _) = broadcast::channel::>>(BROADCAST_CAPACITY); + *SHUTDOWN_MANAGER.shutdown_tx.lock() = Some(shutdown_tx); + + runtime.spawn(async move { + tokio::signal::ctrl_c() + .await + .expect("Failed to listen to SIGINT"); + event!(Level::WARN, "User terminated process via SIGINT"); + std::process::exit(130); + }); + + #[cfg(target_family = "unix")] + { + runtime.spawn(async move { + signal(SignalKind::terminate()) + .expect("Failed to listen to SIGTERM") + .recv() + .await; + event!(Level::WARN, "Received SIGTERM, begginning shutdown."); + Self::graceful_shutdown(); + }); + } + } + + pub fn is_shutting_down() -> bool { + SHUTDOWN_MANAGER.is_shutting_down.load(Ordering::Acquire) + } + + #[allow(clippy::disallowed_methods)] + fn graceful_shutdown() { + if SHUTDOWN_MANAGER + .is_shutting_down + .swap(true, Ordering::Release) + { + event!(Level::WARN, "Shutdown already in progress."); + return; + } + let (complete_tx, complete_rx) = oneshot::channel::<()>(); + let shutdown_guard = Arc::new(complete_tx); + SHUTDOWN_MANAGER + .maybe_shutdown_weak_sender + .lock() + .replace(Arc::downgrade(&shutdown_guard)) + .expect("Expected maybe_shutdown_weak_sender to be empty"); + tokio::spawn(async move { + { + let shutdown_tx_lock = SHUTDOWN_MANAGER.shutdown_tx.lock(); + // No need to check result of send, since it will only fail if + // all receivers have been dropped, in which case it means we + // can safely shutdown. + let _ = shutdown_tx_lock + .as_ref() + .expect("ShutdownManager was never initialized") + .send(shutdown_guard); + } + // It is impossible for the result to be anything but Err(RecvError), + // which means all receivers have been dropped and we can safely shutdown. + let _ = complete_rx.await; + event!(Level::WARN, "All services gracefully shutdown.",); + std::process::exit(143); + }); + } + + pub fn wait_for_shutdown(service_name: impl Into) -> impl Future + Send { + let service_name = service_name.into(); + if Self::is_shutting_down() { + let maybe_shutdown_weak_sender_lock = SHUTDOWN_MANAGER + .maybe_shutdown_weak_sender + .lock(); + let maybe_sender = maybe_shutdown_weak_sender_lock + .as_ref() + .expect("Expected maybe_shutdown_weak_sender to be set"); + if let Some(sender) = maybe_sender.upgrade() { + event!( + Level::INFO, + "Service {service_name} has been notified of shutdown request" + ); + return ready(ShutdownGuard { + service_name, + _maybe_guard: Some(sender), + }).left_future(); + } + return ready(ShutdownGuard { + service_name, + _maybe_guard: None, + }).left_future(); + } + let mut shutdown_receiver = SHUTDOWN_MANAGER + .shutdown_tx + .lock() + .as_ref() + .expect("ShutdownManager was never initialized") + .subscribe(); + async move { + let sender = shutdown_receiver + .recv() + .await + .expect("Shutdown sender dropped. This should never happen."); + event!( + Level::INFO, + "Service {service_name} has been notified of shutdown request" + ); + ShutdownGuard { + service_name, + _maybe_guard: Some(sender), + } + } + .right_future() + } +} + +#[derive(Clone)] +pub struct ShutdownGuard { + service_name: String, + _maybe_guard: Option>>, +} + +impl Drop for ShutdownGuard { + fn drop(&mut self) { + event!( + Level::INFO, + "Service {} has completed shutdown.", + self.service_name + ); + } +} diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 8a7f8b895..af7b70544 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -36,10 +36,11 @@ use nativelink_util::common::fs; use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC}; use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime}; use nativelink_util::origin_context::ActiveOriginContext; +use nativelink_util::shutdown_manager::ShutdownManager; use nativelink_util::store_trait::Store; use nativelink_util::{spawn, tls_utils}; use tokio::process; -use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::sync::mpsc; use tokio::time::sleep; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Streaming; @@ -168,7 +169,6 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, async fn run( &mut self, update_for_worker_stream: Streaming, - shutdown_rx: &mut broadcast::Receiver>>, ) -> Result<(), Error> { // This big block of logic is designed to help simplify upstream components. Upstream // components can write standard futures that return a `Result<(), Error>` and this block @@ -190,9 +190,22 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, let mut update_for_worker_stream = update_for_worker_stream.fuse(); + // If we are shutting down we need to hold onto the shutdown guard + // until we are done processing all the futures. + let mut _maybe_shutdown_guard = None; + let wait_for_shutdown_fut = ShutdownManager::wait_for_shutdown("LocalWorker").fuse(); + tokio::pin!(wait_for_shutdown_fut); loop { select! { maybe_update = update_for_worker_stream.next() => { + if maybe_update.is_none() && ShutdownManager::is_shutting_down() { + event!( + Level::ERROR, + "Closed stream", + ); + // Happy shutdown path, no need to log anything. + continue; + } match maybe_update .err_tip(|| "UpdateForWorker stream closed early")? .err_tip(|| "Got error in UpdateForWorker stream")? @@ -349,22 +362,39 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, let fut = res.err_tip(|| "New future stream receives should never be closed")?; futures.push(fut); }, - res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??, - complete_msg = shutdown_rx.recv().fuse() => { - event!(Level::WARN, "Worker loop reveiced shutdown signal. Shutting down worker...",); + res = futures.next() => { + let res = res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")?; + // If we are shutting down and we get an error, we want to + // keep draining, but not reconnect. + if ShutdownManager::is_shutting_down() { + // If we are shutting down and we only have keep alive left, + // we can exit. + if futures.len() == 1 { + return Ok(()); + } + if res.is_err() { + event!( + Level::ERROR, + "During shutdown future failed with error: {:?}", res.unwrap_err(), + ); + continue; + } + } + // If we are not shutting down and get an error, return the error. + res?; + }, + shutdown_guard = wait_for_shutdown_fut.as_mut() => { + _maybe_shutdown_guard = Some(shutdown_guard); + event!(Level::INFO, "Worker loop reveiced shutdown signal. Shutting down worker...",); let mut grpc_client = self.grpc_client.clone(); let worker_id = self.worker_id.clone(); - let running_actions_manager = self.running_actions_manager.clone(); - let complete_msg_clone = complete_msg.map_err(|e| make_err!(Code::Internal, "Failed to receive shutdown message: {e:?}"))?.clone(); - let shutdown_future = async move { + futures.push(async move { if let Err(e) = grpc_client.going_away(GoingAwayRequest { worker_id }).await { event!(Level::ERROR, "Failed to send GoingAwayRequest: {e}",); return Err(e.into()); } - running_actions_manager.complete_actions(complete_msg_clone).await; Ok::<(), Error>(()) - }; - futures.push(shutdown_future.boxed()); + }.boxed()); }, }; } @@ -526,10 +556,7 @@ impl LocalWorker { } #[instrument(skip(self), level = Level::INFO)] - pub async fn run( - mut self, - mut shutdown_rx: broadcast::Receiver>>, - ) -> Result<(), Error> { + pub async fn run(mut self) -> Result<(), Error> { let sleep_fn = self .sleep_fn .take() @@ -575,7 +602,11 @@ impl LocalWorker { ); // Now listen for connections and run all other services. - if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await { + let res = inner.run(update_for_worker_stream).await; + if ShutdownManager::is_shutting_down() { + return Ok(()); // Do not reconnect if we are shutting down. + } + if let Err(err) = res { 'no_more_actions: { // Ensure there are no actions in transit before we try to kill // all our actions. diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index b9c9c13aa..bb84687f8 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -1349,11 +1349,6 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static { hasher: DigestHasherFunc, ) -> impl Future> + Send; - fn complete_actions( - &self, - complete_msg: Arc>, - ) -> impl Future + Send; - fn kill_all(&self) -> impl Future + Send; fn kill_operation( @@ -1884,17 +1879,6 @@ impl RunningActionsManager for RunningActionsManagerImpl { Ok(()) } - // Waits for all running actions to complete and signals completion. - // Use the Arc> to signal the completion of the actions - // Dropping the sender automatically notifies the process to terminate. - async fn complete_actions(&self, _complete_msg: Arc>) { - let _ = self - .action_done_tx - .subscribe() - .wait_for(|_| self.running_actions.lock().is_empty()) - .await; - } - // Note: When the future returns the process should be fully killed and cleaned up. async fn kill_all(&self) { self.metrics diff --git a/nativelink-worker/tests/utils/local_worker_test_utils.rs b/nativelink-worker/tests/utils/local_worker_test_utils.rs index 6eb349ef4..63d2cb9cc 100644 --- a/nativelink-worker/tests/utils/local_worker_test_utils.rs +++ b/nativelink-worker/tests/utils/local_worker_test_utils.rs @@ -28,7 +28,7 @@ use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; use nativelink_worker::local_worker::LocalWorker; use nativelink_worker::worker_api_client_wrapper::WorkerApiClientTrait; -use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::sync::mpsc; use tonic::Status; use tonic::{ codec::Codec, // Needed for .decoder(). @@ -40,10 +40,6 @@ use tonic::{ use super::mock_running_actions_manager::MockRunningActionsManager; -/// Broadcast Channel Capacity -/// Note: The actual capacity may be greater than the provided capacity. -const BROADCAST_CAPACITY: usize = 1; - #[derive(Debug)] enum WorkerClientApiCalls { ConnectWorker(SupportedProperties), @@ -198,11 +194,8 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf }), Box::new(move |_| Box::pin(async move { /* No sleep */ })), ); - let (shutdown_tx_test, _) = broadcast::channel::>>(BROADCAST_CAPACITY); - let drop_guard = spawn!("local_worker_spawn", async move { - worker.run(shutdown_tx_test.subscribe()).await - }); + let drop_guard = spawn!("local_worker_spawn", worker.run()); let (tx_stream, streaming_response) = setup_grpc_stream(); TestContext { diff --git a/nativelink-worker/tests/utils/mock_running_actions_manager.rs b/nativelink-worker/tests/utils/mock_running_actions_manager.rs index 542ebd93b..680e2f59b 100644 --- a/nativelink-worker/tests/utils/mock_running_actions_manager.rs +++ b/nativelink-worker/tests/utils/mock_running_actions_manager.rs @@ -12,18 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::future; use std::sync::Arc; use async_lock::Mutex; -use futures::Future; use nativelink_error::{make_input_err, Error}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::StartExecute; use nativelink_util::action_messages::{ActionResult, OperationId}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager}; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; #[derive(Debug)] enum RunningActionManagerCalls { @@ -167,13 +165,6 @@ impl RunningActionsManager for MockRunningActionsManager { Ok(()) } - fn complete_actions( - &self, - _complete_msg: Arc>, - ) -> impl Future + Send { - future::ready(()) - } - async fn kill_operation(&self, operation_id: &OperationId) -> Result<(), Error> { self.tx_kill_operation .send(operation_id.clone()) diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 1c61b9f75..3c64944fe 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -21,6 +21,7 @@ use async_lock::Mutex as AsyncMutex; use axum::Router; use clap::Parser; use futures::future::{try_join_all, BoxFuture, Either, OptionFuture, TryFutureExt}; +use futures::FutureExt; use hyper::{Response, StatusCode}; use hyper_util::rt::tokio::TokioIo; use hyper_util::server::conn::auto; @@ -53,11 +54,12 @@ use nativelink_util::health_utils::HealthRegistryBuilder; use nativelink_util::metrics_utils::{set_metrics_enabled_for_this_thread, Counter}; use nativelink_util::operation_state_manager::ClientStateManager; use nativelink_util::origin_context::OriginContext; +use nativelink_util::shutdown_manager::ShutdownManager; use nativelink_util::store_trait::{ set_default_digest_size_health_check, DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG, }; use nativelink_util::task::TaskExecutor; -use nativelink_util::{background_spawn, init_tracing, spawn, spawn_blocking}; +use nativelink_util::{init_tracing, spawn, spawn_blocking}; use nativelink_worker::local_worker::new_local_worker; use opentelemetry::metrics::MeterProvider; use opentelemetry_sdk::metrics::SdkMeterProvider; @@ -67,9 +69,6 @@ use rustls_pemfile::{certs as extract_certs, crls as extract_crls}; use scopeguard::guard; use tokio::net::TcpListener; use tokio::select; -#[cfg(target_family = "unix")] -use tokio::signal::unix::{signal, SignalKind}; -use tokio::sync::{broadcast, oneshot}; use tokio_rustls::rustls::pki_types::{CertificateDer, CertificateRevocationListDer}; use tokio_rustls::rustls::server::WebPkiClientVerifier; use tokio_rustls::rustls::{RootCertStore, ServerConfig as TlsServerConfig}; @@ -94,10 +93,6 @@ const DEFAULT_HEALTH_STATUS_CHECK_PATH: &str = "/status"; /// Name of environment variable to disable metrics. const METRICS_DISABLE_ENV: &str = "NATIVELINK_DISABLE_METRICS"; -/// Broadcast Channel Capacity -/// Note: The actual capacity may be greater than the provided capacity. -const BROADCAST_CAPACITY: usize = 1; - /// Backend for bazel remote execution / cache API. #[derive(Parser, Debug)] #[clap( @@ -134,7 +129,7 @@ struct RootMetrics { impl RootMetricsComponent for RootMetrics {} /// Wrapper to allow us to hash `SocketAddr` for metrics. -#[derive(Hash, PartialEq, Eq)] +#[derive(Debug, Hash, PartialEq, Eq)] struct SocketAddrWrapper(SocketAddr); impl MetricsComponent for SocketAddrWrapper { @@ -166,7 +161,6 @@ impl RootMetricsComponent for ConnectedClientsMetrics {} async fn inner_main( cfg: CasConfig, server_start_timestamp: u64, - shutdown_tx: broadcast::Sender>>, ) -> Result<(), Box> { let health_registry_builder = Arc::new(AsyncMutex::new(HealthRegistryBuilder::new("nativelink"))); @@ -243,9 +237,15 @@ async fn inner_main( schedulers: action_schedulers.clone(), })); - for (server_cfg, connected_clients_mux) in servers_and_clients { + for (i, (server_cfg, connected_clients_mux)) in servers_and_clients.into_iter().enumerate() { let services = server_cfg.services.ok_or("'services' must be configured")?; + let name = if server_cfg.name.is_empty() { + format!("{i}") + } else { + server_cfg.name.clone() + }; + // Currently we only support http as our socket type. let ListenerConfig::http(http_config) = server_cfg.listener; @@ -776,10 +776,16 @@ async fn inner_main( if let Some(value) = http_config.experimental_http2_max_header_list_size { http.http2().max_header_list_size(value); } - event!(Level::WARN, "Ready, listening on {socket_addr}",); + + event!(Level::WARN, "Ready, listening on {socket_addr}"); root_futures.push(Box::pin(async move { + let shutdown_guard = Arc::new(Mutex::new(None)); + let socket_name = format!("TcpSocketListener_{name}"); + let wait_for_shutdown_fut = ShutdownManager::wait_for_shutdown(socket_name.clone()).fuse(); + tokio::pin!(wait_for_shutdown_fut); loop { select! { + biased; accept_result = tcp_listener.accept() => { match accept_result { Ok((tcp_stream, remote_addr)) => { @@ -796,27 +802,52 @@ async fn inner_main( .insert(SocketAddrWrapper(remote_addr)); connected_clients_mux.counter.inc(); + let shutdown_guard = shutdown_guard.clone(); + let socket_name_clone = socket_name.clone(); // This is the safest way to guarantee that if our future // is ever dropped we will cleanup our data. let scope_guard = guard( Arc::downgrade(&connected_clients_mux), move |weak_connected_clients_mux| { + let socket_name = socket_name_clone; event!( target: "nativelink::services", Level::INFO, ?remote_addr, ?socket_addr, + socket_name, "Client disconnected" ); if let Some(connected_clients_mux) = weak_connected_clients_mux.upgrade() { - connected_clients_mux - .inner - .lock() - .remove(&SocketAddrWrapper(remote_addr)); + let mut connected_clients = connected_clients_mux.inner.lock(); + connected_clients.remove(&SocketAddrWrapper(remote_addr)); + + if ShutdownManager::is_shutting_down() && !server_cfg.experimental_connections_dont_block_graceful_shutdown { + if connected_clients.is_empty() { + event!( + target: "nativelink::services", + Level::INFO, + ?remote_addr, + ?socket_addr, + socket_name, + "No more clients connected & received shutdown signal." + ); + drop(shutdown_guard.lock().take()); + } else { + event!( + target: "nativelink::services", + Level::INFO, + socket_name, + ?connected_clients, + "Waiting on {} more clients to disconnect before shutting down.", + connected_clients.len() + ); + } + } } }, ); - + let socket_name = socket_name.clone(); let (http, svc, maybe_tls_acceptor) = (http.clone(), svc.clone(), maybe_tls_acceptor.clone()); Arc::new(OriginContext::new()).background_spawn( @@ -826,11 +857,7 @@ async fn inner_main( ?remote_addr, ?socket_addr ), - async move {}, - ); - background_spawn!( - name: "http_connection", - fut: async move { + async move { // Move it into our spawn, so if our spawn dies the cleanup happens. let _guard = scope_guard; let serve_connection = if let Some(tls_acceptor) = maybe_tls_acceptor { @@ -850,19 +877,37 @@ async fn inner_main( TowerToHyperService::new(svc), )) }; - - if let Err(err) = serve_connection.await { - event!( - target: "nativelink::services", - Level::ERROR, - ?err, - "Failed running service" - ); + let connection_name = format!("Connection_{socket_name}_{remote_addr}"); + let wait_for_shutdown_fut = ShutdownManager::wait_for_shutdown(connection_name.clone()).fuse(); + tokio::pin!(wait_for_shutdown_fut); + tokio::pin!(serve_connection); + loop { + select! { + biased; + res = serve_connection.as_mut() => { + if let Err(err) = res { + event!( + target: "nativelink::services", + Level::ERROR, + ?err, + "Failed running service" + ); + } + break; + } + // Note: We don't need to hold onto this shutdown guard because + // we already have one captured in the outer scope. + _shutdown_guard = wait_for_shutdown_fut.as_mut() => { + if !server_cfg.experimental_connections_dont_block_graceful_shutdown { + match serve_connection.as_mut().as_pin_mut() { + Either::Left(conn) => conn.graceful_shutdown(), + Either::Right(conn) => conn.graceful_shutdown(), + } + } + }, + } } - }, - target: "nativelink::services", - ?remote_addr, - ?socket_addr, + } ); }, Err(err) => { @@ -871,6 +916,23 @@ async fn inner_main( } } }, + inner_shutdown_guard = wait_for_shutdown_fut.as_mut() => { + if server_cfg.experimental_connections_dont_block_graceful_shutdown { + event!( + target: "nativelink", + Level::INFO, + socket_name, + "Connections will not block graceful shutdown" + ); + continue; + } + let connected_clients = connected_clients_mux.inner.lock(); + if connected_clients.is_empty() { + drop(shutdown_guard.lock().take()); + } else { + *shutdown_guard.lock() = Some(inner_shutdown_guard); + } + } } } // Unreachable @@ -942,9 +1004,8 @@ async fn inner_main( } worker_names.insert(name.clone()); worker_metrics.insert(name.clone(), metrics); - let shutdown_rx = shutdown_tx.subscribe(); let fut = Arc::new(OriginContext::new()) - .wrap_async(trace_span!("worker_ctx"), local_worker.run(shutdown_rx)); + .wrap_async(trace_span!("worker_ctx"), local_worker.run()); spawn!("worker", fut, ?name) } }; @@ -1037,41 +1098,12 @@ fn main() -> Result<(), Box> { .on_thread_start(move || set_metrics_enabled_for_this_thread(metrics_enabled)) .build()?; - // Initiates the shutdown process by broadcasting the shutdown signal via the `oneshot::Sender` to all listeners. - // Each listener will perform its cleanup and then drop its `oneshot::Sender`, signaling completion. - // Once all `oneshot::Sender` instances are dropped, the worker knows it can safely terminate. - let (shutdown_tx, _) = broadcast::channel::>>(BROADCAST_CAPACITY); - let shutdown_tx_clone = shutdown_tx.clone(); - let (complete_tx, complete_rx) = oneshot::channel::<()>(); - - runtime.spawn(async move { - tokio::signal::ctrl_c() - .await - .expect("Failed to listen to SIGINT"); - eprintln!("User terminated process via SIGINT"); - std::process::exit(130); - }); - - #[cfg(target_family = "unix")] - { - let complete_tx = Arc::new(complete_tx); - runtime.spawn(async move { - signal(SignalKind::terminate()) - .expect("Failed to listen to SIGTERM") - .recv() - .await; - event!(Level::WARN, "Process terminated via SIGTERM",); - let _ = shutdown_tx_clone.send(complete_tx); - let _ = complete_rx.await; - event!(Level::WARN, "Successfully shut down nativelink.",); - std::process::exit(143); - }); - } + ShutdownManager::init(runtime.handle()); - let _ = runtime.block_on(Arc::new(OriginContext::new()).wrap_async( - trace_span!("main"), - inner_main(cfg, server_start_time, shutdown_tx), - )); + let _ = runtime.block_on( + Arc::new(OriginContext::new()) + .wrap_async(trace_span!("main"), inner_main(cfg, server_start_time)), + ); } Ok(()) }