diff --git a/bindings/rust/s2n-tls-tokio/tests/common/mod.rs b/bindings/rust/s2n-tls-tokio/tests/common/mod.rs index 8139941ab6b..b9fbf7832d8 100644 --- a/bindings/rust/s2n-tls-tokio/tests/common/mod.rs +++ b/bindings/rust/s2n-tls-tokio/tests/common/mod.rs @@ -95,3 +95,25 @@ where ); Ok((client?, server?)) } + +pub async fn get_tls_streams( + server_builder: A, + client_builder: B, +) -> Result< + ( + TlsStream, + TlsStream, + ), + Box, +> +where + ::Output: Unpin, + ::Output: Unpin, +{ + let (server_stream, client_stream) = get_streams().await?; + let connector = TlsConnector::new(client_builder); + let acceptor = TlsAcceptor::new(server_builder); + let (client_tls, server_tls) = + run_negotiate(&connector, client_stream, &acceptor, server_stream).await?; + Ok((server_tls, client_tls)) +} diff --git a/bindings/rust/s2n-tls-tokio/tests/tcp.rs b/bindings/rust/s2n-tls-tokio/tests/tcp.rs new file mode 100644 index 00000000000..05c2bb50e40 --- /dev/null +++ b/bindings/rust/s2n-tls-tokio/tests/tcp.rs @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +pub mod common; + +async fn assert_read_from_closed(mut reader: S, writer: S) +where + S: AsyncRead + AsyncWrite + Unpin, +{ + drop(writer); + let result = reader.read_u8().await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.kind() == std::io::ErrorKind::UnexpectedEof); +} + +#[tokio::test] +async fn match_tcp_read_from_closed() -> Result<(), Box> { + let (tcp_server, tcp_client) = common::get_streams().await?; + assert_read_from_closed(tcp_server, tcp_client).await; + + let (tls13_server, tls13_client) = common::get_tls_streams( + common::server_config()?.build()?, + common::client_config()?.build()?, + ) + .await?; + assert_read_from_closed(tls13_server, tls13_client).await; + + let (tls12_server, tls12_client) = common::get_tls_streams( + common::server_config_tls12()?.build()?, + common::client_config_tls12()?.build()?, + ) + .await?; + assert_read_from_closed(tls12_server, tls12_client).await; + Result::Ok(()) +} + +async fn assert_write_to_closed(reader: S, mut writer: S) +where + S: AsyncRead + AsyncWrite + Unpin, +{ + drop(reader); + let result = writer.write_u8(0).await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn match_tcp_write_to_closed() -> Result<(), Box> { + let (tcp_server, tcp_client) = common::get_streams().await?; + assert_write_to_closed(tcp_server, tcp_client).await; + + let (tls13_server, tls13_client) = common::get_tls_streams( + common::server_config()?.build()?, + common::client_config()?.build()?, + ) + .await?; + assert_write_to_closed(tls13_server, tls13_client).await; + + let (tls12_server, tls12_client) = common::get_tls_streams( + common::server_config_tls12()?.build()?, + common::client_config_tls12()?.build()?, + ) + .await?; + assert_write_to_closed(tls12_server, tls12_client).await; + Result::Ok(()) +} diff --git a/bindings/rust/s2n-tls/src/error.rs b/bindings/rust/s2n-tls/src/error.rs index 43ef766a310..9cdc731aa45 100644 --- a/bindings/rust/s2n-tls/src/error.rs +++ b/bindings/rust/s2n-tls/src/error.rs @@ -306,13 +306,19 @@ impl TryFrom for Error { impl From for std::io::Error { fn from(input: Error) -> Self { - if let Context::Code(_, errno) = input.0 { - if ErrorType::IOError == input.kind() { - let bare = std::io::Error::from_raw_os_error(errno.0); - return std::io::Error::new(bare.kind(), input); + let kind = match input.kind() { + ErrorType::IOError => { + if let Context::Code(_, errno) = input.0 { + let bare = std::io::Error::from_raw_os_error(errno.0); + bare.kind() + } else { + std::io::ErrorKind::Other + } } - } - std::io::Error::new(std::io::ErrorKind::Other, input) + ErrorType::ConnectionClosed => std::io::ErrorKind::UnexpectedEof, + _ => std::io::ErrorKind::Other, + }; + std::io::Error::new(kind, input) } }