Skip to content

Commit

Permalink
enhance: reduce copy of bitset and id conversion of brurtforce search
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Nov 18, 2024
1 parent 0fc0d1a commit 6544d03
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 140 deletions.
4 changes: 0 additions & 4 deletions internal/core/src/common/QueryResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ struct VectorIterator {
heap_.pop();
if (iterators_[top->GetIteratorIdx()]->HasNext()) {
auto origin_pair = iterators_[top->GetIteratorIdx()]->Next();
origin_pair.first = convert_to_segment_offset(
origin_pair.first, top->GetIteratorIdx());
auto off_dis_pair = std::make_shared<OffsetDisPair>(
origin_pair, top->GetIteratorIdx());
heap_.push(off_dis_pair);
Expand Down Expand Up @@ -108,8 +106,6 @@ struct VectorIterator {
for (auto& iter : iterators_) {
if (iter->HasNext()) {
auto origin_pair = iter->Next();
origin_pair.first =
convert_to_segment_offset(origin_pair.first, idx);
auto off_dis_pair =
std::make_shared<OffsetDisPair>(origin_pair, idx++);
heap_.push(off_dis_pair);
Expand Down
99 changes: 60 additions & 39 deletions internal/core/src/query/SearchBruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,49 @@ PrepareBFSearchParams(const SearchInfo& search_info,
return search_cfg;
}

SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
std::pair<knowhere::DataSetPtr, knowhere::DataSetPtr>
PrepareBFDataSet(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
DataType data_type) {
SubSearchResult sub_result(dataset.num_queries,
dataset.topk,
dataset.metric_type,
dataset.round_decimal);
auto nq = dataset.num_queries;
auto dim = dataset.dim;
auto topk = dataset.topk;

auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
auto base_dataset =
knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data);
auto query_dataset = knowhere::GenDataSet(
query_ds.num_queries, query_ds.dim, query_ds.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, remove convert
base_dataset =
knowhere::ConvertFromDataTypeIfNeeded<bfloat16>(base_dataset);
query_dataset =
knowhere::ConvertFromDataTypeIfNeeded<bfloat16>(query_dataset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, remove convert
base_dataset =
knowhere::ConvertFromDataTypeIfNeeded<float16>(base_dataset);
query_dataset =
knowhere::ConvertFromDataTypeIfNeeded<float16>(query_dataset);
}
base_dataset->SetTensorBeginId(raw_ds.begin_id);
return std::make_pair(query_dataset, base_dataset);
};

SubSearchResult
BruteForceSearch(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
SubSearchResult sub_result(query_ds.num_queries,
query_ds.topk,
query_ds.metric_type,
query_ds.round_decimal);
auto topk = query_ds.topk;
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
Expand All @@ -112,10 +133,12 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
res = knowhere::BruteForce::RangeSearch<float>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
res = knowhere::BruteForce::RangeSearch<float16>(
//todo: if knowhere support real fp16/bf16 bf, change it
res = knowhere::BruteForce::RangeSearch<float>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
res = knowhere::BruteForce::RangeSearch<bfloat16>(
//todo: if knowhere support real fp16/bf16 bf, change it
res = knowhere::BruteForce::RangeSearch<float>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_BINARY) {
res = knowhere::BruteForce::RangeSearch<uint8_t>(
Expand All @@ -138,7 +161,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
res.what());
}
auto result =
ReGenRangeSearchResult(res.value(), topk, nq, dataset.metric_type);
ReGenRangeSearchResult(res.value(), topk, nq, query_ds.metric_type);
milvus::tracer::AddEvent("ReGenRangeSearchResult");
std::copy_n(
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
Expand All @@ -155,15 +178,17 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
search_cfg,
bitset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
stat = knowhere::BruteForce::SearchWithBuf<float16>(
//todo: if knowhere support real fp16/bf16 bf, change it
stat = knowhere::BruteForce::SearchWithBuf<float>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
stat = knowhere::BruteForce::SearchWithBuf<bfloat16>(
//todo: if knowhere support real fp16/bf16 bf, change it
stat = knowhere::BruteForce::SearchWithBuf<float>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
Expand Down Expand Up @@ -202,21 +227,15 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
}

SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = dataset.num_queries;
auto dim = dataset.dim;
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);

knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
Expand All @@ -227,11 +246,13 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
base_dataset, query_dataset, search_cfg, bitset);
break;
case DataType::VECTOR_FLOAT16:
iterators_val = knowhere::BruteForce::AnnIterator<float16>(
//todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>(

Check warning on line 250 in internal/core/src/query/SearchBruteForce.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/query/SearchBruteForce.cpp#L250

Added line #L250 was not covered by tests
base_dataset, query_dataset, search_cfg, bitset);
break;
case DataType::VECTOR_BFLOAT16:
iterators_val = knowhere::BruteForce::AnnIterator<bfloat16>(
//todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>(

Check warning on line 255 in internal/core/src/query/SearchBruteForce.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/query/SearchBruteForce.cpp#L255

Added line #L255 was not covered by tests
base_dataset, query_dataset, search_cfg, bitset);
break;
case DataType::VECTOR_SPARSE_FLOAT:
Expand All @@ -251,10 +272,10 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
"equal to nq:{} for single chunk",
iterators_val.value().size(),
nq);
SubSearchResult subSearchResult(dataset.num_queries,
dataset.topk,
dataset.metric_type,
dataset.round_decimal,
SubSearchResult subSearchResult(query_ds.num_queries,
query_ds.topk,
query_ds.metric_type,
query_ds.round_decimal,
iterators_val.value());
return std::move(subSearchResult);
} else {
Expand Down
10 changes: 4 additions & 6 deletions internal/core/src/query/SearchBruteForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,16 @@ CheckBruteForceSearchParam(const FieldMeta& field,
const SearchInfo& search_info);

SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearch(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);

SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
Expand Down
20 changes: 6 additions & 14 deletions internal/core/src/query/SearchOnGrowing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,31 +136,23 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = element_end - element_begin;

auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_data = query::dataset::RawDataset{
element_begin, dim, size_per_chunk, chunk_data};
if (info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(search_dataset,
chunk_data,
size_per_chunk,
sub_data,
info,
index_info,
sub_view,
bitset,
data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(search_dataset,
chunk_data,
size_per_chunk,
sub_data,
info,
index_info,
sub_view,
bitset,
data_type);

// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) {
if (x != -1) {
x += chunk_id * vec_size_per_chunk;
}
}
final_qr.merge(sub_qr);
}
}
Expand Down
85 changes: 30 additions & 55 deletions internal/core/src/query/SearchOnSealed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ SearchOnSealed(const Schema& schema,
? 0
: field.get_dim();

query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};

auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
Expand All @@ -116,51 +116,27 @@ SearchOnSealed(const Schema& schema,
auto vec_data = column->Data(i);
auto chunk_size = column->chunk_row_nums(i);
const uint8_t* bitset_ptr = nullptr;
bool aligned = false;
if ((offset & 0x7) == 0) {
bitset_ptr = bitview.data() + (offset >> 3);
aligned = true;
} else {
char* bitset_data = new char[(chunk_size + 7) / 8];
std::fill(bitset_data, bitset_data + sizeof(bitset_data), 0);
bitset::detail::ElementWiseBitsetPolicy<char>::op_copy(
reinterpret_cast<const char*>(bitview.data()),
offset,
bitset_data,
0,
chunk_size);
bitset_ptr = reinterpret_cast<const uint8_t*>(bitset_data);
}
BitsetView bitset_view(bitset_ptr, chunk_size);

auto data_id = offset;
auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(dataset,
vec_data,
chunk_size,
auto sub_qr = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
bitview,
data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(dataset,
vec_data,
chunk_size,
auto sub_qr = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
bitview,
data_type);
for (auto& o : sub_qr.mutable_seg_offsets()) {
if (o != -1) {
o += offset;
}
}
final_qr.merge(sub_qr);
}

if (!aligned) {
delete[] bitset_ptr;
}
offset += chunk_size;
}
if (search_info.group_by_field_id_.has_value()) {
Expand All @@ -172,8 +148,8 @@ SearchOnSealed(const Schema& schema,
result.distances_ = std::move(final_qr.mutable_distances());
result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
}
result.unity_topK_ = dataset.topk;
result.total_nq_ = dataset.num_queries;
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;
}

void
Expand All @@ -194,38 +170,37 @@ SearchOnSealed(const Schema& schema,
? 0
: field.get_dim();

query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};

auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data};
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(dataset,
vec_data,
row_count,
auto sub_qr = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
data_type);
result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators());
} else {
auto sub_qr = BruteForceSearch(dataset,
vec_data,
row_count,
auto sub_qr = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
data_type);
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
}
result.unity_topK_ = dataset.topk;
result.total_nq_ = dataset.num_queries;
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;
}

} // namespace milvus::query
Loading

0 comments on commit 6544d03

Please sign in to comment.