Skip to content

Commit c61feb3

Browse files
Magalameocramz
authored andcommitted
Add Fast modules to DLA (#59)
* Added a Fast module to DLA, modified benchmarks accordingly
1 parent 28c00d9 commit c61feb3

File tree

6 files changed

+134
-7
lines changed

6 files changed

+134
-7
lines changed

dense-linear-algebra/bench/ChronosBench.hs

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module Main where
22

33
import qualified Statistics.Matrix as M
4+
import qualified Statistics.Matrix.Fast as F
5+
import qualified Statistics.Matrix.Fast.Algorithms as FA
46
import Statistics.Matrix (Matrix (..))
57
import qualified Statistics.Matrix.Algorithms as A
68

@@ -31,9 +33,11 @@ runtimelight v a = do
3133
v2 = U.take n v
3234

3335
C.defaultMainWith (C.defaultConfig {C.timeout = Just 3}) [
34-
C.bench "norm" M.norm v,
36+
C.bench "norm" M.norm v2,
37+
C.bench "Fast.norm" F.norm v2,
3538

3639
C.bench "multiplyV" (M.multiplyV a) (v2),
40+
C.bench "Fast.multiplyV" (F.multiplyV a) (v2),
3741

3842
C.bench "transpose" M.transpose a ,
3943
C.bench "ident" M.ident n,
@@ -45,7 +49,9 @@ runtimeheavy a b = do
4549

4650
C.defaultMainWith (C.defaultConfig {C.timeout = Just 1}) [
4751
C.bench "multiply" (M.multiply a) b,
48-
C.bench "qr" A.qr a
52+
C.bench "Fast.multiply" (F.multiply a) b,
53+
C.bench "qr" A.qr a,
54+
C.bench "Fast.qr" FA.qr a
4955
]
5056

5157

@@ -60,4 +66,4 @@ main = do
6066
-- we split heavy and light, we lose some precision in the bar plots from chronos
6167
runtimelight v a
6268
putStrLn "---Benchmarking heavy operations---"
63-
runtimeheavy a b
69+
runtimeheavy a b

dense-linear-algebra/bench/WeighBench.hs

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module Main where
22

33
import qualified Statistics.Matrix as M
4+
import qualified Statistics.Matrix.Fast as F
5+
import qualified Statistics.Matrix.Fast.Algorithms as FA
46
import Statistics.Matrix (Matrix (..))
57
import qualified Statistics.Matrix.Algorithms as A
68

@@ -32,14 +34,21 @@ weight v a b = do
3234
v2 = U.take n v
3335
W.mainWith (do
3436
W.func "norm" M.norm v2
37+
W.func "Fast.norm" F.norm v2
3538

36-
W.func "multiply" (M.multiply a) b
3739
W.func "multiplyV" (M.multiplyV a) (v2)
38-
W.func "qr" A.qr a
39-
40+
W.func "Fast.multiplyV" (F.multiplyV a) (v2)
4041
W.func "transpose" M.transpose a
4142
W.func "ident" M.ident n
42-
W.func "diag" M.diag v2)
43+
W.func "diag" M.diag v2
44+
45+
W.func "multiply" (M.multiply a) b
46+
W.func "Fast.multiply" (F.multiply a) b
47+
48+
W.func "qr" A.qr a
49+
W.func "Fast.qr" FA.qr a
50+
51+
)
4352

4453

4554

dense-linear-algebra/dense-linear-algebra.cabal

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ library
2626
Statistics.Matrix.Function
2727
Statistics.Matrix.Mutable
2828
Statistics.Matrix.Types
29+
Statistics.Matrix.Fast
30+
Statistics.Matrix.Fast.Algorithms
2931
build-depends: base >= 4.5 && < 5
3032
, deepseq >= 1.1.0.2
3133
, math-functions >= 0.1.7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{-# LANGUAGE BangPatterns #-}
2+
3+
module Statistics.Matrix.Fast (
4+
multiply,
5+
norm,
6+
multiplyV
7+
) where
8+
9+
import Prelude hiding (exponent, map)
10+
import Control.Monad.ST
11+
import qualified Data.Vector.Unboxed as U
12+
13+
14+
import Statistics.Matrix (row)
15+
import Statistics.Matrix.Function
16+
import Statistics.Matrix.Types
17+
import Statistics.Matrix.Mutable (unsafeNew,unsafeWrite,unsafeFreeze)
18+
19+
-- | Matrix-matrix multiplication in a more imperative fashion. Matrices must be of compatible
20+
-- sizes (/note: not checked/). Faster but less accurate than Statistics.Matrix.multiply
21+
multiply :: Matrix -> Matrix -> Matrix
22+
multiply m1@(Matrix r1 _ _) m2@(Matrix _ c2 _) = runST $ do
23+
m3 <- unsafeNew r1 c2
24+
for 0 c2 $ \j -> do
25+
for 0 r1 $ \i -> do
26+
let
27+
z = accum i m1 j m2
28+
unsafeWrite m3 i j z
29+
unsafeFreeze m3
30+
31+
accum :: Int -> Matrix -> Int -> Matrix -> Double
32+
accum ithrow (Matrix r1 c1 v1) jthcol (Matrix _ c2 v2) = sub 0 0
33+
where sub !acc !ij | ij == r1 = acc
34+
| otherwise = sub ( valRow*valCol + acc ) (ij+1)
35+
where
36+
valRow = U.unsafeIndex v1 (ithrow*c1 + ij)
37+
valCol = U.unsafeIndex v2 (ij*c2+jthcol)
38+
39+
-- | Matrix-vector multiplication, with better performances but not as accurate as
40+
-- Statistics.Matrix.multiplyV
41+
multiplyV :: Matrix -> Vector -> Vector
42+
multiplyV m v
43+
| cols m == c = U.generate (rows m) (U.sum . U.zipWith (*) v . row m)
44+
| otherwise = error $ "matrix/vector unconformable " ++ show (cols m,c)
45+
where c = U.length v
46+
47+
-- | Norm of a vector. Faster but less accurate than Statistics.Matrix.norm
48+
norm :: Vector -> Double
49+
norm = sqrt . U.sum . U.map square
50+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
{-# LANGUAGE BangPatterns #-}
2+
3+
-- |
4+
-- Module : Statistics.Matrix.Fast.Algorithms
5+
-- Copyright : 2019 Magalame
6+
-- License : BSD3
7+
--
8+
-- Useful matrix functions.
9+
10+
module Statistics.Matrix.Fast.Algorithms
11+
(
12+
qr
13+
) where
14+
15+
import Control.Applicative ((<$>), (<*>))
16+
import Control.Monad.ST (ST, runST)
17+
import Prelude hiding (replicate)
18+
import Statistics.Matrix (Matrix (..),dimension, for)
19+
import qualified Statistics.Matrix.Mutable as M
20+
import qualified Data.Vector.Unboxed as U
21+
22+
-- | /O(r*c)/ Compute the QR decomposition of a matrix.
23+
-- The result returned is the matrices (/q/,/r/).
24+
qr :: Matrix -> (Matrix, Matrix)
25+
qr mat = runST $ do
26+
let (m,n) = dimension mat
27+
28+
r <- M.replicate n n 0
29+
a <- M.thaw mat
30+
for 0 n $ \j -> do
31+
cn <- M.immutably a $ \aa -> sqrt $ normCol j aa
32+
M.unsafeWrite r j j cn
33+
for 0 m $ \i -> M.unsafeModify a i j (/ cn)
34+
for (j+1) n $ \jj -> do
35+
p <- innerProduct a j jj
36+
M.unsafeWrite r j jj p
37+
for 0 m $ \i -> do
38+
aij <- M.unsafeRead a i j
39+
M.unsafeModify a i jj $ subtract (p * aij)
40+
(,) <$> M.unsafeFreeze a <*> M.unsafeFreeze r
41+
42+
normCol :: Int -> Matrix -> Double
43+
normCol jthcol (Matrix r c v) = sub 0 0
44+
where sub !acc !ij | ij == r = acc
45+
| otherwise = sub ( valCol*valCol + acc ) (ij+1)
46+
where
47+
valCol = U.unsafeIndex v (ij*c+jthcol)
48+
49+
innerProduct :: M.MMatrix s -> Int -> Int -> ST s Double
50+
innerProduct mmat j k = M.immutably mmat $ \mat ->
51+
dotCol j mat k mat
52+
53+
dotCol :: Int -> Matrix -> Int -> Matrix -> Double
54+
dotCol jthcol (Matrix r1 c1 v1) kthcol (Matrix _ c2 v2) = sub 0 0
55+
where sub !acc !ij | ij == r1 = acc
56+
| otherwise = sub ( valColj*valColk + acc ) (ij+1)
57+
where
58+
valColk = U.unsafeIndex v2 (ij*c2+kthcol)
59+
valColj = U.unsafeIndex v1 (ij*c1+jthcol)

dense-linear-algebra/stack.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ extra-deps:
99
- torsor-0.1
1010
- chronos-1.0.5
1111
- chronos-bench-0.2.0.2
12+
- primitive-0.6.4.0
1213

1314
# Override default flag values for local packages and extra-deps
1415
# flags: {}

0 commit comments

Comments
 (0)