Skip to content

Commit

Permalink
Move wlog_manager to function scope.
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Jun 13, 2024
1 parent 2534ad4 commit 91e242d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 26 deletions.
18 changes: 7 additions & 11 deletions src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,6 @@ void DMCBatched::process(xmlNodePtr node)
myComm->barrier_and_abort(ue.what());
}

{ //initialize WalkerLogManager and collectors
wlog_manager_ = std::make_unique<WalkerLogManager>(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager_->makeCollector());
}

{
ReportEngine PRE("DMC", "resetUpdateEngines");
Timer init_timer;
Expand Down Expand Up @@ -440,9 +434,12 @@ bool DMCBatched::run()

estimator_manager_->startDriverRun();

//initialize WalkerLogManager and collectors
WalkerLogManager wlog_manager(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager.makeCollector());
//register walker log collectors into the manager
if (wlog_manager_)
wlog_manager_->startRun(Crowd::getWalkerLogCollectorRefs(crowds_));
wlog_manager.startRun(Crowd::getWalkerLogCollectorRefs(crowds_));

StateForThread dmc_state(qmcdriver_input_, *drift_modifier_, *branch_engine_, population_, steps_per_block_);

Expand Down Expand Up @@ -517,8 +514,7 @@ bool DMCBatched::run()
if (qmcdriver_input_.get_measure_imbalance())
measureImbalance("Block " + std::to_string(block));
endBlock();
if (wlog_manager_)
wlog_manager_->writeBuffers();
wlog_manager.writeBuffers();
recordBlock(block);
}

Expand All @@ -541,8 +537,8 @@ bool DMCBatched::run()

print_mem("DMCBatched ends", app_log());

wlog_manager.stopRun();
estimator_manager_->stopDriverRun();
wlog_manager_->stopRun();

return finalize(num_blocks, true);
}
Expand Down
4 changes: 0 additions & 4 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ namespace qmcplusplus
{
//forward declarations: Do not include headers if not needed
class TraceManager;
class WalkerLogManager;
class EstimatorManagerNew;
class TrialWaveFunction;
class QMCHamiltonian;
Expand Down Expand Up @@ -443,9 +442,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
*/
std::unique_ptr<EstimatorManagerNew> estimator_manager_;

/// walker log manager
std::unique_ptr<WalkerLogManager> wlog_manager_;

///record engine for walkers
std::unique_ptr<HDFWalkerOutput> wOut;

Expand Down
19 changes: 8 additions & 11 deletions src/QMCDrivers/VMC/VMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,6 @@ void VMCBatched::process(xmlNodePtr node)
{
myComm->barrier_and_abort(ue.what());
}

{ //initialize WalkerLogManager and collectors
wlog_manager_ = std::make_unique<WalkerLogManager>(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager_->makeCollector());
}
}

size_t VMCBatched::compute_samples_per_rank(const size_t num_blocks,
Expand Down Expand Up @@ -315,9 +309,13 @@ bool VMCBatched::run()
IndexType num_blocks = qmcdriver_input_.get_max_blocks();
//start the main estimator
estimator_manager_->startDriverRun();

//initialize WalkerLogManager and collectors
WalkerLogManager wlog_manager(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager.makeCollector());
//register walker log collectors into the manager
if (wlog_manager_)
wlog_manager_->startRun(Crowd::getWalkerLogCollectorRefs(crowds_));
wlog_manager.startRun(Crowd::getWalkerLogCollectorRefs(crowds_));

StateForThread vmc_state(qmcdriver_input_, vmcdriver_input_, *drift_modifier_, population_, steps_per_block_);

Expand Down Expand Up @@ -406,8 +404,7 @@ bool VMCBatched::run()
if (qmcdriver_input_.get_measure_imbalance())
measureImbalance("Block " + std::to_string(block));
endBlock();
if (wlog_manager_)
wlog_manager_->writeBuffers();
wlog_manager.writeBuffers();
recordBlock(block);
}

Expand Down Expand Up @@ -453,8 +450,8 @@ bool VMCBatched::run()

print_mem("VMCBatched ends", app_log());

wlog_manager.stopRun();
estimator_manager_->stopDriverRun();
wlog_manager_->stopRun();

return finalize(num_blocks, true);
}
Expand Down

0 comments on commit 91e242d

Please sign in to comment.