Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully templated - Explicit instantiations are made for Matrix and Vector (defined using double) - Consider defining a MatrixF, VectorF, ArrayF and ArrayD to support single-precision #164

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@ endif()
# Tests
option(WITH_TESTS "Build test suite" ON)
if(WITH_TESTS)
enable_testing()
set(GOOGLETEST_VERSION 1.10.0)
add_subdirectory("${PROJECT_SOURCE_DIR}/vendor/googletest-release-1.10.0/googletest")
add_subdirectory(tests)
find_package(GTest REQUIRED)
enable_testing()
add_subdirectory(tests)
endif()


Expand Down
2 changes: 1 addition & 1 deletion cmake/cpd_test.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ function(cpd_test name)
add_executable(${target} ${src})
set_target_properties(${target} PROPERTIES OUTPUT_NAME ${name})
add_test(NAME ${name} COMMAND ${target})
target_link_libraries(${target} PRIVATE Library-C++ ${ARGN} gtest_main)
target_link_libraries(${target} PRIVATE Library-C++ ${ARGN} ${GTEST_LIBRARIES} GTest::gtest_main)
target_include_directories(${target} PRIVATE "${PROJECT_BINARY_DIR}")
endfunction()
2 changes: 1 addition & 1 deletion components/jsoncpp/include/cpd/jsoncpp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ Json::Value to_json(const RigidResult& result);
Json::Value to_json(const AffineResult& result);
Json::Value to_json(const NonrigidResult& result);
Json::Value to_json(const Matrix& matrix);
}
} // namespace cpd
2 changes: 1 addition & 1 deletion components/jsoncpp/src/jsoncpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ Json::Value to_json(const Matrix& matrix) {
std::ostream& operator<<(std::ostream& ostream, const Result& result) {
return ostream;
}
}
} // namespace cpd
2 changes: 1 addition & 1 deletion components/jsoncpp/tests/jsoncpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ TEST(RigidResult, ConvertsToJson) {
RigidResult result;
Json::Value json = cpd::to_json(result);
}
}
} // namespace cpd
14 changes: 6 additions & 8 deletions examples/callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
#include <cpd/nonrigid.hpp>
#include <cpd/rigid.hpp>

void RigidCallback(const cpd::Result &r) {
std::cout << r.points << std::endl << std::endl;
void RigidCallback(const cpd::Result& r) {
std::cout << r.points << std::endl << std::endl;
}

void NonrigidCallback(const cpd::NonrigidResult &r) {
std::cout << r.points << std::endl << std::endl;
void NonrigidCallback(const cpd::NonrigidResult& r) {
std::cout << r.points << std::endl << std::endl;
}

int main(int argc, char** argv) {
Expand All @@ -26,12 +26,12 @@ int main(int argc, char** argv) {

if (method == "rigid") {
cpd::Rigid rigid;
auto *cb = RigidCallback;
auto* cb = RigidCallback;
rigid.add_callback(cb);
auto rigid_result = rigid.run(fixed, moving);
} else if (method == "nonrigid") {
cpd::Nonrigid nonrigid;
auto *cb = NonrigidCallback;
auto* cb = NonrigidCallback;
nonrigid.add_callback(cb);
auto nonrigid_result = nonrigid.run(fixed, moving);
} else {
Expand All @@ -41,5 +41,3 @@ int main(int argc, char** argv) {
std::cout << "Registration completed OK" << std::endl;
return 0;
}


26 changes: 15 additions & 11 deletions include/cpd/affine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,34 @@
namespace cpd {

/// The result of a affine coherent point drift run.
struct AffineResult : public Result {
template <typename M, typename V>
class AffineResult : public Result<M, V> {
public:
/// The affine transformation.
Matrix transform;
M transform;

/// The translation vector.
Vector translation;
V translation;

/// Returns the transform and the translation as one matrix.
Matrix matrix() const;
M matrix() const;

/// Denormalize this result.
void denormalize(const Normalization& normalization);
void denormalize(const Normalization<M, V>& normalization);
};

/// Affine coherent point drift.
class Affine : public Transform<AffineResult> {
template <typename M, typename V>
class Affine : public Transform<M, V, AffineResult> {
public:
Affine()
: Transform()
: Transform<M, V, AffineResult>()
, m_linked(DEFAULT_LINKED) {}

/// Computes one iteration of the affine transformation.
AffineResult compute_one(const Matrix& fixed, const Matrix& moving,
const Probabilities& probabilities,
double sigma2) const;
AffineResult<M, V> compute_one(const M& fixed, const M& moving,
const Probabilities<M, V>& probabilities,
typename M::Scalar sigma2) const;

/// Sets whether the scalings of the two datasets are linked.
Affine& linked(bool linked) {
Expand All @@ -65,5 +68,6 @@ class Affine : public Transform<AffineResult> {
};

/// Runs a affine registration on two matrices.
AffineResult affine(const Matrix& fixed, const Matrix& moving);
template <typename M, typename V>
AffineResult<M, V> affine(const M& fixed, const M& moving);
} // namespace cpd
25 changes: 15 additions & 10 deletions include/cpd/gauss_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,41 @@ namespace cpd {

/// Probability matrices produced by comparing two data sets with a
/// `GaussTransform`.
template <typename M, typename V>
struct Probabilities {
/// The probability matrix, multiplied by the identity matrix.
Vector p1;
V p1;
/// The probability matrix, transposes, multiplied by the identity matrix.
Vector pt1;
V pt1;
/// The probability matrix multiplied by the fixed points.
Matrix px;
M px;
/// The total error.
double l;
typename M::Scalar l;
/// The correspondence vector between the two datasets.
IndexVector correspondence;
};

/// Abstract base class for Gauss transforms.
template <typename M, typename V>
class GaussTransform {
public:
/// Returns the default Gauss transform as a unique ptr.
static std::unique_ptr<GaussTransform> make_default();
static std::unique_ptr<GaussTransform<M, V>> make_default();

/// Computes the Gauss transform.
virtual Probabilities compute(const Matrix& fixed, const Matrix& moving,
double sigma2, double outliers) const = 0;
virtual Probabilities<M, V> compute(const M& fixed, const M& moving,
typename M::Scalar sigma2,
typename M::Scalar outliers) const = 0;

virtual ~GaussTransform() {}
};

/// The direct Gauss transform.
class GaussTransformDirect : public GaussTransform {
template <typename M, typename V>
class GaussTransformDirect : public GaussTransform<M, V> {
public:
Probabilities compute(const Matrix& fixed, const Matrix& moving,
double sigma2, double outliers) const;
Probabilities<M, V> compute(const M& fixed, const M& moving,
typename M::Scalar sigma2,
typename M::Scalar outliers) const;
};
} // namespace cpd
24 changes: 13 additions & 11 deletions include/cpd/gauss_transform_fgt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ enum FgtMethod {
/// The default fgt method
const FgtMethod DEFAULT_FGT_METHOD = FgtMethod::DirectTree;
/// The default switched fgt breakpoint.
const double DEFAULT_BREAKPOINT = 0.2;
const typename M::Scalar DEFAULT_BREAKPOINT = 0.2;
/// The default fgt epsilon.
const double DEFAULT_EPSILON = 1e-4;
const typename M::Scalar DEFAULT_EPSILON = 1e-4;

/// The Gauss transform using the fgt library.
class GaussTransformFgt : public GaussTransform {
template <typename M, typename V>
class GaussTransformFgt : public GaussTransform<M, V> {
public:
GaussTransformFgt()
: GaussTransform()
Expand All @@ -53,13 +54,13 @@ class GaussTransformFgt : public GaussTransform {
, m_method(DEFAULT_FGT_METHOD) {}

/// Sets the ifgt->direct-tree breakpoint.
GaussTransformFgt& breakpoint(double breakpoint) {
GaussTransformFgt& breakpoint(typename M::Scalar breakpoint) {
m_breakpoint = breakpoint;
return *this;
}

/// Sets the epsilon.
GaussTransformFgt& epsilon(double epsilon) {
GaussTransformFgt& epsilon(typename M::Scalar epsilon) {
m_epsilon = epsilon;
return *this;
}
Expand All @@ -70,15 +71,16 @@ class GaussTransformFgt : public GaussTransform {
return *this;
}

Probabilities compute(const Matrix& fixed, const Matrix& moving,
double sigma2, double outliers) const;
Probabilities compute(const M& fixed, const M& moving,
typename M::Scalar sigma2,
typename M::Scalar outliers) const;

private:
std::unique_ptr<fgt::Transform> create_transform(const Matrix& points,
double bandwidth) const;
std::unique_ptr<fgt::Transform> create_transform(
const M& points, typename M::Scalar bandwidth) const;

double m_breakpoint;
double m_epsilon;
typename M::Scalar m_breakpoint;
typename M::Scalar m_epsilon;
FgtMethod m_method;
};
} // namespace cpd
14 changes: 9 additions & 5 deletions include/cpd/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,22 @@ namespace cpd {

/// Our base matrix class.
typedef Eigen::MatrixXd Matrix;
typedef Eigen::MatrixXf MatrixF;

/// Typedef for our specific type of vector.
typedef Eigen::VectorXd Vector;
typedef Eigen::VectorXf VectorF;

/// Typedef for our specific type of array. (TODO: Support this)
typedef Eigen::ArrayXd Array;
typedef Eigen::ArrayXf ArrayF;

/// Typedef for an index vector, used to index other matrices.
typedef Eigen::Matrix<Matrix::Index, Eigen::Dynamic, 1> IndexVector;

/// Typedef for our specific type of array.
typedef Eigen::ArrayXd Array;

/// Apply a transformation matrix to a set of points.
///
/// The transformation matrix should be one column wider than the point matrix.
Matrix apply_transformation_matrix(Matrix points, const Matrix& transform);
/// The transformation matrix should be one column wider than the point matrix
template <typename M, typename V>
M apply_transformation_matrix(M points, const M& transform);
} // namespace cpd
35 changes: 19 additions & 16 deletions include/cpd/nonrigid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,61 +26,64 @@
namespace cpd {

/// Default value for beta.
const double DEFAULT_BETA = 3.0;
const Matrix::Scalar DEFAULT_BETA = 3.0;
/// Default value for lambda.
const double DEFAULT_LAMBDA = 3.0;
const Matrix::Scalar DEFAULT_LAMBDA = 3.0;

/// The result of a nonrigid coherent point drift run.
struct NonrigidResult : public Result {};
template <typename M, typename V>
class NonrigidResult : public Result<M, V> {};

/// Nonrigid coherent point drift.
class Nonrigid : public Transform<NonrigidResult> {
template <typename M, typename V>
class Nonrigid : public Transform<M, V, NonrigidResult> {
public:
Nonrigid()
: Transform()
: Transform<M, V, NonrigidResult>()
, m_lambda(DEFAULT_LAMBDA)
, m_beta(DEFAULT_BETA)
, m_linked(DEFAULT_LINKED) {}

/// Initialize this transform for the provided matrices.
void init(const Matrix& fixed, const Matrix& moving);
void init(const M& fixed, const M& moving);

/// Modifies the probabilities with some affinity and weight information.
void modify_probabilities(Probabilities& probabilities) const;
void modify_probabilities(Probabilities<M, V>& probabilities) const;

/// Sets the beta.
Nonrigid& beta(double beta) {
Nonrigid<M, V>& beta(double beta) {
m_beta = beta;
return *this;
}

/// Sets the lambda.
Nonrigid& lambda(double lambda) {
Nonrigid<M, V>& lambda(double lambda) {
m_lambda = lambda;
return *this;
}

/// Computes one iteration of the nonrigid transformation.
NonrigidResult compute_one(const Matrix& fixed, const Matrix& moving,
const Probabilities& probabilities,
double sigma2) const;
NonrigidResult<M, V> compute_one(const M& fixed, const M& moving,
const Probabilities<M, V>& probabilities,
typename M::Scalar sigma2) const;

/// Sets whether the scalings of the two datasets are linked.
Nonrigid& linked(bool linked) {
Nonrigid<M, V>& linked(bool linked) {
m_linked = linked;
return *this;
}

virtual bool linked() const { return m_linked; }

private:
Matrix m_g;
Matrix m_w;
M m_g;
M m_w;
double m_lambda;
double m_beta;
bool m_linked;
};

/// Runs a nonrigid registration on two matrices.
NonrigidResult nonrigid(const Matrix& fixed, const Matrix& moving);
template <typename M, typename V>
NonrigidResult<M, V> nonrigid(const M& fixed, const M& moving);
} // namespace cpd
16 changes: 8 additions & 8 deletions include/cpd/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@
namespace cpd {

/// The results of normalizing data to a unit cube (or whatever dimensionality).
template <typename M, typename V>
struct Normalization {
/// The average of the fixed points, that was subtracted from those data.
Vector fixed_mean;
V fixed_mean;
/// The fixed points.
Matrix fixed;
M fixed;
/// The scaling factor for the fixed points.
double fixed_scale;
typename M::Scalar fixed_scale;
/// The average of the moving points, that was subtracted from those data.
Vector moving_mean;
V moving_mean;
/// The moving points.
Matrix moving;
M moving;
/// The scaling factor for the moving points.
double moving_scale;
typename M::Scalar moving_scale;

/// Creates a new normalization for the provided matrices.
///
Expand All @@ -49,7 +50,6 @@ struct Normalization {
/// seperately.
///
/// Myronenko's original implementation only had `linked = false` logic.
Normalization(const Matrix& fixed, const Matrix& moving,
bool linked = true);
Normalization(const M& fixed, const M& moving, bool linked = true);
};
} // namespace cpd
Loading
Loading