Skip to content

Create a copy of TypeQuery specialized for bool #9604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
6 changes: 3 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType,
Instance, NoneType, strip_type, TypeType, TypeOfAny,
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
is_named_instance, union_items, TypeQuery, LiteralType,
is_named_instance, union_items, TypeQueryBool, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
get_proper_types, is_literal_type, TypeAliasType)
from mypy.sametypes import is_same_type
Expand Down Expand Up @@ -5441,11 +5441,11 @@ def is_valid_inferred_type(typ: Type) -> bool:
return not typ.accept(NothingSeeker())


class NothingSeeker(TypeQuery[bool]):
class NothingSeeker(TypeQueryBool):
"""Find any <nothing> types resulting from failed (ambiguous) type inference."""

def __init__(self) -> None:
super().__init__(any)
super().__init__(TypeQueryBool.STRATEGY_ANY)

def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
return t.ambiguous
Expand Down
20 changes: 10 additions & 10 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4184,9 +4184,9 @@ def has_any_type(t: Type) -> bool:
return t.accept(HasAnyType())


class HasAnyType(types.TypeQuery[bool]):
class HasAnyType(types.TypeQueryBool):
def __init__(self) -> None:
super().__init__(any)
super().__init__(types.TypeQueryBool.STRATEGY_ANY)

def visit_any(self, t: AnyType) -> bool:
return t.type_of_any != TypeOfAny.special_form # special forms are not real Any types
Expand Down Expand Up @@ -4251,24 +4251,24 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_type)


class ArgInferSecondPassQuery(types.TypeQuery[bool]):
class ArgInferSecondPassQuery(types.TypeQueryBool):
"""Query whether an argument type should be inferred in the second pass.

The result is True if the type has a type variable in a callable return
type anywhere. For example, the result for Callable[[], T] is True if t is
a type variable.
"""
def __init__(self) -> None:
super().__init__(any)
super().__init__(types.TypeQueryBool.STRATEGY_ANY)

def visit_callable_type(self, t: CallableType) -> bool:
return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery())


class HasTypeVarQuery(types.TypeQuery[bool]):
class HasTypeVarQuery(types.TypeQueryBool):
"""Visitor for querying whether a type has a type variable component."""
def __init__(self) -> None:
super().__init__(any)
super().__init__(types.TypeQueryBool.STRATEGY_ANY)

def visit_type_var(self, t: TypeVarType) -> bool:
return True
Expand All @@ -4278,10 +4278,10 @@ def has_erased_component(t: Optional[Type]) -> bool:
return t is not None and t.accept(HasErasedComponentsQuery())


class HasErasedComponentsQuery(types.TypeQuery[bool]):
class HasErasedComponentsQuery(types.TypeQueryBool):
"""Visitor for querying whether a type has an erased component."""
def __init__(self) -> None:
super().__init__(any)
super().__init__(types.TypeQueryBool.STRATEGY_ANY)

def visit_erased_type(self, t: ErasedType) -> bool:
return True
Expand All @@ -4291,10 +4291,10 @@ def has_uninhabited_component(t: Optional[Type]) -> bool:
return t is not None and t.accept(HasUninhabitedComponentsQuery())


class HasUninhabitedComponentsQuery(types.TypeQuery[bool]):
class HasUninhabitedComponentsQuery(types.TypeQueryBool):
"""Visitor for querying whether a type has an UninhabitedType component."""
def __init__(self) -> None:
super().__init__(any)
super().__init__(types.TypeQueryBool.STRATEGY_ANY)

def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
return True
Expand Down
6 changes: 3 additions & 3 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypy.types import (
CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance,
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
UninhabitedType, TypeType, TypeVarId, TypeQueryBool, is_named_instance, TypeOfAny, LiteralType,
ProperType, get_proper_type, TypeAliasType
)
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -255,9 +255,9 @@ def is_complete_type(typ: Type) -> bool:
return typ.accept(CompleteTypeVisitor())


class CompleteTypeVisitor(TypeQuery[bool]):
class CompleteTypeVisitor(TypeQueryBool):
def __init__(self) -> None:
super().__init__(all)
super().__init__(TypeQueryBool.STRATEGY_ALL)

def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
return False
Expand Down
6 changes: 3 additions & 3 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
get_proper_type, get_proper_types, TypeAliasType)
from mypy.typeops import function_type
from mypy.type_visitor import TypeQuery
from mypy.type_visitor import TypeQueryBool
from mypy.nodes import implicit_module_attrs
from mypy.typeanal import (
TypeAnalyser, analyze_type_alias, no_subscript_builtin_alias,
Expand Down Expand Up @@ -4987,9 +4987,9 @@ def is_future_flag_set(self, flag: str) -> bool:
return flag in self.future_import_flags


class HasPlaceholders(TypeQuery[bool]):
class HasPlaceholders(TypeQueryBool):
def __init__(self) -> None:
super().__init__(any)
super().__init__(TypeQueryBool.STRATEGY_ANY)

def visit_placeholder_type(self, t: PlaceholderType) -> bool:
return True
Expand Down
6 changes: 3 additions & 3 deletions mypy/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mypy.traverser import TraverserVisitor
from mypy.typeanal import collect_all_inner_types
from mypy.types import (
Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType, TypeQuery, CallableType,
Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType, TypeQueryBool, CallableType,
TypeOfAny, get_proper_type, get_proper_types
)
from mypy import nodes
Expand Down Expand Up @@ -423,9 +423,9 @@ def is_imprecise(t: Type) -> bool:
return t.accept(HasAnyQuery())


class HasAnyQuery(TypeQuery[bool]):
class HasAnyQuery(TypeQueryBool):
def __init__(self) -> None:
super().__init__(any)
super().__init__(TypeQueryBool.STRATEGY_ANY)

def visit_any(self, t: AnyType) -> bool:
return not is_special_form_any(t)
Expand Down
121 changes: 120 additions & 1 deletion mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from abc import abstractmethod
from mypy.ordered_dict import OrderedDict
from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional, Set, Sequence
from typing_extensions import Final
from mypy_extensions import trait, mypyc_attr

T = TypeVar('T')
Expand Down Expand Up @@ -249,7 +250,8 @@ class TypeQuery(SyntheticTypeVisitor[T]):
"""Visitor for performing queries of types.

strategy is used to combine results for a series of types,
common use cases involve a boolean query using `any` or `all`.
For cases involving a boolean query using `any` or `all`, the specialized
TypeQueryBool is used.

Note: this visitor keeps an internal state (tracks type aliases to avoid
recursion), so it should *never* be re-used for querying different types,
Expand Down Expand Up @@ -351,3 +353,120 @@ def query_types(self, types: Iterable[Type]) -> T:
self.seen_aliases.add(t)
res.append(t.accept(self))
return self.strategy(res)


@mypyc_attr(allow_interpreted_subclasses=True)
class TypeQueryBool(SyntheticTypeVisitor[bool]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related: #9602

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the headsup, I added the decorator for allow_interpreted_subclasses=True

"""Specialized visitor for boolean strategies

strategy is used to combine results for a series of types,
Cases involve a boolean query using `any` or `all`.

Note: this visitor keeps an internal state (tracks type aliases to avoid
recursion), so it should *never* be re-used for querying different types,
create a new visitor instance instead.

# TODO: check that we don't have existing violations of this rule.
"""
STRATEGY_ANY = 0 # type: Final
STRATEGY_ALL = 1 # type: Final

def __init__(self, strategy: int) -> None:
self.strategy = strategy
# Keep track of the type aliases already visited. This is needed to avoid
# infinite recursion on types like A = Union[int, List[A]].
self.seen_aliases = set() # type: Set[TypeAliasType]

def bool_strategy_empty(self) -> bool:
if self.strategy == TypeQueryBool.STRATEGY_ALL:
return True
return False

def visit_unbound_type(self, t: UnboundType) -> bool:
return self.query_types(t.args)

def visit_type_list(self, t: TypeList) -> bool:
return self.query_types(t.items)

def visit_callable_argument(self, t: CallableArgument) -> bool:
return t.typ.accept(self)

def visit_any(self, t: AnyType) -> bool:
return self.bool_strategy_empty()

def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
return self.bool_strategy_empty()

def visit_none_type(self, t: NoneType) -> bool:
return self.bool_strategy_empty()

def visit_erased_type(self, t: ErasedType) -> bool:
return self.bool_strategy_empty()

def visit_deleted_type(self, t: DeletedType) -> bool:
return self.bool_strategy_empty()

def visit_type_var(self, t: TypeVarType) -> bool:
return self.query_types([t.upper_bound] + t.values)

def visit_partial_type(self, t: PartialType) -> bool:
return self.bool_strategy_empty()

def visit_instance(self, t: Instance) -> bool:
return self.query_types(t.args)

def visit_callable_type(self, t: CallableType) -> bool:
# FIX generics
return self.query_types(t.arg_types + [t.ret_type])

def visit_tuple_type(self, t: TupleType) -> bool:
return self.query_types(t.items)

def visit_typeddict_type(self, t: TypedDictType) -> bool:
return self.query_types(t.items.values())

def visit_raw_expression_type(self, t: RawExpressionType) -> bool:
return self.bool_strategy_empty()

def visit_literal_type(self, t: LiteralType) -> bool:
return self.bool_strategy_empty()

def visit_star_type(self, t: StarType) -> bool:
return t.type.accept(self)

def visit_union_type(self, t: UnionType) -> bool:
return self.query_types(t.items)

def visit_overloaded(self, t: Overloaded) -> bool:
return self.query_types(t.items())

def visit_type_type(self, t: TypeType) -> bool:
return t.item.accept(self)

def visit_ellipsis_type(self, t: EllipsisType) -> bool:
return self.bool_strategy_empty()

def visit_placeholder_type(self, t: PlaceholderType) -> bool:
return self.query_types(t.args)

def visit_type_alias_type(self, t: TypeAliasType) -> bool:
return get_proper_type(t).accept(self)

def query_types(self, types: Iterable[Type]) -> bool:
"""Perform a query for a list of types.

Use the strategy to combine the results.
Skip type aliases already visited types to avoid infinite recursion.
"""
for t in types:
if isinstance(t, TypeAliasType):
# Avoid infinite recursion for recursive type aliases.
if t in self.seen_aliases:
continue
self.seen_aliases.add(t)
res = t.accept(self)
if res and self.strategy == TypeQueryBool.STRATEGY_ANY:
return True
elif not res and self.strategy == TypeQueryBool.STRATEGY_ALL:
return False
return self.strategy == TypeQueryBool.STRATEGY_ALL
12 changes: 6 additions & 6 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from mypy.types import (
Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType,
CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarDef, SyntheticTypeVisitor,
StarType, PartialType, EllipsisType, UninhabitedType, TypeType,
CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType,
StarType, PartialType, EllipsisType, UninhabitedType, TypeType, CallableArgument,
TypeQuery, TypeQueryBool, union_items, TypeOfAny, LiteralType, RawExpressionType,
PlaceholderType, Overloaded, get_proper_type, TypeAliasType, TypeVarLikeDef, ParamSpecDef
)

Expand Down Expand Up @@ -1190,9 +1190,9 @@ def has_explicit_any(t: Type) -> bool:
return t.accept(HasExplicitAny())


class HasExplicitAny(TypeQuery[bool]):
class HasExplicitAny(TypeQueryBool):
def __init__(self) -> None:
super().__init__(any)
super().__init__(TypeQueryBool.STRATEGY_ANY)

def visit_any(self, t: AnyType) -> bool:
return t.type_of_any == TypeOfAny.explicit
Expand All @@ -1211,9 +1211,9 @@ def has_any_from_unimported_type(t: Type) -> bool:
return t.accept(HasAnyFromUnimportedType())


class HasAnyFromUnimportedType(TypeQuery[bool]):
class HasAnyFromUnimportedType(TypeQueryBool):
def __init__(self) -> None:
super().__init__(any)
super().__init__(TypeQueryBool.STRATEGY_ANY)

def visit_any(self, t: AnyType) -> bool:
return t.type_of_any == TypeOfAny.from_unimported_type
Expand Down
5 changes: 3 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,6 +1987,7 @@ def get_proper_types(it: Iterable[Optional[Type]]
SyntheticTypeVisitor as SyntheticTypeVisitor,
TypeTranslator as TypeTranslator,
TypeQuery as TypeQuery,
TypeQueryBool as TypeQueryBool
)


Expand Down Expand Up @@ -2272,9 +2273,9 @@ def replace_alias_tvars(tp: Type, vars: List[str], subs: List[Type],
return new_tp


class HasTypeVars(TypeQuery[bool]):
class HasTypeVars(TypeQueryBool):
def __init__(self) -> None:
super().__init__(any)
super().__init__(TypeQueryBool.STRATEGY_ANY)

def visit_type_var(self, t: TypeVarType) -> bool:
return True
Expand Down