Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-878076: Retry Strategy for JDBC #1548

Merged
merged 10 commits into from
Nov 6, 2023
4 changes: 2 additions & 2 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ public class SFSession extends SFBaseSession {
* Amount of seconds a user is willing to tolerate for establishing the connection with database.
* In our case, it means the first login request to get authorization token.
*
* <p>Default:60 seconds
* <p>Default:300 seconds
*/
private int loginTimeout = 60;
private int loginTimeout = 300;
/**
* Amount of milliseconds a user is willing to tolerate for network related issues (e.g. HTTP
* 503/504) or database transient issues (e.g. GS not responding)
Expand Down
32 changes: 32 additions & 0 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.message.BasicHeader;
Expand Down Expand Up @@ -112,6 +113,14 @@ public class SessionUtil {

static final String SF_HEADER_SERVICE_NAME = "X-Snowflake-Service";

public static final String SF_HEADER_CLIENT_APP_ID = "CLIENT_APP_ID";

public static final String SF_HEADER_CLIENT_APP_VERSION = "CLIENT_APP_VERSION";

private static final String SF_DRIVER_NAME = "Snowflake";

private static final String SF_DRIVER_VERSION = SnowflakeDriver.implementVersion;

private static final String ID_TOKEN_AUTHENTICATOR = "ID_TOKEN";

private static final String NO_QUERY_ID = "";
Expand Down Expand Up @@ -592,6 +601,10 @@ private static SFLoginOutput newSession(
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

// Add headers for driver name and version
postRequest.addHeader(SF_HEADER_CLIENT_APP_ID, SF_DRIVER_NAME);
sfc-gh-ext-simba-jf marked this conversation as resolved.
Show resolved Hide resolved
postRequest.addHeader(SF_HEADER_CLIENT_APP_VERSION, SF_DRIVER_VERSION);

// attach the login info json body to the post request
StringEntity input = new StringEntity(json, StandardCharsets.UTF_8);
input.setContentType("application/json");
Expand Down Expand Up @@ -1614,4 +1627,23 @@ public static String generateJWTToken(
privateKey, privateKeyFile, privateKeyFilePwd, accountName, userName);
return s.issueJwtToken();
}

/**
* Helper method to check if the request path is a login/auth request to use for retry strategy.
*
* @param request the post request
* @return true if this is a login/auth request, false otherwise
*/
public static boolean isLoginRequest(HttpRequestBase request) {
sfc-gh-ext-simba-jf marked this conversation as resolved.
Show resolved Hide resolved
URI requestURI = request.getURI();
String requestPath = requestURI.getPath();
if (requestPath != null) {
if (requestPath.equals(SF_PATH_LOGIN_REQUEST)
|| requestPath.equals(SF_PATH_AUTHENTICATOR_REQUEST)
|| requestPath.equals(SF_PATH_TOKEN_REQUEST)) {
return true;
}
}
return false;
}
}
27 changes: 25 additions & 2 deletions src/main/java/net/snowflake/client/jdbc/RestRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ public class RestRequest {
// min backoff in milli before we retry due to transient issues
private static final long minBackoffInMilli = 1000;

// min backoff in milli for login/auth requests before we retry
private static final long minLoginBackoffInMilli = 4000;
sfc-gh-ext-simba-jf marked this conversation as resolved.
Show resolved Hide resolved

// max backoff in milli before we retry due to transient issues
// we double the backoff after each retry till we reach the max backoff
private static final long maxBackoffInMilli = 16000;
Expand Down Expand Up @@ -132,14 +135,22 @@ public static CloseableHttpResponse execute(
// when there are transient network/GS issues.
long startTimePerRequest = startTime;

// Used to indicate that this is a login/auth request and will be using the new retry strategy.
boolean isLoginRequest = SessionUtil.isLoginRequest(httpRequest);

// total elapsed time due to transient issues.
long elapsedMilliForTransientIssues = 0;

// retry timeout (ms)
long retryTimeoutInMilliseconds = retryTimeout * 1000;

// amount of time to wait for backing off before retry
long backoffInMilli = minBackoffInMilli;
long backoffInMilli;
if (isLoginRequest) {
backoffInMilli = minLoginBackoffInMilli;
} else {
backoffInMilli = minBackoffInMilli;
}

// auth timeout (ms)
long authTimeoutInMilli = authTimeout * 1000;
Expand Down Expand Up @@ -417,7 +428,19 @@ public static CloseableHttpResponse execute(
logger.debug("sleeping in {}(ms)", backoffInMilli);
Thread.sleep(backoffInMilli);
elapsedMilliForTransientIssues += backoffInMilli;
backoffInMilli = backoff.nextSleepTime(backoffInMilli);
if (isLoginRequest) {
backoffInMilli = backoff.getJitterForLogin(backoffInMilli);
sfc-gh-ext-simba-jf marked this conversation as resolved.
Show resolved Hide resolved
} else {
backoffInMilli = backoff.nextSleepTime(backoffInMilli);
}
if (retryTimeoutInMilliseconds > 0
&& (elapsedMilliForTransientIssues + backoffInMilli) > retryTimeoutInMilliseconds) {
// If the timeout will be reached before the next backoff, just use the remaining
// time.
backoffInMilli =
Math.min(
backoffInMilli, retryTimeoutInMilliseconds - elapsedMilliForTransientIssues);
}
} catch (InterruptedException ex1) {
logger.debug("Backoff sleep before retrying login got interrupted", false);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package net.snowflake.client.util;

import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;

/**
Expand All @@ -19,4 +20,15 @@ public DecorrelatedJitterBackoff(long base, long cap) {
public long nextSleepTime(long sleep) {
return Math.min(cap, ThreadLocalRandom.current().nextLong(base, sleep * 3));
}

public long getJitterForLogin(long currentTime) {
int mulitplicationFactor = chooseRandom(-1, 1);
long jitter = (long) (mulitplicationFactor * currentTime * 0.5);
return jitter;
}

private int chooseRandom(int min, int max) {
Random random = new Random();
return random.nextInt(max - min) + min;
sfc-gh-ext-simba-jf marked this conversation as resolved.
Show resolved Hide resolved
}
}
46 changes: 45 additions & 1 deletion src/test/java/net/snowflake/client/core/SessionUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
package net.snowflake.client.core;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;

import com.fasterxml.jackson.databind.node.BooleanNode;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.snowflake.client.jdbc.MockConnectionTest;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.junit.Test;

public class SessionUtilTest {
Expand Down Expand Up @@ -83,4 +89,42 @@ public void testConvertSystemPropertyToIntValue() {
HttpUtil.JDBC_MAX_CONNECTIONS_PER_ROUTE_PROPERTY,
HttpUtil.DEFAULT_MAX_CONNECTIONS_PER_ROUTE));
}

@Test
public void testIsLoginRequest() {
List<String> testCases = new ArrayList<String>();
testCases.add("/session/v1/login-request");
testCases.add("/session/token-request");
testCases.add("/session/authenticator-request");

for (String testCase : testCases) {
try {
URIBuilder uriBuilder = new URIBuilder("https://test.snowflakecomputing.com");
uriBuilder.setPath(testCase);
URI uri = uriBuilder.build();
HttpPost postRequest = new HttpPost(uri);
assertTrue(SessionUtil.isLoginRequest(postRequest));
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
}

@Test
public void testIsLoginRequestInvalidURIPath() {
List<String> testCases = new ArrayList<String>();
testCases.add("/session/not-a-real-path");

for (String testCase : testCases) {
try {
URIBuilder uriBuilder = new URIBuilder("https://test.snowflakecomputing.com");
uriBuilder.setPath(testCase);
URI uri = uriBuilder.build();
HttpPost postRequest = new HttpPost(uri);
assertFalse(SessionUtil.isLoginRequest(postRequest));
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ public void testWrongHostNameTimeout() throws InterruptedException {
equalTo(ErrorCode.NETWORK_ERROR.getMessageCode()));

conEnd = System.currentTimeMillis();
assertThat("Login time out not taking effective", conEnd - connStart < 60000);
assertThat("Login time out not taking effective", conEnd - connStart < 300000);

Thread.sleep(WAIT_FOR_TELEMETRY_REPORT_IN_MILLISECS);
if (TelemetryService.getInstance().isDeploymentEnabled()) {
Expand Down Expand Up @@ -595,7 +595,7 @@ public void testHttpsLoginTimeoutWithSSL() throws InterruptedException {
equalTo(ErrorCode.NETWORK_ERROR.getMessageCode()));

conEnd = System.currentTimeMillis();
assertThat("Login time out not taking effective", conEnd - connStart < 60000);
assertThat("Login time out not taking effective", conEnd - connStart < 300000);
Thread.sleep(WAIT_FOR_TELEMETRY_REPORT_IN_MILLISECS);
if (TelemetryService.getInstance().isDeploymentEnabled()) {
assertThat(
Expand Down
80 changes: 79 additions & 1 deletion src/test/java/net/snowflake/client/jdbc/RestRequestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
/** RestRequest unit tests. */
public class RestRequestTest {

static final int DEFAULT_CONNECTION_TIMEOUT = 60000;
static final int DEFAULT_CONNECTION_TIMEOUT = 300000;
static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT = 300000; // ms

private CloseableHttpResponse retryResponse() {
Expand All @@ -42,6 +42,16 @@ private CloseableHttpResponse retryResponse() {
return retryResponse;
}

private CloseableHttpResponse retryLoginResponse() {
StatusLine retryStatusLine = mock(StatusLine.class);
when(retryStatusLine.getStatusCode()).thenReturn(429);

CloseableHttpResponse retryResponse = mock(CloseableHttpResponse.class);
when(retryResponse.getStatusLine()).thenReturn(retryStatusLine);

return retryResponse;
}

private CloseableHttpResponse successResponse() {
StatusLine successStatusLine = mock(StatusLine.class);
when(successStatusLine.getStatusCode()).thenReturn(200);
Expand Down Expand Up @@ -457,6 +467,74 @@ public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwabl
}
}

@Test(expected = SnowflakeSQLException.class)
public void testLoginMaxRetries() throws IOException, SnowflakeSQLException {
boolean telemetryEnabled = TelemetryService.getInstance().isEnabled();

CloseableHttpClient client = mock(CloseableHttpClient.class);
when(client.execute(any(HttpUriRequest.class)))
.thenAnswer(
new Answer<CloseableHttpResponse>() {
int callCount = 0;

@Override
public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwable {
callCount += 1;
if (callCount >= 4) {
return retryLoginResponse();
} else {
return socketTimeoutResponse();
}
}
});

try {
TelemetryService.disable();
execute(client, "/session/v1/login-request", 0, 0, 0, true, false, 1);
fail("testMaxRetries");
} finally {
if (telemetryEnabled) {
TelemetryService.enable();
} else {
TelemetryService.disable();
}
}
}

@Test(expected = SnowflakeSQLException.class)
public void testLoginTimeout() throws IOException, SnowflakeSQLException {
boolean telemetryEnabled = TelemetryService.getInstance().isEnabled();

CloseableHttpClient client = mock(CloseableHttpClient.class);
when(client.execute(any(HttpUriRequest.class)))
.thenAnswer(
new Answer<CloseableHttpResponse>() {
int callCount = 0;

@Override
public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwable {
callCount += 1;
if (callCount >= 4) {
return retryLoginResponse();
} else {
return socketTimeoutResponse();
}
}
});

try {
TelemetryService.disable();
execute(client, "/session/v1/login-request", 1, 0, 0, true, false, 10);
fail("testMaxRetries");
} finally {
if (telemetryEnabled) {
TelemetryService.enable();
} else {
TelemetryService.disable();
}
}
}

@Test
public void testMaxRetriesWithSuccessfulResponse() throws IOException {
boolean telemetryEnabled = TelemetryService.getInstance().isEnabled();
Expand Down