Skip to content

Commit fb4ce18

Browse files
authored
Use tree-diffing for difference (#535)
Context: #364 Also refactor `delete` to create a `deleteFromSubtree` version that can be used at different levels of the tree.
1 parent 59ddae5 commit fb4ce18

File tree

2 files changed

+201
-54
lines changed

2 files changed

+201
-54
lines changed

Data/HashMap/Internal.hs

Lines changed: 175 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ import Data.Functor.Identity (Identity (..))
163163
import Data.Hashable (Hashable)
164164
import Data.Hashable.Lifted (Hashable1, Hashable2)
165165
import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare)
166+
import Data.Maybe (isNothing)
166167
import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid)
167168
import GHC.Exts (Int (..), Int#, TYPE, (==#))
168169
import GHC.Stack (HasCallStack)
@@ -1102,56 +1103,60 @@ delete k m = delete' (hash k) k m
11021103
{-# INLINABLE delete #-}
11031104

11041105
delete' :: Eq k => Hash -> k -> HashMap k v -> HashMap k v
1105-
delete' h0 k0 m0 = go h0 k0 0 m0
1106-
where
1107-
go !_ !_ !_ Empty = Empty
1108-
go h k _ t@(Leaf hy (L ky _))
1109-
| hy == h && ky == k = Empty
1110-
| otherwise = t
1111-
go h k s t@(BitmapIndexed b ary)
1112-
| b .&. m == 0 = t
1113-
| otherwise =
1114-
let !st = A.index ary i
1115-
!st' = go h k (nextShift s) st
1116-
in if st' `ptrEq` st
1117-
then t
1118-
else case st' of
1119-
Empty | A.length ary == 1 -> Empty
1120-
| A.length ary == 2 ->
1121-
case (i, A.index ary 0, A.index ary 1) of
1122-
(0, _, l) | isLeafOrCollision l -> l
1123-
(1, l, _) | isLeafOrCollision l -> l
1124-
_ -> bIndexed
1125-
| otherwise -> bIndexed
1126-
where
1127-
bIndexed = BitmapIndexed (b .&. complement m) (A.delete ary i)
1128-
l | isLeafOrCollision l && A.length ary == 1 -> l
1129-
_ -> BitmapIndexed b (A.update ary i st')
1130-
where m = mask h s
1131-
i = sparseIndex b m
1132-
go h k s t@(Full ary) =
1133-
let !st = A.index ary i
1134-
!st' = go h k (nextShift s) st
1106+
delete' h0 k0 m0 = deleteFromSubtree h0 k0 0 m0
1107+
{-# INLINABLE delete' #-}
1108+
1109+
-- | This version of 'delete' can be used on subtrees when a the
1110+
-- corresponding 'Shift' argument is supplied.
1111+
deleteFromSubtree :: Eq k => Hash -> k -> Shift -> HashMap k v -> HashMap k v
1112+
deleteFromSubtree !_ !_ !_ Empty = Empty
1113+
deleteFromSubtree h k _ t@(Leaf hy (L ky _))
1114+
| hy == h && ky == k = Empty
1115+
| otherwise = t
1116+
deleteFromSubtree h k s t@(BitmapIndexed b ary)
1117+
| b .&. m == 0 = t
1118+
| otherwise =
1119+
let !st = A.index ary i
1120+
!st' = deleteFromSubtree h k (nextShift s) st
11351121
in if st' `ptrEq` st
11361122
then t
11371123
else case st' of
1138-
Empty ->
1139-
let ary' = A.delete ary i
1140-
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1141-
in BitmapIndexed bm ary'
1142-
_ -> Full (A.update ary i st')
1143-
where i = index h s
1144-
go h k _ t@(Collision hy v)
1145-
| h == hy = case indexOf k v of
1146-
Just i
1147-
| A.length v == 2 ->
1148-
if i == 0
1149-
then Leaf h (A.index v 1)
1150-
else Leaf h (A.index v 0)
1151-
| otherwise -> Collision h (A.delete v i)
1152-
Nothing -> t
1153-
| otherwise = t
1154-
{-# INLINABLE delete' #-}
1124+
Empty | A.length ary == 1 -> Empty
1125+
| A.length ary == 2 ->
1126+
case (i, A.index ary 0, A.index ary 1) of
1127+
(0, _, l) | isLeafOrCollision l -> l
1128+
(1, l, _) | isLeafOrCollision l -> l
1129+
_ -> bIndexed
1130+
| otherwise -> bIndexed
1131+
where
1132+
bIndexed = BitmapIndexed (b .&. complement m) (A.delete ary i)
1133+
l | isLeafOrCollision l && A.length ary == 1 -> l
1134+
_ -> BitmapIndexed b (A.update ary i st')
1135+
where m = mask h s
1136+
i = sparseIndex b m
1137+
deleteFromSubtree h k s t@(Full ary) =
1138+
let !st = A.index ary i
1139+
!st' = deleteFromSubtree h k (nextShift s) st
1140+
in if st' `ptrEq` st
1141+
then t
1142+
else case st' of
1143+
Empty ->
1144+
let ary' = A.delete ary i
1145+
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1146+
in BitmapIndexed bm ary'
1147+
_ -> Full (updateFullArray ary i st')
1148+
where i = index h s
1149+
deleteFromSubtree h k _ t@(Collision hy v)
1150+
| h == hy = case indexOf k v of
1151+
Just i
1152+
| A.length v == 2 ->
1153+
if i == 0
1154+
then Leaf h (A.index v 1)
1155+
else Leaf h (A.index v 0)
1156+
| otherwise -> Collision h (A.delete v i)
1157+
Nothing -> t
1158+
| otherwise = t
1159+
{-# INLINABLE deleteFromSubtree #-}
11551160

11561161
-- | Delete optimized for the case when we know the key is in the map.
11571162
--
@@ -1188,7 +1193,7 @@ deleteKeyExists !collPos0 !h0 !k0 !m0 = go collPos0 h0 k0 m0
11881193
let ary' = A.delete ary i
11891194
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
11901195
in BitmapIndexed bm ary'
1191-
_ -> Full (A.update ary i st')
1196+
_ -> Full (updateFullArray ary i st')
11921197
where i = indexSH shiftedHash
11931198
go collPos _shiftedHash _k (Collision h v)
11941199
| A.length v == 2
@@ -1780,14 +1785,131 @@ mapKeys f = fromList . foldrWithKey (\k x xs -> (f k, x) : xs) []
17801785

17811786
-- | \(O(n \log m)\) Difference of two maps. Return elements of the first map
17821787
-- not existing in the second.
1783-
difference :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v
1784-
difference a b = foldlWithKey' go empty a
1788+
difference :: Eq k => HashMap k v -> HashMap k w -> HashMap k v
1789+
difference = go_difference 0
17851790
where
1786-
go m k v = case lookup k b of
1787-
Nothing -> unsafeInsert k v m
1788-
_ -> m
1791+
go_difference !_s Empty _ = Empty
1792+
go_difference s t1@(Leaf h1 (L k1 _)) t2
1793+
= lookupCont (\_ -> t1) (\_ _ -> Empty) h1 k1 s t2
1794+
go_difference _ t1 Empty = t1
1795+
go_difference s t1 (Leaf h2 (L k2 _)) = deleteFromSubtree h2 k2 s t1
1796+
1797+
go_difference s t1@(BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2)
1798+
= differenceArrays s b1 ary1 t1 b2 ary2
1799+
go_difference s t1@(Full ary1) (BitmapIndexed b2 ary2)
1800+
= differenceArrays s fullBitmap ary1 t1 b2 ary2
1801+
go_difference s t1@(BitmapIndexed b1 ary1) (Full ary2)
1802+
= differenceArrays s b1 ary1 t1 fullBitmap ary2
1803+
go_difference s t1@(Full ary1) (Full ary2)
1804+
= differenceArrays s fullBitmap ary1 t1 fullBitmap ary2
1805+
1806+
go_difference s t1@(Collision h1 _) (BitmapIndexed b2 ary2)
1807+
| b2 .&. m == 0 = t1
1808+
| otherwise =
1809+
case A.index# ary2 (sparseIndex b2 m) of
1810+
(# st2 #) -> go_difference (nextShift s) t1 st2
1811+
where m = mask h1 s
1812+
go_difference s t1@(Collision h1 _) (Full ary2)
1813+
= case A.index# ary2 (index h1 s) of
1814+
(# st2 #) -> go_difference (nextShift s) t1 st2
1815+
1816+
go_difference s t1@(BitmapIndexed b1 ary1) t2@(Collision h2 _)
1817+
| b1 .&. m == 0 = t1
1818+
| otherwise =
1819+
let (# !st #) = A.index# ary1 i1
1820+
in case go_difference (nextShift s) st t2 of
1821+
Empty {- | A.length ary1 == 1 -> Empty -- Impossible! -}
1822+
| A.length ary1 == 2 ->
1823+
case (i1, A.index ary1 0, A.index ary1 1) of
1824+
(0, _, l) | isLeafOrCollision l -> l
1825+
(1, l, _) | isLeafOrCollision l -> l
1826+
_ -> bIndexed
1827+
| otherwise -> bIndexed
1828+
where
1829+
bIndexed
1830+
= BitmapIndexed (b1 .&. complement m) (A.delete ary1 i1)
1831+
st' | isLeafOrCollision st' && A.length ary1 == 1 -> st'
1832+
| st `ptrEq` st' -> t1
1833+
| otherwise -> BitmapIndexed b1 (A.update ary1 i1 st')
1834+
where
1835+
m = mask h2 s
1836+
i1 = sparseIndex b1 m
1837+
go_difference s t1@(Full ary1) t2@(Collision h2 _)
1838+
= let (# !st #) = A.index# ary1 i
1839+
in case go_difference (nextShift s) st t2 of
1840+
Empty ->
1841+
let ary1' = A.delete ary1 i
1842+
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1843+
in BitmapIndexed bm ary1'
1844+
st' | st `ptrEq` st' -> t1
1845+
| otherwise -> Full (updateFullArray ary1 i st')
1846+
where i = index h2 s
1847+
1848+
go_difference _ t1@(Collision h1 ary1) (Collision h2 ary2)
1849+
= differenceCollisions h1 ary1 t1 h2 ary2
1850+
1851+
-- TODO: If we keep 'Full' (#399), differenceArrays could be optimized for
1852+
-- each combination of 'Full' and 'BitmapIndexed`.
1853+
differenceArrays !s !b1 !ary1 t1 !b2 !ary2
1854+
| b1 .&. b2 == 0 = t1
1855+
| A.unsafeSameArray ary1 ary2 = Empty
1856+
| otherwise = runST $ do
1857+
mary <- A.new_ $ A.length ary1
1858+
1859+
-- TODO: i == popCount bResult. Not sure if that would be faster.
1860+
-- Also i1 is in some relation with b1'
1861+
let goDA !i !i1 !b1' !bResult !nChanges
1862+
| b1' == 0 = pure (bResult, nChanges)
1863+
| otherwise = do
1864+
!st1 <- A.indexM ary1 i1
1865+
case m .&. b2 of
1866+
0 -> do
1867+
A.write mary i st1
1868+
goDA (i + 1) (i1 + 1) nextB1' (bResult .|. m) nChanges
1869+
_ -> do
1870+
!st2 <- A.indexM ary2 (sparseIndex b2 m)
1871+
case go_difference (nextShift s) st1 st2 of
1872+
Empty -> goDA i (i1 + 1) nextB1' bResult (nChanges + 1)
1873+
st -> do
1874+
A.write mary i st
1875+
let same = I# (Exts.reallyUnsafePtrEquality# st st1)
1876+
let nChanges' = nChanges + (1 - same)
1877+
goDA (i + 1) (i1 + 1) nextB1' (bResult .|. m) nChanges'
1878+
where
1879+
m = b1' .&. negate b1'
1880+
nextB1' = b1' .&. complement m
1881+
1882+
(bResult, nChanges) <- goDA 0 0 b1 0 0
1883+
if nChanges == 0
1884+
then pure t1
1885+
else case popCount bResult of
1886+
0 -> pure Empty
1887+
1 -> do
1888+
l <- A.read mary 0
1889+
if isLeafOrCollision l
1890+
then pure l
1891+
else BitmapIndexed bResult <$> (A.unsafeFreeze =<< A.shrink mary 1)
1892+
n -> bitmapIndexedOrFull bResult <$> (A.unsafeFreeze =<< A.shrink mary n)
17891893
{-# INLINABLE difference #-}
17901894

1895+
-- TODO: This could be faster if we would keep track of which elements of ary2
1896+
-- we've already matched. Those could be skipped when we check the following
1897+
-- elements of ary1.
1898+
differenceCollisions :: Eq k => Hash -> A.Array (Leaf k v1) -> HashMap k v1 -> Hash -> A.Array (Leaf k v2) -> HashMap k v1
1899+
differenceCollisions !h1 !ary1 t1 !h2 !ary2
1900+
| h1 == h2 =
1901+
if A.unsafeSameArray ary1 ary2
1902+
then Empty
1903+
else let ary = A.filter (\(L k1 _) -> isNothing (indexOf k1 ary2)) ary1
1904+
in case A.length ary of
1905+
0 -> Empty
1906+
1 -> case A.index# ary 0 of
1907+
(# l #) -> Leaf h1 l
1908+
n | A.length ary1 == n -> t1
1909+
| otherwise -> Collision h1 ary
1910+
| otherwise = t1
1911+
{-# INLINABLE differenceCollisions #-}
1912+
17911913
-- | \(O(n \log m)\) Difference with a combining function. When two equal keys are
17921914
-- encountered, the combining function is applied to the values of these keys.
17931915
-- If it returns 'Nothing', the element is discarded (proper set difference). If

Data/HashMap/Internal/Array.hs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ module Data.HashMap.Internal.Array
7272
, thaw
7373
, map
7474
, map'
75+
, filter
7576
, traverse
7677
, traverse'
7778
, toList
@@ -113,12 +114,14 @@ import qualified Prelude
113114
if (_k_) < 0 || (_k_) >= (_len_) then error ("Data.HashMap.Internal.Array." ++ (_func_) ++ ": bounds error, offset " ++ show (_k_) ++ ", length " ++ show (_len_)) else
114115
# define CHECK_OP(_func_,_op_,_lhs_,_rhs_) \
115116
if not ((_lhs_) _op_ (_rhs_)) then error ("Data.HashMap.Internal.Array." ++ (_func_) ++ ": Check failed: _lhs_ _op_ _rhs_ (" ++ show (_lhs_) ++ " vs. " ++ show (_rhs_) ++ ")") else
117+
# define CHECK_GE(_func_,_lhs_,_rhs_) CHECK_OP(_func_,>=,_lhs_,_rhs_)
116118
# define CHECK_GT(_func_,_lhs_,_rhs_) CHECK_OP(_func_,>,_lhs_,_rhs_)
117119
# define CHECK_LE(_func_,_lhs_,_rhs_) CHECK_OP(_func_,<=,_lhs_,_rhs_)
118120
# define CHECK_EQ(_func_,_lhs_,_rhs_) CHECK_OP(_func_,==,_lhs_,_rhs_)
119121
#else
120122
# define CHECK_BOUNDS(_func_,_len_,_k_)
121123
# define CHECK_OP(_func_,_op_,_lhs_,_rhs_)
124+
# define CHECK_GE(_func_,_lhs_,_rhs_)
122125
# define CHECK_GT(_func_,_lhs_,_rhs_)
123126
# define CHECK_LE(_func_,_lhs_,_rhs_)
124127
# define CHECK_EQ(_func_,_lhs_,_rhs_)
@@ -221,7 +224,7 @@ new_ n = new n undefinedElem
221224
-- | The returned array is the same as the array given, as it is shrunk in place.
222225
shrink :: MArray s a -> Int -> ST s (MArray s a)
223226
shrink mary _n@(I# n#) =
224-
CHECK_GT("shrink", _n, (0 :: Int))
227+
CHECK_GE("shrink", _n, (0 :: Int))
225228
CHECK_LE("shrink", _n, (unsafeLengthM mary))
226229
ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of
227230
s' -> (# s', mary #)
@@ -496,6 +499,28 @@ map' f = \ ary ->
496499
go ary mary (i+1) n
497500
{-# INLINE map' #-}
498501

502+
filter :: (a -> Bool) -> Array a -> Array a
503+
filter f = \ ary ->
504+
let !n = length ary
505+
in run $ do
506+
mary <- new_ n
507+
len <- go_filter ary mary 0 0 n
508+
shrink mary len
509+
where
510+
-- Without the @!@ on @ary@ we end up reboxing the array when using
511+
-- 'differenceCollisions'. See
512+
-- https://gitlab.haskell.org/ghc/ghc/-/issues/26525.
513+
go_filter !ary !mary !iAry !iMary !n
514+
| iAry >= n = return iMary
515+
| otherwise = do
516+
x <- indexM ary iAry
517+
if f x
518+
then do
519+
write mary iMary x
520+
go_filter ary mary (iAry + 1) (iMary + 1) n
521+
else go_filter ary mary (iAry + 1) iMary n
522+
{-# INLINE filter #-}
523+
499524
fromList :: Int -> [a] -> Array a
500525
fromList n xs0 =
501526
CHECK_EQ("fromList", n, Prelude.length xs0)

0 commit comments

Comments
 (0)