From 66d5572996483d357ba2b4c5ce0ec17bcca07982 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 31 Aug 2023 10:50:52 -0400 Subject: [PATCH] fix: stash and restore prediction when calling learn during learn_returns_prediction==false calls (#4632) * fix: stash and restore prediction when calling learn during learn_returns_prediction==false calls * move dont swap * check ci * comment out * go back to swap, fix igl's usage of pred in learn * copy predictions for compat * formatting * undo copy reductions * cats learn returns prediction * update tests * more cats tests * update test --------- Co-authored-by: Jack Gerrits --- test/pred-sets/ref/cats.predict | 114 +++++----- test/pred-sets/ref/cats_load.predict | 20 +- test/pred-sets/ref/cats_room_temp.predict | 200 +++++++++--------- test/pred-sets/ref/cats_save.predict | 10 +- test/train-sets/ref/0001-replay.stderr | 12 +- test/train-sets/ref/cats-predict.stderr | 14 +- test/train-sets/ref/cats-train.stderr | 14 +- test/train-sets/ref/cats_load.stderr | 8 +- test/train-sets/ref/cats_room_temp.stderr | 16 +- .../train-sets/ref/cats_room_temp_pred.stderr | 16 +- test/train-sets/ref/cats_save.stderr | 8 +- vowpalwabbit/core/include/vw/core/example.h | 2 + vowpalwabbit/core/src/example.cc | 46 ++++ vowpalwabbit/core/src/global_data.cc | 7 + vowpalwabbit/core/src/reductions/cats.cc | 3 + .../core/src/reductions/interaction_ground.cc | 66 +++--- 16 files changed, 310 insertions(+), 246 deletions(-) diff --git a/test/pred-sets/ref/cats.predict b/test/pred-sets/ref/cats.predict index df1a5c18b8f..fc99858dc31 100644 --- a/test/pred-sets/ref/cats.predict +++ b/test/pred-sets/ref/cats.predict @@ -1,57 +1,57 @@ -17118.012,0.00016043648 -15724.644,0.00016043648 -15331.798,0.00016043648 -15138.831,0.00016043648 -13393.029,0.00016043648 -12710.978,0.00016043648 -14503.708,0.00016043648 -15936.4,0.00016043648 -17257.076,0.00016043648 -17208.654,0.00016043648 -14921.187,0.00016043648 -16660.594,0.00016043648 -13486.455,0.00016043648 -6641.8545,2.1031378e-06 -14801.345,0.00016043648 -17399.625,0.00016043648 -16820.953,0.00016043648 -13160.789,0.00016043648 -15094.871,0.00016043648 -13487.722,0.00016043648 -17514.398,0.00016043648 -12287.391,0.00016043648 -17394.318,0.00016043648 -12249.621,0.00016043648 -12830.798,0.00016043648 -14759.012,0.00016043648 -17850.668,0.00016043648 -15202.806,0.00016043648 -16561.379,0.00016043648 -12855.774,0.00016043648 -13520.215,0.00016043648 -14063.967,0.00016043648 -16150.232,0.00016043648 -12360.427,0.00016043648 -13487.657,0.00016043648 -17476.367,0.00016043648 -15890.128,0.00016043648 -14077.783,0.00016043648 -17775.543,0.00016043648 -15440.219,0.00016043648 -16326.922,0.00016043648 -13204.399,0.00016043648 -17127.232,0.00016043648 -17641.572,0.00016043648 -13169.023,0.00016043648 -16228.357,0.00016043648 -14428.307,0.00016043648 -16230.57,0.00016043648 -12099.055,0.00016043648 -17408.406,0.00016043648 -17183.14,0.00016043648 -13980.631,0.00016043648 -16711.67,0.00016043648 -1768.2766,2.1031378e-06 -12294.268,0.00016043648 -15582.021,0.00016043648 -17690.637,0.00016043648 +17540.906,0.00016043648 +15536.194,0.00016043648 +17381.867,0.00016043648 +12421.002,0.00016043648 +13341.685,0.00016043648 +15988.09,0.00016043648 +13866.5,0.00016043648 +17106.008,0.00016043648 +15123.611,0.00016043648 +15783.712,0.00016043648 +22992.766,2.1031378e-06 +7194.4995,2.1031378e-06 +15160.824,0.00016043648 +17345.219,0.00016043648 +14788.842,0.00016043648 +14527.182,0.00016043648 +16322.045,0.00016043648 +12985.183,0.00016043648 +13590.578,0.00016043648 +14753.838,0.00016043648 +17058.365,0.00016043648 +13427.974,0.00016043648 +16906.275,0.00016043648 +14033.336,0.00016043648 +12143.846,0.00016043648 +14244.833,0.00016043648 +16280.45,0.00016043648 +16732.473,0.00016043648 +14155.6875,0.00016043648 +15759.547,0.00016043648 +14163.108,0.00016043648 +14782.351,0.00016043648 +12062.178,0.00016043648 +22556.402,2.1031378e-06 +16126.391,0.00016043648 +17084.584,0.00016043648 +17965.973,0.00016043648 +12816.858,0.00016043648 +14803.017,0.00016043648 +16711.629,0.00016043648 +14877.602,0.00016043648 +15326.174,0.00016043648 +13539.921,0.00016043648 +15957.607,0.00016043648 +15291.404,0.00016043648 +13974.891,0.00016043648 +17413.562,0.00016043648 +13018.842,0.00016043648 +17922.018,0.00016043648 +17293.945,0.00016043648 +15109.656,0.00016043648 +16252.369,0.00016043648 +14785.5625,0.00016043648 +17082.445,0.00016043648 +15951.468,0.00016043648 +12971.434,0.00016043648 +12969.782,0.00016043648 diff --git a/test/pred-sets/ref/cats_load.predict b/test/pred-sets/ref/cats_load.predict index 27c5e3b9f13..1402f6525d5 100644 --- a/test/pred-sets/ref/cats_load.predict +++ b/test/pred-sets/ref/cats_load.predict @@ -1,10 +1,10 @@ -2.3137953,0.40625 -1.8475317,0.40625 -2.1299162,0.40625 -1.7430687,0.40625 -14.327164,0.00625 -1.7164373,0.40625 -1.748255,0.40625 -30.00431,0.00625 -0.51225615,0.40625 -2.3816023,0.40625 +25.271912,0.00625 +2.1290207,0.40625 +4.473629,0.00625 +18.205719,0.00625 +2.4059067,0.40625 +9.585819,0.00625 +0.93930376,0.40625 +1.7801242,0.40625 +2.0926502,0.40625 +0.8038121,0.40625 diff --git a/test/pred-sets/ref/cats_room_temp.predict b/test/pred-sets/ref/cats_room_temp.predict index b0c0474ee4c..6fd96ab4646 100644 --- a/test/pred-sets/ref/cats_room_temp.predict +++ b/test/pred-sets/ref/cats_room_temp.predict @@ -1,100 +1,100 @@ -84.607056,0.005 -3.9181187,0.08833334 -39.260834,0.005 -17.99776,0.08833334 -39.331844,0.005 -6.7627044,0.005 -77.12543,0.005 -69.8972,0.005 -17.184662,0.08833334 -54.76903,0.005 -0.66596967,0.005 -13.026539,0.005 -52.522453,0.08833334 -86.18144,0.005 -49.648746,0.08833334 -53.096367,0.08833334 -41.92059,0.005 -88.03693,0.005 -50.15336,0.08833334 -1.6463974,0.005 -57.090523,0.08833334 -25.416866,0.005 -94.471146,0.005 -43.387676,0.005 -59.151047,0.08833334 -38.3481,0.005 -51.486103,0.08833334 -50.88297,0.08833334 -85.89383,0.005 -55.38167,0.005 -51.260296,0.08833334 -85.17497,0.005 -50.680935,0.08833334 -49.601463,0.08833334 -59.61712,0.005 -76.559,0.005 -51.651062,0.08833334 -62.96363,0.005 -7.8253226,0.005 -32.741436,0.005 -21.816528,0.005 -53.33639,0.08833334 -33.895855,0.005 -48.782,0.08833334 -48.93803,0.08833334 -48.645847,0.08833334 -53.199036,0.08833334 -77.73645,0.005 -62.783302,0.005 -18.494978,0.005 -66.79538,0.005 -53.29303,0.08833334 -53.5938,0.08833334 -47.533524,0.005 -8.389192,0.005 -28.030102,0.005 -52.605915,0.08833334 -53.525967,0.08833334 -50.158722,0.08833334 -47.26561,0.005 -72.41367,0.005 -1.1448574,0.005 -53.4234,0.08833334 -88.777985,0.005 -95.276276,0.005 -49.196915,0.08833334 -48.705658,0.08833334 -51.38816,0.08833334 -9.958978,0.005 -95.176796,0.005 -56.863132,0.005 -48.599724,0.08833334 -42.992935,0.005 -49.75944,0.08833334 -21.298155,0.005 -21.954062,0.005 -54.1602,0.08833334 -53.210964,0.08833334 -52.06177,0.08833334 -55.077793,0.005 -50.11463,0.08833334 -55.921093,0.005 -67.825,0.005 -74.41847,0.005 -79.864365,0.005 -24.734612,0.005 -53.9586,0.08833334 -94.53804,0.005 -48.65285,0.08833334 -67.55201,0.005 -48.724384,0.08833334 -50.91952,0.08833334 -40.403355,0.005 -46.697994,0.005 -49.63687,0.08833334 -50.199287,0.08833334 -52.23696,0.08833334 -48.77658,0.08833334 -52.563095,0.08833334 -14.676942,0.005 +53.204613,0.08833334 +9.196535,0.005 +17.174236,0.08833334 +77.29088,0.005 +15.437802,0.08833334 +93.60582,0.005 +73.45269,0.005 +19.235268,0.08833334 +40.03132,0.005 +18.64663,0.08833334 +66.666214,0.005 +53.193462,0.08833334 +34.754963,0.005 +34.70198,0.005 +57.403484,0.005 +67.80745,0.005 +86.91858,0.005 +52.78149,0.08833334 +43.365547,0.005 +55.016094,0.08833334 +8.145193,0.005 +57.61373,0.08833334 +50.82185,0.08833334 +48.282825,0.005 +81.51001,0.005 +26.397629,0.005 +23.397736,0.005 +50.563953,0.08833334 +53.253433,0.08833334 +53.1661,0.08833334 +50.818462,0.08833334 +72.26961,0.005 +73.63446,0.005 +70.71592,0.005 +23.83146,0.005 +51.322205,0.08833334 +26.74537,0.005 +54.481422,0.08833334 +52.270885,0.08833334 +48.716354,0.08833334 +9.419978,0.005 +2.7159307,0.005 +51.904087,0.08833334 +51.104538,0.08833334 +0.30527782,0.005 +76.84366,0.005 +20.968102,0.005 +58.275642,0.005 +6.868239,0.005 +40.830856,0.005 +54.12306,0.08833334 +51.63766,0.08833334 +95.59329,0.005 +48.718655,0.08833334 +53.290802,0.08833334 +80.52641,0.005 +50.78907,0.08833334 +12.805883,0.005 +89.139755,0.005 +76.673355,0.005 +55.5548,0.005 +11.593977,0.005 +51.98833,0.08833334 +30.242346,0.005 +66.171425,0.005 +51.027653,0.08833334 +53.797455,0.08833334 +91.31642,0.005 +75.98573,0.005 +51.8337,0.08833334 +95.297424,0.005 +54.301224,0.08833334 +49.036102,0.08833334 +31.043774,0.005 +56.56052,0.005 +48.77767,0.08833334 +53.00767,0.08833334 +49.76527,0.08833334 +28.06632,0.005 +53.554474,0.08833334 +54.24915,0.08833334 +15.149416,0.005 +48.76757,0.08833334 +48.71854,0.08833334 +52.769234,0.08833334 +79.30612,0.005 +2.7464993,0.005 +53.082054,0.08833334 +58.80374,0.005 +49.790363,0.08833334 +56.0045,0.005 +88.90586,0.005 +52.904503,0.08833334 +52.2649,0.08833334 +11.048922,0.005 +53.87539,0.08833334 +49.905197,0.08833334 +42.230206,0.005 +42.714077,0.005 +68.1033,0.005 diff --git a/test/pred-sets/ref/cats_save.predict b/test/pred-sets/ref/cats_save.predict index c1e7939b11b..c2113a7759b 100644 --- a/test/pred-sets/ref/cats_save.predict +++ b/test/pred-sets/ref/cats_save.predict @@ -1,10 +1,10 @@ +0.0012017498,0.00625 0.6113522,0.40625 +0.5820448,0.40625 2.2343192,0.40625 +20.461426,0.00625 1.2286053,0.40625 +10.350056,0.00625 22.592487,0.00625 +2.2877245,0.40625 1.3163464,0.40625 -2.1290207,0.40625 -18.205719,0.00625 -9.585819,0.00625 -1.7801242,0.40625 -0.8038121,0.40625 diff --git a/test/train-sets/ref/0001-replay.stderr b/test/train-sets/ref/0001-replay.stderr index c666aba5bec..fda47007a9c 100644 --- a/test/train-sets/ref/0001-replay.stderr +++ b/test/train-sets/ref/0001-replay.stderr @@ -18,18 +18,18 @@ loss last counter weight label predict feat 0.250000 0.250000 8 8.0 0.0000 0.0000 860 0.312500 0.375000 16 16.0 1.0000 0.0000 128 0.343750 0.375000 32 32.0 0.0000 0.0000 176 -0.350946 0.358141 64 64.0 0.0000 0.0275 350 -0.371181 0.391417 128 128.0 1.0000 0.1510 620 -0.292717 0.214252 256 256.0 0.0000 0.3660 410 -0.149202 0.005688 512 512.0 0.0000 0.0000 278 -0.074603 0.000004 1024 1024.0 1.0000 1.0000 170 +0.359375 0.375000 64 64.0 0.0000 0.0275 350 +0.414062 0.468750 128 128.0 1.0000 0.1510 620 +0.433594 0.453125 256 256.0 0.0000 0.3660 410 +0.441406 0.449219 512 512.0 0.0000 0.0000 278 +0.453125 0.464844 1024 1024.0 1.0000 1.0000 170 finished run number of examples per pass = 200 passes used = 8 weighted example sum = 1600.000000 weighted label sum = 728.000000 -average loss = 0.047746 +average loss = 0.455000 best constant = 0.455000 best constant's loss = 0.247975 total feature number = 717536 diff --git a/test/train-sets/ref/cats-predict.stderr b/test/train-sets/ref/cats-predict.stderr index f8c61ac13f1..137f0783596 100644 --- a/test/train-sets/ref/cats-predict.stderr +++ b/test/train-sets/ref/cats-predict.stderr @@ -15,17 +15,17 @@ Input label = CONTINUOUS Output pred = ACTION_PDF_VALUE average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 {185.12,0.6... 17118.01,0 10 -0.000000 0.000000 2 2.0 {772.59,0.4... 15724.64,0 10 -0.226921 0.453841 4 4.0 {14122,0.02,0} 15138.83,0 10 -0.321934 0.416948 8 8.0 {12715.1,0.... 15936.4,0 10 -0.177189 0.032443 16 16.0 {669.12,0.4... 17399.62,0 10 -0.189219 0.201250 32 32.0 {10786.7,0.... 14063.97,0 10 +0.000000 0.000000 1 1.0 {185.12,0.6... 17540.91,0 10 +0.000000 0.000000 2 2.0 {772.59,0.4... 15536.19,0 10 +0.226921 0.453841 4 4.0 {14122,0.02,0} 12421,0 10 +0.321934 0.416948 8 8.0 {12715.1,0.... 17106.01,0 10 +0.184223 0.046511 16 16.0 {669.12,0.4... 14527.18,0 10 +0.192736 0.201250 32 32.0 {10786.7,0.... 14782.35,0 10 finished run number of examples = 57 weighted example sum = 57.000000 weighted label sum = 57.000000 -average loss = 0.168316 +average loss = 0.174316 total feature number = 570 Learn() count per node: id=0, #l=17; id=1, #l=0; id=2, #l=0; diff --git a/test/train-sets/ref/cats-train.stderr b/test/train-sets/ref/cats-train.stderr index 0c9cd87a436..ca7a52ba115 100644 --- a/test/train-sets/ref/cats-train.stderr +++ b/test/train-sets/ref/cats-train.stderr @@ -15,17 +15,17 @@ Input label = CONTINUOUS Output pred = ACTION_PDF_VALUE average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 {185.12,0.6... 6324.15,0 10 -0.000000 0.000000 2 2.0 {772.59,0.4... 16299.34,0 10 -0.226921 0.453841 4 4.0 {14122,0.02,0} 17754.81,0 10 -0.321934 0.416948 8 8.0 {12715.1,0.... 17248.12,0 10 -0.200445 0.078955 16 16.0 {669.12,0.4... 14987.98,0 10 -0.165127 0.129809 32 32.0 {10786.7,0.... 13866.5,0 10 +1.774796 1.774796 1 1.0 {185.12,0.6... 188.57,0 10 +0.887398 0.000000 2 2.0 {772.59,0.4... 12189.73,0 10 +0.670619 0.453841 4 4.0 {14122,0.02,0} 16299.34,0 10 +0.543784 0.416948 8 8.0 {12715.1,0.... 17754.81,0 10 +0.311369 0.078955 16 16.0 {669.12,0.4... 17248.12,0 10 +0.256310 0.201250 32 32.0 {10786.7,0.... 14987.98,0 10 finished run number of examples = 57 weighted example sum = 57.000000 weighted label sum = 57.000000 -average loss = 0.186361 +average loss = 0.231863 total feature number = 570 Learn() count per node: id=0, #l=17; id=1, #l=0; id=2, #l=0; diff --git a/test/train-sets/ref/cats_load.stderr b/test/train-sets/ref/cats_load.stderr index eef510531fb..4080728be24 100644 --- a/test/train-sets/ref/cats_load.stderr +++ b/test/train-sets/ref/cats_load.stderr @@ -16,10 +16,10 @@ Input label = CONTINUOUS Output pred = ACTION_PDF_VALUE average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 {0,0,0.01} 2.31,0.41 6 -0.000000 0.000000 2 2.0 {0.58,0,0.41} 1.85,0.41 6 -0.000000 0.000000 4 4.0 {10.35,0,0.01} 1.74,0.41 6 -0.000000 0.000000 8 8.0 {2.41,0,0.41} 30,0.01 6 +0.000000 0.000000 1 1.0 {0,0,0.01} 25.27,0.01 6 +0.000000 0.000000 2 2.0 {0.58,0,0.41} 2.13,0.41 6 +0.000000 0.000000 4 4.0 {10.35,0,0.01} 18.21,0.01 6 +0.000000 0.000000 8 8.0 {2.41,0,0.41} 1.78,0.41 6 finished run number of examples = 10 diff --git a/test/train-sets/ref/cats_room_temp.stderr b/test/train-sets/ref/cats_room_temp.stderr index 0daabb527fd..588863b1aea 100644 --- a/test/train-sets/ref/cats_room_temp.stderr +++ b/test/train-sets/ref/cats_room_temp.stderr @@ -15,18 +15,18 @@ Input label = CONTINUOUS Output pred = ACTION_PDF_VALUE average since example example current current current loss last counter weight label predict features -1095.824 1095.824 1 1.0 {0,25,0} 2.14,0.09 3 -547.9122 0.000000 2 2.0 {4.07,21.1,... 41.54,0 3 -273.9561 0.000000 4 4.0 {72.94,5.26,0} 88.24,0 3 -136.9780 0.000000 8 8.0 {67.13,2.93... 71.98,0 3 -68.48903 0.000000 16 16.0 {6.01,19.35,0} 51.44,0.09 3 -34.30662 0.124206 32 32.0 {59.57,0.92... 49.41,0.09 3 -18.39109 2.475568 64 64.0 {23.4,7.08,0} 50.56,0.09 3 +1095.824 1095.824 1 1.0 {0,25,0} 0,0 3 +582.1616 68.49878 2 2.0 {4.07,21.1,... 2.14,0.09 3 +291.0808 0.000000 4 4.0 {72.94,5.26,0} 41.54,0 3 +145.5404 0.000000 8 8.0 {67.13,2.93... 88.24,0 3 +72.80454 0.068661 16 16.0 {6.01,19.35,0} 71.98,0 3 +36.48046 0.156381 32 32.0 {59.57,0.92... 51.44,0.09 3 +18.24775 0.015050 64 64.0 {23.4,7.08,0} 49.41,0.09 3 finished run number of examples = 100 weighted example sum = 100.000000 weighted label sum = 100.000000 -average loss = 11.778620 +average loss = 11.691978 total feature number = 300 Learn() count per node: id=0, #l=32; id=1, #l=18; id=2, #l=44; id=3, #l=10; id=4, #l=20; id=5, #l=48; id=6, #l=12; id=7, #l=8; id=8, #l=6; id=9, #l=7; id=10, #l=18; id=11, #l=28; id=12, #l=9; id=13, #l=9; id=14, #l=3; id=15, #l=0; diff --git a/test/train-sets/ref/cats_room_temp_pred.stderr b/test/train-sets/ref/cats_room_temp_pred.stderr index 93262b1c2d2..8cc65301a19 100644 --- a/test/train-sets/ref/cats_room_temp_pred.stderr +++ b/test/train-sets/ref/cats_room_temp_pred.stderr @@ -15,18 +15,18 @@ Input label = CONTINUOUS Output pred = ACTION_PDF_VALUE average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 {0,25,0} 84.61,0 3 -26.04380 52.08761 2 2.0 {4.07,21.1,... 3.92,0.09 3 -13.02190 0.000000 4 4.0 {72.94,5.26,0} 18,0.09 3 -6.510952 0.000000 8 8.0 {67.13,2.93... 69.9,0 3 -31.39586 56.28078 16 16.0 {6.01,19.35,0} 53.1,0.09 3 -15.71233 0.028805 32 32.0 {59.57,0.92... 85.17,0 3 -11.42273 7.133129 64 64.0 {23.4,7.08,0} 88.78,0 3 +0.000000 0.000000 1 1.0 {0,25,0} 53.2,0.09 3 +0.000000 0.000000 2 2.0 {4.07,21.1,... 9.2,0 3 +0.000000 0.000000 4 4.0 {72.94,5.26,0} 77.29,0 3 +72.07164 144.1432 8 8.0 {67.13,2.93... 19.24,0.09 3 +39.93551 7.799388 16 16.0 {6.01,19.35,0} 67.81,0 3 +20.00477 0.074031 32 32.0 {59.57,0.92... 72.27,0 3 +17.12056 14.23634 64 64.0 {23.4,7.08,0} 30.24,0 3 finished run number of examples = 100 weighted example sum = 100.000000 weighted label sum = 100.000000 -average loss = 8.189957 +average loss = 11.096396 total feature number = 300 Learn() count per node: id=0, #l=32; id=1, #l=18; id=2, #l=44; id=3, #l=10; id=4, #l=20; id=5, #l=48; id=6, #l=12; id=7, #l=8; id=8, #l=6; id=9, #l=7; id=10, #l=18; id=11, #l=28; id=12, #l=9; id=13, #l=9; id=14, #l=3; id=15, #l=0; diff --git a/test/train-sets/ref/cats_save.stderr b/test/train-sets/ref/cats_save.stderr index b7204693915..b6ba4c67f41 100644 --- a/test/train-sets/ref/cats_save.stderr +++ b/test/train-sets/ref/cats_save.stderr @@ -17,10 +17,10 @@ Input label = CONTINUOUS Output pred = ACTION_PDF_VALUE average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 {0,0,0.01} 0.61,0.41 6 -0.000000 0.000000 2 2.0 {0.58,0,0.41} 2.23,0.41 6 -0.000000 0.000000 4 4.0 {10.35,0,0.01} 22.59,0.01 6 -0.000000 0.000000 8 8.0 {2.41,0,0.41} 9.59,0.01 6 +0.000000 0.000000 1 1.0 {0,0,0.01} 0,0.01 6 +0.000000 0.000000 2 2.0 {0.58,0,0.41} 0.61,0.41 6 +0.000000 0.000000 4 4.0 {10.35,0,0.01} 2.23,0.41 6 +0.000000 0.000000 8 8.0 {2.41,0,0.41} 22.59,0.01 6 finished run number of examples = 10 diff --git a/vowpalwabbit/core/include/vw/core/example.h b/vowpalwabbit/core/include/vw/core/example.h index 30c6ba0a562..cd2185a8b08 100644 --- a/vowpalwabbit/core/include/vw/core/example.h +++ b/vowpalwabbit/core/include/vw/core/example.h @@ -18,6 +18,7 @@ #include "vw/core/multiclass.h" #include "vw/core/multilabel.h" #include "vw/core/no_label.h" +#include "vw/core/prediction_type.h" #include "vw/core/prob_dist_cont.h" #include "vw/core/simple_label.h" #include "vw/core/slates_label.h" @@ -82,6 +83,7 @@ class polyprediction }; std::string to_string(const v_array& scalars, int decimal_precision = details::DEFAULT_FLOAT_PRECISION); +void swap_prediction(polyprediction& a, polyprediction& b, prediction_type_t prediction_type); class example : public example_predict // core example datatype. { diff --git a/vowpalwabbit/core/src/example.cc b/vowpalwabbit/core/src/example.cc index c49cde7b69f..cb2c88be038 100644 --- a/vowpalwabbit/core/src/example.cc +++ b/vowpalwabbit/core/src/example.cc @@ -16,6 +16,52 @@ #include #include +void VW::swap_prediction(VW::polyprediction& a, VW::polyprediction& b, VW::prediction_type_t prediction_type) +{ + switch (prediction_type) + { + case VW::prediction_type_t::SCALAR: + std::swap(b.scalar, a.scalar); + break; + case VW::prediction_type_t::SCALARS: + std::swap(b.scalars, a.scalars); + break; + case VW::prediction_type_t::ACTION_SCORES: + std::swap(b.a_s, a.a_s); + break; + case VW::prediction_type_t::PDF: + std::swap(b.pdf, a.pdf); + break; + case VW::prediction_type_t::ACTION_PROBS: + std::swap(b.a_s, a.a_s); + break; + case VW::prediction_type_t::MULTICLASS: + std::swap(b.multiclass, a.multiclass); + break; + case VW::prediction_type_t::MULTILABELS: + std::swap(b.multilabels, a.multilabels); + break; + case VW::prediction_type_t::PROB: + std::swap(b.prob, a.prob); + break; + case VW::prediction_type_t::MULTICLASS_PROBS: + std::swap(b.scalars, a.scalars); + break; + case VW::prediction_type_t::DECISION_PROBS: + std::swap(b.decision_scores, a.decision_scores); + break; + case VW::prediction_type_t::ACTION_PDF_VALUE: + std::swap(b.pdf_value, a.pdf_value); + break; + case VW::prediction_type_t::ACTIVE_MULTICLASS: + std::swap(b.active_multiclass, a.active_multiclass); + break; + case VW::prediction_type_t::NOPRED: + // Noop + break; + } +} + float calculate_total_sum_features_squared(bool permutations, VW::example& ec) { float sum_features_squared = 0.f; diff --git a/vowpalwabbit/core/src/global_data.cc b/vowpalwabbit/core/src/global_data.cc index aa2d3584c27..ff3833d0682 100644 --- a/vowpalwabbit/core/src/global_data.cc +++ b/vowpalwabbit/core/src/global_data.cc @@ -5,6 +5,7 @@ #include "vw/core/global_data.h" #include "vw/config/options.h" +#include "vw/core/example.h" #include "vw/core/parse_regressor.h" #include "vw/core/reductions/metrics.h" @@ -91,7 +92,10 @@ void workspace::learn(example& ec) else { VW::LEARNER::require_singleline(l)->predict(ec); + VW::polyprediction saved_prediction; + VW::swap_prediction(ec.pred, saved_prediction, l->get_output_prediction_type()); VW::LEARNER::require_singleline(l)->learn(ec); + VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type()); } } } @@ -107,7 +111,10 @@ void workspace::learn(multi_ex& ec) else { VW::LEARNER::require_multiline(l)->predict(ec); + VW::polyprediction saved_prediction; + VW::swap_prediction(ec[0]->pred, saved_prediction, l->get_output_prediction_type()); VW::LEARNER::require_multiline(l)->learn(ec); + VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); } } } diff --git a/vowpalwabbit/core/src/reductions/cats.cc b/vowpalwabbit/core/src/reductions/cats.cc index 5c2ffd91da8..d6abbfa9a4e 100644 --- a/vowpalwabbit/core/src/reductions/cats.cc +++ b/vowpalwabbit/core/src/reductions/cats.cc @@ -51,8 +51,10 @@ int cats::learn(example& ec, experimental::api_status* status = nullptr) { assert(!ec.test_only); predict(ec, status); + auto pred = ec.pred.pdf_value; VW_DBG(ec) << "cats::learn(), " << to_string(ec.l.cb_cont) << VW::debug::features_to_string(ec) << endl; _base->learn(ec); + ec.pred.pdf_value = pred; return VW::experimental::error_code::success; } @@ -199,6 +201,7 @@ std::shared_ptr VW::reductions::cats_setup(setup_base_i& s auto l = make_reduction_learner(std::move(p_reduction), p_base, predict_or_learn, predict_or_learn, stack_builder.get_setupfn_name(cats_setup)) + .set_learn_returns_prediction(true) .set_input_label_type(VW::label_type_t::CONTINUOUS) .set_output_label_type(VW::label_type_t::CONTINUOUS) .set_input_prediction_type(VW::prediction_type_t::ACTION_PDF_VALUE) diff --git a/vowpalwabbit/core/src/reductions/interaction_ground.cc b/vowpalwabbit/core/src/reductions/interaction_ground.cc index 426b186e1fe..7df7f69ef7d 100644 --- a/vowpalwabbit/core/src/reductions/interaction_ground.cc +++ b/vowpalwabbit/core/src/reductions/interaction_ground.cc @@ -122,8 +122,42 @@ void add_obs_features_to_ik_ex(VW::example& ik_ex, const VW::example& obs_ex) } } +void predict(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq) +{ + VW::example* observation_ex = nullptr; + + if (ec_seq.size() > 0 && ec_seq.back()->l.cb_with_observations.is_observation) + { + observation_ex = ec_seq.back(); + ec_seq.pop_back(); + } + + std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); + + for (auto& ex : ec_seq) + { + ex->l.cb = ex->l.cb_with_observations.event; + ex->l.cb_with_observations.event.reset_to_default(); + } + + base.predict(ec_seq, 1); + std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); + + for (auto& ex : ec_seq) + { + ex->l.cb_with_observations.event = ex->l.cb; + ex->l.cb.reset_to_default(); + } + + if (observation_ex != nullptr) { ec_seq.push_back(observation_ex); } +} + void learn(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq) { + predict(igl, base, ec_seq); + + auto stashed_prediction = ec_seq[0]->pred.a_s; + float p_unlabeled_prior = 0.5f; std::swap(igl.ik_ftrl->all->loss_config.loss, igl.ik_all->loss_config.loss); @@ -212,36 +246,7 @@ void learn(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_s } ec_seq.push_back(observation_ex); -} - -void predict(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq) -{ - VW::example* observation_ex = nullptr; - - if (ec_seq.size() > 0 && ec_seq.back()->l.cb_with_observations.is_observation) - { - observation_ex = ec_seq.back(); - ec_seq.pop_back(); - } - - std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); - - for (auto& ex : ec_seq) - { - ex->l.cb = ex->l.cb_with_observations.event; - ex->l.cb_with_observations.event.reset_to_default(); - } - - base.predict(ec_seq, 1); - std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); - - for (auto& ex : ec_seq) - { - ex->l.cb_with_observations.event = ex->l.cb; - ex->l.cb.reset_to_default(); - } - - if (observation_ex != nullptr) { ec_seq.push_back(observation_ex); } + ec_seq[0]->pred.a_s = std::move(stashed_prediction); } void save_load_igl(VW::reductions::igl::igl_data& igl, VW::io_buf& io, bool read, bool text) @@ -415,6 +420,7 @@ std::shared_ptr VW::reductions::interaction_ground_setup(V auto l = make_reduction_learner( std::move(ld), pi_learner, learn, predict, stack_builder.get_setupfn_name(interaction_ground_setup)) .set_feature_width(feature_width) + .set_learn_returns_prediction(true) .set_input_label_type(label_type_t::CB_WITH_OBSERVATIONS) .set_output_label_type(label_type_t::CB) .set_input_prediction_type(pred_type)