Skip to content

Commit

Permalink
Cleanups in featurization (#35)
Browse files Browse the repository at this point in the history
* := -> =

* tests fix

* get_context_actions to featurized

* black

* exceptions in one place

* all cb actions checks to PickBestEvent ctr

* black

* naming cleanup
  • Loading branch information
ataymano authored Nov 16, 2023
1 parent e5f91fe commit 4aee78d
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 212 deletions.
21 changes: 8 additions & 13 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,21 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam
return [parser.parse_line(line) for line in input_str.split("\n")]


def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
to_select_from = {
k: inputs[k].value
def get_based_on(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
if isinstance(inputs[k], _BasedOn)
}

if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

based_on = {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
def get_to_select_from(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
if isinstance(inputs[k], _ToSelectFrom)
}

return based_on, to_select_from


# end helper functions

Expand Down
76 changes: 30 additions & 46 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,19 @@ class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected or PickBestSelected())
self.to_select_from = to_select_from
self.based_on = based_on

def context(self, model) -> base.Featurized:
return base.embed(self.based_on or {}, model)

def actions(self, model) -> List[base.Featurized]:
to_select_from_var_name, to_select_from = next(
iter(self.to_select_from.items()), (None, None)
)

action_embs = (
(
base.embed(to_select_from, model, to_select_from_var_name)
if self.to_select_from
else None
self.to_select_from = base.get_to_select_from(inputs)
self.based_on = base.get_based_on(inputs)
if not self.to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)
if to_select_from
else None
)
if not action_embs:
if len(self.to_select_from) > 1:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)
return action_embs


class VwTxt:
Expand All @@ -77,9 +60,9 @@ def _sparse_2_str(values: base.SparseFeatures) -> str:
def _to_str(v):
import numbers

return v if isinstance(v, numbers.Number) else f"={v}"
return f":{v}" if isinstance(v, numbers.Number) else f"={v}"

return " ".join([f"{k}:{_to_str(v)}" for k, v in values.items()])
return " ".join([f"{k}{_to_str(v)}" for k, v in values.items()])

@staticmethod
def featurized_2_str(obj: base.Featurized) -> str:
Expand Down Expand Up @@ -157,11 +140,29 @@ def _generic_namespaces(context, actions):
for a in actions:
a["#"] = PickBestFeaturizer._generic_namespace(a)

def get_context_and_actions(
self, event
) -> Tuple[base.Featurized, List[base.Featurized]]:
context = base.embed(event.based_on or {}, self.model)
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)

actions = (
(
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)
return context, actions

def featurize(
self, event: PickBestEvent
) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]:
context = event.context(self.model)
actions = event.actions(self.model)
context, actions = self.get_context_and_actions(event)

if self.auto_embed:
self._dotproducts(context, actions)
Expand Down Expand Up @@ -224,24 +225,7 @@ class PickBest(base.RLLoop[PickBestEvent]):
"""

def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

if len(list(actions.values())) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)

if not context:
raise ValueError(
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
)

event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
return event
return PickBestEvent(inputs=inputs)

def _call_after_predict_before_scoring(
self,
Expand Down
35 changes: 17 additions & 18 deletions tests/unit_tests/test_pick_best_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def test_multiple_ToSelectFrom_throws() -> None:
)


def test_missing_basedOn_from_throws() -> None:
def test_missing_basedOn_from_dont_throw() -> None:
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller,
featurizer=learn_to_pick.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
pick.run(action=learn_to_pick.ToSelectFrom(actions))
pick.run(action=learn_to_pick.ToSelectFrom(actions))


def test_ToSelectFrom_not_a_list_throws() -> None:
Expand Down Expand Up @@ -169,10 +168,10 @@ def test_everything_embedded() -> None:

expected = "\n".join(
[
f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft:={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str2}",
f"|action_dense {action_dense} |action_sparse default_ft:={str3}",
f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse default_ft={str1}",
f"|action_dense {action_dense} |action_sparse default_ft={str2}",
f"|action_dense {action_dense} |action_sparse default_ft={str3}",
]
) # noqa

Expand All @@ -198,10 +197,10 @@ def test_default_auto_embedder_is_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
f"shared |User_sparse default_ft={ctx_str_1}",
f"|action_sparse default_ft={str1}",
f"|action_sparse default_ft={str2}",
f"|action_sparse default_ft={str3}",
]
) # noqa

Expand All @@ -227,10 +226,10 @@ def test_default_w_embeddings_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
f"shared |User_sparse default_ft={ctx_str_1}",
f"|action_sparse default_ft={str1}",
f"|action_sparse default_ft={str2}",
f"|action_sparse default_ft={str3}",
]
) # noqa

Expand Down Expand Up @@ -258,9 +257,9 @@ def test_default_w_embeddings_on() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1} |@_sparse User:={ctx_str_1}",
f"|action_sparse default_ft:={str1} |{dot_prod} |#_sparse action:={str1} ",
f"|action_sparse default_ft:={str2} |{dot_prod} |#_sparse action:={str2} ",
f"shared |User_sparse default_ft={ctx_str_1} |@_sparse User={ctx_str_1}",
f"|action_sparse default_ft={str1} |{dot_prod} |#_sparse action={str1} ",
f"|action_sparse default_ft={str2} |{dot_prod} |#_sparse action={str2} ",
]
) # noqa

Expand Down
Loading

0 comments on commit 4aee78d

Please sign in to comment.