@@ -360,16 +360,16 @@ impl ShaderProcessor {
360
360
}
361
361
} ;
362
362
363
- let shader_defs = HashSet :: < String > :: from_iter ( shader_defs. iter ( ) . cloned ( ) ) ;
363
+ let shader_defs_unique = HashSet :: < String > :: from_iter ( shader_defs. iter ( ) . cloned ( ) ) ;
364
364
let mut scopes = vec ! [ true ] ;
365
365
let mut final_string = String :: new ( ) ;
366
366
for line in shader_str. split ( '\n' ) {
367
367
if let Some ( cap) = self . ifdef_regex . captures ( line) {
368
368
let def = cap. get ( 1 ) . unwrap ( ) ;
369
- scopes. push ( * scopes. last ( ) . unwrap ( ) && shader_defs . contains ( def. as_str ( ) ) ) ;
369
+ scopes. push ( * scopes. last ( ) . unwrap ( ) && shader_defs_unique . contains ( def. as_str ( ) ) ) ;
370
370
} else if let Some ( cap) = self . ifndef_regex . captures ( line) {
371
371
let def = cap. get ( 1 ) . unwrap ( ) ;
372
- scopes. push ( * scopes. last ( ) . unwrap ( ) && !shader_defs . contains ( def. as_str ( ) ) ) ;
372
+ scopes. push ( * scopes. last ( ) . unwrap ( ) && !shader_defs_unique . contains ( def. as_str ( ) ) ) ;
373
373
} else if self . else_regex . is_match ( line) {
374
374
let mut is_parent_scope_truthy = true ;
375
375
if scopes. len ( ) > 1 {
@@ -388,19 +388,32 @@ impl ShaderProcessor {
388
388
. captures ( line)
389
389
{
390
390
let import = ShaderImport :: AssetPath ( cap. get ( 1 ) . unwrap ( ) . as_str ( ) . to_string ( ) ) ;
391
- apply_import ( import_handles, shaders, & import, shader, & mut final_string) ?;
391
+ self . apply_import (
392
+ import_handles,
393
+ shaders,
394
+ & import,
395
+ shader,
396
+ shader_defs,
397
+ & mut final_string,
398
+ ) ?;
392
399
} else if let Some ( cap) = SHADER_IMPORT_PROCESSOR
393
400
. import_custom_path_regex
394
401
. captures ( line)
395
402
{
396
403
let import = ShaderImport :: Custom ( cap. get ( 1 ) . unwrap ( ) . as_str ( ) . to_string ( ) ) ;
397
- apply_import ( import_handles, shaders, & import, shader, & mut final_string) ?;
404
+ self . apply_import (
405
+ import_handles,
406
+ shaders,
407
+ & import,
408
+ shader,
409
+ shader_defs,
410
+ & mut final_string,
411
+ ) ?;
398
412
} else if * scopes. last ( ) . unwrap ( ) {
399
413
final_string. push_str ( line) ;
400
414
final_string. push ( '\n' ) ;
401
415
}
402
416
}
403
-
404
417
final_string. pop ( ) ;
405
418
406
419
if scopes. len ( ) != 1 {
@@ -417,45 +430,51 @@ impl ShaderProcessor {
417
430
}
418
431
}
419
432
}
420
- }
421
433
422
- fn apply_import (
423
- import_handles : & HashMap < ShaderImport , Handle < Shader > > ,
424
- shaders : & HashMap < Handle < Shader > , Shader > ,
425
- import : & ShaderImport ,
426
- shader : & Shader ,
427
- final_string : & mut String ,
428
- ) -> Result < ( ) , ProcessShaderError > {
429
- let imported_shader = import_handles
430
- . get ( import)
431
- . and_then ( |handle| shaders. get ( handle) )
432
- . ok_or_else ( || ProcessShaderError :: UnresolvedImport ( import. clone ( ) ) ) ?;
433
- match & shader. source {
434
- Source :: Wgsl ( _) => {
435
- if let Source :: Wgsl ( import_source) = & imported_shader. source {
436
- final_string. push_str ( import_source) ;
437
- } else {
438
- return Err ( ProcessShaderError :: MismatchedImportFormat ( import. clone ( ) ) ) ;
434
+ fn apply_import (
435
+ & self ,
436
+ import_handles : & HashMap < ShaderImport , Handle < Shader > > ,
437
+ shaders : & HashMap < Handle < Shader > , Shader > ,
438
+ import : & ShaderImport ,
439
+ shader : & Shader ,
440
+ shader_defs : & [ String ] ,
441
+ final_string : & mut String ,
442
+ ) -> Result < ( ) , ProcessShaderError > {
443
+ let imported_shader = import_handles
444
+ . get ( import)
445
+ . and_then ( |handle| shaders. get ( handle) )
446
+ . ok_or_else ( || ProcessShaderError :: UnresolvedImport ( import. clone ( ) ) ) ?;
447
+ let imported_processed =
448
+ self . process ( imported_shader, shader_defs, shaders, import_handles) ?;
449
+
450
+ match & shader. source {
451
+ Source :: Wgsl ( _) => {
452
+ if let ProcessedShader :: Wgsl ( import_source) = & imported_processed {
453
+ final_string. push_str ( import_source) ;
454
+ } else {
455
+ return Err ( ProcessShaderError :: MismatchedImportFormat ( import. clone ( ) ) ) ;
456
+ }
439
457
}
440
- }
441
- Source :: Glsl ( _, _) => {
442
- if let Source :: Glsl ( import_source, _) = & imported_shader. source {
443
- final_string. push_str ( import_source) ;
444
- } else {
445
- return Err ( ProcessShaderError :: MismatchedImportFormat ( import. clone ( ) ) ) ;
458
+ Source :: Glsl ( _, _) => {
459
+ if let ProcessedShader :: Glsl ( import_source, _) = & imported_processed {
460
+ final_string. push_str ( import_source) ;
461
+ } else {
462
+ return Err ( ProcessShaderError :: MismatchedImportFormat ( import. clone ( ) ) ) ;
463
+ }
464
+ }
465
+ Source :: SpirV ( _) => {
466
+ return Err ( ProcessShaderError :: ShaderFormatDoesNotSupportImports ) ;
446
467
}
447
468
}
448
- Source :: SpirV ( _) => {
449
- return Err ( ProcessShaderError :: ShaderFormatDoesNotSupportImports ) ;
450
- }
451
- }
452
469
453
- Ok ( ( ) )
470
+ Ok ( ( ) )
471
+ }
454
472
}
455
473
456
474
#[ cfg( test) ]
457
475
mod tests {
458
- use bevy_asset:: Handle ;
476
+ use bevy_asset:: { Handle , HandleUntyped } ;
477
+ use bevy_reflect:: TypeUuid ;
459
478
use bevy_utils:: HashMap ;
460
479
use naga:: ShaderStage ;
461
480
@@ -1081,4 +1100,106 @@ fn vertex(
1081
1100
. unwrap ( ) ;
1082
1101
assert_eq ! ( result. get_wgsl_source( ) . unwrap( ) , EXPECTED ) ;
1083
1102
}
1103
+
1104
+ #[ test]
1105
+ fn process_import_ifdef ( ) {
1106
+ #[ rustfmt:: skip]
1107
+ const FOO : & str = r"
1108
+ #ifdef IMPORT_MISSING
1109
+ fn in_import_missing() { }
1110
+ #endif
1111
+ #ifdef IMPORT_PRESENT
1112
+ fn in_import_present() { }
1113
+ #endif
1114
+ " ;
1115
+ #[ rustfmt:: skip]
1116
+ const INPUT : & str = r"
1117
+ #import FOO
1118
+ #ifdef MAIN_MISSING
1119
+ fn in_main_missing() { }
1120
+ #endif
1121
+ #ifdef MAIN_PRESENT
1122
+ fn in_main_present() { }
1123
+ #endif
1124
+ " ;
1125
+ #[ rustfmt:: skip]
1126
+ const EXPECTED : & str = r"
1127
+
1128
+ fn in_import_present() { }
1129
+ fn in_main_present() { }
1130
+ " ;
1131
+ let processor = ShaderProcessor :: default ( ) ;
1132
+ let mut shaders = HashMap :: default ( ) ;
1133
+ let mut import_handles = HashMap :: default ( ) ;
1134
+ let foo_handle = Handle :: < Shader > :: default ( ) ;
1135
+ shaders. insert ( foo_handle. clone_weak ( ) , Shader :: from_wgsl ( FOO ) ) ;
1136
+ import_handles. insert (
1137
+ ShaderImport :: Custom ( "FOO" . to_string ( ) ) ,
1138
+ foo_handle. clone_weak ( ) ,
1139
+ ) ;
1140
+ let result = processor
1141
+ . process (
1142
+ & Shader :: from_wgsl ( INPUT ) ,
1143
+ & [ "MAIN_PRESENT" . to_string ( ) , "IMPORT_PRESENT" . to_string ( ) ] ,
1144
+ & shaders,
1145
+ & import_handles,
1146
+ )
1147
+ . unwrap ( ) ;
1148
+ assert_eq ! ( result. get_wgsl_source( ) . unwrap( ) , EXPECTED ) ;
1149
+ }
1150
+
1151
+ #[ test]
1152
+ fn process_import_in_import ( ) {
1153
+ #[ rustfmt:: skip]
1154
+ const BAR : & str = r"
1155
+ #ifdef DEEP
1156
+ fn inner_import() { }
1157
+ #endif
1158
+ " ;
1159
+ const FOO : & str = r"
1160
+ #import BAR
1161
+ fn import() { }
1162
+ " ;
1163
+ #[ rustfmt:: skip]
1164
+ const INPUT : & str = r"
1165
+ #import FOO
1166
+ fn in_main() { }
1167
+ " ;
1168
+ #[ rustfmt:: skip]
1169
+ const EXPECTED : & str = r"
1170
+
1171
+
1172
+ fn inner_import() { }
1173
+ fn import() { }
1174
+ fn in_main() { }
1175
+ " ;
1176
+ let processor = ShaderProcessor :: default ( ) ;
1177
+ let mut shaders = HashMap :: default ( ) ;
1178
+ let mut import_handles = HashMap :: default ( ) ;
1179
+ {
1180
+ let bar_handle = Handle :: < Shader > :: default ( ) ;
1181
+ shaders. insert ( bar_handle. clone_weak ( ) , Shader :: from_wgsl ( BAR ) ) ;
1182
+ import_handles. insert (
1183
+ ShaderImport :: Custom ( "BAR" . to_string ( ) ) ,
1184
+ bar_handle. clone_weak ( ) ,
1185
+ ) ;
1186
+ }
1187
+ {
1188
+ let foo_handle = HandleUntyped :: weak_from_u64 ( Shader :: TYPE_UUID , 1 ) . typed ( ) ;
1189
+ shaders. insert ( foo_handle. clone_weak ( ) , Shader :: from_wgsl ( FOO ) ) ;
1190
+ import_handles. insert (
1191
+ ShaderImport :: Custom ( "FOO" . to_string ( ) ) ,
1192
+ foo_handle. clone_weak ( ) ,
1193
+ ) ;
1194
+ }
1195
+ let result = processor
1196
+ . process (
1197
+ & Shader :: from_wgsl ( INPUT ) ,
1198
+ & [ "DEEP" . to_string ( ) ] ,
1199
+ & shaders,
1200
+ & import_handles,
1201
+ )
1202
+ . unwrap ( ) ;
1203
+ assert_eq ! ( result. get_wgsl_source( ) . unwrap( ) , EXPECTED ) ;
1204
+ }
1084
1205
}
0 commit comments