Skip to content

Commit

Permalink
refactor customNodeclient
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Nov 15, 2024
1 parent 9a98d18 commit 18294a4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException, I

// extend NodeClient since its execute method is final and mockito does not allow to mock final methods
// we can also use spy to overstep the final methods
NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool);
NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, null, false, detector, threadPool);
NodeClient clientSpy = spy(client);
NodeStateManager nodeStateManager = mock(NodeStateManager.class);
clientUtil = new SecurityClientUtil(nodeStateManager, settings);
Expand Down Expand Up @@ -544,45 +544,11 @@ public void testUpdateTextField() throws IOException, InterruptedException {
testUpdateTemplate(TEXT_FIELD_TYPE);
}

public static NodeClient getCustomNodeClient(
SearchResponse detectorResponse,
SearchResponse userIndexResponse,
AnomalyDetector detector,
ThreadPool pool
) {
return new NodeClient(Settings.EMPTY, pool) {
@Override
public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
try {
if (action.equals(SearchAction.INSTANCE)) {
assertTrue(request instanceof SearchRequest);
SearchRequest searchRequest = (SearchRequest) request;
if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) {
listener.onResponse((Response) detectorResponse);
} else {
listener.onResponse((Response) userIndexResponse);
}
} else {
GetFieldMappingsResponse response = new GetFieldMappingsResponse(
TestHelpers.createFieldMappings(detector.getIndices().get(0), "timestamp", "date")
);
listener.onResponse((Response) response);
}
} catch (IOException e) {
logger.error("Create field mapping threw an exception", e);
}
}
};
}

public static NodeClient getCustomNodeClient(
SearchResponse detectorResponse,
SearchResponse userIndexResponse,
SearchResponse configInputIndicesResponse,
boolean useConfigInputIndicesResponse,
AnomalyDetector detector,
ThreadPool pool
) {
Expand All @@ -602,11 +568,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE
searchCallCount++;
if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) {
listener.onResponse((Response) detectorResponse);
} else if (Arrays.equals(searchRequest.indices(), detector.getIndices().toArray(new String[0]))
} else if (useConfigInputIndicesResponse
&& Arrays.equals(searchRequest.indices(), detector.getIndices().toArray(new String[0]))
&& searchRequest.source().aggregations() == null) {
listener.onResponse((Response) configInputIndicesResponse);
// Call for feature validation occurs on the 3rd call.
} else if (searchCallCount == 3) {
// Call for feature validation occurs on the 3rd call and we want to make sure we supplied a response to the
// previous call.
} else if (searchCallCount == 3 && useConfigInputIndicesResponse) {
// This is the third search call, which should be for featureConfig and we want to replicate something like a
// timeout exception
listener.onFailure(new OpenSearchStatusException("timeout", RestStatus.BAD_REQUEST));
Expand Down Expand Up @@ -638,7 +606,7 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException, Interrupte
when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits));
// extend NodeClient since its execute method is final and mockito does not allow to mock final methods
// we can also use spy to overstep the final methods
NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool);
NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, null, false, detector, threadPool);
NodeClient clientSpy = spy(client);
NodeStateManager nodeStateManager = mock(NodeStateManager.class);
clientUtil = new SecurityClientUtil(nodeStateManager, settings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.ad.indices.ADIndex;
import org.opensearch.ad.indices.ADIndexManagement;
Expand Down Expand Up @@ -151,7 +152,7 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc
// extend NodeClient since its execute method is final and mockito does not allow to mock final methods
// we can also use spy to overstep the final methods
NodeClient client = IndexAnomalyDetectorActionHandlerTests
.getCustomNodeClient(detectorResponse, userIndexResponse, singleEntityDetector, threadPool);
.getCustomNodeClient(detectorResponse, userIndexResponse, null, false, singleEntityDetector, threadPool);

NodeClient clientSpy = spy(client);
NodeStateManager nodeStateManager = mock(NodeStateManager.class);
Expand Down Expand Up @@ -209,7 +210,7 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio
// extend NodeClient since its execute method is final and mockito does not allow to mock final methods
// we can also use spy to overstep the final methods
NodeClient client = IndexAnomalyDetectorActionHandlerTests
.getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool);
.getCustomNodeClient(detectorResponse, userIndexResponse, null, false, detector, threadPool);
NodeClient clientSpy = spy(client);
NodeStateManager nodeStateManager = mock(NodeStateManager.class);
SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings);
Expand Down Expand Up @@ -262,8 +263,7 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimitDuplicateNameFailure
SearchResponse detectorResponse = mock(SearchResponse.class);
when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits));
SearchResponse userIndexResponse = mock(SearchResponse.class);
int userIndexHits = 0;
when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits));
when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(0));
AnomalyDetector singleEntityDetector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null, true);

SearchResponse configInputIndicesResponse = mock(SearchResponse.class);
Expand All @@ -272,7 +272,7 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimitDuplicateNameFailure
// extend NodeClient since its execute method is final and mockito does not allow to mock final methods
// we can also use spy to overstep the final methods
NodeClient client = IndexAnomalyDetectorActionHandlerTests
.getCustomNodeClient(detectorResponse, userIndexResponse, configInputIndicesResponse, singleEntityDetector, threadPool);
.getCustomNodeClient(detectorResponse, userIndexResponse, configInputIndicesResponse, true, singleEntityDetector, threadPool);

NodeClient clientSpy = spy(client);
NodeStateManager nodeStateManager = mock(NodeStateManager.class);
Expand All @@ -297,17 +297,15 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimitDuplicateNameFailure
clock,
settings
);

final CountDownLatch inProgressLatch = new CountDownLatch(1);
handler.start(ActionListener.wrap(r -> {
fail("Should not reach here.");
inProgressLatch.countDown();
}, e -> {
PlainActionFuture<ValidateConfigResponse> future = PlainActionFuture.newFuture();
handler.start(future);
try {
future.actionGet(100, TimeUnit.SECONDS);
fail("should not reach here");
} catch (Exception e) {
assertTrue(e instanceof TimeSeriesException);
assertTrue(e.getMessage().contains("Cannot create anomaly detector with name"));
inProgressLatch.countDown();
}));
assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS));
}
verify(clientSpy, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any());
}
}

0 comments on commit 18294a4

Please sign in to comment.