diff --git a/ibis/selectors.py b/ibis/selectors.py index 9431254d64633..e9c75c6e4fa9d 100644 --- a/ibis/selectors.py +++ b/ibis/selectors.py @@ -128,6 +128,15 @@ def numeric() -> Selector: return of_type(dt.Numeric) +class OfType(Selector): + predicate: Callable[[dt.DataType], bool] + + def expand_names(self, table: ir.Table) -> frozenset[str]: + return frozenset( + name for name, typ in table.schema().items() if self.predicate(typ) + ) + + @public def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector: """Select columns of type `dtype`. @@ -186,17 +195,20 @@ def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector: "struct": dt.Struct, "temporal": dt.Temporal, } - if cls := abstract.get(dtype.lower()): - predicate = lambda col: isinstance(col.type(), cls) + + if dtype_cls := abstract.get(dtype.lower()): + predicate = lambda typ, dtype_cls=dtype_cls: isinstance(typ, dtype_cls) else: dtype = dt.dtype(dtype) - predicate = lambda col: col.type() == dtype + predicate = lambda typ, dtype=dtype: typ == dtype + elif inspect.isclass(dtype) and issubclass(dtype, dt.DataType): - predicate = lambda col: isinstance(col.type(), dtype) + predicate = lambda typ, dtype_cls=dtype: isinstance(typ, dtype_cls) else: dtype = dt.dtype(dtype) - predicate = lambda col: col.type() == dtype - return where(predicate) + predicate = lambda typ, dtype=dtype: typ == dtype + + return OfType(predicate) class StartsWith(Selector):