Skip to content

Commit

Permalink
Merge pull request #31 from VowpalWabbit/byom
Browse files Browse the repository at this point in the history
Add pytorch policy
  • Loading branch information
cheng-tan authored Nov 28, 2023
2 parents 213e931 + d88ea8e commit f9f5c3d
Show file tree
Hide file tree
Showing 16 changed files with 761 additions and 65 deletions.
50 changes: 46 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Note: all code examples presented here can be found in `notebooks/readme.ipynb`
- Use a custom score function to grade the decision.
- Directly specify the score manually and asynchronously.

The beauty of `learn_to_pick` is its flexibility. Whether you're a fan of VowpalWabbit or prefer PyTorch (coming soon), the library can seamlessly integrate with both, allowing them to be the brain behind your decisions.
The beauty of `learn_to_pick` is its flexibility. Whether you're a fan of VowpalWabbit or prefer PyTorch, the library can seamlessly integrate with both, allowing them to be the brain behind your decisions.

## Installation

Expand All @@ -43,6 +43,8 @@ The `PickBest` scenario should be used when:
- Only one option is optimal for a specific criteria or context
- There exists a mechanism to provide feedback on the suitability of the chosen option for the specific criteria

### Scorer

Example usage with llm default scorer:

```python
Expand Down Expand Up @@ -113,7 +115,46 @@ dummy_score = 1
picker.update_with_delayed_score(dummy_score, result)
```

`PickBest` is highly configurable to work with a VowpalWabbit decision making policy, a PyTorch decision making policy (coming soon), or with a custom user defined decision making policy
### Using Pytorch policy

Example usage with a Pytorch policy:
```python
from learn_to_pick import PyTorchPolicy

pytorch_picker = learn_to_pick.PickBest.create(
policy=PyTorchPolicy(), selection_scorer=CustomSelectionScorer())

pytorch_picker.run(
pick = learn_to_pick.ToSelectFrom(["option1", "option2"]),
criteria = learn_to_pick.BasedOn("some criteria")
)
```

Example usage with a custom Pytorch policy:
You can alway create a custom Pytorch policy by implementing the Policy interface

```python
class CustomPytorchPolicy(Policy):
def __init__(self, **kwargs: Any):
...

def predict(self, event: TEvent) -> Any:
...

def learn(self, event: TEvent) -> None:
...

def log(self, event: TEvent) -> None:
...

def save(self) -> None:
...

pytorch_picker = learn_to_pick.PickBest.create(
policy=CustomPytorchPolicy(), selection_scorer=CustomSelectionScorer())
```

`PickBest` is highly configurable to work with a VowpalWabbit decision making policy, a PyTorch decision making policy, or with a custom user defined decision making policy

The main thing that needs to be decided from the get-go is:

Expand All @@ -134,7 +175,8 @@ In all three cases, when a score is calculated or provided, the decision making
## Example Notebooks

- `readme.ipynb` showcases all examples shown in this README
- `news_recommendation.ipynb` showcases a personalization scenario where we have to pick articles for specific users
- `news_recommendation.ipynb` showcases a personalization scenario where we have to pick articles for specific users with VowpalWabbit policy
- `news_recommendation_pytorch.ipynb` showcases the same personalization scenario where we have to pick articles for specific users with Pytorch policy
- `prompt_variable_injection.ipynb` showcases learned prompt variable injection and registering callback functionality

### Advanced Usage
Expand Down Expand Up @@ -183,7 +225,7 @@ class CustomSelectionScorer(learn_to_pick.SelectionScorer):
# inputs: the inputs to the picker in Dict[str, Any] format
# picked: the selection that was made by the policy
# event: metadata that can be used to determine the score if needed

# scoring logic goes here

dummy_score = 1.0
Expand Down
238 changes: 238 additions & 0 deletions notebooks/news_recommendation_pytorch.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from setuptools import setup, find_packages
import os

with open("README.md", "r", encoding="UTF-8") as fh:
long_description = fh.read()
Expand Down
17 changes: 12 additions & 5 deletions src/learn_to_pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
BasedOn,
Embed,
Featurizer,
ModelRepository,
Policy,
SelectionScorer,
ToSelectFrom,
VwPolicy,
VwLogger,
embed,
)
from learn_to_pick.pick_best import (
Expand All @@ -22,6 +19,14 @@
)


from learn_to_pick.vw.policy import VwPolicy
from learn_to_pick.vw.model_repository import ModelRepository
from learn_to_pick.vw.logger import VwLogger

from learn_to_pick.pytorch.policy import PyTorchPolicy
from learn_to_pick.pytorch.feature_embedder import PyTorchFeatureEmbedder


def configure_logger() -> None:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -48,9 +53,11 @@ def configure_logger() -> None:
"SelectionScorer",
"AutoSelectionScorer",
"Featurizer",
"ModelRepository",
"Policy",
"PyTorchPolicy",
"PyTorchFeatureEmbedder",
"embed",
"ModelRepository",
"VwPolicy",
"VwLogger",
"embed",
]
53 changes: 1 addition & 52 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
Callable,
)

from learn_to_pick.metrics import MetricsTrackerAverage, MetricsTrackerRollingWindow
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger

from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures
from enum import Enum

Expand Down Expand Up @@ -89,10 +86,6 @@ def EmbedAndKeep(anything: Any) -> Any:
# helper functions


def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
return [parser.parse_line(line) for line in input_str.split("\n")]


def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]:
return {
k: v.value
Expand Down Expand Up @@ -144,50 +137,6 @@ def save(self) -> None:
pass


class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: List[str],
featurizer: Featurizer,
formatter: Callable,
vw_logger: VwLogger,
**kwargs: Any,
):
super().__init__(**kwargs)
self.model_repo = model_repo
self.vw_cmd = vw_cmd
self.workspace = self.model_repo.load(vw_cmd)
self.featurizer = featurizer
self.formatter = formatter
self.vw_logger = vw_logger

def format(self, event):
return self.formatter(*self.featurizer.featurize(event))

def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw

text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(_parse_lines(text_parser, self.format(event)))

def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw

vw_ex = self.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = _parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)

def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.format(event)
self.vw_logger.log(vw_ex)

def save(self) -> None:
self.model_repo.save(self.workspace)


class Featurizer(Generic[TEvent], ABC):
def __init__(self, *args: Any, **kwargs: Any):
pass
Expand Down
10 changes: 7 additions & 3 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import numpy as np

from learn_to_pick import base
from learn_to_pick.vw.policy import VwPolicy
from learn_to_pick.vw.model_repository import ModelRepository
from learn_to_pick.vw.logger import VwLogger


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -333,14 +337,14 @@ def create_policy(

vw_cmd = interactions + vw_cmd

return base.VwPolicy(
model_repo=base.ModelRepository(
return VwPolicy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd,
featurizer=featurizer,
formatter=formatter,
vw_logger=base.VwLogger(rl_logs),
vw_logger=VwLogger(rl_logs),
)

def _default_policy(self):
Expand Down
Empty file.
69 changes: 69 additions & 0 deletions src/learn_to_pick/pytorch/feature_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from sentence_transformers import SentenceTransformer
import torch
from torch import Tensor

from learn_to_pick import PickBestFeaturizer
from learn_to_pick.base import Event
from learn_to_pick.features import SparseFeatures
from typing import Any, Tuple, TypeVar, Union

TEvent = TypeVar("TEvent", bound=Event)


class PyTorchFeatureEmbedder:
def __init__(self, model: Any = None):
if model is None:
model = SentenceTransformer("all-MiniLM-L6-v2")

self.model = model
self.featurizer = PickBestFeaturizer(auto_embed=False)

def encode(self, to_encode: str) -> Tensor:
embeddings = self.model.encode(to_encode, convert_to_tensor=True)
normalized = torch.nn.functional.normalize(embeddings)
return normalized

def convert_features_to_text(self, sparse_features: SparseFeatures) -> str:
results = []
for ns, obj in sparse_features.items():
value = obj.get("default_ft", "")
results.append(f"{ns}={value}")
return " ".join(results)

def format(
self, event: TEvent
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
context_featurized, actions_featurized, selected = self.featurizer.featurize(
event
)

if len(context_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support context with dense features"
)

for action_featurized in actions_featurized:
if len(action_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support action with dense features"
)

context_sparse = self.encode(
[self.convert_features_to_text(context_featurized.sparse)]
)

actions_sparse = []
for action_featurized in actions_featurized:
actions_sparse.append(
self.convert_features_to_text(action_featurized.sparse)
)
actions_sparse = self.encode(actions_sparse).unsqueeze(0)

if selected.score is not None:
return (
torch.Tensor([[selected.score]]),
context_sparse,
actions_sparse[:, selected.index, :].unsqueeze(1),
)
else:
return context_sparse, actions_sparse
21 changes: 21 additions & 0 deletions src/learn_to_pick/pytorch/igw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from torch import Tensor
from typing import Tuple


def IGW(fhat: torch.Tensor, gamma: float) -> Tuple[Tensor, Tensor]:
from math import sqrt

fhatahat, ahat = fhat.max(dim=1)
A = fhat.shape[1]
gamma *= sqrt(A)
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat))
sump = p.sum(dim=1)
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None)
return torch.multinomial(p, num_samples=1).squeeze(1), ahat


def SamplingIGW(A: Tensor, P: Tensor, gamma: float) -> list:
exploreind, _ = IGW(P, gamma)
explore = [ind for _, ind in zip(A, exploreind)]
return explore
Loading

0 comments on commit f9f5c3d

Please sign in to comment.