diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 25a87c0908..c38d98e91f 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -40,6 +40,7 @@ import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.ServerStreamingCallable; import com.google.api.gax.rpc.StreamController; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.type.Color; @@ -62,6 +63,7 @@ import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import org.junit.After; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -72,6 +74,16 @@ @RunWith(JUnit4.class) public class ChannelPoolTest { + private static final int DEFAULT_AWAIT_TERMINATION_SEC = 10; + private ChannelPool pool; + + @After + public void cleanup() throws InterruptedException { + Preconditions.checkNotNull(pool, "Channel pool was never created"); + pool.shutdown(); + pool.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS); + } + @Test public void testAuthority() throws IOException { ManagedChannel sub1 = Mockito.mock(ManagedChannel.class); @@ -79,7 +91,7 @@ public void testAuthority() throws IOException { Mockito.when(sub1.authority()).thenReturn("myAuth"); - ChannelPool pool = + pool = ChannelPool.create( ChannelPoolSettings.staticallySized(2), new FakeChannelFactory(Arrays.asList(sub1, sub2))); @@ -94,7 +106,7 @@ public void testRoundRobin() throws IOException { Mockito.when(sub1.authority()).thenReturn("myAuth"); ArrayList channels = Lists.newArrayList(sub1, sub2); - ChannelPool pool = + pool = ChannelPool.create( ChannelPoolSettings.staticallySized(channels.size()), new FakeChannelFactory(channels)); @@ -150,7 +162,7 @@ public void ensureEvenDistribution() throws InterruptedException, IOException { }); } - final ChannelPool pool = + pool = ChannelPool.create( ChannelPoolSettings.staticallySized(numChannels), new FakeChannelFactory(Arrays.asList(channels))); @@ -184,12 +196,13 @@ public void channelPrimerShouldCallPoolConstruction() throws IOException { ManagedChannel channel1 = Mockito.mock(ManagedChannel.class); ManagedChannel channel2 = Mockito.mock(ManagedChannel.class); - ChannelPool.create( - ChannelPoolSettings.staticallySized(2) - .toBuilder() - .setPreemptiveRefreshEnabled(true) - .build(), - new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer)); + pool = + ChannelPool.create( + ChannelPoolSettings.staticallySized(2) + .toBuilder() + .setPreemptiveRefreshEnabled(true) + .build(), + new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer)); Mockito.verify(mockChannelPrimer, Mockito.times(2)) .primeChannel(Mockito.any(ManagedChannel.class)); } @@ -221,13 +234,14 @@ public void channelPrimerIsCalledPeriodically() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer); - new ChannelPool( - ChannelPoolSettings.staticallySized(1) - .toBuilder() - .setPreemptiveRefreshEnabled(true) - .build(), - channelFactory, - scheduledExecutorService); + pool = + new ChannelPool( + ChannelPoolSettings.staticallySized(1) + .toBuilder() + .setPreemptiveRefreshEnabled(true) + .build(), + channelFactory, + scheduledExecutorService); // 1 call during the creation Mockito.verify(mockChannelPrimer, Mockito.times(1)) .primeChannel(Mockito.any(ManagedChannel.class)); @@ -251,7 +265,7 @@ public void callShouldCompleteAfterCreation() throws IOException { ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class); FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); - ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); + pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); // create a mock call when new call comes to the underlying channel MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); @@ -300,7 +314,7 @@ public void callShouldCompleteAfterStarted() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); - ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); + pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); // create a mock call when new call comes to the underlying channel MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); @@ -345,7 +359,7 @@ public void channelShouldShutdown() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); - ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); + pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); // create a mock call when new call comes to the underlying channel MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); @@ -397,7 +411,7 @@ public void channelRefreshShouldSwapChannels() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel1, underlyingChannel2)); - ChannelPool pool = + pool = new ChannelPool( ChannelPoolSettings.staticallySized(1) .toBuilder() @@ -444,7 +458,7 @@ public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() thro return channel; }; - ChannelPool pool = + pool = new ChannelPool( ChannelPoolSettings.builder() .setInitialChannelCount(2) @@ -525,7 +539,7 @@ public void removedIdleChannelsAreShutdown() throws Exception { return channel; }; - ChannelPool pool = + pool = new ChannelPool( ChannelPoolSettings.builder() .setInitialChannelCount(2) @@ -565,7 +579,7 @@ public void removedActiveChannelsAreShutdown() throws Exception { return channel; }; - ChannelPool pool = + pool = new ChannelPool( ChannelPoolSettings.builder() .setInitialChannelCount(2) @@ -612,11 +626,11 @@ public void testReleasingClientCallCancelEarly() throws IOException { Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall); ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1); ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel)); - ChannelPool channelPool = ChannelPool.create(channelPoolSettings, factory); + pool = ChannelPool.create(channelPoolSettings, factory); ClientContext context = ClientContext.newBuilder() - .setTransportChannel(GrpcTransportChannel.create(channelPool)) - .setDefaultCallContext(GrpcCallContext.of(channelPool, CallOptions.DEFAULT)) + .setTransportChannel(GrpcTransportChannel.create(pool)) + .setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT)) .build(); ServerStreamingCallSettings settings = ServerStreamingCallSettings.newBuilder().build(); diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamControllerTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamControllerTest.java index ca2c2f69f4..96d7cf1063 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamControllerTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamControllerTest.java @@ -29,11 +29,11 @@ */ package com.google.api.gax.grpc; +import com.google.api.gax.core.BackgroundResource; import com.google.api.gax.core.NoCredentialsProvider; import com.google.api.gax.grpc.testing.FakeServiceGrpc; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.retrying.StreamResumptionStrategy; -import com.google.api.gax.rpc.Callables; import com.google.api.gax.rpc.ClientContext; import com.google.api.gax.rpc.DeadlineExceededException; import com.google.api.gax.rpc.FixedTransportChannelProvider; @@ -49,6 +49,9 @@ import io.grpc.ServerBuilder; import io.grpc.Status; import io.grpc.stub.StreamObserver; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.junit.Test; @@ -58,15 +61,13 @@ @RunWith(JUnit4.class) public class GrpcDirectStreamControllerTest { + private static final int DEFAULT_AWAIT_TERMINATION_SEC = 10; @Test(timeout = 180_000) // ms public void testRetryNoRaceCondition() throws Exception { - Server server = ServerBuilder.forPort(1234).addService(new FakeService()).build(); - server.start(); - + Server server = ServerBuilder.forPort(1234).addService(new FakeService()).build().start(); ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 1234).usePlaintext().build(); - StreamResumptionStrategy resumptionStrategy = new StreamResumptionStrategy() { @Nonnull @@ -92,58 +93,68 @@ public boolean canResume() { return true; } }; - - // Set up retry settings. Set total timeout to 1 minute to limit the total runtime of this test. - // Set retry delay to 1 ms so the retries will be scheduled in a loop with no delays. - // Set max attempt to max so there could be as many retries as possible. - ServerStreamingCallSettings callSettigs = + // Set up retry settings. Set total timeout to 1 minute to limit the total runtime of this + // test. Set retry delay to 1 ms so the retries will be scheduled in a loop with no delays. + ServerStreamingCallSettings callSettings = ServerStreamingCallSettings.newBuilder() .setResumptionStrategy(resumptionStrategy) .setRetryableCodes(StatusCode.Code.DEADLINE_EXCEEDED) .setRetrySettings( RetrySettings.newBuilder() .setTotalTimeout(Duration.ofMinutes(1)) - .setMaxAttempts(Integer.MAX_VALUE) + .setInitialRpcTimeout(Duration.ofMillis(1)) + .setMaxRpcTimeout(Duration.ofMillis(1)) .setInitialRetryDelay(Duration.ofMillis(1)) .setMaxRetryDelay(Duration.ofMillis(1)) .build()) .build(); - - StubSettings.Builder builder = - new StubSettings.Builder() { - @Override - public StubSettings build() { - return new StubSettings(this) { - @Override - public Builder toBuilder() { - throw new IllegalStateException(); - } - }; - } - }; - - builder - .setEndpoint("localhost:1234") - .setCredentialsProvider(NoCredentialsProvider.create()) - .setTransportChannelProvider( - FixedTransportChannelProvider.create(GrpcTransportChannel.create(channel))); - - ServerStreamingCallable callable = - GrpcCallableFactory.createServerStreamingCallable( - GrpcCallSettings.create(FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE), - callSettigs, - ClientContext.create(builder.build())); - - ServerStreamingCallable retrying = - Callables.retrying(callable, callSettigs, ClientContext.create(builder.build())); - - Color request = Color.newBuilder().getDefaultInstanceForType(); - + // Store a list of resources to manually close at the end of the test + List backgroundResourceList = new ArrayList<>(); try { - for (Money money : retrying.call(request, GrpcCallContext.createDefault())) {} - + GrpcTransportChannel transportChannel = GrpcTransportChannel.create(channel); + backgroundResourceList.add(transportChannel); + + StubSettings.Builder builder = + new StubSettings.Builder() { + @Override + public StubSettings build() { + return new StubSettings(this) { + @Override + public Builder toBuilder() { + throw new IllegalStateException(); + } + }; + } + }; + + builder + .setEndpoint("localhost:1234") + .setCredentialsProvider(NoCredentialsProvider.create()) + .setTransportChannelProvider(FixedTransportChannelProvider.create(transportChannel)); + + ClientContext clientContext = ClientContext.create(builder.build()); + backgroundResourceList.addAll(clientContext.getBackgroundResources()); + // GrpcCallableFactory's createServerStreamingCallable creates a retrying callable + ServerStreamingCallable callable = + GrpcCallableFactory.createServerStreamingCallable( + GrpcCallSettings.create(FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE), + callSettings, + clientContext); + + Color request = Color.newBuilder().getDefaultInstanceForType(); + for (Money money : callable.call(request, clientContext.getDefaultCallContext())) {} } catch (DeadlineExceededException e) { // Ignore this error + } finally { + // Shutdown all the resources + server.shutdown(); + server.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS); + channel.shutdown(); + channel.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS); + for (BackgroundResource backgroundResource : backgroundResourceList) { + backgroundResource.shutdown(); + backgroundResource.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS); + } } } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeServiceImpl.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeServiceImpl.java index 298e847837..3dccdc76e7 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeServiceImpl.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeServiceImpl.java @@ -81,17 +81,14 @@ public void serverStreamingRecognize( // because the InProcessServer uses a direct executor and will buffer the results ignoring // cancellation Runnable runnable = - new Runnable() { - @Override - public void run() { - try { - Thread.sleep((long) color.getGreen()); - } catch (InterruptedException e) { - Thread.interrupted(); - return; - } + () -> { + try { + Thread.sleep((long) color.getGreen()); responseObserver.onNext(convert(color)); responseObserver.onCompleted(); + } catch (Exception e) { + Thread.interrupted(); + responseObserver.onError(e); } }; @@ -107,9 +104,10 @@ public StreamObserver clientStreamingRecognize(StreamObserver resp } private static Money convert(Color color) { - Money result = - Money.newBuilder().setCurrencyCode("USD").setUnits((long) (color.getRed() * 255)).build(); - return result; + return Money.newBuilder() + .setCurrencyCode("USD") + .setUnits((long) (color.getRed() * 255)) + .build(); } private static class RequestStreamObserver implements StreamObserver {