Skip to content

Commit 4a70218

Browse files
authored
Merge pull request #90 from artemiipatov/msbfs
MSBFS
2 parents 12dd993 + a94fd07 commit 4a70218

File tree

32 files changed

+1817
-38
lines changed

32 files changed

+1817
-38
lines changed

benchmarks/GraphBLAS-sharp.Benchmarks/Algorithms/BFS.fs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,3 @@ type BFSWithTransferBenchmarkInt32() =
187187

188188
static member InputMatrixProvider =
189189
Benchmarks<_>.InputMatrixProviderBuilder "BFSBenchmarks.txt"
190-

benchmarks/GraphBLAS-sharp.Benchmarks/GraphBLAS-sharp.Benchmarks.fsproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0" encoding="utf-8"?>
1+
<?xml version="1.0" encoding="utf-8"?>
22
<Project Sdk="Microsoft.NET.Sdk">
33
<PropertyGroup>
44
<OutputType>Exe</OutputType>

src/GraphBLAS-sharp.Backend/Algorithms/Algorithms.fs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,10 @@ module Algorithms =
1212

1313
let singleSourcePushPull = BFS.singleSourcePushPull
1414

15+
module MSBFS =
16+
let runLevels = MSBFS.Levels.run
17+
18+
let runParents = MSBFS.Parents.run
19+
1520
module SSSP =
16-
let singleSource = SSSP.run
21+
let run = SSSP.run
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
namespace GraphBLAS.FSharp.Backend.Algorithms
2+
3+
open Brahma.FSharp
4+
open FSharp.Quotations
5+
open GraphBLAS.FSharp
6+
open GraphBLAS.FSharp.Objects
7+
open GraphBLAS.FSharp.Common
8+
open GraphBLAS.FSharp.Objects.ClMatrix
9+
open GraphBLAS.FSharp.Objects.ArraysExtensions
10+
open GraphBLAS.FSharp.Objects.ClContextExtensions
11+
open GraphBLAS.FSharp.Objects.ClCellExtensions
12+
open GraphBLAS.FSharp.Backend.Quotes
13+
open GraphBLAS.FSharp.Backend.Matrix.LIL
14+
open GraphBLAS.FSharp.Backend.Matrix.COO
15+
16+
module internal MSBFS =
17+
let private frontExclude (clContext: ClContext) workGroupSize =
18+
19+
let invert =
20+
ClArray.mapInPlace ArithmeticOperations.intNotQ clContext workGroupSize
21+
22+
let prefixSum =
23+
PrefixSum.standardExcludeInPlace clContext workGroupSize
24+
25+
let scatterIndices =
26+
Scatter.lastOccurrence clContext workGroupSize
27+
28+
let scatterValues =
29+
Scatter.lastOccurrence clContext workGroupSize
30+
31+
fun (queue: MailboxProcessor<_>) allocationMode (front: ClMatrix.COO<_>) (intersection: ClArray<int>) ->
32+
33+
invert queue intersection
34+
35+
let length =
36+
(prefixSum queue intersection).ToHostAndFree queue
37+
38+
if length = 0 then
39+
None
40+
else
41+
let rows =
42+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, length)
43+
44+
let columns =
45+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, length)
46+
47+
let values =
48+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, length)
49+
50+
scatterIndices queue intersection front.Rows rows
51+
scatterIndices queue intersection front.Columns columns
52+
scatterValues queue intersection front.Values values
53+
54+
{ Context = clContext
55+
Rows = rows
56+
Columns = columns
57+
Values = values
58+
RowCount = front.RowCount
59+
ColumnCount = front.ColumnCount }
60+
|> Some
61+
62+
module Levels =
63+
let private updateFrontAndLevels (clContext: ClContext) workGroupSize =
64+
65+
let updateFront = frontExclude clContext workGroupSize
66+
67+
let mergeDisjoint =
68+
Matrix.mergeDisjoint clContext workGroupSize
69+
70+
let setLevel = ClArray.fill clContext workGroupSize
71+
72+
let findIntersection =
73+
Intersect.findKeysIntersection clContext workGroupSize
74+
75+
fun (queue: MailboxProcessor<_>) allocationMode (level: int) (front: ClMatrix.COO<_>) (levels: ClMatrix.COO<_>) ->
76+
77+
// Find intersection of levels and front indices.
78+
let intersection =
79+
findIntersection queue DeviceOnly front levels
80+
81+
// Remove mutual elements
82+
let newFront =
83+
updateFront queue allocationMode front intersection
84+
85+
intersection.Free queue
86+
87+
match newFront with
88+
| Some f ->
89+
let levelClCell = clContext.CreateClCell level
90+
91+
// Set current level value to all remaining front positions
92+
setLevel queue levelClCell 0 f.Values.Length f.Values
93+
94+
levelClCell.Free queue
95+
96+
// Update levels
97+
let newLevels = mergeDisjoint queue levels f
98+
99+
newLevels, newFront
100+
| _ -> levels, None
101+
102+
let run<'a when 'a: struct>
103+
(add: Expr<int -> int -> int option>)
104+
(mul: Expr<int -> 'a -> int option>)
105+
(clContext: ClContext)
106+
workGroupSize
107+
=
108+
109+
let spGeMM =
110+
Operations.SpGeMM.COO.expand add mul clContext workGroupSize
111+
112+
let copy = Matrix.copy clContext workGroupSize
113+
114+
let updateFrontAndLevels =
115+
updateFrontAndLevels clContext workGroupSize
116+
117+
fun (queue: MailboxProcessor<Msg>) (matrix: ClMatrix<'a>) (source: int list) ->
118+
let vertexCount = matrix.RowCount
119+
let sourceVertexCount = source.Length
120+
121+
let source = source |> List.sort
122+
123+
let startMatrix =
124+
source |> List.mapi (fun i vertex -> i, vertex, 1)
125+
126+
let mutable levels =
127+
startMatrix
128+
|> Matrix.ofList clContext DeviceOnly sourceVertexCount vertexCount
129+
130+
let mutable front = copy queue DeviceOnly levels
131+
132+
let mutable level = 1
133+
let mutable stop = false
134+
135+
while not stop do
136+
level <- level + 1
137+
138+
//Getting new frontier
139+
match spGeMM queue DeviceOnly (ClMatrix.COO front) matrix with
140+
| None ->
141+
front.Dispose queue
142+
stop <- true
143+
144+
| Some newFrontier ->
145+
front.Dispose queue
146+
147+
//Filtering visited vertices
148+
match updateFrontAndLevels queue DeviceOnly level newFrontier levels with
149+
| l, Some f ->
150+
front <- f
151+
152+
levels.Dispose queue
153+
154+
levels <- l
155+
156+
newFrontier.Dispose queue
157+
158+
| _, None ->
159+
stop <- true
160+
newFrontier.Dispose queue
161+
162+
ClMatrix.COO levels
163+
164+
module Parents =
165+
let private updateFrontAndParents (clContext: ClContext) workGroupSize =
166+
let frontExclude = frontExclude clContext workGroupSize
167+
168+
let mergeDisjoint =
169+
Matrix.mergeDisjoint clContext workGroupSize
170+
171+
let findIntersection =
172+
Intersect.findKeysIntersection clContext workGroupSize
173+
174+
let copyIndices = ClArray.copyTo clContext workGroupSize
175+
176+
fun (queue: MailboxProcessor<Msg>) allocationMode (front: ClMatrix.COO<_>) (parents: ClMatrix.COO<_>) ->
177+
178+
// Find intersection of levels and front indices.
179+
let intersection =
180+
findIntersection queue DeviceOnly front parents
181+
182+
// Remove mutual elements
183+
let newFront =
184+
frontExclude queue allocationMode front intersection
185+
186+
intersection.Free queue
187+
188+
match newFront with
189+
| Some f ->
190+
// Update parents
191+
let newParents = mergeDisjoint queue parents f
192+
193+
copyIndices queue f.Columns f.Values
194+
195+
newParents, Some f
196+
197+
| _ -> parents, None
198+
199+
let run<'a when 'a: struct> (clContext: ClContext) workGroupSize =
200+
201+
let spGeMM =
202+
Operations.SpGeMM.COO.expand
203+
(ArithmeticOperations.min)
204+
(ArithmeticOperations.fst)
205+
clContext
206+
workGroupSize
207+
208+
let updateFrontAndParents =
209+
updateFrontAndParents clContext workGroupSize
210+
211+
fun (queue: MailboxProcessor<Msg>) (inputMatrix: ClMatrix<'a>) (source: int list) ->
212+
let vertexCount = inputMatrix.RowCount
213+
let sourceVertexCount = source.Length
214+
215+
let source = source |> List.sort
216+
217+
let matrix =
218+
match inputMatrix with
219+
| ClMatrix.CSR m ->
220+
{ Context = clContext
221+
RowPointers = m.RowPointers
222+
Columns = m.Columns
223+
Values = m.Columns
224+
RowCount = m.RowCount
225+
ColumnCount = m.ColumnCount }
226+
|> ClMatrix.CSR
227+
| _ -> failwith "Incorrect format"
228+
229+
let mutable parents =
230+
source
231+
|> List.mapi (fun i vertex -> i, vertex, -1)
232+
|> Matrix.ofList clContext DeviceOnly sourceVertexCount vertexCount
233+
234+
let mutable front =
235+
source
236+
|> List.mapi (fun i vertex -> i, vertex, vertex)
237+
|> Matrix.ofList clContext DeviceOnly sourceVertexCount vertexCount
238+
239+
let mutable stop = false
240+
241+
while not stop do
242+
//Getting new frontier
243+
match spGeMM queue DeviceOnly (ClMatrix.COO front) matrix with
244+
| None ->
245+
front.Dispose queue
246+
stop <- true
247+
248+
| Some newFrontier ->
249+
front.Dispose queue
250+
251+
//Filtering visited vertices
252+
match updateFrontAndParents queue DeviceOnly newFrontier parents with
253+
| p, Some f ->
254+
front <- f
255+
256+
parents.Dispose queue
257+
parents <- p
258+
259+
newFrontier.Dispose queue
260+
261+
| _, None ->
262+
stop <- true
263+
newFrontier.Dispose queue
264+
265+
ClMatrix.COO parents

src/GraphBLAS-sharp.Backend/Algorithms/SSSP.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module SSSP =
1111
let run (clContext: ClContext) workGroupSize =
1212

1313
let less = ArithmeticOperations.less<int>
14-
let min = ArithmeticOperations.min<int>
14+
let min = ArithmeticOperations.minOption<int>
1515
let plus = ArithmeticOperations.intSumAsMul
1616

1717
let spMVInPlace =

0 commit comments

Comments
 (0)