diff --git a/.gitignore b/.gitignore index 8cd9881..766b782 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ build *.egg-info *.vw +__pycache__ +.pytest_cache +dist diff --git a/setup.py b/setup.py index 4111815..8294c48 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,11 @@ 'pyskiplist', 'parameterfree', ], + extras_require={ + 'dev': [ + 'pytest' + ] + }, author="VowpalWabbit", description="", packages=find_packages(where="src"), diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 09e6ed6..2af40fe 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -108,7 +108,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict] if not to_select_from: raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." + "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." ) based_on = { @@ -124,7 +124,7 @@ def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]: """ go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status - """ + """ next_inputs = inputs.copy() for k, v in next_inputs.items(): @@ -256,7 +256,7 @@ def score_response( class AutoSelectionScorer(SelectionScorer[Event]): - def __init__(self, + def __init__(self, llm, prompt: Union[Any, None] = None, scoring_criteria_template_str: Optional[str] = None): @@ -304,7 +304,7 @@ def score_response( return resp except Exception as e: raise RuntimeError( - f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}" + f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}" ) @@ -365,17 +365,17 @@ def __init__( @abstractmethod def _default_policy(self): ... - + def update_with_delayed_score( self, score: float, chain_response: Dict[str, Any], force_score: bool = False ) -> None: """ Updates the learned policy with the score provided. Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call - """ + """ if self._can_use_selection_scorer() and not force_score: raise RuntimeError( - "The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." + "The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." ) if self.metrics: self.metrics.on_feedback(score) @@ -387,35 +387,26 @@ def update_with_delayed_score( def deactivate_selection_scorer(self) -> None: """ Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses. - """ + """ self.selection_scorer_activated = False def activate_selection_scorer(self) -> None: """ Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses. - """ + """ self.selection_scorer_activated = True def save_progress(self) -> None: """ This function should be called to save the state of the learned policy model. - """ + """ self.policy.save() - def _validate_inputs(self, inputs: Dict[str, Any]) -> None: - super()._validate_inputs(inputs) - if ( - self.selected_input_key in inputs.keys() - or self.selected_based_on_input_key in inputs.keys() - ): - raise ValueError( - f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward." - ) def _can_use_selection_scorer(self) -> bool: """ Returns whether the chain can use the selection scorer to score responses or not. - """ + """ return self.selection_scorer is not None and self.selection_scorer_activated @abstractmethod @@ -506,7 +497,7 @@ def embed_string_type( if namespace is None: raise ValueError( - "The default namespace must be provided when embedding a string or _Embed object." + "The default namespace must be provided when embedding a string or _Embed object." ) return {namespace: keep_str + encoded} @@ -559,7 +550,7 @@ def embed( model: (Any, required) The model to use for embedding Returns: List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value - """ + """ if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance( to_embed, str ): diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index a82d1a2..45191e2 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -48,7 +48,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): Attributes: model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. - """ + """ def __init__( self, auto_embed: bool, model: Optional[Any] = None, *args: Any, **kwargs: Any @@ -260,7 +260,7 @@ class PickBest(base.RLLoop[PickBestEvent]): Attributes: feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized. - """ + """ def __init__( self, @@ -272,17 +272,17 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: context, actions = base.get_based_on_and_to_select_from(inputs=inputs) if not actions: raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." + "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." ) if len(list(actions.values())) > 1: raise ValueError( - "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." + "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." ) if not context: raise ValueError( - "No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." + "No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." ) event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context) @@ -349,25 +349,26 @@ def create( elif selection_scorer is SENTINEL: selection_scorer = base.AutoSelectionScorer(llm=llm) - feature_embedder = kwargs.pop('feature_embedder', None) - vw_cmd = kwargs.pop('vw_cmd', None) - model_save_dir = kwargs.pop('model_save_dir', "./") - reset_model = kwargs.pop('reset_model', False) - vw_logs = kwargs.pop('vw_logs', None) + policy_args = { + 'feature_embedder': kwargs.pop('feature_embedder', None), + 'vw_cmd': kwargs.pop('vw_cmd', None), + 'model_save_dir': kwargs.pop('model_save_dir', "./"), + 'reset_model': kwargs.pop('reset_model', False), + 'vw_logs': kwargs.pop('vw_logs', None) + } + + if policy and any(policy_args.values()): + logger.warning( + f"{[k for k, v in policy_args.items() if v]} will be ignored since nontrivial policy is provided" + ) return PickBest( - policy = policy or PickBest.create_policy( - feature_embedder=feature_embedder, - vw_cmd=vw_cmd, - model_save_dir=model_save_dir, - reset_model=reset_model, - vw_logs=vw_logs, - **kwargs, - ), - selection_scorer = selection_scorer, + policy=policy or PickBest.create_policy(**policy_args, **kwargs), + selection_scorer=selection_scorer, **kwargs, ) - + + @staticmethod def create_policy( feature_embedder: Optional[base.Embedder] = None, @@ -380,7 +381,7 @@ def create_policy( if feature_embedder: if "auto_embed" in kwargs: logger.warning( - "auto_embed will take no effect when explicit feature_embedder is provided" + "auto_embed will take no effect when explicit feature_embedder is provided" ) # turning auto_embed off for cli setting below auto_embed = False diff --git a/tests/unit_tests/pytest.ini b/tests/unit_tests/pytest.ini new file mode 100644 index 0000000..a171c31 --- /dev/null +++ b/tests/unit_tests/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + requires(name): Mark test to run only when the specified requirement is met. diff --git a/tests/unit_tests/test_pick_best_call.py b/tests/unit_tests/test_pick_best_call.py new file mode 100644 index 0000000..ff5c74f --- /dev/null +++ b/tests/unit_tests/test_pick_best_call.py @@ -0,0 +1,392 @@ +from typing import Any, Dict + +import pytest +from test_utils import MockEncoder, MockEncoderReturnsList + +import learn_to_pick +import learn_to_pick.base as rl_loop + +encoded_keyword = "[encoded]" + +class fake_llm_caller: + def predict(self): + return "hey" + +class fake_llm_caller_with_score: + def predict(self): + return "3" + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_multiple_ToSelectFrom_throws() -> None: + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + actions = ["0", "1", "2"] + with pytest.raises(ValueError): + pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(actions), + another_action=learn_to_pick.ToSelectFrom(actions), + ) + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_missing_basedOn_from_throws() -> None: + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + actions = ["0", "1", "2"] + with pytest.raises(ValueError): + pick.run(action=learn_to_pick.ToSelectFrom(actions)) + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_ToSelectFrom_not_a_list_throws() -> None: + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + actions = {"actions": ["0", "1", "2"]} + with pytest.raises(ValueError): + pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(actions), + ) + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_update_with_delayed_score_with_auto_validator_throws() -> None: + auto_val_llm = fake_llm_caller_with_score + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + selection_scorer=learn_to_pick.AutoSelectionScorer(llm=auto_val_llm), + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + actions = ["0", "1", "2"] + response = pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + + assert picked_metadata.selected.score == 3.0 # type: ignore + with pytest.raises(RuntimeError): + pick.update_with_delayed_score( + chain_response=response, score=100 # type: ignore + ) + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_update_with_delayed_score_force() -> None: + # this LLM returns a number so that the auto validator will return that + auto_val_llm = fake_llm_caller_with_score + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + selection_scorer=learn_to_pick.AutoSelectionScorer(llm=auto_val_llm), + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + actions = ["0", "1", "2"] + response = pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score == 3.0 # type: ignore + pick.update_with_delayed_score( + chain_response=response, score=100, force_score=True # type: ignore + ) + assert picked_metadata.selected.score == 100.0 # type: ignore + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_update_with_delayed_score() -> None: + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + selection_scorer=None, + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + actions = ["0", "1", "2"] + response = pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score is None # type: ignore + pick.update_with_delayed_score(chain_response=response, score=100) # type: ignore + assert picked_metadata.selected.score == 100.0 # type: ignore + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_user_defined_scorer() -> None: + class CustomSelectionScorer(learn_to_pick.SelectionScorer): + def score_response( + self, + inputs: Dict[str, Any], + event: learn_to_pick.PickBestEvent, + ) -> float: + score = 200 + return score + + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + selection_scorer=CustomSelectionScorer(), + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + actions = ["0", "1", "2"] + response = pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score == 200.0 # type: ignore + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_everything_embedded() -> None: + feature_embedder = learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, feature_embedder=feature_embedder, auto_embed=False + ) + + str1 = "0" + str2 = "1" + str3 = "2" + encoded_str1 = rl_loop.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_loop.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_loop.stringify_embedding(list(encoded_keyword + str3)) + + ctx_str_1 = "context1" + + encoded_ctx_str_1 = rl_loop.stringify_embedding(list(encoded_keyword + ctx_str_1)) + + expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa + + actions = [str1, str2, str3] + + response = pick.run( + User=rl_loop.EmbedAndKeep(learn_to_pick.BasedOn(ctx_str_1)), + action=rl_loop.EmbedAndKeep(learn_to_pick.ToSelectFrom(actions)), + ) + picked_metadata = response["picked_metadata"] # type: ignore + vw_str = feature_embedder.format(picked_metadata) # type: ignore + assert vw_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_default_auto_embedder_is_off() -> None: + feature_embedder = learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, feature_embedder=feature_embedder + ) + + str1 = "0" + str2 = "1" + str3 = "2" + ctx_str_1 = "context1" + + expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa + + actions = [str1, str2, str3] + + response = pick.run( + User=learn_to_pick.base.BasedOn(ctx_str_1), + action=learn_to_pick.base.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + vw_str = feature_embedder.format(picked_metadata) # type: ignore + assert vw_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_default_w_embeddings_off() -> None: + feature_embedder = learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, feature_embedder=feature_embedder, auto_embed=False + ) + + str1 = "0" + str2 = "1" + str3 = "2" + ctx_str_1 = "context1" + + expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa + + actions = [str1, str2, str3] + + response = pick.run( + User=learn_to_pick.BasedOn(ctx_str_1), + action=learn_to_pick.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + vw_str = feature_embedder.format(picked_metadata) # type: ignore + assert vw_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_default_w_embeddings_on() -> None: + feature_embedder = learn_to_pick.PickBestFeatureEmbedder( + auto_embed=True, model=MockEncoderReturnsList() + ) + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, feature_embedder=feature_embedder, auto_embed=True + ) + + str1 = "0" + str2 = "1" + ctx_str_1 = "context1" + dot_prod = "dotprod 0:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0] + + expected = f"""shared |User {ctx_str_1} |@ User={ctx_str_1}\n|action {str1} |# action={str1} |{dot_prod}\n|action {str2} |# action={str2} |{dot_prod}""" # noqa + + actions = [str1, str2] + + response = pick.run( + User=learn_to_pick.BasedOn(ctx_str_1), + action=learn_to_pick.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + vw_str = feature_embedder.format(picked_metadata) # type: ignore + assert vw_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: + feature_embedder = learn_to_pick.PickBestFeatureEmbedder( + auto_embed=True, model=MockEncoderReturnsList() + ) + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, feature_embedder=feature_embedder, auto_embed=True + ) + + str1 = "0" + str2 = "1" + encoded_str2 = learn_to_pick.stringify_embedding([1.0, 2.0]) + ctx_str_1 = "context1" + ctx_str_2 = "context2" + encoded_ctx_str_1 = learn_to_pick.stringify_embedding([1.0, 2.0]) + dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0] + + expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa + + actions = [str1, learn_to_pick.Embed(str2)] + + response = pick.run( + User=learn_to_pick.BasedOn(learn_to_pick.Embed(ctx_str_1)), + User2=learn_to_pick.BasedOn(ctx_str_2), + action=learn_to_pick.ToSelectFrom(actions), + ) + picked_metadata = response["picked_metadata"] # type: ignore + vw_str = feature_embedder.format(picked_metadata) # type: ignore + assert vw_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_default_no_scorer_specified() -> None: + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller_with_score, + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + response = pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(["0", "1", "2"]), + ) + # chain llm used for both basic prompt and for scoring + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score == 3.0 # type: ignore + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_explicitly_no_scorer() -> None: + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + selection_scorer=None, + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + response = pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(["0", "1", "2"]), + ) + # chain llm used for both basic prompt and for scoring + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score is None # type: ignore + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_auto_scorer_with_user_defined_llm() -> None: + scorer_llm = fake_llm_caller_with_score + pick = learn_to_pick.PickBest.create( + llm=fake_llm_caller, + selection_scorer=learn_to_pick.AutoSelectionScorer(llm=scorer_llm), + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + response = pick.run( + User=learn_to_pick.BasedOn("Context"), + action=learn_to_pick.ToSelectFrom(["0", "1", "2"]), + ) + # chain llm used for both basic prompt and for scoring + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score == 3 # type: ignore + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_activate_and_deactivate_scorer() -> None: + llm = fake_llm_caller + scorer_llm = fake_llm_caller_with_score + pick = learn_to_pick.PickBest.create( + llm=llm, + selection_scorer=learn_to_pick.base.AutoSelectionScorer(llm=scorer_llm), + feature_embedder=learn_to_pick.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ), + ) + response = pick.run( + User=learn_to_pick.base.BasedOn("Context"), + action=learn_to_pick.base.ToSelectFrom(["0", "1", "2"]), + ) + # chain llm used for both basic prompt and for scoring + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score == 3 # type: ignore + + pick.deactivate_selection_scorer() + response = pick.run( + User=learn_to_pick.base.BasedOn("Context"), + action=learn_to_pick.base.ToSelectFrom(["0", "1", "2"]), + ) + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score is None # type: ignore + + pick.activate_selection_scorer() + response = pick.run( + User=learn_to_pick.base.BasedOn("Context"), + action=learn_to_pick.base.ToSelectFrom(["0", "1", "2"]), + ) + picked_metadata = response["picked_metadata"] # type: ignore + assert picked_metadata.selected.score == 3 # type: ignore diff --git a/tests/unit_tests/test_pick_best_text_embedder.py b/tests/unit_tests/test_pick_best_text_embedder.py new file mode 100644 index 0000000..45714a8 --- /dev/null +++ b/tests/unit_tests/test_pick_best_text_embedder.py @@ -0,0 +1,370 @@ +import pytest +from test_utils import MockEncoder + +import learn_to_pick.base as rl_chain +import learn_to_pick.pick_best as pick_best_chain + +encoded_keyword = "[encoded]" + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_missing_context_throws() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + named_action = {"action": ["0", "1", "2"]} + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_action, based_on={} + ) + with pytest.raises(ValueError): + feature_embedder.format(event) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_missing_actions_throws() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from={}, based_on={"context": "context"} + ) + with pytest.raises(ValueError): + feature_embedder.format(event) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_no_label_no_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + named_actions = {"action1": ["0", "1", "2"]} + expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on={"context": "context"} + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + named_actions = {"action1": ["0", "1", "2"]} + expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) + event = pick_best_chain.PickBestEvent( + inputs={}, + to_select_from=named_actions, + based_on={"context": "context"}, + selected=selected, + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_w_full_label_no_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + named_actions = {"action1": ["0", "1", "2"]} + expected = ( + """shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """ + ) + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, + to_select_from=named_actions, + based_on={"context": "context"}, + selected=selected, + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_w_full_label_w_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + str1 = "0" + str2 = "1" + str3 = "2" + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + + ctx_str_1 = "context1" + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + + named_actions = {"action1": rl_chain.Embed([str1, str2, str3])} + context = {"context": rl_chain.Embed(ctx_str_1)} + expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + str1 = "0" + str2 = "1" + str3 = "2" + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + + ctx_str_1 = "context1" + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + + named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])} + context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)} + expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} + context = {"context1": "context1", "context2": "context2"} + expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} + context = {"context1": "context1", "context2": "context2"} + expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} + context = {"context1": "context1", "context2": "context2"} + expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + + str1 = "0" + str2 = "1" + str3 = "2" + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + + ctx_str_1 = "context1" + ctx_str_2 = "context2" + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + + named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])} + context = { + "context1": rl_chain.Embed(ctx_str_1), + "context2": rl_chain.Embed(ctx_str_2), + } + expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 + + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> ( + None +): + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + + str1 = "0" + str2 = "1" + str3 = "2" + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + + ctx_str_1 = "context1" + ctx_str_2 = "context2" + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + + named_actions = { + "action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3]) + } + context = { + "context1": rl_chain.EmbedAndKeep(ctx_str_1), + "context2": rl_chain.EmbedAndKeep(ctx_str_2), + } + expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 + + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + + str1 = "0" + str2 = "1" + str3 = "2" + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + + ctx_str_1 = "context1" + ctx_str_2 = "context2" + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + + named_actions = { + "action1": [ + {"a": str1, "b": rl_chain.Embed(str1)}, + str2, + rl_chain.Embed(str3), + ] + } + context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)} + expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501 + + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + + str1 = "0" + str2 = "1" + str3 = "2" + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + + ctx_str_1 = "context1" + ctx_str_2 = "context2" + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + + named_actions = { + "action1": [ + {"a": str1, "b": rl_chain.EmbedAndKeep(str1)}, + str2, + rl_chain.EmbedAndKeep(str3), + ] + } + context = { + "context1": ctx_str_1, + "context2": rl_chain.EmbedAndKeep(ctx_str_2), + } + expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 + + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context, selected=selected + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_raw_features_underscored() -> None: + feature_embedder = pick_best_chain.PickBestFeatureEmbedder( + auto_embed=False, model=MockEncoder() + ) + str1 = "this is a long string" + str1_underscored = str1.replace(" ", "_") + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + + ctx_str = "this is a long context" + ctx_str_underscored = ctx_str.replace(" ", "_") + encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str)) + + # No embeddings + named_actions = {"action": [str1]} + context = {"context": ctx_str} + expected_no_embed = ( + f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """ + ) + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected_no_embed + + # Just embeddings + named_actions = {"action": rl_chain.Embed([str1])} + context = {"context": rl_chain.Embed(ctx_str)} + expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """ + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected_embed + + # Embeddings and raw features + named_actions = {"action": rl_chain.EmbedAndKeep([str1])} + context = {"context": rl_chain.EmbedAndKeep(ctx_str)} + expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501 + event = pick_best_chain.PickBestEvent( + inputs={}, to_select_from=named_actions, based_on=context + ) + vw_ex_str = feature_embedder.format(event) + assert vw_ex_str == expected_embed_and_keep diff --git a/tests/unit_tests/test_rl_loop_base_embedder.py b/tests/unit_tests/test_rl_loop_base_embedder.py new file mode 100644 index 0000000..ef0aa1e --- /dev/null +++ b/tests/unit_tests/test_rl_loop_base_embedder.py @@ -0,0 +1,422 @@ +from typing import List, Union + +import pytest +from test_utils import MockEncoder + +import learn_to_pick.base as base + +encoded_keyword = "[encoded]" + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_simple_context_str_no_emb() -> None: + expected = [{"a_namespace": "test"}] + assert base.embed("test", MockEncoder(), "a_namespace") == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_simple_context_str_w_emb() -> None: + str1 = "test" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"a_namespace": encoded_str1}] + assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected + expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}] + assert ( + base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace") + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_simple_context_str_w_nested_emb() -> None: + # nested embeddings, innermost wins + str1 = "test" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"a_namespace": encoded_str1}] + assert ( + base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace") + == expected + ) + + expected2 = [{"a_namespace": str1 + " " + encoded_str1}] + assert ( + base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace") + == expected2 + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_context_w_namespace_no_emb() -> None: + expected = [{"test_namespace": "test"}] + assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_context_w_namespace_w_emb() -> None: + str1 = "test" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"test_namespace": encoded_str1}] + assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected + expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}] + assert ( + base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder()) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_context_w_namespace_w_emb2() -> None: + str1 = "test" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"test_namespace": encoded_str1}] + assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected + expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}] + assert ( + base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder()) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_context_w_namespace_w_some_emb() -> None: + str1 = "test1" + str2 = "test2" + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}] + assert ( + base.embed( + {"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder() + ) + == expected + ) + expected_embed_and_keep = [ + { + "test_namespace": str1, + "test_namespace2": str2 + " " + encoded_str2, + } + ] + assert ( + base.embed( + {"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)}, + MockEncoder(), + ) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_simple_action_strlist_no_emb() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}] + to_embed: List[Union[str, base._Embed]] = [str1, str2, str3] + assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_simple_action_strlist_w_emb() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + expected = [ + {"a_namespace": encoded_str1}, + {"a_namespace": encoded_str2}, + {"a_namespace": encoded_str3}, + ] + assert ( + base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace") + == expected + ) + expected_embed_and_keep = [ + {"a_namespace": str1 + " " + encoded_str1}, + {"a_namespace": str2 + " " + encoded_str2}, + {"a_namespace": str3 + " " + encoded_str3}, + ] + assert ( + base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace") + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_simple_action_strlist_w_some_emb() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + expected = [ + {"a_namespace": str1}, + {"a_namespace": encoded_str2}, + {"a_namespace": encoded_str3}, + ] + assert ( + base.embed( + [str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace" + ) + == expected + ) + expected_embed_and_keep = [ + {"a_namespace": str1}, + {"a_namespace": str2 + " " + encoded_str2}, + {"a_namespace": str3 + " " + encoded_str3}, + ] + assert ( + base.embed( + [str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)], + MockEncoder(), + "a_namespace", + ) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_action_w_namespace_no_emb() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + expected = [ + {"test_namespace": str1}, + {"test_namespace": str2}, + {"test_namespace": str3}, + ] + assert ( + base.embed( + [ + {"test_namespace": str1}, + {"test_namespace": str2}, + {"test_namespace": str3}, + ], + MockEncoder(), + ) + == expected + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_action_w_namespace_w_emb() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + expected = [ + {"test_namespace": encoded_str1}, + {"test_namespace": encoded_str2}, + {"test_namespace": encoded_str3}, + ] + assert ( + base.embed( + [ + {"test_namespace": base.Embed(str1)}, + {"test_namespace": base.Embed(str2)}, + {"test_namespace": base.Embed(str3)}, + ], + MockEncoder(), + ) + == expected + ) + expected_embed_and_keep = [ + {"test_namespace": str1 + " " + encoded_str1}, + {"test_namespace": str2 + " " + encoded_str2}, + {"test_namespace": str3 + " " + encoded_str3}, + ] + assert ( + base.embed( + [ + {"test_namespace": base.EmbedAndKeep(str1)}, + {"test_namespace": base.EmbedAndKeep(str2)}, + {"test_namespace": base.EmbedAndKeep(str3)}, + ], + MockEncoder(), + ) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_action_w_namespace_w_emb2() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + expected = [ + {"test_namespace1": encoded_str1}, + {"test_namespace2": encoded_str2}, + {"test_namespace3": encoded_str3}, + ] + assert ( + base.embed( + base.Embed( + [ + {"test_namespace1": str1}, + {"test_namespace2": str2}, + {"test_namespace3": str3}, + ] + ), + MockEncoder(), + ) + == expected + ) + expected_embed_and_keep = [ + {"test_namespace1": str1 + " " + encoded_str1}, + {"test_namespace2": str2 + " " + encoded_str2}, + {"test_namespace3": str3 + " " + encoded_str3}, + ] + assert ( + base.embed( + base.EmbedAndKeep( + [ + {"test_namespace1": str1}, + {"test_namespace2": str2}, + {"test_namespace3": str3}, + ] + ), + MockEncoder(), + ) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_action_w_namespace_w_some_emb() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + expected = [ + {"test_namespace": str1}, + {"test_namespace": encoded_str2}, + {"test_namespace": encoded_str3}, + ] + assert ( + base.embed( + [ + {"test_namespace": str1}, + {"test_namespace": base.Embed(str2)}, + {"test_namespace": base.Embed(str3)}, + ], + MockEncoder(), + ) + == expected + ) + expected_embed_and_keep = [ + {"test_namespace": str1}, + {"test_namespace": str2 + " " + encoded_str2}, + {"test_namespace": str3 + " " + encoded_str3}, + ] + assert ( + base.embed( + [ + {"test_namespace": str1}, + {"test_namespace": base.EmbedAndKeep(str2)}, + {"test_namespace": base.EmbedAndKeep(str3)}, + ], + MockEncoder(), + ) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None: + str1 = "test1" + str2 = "test2" + str3 = "test3" + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + expected = [ + {"test_namespace": encoded_str1, "test_namespace2": str1}, + {"test_namespace": encoded_str2, "test_namespace2": str2}, + {"test_namespace": encoded_str3, "test_namespace2": str3}, + ] + assert ( + base.embed( + [ + {"test_namespace": base.Embed(str1), "test_namespace2": str1}, + {"test_namespace": base.Embed(str2), "test_namespace2": str2}, + {"test_namespace": base.Embed(str3), "test_namespace2": str3}, + ], + MockEncoder(), + ) + == expected + ) + expected_embed_and_keep = [ + { + "test_namespace": str1 + " " + encoded_str1, + "test_namespace2": str1, + }, + { + "test_namespace": str2 + " " + encoded_str2, + "test_namespace2": str2, + }, + { + "test_namespace": str3 + " " + encoded_str3, + "test_namespace2": str3, + }, + ] + assert ( + base.embed( + [ + {"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1}, + {"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2}, + {"test_namespace": base.EmbedAndKeep(str3), "test_namespace2": str3}, + ], + MockEncoder(), + ) + == expected_embed_and_keep + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_one_namespace_w_list_of_features_no_emb() -> None: + str1 = "test1" + str2 = "test2" + expected = [{"test_namespace": [str1, str2]}] + assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_one_namespace_w_list_of_features_w_some_emb() -> None: + str1 = "test1" + str2 = "test2" + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + expected = [{"test_namespace": [str1, encoded_str2]}] + assert ( + base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder()) + == expected + ) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_nested_list_features_throws() -> None: + with pytest.raises(ValueError): + base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder()) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_dict_in_list_throws() -> None: + with pytest.raises(ValueError): + base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder()) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_nested_dict_throws() -> None: + with pytest.raises(ValueError): + base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder()) + + +@pytest.mark.requires("vowpal_wabbit_next") +def test_list_of_tuples_throws() -> None: + with pytest.raises(ValueError): + base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder()) diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py new file mode 100644 index 0000000..b2cc90b --- /dev/null +++ b/tests/unit_tests/test_utils.py @@ -0,0 +1,15 @@ +from typing import Any, List + + +class MockEncoder: + def encode(self, to_encode: str) -> str: + return "[encoded]" + to_encode + + +class MockEncoderReturnsList: + def encode(self, to_encode: Any) -> List: + if isinstance(to_encode, str): + return [1.0, 2.0] + elif isinstance(to_encode, List): + return [[1.0, 2.0] for _ in range(len(to_encode))] + raise ValueError("Invalid input type for unit test")