Skip to content

Commit

Permalink
Merge pull request #14 from VowpalWabbit/add_tests
Browse files Browse the repository at this point in the history
Add tests
  • Loading branch information
cheng-tan authored Oct 25, 2023
2 parents b4ac6e2 + d56c321 commit 9491aa0
Show file tree
Hide file tree
Showing 9 changed files with 1,231 additions and 30 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
build
*.egg-info
__pycache__
.pytest_cache
dist
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
'pyskiplist',
'parameterfree',
],
extras_require={
'dev': [
'pytest'
]
},
author="VowpalWabbit",
description="",
packages=find_packages(where="src"),
Expand Down
35 changes: 13 additions & 22 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]

if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

based_on = {
Expand All @@ -124,7 +124,7 @@ def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
"""
"""

next_inputs = inputs.copy()
for k, v in next_inputs.items():
Expand Down Expand Up @@ -238,7 +238,7 @@ def score_response(


class AutoSelectionScorer(SelectionScorer[Event]):
def __init__(self,
def __init__(self,
llm,
prompt: Union[Any, None] = None,
scoring_criteria_template_str: Optional[str] = None):
Expand Down Expand Up @@ -286,7 +286,7 @@ def score_response(
return resp
except Exception as e:
raise RuntimeError(
f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}"
f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}"
)


Expand Down Expand Up @@ -349,17 +349,17 @@ def __init__(
@abstractmethod
def _default_policy(self):
...

def update_with_delayed_score(
self, score: float, chain_response: Dict[str, Any], force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
"""
"""
if self._can_use_selection_scorer() and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
)
if self.metrics:
self.metrics.on_feedback(score)
Expand All @@ -371,35 +371,26 @@ def update_with_delayed_score(
def deactivate_selection_scorer(self) -> None:
"""
Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
"""
"""
self.selection_scorer_activated = False

def activate_selection_scorer(self) -> None:
"""
Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
"""
"""
self.selection_scorer_activated = True

def save_progress(self) -> None:
"""
This function should be called to save the state of the learned policy model.
"""
"""
self.policy.save()

def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
super()._validate_inputs(inputs)
if (
self.selected_input_key in inputs.keys()
or self.selected_based_on_input_key in inputs.keys()
):
raise ValueError(
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."
)

def _can_use_selection_scorer(self) -> bool:
"""
Returns whether the chain can use the selection scorer to score responses or not.
"""
"""
return self.selection_scorer is not None and self.selection_scorer_activated

@abstractmethod
Expand Down Expand Up @@ -489,7 +480,7 @@ def embed_string_type(

if namespace is None:
raise ValueError(
"The default namespace must be provided when embedding a string or _Embed object."
"The default namespace must be provided when embedding a string or _Embed object."
)

return {namespace: keep_str + encoded}
Expand Down Expand Up @@ -542,7 +533,7 @@ def embed(
model: (Any, required) The model to use for embedding
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
"""
"""
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
to_embed, str
):
Expand Down
16 changes: 8 additions & 8 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
Attributes:
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
"""
"""

def __init__(
self, auto_embed: bool, model: Optional[Any] = None, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -260,7 +260,7 @@ class PickBest(base.RLLoop[PickBestEvent]):
Attributes:
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
"""
"""

def __init__(
self,
Expand All @@ -272,17 +272,17 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

if len(list(actions.values())) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)

if not context:
raise ValueError(
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
)

event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
Expand Down Expand Up @@ -352,7 +352,7 @@ def create(
selection_scorer = base.AutoSelectionScorer(llm=llm)
if policy and any(policy_args):
logger.warning(
f"{list(policy_args.keys())} will be ignored since nontrivial policy is provided"
f"{list(policy_args.keys())} will be ignored since nontrivial policy is provided"
)

return PickBest(
Expand All @@ -361,7 +361,7 @@ def create(
metrics_step = metrics_step,
metrics_window_size = metrics_window_size,
)

@staticmethod
def create_policy(
feature_embedder: Optional[base.Embedder] = None,
Expand All @@ -374,7 +374,7 @@ def create_policy(
if feature_embedder:
if "auto_embed" in kwargs:
logger.warning(
"auto_embed will take no effect when explicit feature_embedder is provided"
"auto_embed will take no effect when explicit feature_embedder is provided"
)
# turning auto_embed off for cli setting below
auto_embed = False
Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
requires(name): Mark test to run only when the specified requirement is met.
Loading

0 comments on commit 9491aa0

Please sign in to comment.