From 5dc74bef7be8db653c42b6939c3bb0ce7769a2f1 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 14 Nov 2024 15:04:58 +0800 Subject: [PATCH 1/5] add impl Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 12 ++- .../SparseEncodingProcessorFactory.java | 29 +++++- .../processor/pruning/PruneUtils.java | 92 +++++++++++++++++++ .../processor/pruning/PruningType.java | 45 +++++++++ .../neuralsearch/util/TokenWeightUtil.java | 21 ++++- 5 files changed, 192 insertions(+), 7 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e01840fbb..d49a6a709 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -14,6 +14,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.pruning.PruningType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -27,6 +28,8 @@ public final class SparseEncodingProcessor extends InferenceProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; + private final PruningType pruningType; + private final float pruneRatio; public SparseEncodingProcessor( String tag, @@ -34,11 +37,15 @@ public SparseEncodingProcessor( int batchSize, String modelId, Map fieldMap, + PruningType pruningType, + float pruneRatio, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + this.pruningType = pruningType; + this.pruneRatio = pruneRatio; } @Override @@ -49,7 +56,8 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio); + setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } @@ -59,7 +67,7 @@ public void doBatchExecute(List inferenceList, Consumer> handler mlCommonsClientAccessor.inferenceSentencesWithMapResult( this.modelId, inferenceList, - ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) + ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), onException) ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 46055df16..8f9e42eea 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -6,6 +6,8 @@ import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; @@ -19,6 +21,8 @@ import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.pruning.PruneUtils; +import org.opensearch.neuralsearch.processor.pruning.PruningType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -40,7 +44,30 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + PruningType pruningType = PruningType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); + float pruneRatio = 0; + if (pruningType != PruningType.NONE) { + // if we have prune type, then prune ratio field must have value + // readDoubleProperty will throw exception if value is not present + pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); + } else { + // if we don't have prune type, then prune ratio field must not have value + if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { + throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); + } + } - return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService); + return new SparseEncodingProcessor( + tag, + description, + batchSize, + modelId, + fieldMap, + pruningType, + pruneRatio, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java new file mode 100644 index 000000000..47aaaeac9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.pruning; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class PruneUtils { + public static final String PRUNE_TYPE_FIELD = "prune_type"; + public static final String PRUNE_RATIO_FIELD = "prune_ratio"; + + public static Map pruningByTopK(Map sparseVector, int k) { + List> list = new ArrayList<>(sparseVector.entrySet()); + list.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + Map result = new HashMap<>(); + for (int i = 0; i < k && i < list.size(); i++) { + Map.Entry entry = list.get(i); + result.put(entry.getKey(), entry.getValue()); + } + return result; + } + + public static Map pruningByMaxRatio(Map sparseVector, float ratio) { + float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); + + Map result = new HashMap<>(); + for (Map.Entry entry : sparseVector.entrySet()) { + float currentValue = entry.getValue(); + float currentRatio = currentValue / maxValue; + + if (currentRatio >= ratio) { + result.put(entry.getKey(), entry.getValue()); + } + } + + return result; + } + + public static Map pruningByValue(Map sparseVector, float thresh) { + Map result = new HashMap<>(sparseVector); + for (Map.Entry entry : sparseVector.entrySet()) { + float currentValue = Math.abs(entry.getValue()); + if (currentValue < thresh) { + result.remove(entry.getKey()); + } + } + + return result; + } + + public static Map pruningByAlphaMass(Map sparseVector, float alpha) { + List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); + sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); + float topSum = 0f; + + Map result = new HashMap<>(); + for (Map.Entry entry : sortedEntries) { + float value = entry.getValue(); + topSum += value; + result.put(entry.getKey(), value); + + if (topSum / sum >= alpha) { + break; + } + } + + return result; + } + + public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { + switch (pruningType) { + case TOP_K: + return pruningByTopK(sparseVector, (int) pruneRatio); + case ALPHA_MASS: + return pruningByAlphaMass(sparseVector, pruneRatio); + case MAX_RATIO: + return pruningByMaxRatio(sparseVector, pruneRatio); + case ABS_VALUE: + return pruningByValue(sparseVector, pruneRatio); + default: + return sparseVector; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java new file mode 100644 index 000000000..5a26a1e53 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.pruning; + +import org.apache.commons.lang.StringUtils; + +/** + * Enum representing different types of pruning methods for sparse vectors + */ +public enum PruningType { + NONE("none"), + TOP_K("top_k"), + ALPHA_MASS("alpha_mass"), + MAX_RATIO("max_ratio"), + ABS_VALUE("abs_value"); + + private final String value; + + PruningType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + /** + * Get PruningType from string value + * + * @param value string representation of pruning type + * @return corresponding PruningType enum + * @throws IllegalArgumentException if value doesn't match any pruning type + */ + public static PruningType fromString(String value) { + if (StringUtils.isEmpty(value)) return NONE; + for (PruningType type : PruningType.values()) { + if (type.value.equals(value)) { + return type; + } + } + throw new IllegalArgumentException("Unknown pruning type: " + value); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index e36b42cd6..3189706de 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,6 +4,9 @@ */ package org.opensearch.neuralsearch.util; +import org.opensearch.neuralsearch.processor.pruning.PruneUtils; +import org.opensearch.neuralsearch.processor.pruning.PruningType; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -44,7 +47,11 @@ public class TokenWeightUtil { * * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} */ - public static List> fetchListOfTokenWeightMap(List> mapResultList) { + public static List> fetchListOfTokenWeightMap( + List> mapResultList, + PruningType pruningType, + float pruneRatio + ) { if (null == mapResultList || mapResultList.isEmpty()) { throw new IllegalArgumentException("The inference result can not be null or empty."); } @@ -58,10 +65,16 @@ public static List> fetchListOfTokenWeightMap(List) map.get("response")); } - return results.stream().map(TokenWeightUtil::buildTokenWeightMap).collect(Collectors.toList()); + return results.stream() + .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruningType, pruneRatio)) + .collect(Collectors.toList()); + } + + public static List> fetchListOfTokenWeightMap(List> mapResultList) { + return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruningType.NONE, 0f); } - private static Map buildTokenWeightMap(Object uncastedMap) { + private static Map buildTokenWeightMap(Object uncastedMap, PruningType pruningType, float pruneRatio) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } @@ -72,6 +85,6 @@ private static Map buildTokenWeightMap(Object uncastedMap) { } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } - return result; + return PruneUtils.pruningSparseVector(pruningType, pruneRatio, result); } } From 08801c09e621de2b4ce8e6fae98d90772960d834 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 14 Nov 2024 15:59:51 +0800 Subject: [PATCH 2/5] add UT Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 7 +- .../SparseEncodingProcessorFactory.java | 8 +- .../processor/pruning/PruneUtils.java | 92 --------- .../neuralsearch/util/TokenWeightUtil.java | 4 +- .../neuralsearch/util/pruning/PruneUtils.java | 180 ++++++++++++++++++ .../pruning/PruningType.java | 2 +- .../util/pruning/PruneUtilsTests.java | 159 ++++++++++++++++ 7 files changed, 353 insertions(+), 99 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java create mode 100644 src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java rename src/main/java/org/opensearch/neuralsearch/{processor => util}/pruning/PruningType.java (95%) create mode 100644 src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index d49a6a709..61851c1d6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -14,7 +14,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruningType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -67,7 +67,10 @@ public void doBatchExecute(List inferenceList, Consumer> handler mlCommonsClientAccessor.inferenceSentencesWithMapResult( this.modelId, inferenceList, - ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), onException) + ActionListener.wrap( + resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), + onException + ) ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 8f9e42eea..40a31392c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -21,8 +21,8 @@ import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; -import org.opensearch.neuralsearch.processor.pruning.PruneUtils; -import org.opensearch.neuralsearch.processor.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.pruning.PruningType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -44,12 +44,16 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + // if the field is miss, will return PruningType.None PruningType pruningType = PruningType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); float pruneRatio = 0; if (pruningType != PruningType.NONE) { // if we have prune type, then prune ratio field must have value // readDoubleProperty will throw exception if value is not present pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); + if (!PruneUtils.isValidPruneRatio(pruningType, pruneRatio)) throw new IllegalArgumentException( + "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruningType.name() + ); } else { // if we don't have prune type, then prune ratio field must not have value if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java deleted file mode 100644 index 47aaaeac9..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.processor.pruning; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -public class PruneUtils { - public static final String PRUNE_TYPE_FIELD = "prune_type"; - public static final String PRUNE_RATIO_FIELD = "prune_ratio"; - - public static Map pruningByTopK(Map sparseVector, int k) { - List> list = new ArrayList<>(sparseVector.entrySet()); - list.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); - - Map result = new HashMap<>(); - for (int i = 0; i < k && i < list.size(); i++) { - Map.Entry entry = list.get(i); - result.put(entry.getKey(), entry.getValue()); - } - return result; - } - - public static Map pruningByMaxRatio(Map sparseVector, float ratio) { - float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); - - Map result = new HashMap<>(); - for (Map.Entry entry : sparseVector.entrySet()) { - float currentValue = entry.getValue(); - float currentRatio = currentValue / maxValue; - - if (currentRatio >= ratio) { - result.put(entry.getKey(), entry.getValue()); - } - } - - return result; - } - - public static Map pruningByValue(Map sparseVector, float thresh) { - Map result = new HashMap<>(sparseVector); - for (Map.Entry entry : sparseVector.entrySet()) { - float currentValue = Math.abs(entry.getValue()); - if (currentValue < thresh) { - result.remove(entry.getKey()); - } - } - - return result; - } - - public static Map pruningByAlphaMass(Map sparseVector, float alpha) { - List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); - sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); - - float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); - float topSum = 0f; - - Map result = new HashMap<>(); - for (Map.Entry entry : sortedEntries) { - float value = entry.getValue(); - topSum += value; - result.put(entry.getKey(), value); - - if (topSum / sum >= alpha) { - break; - } - } - - return result; - } - - public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { - switch (pruningType) { - case TOP_K: - return pruningByTopK(sparseVector, (int) pruneRatio); - case ALPHA_MASS: - return pruningByAlphaMass(sparseVector, pruneRatio); - case MAX_RATIO: - return pruningByMaxRatio(sparseVector, pruneRatio); - case ABS_VALUE: - return pruningByValue(sparseVector, pruneRatio); - default: - return sparseVector; - } - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 3189706de..0ee48fa33 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.util; -import org.opensearch.neuralsearch.processor.pruning.PruneUtils; -import org.opensearch.neuralsearch.processor.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.pruning.PruningType; import java.util.ArrayList; import java.util.HashMap; diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java new file mode 100644 index 000000000..d7d2234cf --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** + * Utility class providing methods for pruning sparse vectors using different strategies. + * Pruning helps reduce the dimensionality of sparse vectors by removing less significant elements + * based on various criteria. + */ +public class PruneUtils { + public static final String PRUNE_TYPE_FIELD = "prune_type"; + public static final String PRUNE_RATIO_FIELD = "prune_ratio"; + + /** + * Prunes a sparse vector by keeping only the top K elements with the highest values. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param k The number of top elements to keep + * @return A new map containing only the top K elements + */ + private static Map pruningByTopK(Map sparseVector, int k) { + PriorityQueue> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue())); + + for (Map.Entry entry : sparseVector.entrySet()) { + if (pq.size() < k) { + pq.offer(entry); + } else if (entry.getValue() > pq.peek().getValue()) { + pq.poll(); + pq.offer(entry); + } + } + + Map result = new HashMap<>(); + while (!pq.isEmpty()) { + Map.Entry entry = pq.poll(); + result.put(entry.getKey(), entry.getValue()); + } + + return result; + } + + /** + * Prunes a sparse vector by keeping only elements whose values are within a certain ratio + * of the maximum value in the vector. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param ratio The minimum ratio relative to the maximum value for elements to be kept + * @return A new map containing only elements meeting the ratio threshold + */ + private static Map pruningByMaxRatio(Map sparseVector, float ratio) { + float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); + + Map result = new HashMap<>(); + for (Map.Entry entry : sparseVector.entrySet()) { + float currentRatio = entry.getValue() / maxValue; + + if (currentRatio >= ratio) { + result.put(entry.getKey(), entry.getValue()); + } + } + + return result; + } + + /** + * Prunes a sparse vector by removing elements with values below a certain threshold. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param thresh The minimum absolute value for elements to be kept + * @return A new map containing only elements meeting the threshold + */ + private static Map pruningByValue(Map sparseVector, float thresh) { + Map result = new HashMap<>(sparseVector); + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() < thresh) { + result.remove(entry.getKey()); + } + } + + return result; + } + + /** + * Prunes a sparse vector by keeping only elements whose cumulative sum of values + * is within a certain ratio of the total sum. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param alpha The minimum ratio relative to the total sum for elements to be kept + * @return A new map containing only elements meeting the ratio threshold + */ + private static Map pruningByAlphaMass(Map sparseVector, float alpha) { + List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); + sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); + float topSum = 0f; + + Map result = new HashMap<>(); + for (Map.Entry entry : sortedEntries) { + float value = entry.getValue(); + topSum += value; + result.put(entry.getKey(), value); + + if (topSum / sum >= alpha) { + break; + } + } + + return result; + } + + /** + * Prunes a sparse vector using the specified pruning type and ratio. + * + * @param pruningType The type of pruning strategy to use + * @param pruneRatio The ratio or threshold for pruning + * @param sparseVector The input sparse vector as a map of string keys to float values + * @return A new map containing the pruned sparse vector + */ + public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { + if (Objects.isNull(pruningType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( + "Prune type and prune ratio must be provided" + ); + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException("Pruned values must be positive"); + } + } + + switch (pruningType) { + case TOP_K: + return pruningByTopK(sparseVector, (int) pruneRatio); + case ALPHA_MASS: + return pruningByAlphaMass(sparseVector, pruneRatio); + case MAX_RATIO: + return pruningByMaxRatio(sparseVector, pruneRatio); + case ABS_VALUE: + return pruningByValue(sparseVector, pruneRatio); + default: + return sparseVector; + } + } + + /** + * Validates whether a prune ratio is valid for a given pruning type. + * + * @param pruningType The type of pruning strategy + * @param pruneRatio The ratio or threshold to validate + * @return true if the ratio is valid for the given pruning type, false otherwise + * @throws IllegalArgumentException if pruning type is null + */ + public static boolean isValidPruneRatio(PruningType pruningType, float pruneRatio) { + if (pruningType == null) { + throw new IllegalArgumentException("Pruning type cannot be null"); + } + + switch (pruningType) { + case TOP_K: + return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); + case ALPHA_MASS: + case MAX_RATIO: + return pruneRatio > 0 && pruneRatio < 1; + case ABS_VALUE: + return pruneRatio > 0; + default: + return true; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java similarity index 95% rename from src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java rename to src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java index 5a26a1e53..6629bb937 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor.pruning; +package org.opensearch.neuralsearch.util.pruning; import org.apache.commons.lang.StringUtils; diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java new file mode 100644 index 000000000..07a7f11eb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import org.junit.Test; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class PruneUtilsTests extends OpenSearchTestCase { + @Test + public void testPruningByTopK() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + input.put("d", 1.0f); + + Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("c")); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + } + + @Test + public void testPruningByMaxRatio() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 8.0f); + input.put("c", 5.0f); + input.put("d", 2.0f); + + Map result = PruneUtils.pruningSparseVector(PruningType.MAX_RATIO, 0.7f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 + assertTrue(result.containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 + } + + @Test + public void testPruningByValue() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 2.0f); + input.put("d", 1.0f); + + Map result = PruneUtils.pruningSparseVector(PruningType.ABS_VALUE, 3.0f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("b")); + } + + @Test + public void testPruningByAlphaMass() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 6.0f); + input.put("c", 3.0f); + input.put("d", 1.0f); + // Total sum = 20.0 + + Map result = PruneUtils.pruningSparseVector(PruningType.ALPHA_MASS, 0.8f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("b")); + } + + @Test + public void testEmptyInput() { + Map input = new HashMap<>(); + + Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 5, input); + assertTrue(result.isEmpty()); + } + + @Test + public void testNegativeValues() { + Map input = new HashMap<>(); + input.put("a", -5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input) + ); + assertEquals(exception.getMessage(), "Pruned values must be positive"); + } + + @Test + public void testInvalidPruningType() { + Map input = new HashMap<>(); + input.put("a", 1.0f); + input.put("b", 2.0f); + + IllegalArgumentException exception1 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(null, 2, input) + ); + assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); + + IllegalArgumentException exception2 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(null, 2, input) + ); + assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); + } + + @Test + public void testIsValidPruneRatio() { + // Test TOP_K validation + assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 100)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, -1)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1.5f)); + + // Test ALPHA_MASS validation + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.1f)); + + // Test MAX_RATIO validation + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.1f)); + + // Test ABS_VALUE validation + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 1.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 100.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, -0.1f)); + + // Test with extreme cases + assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, Float.MIN_VALUE)); + } + + @Test + public void testIsValidPruneRatioWithNullType() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); + assertEquals("Pruning type cannot be null", exception.getMessage()); + } +} From 958027cbd205048dfc2d606dbb56582602d06437 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Nov 2024 16:49:09 +0800 Subject: [PATCH 3/5] rename pruneType; UT Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 15 +- .../SparseEncodingProcessorFactory.java | 16 +- .../neuralsearch/util/TokenWeightUtil.java | 12 +- .../{PruningType.java => PruneType.java} | 12 +- .../neuralsearch/util/pruning/PruneUtils.java | 16 +- ...ncodingEmbeddingProcessorFactoryTests.java | 182 ++++++++++++++++++ .../util/TokenWeightUtilTests.java | 33 ++++ .../util/pruning/PruneTypeTests.java | 30 +++ .../util/pruning/PruneUtilsTests.java | 71 +++---- 9 files changed, 313 insertions(+), 74 deletions(-) rename src/main/java/org/opensearch/neuralsearch/util/pruning/{PruningType.java => PruneType.java} (77%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 61851c1d6..a3a6cacbb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -9,12 +9,13 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import lombok.Getter; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.util.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -28,7 +29,9 @@ public final class SparseEncodingProcessor extends InferenceProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; - private final PruningType pruningType; + @Getter + private final PruneType pruneType; + @Getter private final float pruneRatio; public SparseEncodingProcessor( @@ -37,14 +40,14 @@ public SparseEncodingProcessor( int batchSize, String modelId, Map fieldMap, - PruningType pruningType, + PruneType pruneType, float pruneRatio, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); - this.pruningType = pruningType; + this.pruneType = pruneType; this.pruneRatio = pruneRatio; } @@ -56,7 +59,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio); setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -68,7 +71,7 @@ public void doBatchExecute(List inferenceList, Consumer> handler this.modelId, inferenceList, ActionListener.wrap( - resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), + resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio)), onException ) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 40a31392c..19cea9419 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -8,9 +8,9 @@ import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; import java.util.Map; @@ -22,7 +22,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.util.pruning.PruneUtils; -import org.opensearch.neuralsearch.util.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -44,15 +44,15 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); - // if the field is miss, will return PruningType.None - PruningType pruningType = PruningType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); + // if the field is miss, will return PruneType.None + PruneType pruneType = PruneType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); float pruneRatio = 0; - if (pruningType != PruningType.NONE) { + if (pruneType != PruneType.NONE) { // if we have prune type, then prune ratio field must have value // readDoubleProperty will throw exception if value is not present pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); - if (!PruneUtils.isValidPruneRatio(pruningType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruningType.name() + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( + "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruneType.getValue() ); } else { // if we don't have prune type, then prune ratio field must not have value @@ -67,7 +67,7 @@ protected AbstractBatchingProcessor newProcessor(String tag, String description, batchSize, modelId, fieldMap, - pruningType, + pruneType, pruneRatio, clientAccessor, environment, diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 0ee48fa33..0de3610f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.util; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.neuralsearch.util.pruning.PruneUtils; -import org.opensearch.neuralsearch.util.pruning.PruningType; import java.util.ArrayList; import java.util.HashMap; @@ -49,7 +49,7 @@ public class TokenWeightUtil { */ public static List> fetchListOfTokenWeightMap( List> mapResultList, - PruningType pruningType, + PruneType pruneType, float pruneRatio ) { if (null == mapResultList || mapResultList.isEmpty()) { @@ -66,15 +66,15 @@ public static List> fetchListOfTokenWeightMap( results.addAll((List) map.get("response")); } return results.stream() - .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruningType, pruneRatio)) + .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruneType, pruneRatio)) .collect(Collectors.toList()); } public static List> fetchListOfTokenWeightMap(List> mapResultList) { - return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruningType.NONE, 0f); + return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruneType.NONE, 0f); } - private static Map buildTokenWeightMap(Object uncastedMap, PruningType pruningType, float pruneRatio) { + private static Map buildTokenWeightMap(Object uncastedMap, PruneType pruneType, float pruneRatio) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } @@ -85,6 +85,6 @@ private static Map buildTokenWeightMap(Object uncastedMap, Prunin } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } - return PruneUtils.pruningSparseVector(pruningType, pruneRatio, result); + return PruneUtils.pruningSparseVector(pruneType, pruneRatio, result); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java similarity index 77% rename from src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java rename to src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java index 6629bb937..22376b7c5 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java @@ -9,7 +9,7 @@ /** * Enum representing different types of pruning methods for sparse vectors */ -public enum PruningType { +public enum PruneType { NONE("none"), TOP_K("top_k"), ALPHA_MASS("alpha_mass"), @@ -18,7 +18,7 @@ public enum PruningType { private final String value; - PruningType(String value) { + PruneType(String value) { this.value = value; } @@ -27,15 +27,15 @@ public String getValue() { } /** - * Get PruningType from string value + * Get PruneType from string value * * @param value string representation of pruning type - * @return corresponding PruningType enum + * @return corresponding PruneType enum * @throws IllegalArgumentException if value doesn't match any pruning type */ - public static PruningType fromString(String value) { + public static PruneType fromString(String value) { if (StringUtils.isEmpty(value)) return NONE; - for (PruningType type : PruningType.values()) { + for (PruneType type : PruneType.values()) { if (type.value.equals(value)) { return type; } diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java index d7d2234cf..87e87cffa 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java @@ -122,13 +122,13 @@ private static Map pruningByAlphaMass(Map sparseVe /** * Prunes a sparse vector using the specified pruning type and ratio. * - * @param pruningType The type of pruning strategy to use + * @param pruneType The type of pruning strategy to use * @param pruneRatio The ratio or threshold for pruning * @param sparseVector The input sparse vector as a map of string keys to float values * @return A new map containing the pruned sparse vector */ - public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { - if (Objects.isNull(pruningType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( + public static Map pruningSparseVector(PruneType pruneType, float pruneRatio, Map sparseVector) { + if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( "Prune type and prune ratio must be provided" ); @@ -138,7 +138,7 @@ public static Map pruningSparseVector(PruningType pruningType, fl } } - switch (pruningType) { + switch (pruneType) { case TOP_K: return pruningByTopK(sparseVector, (int) pruneRatio); case ALPHA_MASS: @@ -155,17 +155,17 @@ public static Map pruningSparseVector(PruningType pruningType, fl /** * Validates whether a prune ratio is valid for a given pruning type. * - * @param pruningType The type of pruning strategy + * @param pruneType The type of pruning strategy * @param pruneRatio The ratio or threshold to validate * @return true if the ratio is valid for the given pruning type, false otherwise * @throws IllegalArgumentException if pruning type is null */ - public static boolean isValidPruneRatio(PruningType pruningType, float pruneRatio) { - if (pruningType == null) { + public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { + if (pruneType == null) { throw new IllegalArgumentException("Pruning type cannot be null"); } - switch (pruningType) { + switch (pruneType) { case TOP_K: return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); case ALPHA_MASS: diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java new file mode 100644 index 000000000..8b1fafe8b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; +import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_TYPE_FIELD; +import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_RATIO_FIELD; + +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; +import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class SparseEncodingEmbeddingProcessorFactoryTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String MODEL_ID = "testModelId"; + private static final int BATCH_SIZE = 1; + + private MLCommonsClientAccessor clientAccessor; + private Environment environment; + private ClusterService clusterService; + private SparseEncodingProcessorFactory sparseEncodingProcessorFactory; + + @Before + public void setup() { + clientAccessor = mock(MLCommonsClientAccessor.class); + environment = mock(Environment.class); + clusterService = mock(ClusterService.class); + sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory(clientAccessor, environment, clusterService); + } + + @SneakyThrows + public void testCreateProcessor_whenAllRequiredParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.NONE, processor.getPruneType()); + assertEquals(0f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenPruneParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 2f); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.TOP_K, processor.getPruneType()); + assertEquals(2f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenEmptyFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of()); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unable to create the processor as field_map has invalid key or value", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingModelIdField_thenFail() { + Map config = new HashMap<>(); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[model_id] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[field_map] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "invalid_prune_type"); + config.put(PRUNE_RATIO_FIELD, 2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unknown pruning type: invalid_prune_type", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 0.2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "alpha_mass"); + + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[prune_ratio] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_RATIO_FIELD, 0.1); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("prune_ratio field is not supported when prune_type is not provided", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java index 887d8fc17..234a70823 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Map; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.test.OpenSearchTestCase; public class TokenWeightUtilTests extends OpenSearchTestCase { @@ -104,4 +105,36 @@ public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonFloatValues_th List> inputData = List.of(Map.of("response", List.of(mockData))); expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); } + + public void testFetchListOfTokenWeightMap_invokeWithPrune() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.MAX_RATIO, 0.8f), List.of(Map.of("world", 2f))); + } + + public void testFetchListOfTokenWeightMap_invokeWithPrune_MultipleObjectsInMultipleResponse() { + + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + },{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA)), Map.of("response", List.of(MOCK_DATA))); + assertEquals( + TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.TOP_K, 1f), + List.of(Map.of("world", 2f), Map.of("world", 2f)) + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java new file mode 100644 index 000000000..a1a823093 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import org.opensearch.test.OpenSearchTestCase; + +public class PruneTypeTests extends OpenSearchTestCase { + public void testGetValue() { + assertEquals("none", PruneType.NONE.getValue()); + assertEquals("top_k", PruneType.TOP_K.getValue()); + assertEquals("alpha_mass", PruneType.ALPHA_MASS.getValue()); + assertEquals("max_ratio", PruneType.MAX_RATIO.getValue()); + assertEquals("abs_value", PruneType.ABS_VALUE.getValue()); + } + + public void testFromString() { + assertEquals(PruneType.NONE, PruneType.fromString("none")); + assertEquals(PruneType.NONE, PruneType.fromString(null)); + assertEquals(PruneType.NONE, PruneType.fromString("")); + assertEquals(PruneType.TOP_K, PruneType.fromString("top_k")); + assertEquals(PruneType.ALPHA_MASS, PruneType.fromString("alpha_mass")); + assertEquals(PruneType.MAX_RATIO, PruneType.fromString("max_ratio")); + assertEquals(PruneType.ABS_VALUE, PruneType.fromString("abs_value")); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneType.fromString("test_value")); + assertEquals("Unknown pruning type: test_value", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java index 07a7f11eb..74aadf09f 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java @@ -4,14 +4,13 @@ */ package org.opensearch.neuralsearch.util.pruning; -import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; import java.util.HashMap; import java.util.Map; public class PruneUtilsTests extends OpenSearchTestCase { - @Test + public void testPruningByTopK() { Map input = new HashMap<>(); input.put("a", 5.0f); @@ -19,7 +18,7 @@ public void testPruningByTopK() { input.put("c", 4.0f); input.put("d", 1.0f); - Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input); + Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); @@ -28,7 +27,6 @@ public void testPruningByTopK() { assertEquals(4.0f, result.get("c"), 0.001); } - @Test public void testPruningByMaxRatio() { Map input = new HashMap<>(); input.put("a", 10.0f); @@ -36,14 +34,13 @@ public void testPruningByMaxRatio() { input.put("c", 5.0f); input.put("d", 2.0f); - Map result = PruneUtils.pruningSparseVector(PruningType.MAX_RATIO, 0.7f, input); + Map result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 assertTrue(result.containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 } - @Test public void testPruningByValue() { Map input = new HashMap<>(); input.put("a", 5.0f); @@ -51,14 +48,13 @@ public void testPruningByValue() { input.put("c", 2.0f); input.put("d", 1.0f); - Map result = PruneUtils.pruningSparseVector(PruningType.ABS_VALUE, 3.0f, input); + Map result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); assertTrue(result.containsKey("b")); } - @Test public void testPruningByAlphaMass() { Map input = new HashMap<>(); input.put("a", 10.0f); @@ -67,22 +63,20 @@ public void testPruningByAlphaMass() { input.put("d", 1.0f); // Total sum = 20.0 - Map result = PruneUtils.pruningSparseVector(PruningType.ALPHA_MASS, 0.8f, input); + Map result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); assertTrue(result.containsKey("b")); } - @Test public void testEmptyInput() { Map input = new HashMap<>(); - Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 5, input); + Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input); assertTrue(result.isEmpty()); } - @Test public void testNegativeValues() { Map input = new HashMap<>(); input.put("a", -5.0f); @@ -91,12 +85,11 @@ public void testNegativeValues() { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input) + () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input) ); - assertEquals(exception.getMessage(), "Pruned values must be positive"); + assertEquals("Pruned values must be positive", exception.getMessage()); } - @Test public void testInvalidPruningType() { Map input = new HashMap<>(); input.put("a", 1.0f); @@ -115,43 +108,41 @@ public void testInvalidPruningType() { assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); } - @Test public void testIsValidPruneRatio() { // Test TOP_K validation - assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 100)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 0)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, -1)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1.5f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 100)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, -1)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1.5f)); // Test ALPHA_MASS validation - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0.5f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.0f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, -0.1f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.1f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.1f)); // Test MAX_RATIO validation - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.0f)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.5f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.0f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, -0.1f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.1f)); // Test ABS_VALUE validation - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 0.0f)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 1.0f)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 100.0f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 1.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 100.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, -0.1f)); // Test with extreme cases - assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, Float.MAX_VALUE)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, Float.MAX_VALUE)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, Float.MIN_VALUE)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, Float.MIN_VALUE)); } - @Test public void testIsValidPruneRatioWithNullType() { IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); assertEquals("Pruning type cannot be null", exception.getMessage()); From 30babbb91136be7e275c8ca812802e48ba34e903 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Nov 2024 16:53:10 +0800 Subject: [PATCH 4/5] changelog Signed-off-by: zhichao-aws --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..36b2a3e7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) ### Enhancements ### Bug Fixes ### Infrastructure From b8d8b7f3ce3627646ef397f4aed6597151e9bcaf Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Nov 2024 17:44:35 +0800 Subject: [PATCH 5/5] ut Signed-off-by: zhichao-aws --- .../SparseEncodingProcessorTests.java | 105 +++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..d705616a9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -14,10 +14,12 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; import static org.mockito.Mockito.verify; +import java.util.Arrays; import java.util.Map; import java.util.ArrayList; import java.util.Collections; @@ -49,6 +51,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.util.pruning.PruneType; public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { @Mock @@ -90,6 +93,17 @@ private SparseEncodingProcessor createInstance(int batchSize) { return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } + @SneakyThrows + private SparseEncodingProcessor createInstance(PruneType pruneType, float pruneRatio) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + config.put("prune_type", pruneType.getValue()); + config.put("prune_ratio", pruneRatio); + return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); @@ -260,9 +274,98 @@ public void test_batchExecute_exception() { } } + @SuppressWarnings("unchecked") + public void testExecute_withPruningConfig_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> dataAsMapList = Collections.singletonList( + Map.of("response", Arrays.asList(ImmutableMap.of("hello", 1.0f, "world", 0.1f), ImmutableMap.of("test", 0.8f, "low", 0.4f))) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + + ArgumentCaptor docCaptor = ArgumentCaptor.forClass(IngestDocument.class); + verify(handler).accept(docCaptor.capture(), isNull()); + + IngestDocument processedDoc = docCaptor.getValue(); + Map first = (Map) processedDoc.getFieldValue("key1Mapped", Map.class); + Map second = (Map) processedDoc.getFieldValue("key2Mapped", Map.class); + + assertNotNull(first); + assertNotNull(second); + + assertTrue(first.containsKey("hello")); + assertFalse(first.containsKey("world")); + assertEquals(1.0f, first.get("hello"), 0.001f); + + assertTrue(second.containsKey("test")); + assertTrue(second.containsKey("low")); + assertEquals(0.8f, second.get("test"), 0.001f); + assertEquals(0.4f, second.get("low"), 0.001f); + } + + public void test_batchExecute_withPruning_successful() { + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> mockMLResponse = Collections.singletonList( + Map.of( + "response", + Arrays.asList( + ImmutableMap.of("token1", 1.0f, "token2", 0.3f, "token3", 0.8f), + ImmutableMap.of("token4", 0.9f, "token5", 0.2f, "token6", 0.7f) + ) + ) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(mockMLResponse); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + Consumer> resultHandler = mock(Consumer.class); + Consumer exceptionHandler = mock(Consumer.class); + + List inferenceList = Arrays.asList("test1", "test2"); + processor.doBatchExecute(inferenceList, resultHandler, exceptionHandler); + + ArgumentCaptor>> resultCaptor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCaptor.capture()); + verify(exceptionHandler, never()).accept(any()); + + List> processedResults = resultCaptor.getValue(); + + assertEquals(2, processedResults.size()); + + Map firstResult = processedResults.get(0); + assertEquals(2, firstResult.size()); + assertTrue(firstResult.containsKey("token1")); + assertTrue(firstResult.containsKey("token3")); + assertFalse(firstResult.containsKey("token2")); + + Map secondResult = processedResults.get(1); + assertEquals(2, secondResult.size()); + assertTrue(secondResult.containsKey("token4")); + assertTrue(secondResult.containsKey("token6")); + assertFalse(secondResult.containsKey("token5")); + } + private List> createMockMapResult(int number) { List> mockSparseEncodingResult = new ArrayList<>(); - IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f, "world", 0.1f))); List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); return mockMapResult;