Skip to content

Commit

Permalink
update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 28, 2023
1 parent 20f8a21 commit af207ff
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 39 deletions.
135 changes: 101 additions & 34 deletions notebooks/news_recommendation_byom.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/learn_to_pick/pytorch/feature_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def format(

if len(context_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support context with dense feature"
"pytorch policy doesn't support context with dense features"
)

for action_featurized in actions_featurized:
if len(action_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support action with dense feature"
"pytorch policy doesn't support action with dense features"
)

context_sparse = self.encode(
Expand Down
6 changes: 3 additions & 3 deletions src/learn_to_pick/pytorch/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from learn_to_pick.pytorch.feature_embedder import PyTorchFeatureEmbedder
import torch
import os
from typing import Any, Optional, PathLike, TypeVar, Union
from typing import Any, Optional, TypeVar, Union

TEvent = TypeVar("TEvent", bound=base.Event)

Expand Down Expand Up @@ -55,7 +55,7 @@ def learn(self, event: TEvent) -> None:
def log(self, event):
pass

def save(self, path: Optional[Union[str, PathLike]]) -> None:
def save(self, path: Optional[Union[str, os.PathLike]]) -> None:
state = {
"workspace_state_dict": self.workspace.state_dict(),
"optimizer_state_dict": self.workspace.optim.state_dict(),
Expand All @@ -69,7 +69,7 @@ def save(self, path: Optional[Union[str, PathLike]]) -> None:
os.makedirs(dir, exist_ok=True)
torch.save(state, path)

def load(self, path: Optional[Union[str, PathLike]]) -> None:
def load(self, path: Optional[Union[str, os.PathLike]]) -> None:
import parameterfree

if os.path.exists(path):
Expand Down

0 comments on commit af207ff

Please sign in to comment.