@@ -75,6 +75,15 @@ def sampler_real_floating(size: tuple[int, ...]):
75
75
raise NotImplementedError (f"{ dtype = } not yet supported." )
76
76
77
77
78
+ def get_exampe_csf_arrays (dtype : np .dtype ) -> tuple :
79
+ pos_1 = np .array ([0 , 1 , 3 ], dtype = np .int64 )
80
+ crd_1 = np .array ([1 , 0 , 1 ], dtype = np .int64 )
81
+ pos_2 = np .array ([0 , 3 , 5 , 7 ], dtype = np .int64 )
82
+ crd_2 = np .array ([0 , 1 , 3 , 0 , 3 , 0 , 1 ], dtype = np .int64 )
83
+ data = np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = dtype )
84
+ return pos_1 , crd_1 , pos_2 , crd_2 , data
85
+
86
+
78
87
@parametrize_dtypes
79
88
@pytest .mark .parametrize ("shape" , [(100 ,), (10 , 200 ), (5 , 10 , 20 )])
80
89
def test_dense_format (dtype , shape ):
@@ -176,11 +185,7 @@ def test_add(rng, dtype):
176
185
@parametrize_dtypes
177
186
def test_csf_format (dtype ):
178
187
SHAPE = (2 , 2 , 4 )
179
- pos_1 = np .array ([0 , 1 , 3 ], dtype = np .int64 )
180
- crd_1 = np .array ([1 , 0 , 1 ], dtype = np .int64 )
181
- pos_2 = np .array ([0 , 3 , 5 , 7 ], dtype = np .int64 )
182
- crd_2 = np .array ([0 , 1 , 3 , 0 , 3 , 0 , 1 ], dtype = np .int64 )
183
- data = np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = dtype )
188
+ pos_1 , crd_1 , pos_2 , crd_2 , data = get_exampe_csf_arrays (dtype )
184
189
csf = [pos_1 , crd_1 , pos_2 , crd_2 , data ]
185
190
186
191
csf_tensor = sparse .asarray (csf , shape = SHAPE , dtype = sparse .asdtype (dtype ), format = "csf" )
@@ -192,3 +197,70 @@ def test_csf_format(dtype):
192
197
csf_2 = [pos_1 , crd_1 , pos_2 , crd_2 , data * 2 ]
193
198
for actual , expected in zip (res_tensor , csf_2 , strict = False ):
194
199
np .testing .assert_array_equal (actual , expected )
200
+
201
+
202
+ @parametrize_dtypes
203
+ def test_reshape (rng , dtype ):
204
+ DENSITY = 0.5
205
+ sampler = generate_sampler (dtype , rng )
206
+
207
+ # CSR, CSC, COO
208
+ for shape , new_shape in [((100 , 50 ), (25 , 200 )), ((80 , 1 ), (8 , 10 ))]:
209
+ for format in ["csr" , "csc" , "coo" ]:
210
+ if format == "coo" :
211
+ # NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
212
+ continue
213
+ if format == "csc" :
214
+ # NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
215
+ continue
216
+
217
+ arr = sps .random_array (
218
+ shape , density = DENSITY , format = format , dtype = dtype , random_state = rng , data_sampler = sampler
219
+ )
220
+ if format == "coo" :
221
+ arr .sum_duplicates ()
222
+
223
+ tensor = sparse .asarray (arr )
224
+
225
+ actual = sparse .reshape (tensor , shape = new_shape ).to_scipy_sparse ()
226
+ expected = arr .todense ().reshape (new_shape )
227
+
228
+ np .testing .assert_array_equal (actual .todense (), expected )
229
+
230
+ # CSF
231
+ csf_shape = (2 , 2 , 4 )
232
+ for shape , new_shape , expected_arrs in [
233
+ (
234
+ csf_shape ,
235
+ (4 , 4 , 1 ),
236
+ [
237
+ np .array ([0 , 0 , 3 , 5 , 7 ]),
238
+ np .array ([0 , 1 , 3 , 0 , 3 , 0 , 1 ]),
239
+ np .array ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
240
+ np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 ]),
241
+ np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
242
+ ],
243
+ ),
244
+ (
245
+ csf_shape ,
246
+ (2 , 1 , 8 ),
247
+ [
248
+ np .array ([0 , 1 , 2 ]),
249
+ np .array ([0 , 0 ]),
250
+ np .array ([0 , 3 , 7 ]),
251
+ np .array ([4 , 5 , 7 , 0 , 3 , 4 , 5 ]),
252
+ np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
253
+ ],
254
+ ),
255
+ ]:
256
+ csf = get_exampe_csf_arrays (dtype )
257
+ csf_tensor = sparse .asarray (csf , shape = shape , dtype = sparse .asdtype (dtype ), format = "csf" )
258
+
259
+ result = sparse .reshape (csf_tensor , shape = new_shape ).to_scipy_sparse ()
260
+
261
+ for actual , expected in zip (result , expected_arrs , strict = False ):
262
+ np .testing .assert_array_equal (actual , expected )
263
+
264
+ # DENSE
265
+ # NOTE: dense reshape is probably broken in MLIR
266
+ # dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
0 commit comments