From 6347f8d4dcb39f7f35af612cd019d33d0e2eff80 Mon Sep 17 00:00:00 2001 From: harshith-gowrachari Date: Thu, 3 Oct 2024 14:54:52 +0200 Subject: [PATCH] fix test nnshift --- tests/test_nnshift.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_nnshift.py b/tests/test_nnshift.py index b3eead6..1e28abb 100644 --- a/tests/test_nnshift.py +++ b/tests/test_nnshift.py @@ -27,11 +27,11 @@ def test_constructor(): def test_fit_train(): - seed = 147 + seed = 1 torch.manual_seed(seed) np.random.seed(seed) - interp = ANN([10, 10], torch.nn.Softplus(), 1000, frequency_print=200, lr=0.03) - shift = ANN([], torch.nn.LeakyReLU(), [2500, 1e-3], frequency_print=200, l2_regularization=0, lr=0.0005) + interp = ANN([10, 10], torch.nn.Softplus(), 10000, frequency_print=200, lr=0.03) + shift = ANN([], torch.nn.LeakyReLU(), [1e-4, 5000], frequency_print=200, l2_regularization=0, lr=0.0023) nnspod = AutomaticShiftSnapshots(shift, interp, Linear(fill_value=0.0), barycenter_loss=10.) pod = POD(rank=1) rbf = RBF()