Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the integration of WalkerLog classes in batched drivers #5039

Merged
merged 10 commits into from
Jun 21, 2024
10 changes: 8 additions & 2 deletions src/QMCDrivers/CloneManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"

//comment this out to use only method to clone
#define ENABLE_CLONE_PSI_AND_H
Expand Down Expand Up @@ -87,7 +86,6 @@ CloneManager::~CloneManager()
#if !defined(REMOVE_TRACEMANAGER)
delete_iter(traceClones.begin(), traceClones.end());
#endif
delete_iter(wlog_collectors.begin(), wlog_collectors.end());
}

void CloneManager::makeClones(MCWalkerConfiguration& w, TrialWaveFunction& psi, QMCHamiltonian& ham)
Expand Down Expand Up @@ -282,4 +280,12 @@ CloneManager::RealType CloneManager::acceptRatio() const
return static_cast<RealType>(nAcceptTot) / static_cast<RealType>(nAcceptTot + nRejectTot);
}

RefVector<WalkerLogCollector> CloneManager::getWalkerLogCollectorRefs()
{
RefVector<WalkerLogCollector> refs;
for(int i = 0; i < wlog_collectors.size(); i++)
refs.push_back(*wlog_collectors[i]);
return refs;
}

} // namespace qmcplusplus
4 changes: 3 additions & 1 deletion src/QMCDrivers/CloneManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class CloneManager : public QMCTraits
///trace managers
std::vector<TraceManager*> traceClones;
///trace collectors
std::vector<WalkerLogCollector*> wlog_collectors;
UPtrVector<WalkerLogCollector> wlog_collectors;

//for correlated sampling.
static std::vector<UPtrVector<MCWalkerConfiguration>> WPoolClones_uptr;
Expand All @@ -89,6 +89,8 @@ class CloneManager : public QMCTraits

///Walkers per MPI rank
std::vector<int> wPerRank;

RefVector<WalkerLogCollector> getWalkerLogCollectorRefs();
};
} // namespace qmcplusplus
#endif
24 changes: 21 additions & 3 deletions src/QMCDrivers/Crowd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Crowd::Crowd(EstimatorManagerNew& emb,
const TrialWaveFunction& twf,
const QMCHamiltonian& ham,
const MultiWalkerDispatchers& dispatchers)
: dispatchers_(dispatchers), driverwalker_resource_collection_(driverwalker_res), estimator_manager_crowd_(emb)
: dispatchers_(dispatchers), driverwalker_resource_collection_(driverwalker_res), estimator_manager_crowd_(emb)
{
if (emb.areThereListeners())
{
Expand Down Expand Up @@ -82,16 +82,34 @@ void Crowd::startBlock(int num_steps)
// VMCBatched does no nonlocal moves
n_nonlocal_accept_ = 0;
estimator_manager_crowd_.startBlock(num_steps);
wlog_collector_.startBlock();
if (wlog_collector_)
wlog_collector_->startBlock();
}

void Crowd::stopBlock() { estimator_manager_crowd_.stopBlock(); }

void Crowd::setWalkerLogCollector(std::unique_ptr<WalkerLogCollector>&& collector)
{
wlog_collector_ = std::move(collector);
}

void Crowd::collectStepWalkerLog(int current_step)
{
if (!wlog_collector_)
return;

for (int iw = 0; iw < size(); ++iw)
wlog_collector_.collect(mcp_walkers_[iw], walker_elecs_[iw], walker_twfs_[iw], walker_hamiltonians_[iw], current_step);
wlog_collector_->collect(mcp_walkers_[iw], walker_elecs_[iw], walker_twfs_[iw], walker_hamiltonians_[iw],
current_step);
}

RefVector<WalkerLogCollector> Crowd::getWalkerLogCollectorRefs(const UPtrVector<Crowd>& crowds)
{
RefVector<WalkerLogCollector> refs;
for (auto& crowd : crowds)
if (crowd && crowd->wlog_collector_)
refs.push_back(*crowd->wlog_collector_);
return refs;
}

} // namespace qmcplusplus
14 changes: 9 additions & 5 deletions src/QMCDrivers/Crowd.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "MultiWalkerDispatchers.h"
#include "DriverWalkerTypes.h"
#include "Estimators/EstimatorManagerCrowd.h"
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"

namespace qmcplusplus
{
Expand Down Expand Up @@ -52,9 +52,9 @@ class Crowd
*/
Crowd(EstimatorManagerNew& emb,
const DriverWalkerResourceCollection& driverwalker_res,
const ParticleSet& pset,
const ParticleSet& pset,
const TrialWaveFunction& twf,
const QMCHamiltonian& hamiltonian_temp,
const QMCHamiltonian& hamiltonian_temp,
const MultiWalkerDispatchers& dispatchers);
~Crowd();
/** Because so many vectors allocate them upfront.
Expand Down Expand Up @@ -85,6 +85,8 @@ class Crowd
estimator_manager_crowd_.accumulate(mcp_walkers_, walker_elecs_, walker_twfs_, walker_hamiltonians_, rng);
}

/// activate the collector
void setWalkerLogCollector(std::unique_ptr<WalkerLogCollector>&&);
/// Collect walker log data
void collectStepWalkerLog(int current_step);

Expand All @@ -103,7 +105,6 @@ class Crowd
const RefVector<QMCHamiltonian>& get_walker_hamiltonians() const { return walker_hamiltonians_; }

const EstimatorManagerCrowd& get_estimator_manager_crowd() const { return estimator_manager_crowd_; }
WalkerLogCollector& getWalkerLogCollector() { return wlog_collector_; }

DriverWalkerResourceCollection& getSharedResource() { return driverwalker_resource_collection_; }

Expand All @@ -118,6 +119,9 @@ class Crowd

const MultiWalkerDispatchers& dispatchers_;

/// get refereces of active walker log collectors. If walker logging is disabled, the RefVector size can be zero.
static RefVector<WalkerLogCollector> getWalkerLogCollectorRefs(const UPtrVector<Crowd>& crowds);
PDoakORNL marked this conversation as resolved.
Show resolved Hide resolved

private:
/** @name Walker Vectors
*
Expand All @@ -136,7 +140,7 @@ class Crowd
/// per crowd estimator manager
EstimatorManagerCrowd estimator_manager_crowd_;
// collector for walker logs
WalkerLogCollector wlog_collector_;
std::unique_ptr<WalkerLogCollector> wlog_collector_;

/** @name Step State
*
Expand Down
12 changes: 6 additions & 6 deletions src/QMCDrivers/DMC/DMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void DMC::resetUpdateEngines()
Rng.resize(NumThreads);
estimatorClones.resize(NumThreads, nullptr);
traceClones.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads);
FairDivideLow(W.getActiveWalkers(), NumThreads, wPerRank);

{
Expand Down Expand Up @@ -144,7 +144,7 @@ void DMC::resetUpdateEngines()
Movers[ip]->setSpinMass(SpinMass);
Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
Movers[ip]->initWalkersForPbyP(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
}
else
Expand All @@ -163,7 +163,7 @@ void DMC::resetUpdateEngines()

Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
Movers[ip]->initWalkersForPbyP(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
}
else
Expand All @@ -174,7 +174,7 @@ void DMC::resetUpdateEngines()
Movers[ip] = new DMCUpdateAllWithRejection(*wClones[ip], *psiClones[ip], *hClones[ip], *Rng[ip]);
Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
Movers[ip]->initWalkers(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
}
}
Expand Down Expand Up @@ -239,7 +239,7 @@ bool DMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->startRun(nBlocks, traceClones);
#endif
wlog_manager_->startRun(wlog_collectors);
wlog_manager_->startRun(getWalkerLogCollectorRefs());
IndexType block = 0;
IndexType updatePeriod = (qmc_driver_mode[QMC_UPDATE_MODE]) ? Period4CheckProperties : (nBlocks + 1) * nSteps;
int sample = 0;
Expand Down Expand Up @@ -297,7 +297,7 @@ bool DMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->write_buffers(traceClones, block);
#endif
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager_->writeBuffers();
block++;
if (DumpConfig && block % Period4CheckPoint == 0)
{
Expand Down
19 changes: 10 additions & 9 deletions src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "MemoryUsage.h"
#include "QMCWaveFunctions/TWFGrads.hpp"
#include "TauParams.hpp"
#include "WalkerLogManager.h"

namespace qmcplusplus
{
Expand Down Expand Up @@ -433,12 +434,12 @@ bool DMCBatched::run()

estimator_manager_->startDriverRun();

//start walker log manager
wlog_manager_ = std::make_unique<WalkerLogManager>(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
std::vector<WalkerLogCollector*> wlog_collectors;
for (auto& c: crowds_)
wlog_collectors.push_back(&c->getWalkerLogCollector());
wlog_manager_->startRun(wlog_collectors);
//initialize WalkerLogManager and collectors
WalkerLogManager wlog_manager(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager.makeCollector());
//register walker log collectors into the manager
wlog_manager.startRun(Crowd::getWalkerLogCollectorRefs(crowds_));

StateForThread dmc_state(qmcdriver_input_, *drift_modifier_, *branch_engine_, population_, steps_per_block_);

Expand Down Expand Up @@ -494,7 +495,7 @@ bool DMCBatched::run()
for (UPtr<QMCHamiltonian>& ham : population_.get_hamiltonians())
setNonLocalMoveHandler(*ham);

dmc_state.step = step;
dmc_state.step = step;
dmc_state.global_step = global_step;
crowd_task(crowds_.size(), runDMCStep, dmc_state, timers_, dmc_timers_, std::ref(step_contexts_),
std::ref(crowds_));
Expand All @@ -513,7 +514,7 @@ bool DMCBatched::run()
if (qmcdriver_input_.get_measure_imbalance())
measureImbalance("Block " + std::to_string(block));
endBlock();
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager.writeBuffers();
recordBlock(block);
}

Expand All @@ -536,8 +537,8 @@ bool DMCBatched::run()

print_mem("DMCBatched ends", app_log());

wlog_manager.stopRun();
estimator_manager_->stopDriverRun();
wlog_manager_->stopRun();

return finalize(num_blocks, true);
}
Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/DMC/DMCBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class DMCBatched : public QMCDriverNew
SFNBranch& branch_engine;
IndexType recalculate_properties_period;
const size_t steps_per_block;
IndexType step = -1;
IndexType global_step = -1;
IndexType step = -1;
IndexType global_step = -1;
bool is_recomputing_block = false;

StateForThread(const QMCDriverInput& qmci,
Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/DMC/DMCUpdatePbyPFast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"
//#define TEST_INNERBRANCH


Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "QMCWaveFunctions/WaveFunctionPool.h"
#include "QMCHamiltonians/QMCHamiltonian.h"
#include "Estimators/EstimatorManagerBase.h"
#include "WalkerLogManager.h"
#include "QMCDrivers/DriverTraits.h"
#include "QMCDrivers/QMCDriverInterface.h"
#include "QMCDrivers/GreenFunctionModifiers/DriftModifierBase.h"
Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "Utilities/Timer.h"
#include "Message/UniformCommunicateError.h"
#include "EstimatorInputDelegates.h"
#include "WalkerLogInput.h"
#include "WalkerLogManager.h"


Expand Down
4 changes: 0 additions & 4 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ namespace qmcplusplus
{
//forward declarations: Do not include headers if not needed
class TraceManager;
class WalkerLogManager;
class EstimatorManagerNew;
class TrialWaveFunction;
class QMCHamiltonian;
Expand Down Expand Up @@ -443,9 +442,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
*/
std::unique_ptr<EstimatorManagerNew> estimator_manager_;

/// walker log manager
std::unique_ptr<WalkerLogManager> wlog_manager_;

///record engine for walkers
std::unique_ptr<HDFWalkerOutput> wOut;

Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/QMCUpdateBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"

namespace qmcplusplus
{
Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCUpdateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "QMCWaveFunctions/TrialWaveFunction.h"
#include "QMCHamiltonians/QMCHamiltonian.h"
#include "QMCHamiltonians/NonLocalTOperator.h"
#include "WalkerLogManager.h"
#include "GreenFunctionModifiers/DriftModifierBase.h"
#include "SimpleFixedNodeBranch.h"
#include "DriverDebugChecks.h"
Expand Down
8 changes: 4 additions & 4 deletions src/QMCDrivers/VMC/VMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ bool VMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->startRun(nBlocks, traceClones);
#endif
wlog_manager_->startRun(wlog_collectors);
wlog_manager_->startRun(getWalkerLogCollectorRefs());

LoopTimer<> vmc_loop;
RunTimeControl<> runtimeControl(run_time_manager, MaxCPUSecs, myComm->getName(), myComm->rank() == 0);
Expand Down Expand Up @@ -111,7 +111,7 @@ bool VMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->write_buffers(traceClones, block);
#endif
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager_->writeBuffers();
recordBlock(block);
vmc_loop.stop();

Expand Down Expand Up @@ -169,7 +169,7 @@ void VMC::resetRun()
Movers.resize(NumThreads, nullptr);
estimatorClones.resize(NumThreads, nullptr);
traceClones.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads);
Rng.resize(NumThreads);

// hdf_archive::hdf_archive() is not thread-safe
Expand Down Expand Up @@ -267,7 +267,7 @@ void VMC::resetRun()
//int ip=omp_get_thread_num();
Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
if (qmc_driver_mode[QMC_UPDATE_MODE])
Movers[ip]->initWalkersForPbyP(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
else
Expand Down
Loading
Loading