1
1
module LambdaBuffers.Codegen.LamVal.Eq (deriveEqImpl ) where
2
2
3
+ import Control.Exception qualified as Exception
3
4
import Data.Map.Ordered qualified as OMap
4
5
import LambdaBuffers.Codegen.LamVal (Field , QProduct , QRecord , QSum , ValueE (CaseE , FieldE , LamE , LetE , RefE ), (@) )
5
6
import LambdaBuffers.Codegen.LamVal.Derive (deriveImpl )
@@ -23,11 +24,7 @@ eqSum qsum =
23
24
r
24
25
( \ (ctorTyR, rxs) ->
25
26
if fst ctorTyL == fst ctorTyR
26
- then
27
- foldl
28
- (\ tot (lx, rx, ty) -> andE @ tot @ (eqE ty @ lx @ rx))
29
- trueE
30
- (zip3 lxs rxs (snd ctorTyL))
27
+ then eqListHelper lxs rxs (snd ctorTyL)
31
28
else falseE
32
29
)
33
30
)
@@ -48,11 +45,7 @@ eqProduct qprod@(_, prodTy) =
48
45
LetE
49
46
qprod
50
47
r
51
- ( \ rxs ->
52
- foldl
53
- (\ tot (lx, rx, ty) -> andE @ tot @ (eqE ty @ lx @ rx))
54
- trueE
55
- (zip3 lxs rxs prodTy)
48
+ ( \ rxs -> eqListHelper lxs rxs prodTy
56
49
)
57
50
)
58
51
)
@@ -65,16 +58,39 @@ eqRecord (qtyN, recTy) =
65
58
( \ l ->
66
59
LamE
67
60
( \ r ->
68
- foldl
69
- (\ tot field -> andE @ tot @ eqField qtyN field l r)
70
- trueE
71
- (OMap. assocs recTy)
61
+ let eqFieldExprs = map (\ field -> eqField qtyN field l r) $ OMap. assocs recTy
62
+ in if null eqFieldExprs
63
+ then trueE
64
+ else
65
+ foldl1
66
+ (\ tot eqFieldExpr -> andE @ tot @ eqFieldExpr)
67
+ eqFieldExprs
72
68
)
73
69
)
74
70
75
71
eqField :: PC. QTyName -> Field -> ValueE -> ValueE -> ValueE
76
72
eqField qtyN (fieldName, fieldTy) l r = eqE fieldTy @ FieldE (qtyN, fieldName) l @ FieldE (qtyN, fieldName) r
77
73
74
+ {- | 'eqListHelper' is an internal function which equates two lists of 'ValueE'
75
+ with their type pairwise.
76
+
77
+ Preconditions:
78
+ - All input lists are the same length
79
+ -}
80
+ eqListHelper :: [ValueE ] -> [ValueE ] -> [LT. Ty ] -> ValueE
81
+ eqListHelper lxs rxs tys =
82
+ Exception. assert preconditionAssertion $
83
+ let eqedLxsRxsTys = map (\ (lx, rx, ty) -> eqE ty @ lx @ rx) $ zip3 lxs rxs tys
84
+ in if null eqedLxsRxsTys
85
+ then trueE
86
+ else foldl1 (\ tot eqExpr -> andE @ tot @ eqExpr) eqedLxsRxsTys
87
+ where
88
+ preconditionAssertion =
89
+ let lxsLength = length lxs
90
+ rxsLength = length rxs
91
+ tysLength = length tys
92
+ in lxsLength == rxsLength && rxsLength == tysLength
93
+
78
94
-- | Hooks
79
95
deriveEqImpl :: PC. ModuleName -> PC. TyDefs -> PC. Ty -> Either P. InternalError ValueE
80
96
deriveEqImpl mn tydefs = deriveImpl mn tydefs eqSum eqProduct eqRecord
0 commit comments