From da4600c4c7421b10f78486c46f4986fc49413ab5 Mon Sep 17 00:00:00 2001 From: Alex Yakunin Date: Wed, 22 Nov 2023 18:58:45 -0800 Subject: [PATCH] feat: add WebSocketOwner --- .../RpcWebSocketServer.cs | 10 ++-- src/Stl.Rpc.Server/RpcWebSocketServer.cs | 8 ++- src/Stl.Rpc/Clients/RpcWebSocketClient.cs | 58 ++++++++++--------- .../Configuration/RpcDefaultDelegates.cs | 5 +- src/Stl.Rpc/RpcBuilder.cs | 5 +- src/Stl.Rpc/WebSockets/WebSocketChannel.cs | 33 +++++------ src/Stl.Rpc/WebSockets/WebSocketOwner.cs | 46 +++++++++++++++ 7 files changed, 109 insertions(+), 56 deletions(-) create mode 100644 src/Stl.Rpc/WebSockets/WebSocketOwner.cs diff --git a/src/Stl.Rpc.Server.NetFx/RpcWebSocketServer.cs b/src/Stl.Rpc.Server.NetFx/RpcWebSocketServer.cs index 41acb16fb..bf2de98b9 100644 --- a/src/Stl.Rpc.Server.NetFx/RpcWebSocketServer.cs +++ b/src/Stl.Rpc.Server.NetFx/RpcWebSocketServer.cs @@ -67,14 +67,16 @@ public HttpStatusCode Invoke(IOwinContext context) [RequiresUnreferencedCode(UnreferencedCode.Serialization)] private async Task HandleWebSocket(IOwinContext context, WebSocketContext wsContext) { - var cancellationToken = context.Request.CallCancelled; - var webSocket = wsContext.WebSocket; - var peerRef = PeerRefFactory.Invoke(this, context); var peer = Hub.GetServerPeer(peerRef); + var cancellationToken = context.Request.CallCancelled; try { + var webSocket = wsContext.WebSocket; + var webSocketOwner = new WebSocketOwner(peer.Ref.Key, webSocket, Services); var channel = new WebSocketChannel( - Settings.WebSocketChannelOptions, webSocket, null, Services, cancellationToken); + Settings.WebSocketChannelOptions, webSocketOwner, cancellationToken) { + OwnsWebSocketOwner = false, + }; var options = ImmutableOptionSet.Empty.Set(context).Set(webSocket); var connection = await ServerConnectionFactory .Invoke(peer, channel, options, cancellationToken) diff --git a/src/Stl.Rpc.Server/RpcWebSocketServer.cs b/src/Stl.Rpc.Server/RpcWebSocketServer.cs index 6e4bdcc9e..fa02300b1 100644 --- a/src/Stl.Rpc.Server/RpcWebSocketServer.cs +++ b/src/Stl.Rpc.Server/RpcWebSocketServer.cs @@ -51,8 +51,11 @@ public async Task Invoke(HttpContext context) #endif var webSocket = await acceptWebSocketTask.ConfigureAwait(false); try { + var webSocketOwner = new WebSocketOwner(peer.Ref.Key, webSocket, Services); var channel = new WebSocketChannel( - Settings.WebSocketChannelOptions, webSocket, null, Services, cancellationToken); + Settings.WebSocketChannelOptions, webSocketOwner, cancellationToken) { + OwnsWebSocketOwner = false, + }; var options = ImmutableOptionSet.Empty.Set(context).Set(webSocket); var connection = await ServerConnectionFactory .Invoke(peer, channel, options, cancellationToken) @@ -64,5 +67,8 @@ public async Task Invoke(HttpContext context) catch (Exception e) when (e.IsCancellationOf(cancellationToken)) { // Intended: this is typically a normal connection termination } + finally { + webSocket.Dispose(); + } } } diff --git a/src/Stl.Rpc/Clients/RpcWebSocketClient.cs b/src/Stl.Rpc/Clients/RpcWebSocketClient.cs index b69615db8..669a9003b 100644 --- a/src/Stl.Rpc/Clients/RpcWebSocketClient.cs +++ b/src/Stl.Rpc/Clients/RpcWebSocketClient.cs @@ -2,25 +2,27 @@ using System.Net.WebSockets; using System.Text.Encodings.Web; using Stl.Rpc.Infrastructure; -using Stl.Rpc.Internal; using Stl.Rpc.WebSockets; namespace Stl.Rpc.Clients; public class RpcWebSocketClient( - RpcWebSocketClient.Options settings, - IServiceProvider services - ) : RpcClient(services) + RpcWebSocketClient.Options settings, + IServiceProvider services + ) : RpcClient(services) { public record Options { public static Options Default { get; set; } = new(); - public Func HostUrlResolver { get; init; } = DefaultHostUrlResolver; - public Func ConnectionUriResolver { get; init; } = DefaultConnectionUriResolver; - public WebSocketChannel.Options WebSocketChannelOptions { get; init; } = WebSocketChannel.Options.Default; - public Func> - WebSocketConnector { get; set; } = DefaultWebSocketConnector; + public Func HostUrlResolver { get; init; } + = DefaultHostUrlResolver; + public Func ConnectionUriResolver { get; init; } + = DefaultConnectionUriResolver; + public Func WebSocketOwnerFactory { get; init; } + = DefaultWebSocketOwnerFactory; + public WebSocketChannel.Options WebSocketChannelOptions { get; init; } + = WebSocketChannel.Options.Default; public TimeSpan ConnectTimeout { get; init; } = TimeSpan.FromSeconds(10); public string RequestPath { get; init; } = "/rpc/ws"; @@ -54,19 +56,8 @@ public static Uri DefaultConnectionUriResolver(RpcWebSocketClient client, RpcCli return uriBuilder.Uri; } - public static async Task<(WebSocket WebSocket, IDisposable? Helper)> DefaultWebSocketConnector( - RpcWebSocketClient client, Uri uri, CancellationToken cancellationToken) - { - var ws = client.Services.GetRequiredService(); - try { - await ws.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); - return (ws, null); - } - catch { - ws.Dispose(); - throw; - } - } + public static WebSocketOwner DefaultWebSocketOwnerFactory(RpcWebSocketClient client, RpcClientPeer peer) + => new(peer.Ref.Key, new ClientWebSocket(), client.Services); } public Options Settings { get; } = settings; @@ -79,13 +70,28 @@ public override async Task CreateConnection(RpcClientPeer peer, C var ctsToken = cts.Token; // ReSharper disable once UseAwaitUsing using var _ = cancellationToken.Register(static x => (x as CancellationTokenSource)?.Cancel(), cts); - var (webSocket, helper) = await Task - .Run(() => Settings.WebSocketConnector.Invoke(this, uri, ctsToken), ctsToken) + var webSocketOwner = await Task + .Run(async () => { + WebSocketOwner? o = null; + try { + o = Settings.WebSocketOwnerFactory.Invoke(this, peer); + await o.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); + return o; + } + catch { + if (o != null) + await o.DisposeAsync().ConfigureAwait(false); + throw; + } + }, ctsToken) .WaitAsync(ctsToken) // MAUI sometimes stuck in sync part of ConnectAsync .ConfigureAwait(false); - var channel = new WebSocketChannel(Settings.WebSocketChannelOptions, webSocket, helper, Services); - var options = ImmutableOptionSet.Empty.Set(uri).Set(webSocket); + var channel = new WebSocketChannel(Settings.WebSocketChannelOptions, webSocketOwner); + var options = ImmutableOptionSet.Empty + .Set(uri) + .Set(webSocketOwner) + .Set(webSocketOwner.WebSocket); return new RpcConnection(channel, options); } } diff --git a/src/Stl.Rpc/Configuration/RpcDefaultDelegates.cs b/src/Stl.Rpc/Configuration/RpcDefaultDelegates.cs index 54fb66f08..2c45591f4 100644 --- a/src/Stl.Rpc/Configuration/RpcDefaultDelegates.cs +++ b/src/Stl.Rpc/Configuration/RpcDefaultDelegates.cs @@ -1,9 +1,10 @@ -using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; using Stl.Generators; using Stl.Interception; -using Stl.Internal; +using Stl.Rpc.Clients; using Stl.Rpc.Diagnostics; using Stl.Rpc.Infrastructure; +using Stl.Rpc.WebSockets; namespace Stl.Rpc; diff --git a/src/Stl.Rpc/RpcBuilder.cs b/src/Stl.Rpc/RpcBuilder.cs index 4d024306b..2ec8a2af5 100644 --- a/src/Stl.Rpc/RpcBuilder.cs +++ b/src/Stl.Rpc/RpcBuilder.cs @@ -7,7 +7,6 @@ using Stl.Rpc.Diagnostics; using Stl.Rpc.Infrastructure; using Stl.Rpc.Internal; -using Stl.Rpc.Testing; using Errors = Stl.Rpc.Internal.Errors; namespace Stl.Rpc; @@ -73,11 +72,11 @@ internal RpcBuilder( services.TryAddSingleton(_ => RpcDefaultDelegates.CallRouter); services.TryAddSingleton(_ => RpcDefaultDelegates.InboundContextFactory); services.TryAddSingleton(_ => RpcDefaultDelegates.PeerFactory); + services.TryAddSingleton(_ => RpcDefaultDelegates.ClientConnectionFactory); + services.TryAddSingleton(_ => RpcDefaultDelegates.ServerConnectionFactory); services.TryAddSingleton(_ => RpcDefaultDelegates.ClientIdGenerator); services.TryAddSingleton(_ => RpcDefaultDelegates.BackendServiceDetector); services.TryAddSingleton(_ => RpcDefaultDelegates.UnrecoverableErrorDetector); - services.TryAddSingleton(_ => RpcDefaultDelegates.ClientConnectionFactory); - services.TryAddSingleton(_ => RpcDefaultDelegates.ServerConnectionFactory); services.TryAddSingleton(_ => RpcDefaultDelegates.MethodTracerFactory); services.TryAddSingleton(_ => RpcArgumentSerializer.Default); services.TryAddSingleton(c => new RpcInboundMiddlewares(c)); diff --git a/src/Stl.Rpc/WebSockets/WebSocketChannel.cs b/src/Stl.Rpc/WebSockets/WebSocketChannel.cs index fc5a0f960..02381f59f 100644 --- a/src/Stl.Rpc/WebSockets/WebSocketChannel.cs +++ b/src/Stl.Rpc/WebSockets/WebSocketChannel.cs @@ -16,7 +16,6 @@ public record Options { public static readonly Options Default = new(); - public bool OwnsWebSocket { get; init; } = true; public int WriteFrameSize { get; init; } = 4400; public int WriteBufferSize { get; init; } = 16_000; // Rented ~just once, so it can be large public int ReadBufferSize { get; init; } = 16_000; // Rented ~just once, so it can be large @@ -52,12 +51,13 @@ public record Options // ReSharper disable once InconsistentlySynchronizedField public Options Settings { get; } + public WebSocketOwner WebSocketOwner { get; } public WebSocket WebSocket { get; } - public IDisposable? WebSocketHelper { get; } public DualSerializer Serializer { get; } public CancellationToken StopToken { get; } public ILogger? Log { get; } public ILogger? ErrorLog { get; } + public bool OwnsWebSocketOwner { get; init; } = true; public Task WhenReadCompleted { get; } public Task WhenWriteCompleted { get; } @@ -65,26 +65,22 @@ public record Options [RequiresUnreferencedCode(UnreferencedCode.Serialization)] public WebSocketChannel( - WebSocket webSocket, - IDisposable? webSocketHelper, - IServiceProvider services, + WebSocketOwner webSocketOwner, CancellationToken cancellationToken = default) - : this(Options.Default, webSocket, webSocketHelper, services, cancellationToken) + : this(Options.Default, webSocketOwner, cancellationToken) { } [RequiresUnreferencedCode(UnreferencedCode.Serialization)] public WebSocketChannel( Options settings, - WebSocket webSocket, - IDisposable? webSocketHelper, - IServiceProvider services, + WebSocketOwner webSocketOwner, CancellationToken cancellationToken = default) { Settings = settings; - WebSocket = webSocket; - WebSocketHelper = webSocketHelper; + WebSocketOwner = webSocketOwner; + WebSocket = webSocketOwner.WebSocket; Serializer = settings.Serializer; - Log = services.LogFor(GetType()); + Log = webSocketOwner.Services.LogFor(GetType()); ErrorLog = Log.IfEnabled(LogLevel.Error); _stopCts = cancellationToken.CreateLinkedTokenSource(); @@ -135,13 +131,8 @@ public async ValueTask Close() stopCts.CancelAndDisposeSilently(); await WhenClosed.SilentAwait(false); - if (Settings.OwnsWebSocket) { - WebSocket.Dispose(); - if (WebSocketHelper is IAsyncDisposable ad) - await ad.DisposeAsync().SilentAwait(false); - else - WebSocketHelper?.Dispose(); - } + if (OwnsWebSocketOwner) + await WebSocketOwner.DisposeAsync().ConfigureAwait(false); _writeBuffer.Dispose(); } @@ -345,7 +336,9 @@ private bool TrySerialize(T value, ArrayPoolBuffer buffer) } catch (Exception e) { buffer.Index = startOffset; - ErrorLog?.LogError(e, "Couldn't serialize the value of type '{Type}'", value?.GetType().FullName ?? "null"); + ErrorLog?.LogError(e, + "Couldn't serialize the value of type '{Type}'", + value?.GetType().FullName ?? "null"); return false; } } diff --git a/src/Stl.Rpc/WebSockets/WebSocketOwner.cs b/src/Stl.Rpc/WebSockets/WebSocketOwner.cs new file mode 100644 index 000000000..949c75568 --- /dev/null +++ b/src/Stl.Rpc/WebSockets/WebSocketOwner.cs @@ -0,0 +1,46 @@ +using System.Net.WebSockets; +using Stl.Internal; + +namespace Stl.Rpc.WebSockets; + +public class WebSocketOwner( + string name, + WebSocket webSocket, + IServiceProvider services) + : SafeAsyncDisposableBase +{ + private ILogger? _log; + + public IServiceProvider Services { get; } = services; + public string Name { get; } = name; + public WebSocket WebSocket { get; } = webSocket; + public object? Handler { get; init; } + public LogLevel LogLevel { get; init; } = LogLevel.Information; + + protected ILogger Log => _log ??= Services.LogFor(GetType()); + + public virtual Task ConnectAsync(Uri uri, CancellationToken cancellationToken = default) + { + if (WebSocket is not ClientWebSocket webSocket) + throw Errors.MustBeAssignableTo(WebSocket.GetType()); + + Log.IfEnabled(LogLevel)?.Log(LogLevel, "'{Name}': connecting to {Uri}", Name, uri); +#if NET7_0_OR_GREATER + if (Handler is HttpMessageHandler handler) + return webSocket.ConnectAsync(uri, new HttpMessageInvoker(handler), cancellationToken); +#endif + return webSocket.ConnectAsync(uri, cancellationToken); + } + + protected override async Task DisposeAsync(bool disposing) + { + if (!disposing) + return; + + WebSocket.Dispose(); + if (Handler is IAsyncDisposable ad) + await ad.DisposeAsync().ConfigureAwait(false); + else if (Handler is IDisposable d) + d.Dispose(); + } +}