Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce structured IR of MtExpr #616

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 148 additions & 65 deletions lib/haskell/natural4/src/LS/Renamer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ scanTypeDeclName tracer mtexprs = do
-- * A GIVETH can be referred to in other rules up the scope hierarchy
-- * The head in DECIDE clauses can also be referred to by other rules in scope hierarchy
-- * WHERE clauses are local to the rule
--
renameRules :: (Traversable f) => Tracer Log -> f Rule -> Renamer (f RnRule)
renameRules tracer rules = do
rulesWithLocalDefs <-
Expand Down Expand Up @@ -443,7 +442,7 @@ renameRule tracer [email protected]{} = do
defaults <- assertEmptyList rule.defaults
symtab <- assertEmptyList rule.symtab
clauses <- traverse (renameHornClause tracer) rule.clauses
name <- renameMultiTerm tracer rule.name
name <- renameMultiTerm tracer RootExpression rule.name
pure $
Hornlike
RnHornlike
Expand Down Expand Up @@ -504,7 +503,7 @@ renameTypeDeclName :: LS.RuleName -> Renamer RnRuleName
renameTypeDeclName mtexprs = do
mt <- assertSingletonMultiTerm mtexprs
rnTyName <- lookupExistingName (NE.singleton mt) RnType
pure [RnExprName rnTyName]
pure $ RnExprName rnTyName

renameUpons ::
Maybe LS.ParamText ->
Expand Down Expand Up @@ -574,36 +573,69 @@ renameGivenInlineEnumParamText params = do
rnParams <- traverse renameEach params
pure $ RnParamText rnParams

-- | Track what "Level" an expression has.
-- This level merely tracks how deep we are in the program AST.
-- If we are at the top-level, called 'RootExpression', we sometimes have to handle
-- certain things differently. For example, if we encounter a '[LS.MultiTerm]' such as:
--
-- @
-- [MTT "f", MTT "a", MTT "b"]
-- @
--
-- Is this a function declaration (e.g., the @f a b@ part of a the haskell expression @f a b = a + b@)
-- or is this a function application, where @f@ is applied to the variables @a@ and @b@?
-- Without further context, impossible to tell, but we do want to be able to tell these two cases
-- apart to simplify transpiler backends.
-- Thus, we track whether we are at the root of the program AST.
data ExprLevel
= -- | We are at the root of an expression tree.
RootExpression
| -- | We are in some sub-expression of an expression tree.
SubExpression
deriving stock (Eq, Show, Ord, Enum, Bounded)

-- | Downgrade any 'ExprLevel' to a 'SubExpression'.
-- Strictly, speaking, this doesn't need a function, but it reduces a few occurrences
-- of random constants.
subExpression :: ExprLevel -> ExprLevel
subExpression _ = SubExpression

renameHornClause :: Tracer Log -> LS.HornClause2 -> Renamer RnHornClause
renameHornClause tracer hc = do
rnHead <- renameRelationalPredicate tracer hc.hHead
rnHead <- renameRelationalPredicate tracer RootExpression hc.hHead
rnBody <- traverse (renameBoolStruct tracer) hc.hBody
pure $
RnHornClause
{ rnHcHead = rnHead
, rnHcBody = rnBody
}

renameRelationalPredicate :: Tracer Log -> LS.RelationalPredicate -> Renamer RnRelationalPredicate
renameRelationalPredicate tracer = \case
renameRelationalPredicate :: Tracer Log -> ExprLevel -> LS.RelationalPredicate -> Renamer RnRelationalPredicate
renameRelationalPredicate tracer exprLvl = \case
LS.RPParamText pText ->
throwError $ UnsupportedRPParamText pText
LS.RPMT mt -> RnRelationalTerm <$> renameMultiTerm tracer mt
LS.RPMT mt -> RnRelationalTerm <$> renameMultiTerm tracer exprLvl mt
LS.RPConstraint lhs relationalPredicate rhs -> do
rnLhs <- renameMultiTerm tracer lhs
rnRhs <- renameMultiTerm tracer rhs
rnLhs <- renameMultiTerm tracer exprLvl lhs
rnRhs <- renameMultiTerm tracer (subExpression exprLvl) rhs
pure $ RnConstraint rnLhs relationalPredicate rnRhs
LS.RPBoolStructR lhs relationalPredicate rhs -> do
rnLhs <- renameMultiTerm tracer lhs
rnLhs <- renameMultiTerm tracer exprLvl lhs
rnRhs <- renameBoolStruct tracer rhs
pure $ RnBoolStructR rnLhs relationalPredicate rnRhs
LS.RPnary relationalPredicate rhs -> do
rnRhs <- traverse (renameRelationalPredicate tracer) rhs
pure $ RnNary relationalPredicate rnRhs
LS.RPnary relationalPredicate [] -> pure $ RnNary relationalPredicate []
LS.RPnary relationalPredicate (lhs : rhs) -> do
-- We have to handle the first element explicitly and differently.
-- See 'scanDecideHeadClause', which explains why.
rnLhs <- renameRelationalPredicate tracer exprLvl lhs
rnRhs <- traverse (renameRelationalPredicate tracer $ subExpression exprLvl) rhs
pure $ RnNary relationalPredicate (rnLhs : rnRhs)

renameBoolStruct :: Tracer Log -> LS.BoolStructR -> Renamer RnBoolStructR
renameBoolStruct tracer = \case
AA.Leaf p -> AA.Leaf <$> renameRelationalPredicate tracer p
-- No expression in a 'BoolStructR' can be a 'RootExpression' expression.
-- Thus, we hardcode 'SubExpression' here.
AA.Leaf p -> AA.Leaf <$> renameRelationalPredicate tracer SubExpression p
AA.All lbl cs -> do
rnBoolStruct <- traverse (renameBoolStruct tracer) cs
pure $ AA.All lbl rnBoolStruct
Expand Down Expand Up @@ -634,8 +666,17 @@ renameBoolStruct tracer = \case
-- For example, @[MTT "x", MTT "f"]@ will be changed @[MTT "f", MTT "x"]@,
-- if and only if @"f"@ is a known function variable in scope with associated
-- arity information.
renameMultiTerm :: Tracer Log -> LS.MultiTerm -> Renamer RnMultiTerm
renameMultiTerm tracer multiTerms = do
--
-- At last, we perform an additional translation which turns the list representation
-- of 'LS.MultiTerm' into a proper AST with dedicated constructors for function
-- application, record projection and variables.
-- We use the 'ExprLevel' to resolve an 'LS.MultiTerm' to either a function
-- declaration or a function application. If the Renamer is currently renaming
-- a root expression, for example the head of a `DECIDE` clause, then we may be
-- introducing a function declaration.
-- Otherwise, we know there can only be function applications.
renameMultiTerm :: Tracer Log -> ExprLevel -> LS.MultiTerm -> Renamer RnExpr
renameMultiTerm tracer exprLvl multiTerms = do
(reversedRnMultiTerms, ctx) <-
foldM
( \(results, state) mt -> do
Expand All @@ -646,57 +687,94 @@ renameMultiTerm tracer multiTerms = do
multiTerms
let
rnMultiTerms = reverse reversedRnMultiTerms
fixFixity ctx rnMultiTerms
multiTermsFixed <- fixFixity ctx rnMultiTerms

case multiTermsFixed of
[expr]
| Just nameOrLitOrBuiltin <- isNameOrLitOrBuiltin expr -> pure nameOrLitOrBuiltin
(varName : attrs)
| Just (name, sels) <- isProjection varName attrs -> pure $ RnProjection name sels
(varName : args)
| Just (name, argExprs) <- isFunctionApp varName args
, mustBeFunctionApplication ->
pure $ RnFunApp name argExprs
(varName : args)
| Just (name, argExprs) <- isFunctionDecl varName args
, mustBeFunctionDeclaration ->
pure $ RnFunDecl name argExprs
exprs -> throwError $ UnknownTermStructure exprs
where
-- Fixing the arity of a function requires us rewrite infix and postfix
-- notation to a prefix notation.
--
-- To rewrite a function application, we first gather the 'FuncInfo' to
-- find the declared arity of the function. Say the arity of the function @f@ is
-- given by the tuple @(p, q)@ where @p@ is the number of arguments before the
-- function name and @q@ is the number of arguments after the function name.
-- This captures functions applied in prefix, infix and postfix notation.
-- Then, we find the index of the function name as it occurs in the 'LS.MultiTerm'
-- and take @p@ elements from the back of the list of @[LS.MTExpr]@ that occur before
-- the function, which we name @ps@, and take @q@ elements from the list of
-- @[LS.MTExpr]@ that occur after the function name, called @qs@.
--
-- Finally, we replace the function application by @[f] ++ ps ++ qs@.
fixFixity ctx rnMultiTerms = case ctx.multiTermContextFunctionCall of
Nothing -> pure rnMultiTerms
Just fnName -> do
funcInfo <- lookupExistingFunction fnName
let
(preNum, postNum) = funcInfo ^. funcArity
(lhs, fnExpr, rhs) <- findFunctionApplication fnName rnMultiTerms
(leftNonArgs, leftArgs) <- processLhs fnName preNum lhs
(rightNonArgs, rightArgs) <- processRhs fnName postNum rhs
pure $ reverse leftNonArgs <> [fnExpr] <> leftArgs <> rightArgs <> rightNonArgs
mustBeFunctionApplication = exprLvl == SubExpression
mustBeFunctionDeclaration = not mustBeFunctionApplication

isProjection var sels = do
varName <- isVariableName var
selNames <- traverse isSelectorName sels
Just (varName, selNames)

isFunctionApp var args = do
varName <- isFunctionName var
Just (varName, args)

findFunctionApplication fnName rnMultiTerms = do
isFunctionDecl var args = do
varName <- isFunctionName var
argNames <- traverse isVariableName args
Just (varName, argNames)

-- | Fixing the arity of a function requires us rewrite infix and postfix
-- notation to a prefix notation.
--
-- To rewrite a function application, we first gather the 'FuncInfo' to
-- find the declared arity of the function. Say the arity of the function @f@ is
-- given by the tuple @(p, q)@ where @p@ is the number of arguments before the
-- function name and @q@ is the number of arguments after the function name.
-- This captures functions applied in prefix, infix and postfix notation.
-- Then, we find the index of the function name as it occurs in the 'LS.MultiTerm'
-- and take @p@ elements from the back of the list of @[LS.MTExpr]@ that occur before
-- the function, which we name @ps@, and take @q@ elements from the list of
-- @[LS.MTExpr]@ that occur after the function name, called @qs@.
--
-- Finally, we replace the function application by @[f] ++ ps ++ qs@.
fixFixity :: MultiTermContext -> [RnExpr] -> Renamer [RnExpr]
fixFixity ctx rnMultiTerms = case ctx.multiTermContextFunctionCall of
Nothing -> pure rnMultiTerms
Just fnName -> do
funcInfo <- lookupExistingFunction fnName
let
(preArgs, postArgsWithName) = List.break (== (RnExprName fnName)) rnMultiTerms
case postArgsWithName of
[] -> throwError $ FixArityFunctionNotFound fnName rnMultiTerms
(fnExpr : postArgs) -> pure (preArgs, fnExpr, postArgs)

processLhs name n lhs = do
case safeSplitAt n (reverse lhs) of
Nothing ->
throwError $ ArityErrorLeft n name lhs
Just (args, nonArgs) -> pure (reverse nonArgs, reverse args)

processRhs name n rhs = do
case safeSplitAt n rhs of
Nothing ->
throwError $ ArityErrorRight n name rhs
Just (nonArgs, args) -> pure (nonArgs, args)

initialMultiTermContext =
MultiTermContext
{ multiTermContextInSelector = False
, multiTermContextFunctionCall = Nothing
}
(preNum, postNum) = funcInfo ^. funcArity
(lhs, fnExpr, rhs) <- findFunctionApplication fnName rnMultiTerms
(leftNonArgs, leftArgs) <- processLhs fnName preNum lhs
(rightNonArgs, rightArgs) <- processRhs fnName postNum rhs
pure $ reverse leftNonArgs <> [fnExpr] <> leftArgs <> rightArgs <> rightNonArgs

findFunctionApplication :: RnName -> [RnExpr] -> Renamer ([RnExpr], RnExpr, [RnExpr])
findFunctionApplication fnName rnMultiTerms = do
let
(preArgs, postArgsWithName) = List.break (== (RnExprName fnName)) rnMultiTerms
case postArgsWithName of
[] -> throwError $ FixArityFunctionNotFound fnName rnMultiTerms
(fnExpr : postArgs) -> pure (preArgs, fnExpr, postArgs)

processLhs :: RnName -> Int -> [RnExpr] -> Renamer ([RnExpr], [RnExpr])
processLhs name n lhs = do
case safeSplitAt n (reverse lhs) of
Nothing ->
throwError $ ArityErrorLeft n name lhs
Just (args, nonArgs) -> pure (reverse nonArgs, reverse args)

processRhs :: RnName -> Int -> [RnExpr] -> Renamer ([RnExpr], [RnExpr])
processRhs name n rhs = do
case safeSplitAt n rhs of
Nothing ->
throwError $ ArityErrorRight n name rhs
Just (nonArgs, args) -> pure (nonArgs, args)

initialMultiTermContext :: MultiTermContext
initialMultiTermContext =
MultiTermContext
{ multiTermContextInSelector = False
, multiTermContextFunctionCall = Nothing
}

-- | Rename a single 'LS.MTExpr' to a 'RnExpr'.
renameMultiTermExpression :: Tracer Log -> MultiTermContext -> LS.MTExpr -> Renamer (RnExpr, MultiTermContext)
Expand Down Expand Up @@ -754,6 +832,7 @@ renameMultiTermExpression tracer ctx = \case
where
-- There is no doubt this is a text literal, if it is enclosed in quotes.
-- Strips away the quotes.
-- TODO: this is lossy, we can never rebuild the exact AST with this.
isTextLiteral t = do
('"', t') <- uncons t
(t'', '"') <- unsnoc t'
Expand All @@ -775,6 +854,7 @@ data RenamerError
| UnexpectedRnNameNotFound RnName
| InsertNameUnexpectedType RnNameType RnNameType
| LookupOrInsertNameUnexpectedType RnNameType RnNameType
| UnknownTermStructure [RnExpr]
| AssertErr AssertionError
deriving (Show, Eq, Ord)

Expand Down Expand Up @@ -836,6 +916,8 @@ renderRenamerError = \case
<> Text.pack (show actual)
<> " but expected: "
<> Text.pack (show expected)
UnknownTermStructure terms ->
"Renamed the terms " <> Text.pack (show terms) <> " but failed to determine the program structure."
AssertErr err -> renderAssertionError err

renderAssertionError :: AssertionError -> Text.Text
Expand Down Expand Up @@ -976,7 +1058,8 @@ recordScopeTable act = do
prevUnique <- use scUniqueSupply
a <- act
scTable <- use scScopeTable
let scTableWithNewNames = filterScopeTable (\_ name -> name.rnUniqueId >= prevUnique) scTable
let
scTableWithNewNames = filterScopeTable (\_ name -> name.rnUniqueId >= prevUnique) scTable
pure (a, scTableWithNewNames)

recordScopeTable_ :: Renamer a -> Renamer ScopeTable
Expand Down
Loading