Skip to content

Commit

Permalink
Merge pull request #5039 from ye-luo/address-WalkerLog
Browse files Browse the repository at this point in the history
Improve the integration of WalkerLog classes in batched drivers
  • Loading branch information
ye-luo authored Jun 21, 2024
2 parents 935919f + c242385 commit 5721c74
Show file tree
Hide file tree
Showing 21 changed files with 145 additions and 128 deletions.
10 changes: 8 additions & 2 deletions src/QMCDrivers/CloneManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"

//comment this out to use only method to clone
#define ENABLE_CLONE_PSI_AND_H
Expand Down Expand Up @@ -87,7 +86,6 @@ CloneManager::~CloneManager()
#if !defined(REMOVE_TRACEMANAGER)
delete_iter(traceClones.begin(), traceClones.end());
#endif
delete_iter(wlog_collectors.begin(), wlog_collectors.end());
}

void CloneManager::makeClones(MCWalkerConfiguration& w, TrialWaveFunction& psi, QMCHamiltonian& ham)
Expand Down Expand Up @@ -282,4 +280,12 @@ CloneManager::RealType CloneManager::acceptRatio() const
return static_cast<RealType>(nAcceptTot) / static_cast<RealType>(nAcceptTot + nRejectTot);
}

RefVector<WalkerLogCollector> CloneManager::getWalkerLogCollectorRefs()
{
RefVector<WalkerLogCollector> refs;
for(int i = 0; i < wlog_collectors.size(); i++)
refs.push_back(*wlog_collectors[i]);
return refs;
}

} // namespace qmcplusplus
4 changes: 3 additions & 1 deletion src/QMCDrivers/CloneManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class CloneManager : public QMCTraits
///trace managers
std::vector<TraceManager*> traceClones;
///trace collectors
std::vector<WalkerLogCollector*> wlog_collectors;
UPtrVector<WalkerLogCollector> wlog_collectors;

//for correlated sampling.
static std::vector<UPtrVector<MCWalkerConfiguration>> WPoolClones_uptr;
Expand All @@ -89,6 +89,8 @@ class CloneManager : public QMCTraits

///Walkers per MPI rank
std::vector<int> wPerRank;

RefVector<WalkerLogCollector> getWalkerLogCollectorRefs();
};
} // namespace qmcplusplus
#endif
24 changes: 21 additions & 3 deletions src/QMCDrivers/Crowd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Crowd::Crowd(EstimatorManagerNew& emb,
const TrialWaveFunction& twf,
const QMCHamiltonian& ham,
const MultiWalkerDispatchers& dispatchers)
: dispatchers_(dispatchers), driverwalker_resource_collection_(driverwalker_res), estimator_manager_crowd_(emb)
: dispatchers_(dispatchers), driverwalker_resource_collection_(driverwalker_res), estimator_manager_crowd_(emb)
{
if (emb.areThereListeners())
{
Expand Down Expand Up @@ -82,16 +82,34 @@ void Crowd::startBlock(int num_steps)
// VMCBatched does no nonlocal moves
n_nonlocal_accept_ = 0;
estimator_manager_crowd_.startBlock(num_steps);
wlog_collector_.startBlock();
if (wlog_collector_)
wlog_collector_->startBlock();
}

void Crowd::stopBlock() { estimator_manager_crowd_.stopBlock(); }

void Crowd::setWalkerLogCollector(std::unique_ptr<WalkerLogCollector>&& collector)
{
wlog_collector_ = std::move(collector);
}

void Crowd::collectStepWalkerLog(int current_step)
{
if (!wlog_collector_)
return;

for (int iw = 0; iw < size(); ++iw)
wlog_collector_.collect(mcp_walkers_[iw], walker_elecs_[iw], walker_twfs_[iw], walker_hamiltonians_[iw], current_step);
wlog_collector_->collect(mcp_walkers_[iw], walker_elecs_[iw], walker_twfs_[iw], walker_hamiltonians_[iw],
current_step);
}

RefVector<WalkerLogCollector> Crowd::getWalkerLogCollectorRefs(const UPtrVector<Crowd>& crowds)
{
RefVector<WalkerLogCollector> refs;
for (auto& crowd : crowds)
if (crowd && crowd->wlog_collector_)
refs.push_back(*crowd->wlog_collector_);
return refs;
}

} // namespace qmcplusplus
14 changes: 9 additions & 5 deletions src/QMCDrivers/Crowd.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "MultiWalkerDispatchers.h"
#include "DriverWalkerTypes.h"
#include "Estimators/EstimatorManagerCrowd.h"
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"

namespace qmcplusplus
{
Expand Down Expand Up @@ -52,9 +52,9 @@ class Crowd
*/
Crowd(EstimatorManagerNew& emb,
const DriverWalkerResourceCollection& driverwalker_res,
const ParticleSet& pset,
const ParticleSet& pset,
const TrialWaveFunction& twf,
const QMCHamiltonian& hamiltonian_temp,
const QMCHamiltonian& hamiltonian_temp,
const MultiWalkerDispatchers& dispatchers);
~Crowd();
/** Because so many vectors allocate them upfront.
Expand Down Expand Up @@ -85,6 +85,8 @@ class Crowd
estimator_manager_crowd_.accumulate(mcp_walkers_, walker_elecs_, walker_twfs_, walker_hamiltonians_, rng);
}

/// activate the collector
void setWalkerLogCollector(std::unique_ptr<WalkerLogCollector>&&);
/// Collect walker log data
void collectStepWalkerLog(int current_step);

Expand All @@ -103,7 +105,6 @@ class Crowd
const RefVector<QMCHamiltonian>& get_walker_hamiltonians() const { return walker_hamiltonians_; }

const EstimatorManagerCrowd& get_estimator_manager_crowd() const { return estimator_manager_crowd_; }
WalkerLogCollector& getWalkerLogCollector() { return wlog_collector_; }

DriverWalkerResourceCollection& getSharedResource() { return driverwalker_resource_collection_; }

Expand All @@ -118,6 +119,9 @@ class Crowd

const MultiWalkerDispatchers& dispatchers_;

/// get refereces of active walker log collectors. If walker logging is disabled, the RefVector size can be zero.
static RefVector<WalkerLogCollector> getWalkerLogCollectorRefs(const UPtrVector<Crowd>& crowds);

private:
/** @name Walker Vectors
*
Expand All @@ -136,7 +140,7 @@ class Crowd
/// per crowd estimator manager
EstimatorManagerCrowd estimator_manager_crowd_;
// collector for walker logs
WalkerLogCollector wlog_collector_;
std::unique_ptr<WalkerLogCollector> wlog_collector_;

/** @name Step State
*
Expand Down
12 changes: 6 additions & 6 deletions src/QMCDrivers/DMC/DMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void DMC::resetUpdateEngines()
Rng.resize(NumThreads);
estimatorClones.resize(NumThreads, nullptr);
traceClones.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads);
FairDivideLow(W.getActiveWalkers(), NumThreads, wPerRank);

{
Expand Down Expand Up @@ -144,7 +144,7 @@ void DMC::resetUpdateEngines()
Movers[ip]->setSpinMass(SpinMass);
Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
Movers[ip]->initWalkersForPbyP(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
}
else
Expand All @@ -163,7 +163,7 @@ void DMC::resetUpdateEngines()

Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
Movers[ip]->initWalkersForPbyP(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
}
else
Expand All @@ -174,7 +174,7 @@ void DMC::resetUpdateEngines()
Movers[ip] = new DMCUpdateAllWithRejection(*wClones[ip], *psiClones[ip], *hClones[ip], *Rng[ip]);
Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
Movers[ip]->initWalkers(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
}
}
Expand Down Expand Up @@ -239,7 +239,7 @@ bool DMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->startRun(nBlocks, traceClones);
#endif
wlog_manager_->startRun(wlog_collectors);
wlog_manager_->startRun(getWalkerLogCollectorRefs());
IndexType block = 0;
IndexType updatePeriod = (qmc_driver_mode[QMC_UPDATE_MODE]) ? Period4CheckProperties : (nBlocks + 1) * nSteps;
int sample = 0;
Expand Down Expand Up @@ -297,7 +297,7 @@ bool DMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->write_buffers(traceClones, block);
#endif
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager_->writeBuffers();
block++;
if (DumpConfig && block % Period4CheckPoint == 0)
{
Expand Down
19 changes: 10 additions & 9 deletions src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "MemoryUsage.h"
#include "QMCWaveFunctions/TWFGrads.hpp"
#include "TauParams.hpp"
#include "WalkerLogManager.h"

namespace qmcplusplus
{
Expand Down Expand Up @@ -433,12 +434,12 @@ bool DMCBatched::run()

estimator_manager_->startDriverRun();

//start walker log manager
wlog_manager_ = std::make_unique<WalkerLogManager>(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
std::vector<WalkerLogCollector*> wlog_collectors;
for (auto& c: crowds_)
wlog_collectors.push_back(&c->getWalkerLogCollector());
wlog_manager_->startRun(wlog_collectors);
//initialize WalkerLogManager and collectors
WalkerLogManager wlog_manager(walker_logs_input, allow_walker_logs, get_root_name(), myComm);
for (auto& crowd : crowds_)
crowd->setWalkerLogCollector(wlog_manager.makeCollector());
//register walker log collectors into the manager
wlog_manager.startRun(Crowd::getWalkerLogCollectorRefs(crowds_));

StateForThread dmc_state(qmcdriver_input_, *drift_modifier_, *branch_engine_, population_, steps_per_block_);

Expand Down Expand Up @@ -494,7 +495,7 @@ bool DMCBatched::run()
for (UPtr<QMCHamiltonian>& ham : population_.get_hamiltonians())
setNonLocalMoveHandler(*ham);

dmc_state.step = step;
dmc_state.step = step;
dmc_state.global_step = global_step;
crowd_task(crowds_.size(), runDMCStep, dmc_state, timers_, dmc_timers_, std::ref(step_contexts_),
std::ref(crowds_));
Expand All @@ -513,7 +514,7 @@ bool DMCBatched::run()
if (qmcdriver_input_.get_measure_imbalance())
measureImbalance("Block " + std::to_string(block));
endBlock();
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager.writeBuffers();
recordBlock(block);
}

Expand All @@ -536,8 +537,8 @@ bool DMCBatched::run()

print_mem("DMCBatched ends", app_log());

wlog_manager.stopRun();
estimator_manager_->stopDriverRun();
wlog_manager_->stopRun();

return finalize(num_blocks, true);
}
Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/DMC/DMCBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class DMCBatched : public QMCDriverNew
SFNBranch& branch_engine;
IndexType recalculate_properties_period;
const size_t steps_per_block;
IndexType step = -1;
IndexType global_step = -1;
IndexType step = -1;
IndexType global_step = -1;
bool is_recomputing_block = false;

StateForThread(const QMCDriverInput& qmci,
Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/DMC/DMCUpdatePbyPFast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"
//#define TEST_INNERBRANCH


Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "QMCWaveFunctions/WaveFunctionPool.h"
#include "QMCHamiltonians/QMCHamiltonian.h"
#include "Estimators/EstimatorManagerBase.h"
#include "WalkerLogManager.h"
#include "QMCDrivers/DriverTraits.h"
#include "QMCDrivers/QMCDriverInterface.h"
#include "QMCDrivers/GreenFunctionModifiers/DriftModifierBase.h"
Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "Utilities/Timer.h"
#include "Message/UniformCommunicateError.h"
#include "EstimatorInputDelegates.h"
#include "WalkerLogInput.h"
#include "WalkerLogManager.h"


Expand Down
4 changes: 0 additions & 4 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ namespace qmcplusplus
{
//forward declarations: Do not include headers if not needed
class TraceManager;
class WalkerLogManager;
class EstimatorManagerNew;
class TrialWaveFunction;
class QMCHamiltonian;
Expand Down Expand Up @@ -443,9 +442,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
*/
std::unique_ptr<EstimatorManagerNew> estimator_manager_;

/// walker log manager
std::unique_ptr<WalkerLogManager> wlog_manager_;

///record engine for walkers
std::unique_ptr<HDFWalkerOutput> wOut;

Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/QMCUpdateBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#else
using TraceManager = int;
#endif
#include "WalkerLogManager.h"
#include "WalkerLogCollector.h"

namespace qmcplusplus
{
Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/QMCUpdateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "QMCWaveFunctions/TrialWaveFunction.h"
#include "QMCHamiltonians/QMCHamiltonian.h"
#include "QMCHamiltonians/NonLocalTOperator.h"
#include "WalkerLogManager.h"
#include "GreenFunctionModifiers/DriftModifierBase.h"
#include "SimpleFixedNodeBranch.h"
#include "DriverDebugChecks.h"
Expand Down
8 changes: 4 additions & 4 deletions src/QMCDrivers/VMC/VMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ bool VMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->startRun(nBlocks, traceClones);
#endif
wlog_manager_->startRun(wlog_collectors);
wlog_manager_->startRun(getWalkerLogCollectorRefs());

LoopTimer<> vmc_loop;
RunTimeControl<> runtimeControl(run_time_manager, MaxCPUSecs, myComm->getName(), myComm->rank() == 0);
Expand Down Expand Up @@ -111,7 +111,7 @@ bool VMC::run()
#if !defined(REMOVE_TRACEMANAGER)
Traces->write_buffers(traceClones, block);
#endif
wlog_manager_->writeBuffers(wlog_collectors);
wlog_manager_->writeBuffers();
recordBlock(block);
vmc_loop.stop();

Expand Down Expand Up @@ -169,7 +169,7 @@ void VMC::resetRun()
Movers.resize(NumThreads, nullptr);
estimatorClones.resize(NumThreads, nullptr);
traceClones.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads, nullptr);
wlog_collectors.resize(NumThreads);
Rng.resize(NumThreads);

// hdf_archive::hdf_archive() is not thread-safe
Expand Down Expand Up @@ -267,7 +267,7 @@ void VMC::resetRun()
//int ip=omp_get_thread_num();
Movers[ip]->put(qmcNode);
//Movers[ip]->resetRun(branchEngine.get(), estimatorClones[ip], traceClones[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip], DriftModifier);
Movers[ip]->resetRun2(branchEngine.get(), estimatorClones[ip], traceClones[ip], wlog_collectors[ip].get(), DriftModifier);
if (qmc_driver_mode[QMC_UPDATE_MODE])
Movers[ip]->initWalkersForPbyP(W.begin() + wPerRank[ip], W.begin() + wPerRank[ip + 1]);
else
Expand Down
Loading

0 comments on commit 5721c74

Please sign in to comment.