diff --git a/src/extensions/client/endpoint.rs b/src/extensions/client/endpoint.rs index 3a212d3..f266d29 100644 --- a/src/extensions/client/endpoint.rs +++ b/src/extensions/client/endpoint.rs @@ -20,8 +20,10 @@ pub struct Endpoint { url: String, health: Arc, client_rx: tokio::sync::watch::Receiver>>, + reconnect_tx: tokio::sync::mpsc::Sender<()>, on_client_ready: Arc, background_tasks: Vec>, + connect_counter: Arc, } impl Drop for Endpoint { @@ -35,15 +37,18 @@ impl Endpoint { url: String, request_timeout: Option, connection_timeout: Option, - health_config: HealthCheckConfig, + health_config: Option, ) -> Self { let (client_tx, client_rx) = tokio::sync::watch::channel(None); + let (reconnect_tx, mut reconnect_rx) = tokio::sync::mpsc::channel(1); let on_client_ready = Arc::new(tokio::sync::Notify::new()); let health = Arc::new(Health::new(url.clone(), health_config)); + let connect_counter = Arc::new(AtomicU32::new(0)); let url_ = url.clone(); let health_ = health.clone(); let on_client_ready_ = on_client_ready.clone(); + let connect_counter_ = connect_counter.clone(); // This task will try to connect to the endpoint and keep the connection alive let connection_task = tokio::spawn(async move { @@ -51,6 +56,7 @@ impl Endpoint { loop { tracing::info!("Connecting endpoint: {url_}"); + connect_counter_.fetch_add(1, Ordering::Relaxed); let client = WsClientBuilder::default() .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) @@ -68,7 +74,15 @@ impl Endpoint { on_client_ready_.notify_waiters(); tracing::info!("Endpoint connected: {url_}"); connect_backoff_counter.store(0, Ordering::Relaxed); - client.on_disconnect().await; + + tokio::select! { + _ = reconnect_rx.recv() => { + tracing::debug!("Endpoint reconnect requested: {url_}"); + }, + _ = client.on_disconnect() => { + tracing::debug!("Endpoint disconnected: {url_}"); + } + } } Err(err) => { health_.on_error(&err); @@ -88,8 +102,10 @@ impl Endpoint { url, health, client_rx, + reconnect_tx, on_client_ready, background_tasks: vec![connection_task, health_checker], + connect_counter, } } @@ -108,24 +124,34 @@ impl Endpoint { self.on_client_ready.notified().await; } + pub fn connect_counter(&self) -> u32 { + self.connect_counter.load(Ordering::Relaxed) + } + pub async fn request( &self, method: &str, params: Vec, timeout: Duration, ) -> Result { - let client = self - .client_rx - .borrow() - .clone() - .ok_or(errors::failed("client not connected"))?; - - match tokio::time::timeout(timeout, client.request(method, params.clone())).await { - Ok(Ok(response)) => Ok(response), - Ok(Err(err)) => { - self.health.on_error(&err); - Err(err) + match tokio::time::timeout(timeout, async { + self.connected().await; + let client = self + .client_rx + .borrow() + .clone() + .ok_or(errors::failed("client not connected"))?; + match client.request(method, params.clone()).await { + Ok(resp) => Ok(resp), + Err(err) => { + self.health.on_error(&err); + Err(err) + } } + }) + .await + { + Ok(res) => res, Err(_) => { tracing::error!("request timed out method: {method} params: {params:?}"); self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); @@ -141,23 +167,27 @@ impl Endpoint { unsubscribe_method: &str, timeout: Duration, ) -> Result, jsonrpsee::core::Error> { - let client = self - .client_rx - .borrow() - .clone() - .ok_or(errors::failed("client not connected"))?; - - match tokio::time::timeout( - timeout, - client.subscribe(subscribe_method, params.clone(), unsubscribe_method), - ) + match tokio::time::timeout(timeout, async { + self.connected().await; + let client = self + .client_rx + .borrow() + .clone() + .ok_or(errors::failed("client not connected"))?; + match client + .subscribe(subscribe_method, params.clone(), unsubscribe_method) + .await + { + Ok(resp) => Ok(resp), + Err(err) => { + self.health.on_error(&err); + Err(err) + } + } + }) .await { - Ok(Ok(response)) => Ok(response), - Ok(Err(err)) => { - self.health.on_error(&err); - Err(err) - } + Ok(res) => res, Err(_) => { tracing::error!("subscribe timed out subscribe: {subscribe_method} params: {params:?}"); self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); @@ -165,4 +195,9 @@ impl Endpoint { } } } + + pub async fn reconnect(&self) { + // notify the client to reconnect + self.reconnect_tx.send(()).await.unwrap(); + } } diff --git a/src/extensions/client/health.rs b/src/extensions/client/health.rs index 7dd5612..36793c7 100644 --- a/src/extensions/client/health.rs +++ b/src/extensions/client/health.rs @@ -35,7 +35,7 @@ impl Event { #[derive(Debug, Default)] pub struct Health { url: String, - config: HealthCheckConfig, + config: Option, score: AtomicU32, unhealthy: tokio::sync::Notify, } @@ -44,7 +44,7 @@ const MAX_SCORE: u32 = 100; const THRESHOLD: u32 = MAX_SCORE / 2; impl Health { - pub fn new(url: String, config: HealthCheckConfig) -> Self { + pub fn new(url: String, config: Option) -> Self { Self { url, config, @@ -104,18 +104,18 @@ impl Health { on_client_ready: Arc, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - // no health method - if health.config.health_method.is_none() { - return; - } + let config = match health.config { + Some(ref config) => config, + None => return, + }; // Wait for the client to be ready before starting the health check on_client_ready.notified().await; - let method_name = health.config.health_method.as_ref().expect("checked above"); - let health_response = health.config.response.clone(); - let interval = Duration::from_secs(health.config.interval_sec); - let healthy_response_time = Duration::from_millis(health.config.healthy_response_time_ms); + let method_name = config.health_method.as_ref().expect("Invalid health config"); + let health_response = config.response.clone(); + let interval = Duration::from_secs(config.interval_sec); + let healthy_response_time = Duration::from_millis(config.healthy_response_time_ms); let client = match client_rx_.borrow().clone() { Some(client) => client, diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 837c17f..569d3e8 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -209,7 +209,6 @@ impl Client { retries: Option, health_config: Option, ) -> Result { - let health_config = health_config.unwrap_or_default(); let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); if endpoints.is_empty() { @@ -240,34 +239,48 @@ impl Client { let rotation_notify = Arc::new(Notify::new()); let rotation_notify_bg = rotation_notify.clone(); - let endpoints_ = endpoints.clone(); + let endpoints2 = endpoints.clone(); + let has_health_method = health_config.is_some(); + + let mut current_endpoint_idx = 0; + let mut selected_endpoint = endpoints[0].clone(); let background_task = tokio::spawn(async move { let request_backoff_counter = Arc::new(AtomicU32::new(0)); - // Select next endpoint with the highest health score, excluding the current one if provided - let healthiest_endpoint = |exclude: Option>| async { + // Select next endpoint with the highest health score, excluding the current one if possible + let select_healtiest = |endpoints: Vec>, current_idx: usize| async move { if endpoints.len() == 1 { let selected_endpoint = endpoints[0].clone(); // Ensure it's connected selected_endpoint.connected().await; - return selected_endpoint; + return (selected_endpoint, 0); } - let mut endpoints = endpoints.clone(); - // Remove the current endpoint from the list - if let Some(exclude) = exclude { - endpoints.retain(|e| e.url() != exclude.url()); - } // wait for at least one endpoint to connect futures::future::select_all(endpoints.iter().map(|x| x.connected().boxed())).await; - // Sort by health score - endpoints.sort_by_key(|endpoint| std::cmp::Reverse(endpoint.health().score())); - // Pick the first one - endpoints[0].clone() + + let (idx, endpoint) = endpoints + .iter() + .enumerate() + .filter(|(idx, _)| *idx != current_idx) + .max_by_key(|(_, endpoint)| endpoint.health().score()) + .expect("No endpoints"); + (endpoint.clone(), idx) }; - let mut selected_endpoint = healthiest_endpoint(None).await; + let select_next = |endpoints: Vec>, current_idx: usize| async move { + let idx = (current_idx + 1) % endpoints.len(); + (endpoints[idx].clone(), idx) + }; + + let next_endpoint = |current_idx| { + if has_health_method { + select_healtiest(endpoints2.clone(), current_idx).boxed() + } else { + select_next(endpoints2.clone(), current_idx).boxed() + } + }; let handle_message = |message: Message, endpoint: Arc, rotation_notify: Arc| { let tx = message_tx_bg.clone(); @@ -422,10 +435,15 @@ impl Client { _ = selected_endpoint.health().unhealthy() => { // Current selected endpoint is unhealthy, try to rotate to another one. // In case of all endpoints are unhealthy, we don't want to keep rotating but stick with the healthiest one. - let new_selected_endpoint = healthiest_endpoint(None).await; - if new_selected_endpoint.url() != selected_endpoint.url() { + + // The ws client maybe in a state that requires a reconnect + selected_endpoint.reconnect().await; + + let (new_selected_endpoint, new_current_endpoint_idx) = next_endpoint(current_endpoint_idx).await; + if new_current_endpoint_idx != current_endpoint_idx { tracing::warn!("Switch to endpoint: {new_url}", new_url=new_selected_endpoint.url()); selected_endpoint = new_selected_endpoint; + current_endpoint_idx = new_current_endpoint_idx; rotation_notify_bg.notify_waiters(); } } @@ -434,7 +452,7 @@ impl Client { match message { Some(Message::RotateEndpoint) => { tracing::info!("Rotating endpoint ..."); - selected_endpoint = healthiest_endpoint(Some(selected_endpoint.clone())).await; + (selected_endpoint, current_endpoint_idx) = next_endpoint(current_endpoint_idx).await; rotation_notify_bg.notify_waiters(); } Some(message) => handle_message(message, selected_endpoint.clone(), rotation_notify_bg.clone()), @@ -449,7 +467,7 @@ impl Client { }); Ok(Self { - endpoints: endpoints_, + endpoints, sender: message_tx, rotation_notify, retries: retries.unwrap_or(3), diff --git a/src/extensions/client/tests.rs b/src/extensions/client/tests.rs index c76a75d..4a7f126 100644 --- a/src/extensions/client/tests.rs +++ b/src/extensions/client/tests.rs @@ -61,11 +61,17 @@ async fn multiple_endpoints() { let (addr2, handle2, rx2, _) = dummy_server().await; let (addr3, handle3, rx3, _) = dummy_server().await; - let client = Client::with_endpoints([ - format!("ws://{addr1}"), - format!("ws://{addr2}"), - format!("ws://{addr3}"), - ]) + let client = Client::new( + [ + format!("ws://{addr1}"), + format!("ws://{addr2}"), + format!("ws://{addr3}"), + ], + None, + None, + None, + Some(Default::default()), + ) .unwrap(); let handle_requests = |mut rx: mpsc::Receiver, n: u32| { @@ -88,7 +94,7 @@ async fn multiple_endpoints() { let result = client.request("mock_rpc", vec![22.into()]).await.unwrap(); - assert_eq!(result.to_string(), "2"); + assert_eq!(result.to_string(), "3"); client.rotate_endpoint().await; @@ -96,7 +102,7 @@ async fn multiple_endpoints() { let result = client.request("mock_rpc", vec![33.into()]).await.unwrap(); - assert_eq!(result.to_string(), "3"); + assert_eq!(result.to_string(), "2"); handle3.stop().unwrap(); @@ -123,20 +129,23 @@ async fn concurrent_requests() { let req2 = rx.recv().await.unwrap(); let req3 = rx.recv().await.unwrap(); - req1.respond(JsonValue::from_str("1").unwrap()); - req2.respond(JsonValue::from_str("2").unwrap()); - req3.respond(JsonValue::from_str("3").unwrap()); + let p1 = req1.params.clone(); + let p2 = req2.params.clone(); + let p3 = req3.params.clone(); + req1.respond(p1); + req2.respond(p2); + req3.respond(p3); }); - let res1 = client.request("mock_rpc", vec![]); - let res2 = client.request("mock_rpc", vec![]); - let res3 = client.request("mock_rpc", vec![]); + let res1 = client.request("mock_rpc", vec![json!(1)]); + let res2 = client.request("mock_rpc", vec![json!(2)]); + let res3 = client.request("mock_rpc", vec![json!(3)]); let res = tokio::join!(res1, res2, res3); - assert_eq!(res.0.unwrap().to_string(), "1"); - assert_eq!(res.1.unwrap().to_string(), "2"); - assert_eq!(res.2.unwrap().to_string(), "3"); + assert_eq!(res.0.unwrap(), json!([1])); + assert_eq!(res.1.unwrap(), json!([2])); + assert_eq!(res.2.unwrap(), json!([3])); handle.stop().unwrap(); task.await.unwrap(); @@ -290,3 +299,46 @@ async fn health_check_works() { handle1.stop().unwrap(); handle2.stop().unwrap(); } + +#[tokio::test] +async fn reconnect_on_disconnect() { + let (addr1, handle1, mut rx1, _) = dummy_server().await; + let (addr2, handle2, mut rx2, _) = dummy_server().await; + + let client = Client::new( + [format!("ws://{addr1}"), format!("ws://{addr2}")], + Some(Duration::from_millis(100)), + None, + Some(2), + None, + ) + .unwrap(); + + let h1 = tokio::spawn(async move { + let _req = rx1.recv().await.unwrap(); + // no response, let it timeout + tokio::time::sleep(Duration::from_millis(200)).await; + }); + + let h2 = tokio::spawn(async move { + let req = rx2.recv().await.unwrap(); + req.respond(json!(1)); + }); + + let h3 = tokio::spawn(async move { + let res = client.request("mock_rpc", vec![]).await; + assert_eq!(res.unwrap(), json!(1)); + + tokio::time::sleep(Duration::from_millis(2000)).await; + + assert_eq!(client.endpoints()[0].connect_counter(), 2); + assert_eq!(client.endpoints()[1].connect_counter(), 1); + }); + + h3.await.unwrap(); + h1.await.unwrap(); + h2.await.unwrap(); + + handle1.stop().unwrap(); + handle2.stop().unwrap(); +} diff --git a/validate_config.yml b/validate_config.yml new file mode 100644 index 0000000..31e6b56 --- /dev/null +++ b/validate_config.yml @@ -0,0 +1,53 @@ +extensions: + client: + endpoints: + - wss://acala-rpc.dwellir.com + - wss://acala-rpc.aca-api.network + - wss://acala-rpc.aca-staging.network + - wss://acala-rpc-node-2.aca-api.network + - wss://acala-internal-rpc.aca-api.network:8443 + health_check: + interval_sec: 30 # check interval, default is 10s + healthy_response_time_ms: 500 # max response time to be considered healthy, default is 500ms + health_method: system_health + response: # response contains { isSyncing: false } + !contains + - - isSyncing + - !eq false + event_bus: + substrate_api: + stale_timeout_seconds: 180 # rotate endpoint if no new blocks for 3 minutes + telemetry: + provider: none + cache: + default_ttl_seconds: 60 + default_size: 500 + merge_subscription: + keep_alive_seconds: 60 + server: + port: 9944 + listen_address: '0.0.0.0' + max_connections: 2000 + http_methods: + - path: /health + method: system_health + - path: /liveness + method: chain_getBlockHash + cors: all + validator: + ignore_methods: + - system_health + - system_name + - system_version + - author_pendingExtrinsics + +middlewares: + methods: + - inject_params + - validate + - upstream + subscriptions: + - merge_subscription + - upstream + +rpcs: substrate