Skip to content

Commit

Permalink
Refactor evaluation environment
Browse files Browse the repository at this point in the history
Vendored EvalEnvironment from patsy and customised it a lil bit.
Now, the Environment class does raise an exception when pickled.
Although what is pickled is an empty environment.

closes #729
  • Loading branch information
has2k1 committed Jan 5, 2024
1 parent fd4ab67 commit 49ec85b
Show file tree
Hide file tree
Showing 25 changed files with 372 additions and 118 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ title: Changelog
default is set to `True`, which pads the domain with `-inf` and `inf` so
that the ECDF does not have discontinuities at the extremes. To get the
behaviour, set `pad` to `False`. ({{< issue 725 >}})
- Removed the environment parameter from `ggplot`.

### New

Expand Down Expand Up @@ -77,6 +78,8 @@ title: Changelog
- All `__all__` variables are explicitly assigned to help static typecheckers
infer module attributes. ({{< issue 685 >}})

- You can now pickle the drawn matplotlib figures. ({{< issue 729 >}})

## v0.12.1
(2023-05-09)

Expand Down
6 changes: 3 additions & 3 deletions plotnine/coords/coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class coord:
# if the coordinate system needs them
params: dict[str, Any]

def __radd__(self, gg: Ggplot) -> Ggplot:
def __radd__(self, plot: Ggplot) -> Ggplot:
"""
Add coordinates to ggplot object
"""
gg.coordinates = copy(self)
return gg
plot.coordinates = copy(self)
return plot

def setup_data(self, data: list[pd.DataFrame]) -> list[pd.DataFrame]:
"""
Expand Down
29 changes: 16 additions & 13 deletions plotnine/facets/facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from matplotlib.gridspec import GridSpec

from plotnine.iapi import layout_details, panel_view
from plotnine.mapping import Environment
from plotnine.typing import (
Axes,
CanBeStripLabellingFunc,
Coord,
EvalEnvironment,
Figure,
Ggplot,
Layers,
Expand Down Expand Up @@ -124,6 +124,9 @@ class facet:

grid_spec: GridSpec

# The plot environment
environment: Environment

def __init__(
self,
scales: Literal["fixed", "free", "free_x", "free_y"] = "fixed",
Expand All @@ -145,23 +148,23 @@ def __init__(
"y": scales in ("free_y", "free"),
}

def __radd__(self, gg: Ggplot) -> Ggplot:
def __radd__(self, plot: Ggplot) -> Ggplot:
"""
Add facet to ggplot object
"""
gg.facet = copy(self)
gg.facet.plot = gg
return gg
plot.facet = copy(self)
plot.facet.environment = plot.environment
return plot

def set_properties(self, gg: Ggplot):
def set_properties(self, plot: Ggplot):
"""
Copy required properties from ggplot object
"""
self.axs = gg.axs
self.coordinates = gg.coordinates
self.figure = gg.figure
self.layout = gg.layout
self.theme = gg.theme
self.axs = plot.axs
self.coordinates = plot.coordinates
self.figure = plot.figure
self.layout = plot.layout
self.theme = plot.theme
self.strips = Strips.from_facet(self)

def setup_data(self, data: list[pd.DataFrame]) -> list[pd.DataFrame]:
Expand Down Expand Up @@ -500,7 +503,7 @@ def _aspect_ratio(self) -> Optional[float]:

def combine_vars(
data: list[pd.DataFrame],
environment: EvalEnvironment,
environment: Environment,
vars: list[str],
drop: bool = True,
) -> pd.DataFrame:
Expand Down Expand Up @@ -631,7 +634,7 @@ def add_missing_facets(


def eval_facet_vars(
data: pd.DataFrame, vars: list[str], env: EvalEnvironment
data: pd.DataFrame, vars: list[str], env: Environment
) -> pd.DataFrame:
"""
Evaluate facet variables
Expand Down
6 changes: 3 additions & 3 deletions plotnine/facets/facet_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ def compute_layout(self, data: list[pd.DataFrame]) -> pd.DataFrame:
return layout_null()

base_rows = combine_vars(
data, self.plot.environment, self.rows, drop=self.drop
data, self.environment, self.rows, drop=self.drop
)

if not self.as_table:
# Reverse the order of the rows
base_rows = base_rows[::-1]
base_cols = combine_vars(
data, self.plot.environment, self.cols, drop=self.drop
data, self.environment, self.cols, drop=self.drop
)

base = cross_join(base_rows, base_cols)
Expand Down Expand Up @@ -201,7 +201,7 @@ def map(self, data: pd.DataFrame, layout: pd.DataFrame) -> pd.DataFrame:
)
data = add_margins(data, margin_vars, self.margins)

facet_vals = eval_facet_vars(data, vars, self.plot.environment)
facet_vals = eval_facet_vars(data, vars, self.environment)
data, facet_vals = add_missing_facets(data, layout, vars, facet_vals)

# assign each point to a panel
Expand Down
6 changes: 2 additions & 4 deletions plotnine/facets/facet_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def compute_layout(
if not self.vars:
return layout_null()

base = combine_vars(
data, self.plot.environment, self.vars, drop=self.drop
)
base = combine_vars(data, self.environment, self.vars, drop=self.drop)
n = len(base)
dims = wrap_dims(n, self._nrow, self._ncol)
_id = np.arange(1, n + 1)
Expand Down Expand Up @@ -160,7 +158,7 @@ def map(self, data: pd.DataFrame, layout: pd.DataFrame) -> pd.DataFrame:
)
return data

facet_vals = eval_facet_vars(data, self.vars, self.plot.environment)
facet_vals = eval_facet_vars(data, self.vars, self.environment)
data, facet_vals = add_missing_facets(
data, layout, self.vars, facet_vals
)
Expand Down
6 changes: 3 additions & 3 deletions plotnine/geoms/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ def __init__(
**kwargs,
)

def __radd__(self, gg: Ggplot) -> Ggplot:
def __radd__(self, plot: Ggplot) -> Ggplot:
"""
Add to ggplot
"""
gg += self.to_layer() # Add layer
return gg
plot += self.to_layer() # Add layer
return plot

def to_layer(self) -> Layer:
"""
Expand Down
12 changes: 6 additions & 6 deletions plotnine/geoms/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import pandas as pd

from plotnine.iapi import panel_view
from plotnine.mapping import Environment
from plotnine.typing import (
Aes,
Axes,
Coord,
DataLike,
DrawingArea,
EvalEnvironment,
Ggplot,
Layer,
Layout,
Expand Down Expand Up @@ -64,7 +64,7 @@ class geom(ABC, metaclass=Register):

# Plot namespace, it gets its value when the plot is being
# built.
environment: EvalEnvironment
environment: Environment

# The geom responsible for the legend if draw_legend is
# not implemented
Expand Down Expand Up @@ -427,22 +427,22 @@ def draw_unit(
msg = "The geom should implement this method."
raise NotImplementedError(msg)

def __radd__(self, gg: Ggplot) -> Ggplot:
def __radd__(self, plot: Ggplot) -> Ggplot:
"""
Add layer representing geom object on the right
Parameters
----------
gg :
plot :
ggplot object
Returns
-------
:
ggplot object with added layer.
"""
gg += self.to_layer() # Add layer
return gg
plot += self.to_layer() # Add layer
return plot

def to_layer(self) -> Layer:
"""
Expand Down
25 changes: 9 additions & 16 deletions plotnine/ggplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
Axes,
Coord,
DataLike,
EvalEnvironment,
Facet,
Figure,
Layer,
Expand All @@ -63,11 +62,11 @@ class ggplot:
mapping :
Default aesthetics mapping for the plot. These will be used
by all layers unless specifically overridden.
environment :
If a variable defined in the aesthetic mapping is not
found in the data, ggplot will look for it in this
namespace. It defaults to using the environment/namespace.
in which `ggplot()` is called.
Notes
-----
ggplot object only have partial support for pickling. The mappings used
by pickled objects should not reference variables in the namespace.
"""

figure: Figure
Expand All @@ -80,9 +79,8 @@ def __init__(
self,
data: Optional[DataLike] = None,
mapping: Optional[aes] = None,
environment: Optional[EvalEnvironment] = None,
):
from patsy.eval import EvalEnvironment
from .mapping._env import Environment

# Allow some sloppiness
data, mapping = order_as_data_mapping(data, mapping)
Expand All @@ -95,7 +93,7 @@ def __init__(
self.scales = Scales()
self.theme = theme_get()
self.coordinates: Coord = coord_cartesian()
self.environment = environment or EvalEnvironment.capture(1)
self.environment = Environment.capture(1)
self.layout = Layout()
self.watermarks: list[Watermark] = []

Expand Down Expand Up @@ -133,13 +131,8 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> ggplot:
old = self.__dict__
new = result.__dict__

# don't make a deepcopy of data, or environment
shallow = {
"data",
"environment",
"figure",
"_build_objs",
}
# don't make a deepcopy of data
shallow = {"data", "figure", "_build_objs"}
for key, item in old.items():
if key in shallow:
new[key] = item
Expand Down
8 changes: 4 additions & 4 deletions plotnine/guides/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def __init__(self, **kwargs):
self, ((ae, kwargs[ae]) for ae in kwargs if ae in aes_names)
)

def __radd__(self, gg):
def __radd__(self, plot):
"""
Add guides to the plot
Parameters
----------
gg : ggplot
plot : ggplot
ggplot object being created
Returns
Expand All @@ -91,8 +91,8 @@ def __radd__(self, gg):
new_guides = {}
for k in self:
new_guides[k] = deepcopy(self[k])
gg.guides.update(new_guides)
return gg
plot.guides.update(new_guides)
return plot

def build(self, plot):
"""
Expand Down
6 changes: 3 additions & 3 deletions plotnine/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __init__(self, **kwargs: str):
raise PlotnineError(f"Cannot deal with these labels: {unknown}")
self.labels = labels_view(**rename_aesthetics(kwargs))

def __radd__(self, gg: p9.ggplot) -> p9.ggplot:
def __radd__(self, plot: p9.ggplot) -> p9.ggplot:
"""
Add labels to ggplot object
"""
gg.labels.update(self.labels)
return gg
plot.labels.update(self.labels)
return plot


class xlab(labs):
Expand Down
18 changes: 9 additions & 9 deletions plotnine/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing
from copy import copy, deepcopy
from typing import Iterable, List, overload
from typing import Iterable, List, cast, overload

import pandas as pd

Expand All @@ -14,11 +14,11 @@
if typing.TYPE_CHECKING:
from typing import Any, Optional, Sequence, SupportsIndex

from plotnine.mapping import Environment
from plotnine.typing import (
Coord,
DataFrameConvertible,
DataLike,
EvalEnvironment,
Geom,
Ggplot,
Layer,
Expand Down Expand Up @@ -125,16 +125,16 @@ def from_geom(geom: Geom) -> Layer:
lkwargs[param] = geom.DEFAULT_PARAMS[param]
return layer(**lkwargs)

def __radd__(self, gg: Ggplot) -> Ggplot:
def __radd__(self, plot: Ggplot) -> Ggplot:
"""
Add layer to ggplot object
"""
try:
gg.layers.append(self)
plot.layers.append(self)
except AttributeError as e:
msg = f"Cannot add layer to object of type {type(gg)!r}"
msg = f"Cannot add layer to object of type {type(plot)!r}"
raise PlotnineError(msg) from e
return gg
return plot

def __deepcopy__(self, memo: dict[Any, Any]) -> layer:
"""
Expand Down Expand Up @@ -176,9 +176,9 @@ def _make_layer_data(self, plot_data: DataLike | None):
if plot_data is None:
data = pd.DataFrame()
elif hasattr(plot_data, "to_pandas"):
data = typing.cast("DataFrameConvertible", plot_data).to_pandas()
data = cast("DataFrameConvertible", plot_data).to_pandas()
else:
data = typing.cast("pd.DataFrame", plot_data)
data = cast("pd.DataFrame", plot_data)

# Each layer that does not have data gets a copy of
# of the ggplot.data. If it has data it is replaced
Expand Down Expand Up @@ -238,7 +238,7 @@ def _make_layer_mapping(self, plot_mapping: aes):
group = f'"{group}"'
self.mapping["group"] = stage(start=group)

def _make_layer_environments(self, plot_environment: EvalEnvironment):
def _make_layer_environments(self, plot_environment: Environment):
"""
Create the aesthetic mappings to be used by this layer
Expand Down
1 change: 1 addition & 0 deletions plotnine/mapping/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Aesthetic Mappings
"""
from ._env import Environment # noqa: F401
from .aes import aes
from .evaluation import after_scale, after_stat, stage

Expand Down
Loading

0 comments on commit 49ec85b

Please sign in to comment.