44import itertools
55import warnings
66from collections import defaultdict
7- from typing import TYPE_CHECKING , AbstractSet , Any , Final , Iterable , Iterator , TypeVar , overload
7+ from typing import TYPE_CHECKING , AbstractSet , Any , Iterable , Iterator , TypeVar , overload
88from weakref import WeakKeyDictionary , WeakSet
99
1010import attrs
@@ -192,6 +192,7 @@ def _get_query(w_query: WorldQuery) -> set[Entity]:
192192 """Return the entities for the given query and world."""
193193 world = w_query .world
194194 query = w_query ._query
195+ assert query == query ._normalized (), "Double checks that relations are correct"
195196 cache = _get_query_cache (world )
196197 if cache is not None :
197198 cached_entities = cache .queries .get (query )
@@ -204,6 +205,22 @@ def _get_query(w_query: WorldQuery) -> set[Entity]:
204205 return entities
205206
206207
208+ def _normalize_query_relation (relation : _RelationQuery ) -> _RelationQuery :
209+ """Normalize a relation query.
210+
211+ This adds the inverse lookup to the sub-query so that this only matches entities which have a relation.
212+ """
213+ if len (relation ) == 2 : # noqa: PLR2004
214+ tag , targets = relation # type: ignore[misc] # https://github.com/python/mypy/issues/1178
215+ if isinstance (targets , WorldQuery ): # (tag, targets)
216+ return tag , targets .all_of (relations = [(..., tag , None )])
217+ return relation
218+ origin , tag , _ = relation # type: ignore[misc] # https://github.com/python/mypy/issues/1178
219+ if isinstance (origin , WorldQuery ): # (origins, tag, None)
220+ return origin .all_of (relations = [(tag , ...)]), tag , None
221+ return relation
222+
223+
207224@attrs .define (frozen = True )
208225class Query :
209226 """A set of conditions used to lookup entities in a World."""
@@ -236,7 +253,7 @@ def all_of(
236253 self ._none_of_components ,
237254 self ._all_of_tags .union (tags ),
238255 self ._none_of_tags ,
239- self ._all_of_relations .union (relations ),
256+ self ._all_of_relations .union (_normalize_query_relation ( relation ) for relation in relations ),
240257 self ._none_of_relations ,
241258 )
242259
@@ -256,7 +273,7 @@ def none_of(
256273 self ._all_of_tags ,
257274 self ._none_of_tags .union (tags ),
258275 self ._all_of_relations ,
259- self ._none_of_relations .union (relations ),
276+ self ._none_of_relations .union (_normalize_query_relation ( relation ) for relation in relations ),
260277 )
261278
262279 def _iter_dependencies (self ) -> Iterator [WorldQuery ]:
@@ -268,14 +285,24 @@ def _iter_dependencies(self) -> Iterator[WorldQuery]:
268285 elif isinstance (relation [0 ], WorldQuery ): # (origins, tag, None)
269286 yield relation [0 ]
270287
288+ def _normalized (self ) -> Query :
289+ """Return a Query with relations normalized."""
290+ return self .__class__ (
291+ self ._all_of_components ,
292+ self ._none_of_components ,
293+ self ._all_of_tags ,
294+ self ._none_of_tags ,
295+ frozenset (_normalize_query_relation (relation ) for relation in self ._all_of_relations ),
296+ frozenset (_normalize_query_relation (relation ) for relation in self ._none_of_relations ),
297+ )
271298
299+
300+ @attrs .define (frozen = True )
272301class WorldQuery :
273302 """Collect a set of entities with the provided conditions."""
274303
275- def __init__ (self , world : World ) -> None :
276- """Initialize a Query."""
277- self .world : Final = world
278- self ._query = Query ()
304+ world : World
305+ _query : Query = attrs .field (factory = Query )
279306
280307 def _get_entities (self , extra_components : AbstractSet [_ComponentKey [object ]] = frozenset ()) -> set [Entity ]:
281308 return _get_query (self .all_of (components = extra_components ))
@@ -288,8 +315,9 @@ def all_of(
288315 relations : Iterable [_RelationQuery ] = (),
289316 ) -> Self :
290317 """Filter entities based on having all of the provided elements."""
291- self ._query = self ._query .all_of (components = components , tags = tags , relations = relations , _stacklevel = 2 )
292- return self
318+ return self .__class__ (
319+ self .world , self ._query .all_of (components = components , tags = tags , relations = relations , _stacklevel = 2 )
320+ )
293321
294322 def none_of (
295323 self ,
@@ -299,8 +327,9 @@ def none_of(
299327 relations : Iterable [_RelationQuery ] = (),
300328 ) -> Self :
301329 """Filter entities based on having none of the provided elements."""
302- self ._query = self ._query .none_of (components = components , tags = tags , relations = relations , _stacklevel = 2 )
303- return self
330+ return self .__class__ (
331+ self .world , self ._query .none_of (components = components , tags = tags , relations = relations , _stacklevel = 2 )
332+ )
304333
305334 def __iter__ (self ) -> Iterator [Entity ]:
306335 """Iterate over the matching entities."""
0 commit comments