Skip to content

Commit

Permalink
WIP: BYOM
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 12, 2023
1 parent 97793c8 commit afef4a7
Show file tree
Hide file tree
Showing 9 changed files with 403 additions and 19 deletions.
183 changes: 166 additions & 17 deletions notebooks/news_recommendation_byom.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from setuptools import setup, find_packages
import os

with open("README.md", "r", encoding="UTF-8") as fh:
long_description = fh.read()
Expand Down
9 changes: 9 additions & 0 deletions src/learn_to_pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
PickBestSelected,
)

from learn_to_pick.byom.pytorch_policy import (
PyTorchPolicy
)

from learn_to_pick.byom.pytorch_feature_embedder import (
PyTorchFeatureEmbedder
)

def configure_logger() -> None:
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,6 +57,8 @@ def configure_logger() -> None:
"Featurizer",
"ModelRepository",
"Policy",
"PyTorchPolicy",
"PyTorchFeatureEmbedder",
"VwPolicy",
"VwLogger",
"embed",
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions src/learn_to_pick/byom/igw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

def IGW(fhat, gamma):
from math import sqrt
fhatahat, ahat = fhat.max(dim=1)
A = fhat.shape[1]
gamma *= sqrt(A)
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat))
sump = p.sum(dim=1)
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None)
return torch.multinomial(p, num_samples=1).squeeze(1), ahat

def SamplingIGW(A, P, gamma):
exploreind, _ = IGW(P, gamma)
explore = [ ind for _, ind in zip(A, exploreind) ]
return explore
70 changes: 70 additions & 0 deletions src/learn_to_pick/byom/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import parameterfree
import torch
import torch.nn.functional as F

class MLP(torch.nn.Module):
@staticmethod
def new_gelu(x):
import math
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

def __init__(self, dim):
super().__init__()
self.c_fc = torch.nn.Linear(dim, 4 * dim)
self.c_proj = torch.nn.Linear(4 * dim, dim)
self.dropout = torch.nn.Dropout(0.5)

def forward(self, x):
x = self.c_fc(x)
x = self.new_gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x

class Block(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.layer = MLP(dim)

def forward(self, x):
return x + self.layer(x)

class ResidualLogisticRegressor(torch.nn.Module):
def __init__(self, in_features, depth):
super().__init__()
self._in_features = in_features
self._depth = depth
self.blocks = torch.nn.Sequential(*[ Block(in_features) for _ in range(depth) ])
self.linear = torch.nn.Linear(in_features=in_features, out_features=1)
self.optim = parameterfree.COCOB(self.parameters())

def clone(self):
other = ResidualLogisticRegressor(self._in_features, self._depth)
other.load_state_dict(self.state_dict())
other.optim = parameterfree.COCOB(other.parameters())
other.optim.load_state_dict(self.optim.state_dict())
return other

def forward(self, X, A):
return self.logits(X, A)

def logits(self, X, A):
# X = batch x features
# A = batch x actionbatch x actionfeatures

Xreshap = X.unsqueeze(1).expand(-1, A.shape[1], -1) # batch x actionbatch x features
XA = torch.cat((Xreshap, A), dim=-1).reshape(X.shape[0], A.shape[1], -1) # batch x actionbatch x (features + actionfeatures)
return self.linear(self.blocks(XA)).squeeze(2) # batch x actionbatch

def predict(self, X, A):
self.eval()
return torch.special.expit(self.logits(X, A))

def bandit_learn(self, X, A, R):
self.train()
self.optim.zero_grad()
output = self(X, A)
loss = F.binary_cross_entropy_with_logits(output, R)
loss.backward()
self.optim.step()
return loss.item()
87 changes: 87 additions & 0 deletions src/learn_to_pick/byom/pytorch_feature_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import learn_to_pick as rl_chain
from sentence_transformers import SentenceTransformer
import torch

class PyTorchFeatureEmbedder(): #rl_chain.Embedder[rl_chain.PickBestEvent]
def __init__(
self, auto_embed, model = None, *args, **kwargs
):
if model is None:
model = model = SentenceTransformer('all-MiniLM-L6-v2')

self.model = model
self.auto_embed = auto_embed

def encode(self, stuff):
embeddings = self.model.encode(stuff, convert_to_tensor=True)
normalized = torch.nn.functional.normalize(embeddings)
return normalized

def get_label(self, event: rl_chain.PickBestEvent) -> tuple:
cost = None
if event.selected:
chosen_action = event.selected.index
cost = (
-1.0 * event.selected.score
if event.selected.score is not None
else None
)
prob = event.selected.probability
return chosen_action, cost, prob
else:
return None, None, None

def get_context_and_action_embeddings(self, event: rl_chain.PickBestEvent) -> tuple:
context_emb = rl_chain.embed(event.based_on, self) if event.based_on else None
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)

action_embs = (
(
rl_chain.embed(to_select_from, self, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)

if not context_emb or not action_embs:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
)
return context_emb, action_embs

def format(self, event: rl_chain.PickBestEvent):
chosen_action, cost, prob = self.get_label(event)
context_emb, action_embs = self.get_context_and_action_embeddings(event)

context = ""
for context_item in context_emb:
for ns, based_on in context_item.items():
e = " ".join(based_on) if isinstance(based_on, list) else based_on
context += f"{ns}={e} "

if self.auto_embed:
context = self.encode([context])

actions = []
for action in action_embs:
action_str = ""
for ns, action_embedding in action.items():
e = (
" ".join(action_embedding)
if isinstance(action_embedding, list)
else action_embedding
)
action_str += f"{ns}={e} "
actions.append(action_str)

if self.auto_embed:
actions = self.encode(actions).unsqueeze(0)

if cost is None:
return context, actions
else:
return torch.Tensor([[-1.0 * cost]]), context, actions[:,chosen_action,:].unsqueeze(1)
54 changes: 54 additions & 0 deletions src/learn_to_pick/byom/pytorch_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from learn_to_pick import base, PickBestEvent
from learn_to_pick.byom.logistic_regression import ResidualLogisticRegressor
from learn_to_pick.byom.igw import SamplingIGW

class PyTorchPolicy(base.Policy[PickBestEvent]):
def __init__(
self,
feature_embedder,
depth: int = 2,
device: str = 'cuda',
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.workspace = ResidualLogisticRegressor(feature_embedder.model.get_sentence_embedding_dimension() * 2, depth).to(device)
self.feature_embedder = feature_embedder
self.device = device
self.index = 0

def predict(self, event):
X, A = self.feature_embedder.format(event)
# print(f"X shape: {X.shape}")
# print(f"A shape: {A.shape}")
# TODO IGW sampling then create the distro so that the one
# that was sampled here is the one that will def be sampled by
# the base sampler, and in the future replace the sampler so that it
# is something that can be plugged in
p = self.workspace.predict(X, A)
# print(f"p: {p}")
import math
explore = SamplingIGW(A, p, math.sqrt(self.index))
self.index += 1
# print(f"explore: {explore}")
r = []
for index in range(p.shape[1]):
if index == explore[0]:
r.append((index, 1))
else:
r.append((index, 0))
# print(f"returning: {r}")
return r
return [(index, val) for index, val in enumerate(p[0].tolist())]

def learn(self, event):
R, X, A = self.feature_embedder.format(event)
# print(f"R: {R}")
R, X, A = R.to(self.device), X.to(self.device), A.to(self.device)
self.workspace.bandit_learn(X, A, R)

def log(self, event):
pass

def save(self) -> None:
pass
2 changes: 1 addition & 1 deletion src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _call_after_scoring_before_learning(

@classmethod
def create(
# cls: Type[PickBest],
cls: Type[PickBest],
policy: Optional[base.Policy] = None,
llm=None,
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
Expand Down

0 comments on commit afef4a7

Please sign in to comment.