Skip to content

Commit 8055717

Browse files
feat: add tests for save_hyperparameters with ignore behavior and composition handling (#20718)
* add tests to check for `save_hyperparameter: ignore` * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fe79be1 commit 8055717

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/tests_pytorch/models/test_hparams.py

+35
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,41 @@ def __init__(self, same_arg="parent_default", other_arg="other"):
440440
assert parent.child.hparams == {"same_arg": "cocofruit"}
441441

442442

443+
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
444+
def test_save_hyperparameters_ignore(base_class):
445+
"""Test if `save_hyperparameter` applies the ignore list correctly during initialization."""
446+
447+
class PLSubclass(base_class):
448+
def __init__(self, learning_rate=1e-3, optimizer="adam"):
449+
super().__init__()
450+
self.save_hyperparameters(ignore=["learning_rate"])
451+
452+
pl_instance = PLSubclass(learning_rate=0.01, optimizer="sgd")
453+
assert pl_instance.hparams == {"optimizer": "sgd"}
454+
455+
456+
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
457+
def test_save_hyperparameters_ignore_under_composition(base_class):
458+
"""Test that in a composed system, hyperparameter saving skips ignored fields from nested modules."""
459+
460+
class ChildModule(base_class):
461+
def __init__(self, dropout, activation, init_method):
462+
super().__init__()
463+
self.save_hyperparameters(ignore=["dropout", "activation"])
464+
465+
class ParentModule(base_class):
466+
def __init__(self, batch_size, optimizer):
467+
super().__init__()
468+
self.child = ChildModule(dropout=0.1, activation="relu", init_method="xavier")
469+
470+
class PipelineWrapper: # not a Lightning subclass on purpose
471+
def __init__(self, run_id="abc123", seed=42):
472+
self.parent_module = ParentModule(batch_size=64, optimizer="adam")
473+
474+
pipeline = PipelineWrapper()
475+
assert pipeline.parent_module.child.hparams == {"init_method": "xavier", "batch_size": 64, "optimizer": "adam"}
476+
477+
443478
class LocalVariableModelSuperLast(BoringModel):
444479
"""This model has the super().__init__() call at the end."""
445480

0 commit comments

Comments
 (0)