Skip to content

Commit

Permalink
Remove TrainMode (#47)
Browse files Browse the repository at this point in the history
* SubmeshTopologyHandler treats each subdomain as unique component.

* remove TrainMode

* remove individual test
  • Loading branch information
dreamer2368 authored Jun 19, 2024
1 parent ec9f027 commit 83597be
Show file tree
Hide file tree
Showing 16 changed files with 100 additions and 386 deletions.
2 changes: 1 addition & 1 deletion include/main_workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void RunExample();

MultiBlockSolver* InitSolver();
SampleGenerator* InitSampleGenerator(MPI_Comm comm);
std::vector<std::string> GetGlobalBasisTagList(const TopologyHandlerMode &topol_mode, const TrainMode &train_mode, bool separate_variable_basis);
std::vector<std::string> GetGlobalBasisTagList(const TopologyHandlerMode &topol_mode, bool separate_variable_basis);

void GenerateSamples(MPI_Comm comm);
void CollectSamples(SampleGenerator *sample_generator);
Expand Down
2 changes: 0 additions & 2 deletions include/multiblock_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ friend class ParameterizedProblem;

// rom variables.
ROMHandlerBase *rom_handler = NULL;
TrainMode train_mode = NUM_TRAINMODE;
bool use_rom = false;
bool separate_variable_basis = false;

Expand Down Expand Up @@ -136,7 +135,6 @@ friend class ParameterizedProblem;
const bool UseRom() const { return use_rom; }
ROMHandlerBase* GetROMHandler() const { return rom_handler; }
TopologyHandler* GetTopologyHandler() const { return topol_handler; }
const TrainMode GetTrainMode() { return train_mode; }
const bool IsVisualizationSaved() const { return visual.save; }
const std::string GetSolutionFilePrefix() const { return sol_prefix; }
const std::string GetVisualizationPrefix() const { return visual.prefix; }
Expand Down
23 changes: 6 additions & 17 deletions include/rom_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
namespace mfem
{

enum TrainMode
{
INDIVIDUAL,
UNIVERSAL,
NUM_TRAINMODE
};

enum ROMBuildingLevel
{
NONE,
Expand All @@ -37,10 +30,8 @@ enum NonlinearHandling
NUM_NLNHNDL
};

const TrainMode SetTrainMode();

const std::string GetBasisTagForComponent(const int &comp_idx, const TrainMode &train_mode, const TopologyHandler *topol_handler, const std::string var_name="");
const std::string GetBasisTag(const int &subdomain_index, const TrainMode &train_mode, const TopologyHandler *topol_handler, const std::string var_name="");
const std::string GetBasisTagForComponent(const int &comp_idx, const TopologyHandler *topol_handler, const std::string var_name="");
const std::string GetBasisTag(const int &subdomain_index, const TopologyHandler *topol_handler, const std::string var_name="");

class ROMHandlerBase
{
Expand All @@ -61,7 +52,6 @@ class ROMHandlerBase
bool component_sampling = false;
bool save_lspg_basis = false;
ROMBuildingLevel save_operator = NUM_BLD_LVL;
TrainMode train_mode = NUM_TRAINMODE;
bool nonlinear_mode = false;
bool separate_variable = false;
NonlinearHandling nlin_handle = NUM_NLNHNDL;
Expand Down Expand Up @@ -120,14 +110,13 @@ class ROMHandlerBase

void ParseInputs();
public:
ROMHandlerBase(const TrainMode &train_mode_, TopologyHandler *input_topol,
const Array<int> &input_var_offsets, const std::vector<std::string> &var_names, const bool separate_variable_basis);
ROMHandlerBase(TopologyHandler *input_topol, const Array<int> &input_var_offsets,
const std::vector<std::string> &var_names, const bool separate_variable_basis);

virtual ~ROMHandlerBase();

// access
const int GetNumSubdomains() { return numSub; }
const TrainMode GetTrainMode() { return train_mode; }
const int GetNumROMRefComps() { return num_rom_comp; }
const int GetNumROMRefBlocks() { return num_rom_ref_blocks; }
const int GetRefNumBasis(const int &basis_idx) { return num_ref_basis[basis_idx]; }
Expand Down Expand Up @@ -232,8 +221,8 @@ class MFEMROMHandler : public ROMHandlerBase
mfem::BlockVector *reduced_sol = NULL;

public:
MFEMROMHandler(const TrainMode &train_mode_, TopologyHandler *input_topol,
const Array<int> &input_var_offsets, const std::vector<std::string> &var_names, const bool separate_variable_basis);
MFEMROMHandler(TopologyHandler *input_topol, const Array<int> &input_var_offsets,
const std::vector<std::string> &var_names, const bool separate_variable_basis);

virtual ~MFEMROMHandler();

Expand Down
2 changes: 1 addition & 1 deletion include/sample_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class SampleGenerator
The appended column indices of each basis tag are stored in col_idxs.
*/
void SaveSnapshot(BlockVector *U_snapshots, std::vector<std::string> &snapshot_basis_tags, Array<int> &col_idxs);
void SaveSnapshotPorts(TopologyHandler *topol_handler, const TrainMode &train_mode, const Array<int> &col_idxs);
void SaveSnapshotPorts(TopologyHandler *topol_handler, const Array<int> &col_idxs);
void AddSnapshotGenerator(const int &fom_vdofs, const std::string &prefix, const std::string &basis_tag);
void WriteSnapshots();
void WriteSnapshotPorts();
Expand Down
1 change: 0 additions & 1 deletion src/advdiff_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ void AdvDiffSolver::BuildCompROMLinElems()
{
mfem_error("AdvDiffSolver::BuildCompROMLinElems is not implemented yet!\n");

assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

Expand Down
3 changes: 0 additions & 3 deletions src/linelast_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ void LinElastSolver::ProjectOperatorOnReducedBasis()
// Component-wise assembly
void LinElastSolver::BuildCompROMLinElems()
{
assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

Expand All @@ -524,7 +523,6 @@ void LinElastSolver::BuildCompROMLinElems()

void LinElastSolver::BuildBdrROMLinElems()
{
assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

Expand Down Expand Up @@ -555,7 +553,6 @@ void LinElastSolver::BuildBdrROMLinElems()
void LinElastSolver::BuildItfaceROMLinElems()
{
assert(topol_mode == TopologyHandlerMode::COMPONENT);
assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

Expand Down
41 changes: 14 additions & 27 deletions src/main_workflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,40 +90,28 @@ SampleGenerator* InitSampleGenerator(MPI_Comm comm)
return generator;
}

std::vector<std::string> GetGlobalBasisTagList(const TopologyHandlerMode &topol_mode, const TrainMode &train_mode, bool separate_variable_basis)
std::vector<std::string> GetGlobalBasisTagList(const TopologyHandlerMode &topol_mode, bool separate_variable_basis)
{
std::vector<std::string> basis_tags(0);

std::vector<std::string> component_list(0);
if (train_mode == TrainMode::INDIVIDUAL)
if (topol_mode == TopologyHandlerMode::SUBMESH)
{
TopologyHandler *topol_handler;
if (topol_mode == TopologyHandlerMode::SUBMESH) topol_handler = new SubMeshTopologyHandler();
else if (topol_mode == TopologyHandlerMode::COMPONENT) topol_handler = new ComponentTopologyHandler();
else
mfem_error("GetGlobalBasisTagList - TopologyHandlerMode is not set!\n");

for (int c = 0; c < topol_handler->GetNumSubdomains(); c++)
component_list.push_back("dom" + std::to_string(c));
TopologyHandler *topol_handler = new SubMeshTopologyHandler();
for (int c = 0; c < topol_handler->GetNumComponents(); c++)
component_list.push_back(topol_handler->GetComponentName(c));

delete topol_handler;
}
else // if (train_mode == TrainMode::UNIVERSAL)
else if (topol_mode == TopologyHandlerMode::COMPONENT)
{
if (topol_mode == TopologyHandlerMode::SUBMESH)
{
component_list.push_back("comp0");
}
else if (topol_mode == TopologyHandlerMode::COMPONENT)
{
YAML::Node component_dict = config.FindNode("mesh/component-wise/components");
assert(component_dict);
for (int p = 0; p < component_dict.size(); p++)
component_list.push_back(config.GetRequiredOptionFromDict<std::string>("name", component_dict[p]));
}
else
mfem_error("GetGlobalBasisTagList - TopologyHandlerMode is not set!\n");
} // if (train_mode == TrainMode::UNIVERSAL)
YAML::Node component_dict = config.FindNode("mesh/component-wise/components");
assert(component_dict);
for (int p = 0; p < component_dict.size(); p++)
component_list.push_back(config.GetRequiredOptionFromDict<std::string>("name", component_dict[p]));
}
else
mfem_error("GetGlobalBasisTagList - TopologyHandlerMode is not set!\n");

std::vector<std::string> var_list(0);
if (separate_variable_basis)
Expand Down Expand Up @@ -223,11 +211,10 @@ void CollectSamples(SampleGenerator *sample_generator)
assert(sample_generator);

TopologyHandlerMode topol_mode = SetTopologyHandlerMode();
TrainMode train_mode = SetTrainMode();
bool separate_variable_basis = config.GetOption<bool>("model_reduction/separate_variable_basis", false);

// Find the all required basis tags.
std::vector<std::string> basis_tags = GetGlobalBasisTagList(topol_mode, train_mode, separate_variable_basis);
std::vector<std::string> basis_tags = GetGlobalBasisTagList(topol_mode, separate_variable_basis);

// tag-specific optional inputs.
YAML::Node basis_list = config.FindNode("basis/tags");
Expand Down
14 changes: 5 additions & 9 deletions src/multiblock_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ void MultiBlockSolver::ParseInputs()
use_rom = config.GetOption<bool>("main/use_rom", false);
separate_variable_basis = config.GetOption<bool>("model_reduction/separate_variable_basis", false);

train_mode = SetTrainMode();

// save solution if single run.
SetSolutionSaveMode(config.GetOption<bool>("save_solution/enabled", false));
}
Expand Down Expand Up @@ -261,7 +259,6 @@ void MultiBlockSolver::GetComponentFESpaces(Array<FiniteElementSpace *> &comp_fe
void MultiBlockSolver::BuildROMLinElems()
{
assert(topol_mode == TopologyHandlerMode::COMPONENT);
assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());

BuildCompROMLinElems();
Expand All @@ -276,7 +273,6 @@ void MultiBlockSolver::BuildROMLinElems()
void MultiBlockSolver::AssembleROMMat()
{
assert(topol_mode == TopologyHandlerMode::COMPONENT);
assert(train_mode == UNIVERSAL);
assert(rom_elems);

const Array<int> *rom_block_offsets = rom_handler->GetBlockOffsets();
Expand Down Expand Up @@ -658,9 +654,9 @@ void MultiBlockSolver::AssembleROMNlinOper()

void MultiBlockSolver::InitROMHandler()
{
rom_handler = new MFEMROMHandler(train_mode, topol_handler, var_offsets, var_names, separate_variable_basis);
rom_handler = new MFEMROMHandler(topol_handler, var_offsets, var_names, separate_variable_basis);

if (!((topol_mode == TopologyHandlerMode::COMPONENT) && (train_mode == UNIVERSAL)))
if (!(topol_mode == TopologyHandlerMode::COMPONENT))
return;

GetComponentFESpaces(comp_fes);
Expand All @@ -674,13 +670,13 @@ void MultiBlockSolver::GetBasisTags(std::vector<std::string> &basis_tags)
basis_tags.resize(numSub * num_var);
for (int m = 0, idx = 0; m < numSub; m++)
for (int v = 0; v < num_var; v++, idx++)
basis_tags[idx] = GetBasisTag(m, train_mode, topol_handler, var_names[v]);
basis_tags[idx] = GetBasisTag(m, topol_handler, var_names[v]);
}
else
{
basis_tags.resize(numSub);
for (int m = 0; m < numSub; m++)
basis_tags[m] = GetBasisTag(m, train_mode, topol_handler);
basis_tags[m] = GetBasisTag(m, topol_handler);
}
}

Expand Down Expand Up @@ -710,7 +706,7 @@ void MultiBlockSolver::SaveSnapshots(SampleGenerator *sample_generator)

Array<int> col_idxs;
sample_generator->SaveSnapshot(U_snapshots, basis_tags, col_idxs);
sample_generator->SaveSnapshotPorts(topol_handler, train_mode, col_idxs);
sample_generator->SaveSnapshotPorts(topol_handler, col_idxs);

/* delete only the view vector, not the data itself. */
delete U_snapshots;
Expand Down
3 changes: 0 additions & 3 deletions src/poisson_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ void PoissonSolver::AssembleInterfaceMatrices()

void PoissonSolver::BuildCompROMLinElems()
{
assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

Expand All @@ -368,7 +367,6 @@ void PoissonSolver::BuildCompROMLinElems()

void PoissonSolver::BuildBdrROMLinElems()
{
assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

Expand Down Expand Up @@ -399,7 +397,6 @@ void PoissonSolver::BuildBdrROMLinElems()
void PoissonSolver::BuildItfaceROMLinElems()
{
assert(topol_mode == TopologyHandlerMode::COMPONENT);
assert(train_mode == UNIVERSAL);
assert(rom_handler->BasisLoaded());
assert(rom_elems);

Expand Down
Loading

0 comments on commit 83597be

Please sign in to comment.