Skip to content

Commit

Permalink
enhance: reduce copy of bitset and id conversion of brurtforce search
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Nov 14, 2024
1 parent 993051b commit 91326ab
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 135 deletions.
4 changes: 0 additions & 4 deletions internal/core/src/common/QueryResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ struct VectorIterator {
heap_.pop();
if (iterators_[top->GetIteratorIdx()]->HasNext()) {
auto origin_pair = iterators_[top->GetIteratorIdx()]->Next();
origin_pair.first = convert_to_segment_offset(
origin_pair.first, top->GetIteratorIdx());
auto off_dis_pair = std::make_shared<OffsetDisPair>(
origin_pair, top->GetIteratorIdx());
heap_.push(off_dis_pair);
Expand Down Expand Up @@ -108,8 +106,6 @@ struct VectorIterator {
for (auto& iter : iterators_) {
if (iter->HasNext()) {
auto origin_pair = iter->Next();
origin_pair.first =
convert_to_segment_offset(origin_pair.first, idx);
auto off_dis_pair =
std::make_shared<OffsetDisPair>(origin_pair, idx++);
heap_.push(off_dis_pair);
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/exec/expression/UnaryExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ PhyUnaryRangeFilterExpr::ExecArrayEqualForIndex(bool reverse) {
};
} else {
auto size_per_chunk = segment_->size_per_chunk();
retrieve = [ size_per_chunk, this ](int64_t offset) -> auto {
retrieve = [ size_per_chunk, this ](int64_t offset) -> auto{
auto chunk_idx = offset / size_per_chunk;
auto chunk_offset = offset % size_per_chunk;
const auto& chunk =
Expand Down
49 changes: 24 additions & 25 deletions internal/core/src/query/SearchBruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,21 @@ PrepareBFSearchParams(const SearchInfo& search_info) {
}

SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearch(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const BitsetView& bitset,
DataType data_type) {
SubSearchResult sub_result(dataset.num_queries,
dataset.topk,
dataset.metric_type,
dataset.round_decimal);
auto nq = dataset.num_queries;
auto dim = dataset.dim;
auto topk = dataset.topk;

auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
SubSearchResult sub_result(query_ds.num_queries,
query_ds.topk,
query_ds.metric_type,
query_ds.round_decimal);
auto topk = query_ds.topk;
auto nq = query_ds.num_queries;
auto base_dataset = knowhere::GenDataSet(
raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data, raw_ds.begin_id);
auto query_dataset = knowhere::GenDataSet(
query_ds.num_queries, query_ds.dim, query_ds.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
Expand Down Expand Up @@ -133,7 +132,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
res.what());
}
auto result =
ReGenRangeSearchResult(res.value(), topk, nq, dataset.metric_type);
ReGenRangeSearchResult(res.value(), topk, nq, query_ds.metric_type);
milvus::tracer::AddEvent("ReGenRangeSearchResult");
std::copy_n(
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
Expand Down Expand Up @@ -197,16 +196,16 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
}

SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = dataset.num_queries;
auto dim = dataset.dim;
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
auto nq = query_ds.num_queries;
auto base_dataset = knowhere::GenDataSet(
raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data, raw_ds.begin_id);
auto query_dataset = knowhere::GenDataSet(
query_ds.num_queries, query_ds.dim, query_ds.query_data);
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
Expand Down Expand Up @@ -245,10 +244,10 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
"equal to nq:{} for single chunk",
iterators_val.value().size(),
nq);
SubSearchResult subSearchResult(dataset.num_queries,
dataset.topk,
dataset.metric_type,
dataset.round_decimal,
SubSearchResult subSearchResult(query_ds.num_queries,
query_ds.topk,
query_ds.metric_type,
query_ds.round_decimal,
iterators_val.value());
return std::move(subSearchResult);
} else {
Expand Down
10 changes: 4 additions & 6 deletions internal/core/src/query/SearchBruteForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@ CheckBruteForceSearchParam(const FieldMeta& field,
const SearchInfo& search_info);

SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearch(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const BitsetView& bitset,
DataType data_type);

SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const BitsetView& bitset,
DataType data_type);
Expand Down
26 changes: 6 additions & 20 deletions internal/core/src/query/SearchOnGrowing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,29 +123,15 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = element_end - element_begin;

auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_data = query::dataset::RawDataset{
element_begin, dim, size_per_chunk, chunk_data};
if (info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(search_dataset,
chunk_data,
size_per_chunk,
info,
sub_view,
data_type);
auto sub_qr = BruteForceSearchIterators(
search_dataset, sub_data, info, bitset, data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(search_dataset,
chunk_data,
size_per_chunk,
info,
sub_view,
data_type);

// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) {
if (x != -1) {
x += chunk_id * vec_size_per_chunk;
}
}
auto sub_qr = BruteForceSearch(
search_dataset, sub_data, info, bitset, data_type);
final_qr.merge(sub_qr);
}
}
Expand Down
81 changes: 26 additions & 55 deletions internal/core/src/query/SearchOnSealed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ SearchOnSealed(const Schema& schema,
? 0
: field.get_dim();

query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};

auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
Expand All @@ -115,49 +115,19 @@ SearchOnSealed(const Schema& schema,
auto vec_data = column->Data(i);
auto chunk_size = column->chunk_row_nums(i);
const uint8_t* bitset_ptr = nullptr;
bool aligned = false;
if ((offset & 0x7) == 0) {
bitset_ptr = bitview.data() + (offset >> 3);
aligned = true;
} else {
char* bitset_data = new char[(chunk_size + 7) / 8];
std::fill(bitset_data, bitset_data + sizeof(bitset_data), 0);
bitset::detail::ElementWiseBitsetPolicy<char>::op_copy(
reinterpret_cast<const char*>(bitview.data()),
offset,
bitset_data,
0,
chunk_size);
bitset_ptr = reinterpret_cast<const uint8_t*>(bitset_data);
}
BitsetView bitset_view(bitset_ptr, chunk_size);

auto data_id = offset;
auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(dataset,
vec_data,
chunk_size,
search_info,
bitset_view,
data_type);
auto sub_qr = BruteForceSearchIterators(
query_dataset, raw_dataset, search_info, bitview, data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(dataset,
vec_data,
chunk_size,
search_info,
bitset_view,
data_type);
for (auto& o : sub_qr.mutable_seg_offsets()) {
if (o != -1) {
o += offset;
}
}
auto sub_qr = BruteForceSearch(
query_dataset, raw_dataset, search_info, bitview, data_type);
final_qr.merge(sub_qr);
}

if (!aligned) {
delete[] bitset_ptr;
}
offset += chunk_size;
}
if (search_info.group_by_field_id_.has_value()) {
Expand All @@ -169,8 +139,8 @@ SearchOnSealed(const Schema& schema,
result.distances_ = std::move(final_qr.mutable_distances());
result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
}
result.unity_topK_ = dataset.topk;
result.total_nq_ = dataset.num_queries;
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;
}

void
Expand All @@ -190,28 +160,29 @@ SearchOnSealed(const Schema& schema,
? 0
: field.get_dim();

query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
dim,
query_data};

auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data};
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(
dataset, vec_data, row_count, search_info, bitset, data_type);
query_dataset, raw_dataset, search_info, bitset, data_type);
result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators());
} else {
auto sub_qr = BruteForceSearch(
dataset, vec_data, row_count, search_info, bitset, data_type);
query_dataset, raw_dataset, search_info, bitset, data_type);
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
}
result.unity_topK_ = dataset.topk;
result.total_nq_ = dataset.num_queries;
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;
}

} // namespace milvus::query
7 changes: 6 additions & 1 deletion internal/core/src/query/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

namespace milvus::query {
namespace dataset {

struct RawDataset {
int64_t begin_id = 0;
int64_t dim;
int64_t num_raw_data;
const void* raw_data;
};
struct SearchDataset {
knowhere::MetricType metric_type;
int64_t num_queries;
Expand Down
4 changes: 2 additions & 2 deletions internal/core/thirdparty/knowhere/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# Update KNOWHERE_VERSION for the first occurrence
milvus_add_pkg_config("knowhere")
set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "")
set( KNOWHERE_VERSION 5935c1f )
set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git")
set( KNOWHERE_VERSION fix-idselector )
set( GIT_REPOSITORY "https://github.com/cqy123456/zilliztech-knowhere.git")
message(STATUS "Knowhere repo: ${GIT_REPOSITORY}")
message(STATUS "Knowhere version: ${KNOWHERE_VERSION}")

Expand Down
9 changes: 5 additions & 4 deletions internal/core/unittest/test_bf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
auto base = GenFloatVecs(dim, nb, metric_type);
auto query = GenFloatVecs(dim, nq, metric_type);

dataset::SearchDataset dataset{
dataset::SearchDataset query_dataset{
metric_type, nq, topk, -1, dim, query.data()};
if (!is_supported_float_metric(metric_type)) {
// Memory leak in knowhere.
Expand All @@ -133,9 +133,10 @@ class TestFloatSearchBruteForce : public ::testing::Test {
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
auto result = BruteForceSearch(dataset,
base.data(),
nb,

auto raw_dataset = query::dataset::RawDataset{0, dim, nb, base.data()};
auto result = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
bitset_view,
DataType::VECTOR_FLOAT);
Expand Down
24 changes: 11 additions & 13 deletions internal/core/unittest/test_bf_sparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,20 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
dataset::SearchDataset dataset{
dataset::SearchDataset query_dataset{
metric_type, nq, topk, -1, kTestSparseDim, query.get()};
auto raw_dataset =
query::dataset::RawDataset{0, kTestSparseDim, nb, base.get()};
if (!is_supported_sparse_float_metric(metric_type)) {
ASSERT_ANY_THROW(BruteForceSearch(dataset,
base.get(),
nb,
ASSERT_ANY_THROW(BruteForceSearch(query_dataset,
raw_dataset,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT));
return;
}
auto result = BruteForceSearch(dataset,
base.get(),
nb,
auto result = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
Expand All @@ -126,9 +126,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {

search_info.search_params_[RADIUS] = 0.1;
search_info.search_params_[RANGE_FILTER] = 0.5;
auto result2 = BruteForceSearch(dataset,
base.get(),
nb,
auto result2 = BruteForceSearch(query_dataset,
raw_dataset,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
Expand All @@ -139,9 +138,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
AssertMatch(ref, ans);
}

auto result3 = BruteForceSearchIterators(dataset,
base.get(),
nb,
auto result3 = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
Expand Down
Loading

0 comments on commit 91326ab

Please sign in to comment.