1
+ {-# LANGUAGE BangPatterns #-}
2
+
3
+ -- |
4
+ -- Module : Statistics.Matrix.Fast.Algorithms
5
+ -- Copyright : 2019 Magalame
6
+ -- License : BSD3
7
+ --
8
+ -- Useful matrix functions.
9
+
10
+ module Statistics.Matrix.Fast.Algorithms
11
+ (
12
+ qr
13
+ ) where
14
+
15
+ import Control.Applicative ((<$>) , (<*>) )
16
+ import Control.Monad.ST (ST , runST )
17
+ import Prelude hiding (replicate )
18
+ import Statistics.Matrix (Matrix (.. ),dimension , for )
19
+ import qualified Statistics.Matrix.Mutable as M
20
+ import qualified Data.Vector.Unboxed as U
21
+
22
+ -- | /O(r*c)/ Compute the QR decomposition of a matrix.
23
+ -- The result returned is the matrices (/q/,/r/).
24
+ qr :: Matrix -> (Matrix , Matrix )
25
+ qr mat = runST $ do
26
+ let (m,n) = dimension mat
27
+
28
+ r <- M. replicate n n 0
29
+ a <- M. thaw mat
30
+ for 0 n $ \ j -> do
31
+ cn <- M. immutably a $ \ aa -> sqrt $ normCol j aa
32
+ M. unsafeWrite r j j cn
33
+ for 0 m $ \ i -> M. unsafeModify a i j (/ cn)
34
+ for (j+ 1 ) n $ \ jj -> do
35
+ p <- innerProduct a j jj
36
+ M. unsafeWrite r j jj p
37
+ for 0 m $ \ i -> do
38
+ aij <- M. unsafeRead a i j
39
+ M. unsafeModify a i jj $ subtract (p * aij)
40
+ (,) <$> M. unsafeFreeze a <*> M. unsafeFreeze r
41
+
42
+ normCol :: Int -> Matrix -> Double
43
+ normCol jthcol (Matrix r c v) = sub 0 0
44
+ where sub ! acc ! ij | ij == r = acc
45
+ | otherwise = sub ( valCol* valCol + acc ) (ij+ 1 )
46
+ where
47
+ valCol = U. unsafeIndex v (ij* c+ jthcol)
48
+
49
+ innerProduct :: M. MMatrix s -> Int -> Int -> ST s Double
50
+ innerProduct mmat j k = M. immutably mmat $ \ mat ->
51
+ dotCol j mat k mat
52
+
53
+ dotCol :: Int -> Matrix -> Int -> Matrix -> Double
54
+ dotCol jthcol (Matrix r1 c1 v1) kthcol (Matrix _ c2 v2) = sub 0 0
55
+ where sub ! acc ! ij | ij == r1 = acc
56
+ | otherwise = sub ( valColj* valColk + acc ) (ij+ 1 )
57
+ where
58
+ valColk = U. unsafeIndex v2 (ij* c2+ kthcol)
59
+ valColj = U. unsafeIndex v1 (ij* c1+ jthcol)
0 commit comments