1
+ import pytest
1
2
import numpy as np
2
3
from CADETPythonSimulator .field import Field , FieldInterpolator
3
4
import matplotlib .pyplot as plt
4
5
# %% Testing utilities
5
6
7
+
6
8
def assert_equal (value , expected , message = "" ):
7
9
message = f"Test failed: { message } . Expected { expected } , got { value } ."
8
10
assert value == expected , message
@@ -26,47 +28,60 @@ def test_field_initialization():
26
28
assert_equal (viscosity .n_dof , 11 * 6 , "Scalar field degrees of freedom" )
27
29
28
30
# Vector field
29
- concentration = Field (name = "concentration" , dimensions = dimensions , n_components = 3 )
31
+ concentration = Field (name = "concentration" ,
32
+ dimensions = dimensions , n_components = 3 )
30
33
assert_shape (concentration .shape , (11 , 6 , 3 ), "Vector field shape" )
31
- assert_equal (concentration .n_dof , 11 * 6 * 3 , "Vector field degrees of freedom" )
34
+ assert_equal (concentration .n_dof , 11 * 6 * 3 ,
35
+ "Vector field degrees of freedom" )
32
36
33
37
# Custom data
34
38
data = np .ones ((11 , 6 , 3 ))
35
- concentration_with_data = Field (name = "concentration" , dimensions = dimensions , n_components = 3 , data = data )
36
- assert_shape (concentration_with_data .shape , (11 , 6 , 3 ), "Custom data field shape" )
39
+ concentration_with_data = Field (
40
+ name = "concentration" , dimensions = dimensions , n_components = 3 , data = data )
41
+ assert_shape (concentration_with_data .shape ,
42
+ (11 , 6 , 3 ), "Custom data field shape" )
43
+
44
+ with pytest .raises (ValueError ):
45
+ viscosity .data = np .ones ((1 , 2 , 3 ))
37
46
38
47
39
48
# %% Plotting
40
49
41
50
def test_plotting ():
42
51
# 1D Plot
43
52
dimensions = {"x" : np .linspace (0 , 10 , 11 )}
44
- field_1D = Field (name = "1D Field" , dimensions = dimensions , n_components = 2 , data = np .random .random ((11 , 2 )))
53
+ field_1D = Field (name = "1D Field" , dimensions = dimensions ,
54
+ n_components = 2 , data = np .random .random ((11 , 2 )))
45
55
fig , ax = field_1D .plot ()
46
56
assert isinstance (ax , plt .Axes ), "1D plot returns one axis"
47
57
48
58
# 2D Plot
49
59
dimensions = {"x" : np .linspace (0 , 10 , 11 ), "y" : np .linspace (0 , 5 , 6 )}
50
- field_2D = Field (name = "2D Field" , dimensions = dimensions , n_components = 3 , data = np .random .random ((11 , 6 , 3 )))
60
+ field_2D = Field (name = "2D Field" , dimensions = dimensions ,
61
+ n_components = 3 , data = np .random .random ((11 , 6 , 3 )))
51
62
fig , axes = field_2D .plot ()
52
63
assert len (axes ) == 3 , "2D plot returns one axis per component"
53
64
54
65
55
66
# %% Slicing
56
67
57
68
def test_field_slicing ():
58
- dimensions = {"axial" : np .linspace (0 , 10 , 11 ), "radial" : np .linspace (0 , 5 , 6 )}
69
+ dimensions = {"axial" : np .linspace (
70
+ 0 , 10 , 11 ), "radial" : np .linspace (0 , 5 , 6 )}
59
71
field = Field (name = "concentration" , dimensions = dimensions , n_components = 3 )
60
72
61
73
# Slice along one dimension
62
74
field_sliced = field [{"axial" : 0 }]
63
- assert_equal (len (field_sliced .dimensions ), 1 , "Field slicing reduces dimensionality" )
75
+ assert_equal (len (field_sliced .dimensions ), 1 ,
76
+ "Field slicing reduces dimensionality" )
64
77
assert_shape (field_sliced .shape , (6 , 3 ), "Field slicing shape" )
65
78
66
79
# Slice along all dimensions
67
80
field_sliced_all = field [{"axial" : 0 , "radial" : 0 }]
68
- assert_equal (len (field_sliced_all .dimensions ), 0 , "Full slicing removes all dimensions" )
69
- assert_shape (field_sliced_all .shape , (3 ,), "Full slicing results in vector" )
81
+ assert_equal (len (field_sliced_all .dimensions ), 0 ,
82
+ "Full slicing removes all dimensions" )
83
+ assert_shape (field_sliced_all .shape , (3 ,),
84
+ "Full slicing results in vector" )
70
85
71
86
72
87
# %% Normalization
@@ -89,14 +104,15 @@ def test_field_normalization():
89
104
90
105
# Test 2: Verify data normalization
91
106
normalized_data = normalized_field .data
92
- assert np .isclose (np .min (normalized_data ), 0.0 ), "Normalized data minimum is not 0."
93
- assert np .isclose (np .max (normalized_data ), 1.0 ), "Normalized data maximum is not 1."
107
+ assert np .isclose (np .min (normalized_data ),
108
+ 0.0 ), "Normalized data minimum is not 0."
109
+ assert np .isclose (np .max (normalized_data ),
110
+ 1.0 ), "Normalized data maximum is not 1."
94
111
95
112
# Test 3: Ensure original field is unchanged
96
113
assert np .array_equal (field .data , z ), "Original field data was modified."
97
114
98
115
99
-
100
116
# %% Interpolation and Resampling
101
117
102
118
def test_temperature_use_case ():
@@ -136,7 +152,8 @@ def test_interpolated_field():
136
152
"radial" : np .linspace (0 , 5 , 6 )
137
153
}
138
154
data = np .random .random ((11 , 6 , 3 ))
139
- concentration = Field (name = "concentration" , dimensions = dimensions , n_components = 3 , data = data )
155
+ concentration = Field (name = "concentration" ,
156
+ dimensions = dimensions , n_components = 3 , data = data )
140
157
141
158
# Interpolated field
142
159
interp_field = FieldInterpolator (concentration )
@@ -146,15 +163,17 @@ def test_interpolated_field():
146
163
147
164
def test_resampling ():
148
165
dimensions = {"x" : np .linspace (0 , 10 , 11 ), "y" : np .linspace (0 , 5 , 6 )}
149
- field = Field (name = "concentration" , dimensions = dimensions , n_components = 2 , data = np .random .random ((11 , 6 , 2 )))
166
+ field = Field (name = "concentration" , dimensions = dimensions ,
167
+ n_components = 2 , data = np .random .random ((11 , 6 , 2 )))
150
168
151
169
# Resample one dimension
152
170
resampled_field = field .resample ({"x" : 50 })
153
171
assert_shape (resampled_field .shape , (50 , 6 , 2 ), "Resampling one dimension" )
154
172
155
173
# Resample all dimensions
156
174
resampled_field_all = field .resample ({"x" : 50 , "y" : 25 })
157
- assert_shape (resampled_field_all .shape , (50 , 25 , 2 ), "Resampling all dimensions" )
175
+ assert_shape (resampled_field_all .shape , (50 , 25 , 2 ),
176
+ "Resampling all dimensions" )
158
177
159
178
160
179
def test_field_interpolation_and_derivatives ():
@@ -203,6 +222,5 @@ def test_field_interpolation_and_derivatives():
203
222
204
223
# %% Run tests
205
224
206
- import pytest
207
225
if __name__ == "__main__" :
208
226
pytest .main ('test_field.py' )
0 commit comments