Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the integration of WalkerLog classes in batched drivers #5039

Merged
merged 10 commits into from
Jun 21, 2024
9 changes: 8 additions & 1 deletion src/QMCDrivers/CloneManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"

//comment this out to use only method to clone
#define ENABLE_CLONE_PSI_AND_H
Expand Down Expand Up @@ -282,4 +281,12 @@ CloneManager::RealType CloneManager::acceptRatio() const
return static_cast<RealType>(nAcceptTot) / static_cast<RealType>(nAcceptTot + nRejectTot);
}

RefVector<WalkerLogCollector> CloneManager::getWalkerLogCollectorRefs()
{
RefVector<WalkerLogCollector> refs;
for(int i = 0; i < wlog_collectors.size(); i++)
refs.push_back(*wlog_collectors[i]);
return refs;
}

} // namespace qmcplusplus
2 changes: 2 additions & 0 deletions src/QMCDrivers/CloneManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class CloneManager : public QMCTraits

///Walkers per MPI rank
std::vector<int> wPerRank;

RefVector<WalkerLogCollector> getWalkerLogCollectorRefs();
};
} // namespace qmcplusplus
#endif
24 changes: 21 additions & 3 deletions src/QMCDrivers/Crowd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Crowd::Crowd(EstimatorManagerNew& emb,
const TrialWaveFunction& twf,
const QMCHamiltonian& ham,
const MultiWalkerDispatchers& dispatchers)
: dispatchers_(dispatchers), driverwalker_resource_collection_(driverwalker_res), estimator_manager_crowd_(emb)
: dispatchers_(dispatchers), driverwalker_resource_collection_(driverwalker_res), estimator_manager_crowd_(emb)
{
if (emb.areThereListeners())
{
Expand Down Expand Up @@ -82,16 +82,34 @@ void Crowd::startBlock(int num_steps)
// VMCBatched does no nonlocal moves
n_nonlocal_accept_ = 0;
estimator_manager_crowd_.startBlock(num_steps);
wlog_collector_.startBlock();
if (wlog_collector_)
wlog_collector_->startBlock();
}

void Crowd::stopBlock() { estimator_manager_crowd_.stopBlock(); }

void Crowd::setWalkerLogCollector(std::unique_ptr<WalkerLogCollector>&& collector)
{
wlog_collector_ = std::move(collector);
}

void Crowd::collectStepWalkerLog(int current_step)
{
if (!wlog_collector_)
return;

for (int iw = 0; iw < size(); ++iw)
wlog_collector_.collect(mcp_walkers_[iw], walker_elecs_[iw], walker_twfs_[iw], walker_hamiltonians_[iw], current_step);
wlog_collector_->collect(mcp_walkers_[iw], walker_elecs_[iw], walker_twfs_[iw], walker_hamiltonians_[iw],
current_step);
}

RefVector<WalkerLogCollector> Crowd::getWalkerLogCollectorRefs(const UPtrVector<Crowd>& crowds)
{
RefVector<WalkerLogCollector> refs;
for (auto& crowd : crowds)
if (crowd && crowd->wlog_collector_)
refs.push_back(*crowd->wlog_collector_);
return refs;
}

} // namespace qmcplusplus
13 changes: 8 additions & 5 deletions src/QMCDrivers/Crowd.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "MultiWalkerDispatchers.h"
#include "DriverWalkerTypes.h"
#include "Estimators/EstimatorManagerCrowd.h"
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"

namespace qmcplusplus
{
Expand Down Expand Up @@ -52,9 +52,9 @@ class Crowd
*/
Crowd(EstimatorManagerNew& emb,
const DriverWalkerResourceCollection& driverwalker_res,
const ParticleSet& pset,
const ParticleSet& pset,
const TrialWaveFunction& twf,
const QMCHamiltonian& hamiltonian_temp,
const QMCHamiltonian& hamiltonian_temp,
const MultiWalkerDispatchers& dispatchers);
~Crowd();
/** Because so many vectors allocate them upfront.
Expand Down Expand Up @@ -85,6 +85,8 @@ class Crowd
estimator_manager_crowd_.accumulate(mcp_walkers_, walker_elecs_, walker_twfs_, walker_hamiltonians_, rng);
}

/// activate the collector
void setWalkerLogCollector(std::unique_ptr<WalkerLogCollector>&&);
/// Collect walker log data
void collectStepWalkerLog(int current_step);

Expand All @@ -103,7 +105,6 @@ class Crowd
const RefVector<QMCHamiltonian>& get_walker_hamiltonians() const { return walker_hamiltonians_; }

const EstimatorManagerCrowd& get_estimator_manager_crowd() const { return estimator_manager_crowd_; }
WalkerLogCollector& getWalkerLogCollector() { return wlog_collector_; }

DriverWalkerResourceCollection& getSharedResource() { return driverwalker_resource_collection_; }

Expand All @@ -118,6 +119,8 @@ class Crowd

const MultiWalkerDispatchers& dispatchers_;

static RefVector<WalkerLogCollector> getWalkerLogCollectorRefs(const UPtrVector<Crowd>& crowds);
PDoakORNL marked this conversation as resolved.
Show resolved Hide resolved

private:
/** @name Walker Vectors
*
Expand All @@ -136,7 +139,7 @@ class Crowd
/// per crowd estimator manager
EstimatorManagerCrowd estimator_manager_crowd_;
// collector for walker logs
WalkerLogCollector wlog_collector_;
std::unique_ptr<WalkerLogCollector> wlog_collector_;

/** @name Step State
*
Expand Down
6 changes: 3 additions & 3 deletions src/QMCDrivers/DMC/DMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void DMC::resetUpdateEngines()
#if !defined(REMOVE_TRACEMANAGER)
traceClones[ip] = Traces->makeClone();
#endif
wlog_collectors[ip] = wlog_manager_->makeCollector();
wlog_collectors[ip] = wlog_manager_->makeCollector().release();
Rng[ip] = rngs_[ip]->makeClone();
hClones[ip]->setRandomGenerator(Rng[ip].get());
if (W.isSpinor())
Expand Down Expand Up @@ -239,7 +239,7 @@ bool DMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->startRun(nBlocks, traceClones);
#endif
wlog_manager_->startRun(wlog_collectors);
wlog_manager_->startRun(getWalkerLogCollectorRefs());
IndexType block = 0;
IndexType updatePeriod = (qmc_driver_mode[QMC_UPDATE_MODE]) ? Period4CheckProperties : (nBlocks + 1) * nSteps;
int sample = 0;
Expand Down Expand Up @@ -297,7 +297,7 @@ bool DMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->write_buffers(traceClones, block);
#endif
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager_->writeBuffers();
block++;
if (DumpConfig && block % Period4CheckPoint == 0)
{
Expand Down
19 changes: 10 additions & 9 deletions src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "MemoryUsage.h"
#include "QMCWaveFunctions/TWFGrads.hpp"
#include "TauParams.hpp"
#include "WalkerLogManager.h"

namespace qmcplusplus
{
Expand Down Expand Up @@ -433,12 +434,12 @@ bool DMCBatched::run()

estimator_manager_->startDriverRun();

//start walker log manager
wlog_manager_ = std::make_unique<WalkerLogManager>(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
std::vector<WalkerLogCollector*> wlog_collectors;
for (auto& c: crowds_)
wlog_collectors.push_back(&c->getWalkerLogCollector());
wlog_manager_->startRun(wlog_collectors);
//initialize WalkerLogManager and collectors
WalkerLogManager wlog_manager(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager.makeCollector());
//register walker log collectors into the manager
wlog_manager.startRun(Crowd::getWalkerLogCollectorRefs(crowds_));

StateForThread dmc_state(qmcdriver_input_, *drift_modifier_, *branch_engine_, population_, steps_per_block_);

Expand Down Expand Up @@ -494,7 +495,7 @@ bool DMCBatched::run()
for (UPtr<QMCHamiltonian>& ham : population_.get_hamiltonians())
setNonLocalMoveHandler(*ham);

dmc_state.step = step;
dmc_state.step = step;
dmc_state.global_step = global_step;
crowd_task(crowds_.size(), runDMCStep, dmc_state, timers_, dmc_timers_, std::ref(step_contexts_),
std::ref(crowds_));
Expand All @@ -513,7 +514,7 @@ bool DMCBatched::run()
if (qmcdriver_input_.get_measure_imbalance())
measureImbalance("Block " + std::to_string(block));
endBlock();
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager.writeBuffers();
recordBlock(block);
}

Expand All @@ -536,8 +537,8 @@ bool DMCBatched::run()

print_mem("DMCBatched ends", app_log());

wlog_manager.stopRun();
estimator_manager_->stopDriverRun();
wlog_manager_->stopRun();

return finalize(num_blocks, true);
}
Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/DMC/DMCBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class DMCBatched : public QMCDriverNew
SFNBranch& branch_engine;
IndexType recalculate_properties_period;
const size_t steps_per_block;
IndexType step = -1;
IndexType global_step = -1;
IndexType step = -1;
IndexType global_step = -1;
bool is_recomputing_block = false;

StateForThread(const QMCDriverInput& qmci,
Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/DMC/DMCUpdatePbyPFast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"
//#define TEST_INNERBRANCH


Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "QMCWaveFunctions/WaveFunctionPool.h"
#include "QMCHamiltonians/QMCHamiltonian.h"
#include "Estimators/EstimatorManagerBase.h"
#include "WalkerLogManager.h"
#include "QMCDrivers/DriverTraits.h"
#include "QMCDrivers/QMCDriverInterface.h"
#include "QMCDrivers/GreenFunctionModifiers/DriftModifierBase.h"
Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "Utilities/Timer.h"
#include "Message/UniformCommunicateError.h"
#include "EstimatorInputDelegates.h"
#include "WalkerLogInput.h"
#include "WalkerLogManager.h"


Expand Down
4 changes: 0 additions & 4 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ namespace qmcplusplus
{
//forward declarations: Do not include headers if not needed
class TraceManager;
class WalkerLogManager;
class EstimatorManagerNew;
class TrialWaveFunction;
class QMCHamiltonian;
Expand Down Expand Up @@ -443,9 +442,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
*/
std::unique_ptr<EstimatorManagerNew> estimator_manager_;

/// walker log manager
std::unique_ptr<WalkerLogManager> wlog_manager_;

///record engine for walkers
std::unique_ptr<HDFWalkerOutput> wOut;

Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/QMCUpdateBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"

namespace qmcplusplus
{
Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCUpdateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "QMCWaveFunctions/TrialWaveFunction.h"
#include "QMCHamiltonians/QMCHamiltonian.h"
#include "QMCHamiltonians/NonLocalTOperator.h"
#include "WalkerLogManager.h"
#include "GreenFunctionModifiers/DriftModifierBase.h"
#include "SimpleFixedNodeBranch.h"
#include "DriverDebugChecks.h"
Expand Down
6 changes: 3 additions & 3 deletions src/QMCDrivers/VMC/VMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ bool VMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->startRun(nBlocks, traceClones);
#endif
wlog_manager_->startRun(wlog_collectors);
wlog_manager_->startRun(getWalkerLogCollectorRefs());

LoopTimer<> vmc_loop;
RunTimeControl<> runtimeControl(run_time_manager, MaxCPUSecs, myComm->getName(), myComm->rank() == 0);
Expand Down Expand Up @@ -111,7 +111,7 @@ bool VMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->write_buffers(traceClones, block);
#endif
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager_->writeBuffers();
recordBlock(block);
vmc_loop.stop();

Expand Down Expand Up @@ -185,7 +185,7 @@ void VMC::resetRun()
#if !defined(REMOVE_TRACEMANAGER)
traceClones[ip] = Traces->makeClone();
#endif
wlog_collectors[ip] = wlog_manager_->makeCollector();
wlog_collectors[ip] = wlog_manager_->makeCollector().release();
Rng[ip] = rngs_[ip]->makeClone();
hClones[ip]->setRandomGenerator(Rng[ip].get());
if (W.isSpinor())
Expand Down
21 changes: 11 additions & 10 deletions src/QMCDrivers/VMC/VMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,13 @@ bool VMCBatched::run()
IndexType num_blocks = qmcdriver_input_.get_max_blocks();
//start the main estimator
estimator_manager_->startDriverRun();
//start walker log manager
wlog_manager_ = std::make_unique<WalkerLogManager>(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
std::vector<WalkerLogCollector*> wlog_collectors;
for (auto& c: crowds_)
wlog_collectors.push_back(&c->getWalkerLogCollector());
wlog_manager_->startRun(wlog_collectors);

//initialize WalkerLogManager and collectors
WalkerLogManager wlog_manager(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager.makeCollector());
//register walker log collectors into the manager
wlog_manager.startRun(Crowd::getWalkerLogCollectorRefs(crowds_));

StateForThread vmc_state(qmcdriver_input_, vmcdriver_input_, *drift_modifier_, population_, steps_per_block_);

Expand Down Expand Up @@ -386,7 +387,7 @@ bool VMCBatched::run()
for (int step = 0; step < steps_per_block_; ++step, ++global_step)
{
ScopedTimer local_timer(timers_.run_steps_timer);
vmc_state.step = step;
vmc_state.step = step;
vmc_state.global_step = global_step;
crowd_task(crowds_.size(), runVMCStep, vmc_state, timers_, std::ref(step_contexts_), std::ref(crowds_));

Expand All @@ -403,7 +404,7 @@ bool VMCBatched::run()
if (qmcdriver_input_.get_measure_imbalance())
measureImbalance("Block " + std::to_string(block));
endBlock();
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager.writeBuffers();
recordBlock(block);
}

Expand Down Expand Up @@ -449,8 +450,8 @@ bool VMCBatched::run()

print_mem("VMCBatched ends", app_log());

wlog_manager.stopRun();
jtkrogel marked this conversation as resolved.
Show resolved Hide resolved
estimator_manager_->stopDriverRun();
wlog_manager_->stopRun();

return finalize(num_blocks, true);
}
Expand All @@ -459,7 +460,7 @@ void VMCBatched::enable_sample_collection()
{
assert(steps_per_block_ > 0 && "VMCBatched::enable_sample_collection steps_per_block_ must be positive!");
auto samples = compute_samples_per_rank(qmcdriver_input_.get_max_blocks(), steps_per_block_,
population_.get_num_local_walkers());
population_.get_num_local_walkers());
samples_.setMaxSamples(samples, population_.get_num_ranks());
collect_samples_ = true;

Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/VMC/VMCBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class VMCBatched : public QMCDriverNew
IndexType recalculate_properties_period;
const size_t steps_per_block;
IndexType step = -1;
IndexType global_step = -1;
IndexType global_step = -1;
bool is_recomputing_block = false;

StateForThread(const QMCDriverInput& qmci,
Expand Down
Loading
Loading