Skip to content
This repository was archived by the owner on Aug 5, 2024. It is now read-only.

Commit 4a0e671

Browse files
committed
conditional store fix
1 parent 2ab8fce commit 4a0e671

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

examples/condstore.plv

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
condstore1 (inout v :: double) (w :: double) (b :: bool) :: () :=
2+
v[b] <- w;
3+
4+
condstore2 {n} (inout A :: double[n]) (w :: double) (b :: bool) :: () :=
5+
A[b] <- w;
6+
7+
condstore3 {n} (inout A :: double[n]) (w :: double) :: () :=
8+
A[A < w] <- w;

src/Language/Plover/CodeGen.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,9 @@ compileLoc loc@(Index a idxs) = do aloc <- asLoc $ compileStat a
17051705
vty@(VecType {}) -> do
17061706
idxloc <- asLoc $ compileStat idx
17071707
return $ Left (length $ getIndices vty, vecBaseType vty, idxloc)
1708+
BoolType -> do
1709+
idxloc <- asLoc $ compileStat idx
1710+
return $ Left (0, BoolType, idxloc)
17081711
ty -> do
17091712
idxex <- asExp $ compileStat idx
17101713
return $ Right idxex

src/Language/Plover/Types.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,9 +942,9 @@ getLocType :: Location CExpr -> Type
942942
getLocType (Ref ty v) = ty
943943
getLocType (Index a idxs) = normalizeTypes $ getTypeIdx idxs (normalizeTypes $ getType a)
944944
where getTypeIdx [] aty = aty
945-
getTypeIdx (idx:idxs) aty@(VecType {}) = getTypeIdxty [] (getType idx) idxs aty
945+
getTypeIdx (idx:idxs) aty = getTypeIdxty [] (getType idx) idxs aty
946946

947-
getTypeIdxty acc idxty idxs vty@(VecType {}) =
947+
getTypeIdxty acc idxty idxs vty =
948948
case normalizeTypes idxty of
949949
VecType st' idxs' idxtybase -> VecType st' idxs' $
950950
getTypeIdxty (acc ++ idxs') idxtybase idxs vty

src/Language/Plover/Unify.hs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -968,11 +968,15 @@ typeCheckLoc pos (Index a idxs) = do -- see note [indexing rules] and see `getLo
968968
idxty <- typeCheck idx
969969
typeCheckIdxty oty [] idxty idxs aty
970970
typeCheckIdx oty (idx:idxs) ty = do
971-
addUError $ UGenTyError pos oty "Too many indices on expression of type"
972-
return ty
971+
idxty <- typeCheck idx
972+
case idxty of
973+
BoolType -> typeCheckIdxty oty [] idxty idxs ty
974+
_ -> do addUError $ UGenTyError pos oty
975+
"Too many indices on expression of type"
976+
return ty
973977

974978
typeCheckIdxty :: Type -> [CExpr] -> Type -> [CExpr] -> Type -> UM Type
975-
typeCheckIdxty oty acc idxty idxs vty@(VecType {}) =
979+
typeCheckIdxty oty acc idxty idxs vty =
976980
case normalizeTypes idxty of
977981
VecType st' idxs' idxtybase -> -- result shape equals shape of index
978982
VecType st' idxs' <$> (typeCheckIdxty oty (acc ++ idxs') idxtybase idxs vty)

0 commit comments

Comments
 (0)