Skip to content

Commit 60fe2d1

Browse files
authored
Merge pull request #43 from mlabs-haskell/compiler/typeclasses-solver
compiler/typeclass-solver
2 parents 1f758b0 + 1722c8e commit 60fe2d1

File tree

6 files changed

+443
-26
lines changed

6 files changed

+443
-26
lines changed

lambda-buffers-compiler/lambda-buffers-compiler.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ library
121121
LambdaBuffers.Compiler.TypeClass.Pat
122122
LambdaBuffers.Compiler.TypeClass.Pretty
123123
LambdaBuffers.Compiler.TypeClass.Rules
124+
LambdaBuffers.Compiler.TypeClass.Solve
124125
LambdaBuffers.Compiler.TypeClassCheck
125126

126127
hs-source-dirs: src

lambda-buffers-compiler/src/LambdaBuffers/Compiler/TypeClass/Pat.hs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ cost of significantly more complex type signatures.
1919
-}
2020

2121
data Pat
22-
= {- extremely stupid, unfortunately necessary -}
22+
= {- Name / ModuleName / Opaque / TyVarP are literal patterns (or ground terms)
23+
because hey cannot contain any VarPs and therefore "have no holes".
24+
Every TyDef or subcomponent thereof will be translated into a composite
25+
pattern "without any holes". (Nil is also a literal/ground term, I guess) -}
2326
Name Text
24-
| ModuleName [Text] -- also stupid, also necessary -_-
27+
| ModuleName [Text]
2528
| Opaque
29+
| TyVarP Text
2630
| {- Lists (constructed from Nil and :*) with bare types are used to
2731
encode products (where a list of length n encodes an n-tuple)
2832
Lists with field labels (l := t) are used to encode records and sum types
@@ -45,10 +49,8 @@ data Pat
4549
| ProdP Pat {- Pat arg should be a list of "Bare types" -}
4650
| SumP Pat {- where the Pat arg is expected to be (Constr l t :* rest) or Nil, where
4751
rest is either Nil or a tyList of Constrs -}
48-
| VarP Text {- This isn't a type variable. Although it is used to represent them in certain contexts,
49-
it is also used more generally to refer to any "hole" in a pattern to which another pattern
50-
may be substituted. We could have separate constr for type variables but it doesn't appear to be
51-
necessary at this time. -}
52+
| VarP Text {- This isn't a type variable. It is used more generally to refer to any "hole" in a pattern into
53+
to which another pattern may be substituted. TyVarP is the literal pattern / ground term for TyVars -}
5254
| RefP Pat Pat {- 1st arg should be a ModuleName -}
5355
| AppP Pat Pat {- Pattern for Type applications -}
5456
| {- This last one is a bit special. This represents a complete type declaration.

lambda-buffers-compiler/src/LambdaBuffers/Compiler/TypeClass/Pretty.hs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
{-# LANGUAGE OverloadedLabels #-}
21
{-# LANGUAGE OverloadedStrings #-}
32
-- orphans are the whole point of this module!
43
{-# OPTIONS_GHC -Wno-orphans #-}
@@ -10,14 +9,14 @@ module LambdaBuffers.Compiler.TypeClass.Pretty (
109
(<///>),
1110
) where
1211

13-
import Control.Lens ((^.))
1412
import Data.Generics.Labels ()
13+
import Data.Text qualified as T
1514
import LambdaBuffers.Compiler.ProtoCompat qualified as P
16-
import LambdaBuffers.Compiler.TypeClass.Pat (Pat (AppP, DecP, ModuleName, Name, Nil, Opaque, ProdP, RecP, RefP, SumP, VarP, (:*), (:=)), patList)
15+
import LambdaBuffers.Compiler.TypeClass.Pat (Pat (AppP, DecP, ModuleName, Name, Nil, Opaque, ProdP, RecP, RefP, SumP, TyVarP, VarP, (:*), (:=)), patList)
1716
import LambdaBuffers.Compiler.TypeClass.Rules (
1817
Class (Class),
1918
Constraint (C),
20-
Instance,
19+
FQClassName (FQClassName),
2120
Rule ((:<=)),
2221
)
2322
import Prettyprinter (
@@ -34,21 +33,16 @@ import Prettyprinter (
3433
(<+>),
3534
)
3635

37-
instance Pretty P.TyClassRef where
38-
pretty = \case
39-
P.ForeignCI (P.ForeignClassRef cn mn _) -> pretty mn <> "." <> pretty (cn ^. #name)
40-
P.LocalCI (P.LocalClassRef cn _) -> pretty (cn ^. #name)
41-
42-
instance Pretty P.ModuleName where
43-
pretty (P.ModuleName pts _) = hcat . punctuate "." $ map (\x -> pretty $ x ^. #name) pts
36+
instance Pretty FQClassName where
37+
pretty (FQClassName cn mnps) = hcat (punctuate "." . map pretty $ mnps) <> pretty cn
4438

4539
instance Pretty Class where
4640
pretty (Class nm _) = pretty nm
4741

4842
instance Pretty Constraint where
4943
pretty (C cls p) = pretty cls <+> pretty p
5044

51-
instance Pretty Instance where
45+
instance Pretty Rule where
5246
pretty (c :<= []) = pretty c
5347
pretty (c :<= cs) = pretty c <+> "<=" <+> list (pretty <$> cs)
5448

@@ -65,6 +59,7 @@ instance Pretty Pat where
6559
Name t -> pretty t
6660
ModuleName ts -> hcat . punctuate "." . map pretty $ ts
6761
Opaque -> "<OPAQUE>"
62+
TyVarP t -> pretty t
6863
RecP ps -> case patList ps of
6964
Nothing -> pretty ps
7065
Just fields -> case traverse prettyField fields of
@@ -87,7 +82,8 @@ instance Pretty Pat where
8782
RefP mn@(ModuleName _) n@(Name _) -> pretty mn <> "." <> pretty n
8883
RefP Nil (Name n) -> pretty n
8984
RefP p1 p2 -> parens $ "Ref" <+> pretty p1 <+> pretty p2
90-
VarP t -> pretty t
85+
-- Pattern variables are uppercased to distinguish them from proper TyVars
86+
VarP t -> pretty (T.toUpper t)
9187
ap@(AppP p1 p2) -> case prettyApp ap of
9288
Just pap -> parens pap
9389
Nothing -> "App" <+> pretty p1 <+> pretty p2

lambda-buffers-compiler/src/LambdaBuffers/Compiler/TypeClass/Rules.hs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@ module LambdaBuffers.Compiler.TypeClass.Rules (
55
Class (..),
66
Constraint (..),
77
Rule (..),
8-
type Instance,
8+
FQClassName (..),
99
mapPat,
10+
ruleHeadPat,
11+
ruleHeadClass,
1012
) where
1113

12-
import LambdaBuffers.Compiler.ProtoCompat qualified as P
14+
import Data.Text (Text)
15+
1316
import LambdaBuffers.Compiler.TypeClass.Pat (Pat)
1417

18+
data FQClassName = FQClassName {cName :: Text, cModule :: [Text]}
19+
deriving stock (Show, Eq, Ord)
20+
1521
data Class = Class
16-
{ name :: P.TyClassRef
17-
, supers :: [Class]
22+
{ cname :: FQClassName
23+
, csupers :: [Class]
1824
}
1925
deriving stock (Show, Eq, Ord)
2026

@@ -33,9 +39,15 @@ data Rule where
3339
deriving stock (Show, Eq, Ord)
3440
infixl 7 :<=
3541

36-
type Instance = Rule
37-
3842
{- Map over the Pats inside of an Rule
3943
-}
4044
mapPat :: (Pat -> Pat) -> Rule -> Rule
4145
mapPat f (C c ty :<= is) = C c (f ty) :<= map (\(C cx p) -> C cx (f p)) is
46+
47+
{- Extract the inner Pat from a Rule head
48+
-}
49+
ruleHeadPat :: Rule -> Pat
50+
ruleHeadPat (C _ p :<= _) = p
51+
52+
ruleHeadClass :: Rule -> Class
53+
ruleHeadClass (C c _ :<= _) = c
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
{-# LANGUAGE LambdaCase #-}
2+
3+
module LambdaBuffers.Compiler.TypeClass.Solve (solveM, solve, Overlap (..)) where
4+
5+
import LambdaBuffers.Compiler.TypeClass.Pat (
6+
Pat (AppP, DecP, ProdP, RecP, RefP, SumP, VarP, (:*), (:=)),
7+
matches,
8+
)
9+
import LambdaBuffers.Compiler.TypeClass.Rules (
10+
Class (csupers),
11+
Constraint (C),
12+
Rule ((:<=)),
13+
mapPat,
14+
ruleHeadClass,
15+
ruleHeadPat,
16+
)
17+
18+
import Control.Monad.Except (throwError)
19+
import Control.Monad.Reader (ReaderT, runReaderT)
20+
import Control.Monad.Reader.Class (MonadReader (ask))
21+
import Control.Monad.Writer.Class (MonadWriter (tell))
22+
import Control.Monad.Writer.Strict (WriterT, execWriterT)
23+
import Data.Foldable (traverse_)
24+
import Data.List (foldl')
25+
import Data.Set qualified as S
26+
import Data.Text (Text)
27+
28+
{- Pattern/Template/Unification variable substitution.
29+
Given a string that represents a variable name,
30+
and a type to instantiate variables with that name to,
31+
performs the instantiation
32+
-}
33+
subV :: Text -> Pat -> Pat -> Pat
34+
subV varNm t = \case
35+
var@(VarP v) -> if v == varNm then t else var
36+
x :* xs -> subV varNm t x :* subV varNm t xs
37+
l := x -> subV varNm t l := subV varNm t x
38+
ProdP xs -> ProdP (subV varNm t xs)
39+
RecP xs -> RecP (subV varNm t xs)
40+
SumP xs -> SumP (subV varNm t xs)
41+
AppP t1 t2 -> AppP (subV varNm t t1) (subV varNm t t2)
42+
RefP n x -> RefP (subV varNm t n) (subV varNm t x)
43+
DecP a b c -> DecP (subV varNm t a) (subV varNm t b) (subV varNm t c)
44+
other -> other
45+
46+
{- Performs substitution on an entire instance (the first argument) given the
47+
concrete types from a Pat (the second argument).
48+
Note that ONLY PatVars which occur in the Instance *HEAD* are replaced, though they
49+
are replaced in the instance superclasses as well (if they occur there).
50+
-}
51+
subst :: Rule -> Pat -> Rule
52+
subst cst@(C _ t :<= _) ty = mapPat (go (getSubs t ty)) cst
53+
where
54+
go :: [(Text, Pat)] -> Pat -> Pat
55+
go subs tty =
56+
let noflip p1 p2 = uncurry subV p2 p1
57+
in foldl' noflip tty subs
58+
59+
{- Given two patterns (which are hopefully structurally similar), gather a list of all substitutions
60+
from the PatVars in the first argument to the concrete types (hopefully!) in the second argument
61+
-}
62+
getSubs :: Pat -> Pat -> [(Text, Pat)] -- should be a set, whatever
63+
getSubs (VarP s) t = [(s, t)]
64+
getSubs (x :* xs) (x' :* xs') = getSubs x x' <> getSubs xs xs'
65+
getSubs (l := t) (l' := t') = getSubs l l' <> getSubs t t'
66+
getSubs (ProdP xs) (ProdP xs') = getSubs xs xs'
67+
getSubs (RecP xs) (RecP xs') = getSubs xs xs'
68+
getSubs (SumP xs) (SumP xs') = getSubs xs xs'
69+
getSubs (AppP t1 t2) (AppP t1' t2') = getSubs t1 t1' <> getSubs t2 t2'
70+
getSubs (RefP n t) (RefP n' t') = getSubs n n' <> getSubs t t'
71+
getSubs (DecP a b c) (DecP a' b' c') = getSubs a a' <> getSubs b b' <> getSubs c c'
72+
getSubs _ _ = []
73+
74+
-- NoMatch isn't fatal but OverlappingMatches is (i.e. we need to stop when we encounter it)
75+
data MatchError
76+
= NoMatch
77+
| OverlappingMatches [Rule]
78+
79+
-- for SolveM, since we catch NoMatch
80+
data Overlap = Overlap Constraint [Rule]
81+
deriving stock (Show, Eq)
82+
83+
selectMatchingInstance :: Pat -> Class -> [Rule] -> Either MatchError Rule
84+
selectMatchingInstance p c rs = case filter matchPatAndClass rs of
85+
[] -> Left NoMatch
86+
[r] -> Right r
87+
overlaps -> Left $ OverlappingMatches overlaps
88+
where
89+
matchPatAndClass :: Rule -> Bool
90+
matchPatAndClass r =
91+
ruleHeadClass r == c
92+
&& ruleHeadPat r
93+
`matches` p
94+
95+
type SolveM = ReaderT [Rule] (WriterT (S.Set Constraint) (Either Overlap))
96+
97+
{- Given a list of instances (the initial scope), determines whether we can derive
98+
an instance of the Class argument for the Pat argument. A result of [] indicates that there are
99+
no remaining subgoals and that the constraint has been solved.
100+
NOTE: At the moment this handles superclasses differently than you might expect -
101+
instead of assuming that the superclasses for all in-scope classes are defined,
102+
we check that those constraints can be solved before affirmatively judging that the
103+
target constraint has been solved. I *think* that makes sense in this context (whereas in Haskell
104+
it doesn't b/c it's *impossible* to have `instance Foo X` if the definition of Foo is
105+
`class Bar y => Foo y` without an `instance Bar X`)
106+
-}
107+
solveM :: Constraint -> SolveM ()
108+
solveM cst@(C c pat) =
109+
ask >>= \inScope ->
110+
-- First, we look for the most specific instance...
111+
case selectMatchingInstance pat c inScope of
112+
Left e -> case e of
113+
NoMatch -> tell $ S.singleton cst
114+
OverlappingMatches olps -> throwError $ Overlap cst olps
115+
-- If there is, we substitute the argument of the constraint to be solved into the matching rules
116+
Right rule -> case subst rule pat of
117+
-- If there are no additional constraints on the rule, we try to solve the superclasses
118+
C _ p :<= [] -> solveClassesFor p (csupers c)
119+
-- If there are additional constraints on the rule, we try to solve them
120+
C _ _ :<= is -> do
121+
traverse_ solveM is
122+
solveClassesFor pat (csupers c)
123+
where
124+
-- NOTE(@bladyjoker): The version w/ flip is more performant...
125+
-- Given a Pat and a list of Classes, attempt to solve the constraints
126+
-- constructed from the Pat and each Class
127+
solveClassesFor :: Pat -> [Class] -> SolveM ()
128+
solveClassesFor p = traverse_ (\cls -> solveM (C cls p))
129+
130+
solve :: [Rule] -> Constraint -> Either Overlap [Constraint]
131+
solve rules c = fmap S.toList $ execWriterT $ runReaderT (solveM c) rules

0 commit comments

Comments
 (0)