Skip to content

Commit

Permalink
add type hint to pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 28, 2023
1 parent d6e9c87 commit 20f8a21
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 32 deletions.
31 changes: 25 additions & 6 deletions src/learn_to_pick/pytorch/feature_embedder.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,53 @@
from sentence_transformers import SentenceTransformer
import torch
from torch import Tensor

from learn_to_pick import PickBestFeaturizer
from learn_to_pick.base import Event
from learn_to_pick.features import SparseFeatures
from typing import Any, Tuple, TypeVar, Union

TEvent = TypeVar("TEvent", bound=Event)


class PyTorchFeatureEmbedder:
def __init__(self, model=None, *args, **kwargs):
def __init__(self, model: Any = None):
if model is None:
model = SentenceTransformer("all-MiniLM-L6-v2")

self.model = model
self.featurizer = PickBestFeaturizer(auto_embed=False)

def encode(self, stuff):
embeddings = self.model.encode(stuff, convert_to_tensor=True)
def encode(self, to_encode: str) -> Tensor:
embeddings = self.model.encode(to_encode, convert_to_tensor=True)
normalized = torch.nn.functional.normalize(embeddings)
return normalized

def convert_features_to_text(self, sparse_features):
def convert_features_to_text(self, sparse_features: SparseFeatures) -> str:
results = []
for ns, obj in sparse_features.items():
value = obj.get("default_ft", "")
results.append(f"{ns}={value}")
return " ".join(results)

def format(self, event):
# TODO: handle dense
def format(
self, event: TEvent
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
context_featurized, actions_featurized, selected = self.featurizer.featurize(
event
)

if len(context_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support context with dense feature"
)

for action_featurized in actions_featurized:
if len(action_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support action with dense feature"
)

context_sparse = self.encode(
[self.convert_features_to_text(context_featurized.sparse)]
)
Expand Down
6 changes: 4 additions & 2 deletions src/learn_to_pick/pytorch/igw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from torch import Tensor
from typing import Tuple


def IGW(fhat, gamma):
def IGW(fhat: torch.Tensor, gamma: float) -> Tuple[Tensor, Tensor]:
from math import sqrt

fhatahat, ahat = fhat.max(dim=1)
Expand All @@ -13,7 +15,7 @@ def IGW(fhat, gamma):
return torch.multinomial(p, num_samples=1).squeeze(1), ahat


def SamplingIGW(A, P, gamma):
def SamplingIGW(A: Tensor, P: Tensor, gamma: float) -> list:
exploreind, _ = IGW(P, gamma)
explore = [ind for _, ind in zip(A, exploreind)]
return explore
23 changes: 12 additions & 11 deletions src/learn_to_pick/pytorch/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import parameterfree
import torch
from torch import Tensor
import torch.nn.functional as F


class MLP(torch.nn.Module):
@staticmethod
def new_gelu(x):
def new_gelu(x: Tensor) -> Tensor:
import math

return (
Expand All @@ -19,13 +20,13 @@ def new_gelu(x):
)
)

def __init__(self, dim):
def __init__(self, dim: int):
super().__init__()
self.c_fc = torch.nn.Linear(dim, 4 * dim)
self.c_proj = torch.nn.Linear(4 * dim, dim)
self.dropout = torch.nn.Dropout(0.5)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.c_fc(x)
x = self.new_gelu(x)
x = self.c_proj(x)
Expand All @@ -34,16 +35,16 @@ def forward(self, x):


class Block(torch.nn.Module):
def __init__(self, dim):
def __init__(self, dim: int):
super().__init__()
self.layer = MLP(dim)

def forward(self, x):
def forward(self, x: Tensor):
return x + self.layer(x)


class ResidualLogisticRegressor(torch.nn.Module):
def __init__(self, in_features, depth, device):
def __init__(self, in_features: int, depth: int, device: str):
super().__init__()
self._in_features = in_features
self._depth = depth
Expand All @@ -52,17 +53,17 @@ def __init__(self, in_features, depth, device):
self.optim = parameterfree.COCOB(self.parameters())
self._device = device

def clone(self):
def clone(self) -> "ResidualLogisticRegressor":
other = ResidualLogisticRegressor(self._in_features, self._depth, self._device)
other.load_state_dict(self.state_dict())
other.optim = parameterfree.COCOB(other.parameters())
other.optim.load_state_dict(self.optim.state_dict())
return other

def forward(self, X, A):
def forward(self, X: Tensor, A: Tensor) -> Tensor:
return self.logits(X, A)

def logits(self, X, A):
def logits(self, X: Tensor, A: Tensor) -> Tensor:
# X = batch x features
# A = batch x actionbatch x actionfeatures

Expand All @@ -76,11 +77,11 @@ def logits(self, X, A):
) # batch x actionbatch x (features + actionfeatures)
return self.linear(self.blocks(XA)).squeeze(2) # batch x actionbatch

def predict(self, X, A):
def predict(self, X: Tensor, A: Tensor) -> Tensor:
self.eval()
return torch.special.expit(self.logits(X, A))

def bandit_learn(self, X, A, R):
def bandit_learn(self, X: Tensor, A: Tensor, R: Tensor) -> float:
self.train()
self.optim.zero_grad()
output = self(X, A)
Expand Down
24 changes: 11 additions & 13 deletions src/learn_to_pick/pytorch/policy.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from learn_to_pick import base, PickBestEvent
from learn_to_pick.pytorch.logistic_regression import ResidualLogisticRegressor
from learn_to_pick.pytorch.igw import SamplingIGW
from learn_to_pick.pytorch.feature_embedder import PyTorchFeatureEmbedder
import torch
import os
from typing import Any, Optional, PathLike, TypeVar, Union

TEvent = TypeVar("TEvent", bound=base.Event)


class PyTorchPolicy(base.Policy[PickBestEvent]):
def __init__(
self,
feature_embedder,
feature_embedder=PyTorchFeatureEmbedder(),
depth: int = 2,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
print(f"Device: {device}")
super().__init__(*args, **kwargs)
Expand All @@ -24,40 +28,34 @@ def __init__(
self.index = 0
self.loss = None

def predict(self, event):
def predict(self, event: TEvent) -> list:
X, A = self.feature_embedder.format(event)
# print(f"X shape: {X.shape}")
# print(f"A shape: {A.shape}")
# TODO IGW sampling then create the distro so that the one
# that was sampled here is the one that will def be sampled by
# the base sampler, and in the future replace the sampler so that it
# is something that can be plugged in
p = self.workspace.predict(X, A)
# print(f"p: {p}")
import math

explore = SamplingIGW(A, p, math.sqrt(self.index))
self.index += 1
# print(f"explore: {explore}")
r = []
for index in range(p.shape[1]):
if index == explore[0]:
r.append((index, 1))
else:
r.append((index, 0))
# print(f"returning: {r}")
return r

def learn(self, event):
def learn(self, event: TEvent) -> None:
R, X, A = self.feature_embedder.format(event)
# print(f"R: {R}")
R, X, A = R.to(self.device), X.to(self.device), A.to(self.device)
self.loss = self.workspace.bandit_learn(X, A, R)

def log(self, event):
pass

def save(self, path) -> None:
def save(self, path: Optional[Union[str, PathLike]]) -> None:
state = {
"workspace_state_dict": self.workspace.state_dict(),
"optimizer_state_dict": self.workspace.optim.state_dict(),
Expand All @@ -71,7 +69,7 @@ def save(self, path) -> None:
os.makedirs(dir, exist_ok=True)
torch.save(state, path)

def load(self, path) -> None:
def load(self, path: Optional[Union[str, PathLike]]) -> None:
import parameterfree

if os.path.exists(path):
Expand Down

0 comments on commit 20f8a21

Please sign in to comment.