Skip to content

Commit

Permalink
Integrate pack_vector_fields into SphericalVector Interpolation m…
Browse files Browse the repository at this point in the history
…ethod. (#224)
  • Loading branch information
odlomax authored Oct 11, 2024
1 parent 10cffa6 commit 96edef9
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 11 deletions.
32 changes: 21 additions & 11 deletions src/atlas/interpolation/method/sphericalvector/SphericalVector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "atlas/runtime/Trace.h"
#include "atlas/util/Constants.h"
#include "atlas/util/Geometry.h"
#include "atlas/util/PackVectorFields.h"
#include "eckit/config/LocalConfiguration.h"

namespace atlas {
Expand Down Expand Up @@ -95,10 +96,9 @@ void SphericalVector::do_setup(const FunctionSpace& source,
const auto deltaAlpha =
(alpha.first - alpha.second) * util::Constants::degreesToRadians();

complexTriplets[dataIndex] =
ComplexTriplet{rowIndex, colIndex,
Complex{baseWeight * std::cos(deltaAlpha),
baseWeight * std::sin(deltaAlpha)}};
complexTriplets[dataIndex] = ComplexTriplet{
rowIndex, colIndex,
baseWeight * Complex{std::cos(deltaAlpha), std::sin(deltaAlpha)}};
realTriplets[dataIndex] = RealTriplet{rowIndex, colIndex, baseWeight};
}
}
Expand All @@ -120,17 +120,22 @@ void SphericalVector::do_execute(const FieldSet& sourceFieldSet,
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()");
ATLAS_ASSERT(sourceFieldSet.size() == targetFieldSet.size());

for (auto i = 0; i < sourceFieldSet.size(); ++i) {
do_execute(sourceFieldSet[i], targetFieldSet[i], metadata);
const auto packedSourceFieldSet = util::pack_vector_fields(sourceFieldSet);
auto packedTargetFieldSet = util::pack_vector_fields(targetFieldSet);

for (auto i = 0; i < packedSourceFieldSet.size(); ++i) {
do_execute(packedSourceFieldSet[i], packedTargetFieldSet[i], metadata);
}

util::unpack_vector_fields(packedTargetFieldSet, targetFieldSet);
}

void SphericalVector::do_execute(const Field& sourceField, Field& targetField,
Metadata&) const {
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()");

if (targetField.size() == 0) {
return;
return;
}

const auto fieldType = sourceField.metadata().getString("type", "");
Expand All @@ -156,9 +161,15 @@ void SphericalVector::do_execute_adjoint(FieldSet& sourceFieldSet,
"atlas::interpolation::method::SphericalVector::do_execute_adjoint()");
ATLAS_ASSERT(sourceFieldSet.size() == targetFieldSet.size());

for (auto i = 0; i < sourceFieldSet.size(); ++i) {
do_execute_adjoint(sourceFieldSet[i], targetFieldSet[i], metadata);
auto packedSourceFieldSet = util::pack_vector_fields(sourceFieldSet);
const auto packedTargetFieldSet = util::pack_vector_fields(targetFieldSet);

for (auto i = 0; i < packedSourceFieldSet.size(); ++i) {
do_execute_adjoint(packedSourceFieldSet[i], packedTargetFieldSet[i],
metadata);
}

util::unpack_vector_fields(packedSourceFieldSet, sourceFieldSet);
}

void SphericalVector::do_execute_adjoint(Field& sourceField,
Expand All @@ -168,7 +179,7 @@ void SphericalVector::do_execute_adjoint(Field& sourceField,
"atlas::interpolation::method::SphericalVector::do_execute_adjoint()");

if (targetField.size() == 0) {
return;
return;
}

const auto fieldType = sourceField.metadata().getString("type", "");
Expand All @@ -192,7 +203,6 @@ template <typename MatMul>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField,
const MatMul& matMul) {

ATLAS_ASSERT_MSG(sourceField.variables() == 2 || sourceField.variables() == 3,
"Vector field can only have 2 or 3 components.");

Expand Down
10 changes: 10 additions & 0 deletions src/atlas/util/PackVectorFields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ FieldSet pack_vector_fields(const FieldSet& fields, FieldSet packedFields) {
componentFieldMetadataVector.push_back(componentFieldMetadata);
vectorField.metadata().set("component_field_metadata",
componentFieldMetadataVector);

// If any component is dirty, the whole field is dirty.
if (vectorIndex == 0) {
vectorField.set_dirty(componentField.dirty());
} else {
vectorField.set_dirty(vectorField.dirty() || componentField.dirty());
}


}
return packedFields;
}
Expand Down Expand Up @@ -218,6 +227,7 @@ FieldSet unpack_vector_fields(const FieldSet& fields, FieldSet unpackedFields) {

// Copy metadata.
componentField.metadata() = componentFieldMetadata;
componentField.set_dirty(vectorField.dirty());

++vectorIndex;
}
Expand Down
91 changes: 91 additions & 0 deletions src/tests/interpolation/test_interpolation_spherical_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,97 @@ CASE("structured columns O96 vector interpolation (2d-field, 2-vector, hi-res)")
testInterpolation<Rank2dField>((config));
}

CASE("separate vector field components") {
const auto sourceFunctionSpace =
FunctionSpaceFixtures::get("structured_columns");
const auto targetFunctionSpace =
FunctionSpaceFixtures::get("cubedsphere_mesh");

auto sourceFieldSet = FieldSet{};
auto targetFieldSet = FieldSet{};

const auto sourceLonLatView =
array::make_view<double, 2>(sourceFunctionSpace.lonlat());
const auto targetLonLatView =
array::make_view<double, 2>(targetFunctionSpace.lonlat());

const auto createFieldView = [&](const FunctionSpace& functionSpace,
const std::string& name,
FieldSet& fieldSet) {
// Note: Vector field name can be anything that uniquely identifies field.
auto field = functionSpace.createField<double>(option::name(name));
field.metadata().set("vector_field_name", "wind");
return array::make_view<double, 1>(fieldSet.add(field));
};

auto uSourceView = createFieldView(sourceFunctionSpace, "u", sourceFieldSet);
auto vSourceView = createFieldView(sourceFunctionSpace, "v", sourceFieldSet);
const auto uTargetView =
createFieldView(targetFunctionSpace, "u", targetFieldSet);
const auto vTargetView =
createFieldView(targetFunctionSpace, "v", targetFieldSet);

uSourceView.assign(0.);
vSourceView.assign(0.);
for (auto idx = idx_t{0}; idx < sourceFunctionSpace.size(); idx++) {
std::tie(uSourceView(idx), vSourceView(idx)) =
vortexHorizontal(sourceLonLatView(idx, 0), sourceLonLatView(idx, 1));
}

const auto interpScheme =
InterpSchemeFixtures::get("structured_linear_spherical");

const auto interp =
Interpolation(interpScheme, sourceFunctionSpace, targetFunctionSpace);

interp.execute(sourceFieldSet, targetFieldSet);
targetFieldSet.haloExchange();

auto errorView =
createFieldView(targetFunctionSpace, "error", targetFieldSet);

auto maxError = 0.;
for (auto idx = idx_t{0}; idx < targetFunctionSpace.size(); idx++) {
auto [uTrue, vTrue] =
vortexHorizontal(targetLonLatView(idx, 0), targetLonLatView(idx, 1));
errorView(idx) =
std::hypot(uTrue - uTargetView(idx), vTrue - vTargetView(idx));
maxError = std::max(maxError, errorView(idx));
}
EXPECT_APPROX_EQ(maxError, 0., 0.00017);

gmshOutput("vector_components_source.msh", sourceFieldSet);
gmshOutput("vector_components_target.msh", targetFieldSet);

auto sourceAdjointFieldSet = FieldSet{};
auto targetAdjointFieldSet = FieldSet{};

targetAdjointFieldSet.add(targetFieldSet["u"].clone());
targetAdjointFieldSet.add(targetFieldSet["v"].clone());

targetAdjointFieldSet.adjointHaloExchange();

auto uSourceAdjointView =
createFieldView(sourceFunctionSpace, "u", sourceAdjointFieldSet);
auto vSourceAdjointView =
createFieldView(sourceFunctionSpace, "v", sourceAdjointFieldSet);
uSourceAdjointView.assign(0.);
vSourceAdjointView.assign(0.);

// sourceAdjointFieldSet.set_dirty(false);
interp.execute_adjoint(sourceAdjointFieldSet, targetAdjointFieldSet);

constexpr auto tinyNum = 1e-13;
const auto targetDotTarget = dotProduct(uTargetView, uTargetView) +
dotProduct(vTargetView, vTargetView);
const auto sourceDotSourceAdjoint =
dotProduct(uSourceView, uSourceAdjointView) +
dotProduct(vSourceView, vSourceAdjointView);

const auto dotProdRatio = targetDotTarget / sourceDotSourceAdjoint;
EXPECT_APPROX_EQ(dotProdRatio, 1., tinyNum);
}

} // namespace test
} // namespace atlas

Expand Down

0 comments on commit 96edef9

Please sign in to comment.