diff --git a/Google.Api.Gax.Grpc.IntegrationTests/ChannelPoolTest.cs b/Google.Api.Gax.Grpc.IntegrationTests/ChannelPoolTest.cs index 37c20650..e87da357 100644 --- a/Google.Api.Gax.Grpc.IntegrationTests/ChannelPoolTest.cs +++ b/Google.Api.Gax.Grpc.IntegrationTests/ChannelPoolTest.cs @@ -14,6 +14,7 @@ public class ChannelPoolTest { private static readonly GrpcAdapter Grpc = GrpcCoreAdapter.Instance; private static readonly ServiceMetadata ServiceMetadata = TestServiceMetadata.TestService; + private const string DefaultUniverseDomain = TestServiceMetadata.DefaultUniverseDomain; [Fact] public void SameEndpoint_SameChannel() @@ -21,9 +22,9 @@ public void SameEndpoint_SameChannel() var pool = new ChannelPool(ServiceMetadata); using (var fixture = new TestServiceFixture()) { - var channel1 = pool.GetChannel(Grpc, fixture.Endpoint, GrpcChannelOptions.Empty); - var channel2 = pool.GetChannel(Grpc, fixture.Endpoint, GrpcChannelOptions.Empty); - Assert.Same(channel1, channel2); + var channel1 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, GrpcChannelOptions.Empty); + var channel2 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, GrpcChannelOptions.Empty); + Assert.Same(channel1, channel2); } } @@ -33,8 +34,8 @@ public void DifferentEndpoint_DifferentChannel() var pool = new ChannelPool(ServiceMetadata); using (TestServiceFixture fixture1 = new TestServiceFixture(), fixture2 = new TestServiceFixture()) { - var channel1 = pool.GetChannel(Grpc, fixture1.Endpoint, GrpcChannelOptions.Empty); - var channel2 = pool.GetChannel(Grpc, fixture2.Endpoint, GrpcChannelOptions.Empty); + var channel1 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture1.Endpoint, GrpcChannelOptions.Empty); + var channel2 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture2.Endpoint, GrpcChannelOptions.Empty); Assert.NotSame(channel1, channel2); } } @@ -47,8 +48,8 @@ public void SameOptions_SameChannel() var pool = new ChannelPool(ServiceMetadata); using (var fixture = new TestServiceFixture()) { - var channel1 = pool.GetChannel(Grpc, fixture.Endpoint, options1); - var channel2 = pool.GetChannel(Grpc, fixture.Endpoint, options2); + var channel1 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, options1); + var channel2 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, options2); Assert.Same(channel1, channel2); } } @@ -61,8 +62,32 @@ public void DifferentOptions_DifferentChannel() var pool = new ChannelPool(ServiceMetadata); using (var fixture = new TestServiceFixture()) { - var channel1 = pool.GetChannel(Grpc, fixture.Endpoint, options1); - var channel2 = pool.GetChannel(Grpc, fixture.Endpoint, options2); + var channel1 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, options1); + var channel2 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, options2); + Assert.NotSame(channel1, channel2); + } + } + + [Fact] + public void SameUniverseDomain_SameChannel() + { + var pool = new ChannelPool(ServiceMetadata); + using (var fixture = new TestServiceFixture()) + { + var channel1 = pool.GetChannel(Grpc, "nowhere.com", fixture.Endpoint, GrpcChannelOptions.Empty); + var channel2 = pool.GetChannel(Grpc, "nowhere.com", fixture.Endpoint, GrpcChannelOptions.Empty); + Assert.Same(channel1, channel2); + } + } + + [Fact] + public void DifferentUniverseDomain_DifferentChannel() + { + var pool = new ChannelPool(ServiceMetadata); + using (TestServiceFixture fixture = new TestServiceFixture()) + { + var channel1 = pool.GetChannel(Grpc, "nowhere.com", fixture.Endpoint, GrpcChannelOptions.Empty); + var channel2 = pool.GetChannel(Grpc, "somewhere.com", fixture.Endpoint, GrpcChannelOptions.Empty); Assert.NotSame(channel1, channel2); } } @@ -73,7 +98,7 @@ public async Task ShutdownAsync_ShutsDownChannel() var pool = new ChannelPool(ServiceMetadata); using (var fixture = new TestServiceFixture()) { - var channel = (Channel) pool.GetChannel(Grpc, fixture.Endpoint, GrpcChannelOptions.Empty); + var channel = (Channel) pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, GrpcChannelOptions.Empty); Assert.NotEqual(ChannelState.Shutdown, channel.State); await pool.ShutdownChannelsAsync(); Assert.Equal(ChannelState.Shutdown, channel.State); @@ -86,10 +111,10 @@ public void ShutdownAsync_EmptiesPool() var pool = new ChannelPool(ServiceMetadata); using (var fixture = new TestServiceFixture()) { - var channel1 = pool.GetChannel(Grpc, fixture.Endpoint, GrpcChannelOptions.Empty); + var channel1 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, GrpcChannelOptions.Empty); // Note: *not* waiting for this to complete. pool.ShutdownChannelsAsync(); - var channel2 = pool.GetChannel(Grpc, fixture.Endpoint, GrpcChannelOptions.Empty); + var channel2 = pool.GetChannel(Grpc, DefaultUniverseDomain, fixture.Endpoint, GrpcChannelOptions.Empty); Assert.NotSame(channel1, channel2); } } diff --git a/Google.Api.Gax.Grpc.IntegrationTests/ClientBuilderBaseTest.cs b/Google.Api.Gax.Grpc.IntegrationTests/ClientBuilderBaseTest.cs index 1e303abd..c5f9513e 100644 --- a/Google.Api.Gax.Grpc.IntegrationTests/ClientBuilderBaseTest.cs +++ b/Google.Api.Gax.Grpc.IntegrationTests/ClientBuilderBaseTest.cs @@ -54,7 +54,7 @@ public async Task DefaultsToChannelPool() { var builder = new SampleClientBuilder(); - ChannelBase channelFromPool = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); + ChannelBase channelFromPool = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.DefaultUniverseDomain, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); Action validator = invoker => { @@ -71,8 +71,8 @@ public async Task DifferentEndpoint_StillFromChannelPool() var endpoint = "custom.nowhere.com"; var builder = new SampleClientBuilder { Endpoint = endpoint }; - ChannelBase channelFromPoolWithDefaultEndpoint = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); - ChannelBase channelFromPoolWithCustomEndpoint = builder.ChannelPool.GetChannel(fakeGrpcAdapter, "custom.nowhere.com", SampleClientBuilder.DefaultOptions); + ChannelBase channelFromPoolWithDefaultEndpoint = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.DefaultUniverseDomain, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); + ChannelBase channelFromPoolWithCustomEndpoint = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.DefaultUniverseDomain, "custom.nowhere.com", SampleClientBuilder.DefaultOptions); Action validator = invoker => { @@ -84,6 +84,26 @@ public async Task DifferentEndpoint_StillFromChannelPool() await ValidateResultAsync(builder, validator); } + [Fact] + public async Task DifferentUniverseDomain_StillFromChannelPool() + { + var universeDomain = "nowhere.com"; + var builder = new SampleClientBuilder { UniverseDomain = universeDomain }; + + ChannelBase channelFromPoolWithDefaultUniverseDomain = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.DefaultUniverseDomain, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); + // The endpoint changes with the universe domain so we cannot use the default endpoint here, even if we are just changing the universe domain. + ChannelBase channelFromPoolWithCustomUniverseDomain = builder.ChannelPool.GetChannel(fakeGrpcAdapter, universeDomain, builder.EffectiveEndpoint, SampleClientBuilder.DefaultOptions); + + Action validator = invoker => + { + var channelFromBuilder = GetChannel(invoker); + Assert.Same(channelFromPoolWithCustomUniverseDomain, channelFromBuilder); + Assert.NotSame(channelFromPoolWithDefaultUniverseDomain, channelFromBuilder); + Assert.Null(builder.LastCreatedChannel); + }; + await ValidateResultAsync(builder, validator); + } + [Fact] public async Task CustomChannelCredentials() { @@ -195,7 +215,7 @@ public async Task TokenAccessMethod() public async Task JwtClientEnabledTest(bool clientUsesJwt, bool poolUsesJwt) { var builder = new SampleClientBuilder(clientUsesJwt, poolUsesJwt) { JsonCredentials = DummyServiceAccountCredentialFileContents }; - ChannelBase channelFromPool = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); + ChannelBase channelFromPool = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.DefaultUniverseDomain, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); // Jwt of client does not match pool, so we won't use channel pool await ValidateResultAsync(builder, AssertNonChannelPool(builder)); @@ -206,7 +226,7 @@ public async Task JwtClientEnabledTest(bool clientUsesJwt, bool poolUsesJwt) public async Task JwtClientAndPoolEnabledTest(bool enabledJwts) { var builder = new SampleClientBuilder(enabledJwts, enabledJwts); - ChannelBase channelFromPool = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); + ChannelBase channelFromPool = builder.ChannelPool.GetChannel(fakeGrpcAdapter, TestServiceMetadata.DefaultUniverseDomain, TestServiceMetadata.TestService.DefaultEndpoint, SampleClientBuilder.DefaultOptions); // Jwt is either enabled or disabled for both client and pool // We use channel pool @@ -290,6 +310,10 @@ public class SampleClientBuilder : ClientBuilderBase public string EndpointUsedToCreateChannel { get; private set; } public ChannelCredentials CredentialsUsedToCreateChannel { get; private set; } + public new string EffectiveUniverseDomain => base.EffectiveUniverseDomain; + + public new string EffectiveEndpoint => base.EffectiveEndpoint; + private readonly string _name; /// @@ -331,7 +355,7 @@ public override CallInvoker Build() } protected override ChannelPool GetChannelPool() => ChannelPool; - + public void ResetChannelCreation() { EndpointUsedToCreateChannel = null; diff --git a/Google.Api.Gax.Grpc.IntegrationTests/Gcp/GrpcCallInvokerPoolTest.cs b/Google.Api.Gax.Grpc.IntegrationTests/Gcp/GrpcCallInvokerPoolTest.cs index c182c70f..af1ca28f 100644 --- a/Google.Api.Gax.Grpc.IntegrationTests/Gcp/GrpcCallInvokerPoolTest.cs +++ b/Google.Api.Gax.Grpc.IntegrationTests/Gcp/GrpcCallInvokerPoolTest.cs @@ -6,9 +6,6 @@ */ using Google.Api.Gax.Grpc.IntegrationTests; -using Google.Protobuf.Reflection; -using System.Collections.Generic; -using System.Linq; using Xunit; namespace Google.Api.Gax.Grpc.Gcp.IntegrationTests @@ -20,23 +17,23 @@ public class GcpCallInvokerPoolTest private static readonly ApiConfig Config1 = new ApiConfig { ChannelPool = new ChannelPoolConfig { MaxSize = 5 } }; [Fact] - public void SameEndpointAndOptions_SameCallInvoker() + public void SameEndpointAndOptionsAndUniverseDomain_SameCallInvoker() { var pool = new GcpCallInvokerPool(TestServiceMetadata.TestService); var options = GrpcChannelOptions.Empty.WithPrimaryUserAgent("abc"); - var callInvoker1 = pool.GetCallInvoker("endpoint", options, Config1, FakeAdapter); - var callInvoker2 = pool.GetCallInvoker("endpoint", options, Config1, FakeAdapter); + var callInvoker1 = pool.GetCallInvoker("domain", "endpoint", options, Config1, FakeAdapter); + var callInvoker2 = pool.GetCallInvoker("domain", "endpoint", options, Config1, FakeAdapter); Assert.Same(callInvoker1, callInvoker2); } [Fact] - public void SameEndpointAndEqualOptions_SameCallInvoker() + public void SameEndpointAndUniverseDomainAndEqualOptions_SameCallInvoker() { var pool = new GcpCallInvokerPool(TestServiceMetadata.TestService); var options1 = GrpcChannelOptions.Empty.WithPrimaryUserAgent("abc"); var options2 = GrpcChannelOptions.Empty.WithPrimaryUserAgent("abc"); - var callInvoker1 = pool.GetCallInvoker("endpoint", options1, Config1, FakeAdapter); - var callInvoker2 = pool.GetCallInvoker("endpoint", options2, Config1, FakeAdapter); + var callInvoker1 = pool.GetCallInvoker("domain", "endpoint", options1, Config1, FakeAdapter); + var callInvoker2 = pool.GetCallInvoker("domain", "endpoint", options2, Config1, FakeAdapter); Assert.Same(callInvoker1, callInvoker2); } @@ -45,8 +42,8 @@ public void DifferentEndpoint_DifferentCallInvoker() { var pool = new GcpCallInvokerPool(TestServiceMetadata.TestService); var options = GrpcChannelOptions.Empty.WithPrimaryUserAgent("abc"); - var callInvoker1 = pool.GetCallInvoker("endpoint1", options, Config1, FakeAdapter); - var callInvoker2 = pool.GetCallInvoker("endpoint2", options, Config1, FakeAdapter); + var callInvoker1 = pool.GetCallInvoker("domain", "endpoint1", options, Config1, FakeAdapter); + var callInvoker2 = pool.GetCallInvoker("domain", "endpoint2", options, Config1, FakeAdapter); Assert.NotSame(callInvoker1, callInvoker2); } @@ -56,23 +53,33 @@ public void DifferentOptions_DifferentCallInvoker() var pool = new GcpCallInvokerPool(TestServiceMetadata.TestService); var options1 = GrpcChannelOptions.Empty.WithPrimaryUserAgent("abc"); var options2 = GrpcChannelOptions.Empty.WithPrimaryUserAgent("def"); - var callInvoker1 = pool.GetCallInvoker("endpoint", options1, Config1, FakeAdapter); - var callInvoker2 = pool.GetCallInvoker("endpoint", options2, Config1, FakeAdapter); + var callInvoker1 = pool.GetCallInvoker("domain", "endpoint", options1, Config1, FakeAdapter); + var callInvoker2 = pool.GetCallInvoker("domain", "endpoint", options2, Config1, FakeAdapter); Assert.NotSame(callInvoker1, callInvoker2); - var callInvoker3 = pool.GetCallInvoker("endpoint", options: null, Config1, FakeAdapter); + var callInvoker3 = pool.GetCallInvoker("domain", "endpoint", options: null, Config1, FakeAdapter); Assert.NotSame(callInvoker1, callInvoker3); } + [Fact] + public void DifferentUniverseDomain_DifferentCallInvoker() + { + var pool = new GcpCallInvokerPool(TestServiceMetadata.TestService); + var options = GrpcChannelOptions.Empty.WithPrimaryUserAgent("abc"); + var callInvoker1 = pool.GetCallInvoker("domain1", "endpoint", options, Config1, FakeAdapter); + var callInvoker2 = pool.GetCallInvoker("domain2", "endpoint", options, Config1, FakeAdapter); + Assert.NotSame(callInvoker1, callInvoker2); + } + // TODO: equal/non-equal configs, different adapters. [Fact] public void ShutdownAsync_EmptiesPool() { var pool = new GcpCallInvokerPool(TestServiceMetadata.TestService); - var callInvoker1 = pool.GetCallInvoker("endpoint", options: null, Config1, FakeAdapter); + var callInvoker1 = pool.GetCallInvoker("domain", "endpoint", options: null, Config1, FakeAdapter); // Note: *not* waiting for this to complete. pool.ShutdownChannelsAsync(); - var callInvoker2 = pool.GetCallInvoker("endpoint", options: null, Config1, FakeAdapter); + var callInvoker2 = pool.GetCallInvoker("domain", "endpoint", options: null, Config1, FakeAdapter); Assert.NotSame(callInvoker1, callInvoker2); } } diff --git a/Google.Api.Gax.Grpc.IntegrationTests/TestServiceMetadata.cs b/Google.Api.Gax.Grpc.IntegrationTests/TestServiceMetadata.cs index fab4661f..3c35fc50 100644 --- a/Google.Api.Gax.Grpc.IntegrationTests/TestServiceMetadata.cs +++ b/Google.Api.Gax.Grpc.IntegrationTests/TestServiceMetadata.cs @@ -9,6 +9,8 @@ namespace Google.Api.Gax.Grpc.IntegrationTests { internal static class TestServiceMetadata { + internal const string DefaultUniverseDomain = "googleapis.com"; + internal static ApiMetadata ApiMetadata { get; } = new ApiMetadata("Google.Api.Gax.Grpc.IntegrationTests", new[] { TestServiceReflection.Descriptor }); internal static ServiceMetadata TestService { get; } = diff --git a/Google.Api.Gax.Grpc/ChannelPool.cs b/Google.Api.Gax.Grpc/ChannelPool.cs index 8ee7daa7..42ab4b44 100644 --- a/Google.Api.Gax.Grpc/ChannelPool.cs +++ b/Google.Api.Gax.Grpc/ChannelPool.cs @@ -64,15 +64,18 @@ public Task ShutdownChannelsAsync() /// The specified channel options are applied, but only those options. /// /// The gRPC implementation to use. Must not be null. + /// The universe domain configured for the service client, + /// to validate against the one configured for the credential. Must not be null. /// The endpoint to connect to. Must not be null. /// The channel options to include. May be null. /// A channel for the specified endpoint. - internal ChannelBase GetChannel(GrpcAdapter grpcAdapter, string endpoint, GrpcChannelOptions channelOptions) + internal ChannelBase GetChannel(GrpcAdapter grpcAdapter, string universeDomain, string endpoint, GrpcChannelOptions channelOptions) { GaxPreconditions.CheckNotNull(grpcAdapter, nameof(grpcAdapter)); + GaxPreconditions.CheckNotNull(universeDomain, nameof(universeDomain)); GaxPreconditions.CheckNotNull(endpoint, nameof(endpoint)); - var credentials = _credentialCache.GetCredentials(); - return GetChannel(grpcAdapter, endpoint, channelOptions, credentials); + var credentials = _credentialCache.GetCredentials(universeDomain); + return GetChannel(grpcAdapter, universeDomain, endpoint, channelOptions, credentials); } /// @@ -81,22 +84,25 @@ internal ChannelBase GetChannel(GrpcAdapter grpcAdapter, string endpoint, GrpcCh /// The specified channel options are applied, but only those options. /// /// The gRPC implementation to use. Must not be null. + /// The universe domain configured for the service client, + /// to validate against the one configured for the credential. Must not be null. /// The endpoint to connect to. Must not be null. /// The channel options to include. May be null. /// A cancellation token for the operation. /// A task representing the asynchronous operation. The value of the completed /// task will be channel for the specified endpoint. - internal async Task GetChannelAsync(GrpcAdapter grpcAdapter, string endpoint, GrpcChannelOptions channelOptions, CancellationToken cancellationToken) + internal async Task GetChannelAsync(GrpcAdapter grpcAdapter, string universeDomain, string endpoint, GrpcChannelOptions channelOptions, CancellationToken cancellationToken) { GaxPreconditions.CheckNotNull(grpcAdapter, nameof(grpcAdapter)); + GaxPreconditions.CheckNotNull(universeDomain, nameof(universeDomain)); GaxPreconditions.CheckNotNull(endpoint, nameof(endpoint)); - var credentials = await _credentialCache.GetCredentialsAsync(cancellationToken).ConfigureAwait(false); - return GetChannel(grpcAdapter, endpoint, channelOptions, credentials); + var credentials = await _credentialCache.GetCredentialsAsync(universeDomain, cancellationToken).ConfigureAwait(false); + return GetChannel(grpcAdapter, universeDomain, endpoint, channelOptions, credentials); } - private ChannelBase GetChannel(GrpcAdapter grpcAdapter, string endpoint, GrpcChannelOptions channelOptions, ChannelCredentials credentials) + private ChannelBase GetChannel(GrpcAdapter grpcAdapter, string universeDomain, string endpoint, GrpcChannelOptions channelOptions, ChannelCredentials credentials) { - var key = new Key(grpcAdapter, endpoint, channelOptions); + var key = new Key(grpcAdapter, universeDomain, endpoint, channelOptions); lock (_lock) { @@ -112,23 +118,25 @@ private ChannelBase GetChannel(GrpcAdapter grpcAdapter, string endpoint, GrpcCha private struct Key : IEquatable { + public readonly string UniverseDomain; public readonly string Endpoint; public readonly GrpcChannelOptions Options; public readonly GrpcAdapter GrpcAdapter; - public Key(GrpcAdapter grpcAdapter, string endpoint, GrpcChannelOptions options) => - (GrpcAdapter, Endpoint, Options) = (grpcAdapter, endpoint, options); + public Key(GrpcAdapter grpcAdapter, string universeDomain, string endpoint, GrpcChannelOptions options) => + (GrpcAdapter, UniverseDomain, Endpoint, Options) = (grpcAdapter, universeDomain, endpoint, options); public override int GetHashCode() => GaxEqualityHelpers.CombineHashCodes( GrpcAdapter.GetHashCode(), + UniverseDomain.GetHashCode(), Endpoint.GetHashCode(), Options.GetHashCode()); public override bool Equals(object obj) => obj is Key other && Equals(other); public bool Equals(Key other) => - GrpcAdapter.Equals(other.GrpcAdapter) && Endpoint.Equals(other.Endpoint) && Options.Equals(other.Options); + GrpcAdapter.Equals(other.GrpcAdapter) && UniverseDomain.Equals(other.UniverseDomain) && Endpoint.Equals(other.Endpoint) && Options.Equals(other.Options); } } } diff --git a/Google.Api.Gax.Grpc/ClientBuilderBase.cs b/Google.Api.Gax.Grpc/ClientBuilderBase.cs index 9fc44646..3478fc61 100644 --- a/Google.Api.Gax.Grpc/ClientBuilderBase.cs +++ b/Google.Api.Gax.Grpc/ClientBuilderBase.cs @@ -469,7 +469,7 @@ protected virtual CallInvoker CreateCallInvoker() ChannelBase channel; if (CanUseChannelPool) { - channel = GetChannelPool().GetChannel(EffectiveGrpcAdapter, endpoint, GetChannelOptions()); + channel = GetChannelPool().GetChannel(EffectiveGrpcAdapter, EffectiveUniverseDomain, endpoint, GetChannelOptions()); } else { @@ -497,7 +497,7 @@ protected virtual async Task CreateCallInvokerAsync(CancellationTok if (CanUseChannelPool) { channel = await GetChannelPool() - .GetChannelAsync(EffectiveGrpcAdapter, endpoint, GetChannelOptions(), cancellationToken) + .GetChannelAsync(EffectiveGrpcAdapter, EffectiveUniverseDomain, endpoint, GetChannelOptions(), cancellationToken) .ConfigureAwait(false); } else @@ -513,7 +513,7 @@ protected virtual async Task CreateCallInvokerAsync(CancellationTok /// credential mechanisms are supported. /// protected virtual ChannelCredentials GetChannelCredentials() => - MaybeGetSimpleChannelCredentials() ?? GetGoogleCredential().ToChannelCredentials(); + MaybeGetSimpleChannelCredentials() ?? GetGoogleCredential().ToChannelCredentials(EffectiveUniverseDomain); /// /// Obtains channel credentials asynchronously. Override this method in a concrete builder type if more @@ -521,7 +521,7 @@ protected virtual ChannelCredentials GetChannelCredentials() => /// protected async virtual Task GetChannelCredentialsAsync(CancellationToken cancellationToken) => MaybeGetSimpleChannelCredentials() - ?? (await GetGoogleCredentialAsync(cancellationToken).ConfigureAwait(false)).ToChannelCredentials(); + ?? (await GetGoogleCredentialAsync(cancellationToken).ConfigureAwait(false)).ToChannelCredentials(EffectiveUniverseDomain); /// /// Obtains channel credentials synchronously if they've been supplied in a ready-to-go fashion. diff --git a/Google.Api.Gax.Grpc/DefaultChannelCredentialsCache.cs b/Google.Api.Gax.Grpc/DefaultChannelCredentialsCache.cs index e8b40159..6312d7a8 100644 --- a/Google.Api.Gax.Grpc/DefaultChannelCredentialsCache.cs +++ b/Google.Api.Gax.Grpc/DefaultChannelCredentialsCache.cs @@ -6,11 +6,10 @@ */ using Google.Apis.Auth.OAuth2; -using Grpc.Auth; using Grpc.Core; using System; using System.Collections.Generic; -using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace Google.Api.Gax.Grpc @@ -31,7 +30,7 @@ internal sealed class DefaultChannelCredentialsCache /// The same channel credentials are used by all pools. The field is initialized in the constructor, as it uses /// _scopes, and you can't refer to an instance field within an instance field initializer. /// - private readonly Lazy> _lazyScopedDefaultChannelCredentials; + private readonly Lazy> _lazyScopedDefaultChannelCredentials; /// /// Creates a cache which will apply the specified scopes to the default application credentials @@ -46,7 +45,7 @@ internal DefaultChannelCredentialsCache(ServiceMetadata serviceMetadata) // However, it won't be any more efficient, and having the scopes easily available when debugging could be handy. _scopes = serviceMetadata.DefaultScopes; _lazyScopedDefaultChannelCredentials = - new Lazy>(() => Task.Run(async () => + new Lazy>(() => Task.Run(async () => { var appDefaultCredentials = await GoogleCredential.GetApplicationDefaultAsync().ConfigureAwait(false); if (appDefaultCredentials.IsCreateScopedRequired) @@ -59,14 +58,17 @@ internal DefaultChannelCredentialsCache(ServiceMetadata serviceMetadata) { appDefaultCredentials = GoogleCredential.FromServiceAccountCredential(serviceCredential.WithUseJwtAccessWithScopes(UseJwtAccessWithScopes)); } - return appDefaultCredentials.ToChannelCredentials(); + return appDefaultCredentials; })); } - internal ChannelCredentials GetCredentials() => - GetCredentialsAsync(default).ResultWithUnwrappedExceptions(); + internal ChannelCredentials GetCredentials(string universeDomain) => + GetCredentialsAsync(universeDomain, default).ResultWithUnwrappedExceptions(); - internal Task GetCredentialsAsync(CancellationToken cancellationToken) => - _lazyScopedDefaultChannelCredentials.Value.WithCancellationToken(cancellationToken); + internal async Task GetCredentialsAsync(string universeDomain, CancellationToken cancellationToken) + { + var googleCredential = await _lazyScopedDefaultChannelCredentials.Value.WithCancellationToken(cancellationToken).ConfigureAwait(false); + return googleCredential.ToChannelCredentials(universeDomain); + } } } diff --git a/Google.Api.Gax.Grpc/Gcp/GcpCallInvokerPool.cs b/Google.Api.Gax.Grpc/Gcp/GcpCallInvokerPool.cs index 041ee743..09f2dccb 100644 --- a/Google.Api.Gax.Grpc/Gcp/GcpCallInvokerPool.cs +++ b/Google.Api.Gax.Grpc/Gcp/GcpCallInvokerPool.cs @@ -9,6 +9,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace Google.Api.Gax.Grpc.Gcp @@ -62,18 +63,64 @@ public Task ShutdownChannelsAsync() /// Returns a call invoker from this pool, creating a new one if there is no call invoker /// already associated with and . /// + /// The universe domain configured for the service client, + /// to validate against the one configured for the credential. Must not be null. /// The endpoint to connect to. Must not be null. /// The options to use for each channel created by the call invoker. May be null. /// The API configuration used to determine channel keys. Must not be null. /// The gRPC adapter to use to create call invokers. Must not be null. /// A call invoker for the specified endpoint. + public GcpCallInvoker GetCallInvoker(string universeDomain, string endpoint, GrpcChannelOptions options, ApiConfig apiConfig, GrpcAdapter adapter) + { + GaxPreconditions.CheckNotNull(endpoint, nameof(endpoint)); + var credentials = _credentialsCache.GetCredentials(universeDomain); + GaxPreconditions.CheckNotNull(apiConfig, nameof(apiConfig)); + GaxPreconditions.CheckNotNull(adapter, nameof(adapter)); + return GetCallInvoker(universeDomain, endpoint, credentials, options, apiConfig, adapter); + } + + /// + /// Asynchronously returns a call invoker from this pool, creating a new one if there is no call invoker + /// already associated with and . + /// + /// The universe domain configured for the service client, + /// to validate against the one configured for the credential. Must not be null. + /// The endpoint to connect to. Must not be null. + /// The options to use for each channel created by the call invoker. May be null. + /// The API configuration used to determine channel keys. Must not be null. + /// The gRPC adapter to use to create call invokers. Must not be null. + /// The cancellation token to cancel the operation. + /// A task representing the asynchronous operation. The value of the completed + /// task will be a call invoker for the specified endpoint. + public async Task GetCallInvokerAsync(string universeDomain, string endpoint, GrpcChannelOptions options, ApiConfig apiConfig, GrpcAdapter adapter, CancellationToken cancellationToken) + { + GaxPreconditions.CheckNotNull(endpoint, nameof(endpoint)); + GaxPreconditions.CheckNotNull(apiConfig, nameof(apiConfig)); + GaxPreconditions.CheckNotNull(adapter, nameof(adapter)); + var credentials = await _credentialsCache.GetCredentialsAsync(universeDomain, cancellationToken).ConfigureAwait(false); + return GetCallInvoker(universeDomain, endpoint, credentials, options, apiConfig, adapter); + } + + /// + /// Returns a call invoker from this pool, creating a new one if there is no call invoker + /// already associated with and . + /// + /// The endpoint to connect to. Must not be null. + /// The options to use for each channel created by the call invoker. May be null. + /// The API configuration used to determine channel keys. Must not be null. + /// The gRPC adapter to use to create call invokers. Must not be null. + /// A call invoker for the specified endpoint. + [Obsolete("Please use the overloads that accept a universe domain to make certain the credentials used are valid in the target universe domain.")] public GcpCallInvoker GetCallInvoker(string endpoint, GrpcChannelOptions options, ApiConfig apiConfig, GrpcAdapter adapter) { + // We use the dedault universe domain here for obtaining credentials and the call invoker as this is obsolete code + // that's not multi universe domain enabled, so it must only ever execute in the default universe domain. + // If the credential being used is not from the default universe domain, validation will fail and no requests will be made. GaxPreconditions.CheckNotNull(endpoint, nameof(endpoint)); - var credentials = _credentialsCache.GetCredentials(); + var credentials = _credentialsCache.GetCredentials(ServiceMetadata.DefaultUniverseDomain); GaxPreconditions.CheckNotNull(apiConfig, nameof(apiConfig)); GaxPreconditions.CheckNotNull(adapter, nameof(adapter)); - return GetCallInvoker(endpoint, credentials, options, apiConfig, adapter); + return GetCallInvoker(ServiceMetadata.DefaultUniverseDomain, endpoint, credentials, options, apiConfig, adapter); } /// @@ -86,20 +133,24 @@ public GcpCallInvoker GetCallInvoker(string endpoint, GrpcChannelOptions options /// The gRPC adapter to use to create call invokers. Must not be null. /// A task representing the asynchronous operation. The value of the completed /// task will be a call invoker for the specified endpoint. + [Obsolete("Please use the overloads that accept a universe domain to make certain the credentials used are valid in the target universe domain.")] public async Task GetCallInvokerAsync(string endpoint, GrpcChannelOptions options, ApiConfig apiConfig, GrpcAdapter adapter) { + // We use the dedault universe domain here for obtaining credentials and the call invoker as this is obsolete code + // that's not multi universe domain enabled, so it must only ever execute in the default universe domain. + // If the credential being used is not from the default universe domain, validation will fail and no requests will be made. GaxPreconditions.CheckNotNull(endpoint, nameof(endpoint)); GaxPreconditions.CheckNotNull(apiConfig, nameof(apiConfig)); GaxPreconditions.CheckNotNull(adapter, nameof(adapter)); - var credentials = await _credentialsCache.GetCredentialsAsync(default).ConfigureAwait(false); - return GetCallInvoker(endpoint, credentials, options, apiConfig, adapter); + var credentials = await _credentialsCache.GetCredentialsAsync(ServiceMetadata.DefaultUniverseDomain, default).ConfigureAwait(false); + return GetCallInvoker(ServiceMetadata.DefaultUniverseDomain, endpoint, credentials, options, apiConfig, adapter); } - private GcpCallInvoker GetCallInvoker(string endpoint, ChannelCredentials credentials, GrpcChannelOptions options, ApiConfig apiConfig, GrpcAdapter adapter) + private GcpCallInvoker GetCallInvoker(string universeDomain, string endpoint, ChannelCredentials credentials, GrpcChannelOptions options, ApiConfig apiConfig, GrpcAdapter adapter) { var effectiveOptions = s_defaultOptions.MergedWith(options ?? GrpcChannelOptions.Empty); apiConfig = apiConfig.Clone(); - var key = new Key(endpoint, effectiveOptions, apiConfig, adapter); + var key = new Key(universeDomain, endpoint, effectiveOptions, apiConfig, adapter); lock (_lock) { @@ -114,13 +165,15 @@ private GcpCallInvoker GetCallInvoker(string endpoint, ChannelCredentials creden private struct Key : IEquatable { + public readonly string UniverseDomain; public readonly string Endpoint; public readonly GrpcChannelOptions Options; public readonly ApiConfig Config; public readonly GrpcAdapter GrpcAdapter; - public Key(string endpoint, GrpcChannelOptions options, ApiConfig config, GrpcAdapter adapter) + public Key(string universedomain, string endpoint, GrpcChannelOptions options, ApiConfig config, GrpcAdapter adapter) { + UniverseDomain = universedomain; Endpoint = endpoint; Options = options; Config = config; @@ -129,6 +182,7 @@ public Key(string endpoint, GrpcChannelOptions options, ApiConfig config, GrpcAd public override int GetHashCode() => GaxEqualityHelpers.CombineHashCodes( + UniverseDomain.GetHashCode(), Endpoint.GetHashCode(), Options.GetHashCode(), Config.GetHashCode(), @@ -137,6 +191,7 @@ public override int GetHashCode() => public override bool Equals(object obj) => obj is Key other && Equals(other); public bool Equals(Key other) => + UniverseDomain.Equals(other.UniverseDomain) && Endpoint.Equals(other.Endpoint) && Options.Equals(other.Options) && Config.Equals(other.Config) && diff --git a/Google.Api.Gax.Grpc/GoogleCredentialExtensions.cs b/Google.Api.Gax.Grpc/GoogleCredentialExtensions.cs new file mode 100644 index 00000000..aa5fab9c --- /dev/null +++ b/Google.Api.Gax.Grpc/GoogleCredentialExtensions.cs @@ -0,0 +1,68 @@ +/* + * Copyright 2024 Google LLC + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file or at + * https://developers.google.com/open-source/licenses/bsd + */ + +using Google.Apis.Auth.OAuth2; +using Grpc.Auth; +using Grpc.Core; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Google.Api.Gax.Grpc; + +/// +/// Extension methods for Google credential universe domain validation. +/// +internal static class GoogleCredentialExtensions +{ + /// + /// Returns a channel credential based on , that will validate its own universe domain + /// against . + /// + /// The Google credential to build the channel credentials from. Must not be null. + /// The universe domain to validate against. Must not be null. + /// should result in the same value as this. + internal static ChannelCredentials ToChannelCredentials(this GoogleCredential googleCredential, string universeDomain) => + new GoogleCredentialWithUniverseDomainValidation(googleCredential, universeDomain).ToChannelCredentials(); + + private class GoogleCredentialWithUniverseDomainValidation : ITokenAccessWithHeaders + { + private readonly ITokenAccessWithHeaders _underlying; + private readonly string _universeDomain; + private readonly Lazy _universeDomainsMatchCheckCache; + + public GoogleCredentialWithUniverseDomainValidation(GoogleCredential googleCredential, string universeDomain) + { + _underlying = GaxPreconditions.CheckNotNull(googleCredential, nameof(googleCredential)); + _universeDomain = GaxPreconditions.CheckNotNull(universeDomain, nameof(universeDomain)); + _universeDomainsMatchCheckCache = new Lazy(UniverseDomainsMatchCheckUncached); + } + + public async Task GetAccessTokenForRequestAsync(string authUri = null, CancellationToken cancellationToken = default) + { + await _universeDomainsMatchCheckCache.Value.WithCancellationToken(cancellationToken).ConfigureAwait(false); + return await _underlying.GetAccessTokenForRequestAsync(authUri, cancellationToken).ConfigureAwait(false); + } + + public async Task GetAccessTokenWithHeadersForRequestAsync(string authUri = null, CancellationToken cancellationToken = default) + { + await _universeDomainsMatchCheckCache.Value.WithCancellationToken(cancellationToken).ConfigureAwait(false); + return await _underlying.GetAccessTokenWithHeadersForRequestAsync(authUri, cancellationToken).ConfigureAwait(false); + } + + private async Task UniverseDomainsMatchCheckUncached() + { + string credentialUniverseDomain = await (_underlying as GoogleCredential).GetUniverseDomainAsync(default).ConfigureAwait(false); + if (credentialUniverseDomain != _universeDomain) + { + throw new InvalidOperationException( + $"The service client universe domain {_universeDomain} does not match the credential universe domain {credentialUniverseDomain}." + + $"You can configure the universe domain for your service client by using the UniverseDomain property on the corresponding client builder."); + } + } + } +}