Skip to content

Optimize ShaderDefs updates #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions crates/bevy_derive/src/shader_defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ pub fn derive_shader_defs(input: TokenStream) -> TokenStream {
#shader_defs_len
}

fn get_shader_def(&self, index: usize) -> Option<&str> {
fn get_shader_def(&self, index: usize) -> Option<(&str, bool)> {
use #bevy_render_path::shader::ShaderDef;
match index {
#(#shader_def_indices => if self.#shader_def_idents.is_defined() {
Some(#shader_defs)
} else {
None
#(#shader_def_indices => {
Some((#shader_defs, self.#shader_def_idents.is_defined()))
},)*
_ => None,
}
Expand Down
6 changes: 1 addition & 5 deletions crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,7 @@ impl Plugin for RenderPlugin {
stage::RENDER_GRAPH_SYSTEMS,
render_graph::render_graph_schedule_executor_system.system(),
)
.add_system_to_stage(stage::DRAW, pipeline::draw_render_pipelines_system.system())
.add_system_to_stage(
stage::POST_RENDER,
shader::clear_shader_defs_system.system(),
);
.add_system_to_stage(stage::DRAW, pipeline::draw_render_pipelines_system.system());

if app.resources().get::<Msaa>().is_none() {
app.init_resource::<Msaa>();
Expand Down
39 changes: 17 additions & 22 deletions crates/bevy_render/src/pipeline/pipeline_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ use super::{state_descriptors::PrimitiveTopology, IndexFormat, PipelineDescripto
use crate::{
pipeline::{BindType, InputStepMode, VertexBufferDescriptor},
renderer::RenderResourceContext,
shader::{Shader, ShaderError, ShaderSource},
shader::{Shader, ShaderDefSource, ShaderError, ShaderSource},
};
use bevy_asset::{Assets, Handle};
use bevy_reflect::Reflect;
use bevy_utils::{HashMap, HashSet};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};

#[derive(Clone, Eq, PartialEq, Debug, Reflect)]
pub struct PipelineSpecialization {
Expand Down Expand Up @@ -40,9 +39,17 @@ impl PipelineSpecialization {
}
}

#[derive(Clone, Eq, PartialEq, Debug, Default, Reflect, Serialize, Deserialize)]
#[derive(Clone, Eq, PartialEq, Debug, Default, Reflect)]
pub struct ShaderSpecialization {
pub shader_defs: HashSet<String>,
/// ShaderDefs tracked per-asset/component.
pub shader_defs: HashMap<ShaderDefSource, Vec<String>>,
}

impl ShaderSpecialization {
// get_specialized_shader should take a &[&str]?
pub fn get_shader_defs(&self) -> Vec<String> {
self.shader_defs.values().flatten().cloned().collect()
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -95,13 +102,8 @@ impl PipelineCompiler {
Ok(specialized_shader.shader.clone_weak())
} else {
// if no shader exists with the current configuration, create new shader and compile
let shader_def_vec = shader_specialization
.shader_defs
.iter()
.cloned()
.collect::<Vec<String>>();
let compiled_shader =
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec))?;
let compiled_shader = render_resource_context
.get_specialized_shader(shader, Some(&shader_specialization.get_shader_defs()))?;
let specialized_handle = shaders.add(compiled_shader);
let weak_specialized_handle = specialized_handle.clone_weak();
specialized_shaders.push(SpecializedShader {
Expand Down Expand Up @@ -316,17 +318,10 @@ impl PipelineCompiler {
if let Some(specialized_shaders) = self.specialized_shaders.get_mut(shader) {
for specialized_shader in specialized_shaders {
// Recompile specialized shader. If it fails, we bail immediately.
let shader_def_vec = specialized_shader
.specialization
.shader_defs
.iter()
.cloned()
.collect::<Vec<String>>();
let new_handle =
shaders.add(render_resource_context.get_specialized_shader(
shaders.get(shader).unwrap(),
Some(&shader_def_vec),
)?);
let new_handle = shaders.add(render_resource_context.get_specialized_shader(
shaders.get(shader).unwrap(),
Some(&specialized_shader.specialization.get_shader_defs()),
)?);

// Replace handle and remove old from assets.
let old_handle = std::mem::replace(&mut specialized_shader.shader, new_handle);
Expand Down
218 changes: 168 additions & 50 deletions crates/bevy_render/src/shader/shader_defs.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use bevy_asset::{Asset, Assets, Handle};

use crate::{pipeline::RenderPipelines, Texture};
use bevy_app::{EventReader, Events};
use bevy_asset::{Asset, AssetEvent, Assets, Handle, HandleUntyped};
pub use bevy_derive::ShaderDefs;
use bevy_ecs::{Query, Res};
use bevy_ecs::{Changed, Local, Mut, Query, QuerySet, Res};
use bevy_reflect::{Reflect, TypeUuid};
use bevy_utils::{HashSet, Uuid};

/// Something that can either be "defined" or "not defined". This is used to determine if a "shader def" should be considered "defined"
pub trait ShaderDef {
Expand All @@ -12,7 +14,7 @@ pub trait ShaderDef {
/// A collection of "shader defs", which define compile time definitions for shaders.
pub trait ShaderDefs {
fn shader_defs_len(&self) -> usize;
fn get_shader_def(&self, index: usize) -> Option<&str>;
fn get_shader_def(&self, index: usize) -> Option<(&str, bool)>;
fn iter_shader_defs(&self) -> ShaderDefIterator;
}

Expand All @@ -31,19 +33,11 @@ impl<'a> ShaderDefIterator<'a> {
}
}
impl<'a> Iterator for ShaderDefIterator<'a> {
type Item = &'a str;
type Item = (&'a str, bool);

fn next(&mut self) -> Option<Self::Item> {
loop {
if self.index == self.shader_defs.shader_defs_len() {
return None;
}
let shader_def = self.shader_defs.get_shader_def(self.index);
self.index += 1;
if shader_def.is_some() {
return shader_def;
}
}
self.index += 1;
self.shader_defs.get_shader_def(self.index - 1)
}
}

Expand All @@ -59,56 +53,180 @@ impl ShaderDef for Option<Handle<Texture>> {
}
}

#[derive(Debug, Clone, Hash, Eq, PartialEq, Reflect)]
pub enum ShaderDefSource {
Component(Uuid),
Asset(HandleUntyped),
}

impl<T: Asset> From<&Handle<T>> for ShaderDefSource {
fn from(h: &Handle<T>) -> Self {
Self::Asset(h.clone_weak_untyped())
}
}

impl From<Uuid> for ShaderDefSource {
fn from(uuid: Uuid) -> Self {
Self::Component(uuid)
}
}

/// Updates [RenderPipelines] with the latest [ShaderDefs]
pub fn shader_defs_system<T>(mut query: Query<(&T, &mut RenderPipelines)>)
pub fn shader_defs_system<T>(mut query: Query<(&T, &mut RenderPipelines), Changed<T>>)
where
T: ShaderDefs + TypeUuid + Send + Sync + 'static,
{
query
.iter_mut()
.map(|(s, p)| (s, (T::TYPE_UUID).into(), p))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the TypeId here instead of TypeUuid? So far we aren't requiring TypeUuid on components and we use TypeId as the unique identifier everywhere else.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Assets do require TypeUuid though)

.for_each(update_render_pipelines)
}

fn update_render_pipelines<T>(q: (&T, ShaderDefSource, Mut<RenderPipelines>))
where
T: ShaderDefs + Send + Sync + 'static,
{
for (shader_defs, mut render_pipelines) in query.iter_mut() {
for shader_def in shader_defs.iter_shader_defs() {
for render_pipeline in render_pipelines.pipelines.iter_mut() {
render_pipeline
.specialization
.shader_specialization
.shader_defs
.insert(shader_def.to_string());
}
}
}
let (shader_defs, src, mut render_pipelines) = q;

let new_defs = shader_defs
.iter_shader_defs()
// FIX: revert macro
.filter_map(|(def, defined)| if defined { Some(def.to_string()) } else { None })
.collect::<Vec<_>>();
render_pipelines.pipelines.iter_mut().for_each(|p| {
*(p.specialization
.shader_specialization
.shader_defs
.entry(src.clone())
.or_default()) = new_defs.clone();
});
}

/// Clears each [RenderPipelines]' shader defs collection
pub fn clear_shader_defs_system(mut query: Query<&mut RenderPipelines>) {
for mut render_pipelines in query.iter_mut() {
for render_pipeline in render_pipelines.pipelines.iter_mut() {
render_pipeline
.specialization
.shader_specialization
.shader_defs
.clear();
// FIX: track entities or clean this up
//#[derive(Default)]
pub struct AssetShaderDefsState<T: Asset> {
event_reader: EventReader<AssetEvent<T>>,
//entities: HashMap<Handle<T>, HashSet<Entity>>,
}

impl<T: Asset> Default for AssetShaderDefsState<T> {
fn default() -> Self {
Self {
event_reader: Default::default(),
//entities: Default::default(),
}
}
}

/// Updates [RenderPipelines] with the latest [ShaderDefs] from a given asset type
pub fn asset_shader_defs_system<T: Asset>(
pub fn asset_shader_defs_system<T>(
mut state: Local<AssetShaderDefsState<T>>,
assets: Res<Assets<T>>,
mut query: Query<(&Handle<T>, &mut RenderPipelines)>,
events: Res<Events<AssetEvent<T>>>,
mut queries: QuerySet<(
Query<(&Handle<T>, &mut RenderPipelines)>,
Query<(&Handle<T>, &mut RenderPipelines), Changed<Handle<T>>>,
)>,
) where
T: ShaderDefs + Send + Sync + 'static,
T: Default + Asset + ShaderDefs + Send + Sync + 'static,
{
for (asset_handle, mut render_pipelines) in query.iter_mut() {
if let Some(asset_handle) = assets.get(asset_handle) {
let shader_defs = asset_handle;
for shader_def in shader_defs.iter_shader_defs() {
for render_pipeline in render_pipelines.pipelines.iter_mut() {
render_pipeline
.specialization
.shader_specialization
.shader_defs
.insert(shader_def.to_string());
let changed = state
.event_reader
.iter(&events)
.fold(HashSet::default(), |mut set, event| {
match event {
AssetEvent::Created { handle } | AssetEvent::Modified { handle } => {
set.insert(handle.clone_weak());
}
AssetEvent::Removed { handle } => {
set.remove(&handle);
}
}
}
set
});

// Update for changed assets.
if !changed.is_empty() {
queries
.q0_mut()
.iter_mut()
.filter(|(h, _)| changed.contains(h))
.filter_map(|(h, p)| assets.get(h).map(|a| (a, h.into(), p)))
.for_each(update_render_pipelines);
}

// Update for changed asset handles.
queries
.q1_mut()
.iter_mut()
// Not worth?
//.filter(|(h, _)| !changed.contains(h))
.filter_map(|(h, p)| assets.get(h).map(|a| (a, h.into(), p)))
.for_each(update_render_pipelines);
}

#[cfg(test)]
mod tests {
use super::{asset_shader_defs_system, ShaderDefs};
use crate::{self as bevy_render, pipeline::RenderPipeline, prelude::RenderPipelines};
use bevy_app::App;
use bevy_asset::{AddAsset, AssetPlugin, AssetServer, Assets, Handle, HandleId};
use bevy_core::CorePlugin;
use bevy_ecs::{Commands, IntoSystem, ResMut};
use bevy_reflect::{ReflectPlugin, TypeUuid};

#[derive(Debug, Default, ShaderDefs, TypeUuid)]
#[uuid = "3130b0bf-46a6-42f2-8556-c1a04da20b7e"]
struct A {
#[shader_def]
d: bool,
}

fn shader_def_len(app: &App) -> usize {
app.world
.query::<&RenderPipelines>()
.next()
.unwrap()
.pipelines[0]
.specialization
.shader_specialization
.shader_defs
.len()
}

#[test]
fn empty_handle() {
// Insert an empty asset handle, and empty render pipelines.
fn setup(commands: &mut Commands, asset_server: ResMut<AssetServer>) {
let handle_id = HandleId::random::<A>();
let h = asset_server.get_handle::<A, HandleId>(handle_id);
let render_pipelines = RenderPipelines::from_pipelines(vec![RenderPipeline::default()]);
commands.spawn((h, render_pipelines));
};

App::build()
.add_plugin(ReflectPlugin::default())
.add_plugin(CorePlugin::default())
.add_plugin(AssetPlugin::default())
.add_asset::<A>()
.add_system(asset_shader_defs_system::<A>.system())
.add_startup_system(setup.system())
.set_runner(move |mut app: App| {
app.update();
assert_eq!(shader_def_len(&app), 0);
{
let mut assets = app.resources.get_mut::<Assets<A>>().unwrap();
let handle = app.world.query::<&Handle<A>>().next().unwrap();
assets.set(handle, A { d: true });
}

// Asset changed events are sent post-update, so we
// have to update twice to see the change.
app.update();
app.update();

assert_eq!(shader_def_len(&app), 1);
})
.run();
}
}