Skip to content

Commit b11ca95

Browse files
committed
Add code to reproduce error
1 parent a2ff8eb commit b11ca95

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

test_extras/test_chemprop/chemprop_test_utils/default_models.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Functions for creating default chemprop models."""
22

3+
from typing import Any
4+
35
from molpipeline.estimators.chemprop import ChempropModel, ChempropNeuralFP
46
from molpipeline.estimators.chemprop.component_wrapper import (
57
MPNN,
@@ -28,16 +30,26 @@ def get_binary_classification_mpnn() -> MPNN:
2830
return mpnn
2931

3032

31-
def get_neural_fp_encoder() -> ChempropNeuralFP:
33+
def get_neural_fp_encoder(
34+
init_kwargs: dict[str, Any] | None = None,
35+
) -> ChempropNeuralFP:
3236
"""Get the Chemprop model.
3337
38+
Parameters
39+
----------
40+
init_kwargs : dict[str, Any], optional
41+
Additional keyword arguments to pass to `ChempropNeuralFP` during initialization.
42+
3443
Returns
3544
-------
3645
ChempropNeuralFP
3746
The Chemprop model.
3847
"""
3948
mpnn = get_binary_classification_mpnn()
40-
chemprop_model = ChempropNeuralFP(model=mpnn, lightning_trainer__accelerator="cpu")
49+
init_kwargs = init_kwargs or {}
50+
chemprop_model = ChempropNeuralFP(
51+
model=mpnn, lightning_trainer__accelerator="cpu", **init_kwargs
52+
)
4153
return chemprop_model
4254

4355

test_extras/test_chemprop/test_neural_fingerprint.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP
99
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json
10-
11-
# pylint: disable=relative-beyond-top-level
1210
from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params
1311
from test_extras.test_chemprop.chemprop_test_utils.default_models import (
1412
get_neural_fp_encoder,
@@ -38,3 +36,10 @@ def test_output_type(self) -> None:
3836
"""Test the output type."""
3937
chemprop_fp_encoder = get_neural_fp_encoder()
4038
self.assertEqual(chemprop_fp_encoder.output_type, "float")
39+
40+
def test_init_with_kwargs(self) -> None:
41+
"""Test the __init__ method with kwargs."""
42+
init_kwargs = {"model__message_passing__depth": 4}
43+
chemprop_fp_encoder = get_neural_fp_encoder(init_kwargs=init_kwargs)
44+
deep_params = chemprop_fp_encoder.get_params(deep=True)
45+
self.assertEqual(deep_params["model__message_passing__depth"], 4)

0 commit comments

Comments
 (0)