Skip to content

Commit

Permalink
feat: add WebSocketOwner
Browse files Browse the repository at this point in the history
  • Loading branch information
alexyakunin committed Nov 23, 2023
1 parent 1b8eb03 commit da4600c
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 56 deletions.
10 changes: 6 additions & 4 deletions src/Stl.Rpc.Server.NetFx/RpcWebSocketServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,16 @@ public HttpStatusCode Invoke(IOwinContext context)
[RequiresUnreferencedCode(UnreferencedCode.Serialization)]
private async Task HandleWebSocket(IOwinContext context, WebSocketContext wsContext)
{
var cancellationToken = context.Request.CallCancelled;
var webSocket = wsContext.WebSocket;

var peerRef = PeerRefFactory.Invoke(this, context);
var peer = Hub.GetServerPeer(peerRef);
var cancellationToken = context.Request.CallCancelled;
try {
var webSocket = wsContext.WebSocket;
var webSocketOwner = new WebSocketOwner(peer.Ref.Key, webSocket, Services);
var channel = new WebSocketChannel<RpcMessage>(
Settings.WebSocketChannelOptions, webSocket, null, Services, cancellationToken);
Settings.WebSocketChannelOptions, webSocketOwner, cancellationToken) {
OwnsWebSocketOwner = false,
};
var options = ImmutableOptionSet.Empty.Set(context).Set(webSocket);
var connection = await ServerConnectionFactory
.Invoke(peer, channel, options, cancellationToken)
Expand Down
8 changes: 7 additions & 1 deletion src/Stl.Rpc.Server/RpcWebSocketServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ public async Task Invoke(HttpContext context)
#endif
var webSocket = await acceptWebSocketTask.ConfigureAwait(false);
try {
var webSocketOwner = new WebSocketOwner(peer.Ref.Key, webSocket, Services);
var channel = new WebSocketChannel<RpcMessage>(
Settings.WebSocketChannelOptions, webSocket, null, Services, cancellationToken);
Settings.WebSocketChannelOptions, webSocketOwner, cancellationToken) {
OwnsWebSocketOwner = false,
};
var options = ImmutableOptionSet.Empty.Set(context).Set(webSocket);
var connection = await ServerConnectionFactory
.Invoke(peer, channel, options, cancellationToken)
Expand All @@ -64,5 +67,8 @@ public async Task Invoke(HttpContext context)
catch (Exception e) when (e.IsCancellationOf(cancellationToken)) {
// Intended: this is typically a normal connection termination
}
finally {
webSocket.Dispose();
}
}
}
58 changes: 32 additions & 26 deletions src/Stl.Rpc/Clients/RpcWebSocketClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
using System.Net.WebSockets;
using System.Text.Encodings.Web;
using Stl.Rpc.Infrastructure;
using Stl.Rpc.Internal;
using Stl.Rpc.WebSockets;

namespace Stl.Rpc.Clients;

public class RpcWebSocketClient(
RpcWebSocketClient.Options settings,
IServiceProvider services
) : RpcClient(services)
RpcWebSocketClient.Options settings,
IServiceProvider services
) : RpcClient(services)
{
public record Options
{
public static Options Default { get; set; } = new();

public Func<RpcWebSocketClient, RpcClientPeer, string> HostUrlResolver { get; init; } = DefaultHostUrlResolver;
public Func<RpcWebSocketClient, RpcClientPeer, Uri> ConnectionUriResolver { get; init; } = DefaultConnectionUriResolver;
public WebSocketChannel<RpcMessage>.Options WebSocketChannelOptions { get; init; } = WebSocketChannel<RpcMessage>.Options.Default;
public Func<RpcWebSocketClient, Uri, CancellationToken, Task<(WebSocket WebSocket, IDisposable? Helper)>>
WebSocketConnector { get; set; } = DefaultWebSocketConnector;
public Func<RpcWebSocketClient, RpcClientPeer, string> HostUrlResolver { get; init; }
= DefaultHostUrlResolver;
public Func<RpcWebSocketClient, RpcClientPeer, Uri> ConnectionUriResolver { get; init; }
= DefaultConnectionUriResolver;
public Func<RpcWebSocketClient, RpcClientPeer, WebSocketOwner> WebSocketOwnerFactory { get; init; }
= DefaultWebSocketOwnerFactory;
public WebSocketChannel<RpcMessage>.Options WebSocketChannelOptions { get; init; }
= WebSocketChannel<RpcMessage>.Options.Default;

public TimeSpan ConnectTimeout { get; init; } = TimeSpan.FromSeconds(10);
public string RequestPath { get; init; } = "/rpc/ws";
Expand Down Expand Up @@ -54,19 +56,8 @@ public static Uri DefaultConnectionUriResolver(RpcWebSocketClient client, RpcCli
return uriBuilder.Uri;
}

public static async Task<(WebSocket WebSocket, IDisposable? Helper)> DefaultWebSocketConnector(
RpcWebSocketClient client, Uri uri, CancellationToken cancellationToken)
{
var ws = client.Services.GetRequiredService<ClientWebSocket>();
try {
await ws.ConnectAsync(uri, cancellationToken).ConfigureAwait(false);
return (ws, null);
}
catch {
ws.Dispose();
throw;
}
}
public static WebSocketOwner DefaultWebSocketOwnerFactory(RpcWebSocketClient client, RpcClientPeer peer)
=> new(peer.Ref.Key, new ClientWebSocket(), client.Services);
}

public Options Settings { get; } = settings;
Expand All @@ -79,13 +70,28 @@ public override async Task<RpcConnection> CreateConnection(RpcClientPeer peer, C
var ctsToken = cts.Token;
// ReSharper disable once UseAwaitUsing
using var _ = cancellationToken.Register(static x => (x as CancellationTokenSource)?.Cancel(), cts);
var (webSocket, helper) = await Task
.Run(() => Settings.WebSocketConnector.Invoke(this, uri, ctsToken), ctsToken)
var webSocketOwner = await Task
.Run(async () => {
WebSocketOwner? o = null;
try {
o = Settings.WebSocketOwnerFactory.Invoke(this, peer);
await o.ConnectAsync(uri, cancellationToken).ConfigureAwait(false);
return o;
}
catch {
if (o != null)
await o.DisposeAsync().ConfigureAwait(false);
throw;
}
}, ctsToken)
.WaitAsync(ctsToken) // MAUI sometimes stuck in sync part of ConnectAsync
.ConfigureAwait(false);

var channel = new WebSocketChannel<RpcMessage>(Settings.WebSocketChannelOptions, webSocket, helper, Services);
var options = ImmutableOptionSet.Empty.Set(uri).Set(webSocket);
var channel = new WebSocketChannel<RpcMessage>(Settings.WebSocketChannelOptions, webSocketOwner);
var options = ImmutableOptionSet.Empty
.Set(uri)
.Set(webSocketOwner)
.Set(webSocketOwner.WebSocket);
return new RpcConnection(channel, options);
}
}
5 changes: 3 additions & 2 deletions src/Stl.Rpc/Configuration/RpcDefaultDelegates.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using System.Diagnostics.CodeAnalysis;
using System.Net.WebSockets;
using Stl.Generators;
using Stl.Interception;
using Stl.Internal;
using Stl.Rpc.Clients;
using Stl.Rpc.Diagnostics;
using Stl.Rpc.Infrastructure;
using Stl.Rpc.WebSockets;

namespace Stl.Rpc;

Expand Down
5 changes: 2 additions & 3 deletions src/Stl.Rpc/RpcBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using Stl.Rpc.Diagnostics;
using Stl.Rpc.Infrastructure;
using Stl.Rpc.Internal;
using Stl.Rpc.Testing;
using Errors = Stl.Rpc.Internal.Errors;

namespace Stl.Rpc;
Expand Down Expand Up @@ -73,11 +72,11 @@ internal RpcBuilder(
services.TryAddSingleton(_ => RpcDefaultDelegates.CallRouter);
services.TryAddSingleton(_ => RpcDefaultDelegates.InboundContextFactory);
services.TryAddSingleton(_ => RpcDefaultDelegates.PeerFactory);
services.TryAddSingleton(_ => RpcDefaultDelegates.ClientConnectionFactory);
services.TryAddSingleton(_ => RpcDefaultDelegates.ServerConnectionFactory);
services.TryAddSingleton(_ => RpcDefaultDelegates.ClientIdGenerator);
services.TryAddSingleton(_ => RpcDefaultDelegates.BackendServiceDetector);
services.TryAddSingleton(_ => RpcDefaultDelegates.UnrecoverableErrorDetector);
services.TryAddSingleton(_ => RpcDefaultDelegates.ClientConnectionFactory);
services.TryAddSingleton(_ => RpcDefaultDelegates.ServerConnectionFactory);
services.TryAddSingleton(_ => RpcDefaultDelegates.MethodTracerFactory);
services.TryAddSingleton(_ => RpcArgumentSerializer.Default);
services.TryAddSingleton(c => new RpcInboundMiddlewares(c));
Expand Down
33 changes: 13 additions & 20 deletions src/Stl.Rpc/WebSockets/WebSocketChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ public record Options
{
public static readonly Options Default = new();

public bool OwnsWebSocket { get; init; } = true;
public int WriteFrameSize { get; init; } = 4400;
public int WriteBufferSize { get; init; } = 16_000; // Rented ~just once, so it can be large
public int ReadBufferSize { get; init; } = 16_000; // Rented ~just once, so it can be large
Expand Down Expand Up @@ -52,39 +51,36 @@ public record Options
// ReSharper disable once InconsistentlySynchronizedField

public Options Settings { get; }
public WebSocketOwner WebSocketOwner { get; }
public WebSocket WebSocket { get; }
public IDisposable? WebSocketHelper { get; }
public DualSerializer<T> Serializer { get; }
public CancellationToken StopToken { get; }
public ILogger? Log { get; }
public ILogger? ErrorLog { get; }
public bool OwnsWebSocketOwner { get; init; } = true;

public Task WhenReadCompleted { get; }
public Task WhenWriteCompleted { get; }
public Task WhenClosed { get; }

[RequiresUnreferencedCode(UnreferencedCode.Serialization)]
public WebSocketChannel(
WebSocket webSocket,
IDisposable? webSocketHelper,
IServiceProvider services,
WebSocketOwner webSocketOwner,
CancellationToken cancellationToken = default)
: this(Options.Default, webSocket, webSocketHelper, services, cancellationToken)
: this(Options.Default, webSocketOwner, cancellationToken)
{ }

[RequiresUnreferencedCode(UnreferencedCode.Serialization)]
public WebSocketChannel(
Options settings,
WebSocket webSocket,
IDisposable? webSocketHelper,
IServiceProvider services,
WebSocketOwner webSocketOwner,
CancellationToken cancellationToken = default)
{
Settings = settings;
WebSocket = webSocket;
WebSocketHelper = webSocketHelper;
WebSocketOwner = webSocketOwner;
WebSocket = webSocketOwner.WebSocket;
Serializer = settings.Serializer;
Log = services.LogFor(GetType());
Log = webSocketOwner.Services.LogFor(GetType());
ErrorLog = Log.IfEnabled(LogLevel.Error);

_stopCts = cancellationToken.CreateLinkedTokenSource();
Expand Down Expand Up @@ -135,13 +131,8 @@ public async ValueTask Close()

stopCts.CancelAndDisposeSilently();
await WhenClosed.SilentAwait(false);
if (Settings.OwnsWebSocket) {
WebSocket.Dispose();
if (WebSocketHelper is IAsyncDisposable ad)
await ad.DisposeAsync().SilentAwait(false);
else
WebSocketHelper?.Dispose();
}
if (OwnsWebSocketOwner)
await WebSocketOwner.DisposeAsync().ConfigureAwait(false);
_writeBuffer.Dispose();
}

Expand Down Expand Up @@ -345,7 +336,9 @@ private bool TrySerialize(T value, ArrayPoolBuffer<byte> buffer)
}
catch (Exception e) {
buffer.Index = startOffset;
ErrorLog?.LogError(e, "Couldn't serialize the value of type '{Type}'", value?.GetType().FullName ?? "null");
ErrorLog?.LogError(e,
"Couldn't serialize the value of type '{Type}'",
value?.GetType().FullName ?? "null");
return false;
}
}
Expand Down
46 changes: 46 additions & 0 deletions src/Stl.Rpc/WebSockets/WebSocketOwner.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using System.Net.WebSockets;
using Stl.Internal;

namespace Stl.Rpc.WebSockets;

public class WebSocketOwner(
string name,
WebSocket webSocket,
IServiceProvider services)
: SafeAsyncDisposableBase
{
private ILogger? _log;

public IServiceProvider Services { get; } = services;
public string Name { get; } = name;
public WebSocket WebSocket { get; } = webSocket;
public object? Handler { get; init; }
public LogLevel LogLevel { get; init; } = LogLevel.Information;

protected ILogger Log => _log ??= Services.LogFor(GetType());

public virtual Task ConnectAsync(Uri uri, CancellationToken cancellationToken = default)
{
if (WebSocket is not ClientWebSocket webSocket)
throw Errors.MustBeAssignableTo<ClientWebSocket>(WebSocket.GetType());

Log.IfEnabled(LogLevel)?.Log(LogLevel, "'{Name}': connecting to {Uri}", Name, uri);
#if NET7_0_OR_GREATER
if (Handler is HttpMessageHandler handler)
return webSocket.ConnectAsync(uri, new HttpMessageInvoker(handler), cancellationToken);
#endif
return webSocket.ConnectAsync(uri, cancellationToken);
}

protected override async Task DisposeAsync(bool disposing)
{
if (!disposing)
return;

WebSocket.Dispose();
if (Handler is IAsyncDisposable ad)
await ad.DisposeAsync().ConfigureAwait(false);
else if (Handler is IDisposable d)
d.Dispose();
}
}

0 comments on commit da4600c

Please sign in to comment.