-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added lfilter * Ran pre-commit * Fixed env file * Improved implementation * Ran pre-commit * Updated to reflect PR comments * Use size_type * Try to fix tests * Moved one pad out * Removed extra pad call * Run pre-commit
- Loading branch information
1 parent
3b1df5f
commit fb8a672
Showing
6 changed files
with
161 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ dependencies: | |
- openblas | ||
- doctest | ||
- zlib | ||
- pre-commit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#ifndef XTENSOR_SIGNAL_LFILTER_HPP | ||
#define XTENSOR_SIGNAL_LFILTER_HPP | ||
|
||
#include <type_traits> | ||
#include <xtensor-blas/xlinalg.hpp> | ||
#include <xtensor/xarray.hpp> | ||
#include <xtensor/xaxis_slice_iterator.hpp> | ||
#include <xtensor/xbuilder.hpp> | ||
#include <xtensor/xexception.hpp> | ||
#include <xtensor/xindex_view.hpp> | ||
#include <xtensor/xnoalias.hpp> | ||
#include <xtensor/xpad.hpp> | ||
#include <xtensor/xtensor.hpp> | ||
|
||
namespace xt { | ||
namespace signal { | ||
namespace detail { | ||
template <typename E1, typename E2, typename E3, typename E4> | ||
inline auto lfilter(E1 &&b, E2 &&a, E3 &&x, E4 zi) { | ||
using value_type = typename std::decay_t<E3>::value_type; | ||
using size_type = typename std::decay_t<E3>::size_type; | ||
if (zi.shape(0) != x.shape(0)) { | ||
XTENSOR_THROW( | ||
std::runtime_error, | ||
"Accumulator initialization must be the same length as the input"); | ||
} | ||
if (x.dimension() != 1) { | ||
XTENSOR_THROW(std::runtime_error, | ||
"Implementation only works on 1D arguments"); | ||
} | ||
if (a.dimension() != 1) { | ||
XTENSOR_THROW(std::runtime_error, | ||
"Implementation only works on 1D arguments"); | ||
} | ||
if (b.dimension() != 1) { | ||
XTENSOR_THROW(std::runtime_error, | ||
"Implementation only works on 1D arguments"); | ||
} | ||
xt::xtensor<value_type, 1> out = | ||
xt::zeros<value_type>({x.shape(0) + 2 * (a.shape(0) - 1)}); | ||
auto padded_x = xt::pad(x, b.shape(0) - 1); | ||
for (size_type i = 0; i < x.shape(0); i++) { | ||
auto b_accum = | ||
xt::sum(b * | ||
xt::flip(xt::view(padded_x, xt::range(i, i + b.shape(0))))) + | ||
zi(i); | ||
|
||
auto a_accum = | ||
b_accum - | ||
xt::sum(xt::view(a, xt::range(1, xt::placeholders::_)) * | ||
xt::flip(xt::view(out, xt::range(i, i + a.shape(0) - 1)))); | ||
auto result = a_accum / a(0); | ||
out(i + a.shape(0) - 1) = result(); | ||
} | ||
out = xt::view(out, xt::range(a.shape(0) - 1, -(a.shape(0) - 1))); | ||
return out; | ||
} | ||
} // namespace detail | ||
|
||
/* | ||
* @brief performs a 1D filter operation along the specified axis. Performs | ||
* operations immediately. | ||
* @param b the numerator of the filter expression | ||
* @param a the denominator of the filter expression | ||
* @param x input dataset | ||
* @param axis the axis along which to perform the filter operation | ||
* @param zi initial condition of the filter accumulator | ||
* @return filtered version of x | ||
* @todo Add implementation bound to MKL or HPC library for IIR and FIR | ||
*/ | ||
template <typename E1, typename E2, typename E3, | ||
typename E4 = decltype(xt::xnone())> | ||
inline auto lfilter(E1 &&b, E2 &&a, E3 &&x, std::ptrdiff_t axis = -1, | ||
E4 zi = xt::xnone()) { | ||
using value_type = typename std::decay_t<E3>::value_type; | ||
xt::xarray<value_type> out(x); | ||
auto saxis = xt::normalize_axis(out.dimension(), axis); | ||
auto begin = xt::axis_slice_begin(out, saxis); | ||
auto end = xt::axis_slice_end(out, saxis); | ||
|
||
for (auto iter = begin; iter != end; iter++) { | ||
if constexpr (std::is_same<typename std::decay<E4>::type, | ||
decltype(xt::xnone())>::value == false) { | ||
xt::noalias(*iter) = detail::lfilter(b, a, *iter, zi); | ||
} else { | ||
xt::noalias(*iter) = detail::lfilter( | ||
b, a, *iter, xt::zeros<value_type>({(*iter).shape(0)})); | ||
} | ||
} | ||
return out; | ||
} | ||
} // namespace signal | ||
} // namespace xt | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#include "doctest/doctest.h" | ||
#include "xtensor-signal/lfilter.hpp" | ||
#include "xtensor/xio.hpp" | ||
#include "xtensor/xrandom.hpp" | ||
#include "xtensor/xsort.hpp" | ||
#include "xtensor/xview.hpp" | ||
|
||
TEST_SUITE("lfilter") { | ||
|
||
TEST_CASE("3rdOrderButterworth") { | ||
// credit | ||
// https://rosettacode.org/wiki/Apply_a_digital_filter_(direct_form_II_transposed)#C++ | ||
// define the signal | ||
xt::xtensor<float, 1> sig = { | ||
-0.917843918645, 0.141984778794, 1.20536903482, 0.190286794412, | ||
-0.662370894973, -1.00700480494, -0.404707073677, 0.800482325044, | ||
0.743500089861, 1.01090520172, 0.741527555207, 0.277841675195, | ||
0.400833448236, -0.2085993586, -0.172842103641, -0.134316096293, | ||
0.0259303398477, 0.490105989562, 0.549391221511, 0.9047198589}; | ||
|
||
xt::xtensor<float, 1> expectation = { | ||
-0.152974, -0.435258, -0.136043, 0.697503, 0.656445, | ||
-0.435483, -1.08924, -0.537677, 0.51705, 1.05225, | ||
0.961854, 0.69569, 0.424356, 0.196262, -0.0278351, | ||
-0.211722, -0.174746, 0.0692584, 0.385446, 0.651771}; | ||
|
||
// Constants for a Butterworth filter (order 3, low pass) | ||
xt::xtensor<float, 1> a = {1.00000000, -2.77555756e-16, 3.33333333e-01, | ||
-1.85037171e-17}; | ||
xt::xtensor<float, 1> b = {0.16666667, 0.5, 0.5, 0.16666667}; | ||
|
||
auto res = xt::signal::lfilter(b, a, sig); | ||
REQUIRE(xt::all(xt::isclose(res, expectation))); | ||
} | ||
TEST_CASE("3rdOrderButterworth_MultipleDims") { | ||
xt::xtensor<float, 1> sig = { | ||
-0.917843918645, 0.141984778794, 1.20536903482, 0.190286794412, | ||
-0.662370894973, -1.00700480494, -0.404707073677, 0.800482325044, | ||
0.743500089861, 1.01090520172, 0.741527555207, 0.277841675195, | ||
0.400833448236, -0.2085993586, -0.172842103641, -0.134316096293, | ||
0.0259303398477, 0.490105989562, 0.549391221511, 0.9047198589}; | ||
|
||
xt::xtensor<float, 1> expectation = { | ||
-0.152974, -0.435258, -0.136043, 0.697503, 0.656445, | ||
-0.435483, -1.08924, -0.537677, 0.51705, 1.05225, | ||
0.961854, 0.69569, 0.424356, 0.196262, -0.0278351, | ||
-0.211722, -0.174746, 0.0692584, 0.385446, 0.651771}; | ||
|
||
xt::xarray<float> sig2 = | ||
xt::stack(std::make_tuple(sig, xt::zeros_like(sig)), 1); | ||
|
||
// Constants for a Butterworth filter (order 3, low pass) | ||
xt::xtensor<float, 1> a = {1.00000000, -2.77555756e-16, 3.33333333e-01, | ||
-1.85037171e-17}; | ||
xt::xtensor<float, 1> b = {0.16666667, 0.5, 0.5, 0.16666667}; | ||
auto res = xt::signal::lfilter(b, a, sig2, 0); | ||
REQUIRE(xt::all(xt::isclose(xt::view(res, xt::all(), 0), expectation))); | ||
REQUIRE(xt::all(xt::isclose(xt::view(res, xt::all(), 1), 0))); | ||
} | ||
} |