Skip to content

Commit

Permalink
Added unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
andreachild committed Nov 4, 2024
1 parent 2fa6c47 commit 13cada7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
14 changes: 7 additions & 7 deletions gremlin-driver/src/main/java/examples/Connections.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ private static void withRemote() throws Exception {
// See reference/#gremlin-java-configuration for full list of configurations
private static void withCluster() throws Exception {
Cluster cluster = Cluster.build("localhost").
maxConnectionPoolSize(8).
path("/gremlin").
port(8182).
serializer(new GraphBinaryMessageSerializerV4()).
create();
maxConnectionPoolSize(8).
path("/gremlin").
port(8182).
serializer(new GraphBinaryMessageSerializerV4()).
create();
GraphTraversalSource g = traversal().withRemote(DriverRemoteConnection.using(cluster, "g"));

g.addV().iterate();
Expand All @@ -96,8 +96,8 @@ private static void withSerializer() throws Exception {
TypeSerializerRegistry typeSerializerRegistry = TypeSerializerRegistry.build().addRegistry(registry).create();
MessageSerializer serializer = new GraphBinaryMessageSerializerV4(typeSerializerRegistry);
Cluster cluster = Cluster.build("localhost").
serializer(serializer).
create();
serializer(serializer).
create();
Client client = cluster.connect();
GraphTraversalSource g = traversal().withRemote(DriverRemoteConnection.using(client, "g"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ public Sigv4(final String regionName, final AwsCredentialsProvider awsCredential
@Override
public HttpRequest apply(final HttpRequest httpRequest) {
try {
final ContentStreamProvider content = toContentStream(httpRequest);
// Convert Http request into an AWS SDK signable request
final SdkHttpRequest awsSignableRequest = toSignableRequest(httpRequest);
final AwsCredentials credentials = awsCredentialsProvider.resolveCredentials();
final ContentStreamProvider content = toContentStream(httpRequest);

// Sign the AWS SDK signable request (which internally adds some HTTP headers)
final SignedRequest signed = aws4Signer.sign(r -> r.identity(credentials)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,30 @@
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant;

import static org.hamcrest.CoreMatchers.allOf;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.startsWith;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.when;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.AUTHORIZATION;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.HOST;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_CONTENT_SHA256;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_DATE;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_SECURITY_TOKEN;

public class Sigv4Test {
public static final String REGION = "us-west-2";
public static final String SERVICE_NAME = "service-name";
private static final String REGION = "us-west-2";
private static final String SERVICE_NAME = "service-name";
private static final byte[] REQUEST_BODY = "{\"gremlin\":\"2-1\"}".getBytes(StandardCharsets.UTF_8);
private static final String HOST = "localhost";
private static final String URI_WITH_QUERY_PARAMS = "http://" + HOST + ":8182?a=1&b=2";
private static final String KEY = "foo";
private static final String SECRET = "bar";
@Rule
public MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock
Expand All @@ -44,7 +51,7 @@ public void setup() {

@Test
public void shouldAddSignedHeaders() throws Exception {
when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create("foo", "bar"));
when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET));
HttpRequest httpRequest = createRequest();
sigv4.apply(httpRequest);
validateExpectedHeaders(httpRequest);
Expand All @@ -53,25 +60,47 @@ public void shouldAddSignedHeaders() throws Exception {
@Test
public void shouldAddSignedHeadersAndSessionToken() throws Exception {
String sessionToken = "foobarz";
when(credentialsProvider.resolveCredentials()).thenReturn(AwsSessionCredentials.create("foo", "bar", sessionToken));
when(credentialsProvider.resolveCredentials()).thenReturn(AwsSessionCredentials.create(KEY, SECRET, sessionToken));
HttpRequest httpRequest = createRequest();
sigv4.apply(httpRequest);
validateExpectedHeaders(httpRequest);
assertEquals(sessionToken, httpRequest.headers().get(X_AMZ_SECURITY_TOKEN));
}

@Test
public void shouldThrowIfRequestNonByteArray() {
Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class,
() -> sigv4.apply(new HttpRequest(new HashMap<>(), "not byte array", new URI(URI_WITH_QUERY_PARAMS))));
assertTrue(ex.getMessage().contains("Expected byte[] in HttpRequest body"));
}

@Test
public void shouldThrowIfNoRequestMethod() {
Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class,
() -> sigv4.apply(new HttpRequest(new HashMap<>(), REQUEST_BODY, new URI(URI_WITH_QUERY_PARAMS), null)));
assertTrue(ex.getMessage().contains("The request method must not be null"));
}

@Test
public void shouldThrowIfNoRequestURI() {
Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class,
() -> sigv4.apply(new HttpRequest(new HashMap<>(), REQUEST_BODY, null)));
assertTrue(ex.getMessage().contains("The request URI must not be null"));
}

private HttpRequest createRequest() throws URISyntaxException {
HttpRequest httpRequest = new HttpRequest(new HashMap<>(), "{\"gremlin\":\"2-1\"}".getBytes(StandardCharsets.UTF_8), new URI("http://localhost:8182?a=1&b=2"));
HttpRequest httpRequest = new HttpRequest(new HashMap<>(), REQUEST_BODY, new URI(URI_WITH_QUERY_PARAMS));
httpRequest.headers().put("Content-Type", "application/json");
httpRequest.headers().put("Host", "this-should-be-ignored-for-signed-host-header");
return httpRequest;
}

private void validateExpectedHeaders(HttpRequest httpRequest) {
assertEquals("localhost", httpRequest.headers().get(HOST));
assertEquals(HOST, httpRequest.headers().get(SignerConstant.HOST));
assertNotNull(httpRequest.headers().get(X_AMZ_DATE));
assertNotNull(httpRequest.headers().get(X_AMZ_CONTENT_SHA256));
assertThat(httpRequest.headers().get(AUTHORIZATION),
allOf(startsWith("AWS4-HMAC-SHA256 Credential=foo"),
allOf(startsWith("AWS4-HMAC-SHA256 Credential=" + KEY),
containsString("/" + REGION + "/service-name/aws4_request"),
containsString("Signature=")));
}
Expand Down

0 comments on commit 13cada7

Please sign in to comment.