From 5ec3bbb4988e053602de585c89c304db8545151f Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Wed, 16 Oct 2024 18:38:26 +0000 Subject: [PATCH] Allow adding `ClientInterceptor`s to specific services and methods --- .../Call/Client/ClientInterceptor.swift | 9 +- .../ClientInterceptorPipelineOperation.swift | 99 +++++++++++ .../Server/Internal/ServerRPCExecutor.swift | 3 +- .../ServerInterceptorPipelineOperation.swift | 3 +- Sources/GRPCCore/GRPCClient.swift | 49 ++++- Tests/GRPCCoreTests/GRPCClientTests.swift | 168 +++++++++++++++++- 6 files changed, 313 insertions(+), 18 deletions(-) create mode 100644 Sources/GRPCCore/Call/Client/ClientInterceptorPipelineOperation.swift diff --git a/Sources/GRPCCore/Call/Client/ClientInterceptor.swift b/Sources/GRPCCore/Call/Client/ClientInterceptor.swift index 939461e54..68a1fcf45 100644 --- a/Sources/GRPCCore/Call/Client/ClientInterceptor.swift +++ b/Sources/GRPCCore/Call/Client/ClientInterceptor.swift @@ -21,10 +21,11 @@ /// received from the transport. They are typically used for cross-cutting concerns like injecting /// metadata, validating messages, logging additional data, and tracing. /// -/// Interceptors are registered with a client and apply to all RPCs. If you need to modify the -/// behavior of an interceptor on a per-RPC basis then you can use the -/// ``ClientContext/descriptor`` to determine which RPC is being called and -/// conditionalise behavior accordingly. +/// Interceptors are registered with the server via ``ClientInterceptorPipelineOperation``s. +/// You may register them for all services registered with a server, for RPCs directed to specific services, or +/// for RPCs directed to specific methods. If you need to modify the behavior of an interceptor on a +/// per-RPC basis in more detail, then you can use the ``ClientContext/descriptor`` to determine +/// which RPC is being called and conditionalise behavior accordingly. /// /// - TODO: Update example and documentation to show how to register an interceptor. /// diff --git a/Sources/GRPCCore/Call/Client/ClientInterceptorPipelineOperation.swift b/Sources/GRPCCore/Call/Client/ClientInterceptorPipelineOperation.swift new file mode 100644 index 000000000..4ae2df8d5 --- /dev/null +++ b/Sources/GRPCCore/Call/Client/ClientInterceptorPipelineOperation.swift @@ -0,0 +1,99 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// A `ClientInterceptorPipelineOperation` describes to which RPCs a client interceptor should be applied. +/// +/// You can configure a client interceptor to be applied to: +/// - all RPCs and services; +/// - requests directed only to specific services; or +/// - requests directed only to specific methods (of a specific service). +/// +/// - SeeAlso: ``ClientInterceptor`` for more information on client interceptors, and +/// ``ServerInterceptorPipelineOperation`` for the server-side version of this type. +public struct ClientInterceptorPipelineOperation: Sendable { + /// The subject of a ``ClientInterceptorPipelineOperation``. + /// The subject of an interceptor can either be all services and methods, only specific services, or only specific methods. + public struct Subject: Sendable { + internal enum Wrapped: Sendable { + case all + case services(Set) + case methods(Set) + } + + private let wrapped: Wrapped + + /// An operation subject specifying an interceptor that applies to all RPCs across all services will be registered with this client. + public static var all: Self { .init(wrapped: .all) } + + /// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified services. + /// - Parameters: + /// - services: The list of service names for which this interceptor should intercept RPCs. + /// - Returns: A ``ClientInterceptorPipelineOperation``. + public static func services(_ services: Set) -> Self { + Self(wrapped: .services(services)) + } + + /// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified service methods. + /// - Parameters: + /// - methods: The list of method descriptors for which this interceptor should intercept RPCs. + /// - Returns: A ``ClientInterceptorPipelineOperation``. + public static func methods(_ methods: Set) -> Self { + Self(wrapped: .methods(methods)) + } + + @usableFromInline + internal func applies(to descriptor: MethodDescriptor) -> Bool { + switch self.wrapped { + case .all: + return true + + case .services(let services): + return services.map({ $0.fullyQualifiedService }).contains(descriptor.service) + + case .methods(let methods): + return methods.contains(descriptor) + } + } + } + + /// The interceptor specified for this operation. + public let interceptor: any ClientInterceptor + + @usableFromInline + internal let subject: Subject + + private init(interceptor: any ClientInterceptor, appliesTo: Subject) { + self.interceptor = interceptor + self.subject = appliesTo + } + + /// Create an operation, specifying which ``ClientInterceptor`` to apply and to which ``Subject``. + /// - Parameters: + /// - interceptor: The ``ClientInterceptor`` to register with the client. + /// - subject: The ``Subject`` to which the `interceptor` applies. + /// - Returns: A ``ClientInterceptorPipelineOperation``. + public static func apply(_ interceptor: any ClientInterceptor, to subject: Subject) -> Self { + Self(interceptor: interceptor, appliesTo: subject) + } + + /// Returns whether this ``ClientInterceptorPipelineOperation`` applies to the given `descriptor`. + /// - Parameter descriptor: A ``MethodDescriptor`` for which to test whether this interceptor applies. + /// - Returns: `true` if this interceptor applies to the given `descriptor`, or `false` otherwise. + @inlinable + internal func applies(to descriptor: MethodDescriptor) -> Bool { + self.subject.applies(to: descriptor) + } +} diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift index d9a35da51..aa2163424 100644 --- a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift @@ -23,7 +23,8 @@ struct ServerRPCExecutor { /// - stream: The accepted stream to execute the RPC on. /// - deserializer: A deserializer for messages received from the client. /// - serializer: A serializer for messages to send to the client. - /// - interceptors: Server interceptors to apply to this RPC. + /// - interceptors: Server interceptors to apply to this RPC. The + /// interceptors will be called in the order of the array. /// - handler: A handler which turns the request into a response. @inlinable static func execute( diff --git a/Sources/GRPCCore/Call/Server/ServerInterceptorPipelineOperation.swift b/Sources/GRPCCore/Call/Server/ServerInterceptorPipelineOperation.swift index 3d2731fd4..e511ea3ec 100644 --- a/Sources/GRPCCore/Call/Server/ServerInterceptorPipelineOperation.swift +++ b/Sources/GRPCCore/Call/Server/ServerInterceptorPipelineOperation.swift @@ -21,7 +21,8 @@ /// - requests directed only to specific services registered with your server; or /// - requests directed only to specific methods (of a specific service). /// -/// - SeeAlso: ``ServerInterceptor`` for more information on server interceptors. +/// - SeeAlso: ``ServerInterceptor`` for more information on server interceptors, and +/// ``ClientInterceptorPipelineOperation`` for the client-side version of this type. public struct ServerInterceptorPipelineOperation: Sendable { /// The subject of a ``ServerInterceptorPipelineOperation``. /// The subject of an interceptor can either be all services and methods, only specific services, or only specific methods. diff --git a/Sources/GRPCCore/GRPCClient.swift b/Sources/GRPCCore/GRPCClient.swift index 98c1c4f3d..19f50d0c7 100644 --- a/Sources/GRPCCore/GRPCClient.swift +++ b/Sources/GRPCCore/GRPCClient.swift @@ -112,13 +112,18 @@ public final class GRPCClient: Sendable { /// The transport which provides a bidirectional communication channel with the server. private let transport: any ClientTransport - /// A collection of interceptors providing cross-cutting functionality to each accepted RPC. + private let interceptorPipeline: [ClientInterceptorPipelineOperation] + + /// A collection of interceptors providing cross-cutting functionality to each accepted RPC, keyed by the method to which they apply. + /// + /// The list of interceptors for each method is computed from `interceptorsPipeline` when calling a method for the first time. + /// This caching is done to avoid having to compute the applicable interceptors for each request made. /// /// The order in which interceptors are added reflects the order in which they are called. The /// first interceptor added will be the first interceptor to intercept each request. The last /// interceptor added will be the final interceptor to intercept each request before calling /// the appropriate handler. - private let interceptors: [any ClientInterceptor] + private let interceptorsPerMethod: Mutex<[MethodDescriptor: [any ClientInterceptor]]> /// The current state of the client. private let state: Mutex @@ -191,17 +196,37 @@ public final class GRPCClient: Sendable { /// /// - Parameters: /// - transport: The transport used to establish a communication channel with a server. - /// - interceptors: A collection of interceptors providing cross-cutting functionality to each + /// - interceptors: A collection of ``ClientInterceptor``s providing cross-cutting functionality to each /// accepted RPC. The order in which interceptors are added reflects the order in which they /// are called. The first interceptor added will be the first interceptor to intercept each /// request. The last interceptor added will be the final interceptor to intercept each /// request before calling the appropriate handler. - public init( + convenience public init( transport: some ClientTransport, interceptors: [any ClientInterceptor] = [] + ) { + self.init( + transport: transport, + interceptorPipeline: interceptors.map { .apply($0, to: .all) } + ) + } + + /// Creates a new client with the given transport, interceptors and configuration. + /// + /// - Parameters: + /// - transport: The transport used to establish a communication channel with a server. + /// - interceptorPipeline: A collection of ``ClientInterceptorPipelineOperation`` providing cross-cutting + /// functionality to each accepted RPC. Only applicable interceptors from the pipeline will be applied to each RPC. + /// The order in which interceptors are added reflects the order in which they are called. + /// The first interceptor added will be the first interceptor to intercept each request. + /// The last interceptor added will be the final interceptor to intercept each request before calling the appropriate handler. + public init( + transport: some ClientTransport, + interceptorPipeline: [ClientInterceptorPipelineOperation] ) { self.transport = transport - self.interceptors = interceptors + self.interceptorPipeline = interceptorPipeline + self.interceptorsPerMethod = Mutex([:]) self.state = Mutex(.notStarted) } @@ -361,6 +386,18 @@ public final class GRPCClient: Sendable { var options = options options.formUnion(with: methodConfig) + let applicableInterceptors = self.interceptorsPerMethod.withLock { + if let interceptors = $0[descriptor] { + return interceptors + } else { + let interceptors = self.interceptorPipeline + .filter { $0.applies(to: descriptor) } + .map { $0.interceptor } + $0[descriptor] = interceptors + return interceptors + } + } + return try await ClientRPCExecutor.execute( request: request, method: descriptor, @@ -368,7 +405,7 @@ public final class GRPCClient: Sendable { serializer: serializer, deserializer: deserializer, transport: self.transport, - interceptors: self.interceptors, + interceptors: applicableInterceptors, handler: handler ) } diff --git a/Tests/GRPCCoreTests/GRPCClientTests.swift b/Tests/GRPCCoreTests/GRPCClientTests.swift index 42a6e3b3b..ed5396da1 100644 --- a/Tests/GRPCCoreTests/GRPCClientTests.swift +++ b/Tests/GRPCCoreTests/GRPCClientTests.swift @@ -16,16 +16,17 @@ import GRPCCore import GRPCInProcessTransport +import Testing import XCTest final class GRPCClientTests: XCTestCase { func withInProcessConnectedClient( services: [any RegistrableRPCService], - interceptors: [any ClientInterceptor] = [], + interceptorPipeline: [ClientInterceptorPipelineOperation] = [], _ body: (GRPCClient, GRPCServer) async throws -> Void ) async throws { let inProcess = InProcessTransport() - let client = GRPCClient(transport: inProcess.client, interceptors: interceptors) + let client = GRPCClient(transport: inProcess.client, interceptorPipeline: interceptorPipeline) let server = GRPCServer(transport: inProcess.server, services: services) try await withThrowingTaskGroup(of: Void.self) { group in @@ -234,10 +235,10 @@ final class GRPCClientTests: XCTestCase { try await self.withInProcessConnectedClient( services: [BinaryEcho()], - interceptors: [ - .requestCounter(counter1), - .rejectAll(with: RPCError(code: .unavailable, message: "")), - .requestCounter(counter2), + interceptorPipeline: [ + .apply(.requestCounter(counter1), to: .all), + .apply(.rejectAll(with: RPCError(code: .unavailable, message: "")), to: .all), + .apply(.requestCounter(counter2), to: .all), ] ) { client, _ in try await client.unary( @@ -409,3 +410,158 @@ final class GRPCClientTests: XCTestCase { task.cancel() } } + +@Suite("GRPC Client Tests") +struct ClientTests { + @Test("Interceptors are applied only to specified services") + func testInterceptorsAreAppliedToSpecifiedServices() async throws { + let onlyBinaryEchoCounter = AtomicCounter() + let allServicesCounter = AtomicCounter() + let onlyHelloWorldCounter = AtomicCounter() + let bothServicesCounter = AtomicCounter() + + try await self.withInProcessConnectedClient( + services: [BinaryEcho(), HelloWorld()], + interceptorPipeline: [ + .apply( + .requestCounter(onlyBinaryEchoCounter), + to: .services([BinaryEcho.serviceDescriptor]) + ), + .apply(.requestCounter(allServicesCounter), to: .all), + .apply( + .requestCounter(onlyHelloWorldCounter), + to: .services([HelloWorld.serviceDescriptor]) + ), + .apply( + .requestCounter(bothServicesCounter), + to: .services([BinaryEcho.serviceDescriptor, HelloWorld.serviceDescriptor]) + ), + ] + ) { client, _ in + // Make a request to the `BinaryEcho` service and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.unary( + request: .init(message: Array("hello".utf8)), + descriptor: BinaryEcho.Methods.get, + serializer: IdentitySerializer(), + deserializer: IdentityDeserializer(), + options: .defaults + ) { response in + let message = try #require(try response.message) + #expect(message == Array("hello".utf8)) + } + + #expect(onlyBinaryEchoCounter.value == 1) + #expect(allServicesCounter.value == 1) + #expect(onlyHelloWorldCounter.value == 0) + #expect(bothServicesCounter.value == 1) + + // Now, make a request to the `HelloWorld` service and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.unary( + request: .init(message: Array("Swift".utf8)), + descriptor: HelloWorld.Methods.sayHello, + serializer: IdentitySerializer(), + deserializer: IdentityDeserializer(), + options: .defaults + ) { response in + let message = try #require(try response.message) + #expect(message == Array("Hello, Swift!".utf8)) + } + + #expect(onlyBinaryEchoCounter.value == 1) + #expect(allServicesCounter.value == 2) + #expect(onlyHelloWorldCounter.value == 1) + #expect(bothServicesCounter.value == 2) + } + } + + @Test("Interceptors are applied only to specified methods") + func testInterceptorsAreAppliedToSpecifiedMethods() async throws { + let onlyBinaryEchoGetCounter = AtomicCounter() + let onlyBinaryEchoCollectCounter = AtomicCounter() + let bothBinaryEchoMethodsCounter = AtomicCounter() + let allMethodsCounter = AtomicCounter() + + try await self.withInProcessConnectedClient( + services: [BinaryEcho()], + interceptorPipeline: [ + .apply( + .requestCounter(onlyBinaryEchoGetCounter), + to: .methods([BinaryEcho.Methods.get]) + ), + .apply(.requestCounter(allMethodsCounter), to: .all), + .apply( + .requestCounter(onlyBinaryEchoCollectCounter), + to: .methods([BinaryEcho.Methods.collect]) + ), + .apply( + .requestCounter(bothBinaryEchoMethodsCounter), + to: .methods([BinaryEcho.Methods.get, BinaryEcho.Methods.collect]) + ), + ] + ) { client, _ in + // Make a request to the `BinaryEcho/get` method and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.unary( + request: .init(message: Array("hello".utf8)), + descriptor: BinaryEcho.Methods.get, + serializer: IdentitySerializer(), + deserializer: IdentityDeserializer(), + options: .defaults + ) { response in + let message = try #require(try response.message) + #expect(message == Array("hello".utf8)) + } + + #expect(onlyBinaryEchoGetCounter.value == 1) + #expect(allMethodsCounter.value == 1) + #expect(onlyBinaryEchoCollectCounter.value == 0) + #expect(bothBinaryEchoMethodsCounter.value == 1) + + // Now, make a request to the `BinaryEcho/collect` method and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.unary( + request: .init(message: Array("hello".utf8)), + descriptor: BinaryEcho.Methods.collect, + serializer: IdentitySerializer(), + deserializer: IdentityDeserializer(), + options: .defaults + ) { response in + let message = try #require(try response.message) + #expect(message == Array("hello".utf8)) + } + + #expect(onlyBinaryEchoGetCounter.value == 1) + #expect(allMethodsCounter.value == 2) + #expect(onlyBinaryEchoCollectCounter.value == 1) + #expect(bothBinaryEchoMethodsCounter.value == 2) + } + } + + func withInProcessConnectedClient( + services: [any RegistrableRPCService], + interceptorPipeline: [ClientInterceptorPipelineOperation] = [], + _ body: (GRPCClient, GRPCServer) async throws -> Void + ) async throws { + let inProcess = InProcessTransport() + let client = GRPCClient(transport: inProcess.client, interceptorPipeline: interceptorPipeline) + let server = GRPCServer(transport: inProcess.server, services: services) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await server.serve() + } + + group.addTask { + try await client.run() + } + + // Make sure both server and client are running + try await Task.sleep(for: .milliseconds(100)) + try await body(client, server) + client.beginGracefulShutdown() + server.beginGracefulShutdown() + } + } +}