Skip to content

Commit f4be01d

Browse files
authored
Merge pull request #6 from dstansby/scatter
Scatter widget
2 parents 45c4f4b + 239a168 commit f4be01d

File tree

6 files changed

+101
-0
lines changed

6 files changed

+101
-0
lines changed

examples/scatter.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Scatter plots
3+
=============
4+
"""
5+
import napari
6+
7+
viewer = napari.Viewer()
8+
viewer.open_sample("napari", "kidney")
9+
10+
viewer.window.add_plugin_dock_widget(
11+
plugin_name="napari-matplotlib", widget_name="Scatter"
12+
)
13+
14+
if __name__ == "__main__":
15+
napari.run()

src/napari_matplotlib/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55

66

77
from .histogram import * # NoQA
8+
from .scatter import * # NoQA

src/napari_matplotlib/base.py

+7
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,10 @@ def __init__(self, napari_viewer: napari.viewer.Viewer):
3636

3737
self.setLayout(QVBoxLayout())
3838
self.layout().addWidget(self.canvas)
39+
40+
@property
41+
def current_z(self) -> int:
42+
"""
43+
Current z-step of the viewer.
44+
"""
45+
return self.viewer.dims.current_step[0]

src/napari_matplotlib/napari.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ contributions:
66
python_name: napari_matplotlib:HistogramWidget
77
title: Make a histogram
88

9+
- id: napari-matplotlib.scatter
10+
python_name: napari_matplotlib:ScatterWidget
11+
title: Make a scatter plot
12+
913
widgets:
1014
- command: napari-matplotlib.histogram
1115
display_name: Histogram
16+
17+
- command: napari-matplotlib.scatter
18+
display_name: Scatter

src/napari_matplotlib/scatter.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import matplotlib.colors as mcolor
2+
import napari
3+
4+
from .base import NapariMPLWidget
5+
6+
__all__ = ["ScatterWidget"]
7+
8+
9+
class ScatterWidget(NapariMPLWidget):
10+
"""
11+
Widget to display scatter plot of two similarly shaped layers.
12+
13+
If there are more than 500 data points, a 2D histogram is displayed instead
14+
of a scatter plot, to avoid too many scatter points.
15+
16+
Attributes
17+
----------
18+
layers : list[`napari.layers.Layer`]
19+
Current two layers being scattered.
20+
"""
21+
22+
def __init__(self, napari_viewer: napari.viewer.Viewer):
23+
super().__init__(napari_viewer)
24+
self.layers = self.viewer.layers[-2:]
25+
26+
self.viewer.dims.events.current_step.connect(
27+
self.scatter_current_layers
28+
)
29+
self.viewer.layers.selection.events.changed.connect(self.update_layers)
30+
31+
self.scatter_current_layers()
32+
33+
def update_layers(self, event: napari.utils.events.Event) -> None:
34+
"""
35+
Update the currently selected layers.
36+
"""
37+
# Update current layer when selection changed in viewer
38+
layers = self.viewer.layers.selection
39+
if len(layers) == 2:
40+
self.layers = list(layers)
41+
self.scatter_current_layers()
42+
43+
def scatter_current_layers(self) -> None:
44+
"""
45+
Clear the axes and scatter the currently selected layers.
46+
"""
47+
self.axes.clear()
48+
data = [layer.data[self.current_z] for layer in self.layers]
49+
if data[0].size < 500:
50+
self.axes.scatter(data[0], data[1], alpha=0.5)
51+
else:
52+
self.axes.hist2d(
53+
data[0].ravel(),
54+
data[1].ravel(),
55+
bins=100,
56+
norm=mcolor.LogNorm(),
57+
)
58+
self.axes.set_xlabel(self.layers[0].name)
59+
self.axes.set_ylabel(self.layers[1].name)
60+
self.canvas.draw()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import numpy as np
2+
3+
from napari_matplotlib import ScatterWidget
4+
5+
6+
def test_scatter(make_napari_viewer):
7+
# Smoke test adding a histogram widget
8+
viewer = make_napari_viewer()
9+
viewer.add_image(np.random.random((100, 100)))
10+
viewer.add_image(np.random.random((100, 100)))
11+
ScatterWidget(viewer)

0 commit comments

Comments
 (0)