Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano committed Nov 16, 2023
1 parent 86f3827 commit 9643dff
Showing 1 changed file with 39 additions and 28 deletions.
67 changes: 39 additions & 28 deletions tests/unit_tests/test_pick_best_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def test_pickbest_textembedder_missing_context_not_throws() -> None:
featurizer = pick_best_chain.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
)
event = pick_best_chain.PickBestEvent(inputs={"action": ToSelectFrom(["0", "1", "2"])})
event = pick_best_chain.PickBestEvent(
inputs={"action": ToSelectFrom(["0", "1", "2"])}
)
featurizer.featurize(event)


Expand Down Expand Up @@ -59,7 +61,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBestEvent(
inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])},
selected=selected
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand Down Expand Up @@ -105,8 +107,9 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None:
event = pick_best_chain.PickBestEvent(
inputs={
"context": rl_chain.Embed(BasedOn("ctx")),
"action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"]))
}, selected=selected
"action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])),
},
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand All @@ -132,8 +135,9 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
event = pick_best_chain.PickBestEvent(
inputs={
"context": rl_chain.EmbedAndKeep(BasedOn("ctx")),
"action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"]))
}, selected=selected
"action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])),
},
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand All @@ -155,8 +159,8 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
inputs={
"action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]),
"context1": BasedOn("context1"),
"context2": BasedOn("context2")
}
"context2": BasedOn("context2"),
}
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand All @@ -179,9 +183,9 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
inputs={
"action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]),
"context1": BasedOn("context1"),
"context2": BasedOn("context2")
},
selected=selected
"context2": BasedOn("context2"),
},
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand All @@ -204,9 +208,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
inputs={
"action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]),
"context1": BasedOn("context1"),
"context2": BasedOn("context2")
},
selected=selected
"context2": BasedOn("context2"),
},
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand Down Expand Up @@ -236,9 +240,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None
inputs={
"context1": BasedOn(rl_chain.Embed(ctx_str_1)),
"context2": BasedOn(rl_chain.Embed(ctx_str_2)),
"action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"]))
"action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])),
},
selected=selected
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand Down Expand Up @@ -270,9 +274,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
inputs={
"context1": BasedOn(rl_chain.EmbedAndKeep(ctx_str_1)),
"context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)),
"action": ToSelectFrom(rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"]))
"action": ToSelectFrom(
rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"])
),
},
selected=selected
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand Down Expand Up @@ -301,9 +307,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N
inputs={
"context1": BasedOn(ctx_str_1),
"context2": BasedOn(rl_chain.Embed(ctx_str_2)),
"action": ToSelectFrom([{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")])
"action": ToSelectFrom(
[{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")]
),
},
selected=selected
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand Down Expand Up @@ -331,9 +339,15 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
inputs={
"context1": BasedOn(ctx_str_1),
"context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)),
"action": ToSelectFrom([{"a": "0", "b": rl_chain.EmbedAndKeep("0")}, "1", rl_chain.EmbedAndKeep("2")])
"action": ToSelectFrom(
[
{"a": "0", "b": rl_chain.EmbedAndKeep("0")},
"1",
rl_chain.EmbedAndKeep("2"),
]
),
},
selected=selected
selected=selected,
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected)
Expand All @@ -360,10 +374,7 @@ def test_raw_features_underscored() -> None:
)

event = pick_best_chain.PickBestEvent(
inputs={
"action": ToSelectFrom([str1]),
"context": BasedOn(ctx_str)
}
inputs={"action": ToSelectFrom([str1]), "context": BasedOn(ctx_str)}
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
assert_vw_ex_equals(vw_ex_str, expected_no_embed)
Expand All @@ -375,7 +386,7 @@ def test_raw_features_underscored() -> None:
event = pick_best_chain.PickBestEvent(
inputs={
"action": ToSelectFrom(rl_chain.Embed([str1])),
"context": BasedOn(rl_chain.Embed(ctx_str))
"context": BasedOn(rl_chain.Embed(ctx_str)),
}
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
Expand All @@ -391,7 +402,7 @@ def test_raw_features_underscored() -> None:
event = pick_best_chain.PickBestEvent(
inputs={
"action": ToSelectFrom(rl_chain.EmbedAndKeep([str1])),
"context": BasedOn(rl_chain.EmbedAndKeep(ctx_str))
"context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)),
}
)
vw_ex_str = vw_cb_formatter(*featurizer.featurize(event))
Expand Down

0 comments on commit 9643dff

Please sign in to comment.