Skip to content

python: bypass plotnine auto-closing comms #7657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ numpy==2.0.2; python_version == '3.9'
numpy==2.2.4; python_version >= '3.10'
pandas==2.2.3
plotly==6.0.1
plotnine==0.13.6; python_version == '3.9'
plotnine==0.14.5; python_version >= '3.10'
polars==1.26.0
polars[timezone]==1.26.0; sys_platform == 'win32'
pyarrow==19.0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from ipykernel.zmqshell import ZMQDisplayPublisher, ZMQInteractiveShell
from IPython.core import magic_arguments, oinspect, page
from IPython.core.error import UsageError
from IPython.core.formatters import DisplayFormatter, IPythonDisplayFormatter, catch_format_error
from IPython.core.interactiveshell import ExecutionInfo, ExecutionResult, InteractiveShell
from IPython.core.magic import Magics, MagicsManager, line_magic, magics_class
from IPython.utils import PyColorize
from IPython.utils import PyColorize, dir2

from .access_keys import encode_access_key
from .connections import ConnectionsService
Expand Down Expand Up @@ -220,11 +221,48 @@ def connection_show(self, line: str) -> None:
original_showwarning = warnings.showwarning


class PositronDisplayFormatter(DisplayFormatter):
@traitlets.default("ipython_display_formatter")
def _default_formatter(self):
return PositronIPythonDisplayFormatter(parent=self)


class PositronIPythonDisplayFormatter(IPythonDisplayFormatter):
print_method = traitlets.ObjectName("_ipython_display_")
_return_type = (type(None), bool)

@catch_format_error
def __call__(self, obj):
"""Compute the format for an object."""
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be wrapped in if self.enabled? I'm not really sure how these work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into this! I am also seeing different plots than anticipated for the retina setting in plotnine, which is maybe related 👀

if obj.__module__ == "plotnine.ggplot":
Copy link
Contributor Author

@isabelizimm isabelizimm May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're special casing plotnine here since there is autoplot opening/closing code that causes Positron's plots comm to be closed and the plot gets sent over a more vanilla "display_data" call. When the get_intrinsic_size RPC call was made, we were getting an error since there was no longer a Positron comm to look for.

We've been able to patch other packages in the posit/patches/ directory, but this requires an interception of _ipython_display_() , not really a patch to the code itself. What do people think about this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good solution!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems totally reasonable to me given we already are doing non-standard stuff with the comms.

obj.draw(show=True)
return True
except AttributeError:
pass
if self.enabled:
# lookup registered printer
try:
printer = self.lookup(obj)
except KeyError:
pass
else:
printer(obj)
return True
# Finally look for special method names
method = dir2.get_real_method(obj, self.print_method)
if method is not None:
method()
return True
return True
Comment on lines +243 to +257
Copy link
Contributor

@seeM seeM May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we call super here instead of inlining that code?



class PositronShell(ZMQInteractiveShell):
kernel: PositronIPyKernel
object_info_string_level: int
magics_manager: MagicsManager
display_pub: ZMQDisplayPublisher
display_formatter: PositronDisplayFormatter = traitlets.Instance(PositronDisplayFormatter) # type: ignore

inspector_class: type[PositronIPythonInspector] = traitlets.Type(
PositronIPythonInspector, # type: ignore
Expand Down Expand Up @@ -296,6 +334,10 @@ def init_user_ns(self):
}
)

def init_display_formatter(self):
self.display_formatter = PositronDisplayFormatter(parent=self)
self.configurables.append(self.display_formatter) # type: ignore IPython type annotation is wrong

def _handle_pre_run_cell(self, info: ExecutionInfo) -> None:
"""Prior to execution, reset the user environment watch state."""
# If an empty cell is being executed, do nothing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,25 @@ def test_mpl_shutdown(shell: PositronShell, plots_service: PlotsService) -> None
# Plots are closed and cleared.
assert not plots_service._plots # noqa: SLF001
assert all(comm._closed for comm in plot_comms) # noqa: SLF001


def test_plotnine_close_then_show(shell: PositronShell, plots_service: PlotsService) -> None:
"""Test that a plotnine plot renders and then closes comm correctly."""
shell.run_cell("""\
from plotnine import ggplot, geom_point, aes, stat_smooth, facet_wrap
from plotnine.data import mtcars

(
ggplot(mtcars, aes("wt", "mpg", color="factor(gear)"))
+ geom_point()
+ stat_smooth(method="lm")
+ facet_wrap("gear")
)\
""").raise_error()
plot_comm = cast("DummyComm", plots_service._plots[0]._comm.comm) # noqa: SLF001

assert plot_comm.messages == [
comm_open_message(_CommTarget.Plot),
json_rpc_notification("show", {}),
]
assert not plot_comm._closed # noqa: SLF001
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ matplotlib
numpy
pandas
plotly
plotnine
polars
polars[timezone]; sys_platform == 'win32'
pyarrow
Expand Down
Loading