|
| 1 | +import matplotlib.colors as mcolor |
| 2 | +import napari |
| 3 | + |
| 4 | +from .base import NapariMPLWidget |
| 5 | + |
| 6 | +__all__ = ["ScatterWidget"] |
| 7 | + |
| 8 | + |
| 9 | +class ScatterWidget(NapariMPLWidget): |
| 10 | + """ |
| 11 | + Widget to display scatter plot of two similarly shaped layers. |
| 12 | +
|
| 13 | + If there are more than 500 data points, a 2D histogram is displayed instead |
| 14 | + of a scatter plot, to avoid too many scatter points. |
| 15 | +
|
| 16 | + Attributes |
| 17 | + ---------- |
| 18 | + layers : list[`napari.layers.Layer`] |
| 19 | + Current two layers being scattered. |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__(self, napari_viewer: napari.viewer.Viewer): |
| 23 | + super().__init__(napari_viewer) |
| 24 | + self.layers = self.viewer.layers[-2:] |
| 25 | + |
| 26 | + self.viewer.dims.events.current_step.connect( |
| 27 | + self.scatter_current_layers |
| 28 | + ) |
| 29 | + self.viewer.layers.selection.events.changed.connect(self.update_layers) |
| 30 | + |
| 31 | + self.scatter_current_layers() |
| 32 | + |
| 33 | + def update_layers(self, event: napari.utils.events.Event) -> None: |
| 34 | + """ |
| 35 | + Update the currently selected layers. |
| 36 | + """ |
| 37 | + # Update current layer when selection changed in viewer |
| 38 | + layers = self.viewer.layers.selection |
| 39 | + if len(layers) == 2: |
| 40 | + self.layers = list(layers) |
| 41 | + self.scatter_current_layers() |
| 42 | + |
| 43 | + def scatter_current_layers(self) -> None: |
| 44 | + """ |
| 45 | + Clear the axes and scatter the currently selected layers. |
| 46 | + """ |
| 47 | + self.axes.clear() |
| 48 | + data = [layer.data[self.current_z] for layer in self.layers] |
| 49 | + if data[0].size < 500: |
| 50 | + self.axes.scatter(data[0], data[1], alpha=0.5) |
| 51 | + else: |
| 52 | + self.axes.hist2d( |
| 53 | + data[0].ravel(), |
| 54 | + data[1].ravel(), |
| 55 | + bins=100, |
| 56 | + norm=mcolor.LogNorm(), |
| 57 | + ) |
| 58 | + self.axes.set_xlabel(self.layers[0].name) |
| 59 | + self.axes.set_ylabel(self.layers[1].name) |
| 60 | + self.canvas.draw() |
0 commit comments