-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_neural_fingerprint.py
40 lines (30 loc) · 1.51 KB
/
test_neural_fingerprint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Test Chemprop neural fingerprint."""
import logging
import unittest
from sklearn.base import clone
from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json
# pylint: disable=relative-beyond-top-level
from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params
from test_extras.test_chemprop.chemprop_test_utils.default_models import (
get_neural_fp_encoder,
)
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
class TestChempropNeuralFingerprint(unittest.TestCase):
"""Test the Chemprop model."""
def test_clone(self) -> None:
"""Test the clone method."""
chemprop_fp_encoder = get_neural_fp_encoder()
cloned_encoder = clone(chemprop_fp_encoder)
self.assertIsInstance(cloned_encoder, ChempropNeuralFP)
compare_params(self, chemprop_fp_encoder, cloned_encoder)
def test_json_serialization(self) -> None:
"""Test the to_json and from_json methods."""
chemprop_fp_encoder = get_neural_fp_encoder()
chemprop_json = recursive_to_json(chemprop_fp_encoder)
chemprop_encoder_copy = recursive_from_json(chemprop_json)
compare_params(self, chemprop_fp_encoder, chemprop_encoder_copy)
def test_output_type(self) -> None:
"""Test the output type."""
chemprop_fp_encoder = get_neural_fp_encoder()
self.assertEqual(chemprop_fp_encoder.output_type, "float")