Skip to content

Fix (**) and use property tests for Floating #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arrayfire.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ test-suite test
base < 5,
directory,
hspec,
HUnit,
QuickCheck,
quickcheck-classes,
vector
Expand Down
4 changes: 4 additions & 0 deletions src/ArrayFire/Orphans.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where
pi = A.scalar @a 3.14159
exp = A.exp @a
log = A.log @a
sqrt = A.sqrt @a
(**) = A.pow @a
sin = A.sin @a
cos = A.cos @a
tan = A.tan @a
tanh = A.tanh @a
asin = A.asin @a
acos = A.acos @a
atan = A.atan @a
Expand Down
165 changes: 117 additions & 48 deletions test/ArrayFire/ArithSpec.hs
Original file line number Diff line number Diff line change
@@ -1,99 +1,168 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module ArrayFire.ArithSpec where

import ArrayFire hiding (acos)
import Prelude hiding (abs, sqrt, div, and, or, not, isNaN)
import Test.Hspec
import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector)
import qualified ArrayFire
import Control.Exception (throwIO)
import Control.Monad (unless, when)
import Foreign.C
import GHC.Exts (IsList (..))
import GHC.Stack
import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..))
import Test.Hspec
import Test.Hspec.QuickCheck
import Prelude hiding (div)

compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation
compareWith comparator result expected =
unless (comparator result expected) $ do
throwIO (HUnitFailure location $ ExpectedButGot Nothing expectedMsg actualMsg)
where
expectedMsg = show expected
actualMsg = show result
location = case reverse (toList callStack) of
(_, loc) : _ -> Just loc
[] -> Nothing

class (Num a) => HasEpsilon a where
eps :: a

instance HasEpsilon Float where
eps = 1.1920929e-7

instance HasEpsilon Double where
eps = 2.220446049250313e-16

approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool
approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b))

approx :: (Ord a, HasEpsilon a) => a -> a -> Bool
approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b

shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation
shouldBeApprox = compareWith approx

evalf :: (AFType a) => Array a -> a
evalf = ArrayFire.getScalar

shouldMatchBuiltin ::
(AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) =>
(Array a -> Array a) ->
(a -> a) ->
a ->
Expectation
shouldMatchBuiltin f f' x
| isInfinite y && isInfinite y' = pure ()
| Prelude.isNaN y && Prelude.isNaN y' = pure ()
| otherwise = y `shouldBeApprox` y'
where
y = evalf (f (scalar x))
y' = f' x

shouldMatchBuiltin2 ::
(AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) =>
(Array a -> Array a -> Array a) ->
(a -> a -> a) ->
a ->
a ->
Expectation
shouldMatchBuiltin2 f f' a = shouldMatchBuiltin (f (scalar a)) (f' a)

spec :: Spec
spec =
describe "Arith tests" $ do
it "Should negate scalar value" $ do
negate (scalar @Int 1) `shouldBe` (-1)
it "Should negate a vector" $ do
negate (vector @Int 3 [2,2,2]) `shouldBe` vector @Int 3 [-2,-2,-2]
negate (vector @Int 3 [2, 2, 2]) `shouldBe` vector @Int 3 [-2, -2, -2]
it "Should add two scalar arrays" $ do
scalar @Int 1 + 2 `shouldBe` 3
it "Should add two scalar bool arrays" $ do
scalar @CBool 1 + 0 `shouldBe` 1
it "Should subtract two scalar arrays" $ do
scalar @Int 4 - 2 `shouldBe` 2
it "Should multiply two scalar arrays" $ do
scalar @Double 4 `mul` 2 `shouldBe` 8
scalar @Double 4 `ArrayFire.mul` 2 `shouldBe` 8
it "Should divide two scalar arrays" $ do
div @Double 8 2 `shouldBe` 4
ArrayFire.div @Double 8 2 `shouldBe` 4
it "Should add two matrices" $ do
matrix @Int (2,2) [[1,1],[1,1]] + matrix @Int (2,2) [[1,1],[1,1]]
`shouldBe`
matrix @Int (2,2) [[2,2],[2,2]]
-- Exact comparisons of Double don't make sense here, so we just check that the result is
-- accurate up to some epsilon.
it "Should take cubed root" $ do
allTrueAll ((abs (3 - cbrt @Double 27)) `lt` 1.0e-14) `shouldBe` (1, 0)
it "Should take square root" $ do
allTrueAll ((abs (2 - sqrt @Double 4)) `lt` 1.0e-14) `shouldBe` (1, 0)
matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]]
`shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]]
prop "Should take cubed root" $ \(x :: Double) ->
evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x

it "Should lte Array" $ do
2 `le` (3 :: Array Double) `shouldBe` 1
2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1
it "Should gte Array" $ do
2 `ge` (3 :: Array Double) `shouldBe` 0
2 `ArrayFire.ge` (3 :: Array Double) `shouldBe` 0
it "Should gt Array" $ do
2 `gt` (3 :: Array Double) `shouldBe` 0
2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0
it "Should lt Array" $ do
2 `le` (3 :: Array Double) `shouldBe` 1
2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1
it "Should eq Array" $ do
3 == (3 :: Array Double) `shouldBe` True
it "Should and Array" $ do
(mkArray @CBool [1] [0] `and` mkArray [1] [1])
`shouldBe` mkArray [1] [0]
(mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1])
`shouldBe` mkArray [1] [0]
it "Should and Array" $ do
(mkArray @CBool [2] [0,0] `and` mkArray [2] [1,0])
`shouldBe` mkArray [2] [0, 0]
(mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0])
`shouldBe` mkArray [2] [0, 0]
it "Should or Array" $ do
(mkArray @CBool [2] [0,0] `or` mkArray [2] [1,0])
`shouldBe` mkArray [2] [1, 0]
(mkArray @CBool [2] [0, 0] `ArrayFire.or` mkArray [2] [1, 0])
`shouldBe` mkArray [2] [1, 0]
it "Should not Array" $ do
not (mkArray @CBool [2] [1,0]) `shouldBe` mkArray [2] [0,1]
ArrayFire.not (mkArray @CBool [2] [1, 0]) `shouldBe` mkArray [2] [0, 1]
it "Should bitwise and array" $ do
bitAnd (scalar @Int 1) (scalar @Int 0)
`shouldBe`
0
ArrayFire.bitAnd (scalar @Int 1) (scalar @Int 0)
`shouldBe` 0
it "Should bitwise or array" $ do
bitOr (scalar @Int 1) (scalar @Int 0)
`shouldBe`
1
ArrayFire.bitOr (scalar @Int 1) (scalar @Int 0)
`shouldBe` 1
it "Should bitwise xor array" $ do
bitXor (scalar @Int 1) (scalar @Int 1)
`shouldBe`
0
ArrayFire.bitXor (scalar @Int 1) (scalar @Int 1)
`shouldBe` 0
it "Should bitwise shift left an array" $ do
bitShiftL (scalar @Int 1) (scalar @Int 3)
`shouldBe`
8
ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3)
`shouldBe` 8
it "Should cast an array" $ do
getType (cast (scalar @Int 1) :: Array Double)
`shouldBe`
F64
`shouldBe` ArrayFire.F64
it "Should find the minimum of two arrays" $ do
minOf (scalar @Int 1) (scalar @Int 0)
`shouldBe`
0
`shouldBe` 0
it "Should find the max of two arrays" $ do
maxOf (scalar @Int 1) (scalar @Int 0)
`shouldBe`
1
`shouldBe` 1
it "Should take the clamp of 3 arrays" $ do
clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3)
`shouldBe`
2
`shouldBe` 2
it "Should check if an array has positive or negative infinities" $ do
isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1
isInf (scalar @Double 10) `shouldBe` scalar @Double 0
it "Should check if an array has any NaN values" $ do
isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1
isNaN (scalar @Double 10) `shouldBe` scalar @Double 0
ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1
ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0
it "Should check if an array has any Zero values" $ do
isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0
isZero (scalar @Double 0) `shouldBe` scalar @Double 1
isZero (scalar @Double 1) `shouldBe` scalar @Double 0

prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x
prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x
prop "Floating @Float (sqrt)" $ \(x :: Float) -> sqrt `shouldMatchBuiltin` sqrt $ x
prop "Floating @Float (**)" $ \(x :: Float) (y :: Float) -> ((**) `shouldMatchBuiltin2` (**)) x y
prop "Floating @Float (sin)" $ \(x :: Float) -> sin `shouldMatchBuiltin` sin $ x
prop "Floating @Float (cos)" $ \(x :: Float) -> cos `shouldMatchBuiltin` cos $ x
prop "Floating @Float (tan)" $ \(x :: Float) -> tan `shouldMatchBuiltin` tan $ x
prop "Floating @Float (asin)" $ \(x :: Float) -> asin `shouldMatchBuiltin` asin $ x
prop "Floating @Float (acos)" $ \(x :: Float) -> acos `shouldMatchBuiltin` acos $ x
prop "Floating @Float (atan)" $ \(x :: Float) -> atan `shouldMatchBuiltin` atan $ x
prop "Floating @Float (sinh)" $ \(x :: Float) -> sinh `shouldMatchBuiltin` sinh $ x
prop "Floating @Float (cosh)" $ \(x :: Float) -> cosh `shouldMatchBuiltin` cosh $ x
prop "Floating @Float (tanh)" $ \(x :: Float) -> tanh `shouldMatchBuiltin` tanh $ x
prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x
prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x
prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x