Skip to content

Commit

Permalink
Rename some classes
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 4, 2024
1 parent a19de09 commit 0f3813f
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
Expand All @@ -24,7 +24,7 @@
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR;
import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR;

/**
* Processor to add explanation details to search response
Expand All @@ -40,19 +40,21 @@ public class ExplanationResponseProcessor implements SearchResponseProcessor {
private final boolean ignoreFailure;

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
return processResponse(request, response, null);
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) {
if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) {
if (Objects.isNull(requestContext)
|| (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))
|| requestContext.getAttribute(EXPLAIN_RESPONSE_KEY) instanceof ExplanationPayload == false) {
return response;
}
ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationResponse.ExplanationType, Object> explainPayload = explanationResponse.getExplainPayload();
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = explanationResponse.getExplanation();
Explanation processorExplanation = explanationPayload.getExplanation();
if (Objects.isNull(processorExplanation)) {
return response;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
Expand All @@ -42,7 +42,7 @@

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.topLevelExpalantionForCombinedScore;
import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria;

/**
Expand Down Expand Up @@ -123,36 +123,36 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<

Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs);

Map<DocIdAtSearchShard, ExplainDetails> normalizationExplain = scoreNormalizer.explain(
Map<DocIdAtSearchShard, ExplainationDetails> normalizationExplain = scoreNormalizer.explain(
queryTopDocs,
(ExplainableTechnique) request.getNormalizationTechnique()
);
Map<SearchShard, List<ExplainDetails>> combinationExplain = scoreCombiner.explain(
Map<SearchShard, List<ExplainationDetails>> combinationExplain = scoreCombiner.explain(
queryTopDocs,
request.getCombinationTechnique(),
sortForQuery
);
Map<SearchShard, List<CombinedExplainDetails>> combinedExplain = new HashMap<>();

combinationExplain.forEach((searchShard, explainDetails) -> {
for (ExplainDetails explainDetail : explainDetails) {
for (ExplainationDetails explainDetail : explainDetails) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard);
ExplainDetails normalizedExplainDetails = normalizationExplain.get(docIdAtSearchShard);
ExplainationDetails normalizedExplainationDetails = normalizationExplain.get(docIdAtSearchShard);
CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder()
.normalizationExplain(normalizedExplainDetails)
.normalizationExplain(normalizedExplainationDetails)
.combinationExplain(explainDetail)
.build();
combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails);
}
});

ExplanationResponse explanationResponse = ExplanationResponse.builder()
ExplanationPayload explanationPayload = ExplanationPayload.builder()
.explanation(topLevelExplanationForTechniques)
.explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplain))
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse);
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationPayload);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on arithmetic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on geometrical mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on harmonic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainationDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getScoreCombinationExplainDetailsForDocument;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getScoreCombinationExplainDetailsForDocument;

/**
* Abstracts combination of scores in query search results.
Expand Down Expand Up @@ -318,15 +318,15 @@ private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, final lon
return new TotalHits(maxHits, totalHits);
}

public Map<SearchShard, List<ExplainDetails>> explain(
public Map<SearchShard, List<ExplainationDetails>> explain(
final List<CompoundTopDocs> queryTopDocs,
final ScoreCombinationTechnique combinationTechnique,
final Sort sort
) {
// In case of duplicate keys, keep the first value
HashMap<SearchShard, List<ExplainDetails>> explanations = new HashMap<>();
HashMap<SearchShard, List<ExplainationDetails>> explanations = new HashMap<>();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
for (Map.Entry<SearchShard, List<ExplainDetails>> docIdAtSearchShardExplainDetailsEntry : explainByShard(
for (Map.Entry<SearchShard, List<ExplainationDetails>> docIdAtSearchShardExplainDetailsEntry : explainByShard(
combinationTechnique,
compoundQueryTopDocs,
sort
Expand All @@ -337,7 +337,7 @@ public Map<SearchShard, List<ExplainDetails>> explain(
return explanations;
}

private Map<SearchShard, List<ExplainDetails>> explainByShard(
private Map<SearchShard, List<ExplainationDetails>> explainByShard(
final ScoreCombinationTechnique scoreCombinationTechnique,
final CompoundTopDocs compoundQueryTopDocs,
Sort sort
Expand All @@ -351,7 +351,7 @@ private Map<SearchShard, List<ExplainDetails>> explainByShard(
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue())));
Collection<Integer> sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId);
List<ExplainDetails> listOfExplainsForShard = sortedDocsIds.stream()
List<ExplainationDetails> listOfExplainsForShard = sortedDocsIds.stream()
.map(
docId -> getScoreCombinationExplainDetailsForDocument(
docId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
@Builder
@Getter
public class CombinedExplainDetails {
private ExplainDetails normalizationExplain;
private ExplainDetails combinationExplain;
private ExplainationDetails normalizationExplain;
private ExplainationDetails combinationExplain;
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ default String describe() {
* @param queryTopDocs collection of CompoundTopDocs for each shard result
* @return map of document per shard and corresponding explanation object
*/
default Map<DocIdAtSearchShard, ExplainDetails> explain(final List<CompoundTopDocs> queryTopDocs) {
default Map<DocIdAtSearchShard, ExplainationDetails> explain(final List<CompoundTopDocs> queryTopDocs) {
return Map.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
* @param value
* @param description
*/
public record ExplainDetails(int docId, float value, String description) {
public ExplainDetails(float value, String description) {
public record ExplainationDetails(int docId, float value, String description) {
public ExplainationDetails(float value, String description) {
this(-1, value, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
/**
* Utility class for explain functionality
*/
public class ExplainUtils {
public class ExplainationUtils {

/**
* Creates map of DocIdAtQueryPhase to String containing source and normalized scores
* @param normalizedScores map of DocIdAtQueryPhase to normalized scores
* @param sourceScores map of DocIdAtQueryPhase to source scores
* @return map of DocIdAtQueryPhase to String containing source and normalized scores
*/
public static Map<DocIdAtSearchShard, ExplainDetails> getDocIdAtQueryForNormalization(
public static Map<DocIdAtSearchShard, ExplainationDetails> getDocIdAtQueryForNormalization(
final Map<DocIdAtSearchShard, List<Float>> normalizedScores,
final Map<DocIdAtSearchShard, List<Float>> sourceScores
) {
Map<DocIdAtSearchShard, ExplainDetails> explain = sourceScores.entrySet()
Map<DocIdAtSearchShard, ExplainationDetails> explain = sourceScores.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> {
List<Float> srcScores = entry.getValue();
List<Float> normScores = normalizedScores.get(entry.getKey());
return new ExplainDetails(
return new ExplainationDetails(
normScores.stream().reduce(0.0f, Float::max),
String.format(Locale.ROOT, "source scores: %s normalized to scores: %s", srcScores, normScores)
);
Expand All @@ -49,13 +49,13 @@ public static Map<DocIdAtSearchShard, ExplainDetails> getDocIdAtQueryForNormaliz
* @param normalizedScoresPerDoc
* @return
*/
public static ExplainDetails getScoreCombinationExplainDetailsForDocument(
public static ExplainationDetails getScoreCombinationExplainDetailsForDocument(
final Integer docId,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final float[] normalizedScoresPerDoc
) {
float combinedScore = combinedNormalizedScoresByDocId.get(docId);
return new ExplainDetails(
return new ExplainationDetails(
docId,
combinedScore,
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
@AllArgsConstructor
@Builder
@Getter
public class ExplanationResponse {
public class ExplanationPayload {
Explanation explanation;
Map<ExplanationType, Object> explainPayload;
Map<PayloadType, Object> explainPayload;

public enum ExplanationType {
public enum PayloadType {
NORMALIZATION_PROCESSOR
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import lombok.ToString;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization;

/**
* Abstracts normalization of scores based on L2 method
Expand Down Expand Up @@ -64,7 +64,7 @@ public String describe() {
}

@Override
public Map<DocIdAtSearchShard, ExplainDetails> explain(List<CompoundTopDocs> queryTopDocs) {
public Map<DocIdAtSearchShard, ExplainationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();
Map<DocIdAtSearchShard, List<Float>> sourceScores = new HashMap<>();
List<Float> normsPerSubquery = getL2Norm(queryTopDocs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

import lombok.ToString;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization;

/**
* Abstracts normalization of scores based on min-max method
Expand Down Expand Up @@ -78,7 +78,7 @@ public String describe() {
}

@Override
public Map<DocIdAtSearchShard, ExplainDetails> explain(final List<CompoundTopDocs> queryTopDocs) {
public Map<DocIdAtSearchShard, ExplainationDetails> explain(final List<CompoundTopDocs> queryTopDocs) {
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();
Map<DocIdAtSearchShard, List<Float>> sourceScores = new HashMap<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

public class ScoreNormalizer {
Expand All @@ -30,7 +30,7 @@ private boolean canQueryResultsBeNormalized(final List<CompoundTopDocs> queryTop
return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0);
}

public Map<DocIdAtSearchShard, ExplainDetails> explain(
public Map<DocIdAtSearchShard, ExplainationDetails> explain(
final List<CompoundTopDocs> queryTopDocs,
final ExplainableTechnique scoreNormalizationTechnique
) {
Expand Down
Loading

0 comments on commit 0f3813f

Please sign in to comment.