Skip to content

Commit

Permalink
Lfilter (#13)
Browse files Browse the repository at this point in the history
* added lfilter

* Ran pre-commit

* Fixed env file

* Improved implementation

* Ran pre-commit

* Updated to reflect PR comments

* Use size_type

* Try to fix tests

* Moved one pad out

* Removed extra pad call

* Run pre-commit
  • Loading branch information
spectre-ns authored Nov 19, 2023
1 parent 3b1df5f commit fb8a672
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 3 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ repos:
- id: detect-private-key
- id: check-merge-conflict
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.4.2
rev: v1.5.4
hooks:
- id: forbid-tabs
- id: remove-tabs
args: [--whitespaces-count, '4']
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.7.0
rev: v2.10.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
Expand All @@ -41,7 +41,7 @@ repos:
files: environment.yaml
# Externally provided executables (so we can use them with editors as well).
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v15.0.7
rev: v16.0.6
hooks:
- id: clang-format
files: .*\.[hc]pp$
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ endif()
set(XTENSOR_SIGNAL_HEADERS
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/xtensor_signal.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/find_peaks.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/lfilter.hpp
)

add_library(xtensor-signal INTERFACE)
Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ dependencies:
- openblas
- doctest
- zlib
- pre-commit
95 changes: 95 additions & 0 deletions include/xtensor-signal/lfilter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#ifndef XTENSOR_SIGNAL_LFILTER_HPP
#define XTENSOR_SIGNAL_LFILTER_HPP

#include <type_traits>
#include <xtensor-blas/xlinalg.hpp>
#include <xtensor/xarray.hpp>
#include <xtensor/xaxis_slice_iterator.hpp>
#include <xtensor/xbuilder.hpp>
#include <xtensor/xexception.hpp>
#include <xtensor/xindex_view.hpp>
#include <xtensor/xnoalias.hpp>
#include <xtensor/xpad.hpp>
#include <xtensor/xtensor.hpp>

namespace xt {
namespace signal {
namespace detail {
template <typename E1, typename E2, typename E3, typename E4>
inline auto lfilter(E1 &&b, E2 &&a, E3 &&x, E4 zi) {
using value_type = typename std::decay_t<E3>::value_type;
using size_type = typename std::decay_t<E3>::size_type;
if (zi.shape(0) != x.shape(0)) {
XTENSOR_THROW(
std::runtime_error,
"Accumulator initialization must be the same length as the input");
}
if (x.dimension() != 1) {
XTENSOR_THROW(std::runtime_error,
"Implementation only works on 1D arguments");
}
if (a.dimension() != 1) {
XTENSOR_THROW(std::runtime_error,
"Implementation only works on 1D arguments");
}
if (b.dimension() != 1) {
XTENSOR_THROW(std::runtime_error,
"Implementation only works on 1D arguments");
}
xt::xtensor<value_type, 1> out =
xt::zeros<value_type>({x.shape(0) + 2 * (a.shape(0) - 1)});
auto padded_x = xt::pad(x, b.shape(0) - 1);
for (size_type i = 0; i < x.shape(0); i++) {
auto b_accum =
xt::sum(b *
xt::flip(xt::view(padded_x, xt::range(i, i + b.shape(0))))) +
zi(i);

auto a_accum =
b_accum -
xt::sum(xt::view(a, xt::range(1, xt::placeholders::_)) *
xt::flip(xt::view(out, xt::range(i, i + a.shape(0) - 1))));
auto result = a_accum / a(0);
out(i + a.shape(0) - 1) = result();
}
out = xt::view(out, xt::range(a.shape(0) - 1, -(a.shape(0) - 1)));
return out;
}
} // namespace detail

/*
* @brief performs a 1D filter operation along the specified axis. Performs
* operations immediately.
* @param b the numerator of the filter expression
* @param a the denominator of the filter expression
* @param x input dataset
* @param axis the axis along which to perform the filter operation
* @param zi initial condition of the filter accumulator
* @return filtered version of x
* @todo Add implementation bound to MKL or HPC library for IIR and FIR
*/
template <typename E1, typename E2, typename E3,
typename E4 = decltype(xt::xnone())>
inline auto lfilter(E1 &&b, E2 &&a, E3 &&x, std::ptrdiff_t axis = -1,
E4 zi = xt::xnone()) {
using value_type = typename std::decay_t<E3>::value_type;
xt::xarray<value_type> out(x);
auto saxis = xt::normalize_axis(out.dimension(), axis);
auto begin = xt::axis_slice_begin(out, saxis);
auto end = xt::axis_slice_end(out, saxis);

for (auto iter = begin; iter != end; iter++) {
if constexpr (std::is_same<typename std::decay<E4>::type,
decltype(xt::xnone())>::value == false) {
xt::noalias(*iter) = detail::lfilter(b, a, *iter, zi);
} else {
xt::noalias(*iter) = detail::lfilter(
b, a, *iter, xt::zeros<value_type>({(*iter).shape(0)}));
}
}
return out;
}
} // namespace signal
} // namespace xt

#endif
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(XTENSOR_SIGNAL_TESTS
main.cpp
test_config.cpp
find_peaks_test.cpp
lfilter_test.cpp
)

if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
Expand Down
60 changes: 60 additions & 0 deletions test/lfilter_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "doctest/doctest.h"
#include "xtensor-signal/lfilter.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xrandom.hpp"
#include "xtensor/xsort.hpp"
#include "xtensor/xview.hpp"

TEST_SUITE("lfilter") {

TEST_CASE("3rdOrderButterworth") {
// credit
// https://rosettacode.org/wiki/Apply_a_digital_filter_(direct_form_II_transposed)#C++
// define the signal
xt::xtensor<float, 1> sig = {
-0.917843918645, 0.141984778794, 1.20536903482, 0.190286794412,
-0.662370894973, -1.00700480494, -0.404707073677, 0.800482325044,
0.743500089861, 1.01090520172, 0.741527555207, 0.277841675195,
0.400833448236, -0.2085993586, -0.172842103641, -0.134316096293,
0.0259303398477, 0.490105989562, 0.549391221511, 0.9047198589};

xt::xtensor<float, 1> expectation = {
-0.152974, -0.435258, -0.136043, 0.697503, 0.656445,
-0.435483, -1.08924, -0.537677, 0.51705, 1.05225,
0.961854, 0.69569, 0.424356, 0.196262, -0.0278351,
-0.211722, -0.174746, 0.0692584, 0.385446, 0.651771};

// Constants for a Butterworth filter (order 3, low pass)
xt::xtensor<float, 1> a = {1.00000000, -2.77555756e-16, 3.33333333e-01,
-1.85037171e-17};
xt::xtensor<float, 1> b = {0.16666667, 0.5, 0.5, 0.16666667};

auto res = xt::signal::lfilter(b, a, sig);
REQUIRE(xt::all(xt::isclose(res, expectation)));
}
TEST_CASE("3rdOrderButterworth_MultipleDims") {
xt::xtensor<float, 1> sig = {
-0.917843918645, 0.141984778794, 1.20536903482, 0.190286794412,
-0.662370894973, -1.00700480494, -0.404707073677, 0.800482325044,
0.743500089861, 1.01090520172, 0.741527555207, 0.277841675195,
0.400833448236, -0.2085993586, -0.172842103641, -0.134316096293,
0.0259303398477, 0.490105989562, 0.549391221511, 0.9047198589};

xt::xtensor<float, 1> expectation = {
-0.152974, -0.435258, -0.136043, 0.697503, 0.656445,
-0.435483, -1.08924, -0.537677, 0.51705, 1.05225,
0.961854, 0.69569, 0.424356, 0.196262, -0.0278351,
-0.211722, -0.174746, 0.0692584, 0.385446, 0.651771};

xt::xarray<float> sig2 =
xt::stack(std::make_tuple(sig, xt::zeros_like(sig)), 1);

// Constants for a Butterworth filter (order 3, low pass)
xt::xtensor<float, 1> a = {1.00000000, -2.77555756e-16, 3.33333333e-01,
-1.85037171e-17};
xt::xtensor<float, 1> b = {0.16666667, 0.5, 0.5, 0.16666667};
auto res = xt::signal::lfilter(b, a, sig2, 0);
REQUIRE(xt::all(xt::isclose(xt::view(res, xt::all(), 0), expectation)));
REQUIRE(xt::all(xt::isclose(xt::view(res, xt::all(), 1), 0)));
}
}

0 comments on commit fb8a672

Please sign in to comment.