Skip to content

Commit 6026e57

Browse files
authored
[wgsl] Add more complete function calling support (gfx-rs#144)
* Add function calling support to wgsl frontend * Fix external namespace with multiple namespaces * changes after code review * Don't re-tokenize std_namespace every time
1 parent 5035362 commit 6026e57

File tree

3 files changed

+127
-47
lines changed

3 files changed

+127
-47
lines changed

src/front/wgsl.rs

Lines changed: 111 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ pub enum Error<'a> {
209209
ZeroStride,
210210
#[error("not a composite type: {0:?}")]
211211
NotCompositeType(crate::TypeInner),
212+
#[error("function redefinition: `{0}`")]
213+
FunctionRedefinition(&'a str),
212214
//MutabilityViolation(&'a str),
213215
// TODO: these could be replaced with more detailed errors
214216
#[error("other error")]
@@ -237,7 +239,7 @@ impl<'a> Lexer<'a> {
237239
self.clone().next()
238240
}
239241

240-
fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
242+
fn expect(&mut self, expected: Token<'_>) -> Result<(), Error<'a>> {
241243
let token = self.next();
242244
if token == expected {
243245
Ok(())
@@ -246,7 +248,7 @@ impl<'a> Lexer<'a> {
246248
}
247249
}
248250

249-
fn skip(&mut self, what: Token<'a>) -> bool {
251+
fn skip(&mut self, what: Token<'_>) -> bool {
250252
let (token, rest) = lex::consume_token(self.input);
251253
if token == what {
252254
self.input = rest;
@@ -296,7 +298,7 @@ impl<'a> Lexer<'a> {
296298
Ok(pair)
297299
}
298300

299-
fn take_until(&mut self, what: Token<'a>) -> Result<Lexer<'a>, Error<'a>> {
301+
fn take_until(&mut self, what: Token<'_>) -> Result<Lexer<'a>, Error<'a>> {
300302
let original_input = self.input;
301303
let initial_len = self.input.len();
302304
let mut used_len = 0;
@@ -449,14 +451,16 @@ pub struct ParseError<'a> {
449451
pub struct Parser {
450452
scopes: Vec<Scope>,
451453
lookup_type: FastHashMap<String, Handle<crate::Type>>,
452-
std_namespace: Option<String>,
454+
function_lookup: FastHashMap<String, Handle<crate::Function>>,
455+
std_namespace: Option<Vec<String>>,
453456
}
454457

455458
impl Parser {
456459
pub fn new() -> Self {
457460
Parser {
458461
scopes: Vec::new(),
459462
lookup_type: FastHashMap::default(),
463+
function_lookup: FastHashMap::default(),
460464
std_namespace: None,
461465
}
462466
}
@@ -543,6 +547,49 @@ impl Parser {
543547
}
544548
}
545549

550+
fn parse_function_call<'a>(
551+
&mut self,
552+
lexer: &Lexer<'a>,
553+
mut ctx: ExpressionContext<'a, '_, '_>,
554+
) -> Result<Option<(crate::Expression, Lexer<'a>)>, Error<'a>> {
555+
let mut lexer = lexer.clone();
556+
557+
let external_function = if let Some(std_namespaces) = self.std_namespace.as_deref() {
558+
std_namespaces.iter().all(|namespace| {
559+
lexer.skip(Token::Word(namespace)) && lexer.skip(Token::DoubleColon)
560+
})
561+
} else {
562+
false
563+
};
564+
565+
let origin = if external_function {
566+
let function = lexer.next_ident()?;
567+
crate::FunctionOrigin::External(function.to_string())
568+
} else if let Ok(function) = lexer.next_ident() {
569+
if let Some(&function) = self.function_lookup.get(function) {
570+
crate::FunctionOrigin::Local(function)
571+
} else {
572+
return Ok(None);
573+
}
574+
} else {
575+
return Ok(None);
576+
};
577+
578+
if !lexer.skip(Token::Paren('(')) {
579+
return Ok(None);
580+
}
581+
582+
let mut arguments = Vec::new();
583+
while !lexer.skip(Token::Paren(')')) {
584+
if !arguments.is_empty() {
585+
lexer.expect(Token::Separator(','))?;
586+
}
587+
let arg = self.parse_general_expression(&mut lexer, ctx.reborrow())?;
588+
arguments.push(arg);
589+
}
590+
Ok(Some((crate::Expression::Call { origin, arguments }, lexer)))
591+
}
592+
546593
fn parse_const_expression<'a>(
547594
&mut self,
548595
lexer: &mut Lexer<'a>,
@@ -654,22 +701,11 @@ impl Parser {
654701
self.scopes.pop();
655702
return Ok(*handle);
656703
}
657-
if self.std_namespace.as_deref() == Some(word) {
658-
lexer.expect(Token::DoubleColon)?;
659-
let name = lexer.next_ident()?;
660-
let mut arguments = Vec::new();
661-
lexer.expect(Token::Paren('('))?;
662-
while !lexer.skip(Token::Paren(')')) {
663-
if !arguments.is_empty() {
664-
lexer.expect(Token::Separator(','))?;
665-
}
666-
let arg = self.parse_general_expression(lexer, ctx.reborrow())?;
667-
arguments.push(arg);
668-
}
669-
crate::Expression::Call {
670-
origin: crate::FunctionOrigin::External(name.to_owned()),
671-
arguments,
672-
}
704+
if let Some((expr, new_lexer)) =
705+
self.parse_function_call(&backup, ctx.reborrow())?
706+
{
707+
*lexer = new_lexer;
708+
expr
673709
} else {
674710
*lexer = backup;
675711
let ty = self.parse_type_decl(lexer, ctx.types)?;
@@ -1295,6 +1331,7 @@ impl Parser {
12951331
lexer: &mut Lexer<'a>,
12961332
mut context: StatementContext<'a, '_, '_>,
12971333
) -> Result<Option<crate::Statement>, Error<'a>> {
1334+
let backup = lexer.clone();
12981335
match lexer.next() {
12991336
Token::Separator(';') => Ok(Some(crate::Statement::Empty)),
13001337
Token::Paren('}') => Ok(None),
@@ -1387,15 +1424,26 @@ impl Parser {
13871424
"continue" => crate::Statement::Continue,
13881425
ident => {
13891426
// assignment
1390-
let var_expr = context.lookup_ident.lookup(ident)?;
1391-
let left = self.parse_postfix(lexer, context.as_expression(), var_expr)?;
1392-
lexer.expect(Token::Operation('='))?;
1393-
let value =
1394-
self.parse_general_expression(lexer, context.as_expression())?;
1395-
lexer.expect(Token::Separator(';'))?;
1396-
crate::Statement::Store {
1397-
pointer: left,
1398-
value,
1427+
if let Some(&var_expr) = context.lookup_ident.get(ident) {
1428+
let left =
1429+
self.parse_postfix(lexer, context.as_expression(), var_expr)?;
1430+
lexer.expect(Token::Operation('='))?;
1431+
let value =
1432+
self.parse_general_expression(lexer, context.as_expression())?;
1433+
lexer.expect(Token::Separator(';'))?;
1434+
crate::Statement::Store {
1435+
pointer: left,
1436+
value,
1437+
}
1438+
} else if let Some((expr, new_lexer)) =
1439+
self.parse_function_call(&backup, context.as_expression())?
1440+
{
1441+
*lexer = new_lexer;
1442+
context.expressions.append(expr);
1443+
lexer.expect(Token::Separator(';'))?;
1444+
crate::Statement::Empty
1445+
} else {
1446+
return Err(Error::UnknownIdent(ident));
13991447
}
14001448
}
14011449
};
@@ -1459,35 +1507,45 @@ impl Parser {
14591507
} else {
14601508
Some(self.parse_type_decl(lexer, &mut module.types)?)
14611509
};
1510+
1511+
let fun_handle = module.functions.append(crate::Function {
1512+
name: Some(fun_name.to_string()),
1513+
parameter_types,
1514+
return_type,
1515+
global_usage: Vec::new(),
1516+
local_variables: Arena::new(),
1517+
expressions,
1518+
body: Vec::new(),
1519+
});
1520+
if self
1521+
.function_lookup
1522+
.insert(fun_name.to_string(), fun_handle)
1523+
.is_some()
1524+
{
1525+
return Err(Error::FunctionRedefinition(fun_name));
1526+
}
1527+
let fun = module.functions.get_mut(fun_handle);
1528+
14621529
// read body
1463-
let mut local_variables = Arena::new();
14641530
let mut typifier = Typifier::new();
1465-
let body = self.parse_block(
1531+
fun.body = self.parse_block(
14661532
lexer,
14671533
StatementContext {
14681534
lookup_ident: &mut lookup_ident,
14691535
typifier: &mut typifier,
1470-
variables: &mut local_variables,
1471-
expressions: &mut expressions,
1536+
variables: &mut fun.local_variables,
1537+
expressions: &mut fun.expressions,
14721538
types: &mut module.types,
14731539
constants: &mut module.constants,
14741540
global_vars: &module.global_variables,
14751541
},
14761542
)?;
14771543
// done
1478-
let global_usage = crate::GlobalUse::scan(&expressions, &body, &module.global_variables);
1544+
fun.global_usage =
1545+
crate::GlobalUse::scan(&fun.expressions, &fun.body, &module.global_variables);
14791546
self.scopes.pop();
14801547

1481-
let fun = crate::Function {
1482-
name: Some(fun_name.to_owned()),
1483-
parameter_types,
1484-
return_type,
1485-
global_usage,
1486-
local_variables,
1487-
expressions,
1488-
body,
1489-
};
1490-
Ok(module.functions.append(fun))
1548+
Ok(fun_handle)
14911549
}
14921550

14931551
fn parse_global_decl<'a>(
@@ -1554,10 +1612,16 @@ impl Parser {
15541612
other => return Err(Error::Unexpected(other)),
15551613
};
15561614
lexer.expect(Token::Word("as"))?;
1557-
let namespace = lexer.next_ident()?;
1558-
lexer.expect(Token::Separator(';'))?;
1615+
let mut namespaces = Vec::new();
1616+
loop {
1617+
namespaces.push(lexer.next_ident()?.to_owned());
1618+
if lexer.skip(Token::Separator(';')) {
1619+
break;
1620+
}
1621+
lexer.expect(Token::DoubleColon)?;
1622+
}
15591623
match path {
1560-
"GLSL.std.450" => self.std_namespace = Some(namespace.to_owned()),
1624+
"GLSL.std.450" => self.std_namespace = Some(namespaces),
15611625
_ => return Err(Error::UnknownImport(path)),
15621626
}
15631627
self.scopes.pop();

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ pub enum DerivativeAxis {
423423
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
424424
pub enum FunctionOrigin {
425425
Local(Handle<Function>),
426+
// External {
427+
// namespace: String, // Maybe this should be a handle to a namespace Arena?
428+
// function: String,
429+
// },
426430
External(String),
427431
}
428432

test-data/function.wgsl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import "GLSL.std.450" as std::glsl;
2+
3+
fn test_function(test: f32) -> f32 {
4+
return test;
5+
}
6+
7+
fn main_vert() -> void {
8+
var foo: f32 = std::glsl::distance(0.0, 1.0);
9+
var test: f32 = test_function(1.0);
10+
}
11+
12+
entry_point vertex as "main" = main_vert;

0 commit comments

Comments
 (0)