Skip to content

Commit 9c49a29

Browse files
committed
Normalize relation queries to reduce overhead
Also make WorldQuery frozen.
1 parent f2529d9 commit 9c49a29

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

tcod/ecs/query.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55
import warnings
66
from collections import defaultdict
7-
from typing import TYPE_CHECKING, AbstractSet, Any, Final, Iterable, Iterator, TypeVar, overload
7+
from typing import TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator, TypeVar, overload
88
from weakref import WeakKeyDictionary, WeakSet
99

1010
import attrs
@@ -192,6 +192,7 @@ def _get_query(w_query: WorldQuery) -> set[Entity]:
192192
"""Return the entities for the given query and world."""
193193
world = w_query.world
194194
query = w_query._query
195+
assert query == query._normalized(), "Double checks that relations are correct"
195196
cache = _get_query_cache(world)
196197
if cache is not None:
197198
cached_entities = cache.queries.get(query)
@@ -204,6 +205,22 @@ def _get_query(w_query: WorldQuery) -> set[Entity]:
204205
return entities
205206

206207

208+
def _normalize_query_relation(relation: _RelationQuery) -> _RelationQuery:
209+
"""Normalize a relation query.
210+
211+
This adds the inverse lookup to the sub-query so that this only matches entities which have a relation.
212+
"""
213+
if len(relation) == 2: # noqa: PLR2004
214+
tag, targets = relation # type: ignore[misc] # https://github.com/python/mypy/issues/1178
215+
if isinstance(targets, WorldQuery): # (tag, targets)
216+
return tag, targets.all_of(relations=[(..., tag, None)])
217+
return relation
218+
origin, tag, _ = relation # type: ignore[misc] # https://github.com/python/mypy/issues/1178
219+
if isinstance(origin, WorldQuery): # (origins, tag, None)
220+
return origin.all_of(relations=[(tag, ...)]), tag, None
221+
return relation
222+
223+
207224
@attrs.define(frozen=True)
208225
class Query:
209226
"""A set of conditions used to lookup entities in a World."""
@@ -236,7 +253,7 @@ def all_of(
236253
self._none_of_components,
237254
self._all_of_tags.union(tags),
238255
self._none_of_tags,
239-
self._all_of_relations.union(relations),
256+
self._all_of_relations.union(_normalize_query_relation(relation) for relation in relations),
240257
self._none_of_relations,
241258
)
242259

@@ -256,7 +273,7 @@ def none_of(
256273
self._all_of_tags,
257274
self._none_of_tags.union(tags),
258275
self._all_of_relations,
259-
self._none_of_relations.union(relations),
276+
self._none_of_relations.union(_normalize_query_relation(relation) for relation in relations),
260277
)
261278

262279
def _iter_dependencies(self) -> Iterator[WorldQuery]:
@@ -268,14 +285,24 @@ def _iter_dependencies(self) -> Iterator[WorldQuery]:
268285
elif isinstance(relation[0], WorldQuery): # (origins, tag, None)
269286
yield relation[0]
270287

288+
def _normalized(self) -> Query:
289+
"""Return a Query with relations normalized."""
290+
return self.__class__(
291+
self._all_of_components,
292+
self._none_of_components,
293+
self._all_of_tags,
294+
self._none_of_tags,
295+
frozenset(_normalize_query_relation(relation) for relation in self._all_of_relations),
296+
frozenset(_normalize_query_relation(relation) for relation in self._none_of_relations),
297+
)
271298

299+
300+
@attrs.define(frozen=True)
272301
class WorldQuery:
273302
"""Collect a set of entities with the provided conditions."""
274303

275-
def __init__(self, world: World) -> None:
276-
"""Initialize a Query."""
277-
self.world: Final = world
278-
self._query = Query()
304+
world: World
305+
_query: Query = attrs.field(factory=Query)
279306

280307
def _get_entities(self, extra_components: AbstractSet[_ComponentKey[object]] = frozenset()) -> set[Entity]:
281308
return _get_query(self.all_of(components=extra_components))
@@ -288,8 +315,9 @@ def all_of(
288315
relations: Iterable[_RelationQuery] = (),
289316
) -> Self:
290317
"""Filter entities based on having all of the provided elements."""
291-
self._query = self._query.all_of(components=components, tags=tags, relations=relations, _stacklevel=2)
292-
return self
318+
return self.__class__(
319+
self.world, self._query.all_of(components=components, tags=tags, relations=relations, _stacklevel=2)
320+
)
293321

294322
def none_of(
295323
self,
@@ -299,8 +327,9 @@ def none_of(
299327
relations: Iterable[_RelationQuery] = (),
300328
) -> Self:
301329
"""Filter entities based on having none of the provided elements."""
302-
self._query = self._query.none_of(components=components, tags=tags, relations=relations, _stacklevel=2)
303-
return self
330+
return self.__class__(
331+
self.world, self._query.none_of(components=components, tags=tags, relations=relations, _stacklevel=2)
332+
)
304333

305334
def __iter__(self) -> Iterator[Entity]:
306335
"""Iterate over the matching entities."""

tests/test_relations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
def test_conditional_relations() -> None:
1414
world = tcod.ecs.World()
1515
world["A"].relation_tag[ChildOf] = world["B"]
16+
world["C"].components[int] = 42
1617
has_int_query = world.Q.all_of(components=[int])
1718
assert not set(world.Q.all_of(relations=[(ChildOf, has_int_query)]))
1819
assert not set(world.Q.all_of(relations=[(has_int_query, ChildOf, None)]))

0 commit comments

Comments
 (0)