Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-878076: Retry Strategy for JDBC #1548

Merged
merged 10 commits into from
Nov 6, 2023
18 changes: 17 additions & 1 deletion src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class SFLoginInput {
private String oktaUserName;
private String accountName;
private int loginTimeout = -1; // default is invalid
private int retryTimeout = 300;
private int authTimeout = 0;
private String userName;
private String password;
Expand Down Expand Up @@ -139,8 +140,23 @@ int getLoginTimeout() {
return loginTimeout;
}

// We want to choose the smaller of the two values between retryTimeout and loginTimeout for the
// new retry strategy.
SFLoginInput setLoginTimeout(int loginTimeout) {
this.loginTimeout = loginTimeout;
if (loginTimeout > retryTimeout && retryTimeout != 0) {
this.loginTimeout = retryTimeout;
} else {
this.loginTimeout = loginTimeout;
}
return this;
}

int getRetryTimeout() {
return retryTimeout;
}

SFLoginInput setRetryTimeout(int retryTimeout) {
this.retryTimeout = retryTimeout;
return this;
}

Expand Down
26 changes: 23 additions & 3 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ public class SFSession extends SFBaseSession {
* Amount of seconds a user is willing to tolerate for establishing the connection with database.
* In our case, it means the first login request to get authorization token.
*
* <p>Default:60 seconds
* <p>Default:300 seconds
*/
private int loginTimeout = 60;
private int loginTimeout = 300;
/**
* Amount of milliseconds a user is willing to tolerate for network related issues (e.g. HTTP
* 503/504) or database transient issues (e.g. GS not responding)
Expand Down Expand Up @@ -108,6 +108,13 @@ public class SFSession extends SFBaseSession {
// Max retries for outgoing http requests.
private int maxHttpRetries = 7;

/**
* Retry timeout in seconds. Cannot be less than 300.
*
* <p>Default: 300
*/
private int retryTimeout = 300;

// This constructor is used only by tests with no real connection.
// For real connections, the other constructor is always used.
@VisibleForTesting
Expand Down Expand Up @@ -369,6 +376,15 @@ public void addSFSessionProperty(String propertyName, Object propertyValue) thro
}
break;

case RETRY_TIMEOUT:
sfc-gh-ext-simba-jf marked this conversation as resolved.
Show resolved Hide resolved
if (propertyValue != null) {
int timeoutValue = (Integer) propertyValue;
if (timeoutValue >= 300 || timeoutValue == 0) {
retryTimeout = timeoutValue;
}
}
break;

default:
break;
}
Expand Down Expand Up @@ -405,7 +421,7 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
"input: server={}, account={}, user={}, password={}, role={}, database={}, schema={},"
+ " warehouse={}, validate_default_parameters={}, authenticator={}, ocsp_mode={},"
+ " passcode_in_password={}, passcode={}, private_key={}, disable_socks_proxy={},"
+ " application={}, app_id={}, app_version={}, login_timeout={}, network_timeout={},"
+ " application={}, app_id={}, app_version={}, login_timeout={}, retry_timeout={}, network_timeout={},"
+ " query_timeout={}, tracing={}, private_key_file={}, private_key_file_pwd={}."
+ " session_parameters: client_store_temporary_credential={}, gzip_disabled={}",
connectionPropertiesMap.get(SFSessionProperty.SERVER_URL),
Expand Down Expand Up @@ -433,6 +449,7 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
connectionPropertiesMap.get(SFSessionProperty.APP_ID),
connectionPropertiesMap.get(SFSessionProperty.APP_VERSION),
connectionPropertiesMap.get(SFSessionProperty.LOGIN_TIMEOUT),
connectionPropertiesMap.get(SFSessionProperty.RETRY_TIMEOUT),
connectionPropertiesMap.get(SFSessionProperty.NETWORK_TIMEOUT),
connectionPropertiesMap.get(SFSessionProperty.QUERY_TIMEOUT),
connectionPropertiesMap.get(SFSessionProperty.TRACING),
Expand Down Expand Up @@ -471,6 +488,7 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setOKTAUserName((String) connectionPropertiesMap.get(SFSessionProperty.OKTA_USERNAME))
.setAccountName((String) connectionPropertiesMap.get(SFSessionProperty.ACCOUNT))
.setLoginTimeout(loginTimeout)
.setRetryTimeout(retryTimeout)
.setAuthTimeout(authTimeout)
.setUserName((String) connectionPropertiesMap.get(SFSessionProperty.USER))
.setPassword((String) connectionPropertiesMap.get(SFSessionProperty.PASSWORD))
Expand Down Expand Up @@ -652,6 +670,7 @@ synchronized void renewSession(String prevSessionToken)
.setIdToken(idToken)
.setMfaToken(mfaToken)
.setLoginTimeout(loginTimeout)
.setRetryTimeout(retryTimeout)
.setDatabaseName(getDatabase())
.setSchemaName(getSchema())
.setRole(getRole())
Expand Down Expand Up @@ -696,6 +715,7 @@ public void close() throws SFException, SnowflakeSQLException {
.setServerUrl(getServerUrl())
.setSessionToken(sessionToken)
.setLoginTimeout(loginTimeout)
.setRetryTimeout(retryTimeout)
.setOCSPMode(getOCSPMode())
.setHttpClientSettingsKey(getHttpClientKey());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ public enum SFSessionProperty {

ENABLE_PUT_GET("enablePutGet", false, Boolean.class),

PUT_GET_MAX_RETRIES("putGetMaxRetries", false, Integer.class);
PUT_GET_MAX_RETRIES("putGetMaxRetries", false, Integer.class),

RETRY_TIMEOUT("retryTimeout", false, Integer.class);

// property key in string
private String propertyKey;
Expand Down
37 changes: 37 additions & 0 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.message.BasicHeader;
Expand Down Expand Up @@ -112,6 +113,10 @@ public class SessionUtil {

static final String SF_HEADER_SERVICE_NAME = "X-Snowflake-Service";

public static final String SF_HEADER_CLIENT_APP_ID = "CLIENT_APP_ID";

public static final String SF_HEADER_CLIENT_APP_VERSION = "CLIENT_APP_VERSION";

private static final String ID_TOKEN_AUTHENTICATOR = "ID_TOKEN";

private static final String NO_QUERY_ID = "";
Expand Down Expand Up @@ -592,6 +597,10 @@ private static SFLoginOutput newSession(
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

// Add headers for driver name and version
postRequest.addHeader(SF_HEADER_CLIENT_APP_ID, loginInput.getAppId());
postRequest.addHeader(SF_HEADER_CLIENT_APP_VERSION, loginInput.getAppVersion());

// attach the login info json body to the post request
StringEntity input = new StringEntity(json, StandardCharsets.UTF_8);
input.setContentType("application/json");
Expand All @@ -609,6 +618,7 @@ private static SFLoginOutput newSession(
setServiceNameHeader(loginInput, postRequest);

String theString = null;

int leftRetryTimeout = loginInput.getLoginTimeout();
int leftsocketTimeout = loginInput.getSocketTimeout();
int retryCount = 0;
Expand Down Expand Up @@ -902,6 +912,10 @@ private static SFLoginOutput tokenRequest(SFLoginInput loginInput, TokenRequestT

postRequest = new HttpPost(uriBuilder.build());

// Add headers for driver name and version
postRequest.addHeader(SF_HEADER_CLIENT_APP_ID, loginInput.getAppId());
postRequest.addHeader(SF_HEADER_CLIENT_APP_VERSION, loginInput.getAppVersion());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());
Expand Down Expand Up @@ -1259,6 +1273,10 @@ private static JsonNode federatedFlowStep1(SFLoginInput loginInput) throws Snowf
postRequest.setEntity(input);
postRequest.addHeader("accept", "application/json");

// Add headers for driver name and version
postRequest.addHeader(SF_HEADER_CLIENT_APP_ID, loginInput.getAppId());
postRequest.addHeader(SF_HEADER_CLIENT_APP_VERSION, loginInput.getAppVersion());

final String gsResponse =
HttpUtil.executeGeneralRequest(
postRequest,
Expand Down Expand Up @@ -1614,4 +1632,23 @@ public static String generateJWTToken(
privateKey, privateKeyFile, privateKeyFilePwd, accountName, userName);
return s.issueJwtToken();
}

/**
* Helper method to check if the request path is a login/auth request to use for retry strategy.
*
* @param request the post request
* @return true if this is a login/auth request, false otherwise
*/
public static boolean isNewRetryStrategyRequest(HttpRequestBase request) {
URI requestURI = request.getURI();
String requestPath = requestURI.getPath();
if (requestPath != null) {
if (requestPath.equals(SF_PATH_LOGIN_REQUEST)
|| requestPath.equals(SF_PATH_AUTHENTICATOR_REQUEST)
|| requestPath.equals(SF_PATH_TOKEN_REQUEST)) {
return true;
}
}
return false;
}
}
22 changes: 21 additions & 1 deletion src/main/java/net/snowflake/client/jdbc/RestRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ public static CloseableHttpResponse execute(
// when there are transient network/GS issues.
long startTimePerRequest = startTime;

// Used to indicate that this is a login/auth request and will be using the new retry strategy.
boolean isLoginRequest = SessionUtil.isNewRetryStrategyRequest(httpRequest);

// total elapsed time due to transient issues.
long elapsedMilliForTransientIssues = 0;

Expand Down Expand Up @@ -417,7 +420,24 @@ public static CloseableHttpResponse execute(
logger.debug("sleeping in {}(ms)", backoffInMilli);
Thread.sleep(backoffInMilli);
elapsedMilliForTransientIssues += backoffInMilli;
backoffInMilli = backoff.nextSleepTime(backoffInMilli);
if (isLoginRequest) {
long jitteredBackoffInMilli = backoff.getJitterForLogin(backoffInMilli);
backoffInMilli =
(long)
backoff.chooseRandom(
jitteredBackoffInMilli + backoffInMilli,
Math.pow(2, retryCount) + jitteredBackoffInMilli);
} else {
backoffInMilli = backoff.nextSleepTime(backoffInMilli);
}
if (retryTimeoutInMilliseconds > 0
&& (elapsedMilliForTransientIssues + backoffInMilli) > retryTimeoutInMilliseconds) {
// If the timeout will be reached before the next backoff, just use the remaining
// time.
backoffInMilli =
Math.min(
backoffInMilli, retryTimeoutInMilliseconds - elapsedMilliForTransientIssues);
}
} catch (InterruptedException ex1) {
logger.debug("Backoff sleep before retrying login got interrupted", false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,14 @@ public DecorrelatedJitterBackoff(long base, long cap) {
public long nextSleepTime(long sleep) {
return Math.min(cap, ThreadLocalRandom.current().nextLong(base, sleep * 3));
}

public long getJitterForLogin(long currentTime) {
double multiplicationFactor = chooseRandom(-1, 1);
long jitter = (long) (multiplicationFactor * currentTime * 0.5);
return jitter;
}

public double chooseRandom(double min, double max) {
return min + (Math.random() * (max - min));
}
}
46 changes: 45 additions & 1 deletion src/test/java/net/snowflake/client/core/SessionUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
package net.snowflake.client.core;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;

import com.fasterxml.jackson.databind.node.BooleanNode;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.snowflake.client.jdbc.MockConnectionTest;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.junit.Test;

public class SessionUtilTest {
Expand Down Expand Up @@ -83,4 +89,42 @@ public void testConvertSystemPropertyToIntValue() {
HttpUtil.JDBC_MAX_CONNECTIONS_PER_ROUTE_PROPERTY,
HttpUtil.DEFAULT_MAX_CONNECTIONS_PER_ROUTE));
}

@Test
public void testIsLoginRequest() {
List<String> testCases = new ArrayList<String>();
testCases.add("/session/v1/login-request");
testCases.add("/session/token-request");
testCases.add("/session/authenticator-request");

for (String testCase : testCases) {
try {
URIBuilder uriBuilder = new URIBuilder("https://test.snowflakecomputing.com");
uriBuilder.setPath(testCase);
URI uri = uriBuilder.build();
HttpPost postRequest = new HttpPost(uri);
assertTrue(SessionUtil.isNewRetryStrategyRequest(postRequest));
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
}

@Test
public void testIsLoginRequestInvalidURIPath() {
List<String> testCases = new ArrayList<String>();
testCases.add("/session/not-a-real-path");

for (String testCase : testCases) {
try {
URIBuilder uriBuilder = new URIBuilder("https://test.snowflakecomputing.com");
uriBuilder.setPath(testCase);
URI uri = uriBuilder.build();
HttpPost postRequest = new HttpPost(uri);
assertFalse(SessionUtil.isNewRetryStrategyRequest(postRequest));
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ public void testWrongHostNameTimeout() throws InterruptedException {
equalTo(ErrorCode.NETWORK_ERROR.getMessageCode()));

conEnd = System.currentTimeMillis();
assertThat("Login time out not taking effective", conEnd - connStart < 60000);
assertThat("Login time out not taking effective", conEnd - connStart < 300000);

Thread.sleep(WAIT_FOR_TELEMETRY_REPORT_IN_MILLISECS);
if (TelemetryService.getInstance().isDeploymentEnabled()) {
Expand Down Expand Up @@ -595,7 +595,7 @@ public void testHttpsLoginTimeoutWithSSL() throws InterruptedException {
equalTo(ErrorCode.NETWORK_ERROR.getMessageCode()));

conEnd = System.currentTimeMillis();
assertThat("Login time out not taking effective", conEnd - connStart < 60000);
assertThat("Login time out not taking effective", conEnd - connStart < 300000);
Thread.sleep(WAIT_FOR_TELEMETRY_REPORT_IN_MILLISECS);
if (TelemetryService.getInstance().isDeploymentEnabled()) {
assertThat(
Expand Down
Loading
Loading