diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..36b2a3e7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e01840fbb..a3a6cacbb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -9,11 +9,13 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import lombok.Getter; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -27,6 +29,10 @@ public final class SparseEncodingProcessor extends InferenceProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; + @Getter + private final PruneType pruneType; + @Getter + private final float pruneRatio; public SparseEncodingProcessor( String tag, @@ -34,11 +40,15 @@ public SparseEncodingProcessor( int batchSize, String modelId, Map fieldMap, + PruneType pruneType, + float pruneRatio, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + this.pruneType = pruneType; + this.pruneRatio = pruneRatio; } @Override @@ -49,7 +59,8 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio); + setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } @@ -59,7 +70,10 @@ public void doBatchExecute(List inferenceList, Consumer> handler mlCommonsClientAccessor.inferenceSentencesWithMapResult( this.modelId, inferenceList, - ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) + ActionListener.wrap( + resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio)), + onException + ) ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 46055df16..19cea9419 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -6,9 +6,11 @@ import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; +import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; import java.util.Map; @@ -19,6 +21,8 @@ import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.pruning.PruneType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -40,7 +44,34 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + // if the field is miss, will return PruneType.None + PruneType pruneType = PruneType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); + float pruneRatio = 0; + if (pruneType != PruneType.NONE) { + // if we have prune type, then prune ratio field must have value + // readDoubleProperty will throw exception if value is not present + pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( + "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruneType.getValue() + ); + } else { + // if we don't have prune type, then prune ratio field must not have value + if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { + throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); + } + } - return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService); + return new SparseEncodingProcessor( + tag, + description, + batchSize, + modelId, + fieldMap, + pruneType, + pruneRatio, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index e36b42cd6..0de3610f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,6 +4,9 @@ */ package org.opensearch.neuralsearch.util; +import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -44,7 +47,11 @@ public class TokenWeightUtil { * * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} */ - public static List> fetchListOfTokenWeightMap(List> mapResultList) { + public static List> fetchListOfTokenWeightMap( + List> mapResultList, + PruneType pruneType, + float pruneRatio + ) { if (null == mapResultList || mapResultList.isEmpty()) { throw new IllegalArgumentException("The inference result can not be null or empty."); } @@ -58,10 +65,16 @@ public static List> fetchListOfTokenWeightMap(List) map.get("response")); } - return results.stream().map(TokenWeightUtil::buildTokenWeightMap).collect(Collectors.toList()); + return results.stream() + .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruneType, pruneRatio)) + .collect(Collectors.toList()); + } + + public static List> fetchListOfTokenWeightMap(List> mapResultList) { + return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruneType.NONE, 0f); } - private static Map buildTokenWeightMap(Object uncastedMap) { + private static Map buildTokenWeightMap(Object uncastedMap, PruneType pruneType, float pruneRatio) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } @@ -72,6 +85,6 @@ private static Map buildTokenWeightMap(Object uncastedMap) { } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } - return result; + return PruneUtils.pruningSparseVector(pruneType, pruneRatio, result); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java new file mode 100644 index 000000000..22376b7c5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import org.apache.commons.lang.StringUtils; + +/** + * Enum representing different types of pruning methods for sparse vectors + */ +public enum PruneType { + NONE("none"), + TOP_K("top_k"), + ALPHA_MASS("alpha_mass"), + MAX_RATIO("max_ratio"), + ABS_VALUE("abs_value"); + + private final String value; + + PruneType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + /** + * Get PruneType from string value + * + * @param value string representation of pruning type + * @return corresponding PruneType enum + * @throws IllegalArgumentException if value doesn't match any pruning type + */ + public static PruneType fromString(String value) { + if (StringUtils.isEmpty(value)) return NONE; + for (PruneType type : PruneType.values()) { + if (type.value.equals(value)) { + return type; + } + } + throw new IllegalArgumentException("Unknown pruning type: " + value); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java new file mode 100644 index 000000000..87e87cffa --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** + * Utility class providing methods for pruning sparse vectors using different strategies. + * Pruning helps reduce the dimensionality of sparse vectors by removing less significant elements + * based on various criteria. + */ +public class PruneUtils { + public static final String PRUNE_TYPE_FIELD = "prune_type"; + public static final String PRUNE_RATIO_FIELD = "prune_ratio"; + + /** + * Prunes a sparse vector by keeping only the top K elements with the highest values. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param k The number of top elements to keep + * @return A new map containing only the top K elements + */ + private static Map pruningByTopK(Map sparseVector, int k) { + PriorityQueue> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue())); + + for (Map.Entry entry : sparseVector.entrySet()) { + if (pq.size() < k) { + pq.offer(entry); + } else if (entry.getValue() > pq.peek().getValue()) { + pq.poll(); + pq.offer(entry); + } + } + + Map result = new HashMap<>(); + while (!pq.isEmpty()) { + Map.Entry entry = pq.poll(); + result.put(entry.getKey(), entry.getValue()); + } + + return result; + } + + /** + * Prunes a sparse vector by keeping only elements whose values are within a certain ratio + * of the maximum value in the vector. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param ratio The minimum ratio relative to the maximum value for elements to be kept + * @return A new map containing only elements meeting the ratio threshold + */ + private static Map pruningByMaxRatio(Map sparseVector, float ratio) { + float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); + + Map result = new HashMap<>(); + for (Map.Entry entry : sparseVector.entrySet()) { + float currentRatio = entry.getValue() / maxValue; + + if (currentRatio >= ratio) { + result.put(entry.getKey(), entry.getValue()); + } + } + + return result; + } + + /** + * Prunes a sparse vector by removing elements with values below a certain threshold. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param thresh The minimum absolute value for elements to be kept + * @return A new map containing only elements meeting the threshold + */ + private static Map pruningByValue(Map sparseVector, float thresh) { + Map result = new HashMap<>(sparseVector); + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() < thresh) { + result.remove(entry.getKey()); + } + } + + return result; + } + + /** + * Prunes a sparse vector by keeping only elements whose cumulative sum of values + * is within a certain ratio of the total sum. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param alpha The minimum ratio relative to the total sum for elements to be kept + * @return A new map containing only elements meeting the ratio threshold + */ + private static Map pruningByAlphaMass(Map sparseVector, float alpha) { + List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); + sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); + float topSum = 0f; + + Map result = new HashMap<>(); + for (Map.Entry entry : sortedEntries) { + float value = entry.getValue(); + topSum += value; + result.put(entry.getKey(), value); + + if (topSum / sum >= alpha) { + break; + } + } + + return result; + } + + /** + * Prunes a sparse vector using the specified pruning type and ratio. + * + * @param pruneType The type of pruning strategy to use + * @param pruneRatio The ratio or threshold for pruning + * @param sparseVector The input sparse vector as a map of string keys to float values + * @return A new map containing the pruned sparse vector + */ + public static Map pruningSparseVector(PruneType pruneType, float pruneRatio, Map sparseVector) { + if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( + "Prune type and prune ratio must be provided" + ); + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException("Pruned values must be positive"); + } + } + + switch (pruneType) { + case TOP_K: + return pruningByTopK(sparseVector, (int) pruneRatio); + case ALPHA_MASS: + return pruningByAlphaMass(sparseVector, pruneRatio); + case MAX_RATIO: + return pruningByMaxRatio(sparseVector, pruneRatio); + case ABS_VALUE: + return pruningByValue(sparseVector, pruneRatio); + default: + return sparseVector; + } + } + + /** + * Validates whether a prune ratio is valid for a given pruning type. + * + * @param pruneType The type of pruning strategy + * @param pruneRatio The ratio or threshold to validate + * @return true if the ratio is valid for the given pruning type, false otherwise + * @throws IllegalArgumentException if pruning type is null + */ + public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { + if (pruneType == null) { + throw new IllegalArgumentException("Pruning type cannot be null"); + } + + switch (pruneType) { + case TOP_K: + return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); + case ALPHA_MASS: + case MAX_RATIO: + return pruneRatio > 0 && pruneRatio < 1; + case ABS_VALUE: + return pruneRatio > 0; + default: + return true; + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..d705616a9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -14,10 +14,12 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; import static org.mockito.Mockito.verify; +import java.util.Arrays; import java.util.Map; import java.util.ArrayList; import java.util.Collections; @@ -49,6 +51,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.util.pruning.PruneType; public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { @Mock @@ -90,6 +93,17 @@ private SparseEncodingProcessor createInstance(int batchSize) { return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } + @SneakyThrows + private SparseEncodingProcessor createInstance(PruneType pruneType, float pruneRatio) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + config.put("prune_type", pruneType.getValue()); + config.put("prune_ratio", pruneRatio); + return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); @@ -260,9 +274,98 @@ public void test_batchExecute_exception() { } } + @SuppressWarnings("unchecked") + public void testExecute_withPruningConfig_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> dataAsMapList = Collections.singletonList( + Map.of("response", Arrays.asList(ImmutableMap.of("hello", 1.0f, "world", 0.1f), ImmutableMap.of("test", 0.8f, "low", 0.4f))) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + + ArgumentCaptor docCaptor = ArgumentCaptor.forClass(IngestDocument.class); + verify(handler).accept(docCaptor.capture(), isNull()); + + IngestDocument processedDoc = docCaptor.getValue(); + Map first = (Map) processedDoc.getFieldValue("key1Mapped", Map.class); + Map second = (Map) processedDoc.getFieldValue("key2Mapped", Map.class); + + assertNotNull(first); + assertNotNull(second); + + assertTrue(first.containsKey("hello")); + assertFalse(first.containsKey("world")); + assertEquals(1.0f, first.get("hello"), 0.001f); + + assertTrue(second.containsKey("test")); + assertTrue(second.containsKey("low")); + assertEquals(0.8f, second.get("test"), 0.001f); + assertEquals(0.4f, second.get("low"), 0.001f); + } + + public void test_batchExecute_withPruning_successful() { + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> mockMLResponse = Collections.singletonList( + Map.of( + "response", + Arrays.asList( + ImmutableMap.of("token1", 1.0f, "token2", 0.3f, "token3", 0.8f), + ImmutableMap.of("token4", 0.9f, "token5", 0.2f, "token6", 0.7f) + ) + ) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(mockMLResponse); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + Consumer> resultHandler = mock(Consumer.class); + Consumer exceptionHandler = mock(Consumer.class); + + List inferenceList = Arrays.asList("test1", "test2"); + processor.doBatchExecute(inferenceList, resultHandler, exceptionHandler); + + ArgumentCaptor>> resultCaptor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCaptor.capture()); + verify(exceptionHandler, never()).accept(any()); + + List> processedResults = resultCaptor.getValue(); + + assertEquals(2, processedResults.size()); + + Map firstResult = processedResults.get(0); + assertEquals(2, firstResult.size()); + assertTrue(firstResult.containsKey("token1")); + assertTrue(firstResult.containsKey("token3")); + assertFalse(firstResult.containsKey("token2")); + + Map secondResult = processedResults.get(1); + assertEquals(2, secondResult.size()); + assertTrue(secondResult.containsKey("token4")); + assertTrue(secondResult.containsKey("token6")); + assertFalse(secondResult.containsKey("token5")); + } + private List> createMockMapResult(int number) { List> mockSparseEncodingResult = new ArrayList<>(); - IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f, "world", 0.1f))); List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); return mockMapResult; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java new file mode 100644 index 000000000..8b1fafe8b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; +import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_TYPE_FIELD; +import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_RATIO_FIELD; + +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; +import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class SparseEncodingEmbeddingProcessorFactoryTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String MODEL_ID = "testModelId"; + private static final int BATCH_SIZE = 1; + + private MLCommonsClientAccessor clientAccessor; + private Environment environment; + private ClusterService clusterService; + private SparseEncodingProcessorFactory sparseEncodingProcessorFactory; + + @Before + public void setup() { + clientAccessor = mock(MLCommonsClientAccessor.class); + environment = mock(Environment.class); + clusterService = mock(ClusterService.class); + sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory(clientAccessor, environment, clusterService); + } + + @SneakyThrows + public void testCreateProcessor_whenAllRequiredParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.NONE, processor.getPruneType()); + assertEquals(0f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenPruneParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 2f); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.TOP_K, processor.getPruneType()); + assertEquals(2f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenEmptyFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of()); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unable to create the processor as field_map has invalid key or value", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingModelIdField_thenFail() { + Map config = new HashMap<>(); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[model_id] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[field_map] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "invalid_prune_type"); + config.put(PRUNE_RATIO_FIELD, 2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unknown pruning type: invalid_prune_type", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 0.2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "alpha_mass"); + + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[prune_ratio] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_RATIO_FIELD, 0.1); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("prune_ratio field is not supported when prune_type is not provided", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java index 887d8fc17..234a70823 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Map; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.test.OpenSearchTestCase; public class TokenWeightUtilTests extends OpenSearchTestCase { @@ -104,4 +105,36 @@ public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonFloatValues_th List> inputData = List.of(Map.of("response", List.of(mockData))); expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); } + + public void testFetchListOfTokenWeightMap_invokeWithPrune() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.MAX_RATIO, 0.8f), List.of(Map.of("world", 2f))); + } + + public void testFetchListOfTokenWeightMap_invokeWithPrune_MultipleObjectsInMultipleResponse() { + + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + },{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA)), Map.of("response", List.of(MOCK_DATA))); + assertEquals( + TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.TOP_K, 1f), + List.of(Map.of("world", 2f), Map.of("world", 2f)) + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java new file mode 100644 index 000000000..a1a823093 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import org.opensearch.test.OpenSearchTestCase; + +public class PruneTypeTests extends OpenSearchTestCase { + public void testGetValue() { + assertEquals("none", PruneType.NONE.getValue()); + assertEquals("top_k", PruneType.TOP_K.getValue()); + assertEquals("alpha_mass", PruneType.ALPHA_MASS.getValue()); + assertEquals("max_ratio", PruneType.MAX_RATIO.getValue()); + assertEquals("abs_value", PruneType.ABS_VALUE.getValue()); + } + + public void testFromString() { + assertEquals(PruneType.NONE, PruneType.fromString("none")); + assertEquals(PruneType.NONE, PruneType.fromString(null)); + assertEquals(PruneType.NONE, PruneType.fromString("")); + assertEquals(PruneType.TOP_K, PruneType.fromString("top_k")); + assertEquals(PruneType.ALPHA_MASS, PruneType.fromString("alpha_mass")); + assertEquals(PruneType.MAX_RATIO, PruneType.fromString("max_ratio")); + assertEquals(PruneType.ABS_VALUE, PruneType.fromString("abs_value")); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneType.fromString("test_value")); + assertEquals("Unknown pruning type: test_value", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java new file mode 100644 index 000000000..74aadf09f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class PruneUtilsTests extends OpenSearchTestCase { + + public void testPruningByTopK() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + input.put("d", 1.0f); + + Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("c")); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + } + + public void testPruningByMaxRatio() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 8.0f); + input.put("c", 5.0f); + input.put("d", 2.0f); + + Map result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 + assertTrue(result.containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 + } + + public void testPruningByValue() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 2.0f); + input.put("d", 1.0f); + + Map result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("b")); + } + + public void testPruningByAlphaMass() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 6.0f); + input.put("c", 3.0f); + input.put("d", 1.0f); + // Total sum = 20.0 + + Map result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("b")); + } + + public void testEmptyInput() { + Map input = new HashMap<>(); + + Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input); + assertTrue(result.isEmpty()); + } + + public void testNegativeValues() { + Map input = new HashMap<>(); + input.put("a", -5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input) + ); + assertEquals("Pruned values must be positive", exception.getMessage()); + } + + public void testInvalidPruningType() { + Map input = new HashMap<>(); + input.put("a", 1.0f); + input.put("b", 2.0f); + + IllegalArgumentException exception1 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(null, 2, input) + ); + assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); + + IllegalArgumentException exception2 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(null, 2, input) + ); + assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); + } + + public void testIsValidPruneRatio() { + // Test TOP_K validation + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 100)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, -1)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1.5f)); + + // Test ALPHA_MASS validation + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.1f)); + + // Test MAX_RATIO validation + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.1f)); + + // Test ABS_VALUE validation + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 1.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 100.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, -0.1f)); + + // Test with extreme cases + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, Float.MIN_VALUE)); + } + + public void testIsValidPruneRatioWithNullType() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); + assertEquals("Pruning type cannot be null", exception.getMessage()); + } +}