diff --git a/glue_plotly/common/common.py b/glue_plotly/common/common.py index 4574ef5..21acd54 100644 --- a/glue_plotly/common/common.py +++ b/glue_plotly/common/common.py @@ -137,7 +137,7 @@ def rgb_colors(layer_state, mask, cmap_att): def color_info(layer_state, mask=None, mode_att="cmap_mode", cmap_att="cmap_att"): - if getattr(layer_state, mode_att) == "Fixed": + if getattr(layer_state, mode_att, "Fixed") == "Fixed": return fixed_color(layer_state) else: return rgb_colors(layer_state, mask, cmap_att) diff --git a/glue_plotly/common/dotplot.py b/glue_plotly/common/dotplot.py new file mode 100644 index 0000000..9b94e28 --- /dev/null +++ b/glue_plotly/common/dotplot.py @@ -0,0 +1,52 @@ +from uuid import uuid4 + +from plotly.graph_objs import Scatter + +from glue.core import BaseData + +from .common import color_info, dimensions + + +def dot_radius(viewer, layer_state): + edges = layer_state.histogram[0] + viewer_state = viewer.state + diam_world = min([edges[i + 1] - edges[i] for i in range(len(edges) - 1)]) + width, height = dimensions(viewer) + diam = diam_world * width / abs(viewer_state.x_max - viewer_state.x_min) + if viewer_state.y_min is not None and viewer_state.y_max is not None: + max_diam_world_v = 1 + diam_pixel_v = max_diam_world_v * height / abs(viewer_state.y_max - viewer_state.y_min) + diam = min(diam_pixel_v, diam) + return diam / 2 + + +def traces_for_layer(viewer, layer_state, add_data_label=True): + legend_group = uuid4().hex + dots_id = uuid4().hex + + x = [] + y = [] + edges, counts = layer_state.histogram + counts = counts.astype(int) + for i in range(len(edges) - 1): + x_i = (edges[i] + edges[i + 1]) / 2 + y_i = range(1, counts[i] + 1) + x.extend([x_i] * counts[i]) + y.extend(y_i) + + radius = dot_radius(viewer, layer_state) + marker = dict(color=color_info(layer_state, mask=None), size=radius) + + name = layer_state.layer.label + if add_data_label and not isinstance(layer_state.layer, BaseData): + name += " ({0})".format(layer_state.layer.data.label) + + return [Scatter( + x=x, + y=y, + mode="markers", + marker=marker, + name=name, + legendgroup=legend_group, + meta=dots_id, + )] diff --git a/glue_plotly/common/scatter2d.py b/glue_plotly/common/scatter2d.py index 0e0482f..fcb740d 100644 --- a/glue_plotly/common/scatter2d.py +++ b/glue_plotly/common/scatter2d.py @@ -245,7 +245,7 @@ def size_info(layer_state, mask=None): return s -def base_marker(layer_state, mask): +def base_marker(layer_state, mask=None): color = color_info(layer_state, mask) marker = dict(size=size_info(layer_state, mask), color=color, diff --git a/glue_plotly/common/tests/test_dotplot.py b/glue_plotly/common/tests/test_dotplot.py new file mode 100644 index 0000000..b0d4521 --- /dev/null +++ b/glue_plotly/common/tests/test_dotplot.py @@ -0,0 +1,78 @@ +from numpy import unique +from plotly.graph_objs import Scatter + +from glue.core import Data +from glue_qt.app import GlueApplication +from glue_qt.viewers.histogram import HistogramViewer + +from glue_plotly.common import sanitize +from glue_plotly.common.dotplot import traces_for_layer + +from glue_plotly.viewers.histogram.viewer import PlotlyHistogramView +from glue_plotly.viewers.histogram.dotplot_layer_artist import PlotlyDotplotLayerArtist + + +class SimpleDotplotViewer(PlotlyHistogramView): + _data_artist_cls = PlotlyDotplotLayerArtist + _subset_artist_cls = PlotlyDotplotLayerArtist + + +class TestDotplot: + + def setup_method(self, method): + x = [86, 86, 76, 78, 93, 100, 90, 87, 73, 61, 71, 68, 78, + 9, 87, 32, 34, 2, 57, 79, 48, 5, 8, 19, 7, 78, + 16, 15, 58, 34, 20, 63, 96, 97, 86, 92, 35, 59, 75, + 0, 53, 45, 59, 74, 59, 4, 69, 76, 97, 77, 24, 99, + 50, 6, 1, 55, 13, 40, 27, 17, 92, 72, 40, 29, 64, + 38, 77, 11, 91, 23, 59, 92, 5, 88, 15, 90, 40, 100, + 47, 28, 3, 44, 89, 75, 13, 94, 95, 43, 17, 88, 6, + 94, 100, 28, 45, 36, 63, 14, 90, 66] + self.data = Data(label="dotplot", x=x) + self.app = GlueApplication() + self.app.session.data_collection.append(self.data) + self.viewer = self.app.new_data_viewer(HistogramViewer) + self.viewer.add_data(self.data) + self.mask, self.sanitized = sanitize(self.data['x']) + + viewer_state = self.viewer.state + viewer_state.hist_n_bin = 18 + viewer_state.x_axislabel_size = 14 + viewer_state.y_axislabel_size = 8 + viewer_state.x_ticklabel_size = 18 + viewer_state.y_ticklabel_size = 20 + viewer_state.x_min = 0 + viewer_state.x_max = 100 + viewer_state.y_min = 0 + viewer_state.y_max = 15 + viewer_state.x_axislabel = 'X Axis' + viewer_state.y_axislabel = 'Y Axis' + + self.layer = self.viewer.layers[0] + self.layer.state.color = '#0e1dab' + self.layer.state.alpha = 0.85 + + def teardown_method(self, method): + self.viewer.close(warn=False) + self.viewer = None + self.app.close() + self.app = None + + def test_basic_dots(self): + traces = traces_for_layer(self.viewer, self.layer.state) + assert len(traces) == 1 + dots = traces[0] + assert isinstance(dots, Scatter) + + assert len(unique(dots.x)) == 18 + expected_y = (1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 1, + 2, 3, 4, 5, 6, 1, 2, 3, 4, 1, 2, 3, 1, 2, + 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, + 3, 4, 5, 1, 2, 1, 2, 3, 4, 5, 6, 7, 1, 2, + 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 1, 2, 3, 4, 5, + 6, 7, 8) + + assert dots.y == expected_y + assert dots.marker.size == 16 # Default figure is 640x480 diff --git a/glue_plotly/viewers/histogram/dotplot_layer_artist.py b/glue_plotly/viewers/histogram/dotplot_layer_artist.py new file mode 100644 index 0000000..8802b5a --- /dev/null +++ b/glue_plotly/viewers/histogram/dotplot_layer_artist.py @@ -0,0 +1,160 @@ +# NB: This dot plot layer artist shouldn't be used together with the +# normalized mode, as a dotplot only makes sense when the heights are integral. + +import numpy as np + +from glue.core.exceptions import IncompatibleAttribute +from glue.viewers.common.layer_artist import LayerArtist +from glue.viewers.histogram.state import HistogramLayerState +from glue_plotly.common.common import fixed_color + +from glue_plotly.common.dotplot import dot_radius, traces_for_layer + +__all__ = ["PlotlyDotplotLayerArtist"] + +SCALE_PROPERTIES = {'y_log', 'normalize', 'cumulative'} +HISTOGRAM_PROPERTIES = SCALE_PROPERTIES | {'layer', 'x_att', 'hist_x_min', + 'hist_x_max', 'hist_n_bin', 'x_log'} + +# Note that, because we need to scale the dots based on pixel space due to how Plotly sizes scatters, +# we need to update the dot sizing when the bounds change +VISUAL_PROPERTIES = {'alpha', 'color', 'zorder', 'visible', 'x_min', 'x_max', 'y_min', 'y_max'} +DATA_PROPERTIES = {'layer', 'x_att', 'y_att'} + + +class PlotlyDotplotLayerArtist(LayerArtist): + + _layer_state_cls = HistogramLayerState + + def __init__(self, view, viewer_state, layer_state=None, layer=None): + super().__init__( + viewer_state, + layer_state=layer_state, + layer=layer + ) + + self.view = view + self.bins = None + self._dots_id = None + + self._viewer_state.add_global_callback(self._update_dotplot) + self.state.add_global_callback(self._update_dotplot) + + def _get_dots(self): + return self.view.figure.select_traces(dict(meta=self._dots_id)) + + def traces(self): + return self._get_dots() + + def _calculate_histogram(self): + try: + self.state.reset_cache() + self.bins, self.hist_unscaled = self.state.histogram + except IncompatibleAttribute: + self.disable('Could not compute histogram') + self.bins = self.hist_unscaled = None + + def _scale_histogram(self): + + if self.bins is None: + return # can happen when the subset is empty + + if self.bins.size == 0: + return + + with self.view.figure.batch_update(): + + # We have to do the following to make sure that we reset the y_max as + # needed. We can't simply reset based on the maximum for this layer + # because other layers might have other values, and we also can't do: + # + # self._viewer_state.y_max = max(self._viewer_state.y_max, result[0].max()) + # + # because this would never allow y_max to get smaller. + + _, hist = self.state.histogram + self.state._y_max = hist.max() + if self._viewer_state.y_log: + self.state._y_max *= 2 + else: + self.state._y_max *= 1.2 + + if self._viewer_state.y_log: + keep = hist > 0 + if np.any(keep): + self.state._y_min = hist[keep].min() / 10 + else: + self.state._y_min = 0 + else: + self.state._y_min = 0 + + largest_y_max = max(getattr(layer, '_y_max', 0) + for layer in self._viewer_state.layers) + if np.isfinite(largest_y_max) and largest_y_max != self._viewer_state.y_max: + self._viewer_state.y_max = largest_y_max + + smallest_y_min = min(getattr(layer, '_y_min', np.inf) + for layer in self._viewer_state.layers) + if np.isfinite(smallest_y_min) and smallest_y_min != self._viewer_state.y_min: + self._viewer_state.y_min = smallest_y_min + + def _update_visual_attributes(self, changed, force=False): + if not self.enabled: + return + + with self.view.figure.batch_update(): + self.view.figure.for_each_trace(self._update_visual_attrs_for_trace, dict(meta=self._dots_id)) + + def _update_visual_attrs_for_trace(self, trace): + marker = trace.marker + marker.update(opacity=self.state.alpha, color=fixed_color(self.state), size=dot_radius(self.view, self.state)) + print(marker) + trace.update(marker=marker, + visible=self.state.visible, + unselected=dict(marker=dict(opacity=self.state.alpha))) + + def _update_data(self): + old_dots = self._get_dots() + if old_dots: + self.view._remove_traces(old_dots) + + dots = traces_for_layer(self.view, self.state, add_data_label=True) + self._dots_id = dots[0].meta if dots else None + self.view.figure.add_traces(dots) + + def _update_zorder(self): + traces = [self.view.selection_layer] + for layer in self.view.layers: + traces += list(layer.traces()) + self.view.figure.data = traces + + def _update_dotplot(self, force=False, **kwargs): + if (self._viewer_state.hist_x_min is None or + self._viewer_state.hist_x_max is None or + self._viewer_state.hist_n_bin is None or + self._viewer_state.x_att is None or + self.state.layer is None): + return + + changed = self.pop_changed_properties() + + if force or len(changed & HISTOGRAM_PROPERTIES) > 0: + self._calculate_histogram() + force = True + + if force or len(changed & DATA_PROPERTIES) > 0: + self._update_data() + force = True + + if force or len(changed & SCALE_PROPERTIES) > 0: + self._scale_histogram() + + if force or len(changed & VISUAL_PROPERTIES) > 0: + self._update_visual_attributes(changed, force=force) + + if force or "zorder" in changed: + self._update_zorder() + + def update(self): + self.state.reset_cache() + self._update_dotplot(force=True)