Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Kondaka <[email protected]>
  • Loading branch information
kkondaka committed Nov 16, 2024
1 parent dfe8f0b commit 0c02d55
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.dataprepper.plugins.lambda.common;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;

import org.checkerframework.common.reflection.qual.Invoke;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory;
Expand Down Expand Up @@ -36,39 +38,27 @@
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;

public class LambdaCommonHandler {
private final Logger LOG;
private final LambdaAsyncClient lambdaAsyncClient;
private final String functionName;
private final String invocationType;
private final LambdaCommonConfig config;
private final String whenCondition;
BufferFactory bufferFactory;
final InputCodec responseCodec;
final ExpressionEvaluator expressionEvaluator;
JsonOutputCodecConfig jsonOutputCodecConfig;
private final int maxEvents;
private final ByteCount maxBytes;
private final Duration maxCollectionDuration;
private final ResponseEventHandlingStrategy responseStrategy;

public LambdaCommonHandler(final Logger log,
final LambdaAsyncClient lambdaAsyncClient,
final JsonOutputCodecConfig jsonOutputCodecConfig,
final InputCodec responseCodec,
final String whenCondition,
final ExpressionEvaluator expressionEvaluator,
final ResponseEventHandlingStrategy responseStrategy,
final LambdaCommonConfig lambdaCommonConfig) {
this.LOG = log;
this.lambdaAsyncClient = lambdaAsyncClient;
this.responseStrategy = responseStrategy;
this.config = lambdaCommonConfig;
this.jsonOutputCodecConfig = jsonOutputCodecConfig;
this.whenCondition = whenCondition;
this.responseCodec = responseCodec;
this.expressionEvaluator = expressionEvaluator;
this.functionName = config.getFunctionName();
this.invocationType = config.getInvocationType().getAwsLambdaValue();
maxEvents = lambdaCommonConfig.getBatchOptions().getThresholdOptions().getEventCount();
Expand All @@ -77,13 +67,6 @@ public LambdaCommonHandler(final Logger log,
bufferFactory = new InMemoryBufferFactory();
}

public LambdaCommonHandler(final Logger log,
final LambdaAsyncClient lambdaAsyncClient,
final JsonOutputCodecConfig jsonOutputCodecConfig,
final LambdaCommonConfig lambdaCommonConfig) {
this(log, lambdaAsyncClient, jsonOutputCodecConfig, null, null, null, null, lambdaCommonConfig);
}

public Buffer createBuffer(BufferFactory bufferFactory) {
try {
LOG.debug("Resetting buffer");
Expand Down Expand Up @@ -116,23 +99,15 @@ public void waitForFutures(List<CompletableFuture<Void>> futureList) {
}

public List<Record<Event>> sendRecords(Collection<Record<Event>> records,
BiConsumer<Buffer, List<Record<Event>>> successHandler, BiConsumer<Buffer, List<Record<Event>>> failureHandler) {
BiFunction<Buffer, InvokeResponse, List<Record<Event>>> successHandler,
BiConsumer<Buffer, List<Record<Event>>> failureHandler) {
List<Record<Event>> resultRecords = Collections.synchronizedList(new ArrayList());
boolean createNewBuffer = true;
Buffer currentBufferPerBatch = null;
OutputCodec requestCodec = null;
List futureList = new ArrayList<>();
for (Record<Event> record : records) {
final Event event = record.getData();

// If the condition is false, add the event to resultRecords as-is
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) {
synchronized(resultRecords) {
resultRecords.add(record);
}
continue;
}

Event event = record.getData();
try {
if (createNewBuffer) {
currentBufferPerBatch = createBuffer(bufferFactory);
Expand All @@ -157,8 +132,9 @@ public List<Record<Event>> sendRecords(Collection<Record<Event>> records,
}

boolean flushToLambdaIfNeeded(List<Record<Event>> resultRecords, Buffer currentBufferPerBatch,
OutputCodec requestCodec, List futureList, BiConsumer<Buffer, List<Record<Event>>> successHandler,
BiConsumer<Buffer, List<Record<Event>>> failureHandler, boolean forceFlush) {
OutputCodec requestCodec, List futureList,
BiFunction<Buffer, InvokeResponse, List<Record<Event>>> successHandler,
BiConsumer<Buffer, List<Record<Event>>> failureHandler, boolean forceFlush) {

LOG.debug("currentBufferPerBatchEventCount:{}, maxEvents:{}, maxBytes:{}, " +
"maxCollectionDuration:{}, forceFlush:{} ", currentBufferPerBatch.getEventCount(),
Expand Down Expand Up @@ -201,8 +177,9 @@ boolean flushToLambdaIfNeeded(List<Record<Event>> resultRecords, Buffer currentB
}

private void handleLambdaResponse(List<Record<Event>> resultRecords, Buffer flushedBuffer,
int eventCount, InvokeResponse response, BiConsumer<Buffer, List<Record<Event>>> successHandler,
BiConsumer<Buffer, List<Record<Event>>> failureHandler) {
int eventCount, InvokeResponse response,
BiFunction<Buffer, InvokeResponse, List<Record<Event>>> successHandler,
BiConsumer<Buffer, List<Record<Event>>> failureHandler) {
boolean success = checkStatusCode(response);
if (success) {
LOG.info("Successfully flushed {} events", eventCount);
Expand All @@ -212,70 +189,16 @@ private void handleLambdaResponse(List<Record<Event>> resultRecords, Buffer flus
Duration latency = flushedBuffer.stopLatencyWatch();
//lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
//totalFlushedEvents += eventCount;

convertLambdaResponseToEvent(resultRecords, response, flushedBuffer, successHandler);
synchronized(resultRecords) {
resultRecords.addAll(successHandler.apply(flushedBuffer, response));
}
//convertLambdaResponseToEvent(resultRecords, response, flushedBuffer, successHandler);
} else {
// Non-2xx status code treated as failure
handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()), flushedBuffer, resultRecords, failureHandler);
}
}

/*
* Assumption: Lambda always returns json array.
* 1. If response has an array, we assume that we split the individual events.
* 2. If it is not an array, then create one event per response.
*/
void convertLambdaResponseToEvent(final List<Record<Event>> resultRecords, final InvokeResponse lambdaResponse,
Buffer flushedBuffer, BiConsumer<Buffer, List<Record<Event>>> successHandler) {
try {
List<Event> parsedEvents = new ArrayList<>();
List<Record<Event>> originalRecords = flushedBuffer.getRecords();

SdkBytes payload = lambdaResponse.payload();
// Handle null or empty payload
if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) {
LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events");
// Set metrics
//requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
//responsePayloadMetric.set(0);
} else {
// Set metrics
//requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
//responsePayloadMetric.set(payload.asByteArray().length);

LOG.debug("Response payload:{}", payload.asUtf8String());
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
//Convert to response codec
try {
responseCodec.parse(inputStream, record -> {
Event event = record.getData();
parsedEvents.add(event);
});
} catch (IOException ex) {
throw new RuntimeException(ex);
}

LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " +
"FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(),
flushedBuffer.getSize());
synchronized(resultRecords) {
responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer);
successHandler.accept(flushedBuffer, originalRecords);
}
}
} catch (Exception e) {
LOG.error(NOISY, "Error converting Lambda response to Event");
// Metrics update
//requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
//responsePayloadMetric.set(0);
//????? handleFailure(e, flushedBuffer, resultRecords, failureHandler);
}
}

/*
* If one event in the Buffer fails, we consider that the entire
* Batch fails and tag each event in that Batch.
*/
void handleFailure(Throwable e, Buffer flushedBuffer, List<Record<Event>> resultRecords, BiConsumer<Buffer, List<Record<Event>>> failureHandler) {
try {
if (flushedBuffer.getEventCount() > 0) {
Expand All @@ -291,4 +214,5 @@ void handleFailure(Throwable e, Buffer flushedBuffer, List<Record<Event>> result
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,21 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;

@DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class)
public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {
Expand Down Expand Up @@ -64,6 +73,7 @@ public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Eve
final LambdaProcessorConfig lambdaProcessorConfig;
private final ResponseEventHandlingStrategy responseStrategy;
private final JsonOutputCodecConfig jsonOutputCodecConfig;
private final ThreadLocal<InputCodec> responseCodec;

@DataPrepperPluginConstructor
public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pluginMetrics, final LambdaProcessorConfig lambdaProcessorConfig, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) {
Expand All @@ -80,6 +90,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl

tagsOnMatchFailure = lambdaProcessorConfig.getTagsOnMatchFailure();


PluginModel responseCodecConfig = lambdaProcessorConfig.getResponseCodecConfig();

if (responseCodecConfig == null) {
Expand All @@ -89,12 +100,15 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl
codecPluginSetting = new PluginSetting(responseCodecConfig.getPluginName(), responseCodecConfig.getPluginSettings());
}

responseCodec = ThreadLocal.withInitial(()->pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting));
jsonOutputCodecConfig = new JsonOutputCodecConfig();
jsonOutputCodecConfig.setKeyName(lambdaProcessorConfig.getBatchOptions().getKeyName());

lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient(lambdaProcessorConfig.getAwsAuthenticationOptions(),
lambdaProcessorConfig.getMaxConnectionRetries(), awsCredentialsSupplier, lambdaProcessorConfig.getConnectionTimeout());

lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, jsonOutputCodecConfig, lambdaProcessorConfig);

// Select the correct strategy based on the configuration
if (lambdaProcessorConfig.getResponseEventsMatch()) {
this.responseStrategy = new StrictResponseEventHandlingStrategy();
Expand All @@ -111,12 +125,83 @@ public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
}
BufferFactory bufferFactory = new InMemoryBufferFactory();
// Setup request codec
List<Record<Event>> resultRecords = new ArrayList<>();
List<Record<Event>> recordsToLambda = new ArrayList<>();
for (Record<Event> record : records) {
final Event event = record.getData();
// If the condition is false, add the event to resultRecords as-is
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) {
resultRecords.add(record);
continue;
}
recordsToLambda.add(record);
}
resultRecords.addAll(lambdaCommonHandler.sendRecords(recordsToLambda,
(inputBuffer, response)-> {
List<Record<Event>> outputRecords = convertLambdaResponseToEvent(response, inputBuffer);
return outputRecords;
},
(inputBuffer, outputRecords)-> {
addFailureTags(inputBuffer, outputRecords);
})
);
return resultRecords;
}

List<Record<Event>> convertLambdaResponseToEvent(final InvokeResponse lambdaResponse, Buffer flushedBuffer) {
List<Record<Event>> originalRecords = flushedBuffer.getRecords();
try {
List<Event> parsedEvents = new ArrayList<>();

InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting);
lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, jsonOutputCodecConfig, responseCodec, whenCondition, expressionEvaluator, responseStrategy, lambdaProcessorConfig);
return lambdaCommonHandler.sendRecords(records, (inputBuffer, resultRecords)->{}, (inputBuffer, resultRecords)->{ addFailureTags(inputBuffer, resultRecords);});

List<Record<Event>> resultRecords = new ArrayList<>();
SdkBytes payload = lambdaResponse.payload();
// Handle null or empty payload
if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) {
LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events");
// Set metrics
//requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
//responsePayloadMetric.set(0);
} else {
// Set metrics
//requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
//responsePayloadMetric.set(payload.asByteArray().length);

LOG.debug("Response payload:{}", payload.asUtf8String());
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
//Convert to response codec
try {
responseCodec.get().parse(inputStream, record -> {
Event event = record.getData();
parsedEvents.add(event);
});
} catch (IOException ex) {
throw new RuntimeException(ex);
}

LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " +
"FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(),
flushedBuffer.getSize());
responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer);

}
return resultRecords;
} catch (Exception e) {
LOG.error(NOISY, "Error converting Lambda response to Event");
// Metrics update
//requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
//responsePayloadMetric.set(0);
addFailureTags(flushedBuffer, originalRecords);
return originalRecords;
//????? handleFailure(e, flushedBuffer, resultRecords, failureHandler);
}
}

/*
* If one event in the Buffer fails, we consider that the entire
* Batch fails and tag each event in that Batch.
*/

private void addFailureTags(Buffer flushedBuffer, List<Record<Event>> resultRecords) {
// Add failure tags to each event in the batch
for (Record<Event> record : flushedBuffer.getRecords()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ public void output(Collection<Record<Event>> records) {
BufferFactory bufferFactory = new InMemoryBufferFactory();
lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, jsonOutputCodecConfig, lambdaSinkConfig);
lambdaCommonHandler.sendRecords(records,
(inputBuffer, resultRecords)->{
(inputBuffer, response)->{
releaseEventHandlesPerBatch(true, inputBuffer);
return null;
},
(inputBuffer, resultRecords)->{
handleFailure(new RuntimeException("failed"), inputBuffer);
Expand Down

0 comments on commit 0c02d55

Please sign in to comment.