Skip to content

Commit

Permalink
use shared lock for global mutex of hgraph
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Nov 14, 2024
1 parent a0b12e0 commit 66f1d05
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 49 deletions.
99 changes: 51 additions & 48 deletions src/index/hgraph_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ HGraphIndex::hnsw_add(const DatasetPtr& data) {
auto cur_count = this->bottom_graph_->TotalCount();
vsag::Vector<std::shared_mutex>(total + cur_count, allocator_).swap(this->neighbors_mutex_);

std::mutex add_mutex;

auto build_func = [&](InnerIdType begin, InnerIdType end) -> void {
for (InnerIdType i = begin; i < end; ++i) {
int level = this->get_random_level() - 1;
Expand All @@ -221,60 +223,20 @@ HGraphIndex::hnsw_add(const DatasetPtr& data) {
this->label_lookup_[label] = inner_id;
}

std::unique_lock<std::mutex> lock(this->global_mutex_);
bool need_lock = false;
std::unique_lock<std::mutex> add_lock(add_mutex);
if (level >= int64_t(this->max_level_) || bottom_graph_->TotalCount() == 0) {
std::unique_lock<std::shared_mutex> wlock(this->global_mutex_);
for (int64_t j = max_level_; j <= level; ++j) {
this->route_graphs_.emplace_back(this->generate_one_route_graph());
}
max_level_ = level + 1;
need_lock = true;
} else {
lock.unlock();
}

{
auto ep = this->entry_point_id_;
MaxHeap result(allocator_);
for (auto j = max_level_ - 1; j > level; --j) {
result = search_one_graph(
datas + dim_ * i, route_graphs_[j], basic_flatten_codes_, ep, 1, nullptr);
ep = result.top().second;
}

for (auto j = level; j >= 0; --j) {
if (route_graphs_[j]->TotalCount() != 0) {
result = search_one_graph(datas + dim_ * i,
route_graphs_[j],
basic_flatten_codes_,
ep,
this->ef_construct_,
nullptr);
ep = this->mutually_connect_new_element(
inner_id, result, route_graphs_[j], basic_flatten_codes_, false);
} else {
route_graphs_[j]->InsertNeighborsById(inner_id,
Vector<InnerIdType>(allocator_));
}
route_graphs_[j]->IncreaseTotalCount(1);
}
if (bottom_graph_->TotalCount() != 0) {
result = search_one_graph(datas + dim_ * i,
this->bottom_graph_,
basic_flatten_codes_,
ep,
this->ef_construct_,
nullptr);
this->mutually_connect_new_element(
inner_id, result, this->bottom_graph_, basic_flatten_codes_, false);
} else {
bottom_graph_->InsertNeighborsById(inner_id, Vector<InnerIdType>(allocator_));
}
bottom_graph_->IncreaseTotalCount(1);
}

if (need_lock) {
this->add_one_point(datas + i * dim_, level, inner_id);
entry_point_id_ = inner_id;
add_lock.unlock();
} else {
add_lock.unlock();
std::shared_lock<std::shared_mutex> rlock(this->global_mutex_);
this->add_one_point(datas + i * dim_, level, inner_id);
}
}
};
Expand Down Expand Up @@ -592,6 +554,7 @@ HGraphIndex::cal_serialize_size() const {
this->serialize(writer);
return writer.cursor_;
}

tl::expected<void, Error>
HGraphIndex::serialize(std::ostream& out_stream) const {
try {
Expand Down Expand Up @@ -634,6 +597,7 @@ HGraphIndex::deserialize(const BinarySet& binary_set) {

return {};
}

tl::expected<void, Error>
HGraphIndex::deserialize(std::istream& in_stream) {
SlowTaskTimer t("hgraph deserialize");
Expand Down Expand Up @@ -671,5 +635,44 @@ HGraphIndex::calc_distance_by_id(const float* vector, int64_t id) const {
}
}
}
void
HGraphIndex::add_one_point(const float* data, int level, InnerIdType inner_id) {
auto ep = this->entry_point_id_;
MaxHeap result(allocator_);
for (auto j = max_level_ - 1; j > level; --j) {
result = search_one_graph(
data, route_graphs_[j], basic_flatten_codes_, ep, 1, nullptr);
ep = result.top().second;
}

for (auto j = level; j >= 0; --j) {
if (route_graphs_[j]->TotalCount() != 0) {
result = search_one_graph(data,
route_graphs_[j],
basic_flatten_codes_,
ep,
this->ef_construct_,
nullptr);
ep = this->mutually_connect_new_element(
inner_id, result, route_graphs_[j], basic_flatten_codes_, false);
} else {
route_graphs_[j]->InsertNeighborsById(inner_id, Vector<InnerIdType>(allocator_));
}
route_graphs_[j]->IncreaseTotalCount(1);
}
if (bottom_graph_->TotalCount() != 0) {
result = search_one_graph(data,
this->bottom_graph_,
basic_flatten_codes_,
ep,
this->ef_construct_,
nullptr);
this->mutually_connect_new_element(
inner_id, result, this->bottom_graph_, basic_flatten_codes_, false);
} else {
bottom_graph_->InsertNeighborsById(inner_id, Vector<InnerIdType>(allocator_));
}
bottom_graph_->IncreaseTotalCount(1);
}

} // namespace vsag
5 changes: 4 additions & 1 deletion src/index/hgraph_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ class HGraphIndex : public Index {
uint64_t
cal_serialize_size() const;

void
add_one_point(const float* data, int level, InnerIdType id);

private:
std::default_random_engine level_generator_{2021};
double mult_{1.0};
Expand All @@ -266,7 +269,7 @@ class HGraphIndex : public Index {
uint64_t max_level_{0};

uint64_t ef_construct_{400};
std::mutex global_mutex_;
std::shared_mutex global_mutex_;

// Locks operations with element by label value
mutable vsag::Vector<std::mutex> label_op_mutex_;
Expand Down

0 comments on commit 66f1d05

Please sign in to comment.