diff --git a/tests/unit_tests/test_pick_best_text_embedder.py b/tests/unit_tests/test_pick_best_text_embedder.py index 02fc114..ea34db8 100644 --- a/tests/unit_tests/test_pick_best_text_embedder.py +++ b/tests/unit_tests/test_pick_best_text_embedder.py @@ -11,7 +11,9 @@ def test_pickbest_textembedder_missing_context_not_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - event = pick_best_chain.PickBestEvent(inputs={"action": ToSelectFrom(["0", "1", "2"])}) + event = pick_best_chain.PickBestEvent( + inputs={"action": ToSelectFrom(["0", "1", "2"])} + ) featurizer.featurize(event) @@ -59,7 +61,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -105,8 +107,9 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: event = pick_best_chain.PickBestEvent( inputs={ "context": rl_chain.Embed(BasedOn("ctx")), - "action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])) - }, selected=selected + "action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -132,8 +135,9 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: event = pick_best_chain.PickBestEvent( inputs={ "context": rl_chain.EmbedAndKeep(BasedOn("ctx")), - "action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])) - }, selected=selected + "action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -155,8 +159,8 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: inputs={ "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), "context1": BasedOn("context1"), - "context2": BasedOn("context2") - } + "context2": BasedOn("context2"), + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -179,9 +183,9 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: inputs={ "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), "context1": BasedOn("context1"), - "context2": BasedOn("context2") - }, - selected=selected + "context2": BasedOn("context2"), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -204,9 +208,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: inputs={ "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), "context1": BasedOn("context1"), - "context2": BasedOn("context2") - }, - selected=selected + "context2": BasedOn("context2"), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -236,9 +240,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None inputs={ "context1": BasedOn(rl_chain.Embed(ctx_str_1)), "context2": BasedOn(rl_chain.Embed(ctx_str_2)), - "action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])) + "action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -270,9 +274,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee inputs={ "context1": BasedOn(rl_chain.EmbedAndKeep(ctx_str_1)), "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), - "action": ToSelectFrom(rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"])) + "action": ToSelectFrom( + rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"]) + ), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -301,9 +307,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N inputs={ "context1": BasedOn(ctx_str_1), "context2": BasedOn(rl_chain.Embed(ctx_str_2)), - "action": ToSelectFrom([{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")]) + "action": ToSelectFrom( + [{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")] + ), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -331,9 +339,15 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() inputs={ "context1": BasedOn(ctx_str_1), "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), - "action": ToSelectFrom([{"a": "0", "b": rl_chain.EmbedAndKeep("0")}, "1", rl_chain.EmbedAndKeep("2")]) + "action": ToSelectFrom( + [ + {"a": "0", "b": rl_chain.EmbedAndKeep("0")}, + "1", + rl_chain.EmbedAndKeep("2"), + ] + ), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -360,10 +374,7 @@ def test_raw_features_underscored() -> None: ) event = pick_best_chain.PickBestEvent( - inputs={ - "action": ToSelectFrom([str1]), - "context": BasedOn(ctx_str) - } + inputs={"action": ToSelectFrom([str1]), "context": BasedOn(ctx_str)} ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_no_embed) @@ -375,7 +386,7 @@ def test_raw_features_underscored() -> None: event = pick_best_chain.PickBestEvent( inputs={ "action": ToSelectFrom(rl_chain.Embed([str1])), - "context": BasedOn(rl_chain.Embed(ctx_str)) + "context": BasedOn(rl_chain.Embed(ctx_str)), } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) @@ -391,7 +402,7 @@ def test_raw_features_underscored() -> None: event = pick_best_chain.PickBestEvent( inputs={ "action": ToSelectFrom(rl_chain.EmbedAndKeep([str1])), - "context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)) + "context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)), } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))