@@ -200,3 +200,86 @@ func TestCOOTranspose(t *testing.T) {
200
200
}
201
201
}
202
202
}
203
+
204
+ func TestPermutations (t * testing.T ) {
205
+ for _ , test := range []struct {
206
+ m * COO
207
+ permutation []int
208
+ want * mat.Dense
209
+ desc string
210
+ permuteCol bool
211
+ }{
212
+ {
213
+ m : NewCOO (2 , 3 , []int {0 , 0 , 1 }, []int {0 , 1 , 0 }, []float64 {1.0 , 2.0 , 3.0 }),
214
+ permutation : []int {1 , 0 },
215
+ want : mat .NewDense (2 , 3 , []float64 {3.0 , 0.0 , 0.0 , 1.0 , 2.0 , 0.0 }),
216
+ desc : "2x3 matrix row permutation" ,
217
+ permuteCol : false ,
218
+ },
219
+ {
220
+ m : NewCOO (2 , 3 , []int {0 , 0 , 1 }, []int {0 , 1 , 0 }, []float64 {1.0 , 2.0 , 3.0 }),
221
+ permutation : []int {1 , 0 , 2 },
222
+ want : mat .NewDense (2 , 3 , []float64 {2.0 , 1.0 , 0.0 , 0.0 , 3.0 , 0.0 }),
223
+ desc : "2x3 matrix col permutation" ,
224
+ permuteCol : true ,
225
+ },
226
+ {
227
+ m : NewCOO (1 , 1 , []int {0 }, []int {0 }, []float64 {1.0 }),
228
+ permutation : []int {0 },
229
+ want : mat .NewDense (1 , 1 , []float64 {1.0 }),
230
+ desc : "1x1 matrix row permutation" ,
231
+ permuteCol : false ,
232
+ },
233
+ } {
234
+ if test .permuteCol {
235
+ test .m .PermuteCols (test .permutation )
236
+ } else {
237
+ test .m .PermuteRows (test .permutation )
238
+ }
239
+
240
+ result := test .m .ToDense ()
241
+
242
+ if ! mat .Equal (result , test .want ) {
243
+ t .Errorf ("Test: %s: Expected\n %v\n got%v\n " , test .desc , test .want , result )
244
+ }
245
+
246
+ }
247
+ }
248
+
249
+ func TestIsPermutation (t * testing.T ) {
250
+ for _ , test := range []struct {
251
+ N int
252
+ permutation []int
253
+ valid bool
254
+ desc string
255
+ }{
256
+ {
257
+ N : 2 ,
258
+ permutation : []int {0 , 1 },
259
+ valid : true ,
260
+ desc : "Valid 2xN matrix permutation" ,
261
+ },
262
+ {
263
+ N : 3 ,
264
+ permutation : []int {0 , 1 },
265
+ valid : false ,
266
+ desc : "Invalid: missing 2" ,
267
+ },
268
+ {
269
+ N : 3 ,
270
+ permutation : []int {0 , 0 , 1 },
271
+ valid : false ,
272
+ desc : "Invalid: duplicates" ,
273
+ },
274
+ {
275
+ N : 3 ,
276
+ permutation : []int {1 , 0 , 4 },
277
+ valid : false ,
278
+ desc : "Invalid: value out of bounds" ,
279
+ },
280
+ } {
281
+ if result := isPermutation (test .permutation , test .N ); result != test .valid {
282
+ t .Errorf ("Test %s: expected %v got %v\n " , test .desc , test .valid , result )
283
+ }
284
+ }
285
+ }
0 commit comments