Skip to content

Commit

Permalink
langdefs: use case expression for function bodies (#113)
Browse files Browse the repository at this point in the history
## Summary

Instead of a list of assignments of the form `<f> <pattern>* = <body>`,
the body of a meta-function must now be a `case` expression.

## Details

A function such as:
```
function f, x -> x:
  f(z) = 1
  f(1) = 2
```
is now written as:
```
function f, x -> x:
  case _
  of z: 1
  of 1: 2 
```

The two reasons for this change are:

1. there's no point in repeatedly specifying the name of a function:
   it's the only valid identifier that may appear in before the
   parenthesis. It's presence only introduces visual noise and makes
   it harder to rename functions

2. the assignment based syntax is a bit trickier to parse, and it's
   arguably not very intuitive. For people already familiar with
   NimSkull and its syntax, what the new `case` expression syntax does
   is likely easier to understand

Using a `case` expression also makes `else` available to be used for
representing the fallthrough case (the new parser already supports
this).

Two things that might be slightly confusing are that:
1. contrary to the NimSkull semantics, the `of` branches in the
   DSL are ordered
2. in the DSL, the body of the `of` branches also affects whether the
  branch is actually "picked"

The macro's parsing logic is updated to handle the new syntax, and the
source language specification is changed to make use of it.
  • Loading branch information
zerbina authored Feb 1, 2025
1 parent 0944c9f commit 1d2e064
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 78 deletions.
107 changes: 62 additions & 45 deletions languages/source.nim
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,18 @@ const lang* = language:

function desugar, expr -> e:
# FIXME: the sub-expressions need to be desugared too!
desugar And(expr_1, expr_2) =
case _
of And(expr_1, expr_2):
If(expr_1, expr_2, "false")
desugar Or(expr_1, expr_2) =
of Or(expr_1, expr_2):
If(expr_1, "true", expr_2)
desugar Decl(x_1, expr_1) =
of Decl(x_1, expr_1):
Let(x_1, expr_1, TupleCons())
desugar Exprs(*expr_1, Decl(x_1, expr_2), +expr_3) =
of Exprs(*expr_1, Decl(x_1, expr_2), +expr_3):
Exprs(...expr_1, Let(x_1, expr_2, Exprs(...expr_3)))
desugar If(expr_1, expr_2) =
of If(expr_1, expr_2):
If(expr_1, expr_2, TupleCons())
desugar If(Exprs(*expr_1, expr_2), expr_3, expr_4) =
of If(Exprs(*expr_1, expr_2), expr_3, expr_4):
Exprs(...expr_1, If(expr_2, expr_3, expr_4))

## Type Relations
Expand Down Expand Up @@ -174,22 +175,25 @@ const lang* = language:
function common, (typ, typ) -> typ:
## Computes the closest common ancestor type for a type pair. The function
## is not total, as not all two types have a common ancestor type.
common(typ_1, typ_1) = typ_1
common(typ_1, typ_2) = block:
case _
of typ_1, typ_1: typ_1
of typ_1, typ_2:
premise typ_1 <: typ_2
typ_2
common(typ_1, typ_2) = block:
of typ_1, typ_2:
premise typ_2 <: typ_1
typ_1

function uniqueTypes, (+typ) -> z:
## Returns true (1) if all types are unique (in the sense of type
## equality), false (0) otherwise.
uniqueTypes(typ) = 1
uniqueTypes(typ_1, *typ, typ_2, *typ) = block:
case _
of typ: 1
of typ_1, *typ, typ_2, *typ:
premise eq(typ_1, typ_2)
0
uniqueTypes(typ_1, +typ_2) = uniqueTypes(...typ_2)
of typ_1, +typ_2:
uniqueTypes(...typ_2)

inductive ttypes(inp, inp, out), "$1 \\vdash_{\\tau} $2 : $3":
axiom "S-void-type", C, VoidTy(), "void"
Expand Down Expand Up @@ -235,9 +239,10 @@ const lang* = language:
## Returns true (1) when all symbols are unique, false (0) otherwise.
# there are no "real" booleans in the meta-language, hence 1 and 0 being
# used
unique(x) = 1
unique(x_1, *x, x_1, *x) = 0
unique(x_1, x_2, *x_3) = unique(x_2, ...x_3)
case _
of x: 1
of x_1, *x, x_1, *x: 0
of x_1, x_2, *x_3: unique(x_2, ...x_3)

inductive types(inp, inp, out), "$1 \\vdash $2 : $3":
axiom "S-int", C, n, IntTy()
Expand Down Expand Up @@ -492,90 +497,101 @@ const lang* = language:
function substitute, (e, +[x, e]) -> (e, +[x, e]):
## The substitution function, which handles binding values/expressions to
## identifiers. Identifiers cannot be shadowed.
substitute(any_1(*e_1), *any_2) = block:
case _
of any_1(*e_1), *any_2:
where *e_2, ...substitute(e_1, ...any_2)
term any_1(...e_2)
# the actual substitution:
substitute(x_1, *[!x_1, _], [x_1, e_1], *[!x_1, _]) = e_1
substitute(e_1, +any) = e_1 # nothing to replace
of x_1, *[!x_1, _], [x_1, e_1], *[!x_1, _]:
e_1 # the actual substitution
of e_1, +any:
e_1 # nothing to replace

function trunc, r -> n:
## Round towards zero.

function intAdd, (n, n) -> n:
intAdd(n_1, n_2) = block:
case _
of n_1, n_2:
where n_3, n_1 + n_2
condition -(2 ^ 63) <= n_3
condition n_3 < (2 ^ 63)
n_3
intAdd(n_1, n_2) = {}
else: {}

function intSub, (n, n) -> n:
intSub(n_1, n_2) = block:
case _
of n_1, n_2:
where n_3, n_1 - n_2
condition -(2 ^ 63) <= n_3
condition n_3 < (2 ^ 63)
n_3
intSub(n_1, n_2) = {}
else: {}

function intMul, (n, n) -> n:
intMul(n_1, n_2) = block:
case _
of n_1, n_2:
where n_3, n_1 * n_2
condition -(2 ^ 63) <= n_3
condition n_3 < (2 ^ 63)
n_3
intMul(n_1, n_2) = {}
else: {}

function intDiv, (n, n) -> n:
intDiv(n_1, 0) = {}
intDiv(n_1, n_2) = block:
case _
of n_1, 0: {}
of n_1, n_2:
condition n_1 == (-2 ^ 63)
condition n_2 == -1
{}
intDiv(n_1, n_2) = trunc(n_1 / n_2)
of n_1, n_2: trunc(n_1 / n_2)

function intMod, (n, n) -> n:
intMod(n_1, 0) = {}
intMod(n_1, n_2) = n_1 - (n_2 * trunc(n_1 / n_2))
case _
of n_1, 0: {}
of n_1, n_2: n_1 - (n_2 * trunc(n_1 / n_2))

function float_add, (r, r) -> r:
## XXX: not defined
function float_sub, (r, r) -> r:
## XXX: not defined

function valEq, (val, val) -> val:
valEq(val_1, val_1) = "true"
valEq(val_1, val_2) = "false" # otherwise
case _
of val_1, val_1: "true"
else: "false"

function lt, (val, val) -> val:
lt(n_1, n_2) = block:
case _
of n_1, n_2:
condition n_1 < n_2
"true"
lt(r_1, r_2) = block:
of r_1, r_2:
condition r_1 < r_2
"true"
lt(val_1, val_2) = "false" # otherwise
else: "false"

function lessEqual, (val, val) -> val:
lessEqual(val_1, val_2) = block:
case _
of val_1, val_2:
where "true", valEq(val_1, val_2)
"true"
lessEqual(val_1, val_2) = block:
of val_1, val_2:
where "true", lt(val_1, val_2)
"true"
lessEqual(val_1, val_2) = "false" # otherwise
else: "false"

# TODO: the floating-point operations need to be defined according to the
# IEEE 754.2008 standard

function copy, (C, val) -> val:
## The `copy` function takes a context and value and maps them to a value
## that is neither a location nor contains any.
copy(C, c_1) = c_1
copy(C_1, l_1) = copy(C_1, C_1.locs(l_1))
copy(C, `proc`(typ_r, *[x_1, typ_1], e_1)) = `proc`(typ_r, ...[x_1, typ_1], e_1)
copy(C, `array`(*val_1)) = `array`(...val_1)
copy(C, `tuple`(*val_1)) = `tuple`(...val_1)
case _
of C, c_1: c_1
of C_1, l_1: copy(C_1, C_1.locs(l_1))
of C, `proc`(typ_r, *[x_1, typ_1], e_1): `proc`(typ_r, ...[x_1, typ_1], e_1)
of C, `array`(*val_1): `array`(...val_1)
of C, `tuple`(*val_1): `tuple`(...val_1)

function utf8_bytes, x -> (+ch,):
# TODO: the function is mostly self-explanatory, but it should be defined in
Expand All @@ -584,8 +600,9 @@ const lang* = language:

function len, (val) -> z:
## Computes the number of elements in an array value.
len(`array`()) = 0
len(`array`(val_1, *val_2)) = 1 + len(...val_2)
case _
of `array`(): 0
of `array`(val_1, *val_2): 1 + len(...val_2)

## Evaluation Contexts
## ~~~~~~~~~~~~~~~~~~~
Expand Down
83 changes: 50 additions & 33 deletions spec/langdefs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -519,41 +519,58 @@ proc parseNonTerminal(lookup; body: NimNode): TreeNode =

proc parseFunction(lookup; def: NimNode): Function =
result = Function(name: $def[1])
for it in def[3].items:
case it.kind
of nnkCommentStmt:
# TODO: implement
discard
of nnkAsgn:
let left = it[0]
left.expectKind nnkCallKinds
if not left[0].eqIdent(result.name):
error(fmt"name must be '{result.name}'", left[0])

var
impl = FunctionImpl()
c = Context()

for i in 1..<left.len:
impl.params.add parsePattern(lookup, c, left[i])

case it[1].kind
of nnkBlockStmt:
let body = it[1][1]
body.expectKind nnkStmtList
body.expectMinLen 1
# everything prior to the result must be a predicate
for j in 0..<body.len-1:
body[j].expectKind nnkCallKinds
let x = body[j]
impl.predicates.add parsePredicate(lookup, c, x)
impl.output = parseExpr(lookup, c, body[^1])
var i = 0
# skip/parse the leading comment statements
while i < def[3].len and def[3][i].kind == nnkCommentStmt:
# TODO: handle
inc i

if i == def[3].len:
# empty function; this is currently allowed
return

let signature = def[2]
signature.expectKind nnkInfix
signature.expectLen 3

let stmt = def[3][i]
stmt.expectKind nnkCaseStmt
# ignore the selector expression
stmt.expectMinLen 2
for i in 1..<stmt.len:
let branch = stmt[i]
var
impl = FunctionImpl()
c = Context()
case branch.kind
of nnkOfBranch:
for j in 0..<branch.len-1:
impl.params.add parsePattern(lookup, c, branch[j])
of nnkElse, nnkElseExpr:
# take the patterns from the signature
if signature[1].kind == nnkTupleConstr:
for it in signature[1].items:
impl.params.add parsePattern(lookup, c, it)
else:
impl.output = parseExpr(lookup, c, it[1])
result.impls.add impl
# it's only a single argument
impl.params.add parsePattern(lookup, c, signature[1])
else:
error("expected 'else' or 'of' branch", branch)

case branch[^1].kind
of nnkStmtList:
let body = branch[^1]
body.expectKind nnkStmtList
body.expectMinLen 1
# everything prior to the result must be a predicate
for j in 0..<body.len-1:
body[j].expectKind nnkCallKinds
let x = body[j]
impl.predicates.add parsePredicate(lookup, c, x)
impl.output = parseExpr(lookup, c, body[^1])
else:
error(fmt"expected assignment of the form `{result.name}(...) = ...`",
it)
impl.output = parseExpr(lookup, c, branch[^1])
result.impls.add impl

proc parseRelationHeader(n: NimNode): Relation =
n.expectLen 4
Expand Down

0 comments on commit 1d2e064

Please sign in to comment.