Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bigframes/core/agg_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import functools
import itertools
import typing
from typing import Callable, Mapping, Tuple, TypeVar
from typing import Callable, Hashable, Mapping, Tuple, TypeVar

from bigframes import dtypes
from bigframes.core import expression, window_spec
Expand Down Expand Up @@ -68,7 +68,7 @@ def children(self) -> Tuple[expression.Expression, ...]:
return self.inputs

@property
def free_variables(self) -> typing.Tuple[str, ...]:
def free_variables(self) -> typing.Tuple[Hashable, ...]:
return tuple(
itertools.chain.from_iterable(map(lambda x: x.free_variables, self.inputs))
)
Expand All @@ -92,7 +92,7 @@ def transform_children(

def bind_variables(
self: TExpression,
bindings: Mapping[str, expression.Expression],
bindings: Mapping[Hashable, expression.Expression],
allow_partial_bindings: bool = False,
) -> TExpression:
return self.transform_children(
Expand Down Expand Up @@ -192,7 +192,7 @@ def children(self) -> Tuple[expression.Expression, ...]:
return self.inputs

@property
def free_variables(self) -> typing.Tuple[str, ...]:
def free_variables(self) -> typing.Tuple[Hashable, ...]:
return tuple(
itertools.chain.from_iterable(map(lambda x: x.free_variables, self.inputs))
)
Expand All @@ -216,7 +216,7 @@ def transform_children(

def bind_variables(
self: WindowExpression,
bindings: Mapping[str, expression.Expression],
bindings: Mapping[Hashable, expression.Expression],
allow_partial_bindings: bool = False,
) -> WindowExpression:
return self.transform_children(
Expand Down
119 changes: 119 additions & 0 deletions bigframes/core/col.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
from typing import Any, Hashable

import bigframes.core.expression as bf_expression
import bigframes.operations as bf_ops


# Not to be confused with internal Expressions class
# Name collision unintended
@dataclasses.dataclass(frozen=True)
class Expression:
_value: bf_expression.Expression

def _apply_unary(self, op: bf_ops.UnaryOp) -> Expression:
return Expression(op.as_expr(self._value))

def _apply_binary(self, other: Any, op: bf_ops.BinaryOp, reverse: bool = False):
if isinstance(other, Expression):
other_value = other._value
else:
other_value = bf_expression.const(other)
if reverse:
return Expression(op.as_expr(other_value, self._value))
else:
return Expression(op.as_expr(self._value, other_value))

def __add__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.add_op)

def __radd__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.add_op, reverse=True)

def __sub__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.sub_op)

def __rsub__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.sub_op, reverse=True)

def __mul__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.mul_op)

def __rmul__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.mul_op, reverse=True)

def __truediv__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.div_op)

def __rtruediv__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.div_op, reverse=True)

def __floordiv__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.floordiv_op)

def __rfloordiv__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.floordiv_op, reverse=True)

def __ge__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.ge_op)

def __gt__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.gt_op)

def __le__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.le_op)

def __lt__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.lt_op)

def __eq__(self, other: object) -> Expression: # type: ignore
return self._apply_binary(other, bf_ops.eq_op)

def __ne__(self, other: object) -> Expression: # type: ignore
return self._apply_binary(other, bf_ops.ne_op)

def __mod__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.mod_op)

def __rmod__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.mod_op, reverse=True)

def __and__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.and_op)

def __rand__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.and_op, reverse=True)

def __or__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.or_op)

def __ror__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.or_op, reverse=True)

def __xor__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.xor_op)

def __rxor__(self, other: Any) -> Expression:
return self._apply_binary(other, bf_ops.xor_op, reverse=True)

def __invert__(self) -> Expression:
return self._apply_unary(bf_ops.invert_op)


def col(col_name: Hashable) -> Expression:
return Expression(bf_expression.free_var(col_name))
32 changes: 21 additions & 11 deletions bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import functools
import itertools
import typing
from typing import Callable, Generator, Mapping, TypeVar, Union
from typing import Callable, Generator, Hashable, Mapping, TypeVar, Union

import pandas as pd

Expand All @@ -39,7 +39,7 @@ def deref(name: str) -> DerefOp:
return DerefOp(ids.ColumnId(name))


def free_var(id: str) -> UnboundVariableExpression:
def free_var(id: Hashable) -> UnboundVariableExpression:
return UnboundVariableExpression(id)


Expand All @@ -52,7 +52,7 @@ class Expression(abc.ABC):
"""An expression represents a computation taking N scalar inputs and producing a single output scalar."""

@property
def free_variables(self) -> typing.Tuple[str, ...]:
def free_variables(self) -> typing.Tuple[Hashable, ...]:
return ()

@property
Expand Down Expand Up @@ -116,7 +116,9 @@ def bind_refs(

@abc.abstractmethod
def bind_variables(
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
self,
bindings: Mapping[Hashable, Expression],
allow_partial_bindings: bool = False,
) -> Expression:
"""Replace variables with expression given in `bindings`.

Expand Down Expand Up @@ -191,7 +193,9 @@ def output_type(self) -> dtypes.ExpressionType:
return self.dtype

def bind_variables(
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
self,
bindings: Mapping[Hashable, Expression],
allow_partial_bindings: bool = False,
) -> Expression:
return self

Expand Down Expand Up @@ -226,10 +230,10 @@ def transform_children(self, t: Callable[[Expression], Expression]) -> Expressio
class UnboundVariableExpression(Expression):
"""A variable expression representing an unbound variable."""

id: str
id: Hashable

@property
def free_variables(self) -> typing.Tuple[str, ...]:
def free_variables(self) -> typing.Tuple[Hashable, ...]:
return (self.id,)

@property
Expand All @@ -256,7 +260,9 @@ def bind_refs(
return self

def bind_variables(
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
self,
bindings: Mapping[Hashable, Expression],
allow_partial_bindings: bool = False,
) -> Expression:
if self.id in bindings.keys():
return bindings[self.id]
Expand Down Expand Up @@ -304,7 +310,9 @@ def output_type(self) -> dtypes.ExpressionType:
raise ValueError(f"Type of variable {self.id} has not been fixed.")

def bind_variables(
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
self,
bindings: Mapping[Hashable, Expression],
allow_partial_bindings: bool = False,
) -> Expression:
return self

Expand Down Expand Up @@ -373,7 +381,7 @@ def column_references(
)

@property
def free_variables(self) -> typing.Tuple[str, ...]:
def free_variables(self) -> typing.Tuple[Hashable, ...]:
return tuple(
itertools.chain.from_iterable(map(lambda x: x.free_variables, self.inputs))
)
Expand Down Expand Up @@ -408,7 +416,9 @@ def output_type(self) -> dtypes.ExpressionType:
return self.op.output_type(*input_types)

def bind_variables(
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
self,
bindings: Mapping[Hashable, Expression],
allow_partial_bindings: bool = False,
) -> OpExpression:
return OpExpression(
self.op,
Expand Down
16 changes: 15 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from bigframes.core import agg_expressions
import bigframes.core.block_transforms as block_ops
import bigframes.core.blocks as blocks
import bigframes.core.col
import bigframes.core.convert
import bigframes.core.explode
import bigframes.core.expression as ex
Expand Down Expand Up @@ -94,7 +95,13 @@
import bigframes.session

SingleItemValue = Union[
bigframes.series.Series, int, float, str, pandas.Timedelta, Callable
bigframes.series.Series,
int,
float,
str,
pandas.Timedelta,
Callable,
bigframes.core.col.Expression,
]
MultiItemValue = Union[
"DataFrame", Sequence[int | float | str | pandas.Timedelta | Callable]
Expand Down Expand Up @@ -2236,6 +2243,13 @@ def _assign_single_item(
) -> DataFrame:
if isinstance(v, bigframes.series.Series):
return self._assign_series_join_on_index(k, v)
elif isinstance(v, bigframes.core.col.Expression):
label_to_col_ref = {
label: ex.deref(id) for id, label in self._block.col_id_to_label.items()
}
resolved_expr = v._value.bind_variables(label_to_col_ref)
block = self._block.project_block_exprs([resolved_expr], labels=[k])
return DataFrame(block)
elif isinstance(v, bigframes.dataframe.DataFrame):
v_df_col_count = len(v._block.value_columns)
if v_df_col_count != 1:
Expand Down
4 changes: 2 additions & 2 deletions bigframes/operations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
return dtypes.TIMEDELTA_DTYPE

if dtypes.is_numeric(input_types[0]):
if pd.api.types.is_bool_dtype(input_types[0]):
if pd.api.types.is_bool_dtype(input_types[0]): # type: ignore
return dtypes.INT_DTYPE
return input_types[0]

Expand All @@ -224,7 +224,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
# These will change if median is changed to exact implementation.
if not dtypes.is_orderable(input_types[0]):
raise TypeError(f"Type {input_types[0]} is not orderable")
if pd.api.types.is_bool_dtype(input_types[0]):
if pd.api.types.is_bool_dtype(input_types[0]): # type: ignore
return dtypes.INT_DTYPE
else:
return input_types[0]
Expand Down
2 changes: 2 additions & 0 deletions bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pandas

import bigframes._config as config
from bigframes.core.col import col
import bigframes.core.global_session as global_session
import bigframes.core.indexes
from bigframes.core.logging import log_adapter
Expand Down Expand Up @@ -415,6 +416,7 @@ def reset_session():
"clean_up_by_session_id",
"concat",
"crosstab",
"col",
"cut",
"deploy_remote_function",
"deploy_udf",
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_dataframe_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,25 @@ def test_assign_new_column(scalars_dfs):
assert_frame_equal(bf_result, pd_result)


def test_assign_using_pd_col(scalars_dfs):
if pd.__version__.startswith("1.") or pd.__version__.startswith("2."):
pytest.skip("col expression interface only supported for pandas 3+")
scalars_df, scalars_pandas_df = scalars_dfs
bf_kwargs = {
"new_col_1": 4 - bpd.col("int64_col"),
"new_col_2": bpd.col("int64_col") / (bpd.col("float64_col") * 0.5),
}
pd_kwargs = {
"new_col_1": 4 - pd.col("int64_col"), # type: ignore
"new_col_2": pd.col("int64_col") / (pd.col("float64_col") * 0.5), # type: ignore
}
df = scalars_df.assign(**bf_kwargs)
bf_result = df.to_pandas()
pd_result = scalars_pandas_df.assign(**pd_kwargs)

assert_frame_equal(bf_result, pd_result)


def test_assign_new_column_w_loc(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
bf_df = scalars_df.copy()
Expand Down
Loading