Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement pruning for neural sparse search #988

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import lombok.Getter;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

import lombok.extern.log4j.Log4j2;
Expand All @@ -27,18 +29,26 @@ public final class SparseEncodingProcessor extends InferenceProcessor {

public static final String TYPE = "sparse_encoding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
@Getter
private final PruneType pruneType;
@Getter
private final float pruneRatio;

public SparseEncodingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
PruneType pruneType,
float pruneRatio,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
this.pruneType = pruneType;
this.pruneRatio = pruneRatio;
}

@Override
Expand All @@ -49,7 +59,8 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio);
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}
Expand All @@ -59,7 +70,10 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
this.modelId,
inferenceList,
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
ActionListener.wrap(
resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio)),
onException
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty;
import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE;

import java.util.Map;

Expand All @@ -19,6 +21,8 @@
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;
import org.opensearch.neuralsearch.util.pruning.PruneType;

/**
* Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
Expand All @@ -40,7 +44,34 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
// if the field is miss, will return PruneType.None
PruneType pruneType = PruneType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD));
float pruneRatio = 0;
if (pruneType != PruneType.NONE) {
// if we have prune type, then prune ratio field must have value
// readDoubleProperty will throw exception if value is not present
pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue();
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException(
"Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruneType.getValue()
);
} else {
// if we don't have prune type, then prune ratio field must not have value
if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) {
throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided");
}
}

return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService);
return new SparseEncodingProcessor(
tag,
description,
batchSize,
modelId,
fieldMap,
pruneType,
pruneRatio,
clientAccessor,
environment,
clusterService
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
*/
package org.opensearch.neuralsearch.util;

import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -44,7 +47,11 @@ public class TokenWeightUtil {
*
* @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor}
*/
public static List<Map<String, Float>> fetchListOfTokenWeightMap(List<Map<String, ?>> mapResultList) {
public static List<Map<String, Float>> fetchListOfTokenWeightMap(
List<Map<String, ?>> mapResultList,
PruneType pruneType,
float pruneRatio
) {
if (null == mapResultList || mapResultList.isEmpty()) {
throw new IllegalArgumentException("The inference result can not be null or empty.");
}
Expand All @@ -58,10 +65,16 @@ public static List<Map<String, Float>> fetchListOfTokenWeightMap(List<Map<String
}
results.addAll((List<?>) map.get("response"));
}
return results.stream().map(TokenWeightUtil::buildTokenWeightMap).collect(Collectors.toList());
return results.stream()
.map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruneType, pruneRatio))
.collect(Collectors.toList());
}

public static List<Map<String, Float>> fetchListOfTokenWeightMap(List<Map<String, ?>> mapResultList) {
return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruneType.NONE, 0f);
}

private static Map<String, Float> buildTokenWeightMap(Object uncastedMap) {
private static Map<String, Float> buildTokenWeightMap(Object uncastedMap, PruneType pruneType, float pruneRatio) {
if (!Map.class.isAssignableFrom(uncastedMap.getClass())) {
throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values.");
}
Expand All @@ -72,6 +85,6 @@ private static Map<String, Float> buildTokenWeightMap(Object uncastedMap) {
}
result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue());
}
return result;
return PruneUtils.pruningSparseVector(pruneType, pruneRatio, result);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.util.pruning;

import org.apache.commons.lang.StringUtils;

/**
* Enum representing different types of pruning methods for sparse vectors
*/
public enum PruneType {
NONE("none"),
TOP_K("top_k"),
ALPHA_MASS("alpha_mass"),
MAX_RATIO("max_ratio"),
ABS_VALUE("abs_value");

private final String value;

PruneType(String value) {
this.value = value;
}

public String getValue() {
return value;
}

/**
* Get PruneType from string value
*
* @param value string representation of pruning type
* @return corresponding PruneType enum
* @throws IllegalArgumentException if value doesn't match any pruning type
*/
public static PruneType fromString(String value) {
if (StringUtils.isEmpty(value)) return NONE;
for (PruneType type : PruneType.values()) {
if (type.value.equals(value)) {
return type;
}
}
throw new IllegalArgumentException("Unknown pruning type: " + value);
}
}
Loading
Loading