Skip to content

Commit d7e5c9b

Browse files
authored
Wingman: Properly destruct forall-quantified types (#2049)
* Properly destruct forall-quantified types * Better haddock
1 parent 7a41ab7 commit d7e5c9b

File tree

5 files changed

+58
-10
lines changed

5 files changed

+58
-10
lines changed

plugins/hls-tactics-plugin/src/Wingman/CaseSplit.hs

+21-5
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ mkFirstAgda pats body = AgdaMatch pats body
3030
-- | Transform an 'AgdaMatch' whose body is a case over a bound pattern, by
3131
-- splitting it into multiple matches: one for each alternative of the case.
3232
agdaSplit :: AgdaMatch -> [AgdaMatch]
33-
agdaSplit (AgdaMatch pats (Case (HsVar _ (L _ var)) matches)) = do
34-
(pat, body) <- matches
35-
-- TODO(sandy): use an at pattern if necessary
36-
pure $ AgdaMatch (rewriteVarPat var pat pats) $ unLoc body
33+
agdaSplit (AgdaMatch pats (Case (HsVar _ (L _ var)) matches))
34+
-- Ensure the thing we're destructing is actually a pattern that's been
35+
-- bound.
36+
| containsVar var pats
37+
= do
38+
(pat, body) <- matches
39+
-- TODO(sandy): use an at pattern if necessary
40+
pure $ AgdaMatch (rewriteVarPat var pat pats) $ unLoc body
3741
agdaSplit x = [x]
3842

3943

@@ -53,6 +57,19 @@ wildifyT (S.map occNameString -> used) = everywhere $ mkT $ \case
5357
(x :: Pat GhcPs) -> x
5458

5559

60+
------------------------------------------------------------------------------
61+
-- | Determine whether the given 'RdrName' exists as a 'VarPat' inside of @a@.
62+
containsVar :: Data a => RdrName -> a -> Bool
63+
containsVar name = everything (||) $
64+
mkQ False (\case
65+
VarPat _ (L _ var) -> eqRdrName name var
66+
(_ :: Pat GhcPs) -> False
67+
)
68+
`extQ` \case
69+
HsRecField lbl _ True -> eqRdrName name $ unLoc $ rdrNameFieldOcc $ unLoc lbl
70+
(_ :: HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)) -> False
71+
72+
5673
------------------------------------------------------------------------------
5774
-- | Replace a 'VarPat' with the given @'Pat' GhcPs@.
5875
rewriteVarPat :: Data a => RdrName -> Pat GhcPs -> a -> a
@@ -68,7 +85,6 @@ rewriteVarPat name rep = everywhere $
6885
(x :: HsRecField' (FieldOcc GhcPs) (PatCompat GhcPs)) -> x
6986

7087

71-
7288
------------------------------------------------------------------------------
7389
-- | Construct an 'HsDecl' from a set of 'AgdaMatch'es.
7490
splitToDecl

plugins/hls-tactics-plugin/src/Wingman/GHC.hs

+12-5
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,15 @@ tacticsThetaTy (tcSplitSigmaTy -> (_, theta, _)) = theta
9191
------------------------------------------------------------------------------
9292
-- | Get the data cons of a type, if it has any.
9393
tacticsGetDataCons :: Type -> Maybe ([DataCon], [Type])
94-
tacticsGetDataCons ty | Just _ <- algebraicTyCon ty =
95-
splitTyConApp_maybe ty <&> \(tc, apps) ->
96-
( filter (not . dataConCannotMatch apps) $ tyConDataCons tc
97-
, apps
98-
)
94+
tacticsGetDataCons ty
95+
| Just (_, ty') <- tcSplitForAllTy_maybe ty
96+
= tacticsGetDataCons ty'
97+
tacticsGetDataCons ty
98+
| Just _ <- algebraicTyCon ty
99+
= splitTyConApp_maybe ty <&> \(tc, apps) ->
100+
( filter (not . dataConCannotMatch apps) $ tyConDataCons tc
101+
, apps
102+
)
99103
tacticsGetDataCons _ = Nothing
100104

101105
------------------------------------------------------------------------------
@@ -132,6 +136,9 @@ getRecordFields dc =
132136
------------------------------------------------------------------------------
133137
-- | Is this an algebraic type?
134138
algebraicTyCon :: Type -> Maybe TyCon
139+
algebraicTyCon ty
140+
| Just (_, ty') <- tcSplitForAllTy_maybe ty
141+
= algebraicTyCon ty'
135142
algebraicTyCon (splitTyConApp_maybe -> Just (tycon, _))
136143
| tycon == intTyCon = Nothing
137144
| tycon == floatTyCon = Nothing

plugins/hls-tactics-plugin/test/CodeAction/AutoSpec.hs

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ spec = do
4848
autoTest 2 19 "AutoInfixApplyMany"
4949
autoTest 2 25 "AutoInfixInfix"
5050
autoTest 19 12 "AutoTypeLevel"
51+
autoTest 11 9 "AutoForallClassMethod"
5152

5253
failing "flaky in CI" $
5354
autoTest 2 11 "GoldenApplicativeThen"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{-# LANGUAGE ExplicitForAll #-}
2+
{-# LANGUAGE FlexibleContexts #-}
3+
{-# LANGUAGE MultiParamTypeClasses #-}
4+
5+
import Data.Functor.Contravariant
6+
7+
class Semigroupal cat t1 t2 to f where
8+
combine :: cat (to (f x y) (f x' y')) (f (t1 x x') (t2 y y'))
9+
10+
comux :: forall p a b c d. Semigroupal Op (,) (,) (,) p => p (a, c) (b, d) -> (p a b, p c d)
11+
comux = case combine of { (Op f) -> f }
12+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{-# LANGUAGE ExplicitForAll #-}
2+
{-# LANGUAGE FlexibleContexts #-}
3+
{-# LANGUAGE MultiParamTypeClasses #-}
4+
5+
import Data.Functor.Contravariant
6+
7+
class Semigroupal cat t1 t2 to f where
8+
combine :: cat (to (f x y) (f x' y')) (f (t1 x x') (t2 y y'))
9+
10+
comux :: forall p a b c d. Semigroupal Op (,) (,) (,) p => p (a, c) (b, d) -> (p a b, p c d)
11+
comux = _
12+

0 commit comments

Comments
 (0)