Skip to content

[Merged by Bors] - Replace ReadOnlyFetch with ReadOnlyWorldQuery #4626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 73 additions & 24 deletions crates/bevy_ecs/macros/src/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
user_generics_with_world.split_for_impl();

let struct_name = ast.ident.clone();
let read_only_struct_name = if fetch_struct_attributes.is_mutable {
Ident::new(&format!("{}ReadOnly", struct_name), Span::call_site())
} else {
struct_name.clone()
};

let item_struct_name = Ident::new(&format!("{}Item", struct_name), Span::call_site());
let read_only_item_struct_name = if fetch_struct_attributes.is_mutable {
Expand Down Expand Up @@ -178,7 +183,8 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#(#ignored_field_idents: #ignored_field_types,)*
}

impl #user_impl_generics_with_world #path::query::Fetch<'__w>
// SAFETY: `update_component_access` and `update_archetype_component_access` are called on every field
unsafe impl #user_impl_generics_with_world #path::query::Fetch<'__w>
for #fetch_struct_name #user_ty_generics_with_world #user_where_clauses_with_world {

type Item = #item_struct_name #user_ty_generics_with_world;
Expand Down Expand Up @@ -248,6 +254,17 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
unsafe fn archetype_filter_fetch(&mut self, _archetype_index: usize) -> bool {
true #(&& self.#field_idents.archetype_filter_fetch(_archetype_index))*
}

fn update_component_access(state: &Self::State, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) {
#( #path::query::#fetch_type_alias::<'static, #field_types> :: update_component_access(&state.#field_idents, _access); )*
}

fn update_archetype_component_access(state: &Self::State, _archetype: &#path::archetype::Archetype, _access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>) {
#(
#path::query::#fetch_type_alias::<'static, #field_types>
:: update_archetype_component_access(&state.#field_idents, _archetype, _access);
)*
}
}
}
};
Expand All @@ -262,23 +279,14 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#(#ignored_field_idents: #ignored_field_types,)*
}

// SAFETY: `update_component_access` and `update_archetype_component_access` are called for each item in the struct
unsafe impl #user_impl_generics #path::query::FetchState for #state_struct_name #user_ty_generics #user_where_clauses {
impl #user_impl_generics #path::query::FetchState for #state_struct_name #user_ty_generics #user_where_clauses {
fn init(world: &mut #path::world::World) -> Self {
#state_struct_name {
#(#field_idents: <<#field_types as #path::query::WorldQuery>::State as #path::query::FetchState>::init(world),)*
#(#ignored_field_idents: Default::default(),)*
}
}

fn update_component_access(&self, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) {
#(self.#field_idents.update_component_access(_access);)*
}

fn update_archetype_component_access(&self, _archetype: &#path::archetype::Archetype, _access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>) {
#(self.#field_idents.update_archetype_component_access(_archetype, _access);)*
}

fn matches_component_set(&self, _set_contains_id: &impl Fn(#path::component::ComponentId) -> bool) -> bool {
true #(&& self.#field_idents.matches_component_set(_set_contains_id))*

Expand All @@ -290,29 +298,66 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
impl_fetch(
true,
read_only_fetch_struct_name.clone(),
read_only_item_struct_name,
read_only_item_struct_name.clone(),
)
} else {
quote! {}
};

let read_only_world_query_impl = if fetch_struct_attributes.is_mutable {
quote! {
#[automatically_derived]
#visibility struct #read_only_struct_name #user_impl_generics #user_where_clauses {
#( #field_idents: < #field_types as #path::query::WorldQuery >::ReadOnly, )*
#(#(#ignored_field_attrs)* #ignored_field_visibilities #ignored_field_idents: #ignored_field_types,)*
}

// SAFETY: `ROQueryFetch<Self>` is the same as `QueryFetch<Self>`
unsafe impl #user_impl_generics #path::query::WorldQuery for #read_only_struct_name #user_ty_generics #user_where_clauses {
type ReadOnly = Self;
type State = #state_struct_name #user_ty_generics;

fn shrink<'__wlong: '__wshort, '__wshort>(item: #path::query::#item_type_alias<'__wlong, Self>)
-> #path::query::#item_type_alias<'__wshort, Self> {
#read_only_item_struct_name {
#(
#field_idents : <
< #field_types as #path::query::WorldQuery >::ReadOnly as #path::query::WorldQuery
> :: shrink( item.#field_idents ),
)*
#(
#ignored_field_idents: item.#ignored_field_idents,
)*
}
}
}

impl #user_impl_generics_with_world #path::query::WorldQueryGats<'__w> for #read_only_struct_name #user_ty_generics #user_where_clauses {
type Fetch = #read_only_fetch_struct_name #user_ty_generics_with_world;
type _State = #state_struct_name #user_ty_generics;
}
}
} else {
quote! {}
};

let read_only_asserts = if fetch_struct_attributes.is_mutable {
quote! {
// Double-check that the data fetched by `ROQueryFetch` is read-only.
// This is technically unnecessary as `<_ as WorldQueryGats<'world>>::ReadOnlyFetch: ReadOnlyFetch`
// but to protect against future mistakes we assert the assoc type implements `ReadOnlyFetch` anyway
#( assert_readonly::<#path::query::ROQueryFetch<'__w, #field_types>>(); )*
// Double-check that the data fetched by `<_ as WorldQuery>::ReadOnly` is read-only.
// This is technically unnecessary as `<_ as WorldQuery>::ReadOnly: ReadOnlyWorldQuery`
// but to protect against future mistakes we assert the assoc type implements `ReadOnlyWorldQuery` anyway
#( assert_readonly::< < #field_types as #path::query::WorldQuery > :: ReadOnly >(); )*
}
} else {
quote! {
// Statically checks that the safety guarantee of `ReadOnlyFetch` for `$fetch_struct_name` actually holds true.
// We need this to make sure that we don't compile `ReadOnlyFetch` if our struct contains nested `WorldQuery`
// Statically checks that the safety guarantee of `ReadOnlyWorldQuery` for `$fetch_struct_name` actually holds true.
// We need this to make sure that we don't compile `ReadOnlyWorldQuery` if our struct contains nested `WorldQuery`
// members that don't implement it. I.e.:
// ```
// #[derive(WorldQuery)]
// pub struct Foo { a: &'static mut MyComponent }
// ```
#( assert_readonly::<#path::query::QueryFetch<'__w, #field_types>>(); )*
#( assert_readonly::<#field_types>(); )*
}
};

Expand All @@ -323,7 +368,12 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {

#read_only_fetch_impl

impl #user_impl_generics #path::query::WorldQuery for #struct_name #user_ty_generics #user_where_clauses {
#read_only_world_query_impl

// SAFETY: if the worldquery is mutable this defers to soundness of the `#field_types: WorldQuery` impl, otherwise
// if the world query is immutable then `#read_only_struct_name #user_ty_generics` is the same type as `#struct_name #user_ty_generics`
unsafe impl #user_impl_generics #path::query::WorldQuery for #struct_name #user_ty_generics #user_where_clauses {
type ReadOnly = #read_only_struct_name #user_ty_generics;
type State = #state_struct_name #user_ty_generics;
fn shrink<'__wlong: '__wshort, '__wshort>(item: #path::query::#item_type_alias<'__wlong, Self>)
-> #path::query::#item_type_alias<'__wshort, Self> {
Expand All @@ -340,19 +390,18 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {

impl #user_impl_generics_with_world #path::query::WorldQueryGats<'__w> for #struct_name #user_ty_generics #user_where_clauses {
type Fetch = #fetch_struct_name #user_ty_generics_with_world;
type ReadOnlyFetch = #read_only_fetch_struct_name #user_ty_generics_with_world;
type _State = #state_struct_name #user_ty_generics;
}

/// SAFETY: each item in the struct is read only
unsafe impl #user_impl_generics_with_world #path::query::ReadOnlyFetch
for #read_only_fetch_struct_name #user_ty_generics_with_world #user_where_clauses_with_world {}
unsafe impl #user_impl_generics #path::query::ReadOnlyWorldQuery
for #read_only_struct_name #user_ty_generics #user_where_clauses {}

#[allow(dead_code)]
const _: () = {
fn assert_readonly<T>()
where
T: #path::query::ReadOnlyFetch,
T: #path::query::ReadOnlyWorldQuery,
{
}

Expand Down
Loading