-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #31 from VowpalWabbit/byom
Add pytorch policy
- Loading branch information
Showing
16 changed files
with
761 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from sentence_transformers import SentenceTransformer | ||
import torch | ||
from torch import Tensor | ||
|
||
from learn_to_pick import PickBestFeaturizer | ||
from learn_to_pick.base import Event | ||
from learn_to_pick.features import SparseFeatures | ||
from typing import Any, Tuple, TypeVar, Union | ||
|
||
TEvent = TypeVar("TEvent", bound=Event) | ||
|
||
|
||
class PyTorchFeatureEmbedder: | ||
def __init__(self, model: Any = None): | ||
if model is None: | ||
model = SentenceTransformer("all-MiniLM-L6-v2") | ||
|
||
self.model = model | ||
self.featurizer = PickBestFeaturizer(auto_embed=False) | ||
|
||
def encode(self, to_encode: str) -> Tensor: | ||
embeddings = self.model.encode(to_encode, convert_to_tensor=True) | ||
normalized = torch.nn.functional.normalize(embeddings) | ||
return normalized | ||
|
||
def convert_features_to_text(self, sparse_features: SparseFeatures) -> str: | ||
results = [] | ||
for ns, obj in sparse_features.items(): | ||
value = obj.get("default_ft", "") | ||
results.append(f"{ns}={value}") | ||
return " ".join(results) | ||
|
||
def format( | ||
self, event: TEvent | ||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: | ||
context_featurized, actions_featurized, selected = self.featurizer.featurize( | ||
event | ||
) | ||
|
||
if len(context_featurized.dense) > 0: | ||
raise NotImplementedError( | ||
"pytorch policy doesn't support context with dense features" | ||
) | ||
|
||
for action_featurized in actions_featurized: | ||
if len(action_featurized.dense) > 0: | ||
raise NotImplementedError( | ||
"pytorch policy doesn't support action with dense features" | ||
) | ||
|
||
context_sparse = self.encode( | ||
[self.convert_features_to_text(context_featurized.sparse)] | ||
) | ||
|
||
actions_sparse = [] | ||
for action_featurized in actions_featurized: | ||
actions_sparse.append( | ||
self.convert_features_to_text(action_featurized.sparse) | ||
) | ||
actions_sparse = self.encode(actions_sparse).unsqueeze(0) | ||
|
||
if selected.score is not None: | ||
return ( | ||
torch.Tensor([[selected.score]]), | ||
context_sparse, | ||
actions_sparse[:, selected.index, :].unsqueeze(1), | ||
) | ||
else: | ||
return context_sparse, actions_sparse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
from torch import Tensor | ||
from typing import Tuple | ||
|
||
|
||
def IGW(fhat: torch.Tensor, gamma: float) -> Tuple[Tensor, Tensor]: | ||
from math import sqrt | ||
|
||
fhatahat, ahat = fhat.max(dim=1) | ||
A = fhat.shape[1] | ||
gamma *= sqrt(A) | ||
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat)) | ||
sump = p.sum(dim=1) | ||
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None) | ||
return torch.multinomial(p, num_samples=1).squeeze(1), ahat | ||
|
||
|
||
def SamplingIGW(A: Tensor, P: Tensor, gamma: float) -> list: | ||
exploreind, _ = IGW(P, gamma) | ||
explore = [ind for _, ind in zip(A, exploreind)] | ||
return explore |
Oops, something went wrong.