Skip to content

Commit

Permalink
Doing some refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 4, 2024
1 parent 9340557 commit 6761ac7
Show file tree
Hide file tree
Showing 15 changed files with 562 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
Expand Down Expand Up @@ -185,7 +185,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRespon
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()),
ExplainResponseProcessor.TYPE,
ExplanationResponseProcessor.TYPE,
new ProcessorExplainPublisherFactory()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSo
public CompoundTopDocs(final QuerySearchResult querySearchResult) {
final TopDocs topDocs = querySearchResult.topDocs().topDocs;
final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
SearchShard searchShard = new SearchShard(
searchShardTarget.getIndex(),
searchShardTarget.getShardId().id(),
searchShardTarget.getNodeId()
);
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
boolean isSortEnabled = false;
if (topDocs instanceof TopFieldDocs) {
isSortEnabled = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
Expand All @@ -24,13 +24,13 @@
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR;
import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR;

@Getter
@AllArgsConstructor
public class ExplainResponseProcessor implements SearchResponseProcessor {
public class ExplanationResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "explain_response_processor";
public static final String TYPE = "explanation_response_processor";

private final String description;
private final String tag;
Expand All @@ -46,10 +46,10 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) {
return response;
}
ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ProcessorExplainDto.ExplanationType, Object> explainPayload = processorExplainDto.getExplainPayload();
ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationResponse.ExplanationType, Object> explainPayload = explanationResponse.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = processorExplainDto.getExplanation();
Explanation processorExplanation = explanationResponse.getExplanation();
if (Objects.isNull(processorExplanation)) {
return response;
}
Expand All @@ -62,7 +62,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
for (int i = 0; i < searchHitsArray.length; i++) {
SearchHit searchHit = searchHitsArray[i];
SearchShardTarget searchShardTarget = searchHit.getShard();
SearchShard searchShard = SearchShard.create(searchShardTarget);
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i);
explainsByShardCount.putIfAbsent(searchShard, -1);
}
Expand All @@ -73,7 +73,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
List<CombinedExplainDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);

for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.create(searchHit.getShard());
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
Explanation normalizedExplanation = Explanation.match(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -146,13 +146,13 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
}
});

ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder()
ExplanationResponse explanationResponse = ExplanationResponse.builder()
.explanation(topLevelExplanationForTechniques)
.explainPayload(Map.of(ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto);
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

public record SearchShard(String index, int shardId, String nodeId) {

public static SearchShard create(SearchShardTarget searchShardTarget) {
public static SearchShard createSearchShard(SearchShardTarget searchShardTarget) {
return new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
@AllArgsConstructor
@Builder
@Getter
/**
* DTO class to hold explain details for normalization and combination
*/
public class CombinedExplainDetails {
private ExplainDetails normalizationExplain;
private ExplainDetails combinationExplain;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.opensearch.neuralsearch.processor.SearchShard;

/**
* Data class to store docId and search shard for a query.
* DTO class to store docId and search shard for a query.
* Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards.
* @param docId
* @param searchShard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
* @param value
* @param description
*/
public record ExplainDetails(float value, String description, int docId) {

public record ExplainDetails(int docId, float value, String description) {
public ExplainDetails(float value, String description) {
this(value, description, -1);
this(-1, value, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ public static ExplainDetails getScoreCombinationExplainDetailsForDocument(
) {
float combinedScore = combinedNormalizedScoresByDocId.get(docId);
return new ExplainDetails(
docId,
combinedScore,
String.format(
Locale.ROOT,
"normalized scores: %s combined to a final score: %s",
Arrays.toString(normalizedScoresPerDoc),
combinedScore
),
docId
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@AllArgsConstructor
@Builder
@Getter
public class ProcessorExplainDto {
public class ExplanationResponse {
Explanation explanation;
Map<ExplanationType, Object> explainPayload;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

Expand All @@ -21,6 +21,6 @@ public SearchResponseProcessor create(
Map<String, Object> config,
Processor.PipelineContext pipelineContext
) throws Exception {
return new ExplainResponseProcessor(description, tag, ignoreFailure);
return new ExplanationResponseProcessor(description, tag, ignoreFailure);
}
}

This file was deleted.

Loading

0 comments on commit 6761ac7

Please sign in to comment.