Skip to content

Commit

Permalink
FIX problem with extra dimension that makes linear interpolation fail…
Browse files Browse the repository at this point in the history
…ing (#247)

* FIX problem with extra dimension that makes linear interpolation failing
  • Loading branch information
giovastabile authored Nov 15, 2023
1 parent e57fe51 commit 677794b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ezyrb/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ def predict(self, new_point):
:return: the interpolated values.
:rtype: numpy.ndarray
"""
return self.interpolator(new_point)
return self.interpolator(new_point).squeeze()
9 changes: 8 additions & 1 deletion tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import TestCase
from ezyrb import Linear, Database, POD, ReducedOrderModel

class TestKNeighbors(TestCase):
class TestLinear(TestCase):
def test_params(self):
reg = Linear(fill_value=0)
assert reg.fill_value == 0
Expand Down Expand Up @@ -52,6 +52,13 @@ def test_with_db_predict(self):
assert rom.predict([2]) == 5
assert rom.predict([3]) == 3

Y = np.random.uniform(size=(3, 3))
db = Database(np.array([1, 2, 3]), Y)
rom = ReducedOrderModel(db, POD(), Linear())
rom.fit()
assert rom.predict([1.]).shape == (3,)


def test_wrong1(self):
# wrong number of params
with warnings.catch_warnings():
Expand Down

0 comments on commit 677794b

Please sign in to comment.