Skip to content

Commit d7b37f5

Browse files
committed
Track shader defs by source, handle asset changes, add test
1 parent 4c1166d commit d7b37f5

File tree

2 files changed

+180
-48
lines changed

2 files changed

+180
-48
lines changed

crates/bevy_render/src/pipeline/pipeline_compiler.rs

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@ use super::{state_descriptors::PrimitiveTopology, IndexFormat, PipelineDescripto
22
use crate::{
33
pipeline::{BindType, InputStepMode, VertexBufferDescriptor},
44
renderer::RenderResourceContext,
5-
shader::{Shader, ShaderError, ShaderSource},
5+
shader::{Shader, ShaderDefSource, ShaderError, ShaderSource},
66
};
77
use bevy_asset::{Assets, Handle};
88
use bevy_reflect::Reflect;
99
use bevy_utils::{HashMap, HashSet};
1010
use once_cell::sync::Lazy;
11-
use serde::{Deserialize, Serialize};
1211

1312
#[derive(Clone, Eq, PartialEq, Debug, Reflect)]
1413
pub struct PipelineSpecialization {
@@ -40,9 +39,17 @@ impl PipelineSpecialization {
4039
}
4140
}
4241

43-
#[derive(Clone, Eq, PartialEq, Debug, Default, Reflect, Serialize, Deserialize)]
42+
#[derive(Clone, Eq, PartialEq, Debug, Default, Reflect)]
4443
pub struct ShaderSpecialization {
45-
pub shader_defs: HashSet<String>,
44+
/// ShaderDefs tracked per-asset/component.
45+
pub shader_defs: HashMap<ShaderDefSource, Vec<String>>,
46+
}
47+
48+
impl ShaderSpecialization {
49+
// get_specialized_shader should take a &[&str]?
50+
pub fn get_shader_defs(&self) -> Vec<String> {
51+
self.shader_defs.values().flatten().cloned().collect()
52+
}
4653
}
4754

4855
#[derive(Debug)]
@@ -95,13 +102,8 @@ impl PipelineCompiler {
95102
Ok(specialized_shader.shader.clone_weak())
96103
} else {
97104
// if no shader exists with the current configuration, create new shader and compile
98-
let shader_def_vec = shader_specialization
99-
.shader_defs
100-
.iter()
101-
.cloned()
102-
.collect::<Vec<String>>();
103-
let compiled_shader =
104-
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec))?;
105+
let compiled_shader = render_resource_context
106+
.get_specialized_shader(shader, Some(&shader_specialization.get_shader_defs()))?;
105107
let specialized_handle = shaders.add(compiled_shader);
106108
let weak_specialized_handle = specialized_handle.clone_weak();
107109
specialized_shaders.push(SpecializedShader {
@@ -316,17 +318,10 @@ impl PipelineCompiler {
316318
if let Some(specialized_shaders) = self.specialized_shaders.get_mut(shader) {
317319
for specialized_shader in specialized_shaders {
318320
// Recompile specialized shader. If it fails, we bail immediately.
319-
let shader_def_vec = specialized_shader
320-
.specialization
321-
.shader_defs
322-
.iter()
323-
.cloned()
324-
.collect::<Vec<String>>();
325-
let new_handle =
326-
shaders.add(render_resource_context.get_specialized_shader(
327-
shaders.get(shader).unwrap(),
328-
Some(&shader_def_vec),
329-
)?);
321+
let new_handle = shaders.add(render_resource_context.get_specialized_shader(
322+
shaders.get(shader).unwrap(),
323+
Some(&specialized_shader.specialization.get_shader_defs()),
324+
)?);
330325

331326
// Replace handle and remove old from assets.
332327
let old_handle = std::mem::replace(&mut specialized_shader.shader, new_handle);

crates/bevy_render/src/shader/shader_defs.rs

Lines changed: 163 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
use bevy_asset::{Asset, Assets, Handle};
2-
31
use crate::{pipeline::RenderPipelines, Texture};
2+
use bevy_app::{EventReader, Events};
3+
use bevy_asset::{Asset, AssetEvent, Assets, Handle, HandleUntyped};
44
pub use bevy_derive::ShaderDefs;
5-
use bevy_ecs::{Changed, Mut, Query, Res};
5+
use bevy_ecs::{Changed, Local, Mut, Query, QuerySet, Res};
6+
use bevy_reflect::{Reflect, TypeUuid};
7+
use bevy_utils::{HashSet, Uuid};
68

79
/// Something that can either be "defined" or "not defined". This is used to determine if a "shader def" should be considered "defined"
810
pub trait ShaderDef {
@@ -51,46 +53,181 @@ impl ShaderDef for Option<Handle<Texture>> {
5153
}
5254
}
5355

56+
#[derive(Debug, Clone, Hash, Eq, PartialEq, Reflect)]
57+
pub enum ShaderDefSource {
58+
Component(Uuid),
59+
Asset(HandleUntyped),
60+
}
61+
62+
impl<T: Asset> From<&Handle<T>> for ShaderDefSource {
63+
fn from(h: &Handle<T>) -> Self {
64+
Self::Asset(h.clone_weak_untyped())
65+
}
66+
}
67+
68+
impl From<Uuid> for ShaderDefSource {
69+
fn from(uuid: Uuid) -> Self {
70+
Self::Component(uuid)
71+
}
72+
}
73+
5474
/// Updates [RenderPipelines] with the latest [ShaderDefs]
5575
pub fn shader_defs_system<T>(mut query: Query<(&T, &mut RenderPipelines), Changed<T>>)
5676
where
57-
T: ShaderDefs + Send + Sync + 'static,
77+
T: ShaderDefs + TypeUuid + Send + Sync + 'static,
5878
{
59-
query.iter_mut().for_each(update_render_pipelines)
79+
query
80+
.iter_mut()
81+
.map(|(s, p)| (s, (T::TYPE_UUID).into(), p))
82+
.for_each(update_render_pipelines)
6083
}
6184

62-
/// Insert defined shader defs and remove undefined ones from render pipelines.
63-
fn update_render_pipelines<T>(q: (&T, Mut<RenderPipelines>))
85+
fn update_render_pipelines<T>(q: (&T, ShaderDefSource, Mut<RenderPipelines>))
6486
where
6587
T: ShaderDefs + Send + Sync + 'static,
6688
{
67-
let (shader_defs, mut render_pipelines) = q;
68-
for (shader_def, is_defined) in shader_defs.iter_shader_defs() {
69-
for render_pipeline in render_pipelines.pipelines.iter_mut() {
70-
let shader_defs = &mut render_pipeline
71-
.specialization
72-
.shader_specialization
73-
.shader_defs;
74-
let s = shader_def.to_string();
75-
if is_defined {
76-
shader_defs.insert(s);
77-
} else {
78-
shader_defs.remove(&s);
79-
}
89+
let (shader_defs, src, mut render_pipelines) = q;
90+
91+
let new_defs = shader_defs
92+
.iter_shader_defs()
93+
// FIX: revert macro
94+
.filter_map(|(def, defined)| if defined { Some(def.to_string()) } else { None })
95+
.collect::<Vec<_>>();
96+
render_pipelines.pipelines.iter_mut().for_each(|p| {
97+
*(p.specialization
98+
.shader_specialization
99+
.shader_defs
100+
.entry(src.clone())
101+
.or_default()) = new_defs.clone();
102+
});
103+
}
104+
105+
// FIX: track entities or clean this up
106+
//#[derive(Default)]
107+
pub struct AssetShaderDefsState<T: Asset> {
108+
event_reader: EventReader<AssetEvent<T>>,
109+
//entities: HashMap<Handle<T>, HashSet<Entity>>,
110+
}
111+
112+
impl<T: Asset> Default for AssetShaderDefsState<T> {
113+
fn default() -> Self {
114+
Self {
115+
event_reader: Default::default(),
116+
//entities: Default::default(),
80117
}
81118
}
82119
}
83120

84121
/// Updates [RenderPipelines] with the latest [ShaderDefs] from a given asset type
85-
pub fn asset_shader_defs_system<T: Asset>(
122+
pub fn asset_shader_defs_system<T>(
123+
mut state: Local<AssetShaderDefsState<T>>,
86124
assets: Res<Assets<T>>,
87-
mut query: Query<(&Handle<T>, &mut RenderPipelines), Changed<Handle<T>>>,
125+
events: Res<Events<AssetEvent<T>>>,
126+
mut queries: QuerySet<(
127+
Query<(&Handle<T>, &mut RenderPipelines)>,
128+
Query<(&Handle<T>, &mut RenderPipelines), Changed<Handle<T>>>,
129+
)>,
88130
) where
89-
T: ShaderDefs + Send + Sync + 'static,
131+
T: Default + Asset + ShaderDefs + Send + Sync + 'static,
90132
{
91-
query
133+
let changed = state
134+
.event_reader
135+
.iter(&events)
136+
.fold(HashSet::default(), |mut set, event| {
137+
match event {
138+
AssetEvent::Created { handle } | AssetEvent::Modified { handle } => {
139+
set.insert(handle.clone_weak());
140+
}
141+
AssetEvent::Removed { handle } => {
142+
set.remove(&handle);
143+
}
144+
}
145+
set
146+
});
147+
148+
// Update for changed assets.
149+
if changed.len() > 0 {
150+
queries
151+
.q0_mut()
152+
.iter_mut()
153+
.filter(|(h, _)| changed.contains(h))
154+
.filter_map(|(h, p)| assets.get(h).map(|a| (a, h.into(), p)))
155+
.for_each(update_render_pipelines);
156+
}
157+
158+
// Update for changed asset handles.
159+
queries
160+
.q1_mut()
92161
.iter_mut()
93-
// (Handle<T>, _) -> (&T, _)
94-
.filter_map(|(h, p)| assets.get(h).map(|a| (a, p)))
162+
// Not worth?
163+
//.filter(|(h, _)| !changed.contains(h))
164+
.filter_map(|(h, p)| assets.get(h).map(|a| (a, h.into(), p)))
95165
.for_each(update_render_pipelines);
96166
}
167+
168+
#[cfg(test)]
169+
mod tests {
170+
use super::asset_shader_defs_system;
171+
use super::ShaderDefs;
172+
use crate::{self as bevy_render, pipeline::RenderPipeline, prelude::RenderPipelines};
173+
use bevy_app::App;
174+
use bevy_asset::{AddAsset, AssetPlugin, AssetServer, Assets, HandleId};
175+
use bevy_core::CorePlugin;
176+
use bevy_ecs::{Commands, ResMut};
177+
use bevy_reflect::{ReflectPlugin, TypeUuid};
178+
179+
#[derive(Debug, Default, ShaderDefs, TypeUuid)]
180+
#[uuid = "3130b0bf-46a6-42f2-8556-c1a04da20b7e"]
181+
struct A {
182+
#[shader_def]
183+
d: bool,
184+
}
185+
186+
fn shader_def_len(app: &App) -> usize {
187+
app.world
188+
.query::<&RenderPipelines>()
189+
.next()
190+
.unwrap()
191+
.pipelines[0]
192+
.specialization
193+
.shader_specialization
194+
.shader_defs
195+
.len()
196+
}
197+
198+
#[test]
199+
fn empty_handle() {
200+
// Insert an empty asset handle, and empty render pipelines.
201+
let handle_id = HandleId::random::<A>();
202+
let setup = move |commands: &mut Commands, asset_server: ResMut<AssetServer>| {
203+
let h = asset_server.get_handle::<A, HandleId>(handle_id);
204+
let render_pipelines = RenderPipelines::from_pipelines(vec![RenderPipeline::default()]);
205+
commands.spawn((h, render_pipelines));
206+
};
207+
208+
App::build()
209+
.add_plugin(ReflectPlugin::default())
210+
.add_plugin(CorePlugin::default())
211+
.add_plugin(AssetPlugin::default())
212+
.add_asset::<A>()
213+
.add_system(asset_shader_defs_system::<A>)
214+
.add_startup_system(setup)
215+
.set_runner(move |mut app: App| {
216+
app.initialize();
217+
app.update();
218+
assert_eq!(shader_def_len(&app), 0);
219+
{
220+
let mut asset_server = app.resources.get_mut::<Assets<A>>().unwrap();
221+
asset_server.set(handle_id, A { d: true });
222+
}
223+
224+
// Asset changed events are sent post-update, so we
225+
// have to update twice to see the change.
226+
app.update();
227+
app.update();
228+
229+
assert_eq!(shader_def_len(&app), 1);
230+
})
231+
.run();
232+
}
233+
}

0 commit comments

Comments
 (0)