@@ -345,19 +345,24 @@ module Expand =
345
345
let upperBound =
346
346
ClArray.upperBoundAndValue clContext workGroupSize
347
347
348
+ let set = ClArray.set clContext workGroupSize
349
+
348
350
let subMatrix =
349
351
CSR.Matrix.subRows clContext workGroupSize
350
352
351
353
let runCOO =
352
354
runCOO opAdd opMul clContext workGroupSize
353
355
354
- fun ( processor : MailboxProcessor < _ >) allocationMode maxAllocSize ( leftMatrix : ClMatrix.CSR < 'a >) segmentLengths rightMatrixRowsNNZ ( rightMatrix : ClMatrix.CSR < 'b >) ->
356
+ fun ( processor : MailboxProcessor < _ >) allocationMode maxAllocSize generalLength ( leftMatrix : ClMatrix.CSR < 'a >) segmentLengths rightMatrixRowsNNZ ( rightMatrix : ClMatrix.CSR < 'b >) ->
355
357
// extract segment lengths by left matrix rows pointers
356
358
let segmentPointersByLeftMatrixRows =
357
359
clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, leftMatrix.RowPointers.Length)
358
360
359
361
gather processor leftMatrix.RowPointers segmentLengths segmentPointersByLeftMatrixRows
360
362
363
+ // set last element to one step length
364
+ set processor segmentPointersByLeftMatrixRows ( leftMatrix.RowPointers.Length - 1 ) generalLength
365
+
361
366
// curring
362
367
let upperBound =
363
368
upperBound processor segmentPointersByLeftMatrixRows
@@ -422,10 +427,10 @@ module Expand =
422
427
let rightMatrixRowsNNZ =
423
428
getNNZInRows processor DeviceOnly rightMatrix
424
429
425
- let length , segmentLengths =
430
+ let generalLength , segmentLengths =
426
431
getSegmentPointers processor leftMatrix.Columns rightMatrixRowsNNZ
427
432
428
- if length < maxAllocSize then
433
+ if generalLength < maxAllocSize then
429
434
segmentLengths.Free processor
430
435
431
436
runOneStep processor allocationMode leftMatrix rightMatrixRowsNNZ rightMatrix
@@ -435,6 +440,7 @@ module Expand =
435
440
processor
436
441
allocationMode
437
442
maxAllocSize
443
+ generalLength
438
444
leftMatrix
439
445
segmentLengths
440
446
rightMatrixRowsNNZ
0 commit comments