Skip to content

Commit ab0bf5d

Browse files
committed
feat: add efficient permutation operations for COO
1 parent 1e6c7dd commit ab0bf5d

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

coordinate.go

+46
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,52 @@ func (c *COO) ToCOO() *COO {
154154
return c
155155
}
156156

157+
// a []int is a row permutation if all row numbers are present once
158+
func isPermutation(permutation []int, N int) bool {
159+
if len(permutation) != N {
160+
return false
161+
}
162+
163+
distinctValues := make(map[int]bool)
164+
for _, v := range permutation {
165+
if v < N && v >= 0 {
166+
distinctValues[v] = true
167+
}
168+
}
169+
return len(distinctValues) == N
170+
}
171+
172+
// PermuteRows swaps rows in the matrix given by the permutations.
173+
// If the provided array is not a valid permutation, it panics.
174+
// A valid permutation contains all rows exactly once
175+
//
176+
// Example:
177+
// For a 3x2 matrix a permutation that swaps row 1 and 3 is
178+
// {2, 1, 0}
179+
func (c *COO) PermuteRows(permutations []int) {
180+
if !isPermutation(permutations, c.r) {
181+
panic("invalid permutation (some rows are missing)")
182+
}
183+
mapSlice(c.rows, permutations)
184+
}
185+
186+
// PermuteCols swaps columns int the matrix given by permutations
187+
// If the provided array is not a valid permutation, it panics
188+
// See PermuteRows for a detailed explination of valid permutations
189+
// (works the same except "row" are replaced by "col")
190+
func (c *COO) PermuteCols(permutations []int) {
191+
if !isPermutation(permutations, c.c) {
192+
panic("invalid permutation (some cols missing)")
193+
}
194+
mapSlice(c.cols, permutations)
195+
}
196+
197+
func mapSlice(data []int, permutation []int) {
198+
for i := range data {
199+
data[i] = permutation[data[i]]
200+
}
201+
}
202+
157203
func cumsum(p []int, c []int, n int) int {
158204
nz := 0
159205
for i := 0; i < n; i++ {

coordinate_test.go

+83
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,86 @@ func TestCOOTranspose(t *testing.T) {
200200
}
201201
}
202202
}
203+
204+
func TestPermutations(t *testing.T) {
205+
for _, test := range []struct {
206+
m *COO
207+
permutation []int
208+
want *mat.Dense
209+
desc string
210+
permuteCol bool
211+
}{
212+
{
213+
m: NewCOO(2, 3, []int{0, 0, 1}, []int{0, 1, 0}, []float64{1.0, 2.0, 3.0}),
214+
permutation: []int{1, 0},
215+
want: mat.NewDense(2, 3, []float64{3.0, 0.0, 0.0, 1.0, 2.0, 0.0}),
216+
desc: "2x3 matrix row permutation",
217+
permuteCol: false,
218+
},
219+
{
220+
m: NewCOO(2, 3, []int{0, 0, 1}, []int{0, 1, 0}, []float64{1.0, 2.0, 3.0}),
221+
permutation: []int{1, 0, 2},
222+
want: mat.NewDense(2, 3, []float64{2.0, 1.0, 0.0, 0.0, 3.0, 0.0}),
223+
desc: "2x3 matrix col permutation",
224+
permuteCol: true,
225+
},
226+
{
227+
m: NewCOO(1, 1, []int{0}, []int{0}, []float64{1.0}),
228+
permutation: []int{0},
229+
want: mat.NewDense(1, 1, []float64{1.0}),
230+
desc: "1x1 matrix row permutation",
231+
permuteCol: false,
232+
},
233+
} {
234+
if test.permuteCol {
235+
test.m.PermuteCols(test.permutation)
236+
} else {
237+
test.m.PermuteRows(test.permutation)
238+
}
239+
240+
result := test.m.ToDense()
241+
242+
if !mat.Equal(result, test.want) {
243+
t.Errorf("Test: %s: Expected\n%v\ngot%v\n", test.desc, test.want, result)
244+
}
245+
246+
}
247+
}
248+
249+
func TestIsPermutation(t *testing.T) {
250+
for _, test := range []struct {
251+
N int
252+
permutation []int
253+
valid bool
254+
desc string
255+
}{
256+
{
257+
N: 2,
258+
permutation: []int{0, 1},
259+
valid: true,
260+
desc: "Valid 2xN matrix permutation",
261+
},
262+
{
263+
N: 3,
264+
permutation: []int{0, 1},
265+
valid: false,
266+
desc: "Invalid: missing 2",
267+
},
268+
{
269+
N: 3,
270+
permutation: []int{0, 0, 1},
271+
valid: false,
272+
desc: "Invalid: duplicates",
273+
},
274+
{
275+
N: 3,
276+
permutation: []int{1, 0, 4},
277+
valid: false,
278+
desc: "Invalid: value out of bounds",
279+
},
280+
} {
281+
if result := isPermutation(test.permutation, test.N); result != test.valid {
282+
t.Errorf("Test %s: expected %v got %v\n", test.desc, test.valid, result)
283+
}
284+
}
285+
}

0 commit comments

Comments
 (0)