Skip to content

Commit a3c53e6

Browse files
committed
Shader Processor: process imported shader (#3290)
# Objective - I want to be able to use `#ifdef` and other processor directives in an imported shader ## Solution - Process imported shader strings Co-authored-by: François <[email protected]>
1 parent b5d7ff2 commit a3c53e6

File tree

1 file changed

+157
-36
lines changed
  • crates/bevy_render/src/render_resource

1 file changed

+157
-36
lines changed

crates/bevy_render/src/render_resource/shader.rs

Lines changed: 157 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,16 @@ impl ShaderProcessor {
360360
}
361361
};
362362

363-
let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
363+
let shader_defs_unique = HashSet::<String>::from_iter(shader_defs.iter().cloned());
364364
let mut scopes = vec![true];
365365
let mut final_string = String::new();
366366
for line in shader_str.split('\n') {
367367
if let Some(cap) = self.ifdef_regex.captures(line) {
368368
let def = cap.get(1).unwrap();
369-
scopes.push(*scopes.last().unwrap() && shader_defs.contains(def.as_str()));
369+
scopes.push(*scopes.last().unwrap() && shader_defs_unique.contains(def.as_str()));
370370
} else if let Some(cap) = self.ifndef_regex.captures(line) {
371371
let def = cap.get(1).unwrap();
372-
scopes.push(*scopes.last().unwrap() && !shader_defs.contains(def.as_str()));
372+
scopes.push(*scopes.last().unwrap() && !shader_defs_unique.contains(def.as_str()));
373373
} else if self.else_regex.is_match(line) {
374374
let mut is_parent_scope_truthy = true;
375375
if scopes.len() > 1 {
@@ -388,19 +388,32 @@ impl ShaderProcessor {
388388
.captures(line)
389389
{
390390
let import = ShaderImport::AssetPath(cap.get(1).unwrap().as_str().to_string());
391-
apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
391+
self.apply_import(
392+
import_handles,
393+
shaders,
394+
&import,
395+
shader,
396+
shader_defs,
397+
&mut final_string,
398+
)?;
392399
} else if let Some(cap) = SHADER_IMPORT_PROCESSOR
393400
.import_custom_path_regex
394401
.captures(line)
395402
{
396403
let import = ShaderImport::Custom(cap.get(1).unwrap().as_str().to_string());
397-
apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
404+
self.apply_import(
405+
import_handles,
406+
shaders,
407+
&import,
408+
shader,
409+
shader_defs,
410+
&mut final_string,
411+
)?;
398412
} else if *scopes.last().unwrap() {
399413
final_string.push_str(line);
400414
final_string.push('\n');
401415
}
402416
}
403-
404417
final_string.pop();
405418

406419
if scopes.len() != 1 {
@@ -417,45 +430,51 @@ impl ShaderProcessor {
417430
}
418431
}
419432
}
420-
}
421433

422-
fn apply_import(
423-
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
424-
shaders: &HashMap<Handle<Shader>, Shader>,
425-
import: &ShaderImport,
426-
shader: &Shader,
427-
final_string: &mut String,
428-
) -> Result<(), ProcessShaderError> {
429-
let imported_shader = import_handles
430-
.get(import)
431-
.and_then(|handle| shaders.get(handle))
432-
.ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
433-
match &shader.source {
434-
Source::Wgsl(_) => {
435-
if let Source::Wgsl(import_source) = &imported_shader.source {
436-
final_string.push_str(import_source);
437-
} else {
438-
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
434+
fn apply_import(
435+
&self,
436+
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
437+
shaders: &HashMap<Handle<Shader>, Shader>,
438+
import: &ShaderImport,
439+
shader: &Shader,
440+
shader_defs: &[String],
441+
final_string: &mut String,
442+
) -> Result<(), ProcessShaderError> {
443+
let imported_shader = import_handles
444+
.get(import)
445+
.and_then(|handle| shaders.get(handle))
446+
.ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
447+
let imported_processed =
448+
self.process(imported_shader, shader_defs, shaders, import_handles)?;
449+
450+
match &shader.source {
451+
Source::Wgsl(_) => {
452+
if let ProcessedShader::Wgsl(import_source) = &imported_processed {
453+
final_string.push_str(import_source);
454+
} else {
455+
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
456+
}
439457
}
440-
}
441-
Source::Glsl(_, _) => {
442-
if let Source::Glsl(import_source, _) = &imported_shader.source {
443-
final_string.push_str(import_source);
444-
} else {
445-
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
458+
Source::Glsl(_, _) => {
459+
if let ProcessedShader::Glsl(import_source, _) = &imported_processed {
460+
final_string.push_str(import_source);
461+
} else {
462+
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
463+
}
464+
}
465+
Source::SpirV(_) => {
466+
return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports);
446467
}
447468
}
448-
Source::SpirV(_) => {
449-
return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports);
450-
}
451-
}
452469

453-
Ok(())
470+
Ok(())
471+
}
454472
}
455473

456474
#[cfg(test)]
457475
mod tests {
458-
use bevy_asset::Handle;
476+
use bevy_asset::{Handle, HandleUntyped};
477+
use bevy_reflect::TypeUuid;
459478
use bevy_utils::HashMap;
460479
use naga::ShaderStage;
461480

@@ -1081,4 +1100,106 @@ fn vertex(
10811100
.unwrap();
10821101
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
10831102
}
1103+
1104+
#[test]
1105+
fn process_import_ifdef() {
1106+
#[rustfmt::skip]
1107+
const FOO: &str = r"
1108+
#ifdef IMPORT_MISSING
1109+
fn in_import_missing() { }
1110+
#endif
1111+
#ifdef IMPORT_PRESENT
1112+
fn in_import_present() { }
1113+
#endif
1114+
";
1115+
#[rustfmt::skip]
1116+
const INPUT: &str = r"
1117+
#import FOO
1118+
#ifdef MAIN_MISSING
1119+
fn in_main_missing() { }
1120+
#endif
1121+
#ifdef MAIN_PRESENT
1122+
fn in_main_present() { }
1123+
#endif
1124+
";
1125+
#[rustfmt::skip]
1126+
const EXPECTED: &str = r"
1127+
1128+
fn in_import_present() { }
1129+
fn in_main_present() { }
1130+
";
1131+
let processor = ShaderProcessor::default();
1132+
let mut shaders = HashMap::default();
1133+
let mut import_handles = HashMap::default();
1134+
let foo_handle = Handle::<Shader>::default();
1135+
shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
1136+
import_handles.insert(
1137+
ShaderImport::Custom("FOO".to_string()),
1138+
foo_handle.clone_weak(),
1139+
);
1140+
let result = processor
1141+
.process(
1142+
&Shader::from_wgsl(INPUT),
1143+
&["MAIN_PRESENT".to_string(), "IMPORT_PRESENT".to_string()],
1144+
&shaders,
1145+
&import_handles,
1146+
)
1147+
.unwrap();
1148+
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
1149+
}
1150+
1151+
#[test]
1152+
fn process_import_in_import() {
1153+
#[rustfmt::skip]
1154+
const BAR: &str = r"
1155+
#ifdef DEEP
1156+
fn inner_import() { }
1157+
#endif
1158+
";
1159+
const FOO: &str = r"
1160+
#import BAR
1161+
fn import() { }
1162+
";
1163+
#[rustfmt::skip]
1164+
const INPUT: &str = r"
1165+
#import FOO
1166+
fn in_main() { }
1167+
";
1168+
#[rustfmt::skip]
1169+
const EXPECTED: &str = r"
1170+
1171+
1172+
fn inner_import() { }
1173+
fn import() { }
1174+
fn in_main() { }
1175+
";
1176+
let processor = ShaderProcessor::default();
1177+
let mut shaders = HashMap::default();
1178+
let mut import_handles = HashMap::default();
1179+
{
1180+
let bar_handle = Handle::<Shader>::default();
1181+
shaders.insert(bar_handle.clone_weak(), Shader::from_wgsl(BAR));
1182+
import_handles.insert(
1183+
ShaderImport::Custom("BAR".to_string()),
1184+
bar_handle.clone_weak(),
1185+
);
1186+
}
1187+
{
1188+
let foo_handle = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1).typed();
1189+
shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
1190+
import_handles.insert(
1191+
ShaderImport::Custom("FOO".to_string()),
1192+
foo_handle.clone_weak(),
1193+
);
1194+
}
1195+
let result = processor
1196+
.process(
1197+
&Shader::from_wgsl(INPUT),
1198+
&["DEEP".to_string()],
1199+
&shaders,
1200+
&import_handles,
1201+
)
1202+
.unwrap();
1203+
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
1204+
}
10841205
}

0 commit comments

Comments
 (0)