Skip to content

Commit 90812e0

Browse files
authored
Fix joinMany (#56)
Instead of allocating an array of pointers, joinMany was allocating memory for just one pointer. This was making ArrayFire read out of bounds and fail with various errors. This commit fixes this issue by adding a helper withManyForeignPtr function that acts like withForeignPtr (not unsafeWithForeignPtr!), but for a list of ForeignPtrs.
1 parent 1e4f909 commit 90812e0

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

cabal.project

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
packages: .
12
ignore-project: False
23
write-ghc-environment-files: always
34
tests: True

src/ArrayFire/Data.hs

+13-12
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
module ArrayFire.Data where
3131

3232
import Control.Exception
33-
import Control.Monad
3433
import Data.Complex
3534
import Data.Int
3635
import Data.Proxy
3736
import Data.Word
3837
import Foreign.C.Types
3938
import Foreign.ForeignPtr
4039
import Foreign.Marshal hiding (void)
40+
import Foreign.Ptr (Ptr)
4141
import Foreign.Storable
4242
import System.IO.Unsafe
4343
import Unsafe.Coerce
@@ -357,20 +357,21 @@ joinMany
357357
:: Int
358358
-> [Array a]
359359
-> Array a
360-
joinMany (fromIntegral -> n) arrays = unsafePerformIO . mask_ $ do
361-
fptrs <- forM arrays $ \(Array fptr) -> pure fptr
362-
newPtr <-
363-
alloca $ \fPtrsPtr -> do
364-
forM_ fptrs $ \fptr ->
365-
withForeignPtr fptr (poke fPtrsPtr)
366-
alloca $ \aPtr -> do
367-
zeroOutArray aPtr
368-
throwAFError =<< af_join_many aPtr n nArrays fPtrsPtr
369-
peek aPtr
360+
joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do
361+
newPtr <- alloca $ \aPtr -> do
362+
zeroOutArray aPtr
363+
(throwAFError =<<) $
364+
withManyForeignPtr arrays $ \(fromIntegral -> nArrays) fPtrsPtr ->
365+
af_join_many aPtr n nArrays fPtrsPtr
366+
peek aPtr
370367
Array <$>
371368
newForeignPtr af_release_array_finalizer newPtr
369+
370+
withManyForeignPtr :: [ForeignPtr a] -> (Int -> Ptr (Ptr a) -> IO b) -> IO b
371+
withManyForeignPtr fptrs action = go [] fptrs
372372
where
373-
nArrays = fromIntegral (length arrays)
373+
go ptrs [] = withArrayLen (reverse ptrs) action
374+
go ptrs (fptr:others) = withForeignPtr fptr $ \ptr -> go (ptr : ptrs) others
374375

375376
-- | Tiles an Array according to specified dimensions
376377
--

test/ArrayFire/DataSpec.hs

+5
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,8 @@ spec =
3232
constant @(Complex Float) [1] (1.0 :+ 1.0)
3333
`shouldBe`
3434
constant @(Complex Float) [1] (1.0 :+ 1.0)
35+
it "Should join Arrays along the specified dimension" $ do
36+
join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
37+
join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2]
38+
joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
39+
joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3]

0 commit comments

Comments
 (0)