File tree 2 files changed +21
-4
lines changed
2 files changed +21
-4
lines changed Original file line number Diff line number Diff line change 1
1
"""Functions for creating default chemprop models."""
2
2
3
+ from typing import Any
4
+
3
5
from molpipeline .estimators .chemprop import ChempropModel , ChempropNeuralFP
4
6
from molpipeline .estimators .chemprop .component_wrapper import (
5
7
MPNN ,
@@ -28,16 +30,26 @@ def get_binary_classification_mpnn() -> MPNN:
28
30
return mpnn
29
31
30
32
31
- def get_neural_fp_encoder () -> ChempropNeuralFP :
33
+ def get_neural_fp_encoder (
34
+ init_kwargs : dict [str , Any ] | None = None ,
35
+ ) -> ChempropNeuralFP :
32
36
"""Get the Chemprop model.
33
37
38
+ Parameters
39
+ ----------
40
+ init_kwargs : dict[str, Any], optional
41
+ Additional keyword arguments to pass to `ChempropNeuralFP` during initialization.
42
+
34
43
Returns
35
44
-------
36
45
ChempropNeuralFP
37
46
The Chemprop model.
38
47
"""
39
48
mpnn = get_binary_classification_mpnn ()
40
- chemprop_model = ChempropNeuralFP (model = mpnn , lightning_trainer__accelerator = "cpu" )
49
+ init_kwargs = init_kwargs or {}
50
+ chemprop_model = ChempropNeuralFP (
51
+ model = mpnn , lightning_trainer__accelerator = "cpu" , ** init_kwargs
52
+ )
41
53
return chemprop_model
42
54
43
55
Original file line number Diff line number Diff line change 7
7
8
8
from molpipeline .estimators .chemprop .neural_fingerprint import ChempropNeuralFP
9
9
from molpipeline .utils .json_operations import recursive_from_json , recursive_to_json
10
-
11
- # pylint: disable=relative-beyond-top-level
12
10
from test_extras .test_chemprop .chemprop_test_utils .compare_models import compare_params
13
11
from test_extras .test_chemprop .chemprop_test_utils .default_models import (
14
12
get_neural_fp_encoder ,
@@ -38,3 +36,10 @@ def test_output_type(self) -> None:
38
36
"""Test the output type."""
39
37
chemprop_fp_encoder = get_neural_fp_encoder ()
40
38
self .assertEqual (chemprop_fp_encoder .output_type , "float" )
39
+
40
+ def test_init_with_kwargs (self ) -> None :
41
+ """Test the __init__ method with kwargs."""
42
+ init_kwargs = {"model__message_passing__depth" : 4 }
43
+ chemprop_fp_encoder = get_neural_fp_encoder (init_kwargs = init_kwargs )
44
+ deep_params = chemprop_fp_encoder .get_params (deep = True )
45
+ self .assertEqual (deep_params ["model__message_passing__depth" ], 4 )
You can’t perform that action at this time.
0 commit comments