diff --git a/internal/core/src/common/QueryResult.h b/internal/core/src/common/QueryResult.h index 75d54987ea607..2407a295cdc28 100644 --- a/internal/core/src/common/QueryResult.h +++ b/internal/core/src/common/QueryResult.h @@ -79,8 +79,6 @@ struct VectorIterator { heap_.pop(); if (iterators_[top->GetIteratorIdx()]->HasNext()) { auto origin_pair = iterators_[top->GetIteratorIdx()]->Next(); - origin_pair.first = convert_to_segment_offset( - origin_pair.first, top->GetIteratorIdx()); auto off_dis_pair = std::make_shared( origin_pair, top->GetIteratorIdx()); heap_.push(off_dis_pair); @@ -108,8 +106,6 @@ struct VectorIterator { for (auto& iter : iterators_) { if (iter->HasNext()) { auto origin_pair = iter->Next(); - origin_pair.first = - convert_to_segment_offset(origin_pair.first, idx); auto off_dis_pair = std::make_shared(origin_pair, idx++); heap_.push(off_dis_pair); diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index eca42e9a6f151..9df66690b8396 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -71,28 +71,49 @@ PrepareBFSearchParams(const SearchInfo& search_info, return search_cfg; } -SubSearchResult -BruteForceSearch(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t chunk_rows, - const SearchInfo& search_info, - const std::map& index_info, - const BitsetView& bitset, +std::pair +PrepareBFDataSet(const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, DataType data_type) { - SubSearchResult sub_result(dataset.num_queries, - dataset.topk, - dataset.metric_type, - dataset.round_decimal); - auto nq = dataset.num_queries; - auto dim = dataset.dim; - auto topk = dataset.topk; - - auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw); - auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data); + auto base_dataset = + knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data); + auto query_dataset = knowhere::GenDataSet( + query_ds.num_queries, query_ds.dim, query_ds.query_data); if (data_type == DataType::VECTOR_SPARSE_FLOAT) { base_dataset->SetIsSparse(true); query_dataset->SetIsSparse(true); + } else if (data_type == DataType::VECTOR_BFLOAT16) { + //todo: if knowhere support real fp16/bf16 bf, remove convert + base_dataset = + knowhere::ConvertFromDataTypeIfNeeded(base_dataset); + query_dataset = + knowhere::ConvertFromDataTypeIfNeeded(query_dataset); + } else if (data_type == DataType::VECTOR_FLOAT16) { + //todo: if knowhere support real fp16/bf16 bf, remove convert + base_dataset = + knowhere::ConvertFromDataTypeIfNeeded(base_dataset); + query_dataset = + knowhere::ConvertFromDataTypeIfNeeded(query_dataset); } + base_dataset->SetTensorBeginId(raw_ds.begin_id); + return std::make_pair(query_dataset, base_dataset); +}; + +SubSearchResult +BruteForceSearch(const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + DataType data_type) { + SubSearchResult sub_result(query_ds.num_queries, + query_ds.topk, + query_ds.metric_type, + query_ds.round_decimal); + auto topk = query_ds.topk; + auto nq = query_ds.num_queries; + auto [query_dataset, base_dataset] = + PrepareBFDataSet(query_ds, raw_ds, data_type); auto search_cfg = PrepareBFSearchParams(search_info, index_info); // `range_search_k` is only used as one of the conditions for iterator early termination. // not gurantee to return exactly `range_search_k` results, which may be more or less. @@ -112,10 +133,12 @@ BruteForceSearch(const dataset::SearchDataset& dataset, res = knowhere::BruteForce::RangeSearch( base_dataset, query_dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_FLOAT16) { - res = knowhere::BruteForce::RangeSearch( + //todo: if knowhere support real fp16/bf16 bf, change it + res = knowhere::BruteForce::RangeSearch( base_dataset, query_dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_BFLOAT16) { - res = knowhere::BruteForce::RangeSearch( + //todo: if knowhere support real fp16/bf16 bf, change it + res = knowhere::BruteForce::RangeSearch( base_dataset, query_dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_BINARY) { res = knowhere::BruteForce::RangeSearch( @@ -138,7 +161,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset, res.what()); } auto result = - ReGenRangeSearchResult(res.value(), topk, nq, dataset.metric_type); + ReGenRangeSearchResult(res.value(), topk, nq, query_ds.metric_type); milvus::tracer::AddEvent("ReGenRangeSearchResult"); std::copy_n( GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets()); @@ -155,7 +178,8 @@ BruteForceSearch(const dataset::SearchDataset& dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_FLOAT16) { - stat = knowhere::BruteForce::SearchWithBuf( + //todo: if knowhere support real fp16/bf16 bf, change it + stat = knowhere::BruteForce::SearchWithBuf( base_dataset, query_dataset, sub_result.mutable_seg_offsets().data(), @@ -163,7 +187,8 @@ BruteForceSearch(const dataset::SearchDataset& dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_BFLOAT16) { - stat = knowhere::BruteForce::SearchWithBuf( + //todo: if knowhere support real fp16/bf16 bf, change it + stat = knowhere::BruteForce::SearchWithBuf( base_dataset, query_dataset, sub_result.mutable_seg_offsets().data(), @@ -202,21 +227,15 @@ BruteForceSearch(const dataset::SearchDataset& dataset, } SubSearchResult -BruteForceSearchIterators(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t chunk_rows, +BruteForceSearchIterators(const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, DataType data_type) { - auto nq = dataset.num_queries; - auto dim = dataset.dim; - auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw); - auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data); - if (data_type == DataType::VECTOR_SPARSE_FLOAT) { - base_dataset->SetIsSparse(true); - query_dataset->SetIsSparse(true); - } + auto nq = query_ds.num_queries; + auto [query_dataset, base_dataset] = + PrepareBFDataSet(query_ds, raw_ds, data_type); auto search_cfg = PrepareBFSearchParams(search_info, index_info); knowhere::expected> @@ -227,11 +246,13 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, base_dataset, query_dataset, search_cfg, bitset); break; case DataType::VECTOR_FLOAT16: - iterators_val = knowhere::BruteForce::AnnIterator( + //todo: if knowhere support real fp16/bf16 bf, change it + iterators_val = knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, search_cfg, bitset); break; case DataType::VECTOR_BFLOAT16: - iterators_val = knowhere::BruteForce::AnnIterator( + //todo: if knowhere support real fp16/bf16 bf, change it + iterators_val = knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, search_cfg, bitset); break; case DataType::VECTOR_SPARSE_FLOAT: @@ -251,10 +272,10 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, "equal to nq:{} for single chunk", iterators_val.value().size(), nq); - SubSearchResult subSearchResult(dataset.num_queries, - dataset.topk, - dataset.metric_type, - dataset.round_decimal, + SubSearchResult subSearchResult(query_ds.num_queries, + query_ds.topk, + query_ds.metric_type, + query_ds.round_decimal, iterators_val.value()); return std::move(subSearchResult); } else { diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index 3cf6863b91d08..15fb5697abfdf 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -24,18 +24,16 @@ CheckBruteForceSearchParam(const FieldMeta& field, const SearchInfo& search_info); SubSearchResult -BruteForceSearch(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t chunk_rows, +BruteForceSearch(const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, DataType data_type); SubSearchResult -BruteForceSearchIterators(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t chunk_rows, +BruteForceSearchIterators(const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 7e6606261fa43..2add5f5a1fde8 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -136,31 +136,23 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, std::min(active_count, (chunk_id + 1) * vec_size_per_chunk); auto size_per_chunk = element_end - element_begin; - auto sub_view = bitset.subview(element_begin, size_per_chunk); + auto sub_data = query::dataset::RawDataset{ + element_begin, dim, size_per_chunk, chunk_data}; if (info.group_by_field_id_.has_value()) { auto sub_qr = BruteForceSearchIterators(search_dataset, - chunk_data, - size_per_chunk, + sub_data, info, index_info, - sub_view, + bitset, data_type); final_qr.merge(sub_qr); } else { auto sub_qr = BruteForceSearch(search_dataset, - chunk_data, - size_per_chunk, + sub_data, info, index_info, - sub_view, + bitset, data_type); - - // convert chunk uid to segment uid - for (auto& x : sub_qr.mutable_seg_offsets()) { - if (x != -1) { - x += chunk_id * vec_size_per_chunk; - } - } final_qr.merge(sub_qr); } } diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 2a0dc5f7b078a..59146b2447a0b 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -95,12 +95,12 @@ SearchOnSealed(const Schema& schema, ? 0 : field.get_dim(); - query::dataset::SearchDataset dataset{search_info.metric_type_, - num_queries, - search_info.topk_, - search_info.round_decimal_, - dim, - query_data}; + query::dataset::SearchDataset query_dataset{search_info.metric_type_, + num_queries, + search_info.topk_, + search_info.round_decimal_, + dim, + query_data}; auto data_type = field.get_data_type(); CheckBruteForceSearchParam(field, search_info); @@ -116,51 +116,27 @@ SearchOnSealed(const Schema& schema, auto vec_data = column->Data(i); auto chunk_size = column->chunk_row_nums(i); const uint8_t* bitset_ptr = nullptr; - bool aligned = false; - if ((offset & 0x7) == 0) { - bitset_ptr = bitview.data() + (offset >> 3); - aligned = true; - } else { - char* bitset_data = new char[(chunk_size + 7) / 8]; - std::fill(bitset_data, bitset_data + sizeof(bitset_data), 0); - bitset::detail::ElementWiseBitsetPolicy::op_copy( - reinterpret_cast(bitview.data()), - offset, - bitset_data, - 0, - chunk_size); - bitset_ptr = reinterpret_cast(bitset_data); - } - BitsetView bitset_view(bitset_ptr, chunk_size); - + auto data_id = offset; + auto raw_dataset = + query::dataset::RawDataset{offset, dim, chunk_size, vec_data}; if (search_info.group_by_field_id_.has_value()) { - auto sub_qr = BruteForceSearchIterators(dataset, - vec_data, - chunk_size, + auto sub_qr = BruteForceSearchIterators(query_dataset, + raw_dataset, search_info, index_info, - bitset_view, + bitview, data_type); final_qr.merge(sub_qr); } else { - auto sub_qr = BruteForceSearch(dataset, - vec_data, - chunk_size, + auto sub_qr = BruteForceSearch(query_dataset, + raw_dataset, search_info, index_info, - bitset_view, + bitview, data_type); - for (auto& o : sub_qr.mutable_seg_offsets()) { - if (o != -1) { - o += offset; - } - } final_qr.merge(sub_qr); } - if (!aligned) { - delete[] bitset_ptr; - } offset += chunk_size; } if (search_info.group_by_field_id_.has_value()) { @@ -172,8 +148,8 @@ SearchOnSealed(const Schema& schema, result.distances_ = std::move(final_qr.mutable_distances()); result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets()); } - result.unity_topK_ = dataset.topk; - result.total_nq_ = dataset.num_queries; + result.unity_topK_ = query_dataset.topk; + result.total_nq_ = query_dataset.num_queries; } void @@ -194,19 +170,19 @@ SearchOnSealed(const Schema& schema, ? 0 : field.get_dim(); - query::dataset::SearchDataset dataset{search_info.metric_type_, - num_queries, - search_info.topk_, - search_info.round_decimal_, - dim, - query_data}; + query::dataset::SearchDataset query_dataset{search_info.metric_type_, + num_queries, + search_info.topk_, + search_info.round_decimal_, + dim, + query_data}; auto data_type = field.get_data_type(); CheckBruteForceSearchParam(field, search_info); + auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data}; if (search_info.group_by_field_id_.has_value()) { - auto sub_qr = BruteForceSearchIterators(dataset, - vec_data, - row_count, + auto sub_qr = BruteForceSearchIterators(query_dataset, + raw_dataset, search_info, index_info, bitset, @@ -214,9 +190,8 @@ SearchOnSealed(const Schema& schema, result.AssembleChunkVectorIterators( num_queries, 1, {0}, sub_qr.chunk_iterators()); } else { - auto sub_qr = BruteForceSearch(dataset, - vec_data, - row_count, + auto sub_qr = BruteForceSearch(query_dataset, + raw_dataset, search_info, index_info, bitset, @@ -224,8 +199,8 @@ SearchOnSealed(const Schema& schema, result.distances_ = std::move(sub_qr.mutable_distances()); result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets()); } - result.unity_topK_ = dataset.topk; - result.total_nq_ = dataset.num_queries; + result.unity_topK_ = query_dataset.topk; + result.total_nq_ = query_dataset.num_queries; } } // namespace milvus::query diff --git a/internal/core/src/query/helper.h b/internal/core/src/query/helper.h index 88bedfe590080..56bf1d1261f6f 100644 --- a/internal/core/src/query/helper.h +++ b/internal/core/src/query/helper.h @@ -19,7 +19,12 @@ namespace milvus::query { namespace dataset { - +struct RawDataset { + int64_t begin_id = 0; + int64_t dim; + int64_t num_raw_data; + const void* raw_data; +}; struct SearchDataset { knowhere::MetricType metric_type; int64_t num_queries; diff --git a/internal/core/unittest/test_bf.cpp b/internal/core/unittest/test_bf.cpp index 75c4566145759..d92840c2b11f7 100644 --- a/internal/core/unittest/test_bf.cpp +++ b/internal/core/unittest/test_bf.cpp @@ -124,7 +124,7 @@ class TestFloatSearchBruteForce : public ::testing::Test { auto query = GenFloatVecs(dim, nq, metric_type); auto index_info = std::map{}; - dataset::SearchDataset dataset{ + dataset::SearchDataset query_dataset{ metric_type, nq, topk, -1, dim, query.data()}; if (!is_supported_float_metric(metric_type)) { // Memory leak in knowhere. @@ -134,9 +134,10 @@ class TestFloatSearchBruteForce : public ::testing::Test { SearchInfo search_info; search_info.topk_ = topk; search_info.metric_type_ = metric_type; - auto result = BruteForceSearch(dataset, - base.data(), - nb, + + auto raw_dataset = query::dataset::RawDataset{0, dim, nb, base.data()}; + auto result = BruteForceSearch(query_dataset, + raw_dataset, search_info, index_info, bitset_view, diff --git a/internal/core/unittest/test_bf_sparse.cpp b/internal/core/unittest/test_bf_sparse.cpp index 0c5ce6d1a64be..4724874ebc039 100644 --- a/internal/core/unittest/test_bf_sparse.cpp +++ b/internal/core/unittest/test_bf_sparse.cpp @@ -103,21 +103,21 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { SearchInfo search_info; search_info.topk_ = topk; search_info.metric_type_ = metric_type; - dataset::SearchDataset dataset{ + dataset::SearchDataset query_dataset{ metric_type, nq, topk, -1, kTestSparseDim, query.get()}; + auto raw_dataset = + query::dataset::RawDataset{0, kTestSparseDim, nb, base.get()}; if (!is_supported_sparse_float_metric(metric_type)) { - ASSERT_ANY_THROW(BruteForceSearch(dataset, - base.get(), - nb, + ASSERT_ANY_THROW(BruteForceSearch(query_dataset, + raw_dataset, search_info, index_info, bitset_view, DataType::VECTOR_SPARSE_FLOAT)); return; } - auto result = BruteForceSearch(dataset, - base.get(), - nb, + auto result = BruteForceSearch(query_dataset, + raw_dataset, search_info, index_info, bitset_view, @@ -130,9 +130,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { search_info.search_params_[RADIUS] = 0.1; search_info.search_params_[RANGE_FILTER] = 0.5; - auto result2 = BruteForceSearch(dataset, - base.get(), - nb, + auto result2 = BruteForceSearch(query_dataset, + raw_dataset, search_info, index_info, bitset_view, @@ -144,9 +143,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { AssertMatch(ref, ans); } - auto result3 = BruteForceSearchIterators(dataset, - base.get(), - nb, + auto result3 = BruteForceSearchIterators(query_dataset, + raw_dataset, search_info, index_info, bitset_view, diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 32ead437558a2..bd3678f6f5d13 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -177,9 +177,10 @@ TEST(Indexing, BinaryBruteForce) { search_info.topk_ = topk; search_info.round_decimal_ = round_decimal; search_info.metric_type_ = metric_type; + auto base_dataset = query::dataset::RawDataset{ + int64_t(0), dim, N, (const void*)bin_vec.data()}; auto sub_result = query::BruteForceSearch(search_dataset, - bin_vec.data(), - N, + base_dataset, search_info, index_info, nullptr, diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index b8e1e4c090b44..a8c690e19d629 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -1254,9 +1254,9 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { search_info.topk_ = topk; search_info.round_decimal_ = round_decimal; search_info.metric_type_ = metric_type; + auto raw_dataset = query::dataset::RawDataset{0, dim, N, vec_col.data()}; auto sub_result = BruteForceSearch(search_dataset, - vec_col.data(), - N, + raw_dataset, search_info, index_info, nullptr,