@@ -22,69 +22,81 @@ def test__clip_out_of_bounds():
22
22
assert numpy .all (diff < 0.0001 )
23
23
24
24
25
- class Mixin_Transform (object ):
26
- known_input_dims = 1
27
- known_output_dims = 1
28
- known_is_separable = True
29
- known_has_inverse = True
30
25
31
- def test_input_dims (self ):
32
- assert hasattr (self .trans , 'input_dims' )
33
- assert self .trans .input_dims == self .known_input_dims
26
+ @pytest .fixture
27
+ def prob_trans ():
28
+ cls = transforms .ProbTransform
29
+ return cls (_minimal_norm )
34
30
35
- def test_output_dims (self ):
36
- assert hasattr (self .trans , 'output_dims' )
37
- assert self .trans .output_dims == self .known_output_dims
38
31
39
- def test_is_separable (self ):
40
- assert hasattr (self .trans , 'is_separable' )
41
- assert self .trans .is_separable == self .known_is_separable
32
+ @pytest .fixture
33
+ def quant_trans ():
34
+ cls = transforms .QuantileTransform
35
+ return cls (_minimal_norm )
42
36
43
- def test_has_inverse (self ):
44
- assert hasattr (self .trans , 'has_inverse' )
45
- assert self .trans .has_inverse == self .known_has_inverse
46
37
47
- def test_dist ( self ):
48
- assert hasattr ( self . trans , 'dist' )
49
- assert self . trans .dist == _minimal_norm
38
+ @ pytest . mark . parametrize ( 'trans' , [ prob_trans (), quant_trans ()])
39
+ def test_transform_input_dims ( trans ):
40
+ assert trans .input_dims == 1
50
41
51
- def test_transform_non_affine (self ):
52
- assert hasattr (self .trans , 'transform_non_affine' )
53
- diff = numpy .abs (self .trans .transform_non_affine ([0.5 ]) - self .known_tras_na )
54
- assert numpy .all (diff < 0.0001 )
55
42
56
- def test_inverted (self ):
57
- assert hasattr (self .trans , 'inverted' )
43
+ @pytest .mark .parametrize ('trans' , [prob_trans (), quant_trans ()])
44
+ def test_transform_output_dims (trans ):
45
+ assert trans .output_dims == 1
58
46
59
- def test_bad_non_pos (self ):
60
- with pytest .raises (ValueError ):
61
- self ._trans (_minimal_norm , nonpos = 'junk' )
62
47
63
- def test_non_pos_clip (self ):
64
- self ._trans (_minimal_norm , nonpos = 'clip' )
48
+ @pytest .mark .parametrize ('trans' , [prob_trans (), quant_trans ()])
49
+ def test_transform_is_separable (trans ):
50
+ assert trans .is_separable
51
+
65
52
53
+ @pytest .mark .parametrize ('trans' , [prob_trans (), quant_trans ()])
54
+ def test_transform_has_inverse (trans ):
55
+ assert trans .has_inverse
66
56
67
- class Test_ProbTransform (Mixin_Transform ):
68
- def setup (self ):
69
- self ._trans = transforms .ProbTransform
70
- self .trans = transforms .ProbTransform (_minimal_norm )
71
- self .known_tras_na = [- 2.569150498 ]
72
57
73
- def test_inverted (self ):
74
- inv_trans = self .trans .inverted ()
75
- assert self .trans .dist == inv_trans .dist
76
- assert self .trans .factor == inv_trans .factor
77
- assert self .trans .nonpos == inv_trans .nonpos
58
+ @pytest .mark .parametrize ('trans' , [prob_trans (), quant_trans ()])
59
+ def test_transform_dist (trans ):
60
+ trans .dist == _minimal_norm
78
61
79
62
80
- class Test_QuantileTransform (Mixin_Transform ):
81
- def setup (self ):
82
- self ._trans = transforms .QuantileTransform
83
- self .trans = transforms .QuantileTransform (_minimal_norm )
84
- self .known_tras_na = [69.1464492 ]
63
+ @pytest .mark .parametrize (('trans' , 'known_trans_na' ), [
64
+ (prob_trans (), - 2.569150498 ), (quant_trans (), 69.1464492 )
65
+ ])
66
+ def test_transform_non_affine (trans , known_trans_na ):
67
+ diff = numpy .abs (trans .transform_non_affine ([0.5 ]) - known_trans_na )
68
+ assert numpy .all (diff < 0.0001 )
85
69
86
- def test_inverted (self ):
87
- inv_trans = self .trans .inverted ()
88
- assert self .trans .dist == inv_trans .dist
89
- assert self .trans .factor == inv_trans .factor
90
- assert self .trans .nonpos == inv_trans .nonpos
70
+
71
+ @pytest .mark .parametrize (('trans' , 'inver_cls' ), [
72
+ (prob_trans (), transforms .QuantileTransform ),
73
+ (quant_trans (), transforms .ProbTransform ),
74
+ ])
75
+ def test_transform_inverted (trans , inver_cls ):
76
+ t_inv = trans .inverted ()
77
+ assert isinstance (t_inv , inver_cls )
78
+ assert trans .dist == t_inv .dist
79
+ assert trans .as_pct == t_inv .as_pct
80
+ assert trans .out_of_bounds == t_inv .out_of_bounds
81
+
82
+
83
+ @pytest .mark .parametrize ('cls' , [transforms .ProbTransform , transforms .QuantileTransform ])
84
+ def test_bad_out_of_bounds (cls ):
85
+ with pytest .raises (ValueError ):
86
+ cls (_minimal_norm , out_of_bounds = 'junk' )
87
+
88
+
89
+ @pytest .mark .parametrize ('cls' , [transforms .ProbTransform , transforms .QuantileTransform ])
90
+ @pytest .mark .parametrize (('method' , 'func' ), [
91
+ ('clip' , transforms ._clip_out_of_bounds ),
92
+ ('mask' , transforms ._mask_out_of_bounds ),
93
+ ('junk' , None ),
94
+ ])
95
+ def test_out_of_bounds (cls , method , func ):
96
+ if func is None :
97
+ with pytest .raises (ValueError ):
98
+ cls (_minimal_norm , out_of_bounds = method )
99
+ else :
100
+ t = cls (_minimal_norm , out_of_bounds = method )
101
+ assert t .out_of_bounds == method
102
+ assert t ._handle_out_of_bounds == func
0 commit comments