Skip to content

Commit

Permalink
fix: stash and restore prediction when calling learn during learn_ret…
Browse files Browse the repository at this point in the history
…urns_prediction==false calls (#4632)

* fix: stash and restore prediction when calling learn during learn_returns_prediction==false calls

* move dont swap

* check ci

* comment out

* go back to swap, fix igl's usage of pred in learn

* copy predictions for compat

* formatting

* undo copy reductions

* cats learn returns prediction

* update tests

* more cats tests

* update test

---------

Co-authored-by: Jack Gerrits <[email protected]>
  • Loading branch information
jackgerrits and Jack Gerrits authored Aug 31, 2023
1 parent c4c85ed commit 66d5572
Show file tree
Hide file tree
Showing 16 changed files with 310 additions and 246 deletions.
114 changes: 57 additions & 57 deletions test/pred-sets/ref/cats.predict
Original file line number Diff line number Diff line change
@@ -1,57 +1,57 @@
17118.012,0.00016043648
15724.644,0.00016043648
15331.798,0.00016043648
15138.831,0.00016043648
13393.029,0.00016043648
12710.978,0.00016043648
14503.708,0.00016043648
15936.4,0.00016043648
17257.076,0.00016043648
17208.654,0.00016043648
14921.187,0.00016043648
16660.594,0.00016043648
13486.455,0.00016043648
6641.8545,2.1031378e-06
14801.345,0.00016043648
17399.625,0.00016043648
16820.953,0.00016043648
13160.789,0.00016043648
15094.871,0.00016043648
13487.722,0.00016043648
17514.398,0.00016043648
12287.391,0.00016043648
17394.318,0.00016043648
12249.621,0.00016043648
12830.798,0.00016043648
14759.012,0.00016043648
17850.668,0.00016043648
15202.806,0.00016043648
16561.379,0.00016043648
12855.774,0.00016043648
13520.215,0.00016043648
14063.967,0.00016043648
16150.232,0.00016043648
12360.427,0.00016043648
13487.657,0.00016043648
17476.367,0.00016043648
15890.128,0.00016043648
14077.783,0.00016043648
17775.543,0.00016043648
15440.219,0.00016043648
16326.922,0.00016043648
13204.399,0.00016043648
17127.232,0.00016043648
17641.572,0.00016043648
13169.023,0.00016043648
16228.357,0.00016043648
14428.307,0.00016043648
16230.57,0.00016043648
12099.055,0.00016043648
17408.406,0.00016043648
17183.14,0.00016043648
13980.631,0.00016043648
16711.67,0.00016043648
1768.2766,2.1031378e-06
12294.268,0.00016043648
15582.021,0.00016043648
17690.637,0.00016043648
17540.906,0.00016043648
15536.194,0.00016043648
17381.867,0.00016043648
12421.002,0.00016043648
13341.685,0.00016043648
15988.09,0.00016043648
13866.5,0.00016043648
17106.008,0.00016043648
15123.611,0.00016043648
15783.712,0.00016043648
22992.766,2.1031378e-06
7194.4995,2.1031378e-06
15160.824,0.00016043648
17345.219,0.00016043648
14788.842,0.00016043648
14527.182,0.00016043648
16322.045,0.00016043648
12985.183,0.00016043648
13590.578,0.00016043648
14753.838,0.00016043648
17058.365,0.00016043648
13427.974,0.00016043648
16906.275,0.00016043648
14033.336,0.00016043648
12143.846,0.00016043648
14244.833,0.00016043648
16280.45,0.00016043648
16732.473,0.00016043648
14155.6875,0.00016043648
15759.547,0.00016043648
14163.108,0.00016043648
14782.351,0.00016043648
12062.178,0.00016043648
22556.402,2.1031378e-06
16126.391,0.00016043648
17084.584,0.00016043648
17965.973,0.00016043648
12816.858,0.00016043648
14803.017,0.00016043648
16711.629,0.00016043648
14877.602,0.00016043648
15326.174,0.00016043648
13539.921,0.00016043648
15957.607,0.00016043648
15291.404,0.00016043648
13974.891,0.00016043648
17413.562,0.00016043648
13018.842,0.00016043648
17922.018,0.00016043648
17293.945,0.00016043648
15109.656,0.00016043648
16252.369,0.00016043648
14785.5625,0.00016043648
17082.445,0.00016043648
15951.468,0.00016043648
12971.434,0.00016043648
12969.782,0.00016043648
20 changes: 10 additions & 10 deletions test/pred-sets/ref/cats_load.predict
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
2.3137953,0.40625
1.8475317,0.40625
2.1299162,0.40625
1.7430687,0.40625
14.327164,0.00625
1.7164373,0.40625
1.748255,0.40625
30.00431,0.00625
0.51225615,0.40625
2.3816023,0.40625
25.271912,0.00625
2.1290207,0.40625
4.473629,0.00625
18.205719,0.00625
2.4059067,0.40625
9.585819,0.00625
0.93930376,0.40625
1.7801242,0.40625
2.0926502,0.40625
0.8038121,0.40625
200 changes: 100 additions & 100 deletions test/pred-sets/ref/cats_room_temp.predict
Original file line number Diff line number Diff line change
@@ -1,100 +1,100 @@
84.607056,0.005
3.9181187,0.08833334
39.260834,0.005
17.99776,0.08833334
39.331844,0.005
6.7627044,0.005
77.12543,0.005
69.8972,0.005
17.184662,0.08833334
54.76903,0.005
0.66596967,0.005
13.026539,0.005
52.522453,0.08833334
86.18144,0.005
49.648746,0.08833334
53.096367,0.08833334
41.92059,0.005
88.03693,0.005
50.15336,0.08833334
1.6463974,0.005
57.090523,0.08833334
25.416866,0.005
94.471146,0.005
43.387676,0.005
59.151047,0.08833334
38.3481,0.005
51.486103,0.08833334
50.88297,0.08833334
85.89383,0.005
55.38167,0.005
51.260296,0.08833334
85.17497,0.005
50.680935,0.08833334
49.601463,0.08833334
59.61712,0.005
76.559,0.005
51.651062,0.08833334
62.96363,0.005
7.8253226,0.005
32.741436,0.005
21.816528,0.005
53.33639,0.08833334
33.895855,0.005
48.782,0.08833334
48.93803,0.08833334
48.645847,0.08833334
53.199036,0.08833334
77.73645,0.005
62.783302,0.005
18.494978,0.005
66.79538,0.005
53.29303,0.08833334
53.5938,0.08833334
47.533524,0.005
8.389192,0.005
28.030102,0.005
52.605915,0.08833334
53.525967,0.08833334
50.158722,0.08833334
47.26561,0.005
72.41367,0.005
1.1448574,0.005
53.4234,0.08833334
88.777985,0.005
95.276276,0.005
49.196915,0.08833334
48.705658,0.08833334
51.38816,0.08833334
9.958978,0.005
95.176796,0.005
56.863132,0.005
48.599724,0.08833334
42.992935,0.005
49.75944,0.08833334
21.298155,0.005
21.954062,0.005
54.1602,0.08833334
53.210964,0.08833334
52.06177,0.08833334
55.077793,0.005
50.11463,0.08833334
55.921093,0.005
67.825,0.005
74.41847,0.005
79.864365,0.005
24.734612,0.005
53.9586,0.08833334
94.53804,0.005
48.65285,0.08833334
67.55201,0.005
48.724384,0.08833334
50.91952,0.08833334
40.403355,0.005
46.697994,0.005
49.63687,0.08833334
50.199287,0.08833334
52.23696,0.08833334
48.77658,0.08833334
52.563095,0.08833334
14.676942,0.005
53.204613,0.08833334
9.196535,0.005
17.174236,0.08833334
77.29088,0.005
15.437802,0.08833334
93.60582,0.005
73.45269,0.005
19.235268,0.08833334
40.03132,0.005
18.64663,0.08833334
66.666214,0.005
53.193462,0.08833334
34.754963,0.005
34.70198,0.005
57.403484,0.005
67.80745,0.005
86.91858,0.005
52.78149,0.08833334
43.365547,0.005
55.016094,0.08833334
8.145193,0.005
57.61373,0.08833334
50.82185,0.08833334
48.282825,0.005
81.51001,0.005
26.397629,0.005
23.397736,0.005
50.563953,0.08833334
53.253433,0.08833334
53.1661,0.08833334
50.818462,0.08833334
72.26961,0.005
73.63446,0.005
70.71592,0.005
23.83146,0.005
51.322205,0.08833334
26.74537,0.005
54.481422,0.08833334
52.270885,0.08833334
48.716354,0.08833334
9.419978,0.005
2.7159307,0.005
51.904087,0.08833334
51.104538,0.08833334
0.30527782,0.005
76.84366,0.005
20.968102,0.005
58.275642,0.005
6.868239,0.005
40.830856,0.005
54.12306,0.08833334
51.63766,0.08833334
95.59329,0.005
48.718655,0.08833334
53.290802,0.08833334
80.52641,0.005
50.78907,0.08833334
12.805883,0.005
89.139755,0.005
76.673355,0.005
55.5548,0.005
11.593977,0.005
51.98833,0.08833334
30.242346,0.005
66.171425,0.005
51.027653,0.08833334
53.797455,0.08833334
91.31642,0.005
75.98573,0.005
51.8337,0.08833334
95.297424,0.005
54.301224,0.08833334
49.036102,0.08833334
31.043774,0.005
56.56052,0.005
48.77767,0.08833334
53.00767,0.08833334
49.76527,0.08833334
28.06632,0.005
53.554474,0.08833334
54.24915,0.08833334
15.149416,0.005
48.76757,0.08833334
48.71854,0.08833334
52.769234,0.08833334
79.30612,0.005
2.7464993,0.005
53.082054,0.08833334
58.80374,0.005
49.790363,0.08833334
56.0045,0.005
88.90586,0.005
52.904503,0.08833334
52.2649,0.08833334
11.048922,0.005
53.87539,0.08833334
49.905197,0.08833334
42.230206,0.005
42.714077,0.005
68.1033,0.005
10 changes: 5 additions & 5 deletions test/pred-sets/ref/cats_save.predict
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
0.0012017498,0.00625
0.6113522,0.40625
0.5820448,0.40625
2.2343192,0.40625
20.461426,0.00625
1.2286053,0.40625
10.350056,0.00625
22.592487,0.00625
2.2877245,0.40625
1.3163464,0.40625
2.1290207,0.40625
18.205719,0.00625
9.585819,0.00625
1.7801242,0.40625
0.8038121,0.40625
12 changes: 6 additions & 6 deletions test/train-sets/ref/0001-replay.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ loss last counter weight label predict feat
0.250000 0.250000 8 8.0 0.0000 0.0000 860
0.312500 0.375000 16 16.0 1.0000 0.0000 128
0.343750 0.375000 32 32.0 0.0000 0.0000 176
0.350946 0.358141 64 64.0 0.0000 0.0275 350
0.371181 0.391417 128 128.0 1.0000 0.1510 620
0.292717 0.214252 256 256.0 0.0000 0.3660 410
0.149202 0.005688 512 512.0 0.0000 0.0000 278
0.074603 0.000004 1024 1024.0 1.0000 1.0000 170
0.359375 0.375000 64 64.0 0.0000 0.0275 350
0.414062 0.468750 128 128.0 1.0000 0.1510 620
0.433594 0.453125 256 256.0 0.0000 0.3660 410
0.441406 0.449219 512 512.0 0.0000 0.0000 278
0.453125 0.464844 1024 1024.0 1.0000 1.0000 170

finished run
number of examples per pass = 200
passes used = 8
weighted example sum = 1600.000000
weighted label sum = 728.000000
average loss = 0.047746
average loss = 0.455000
best constant = 0.455000
best constant's loss = 0.247975
total feature number = 717536
Loading

0 comments on commit 66d5572

Please sign in to comment.