Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update masks for scatter layers #109

Merged
merged 3 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions glue_plotly/common/base_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ def bbox_mask(viewer_state, x, y, z):
(z >= viewer_state.z_min) & (z <= viewer_state.z_max)


def clipped_data(viewer_state, layer_state):
x = layer_state.layer[viewer_state.x_att]
y = layer_state.layer[viewer_state.y_att]
z = layer_state.layer[viewer_state.z_att]
def clipped_data(viewer_state, x, y, z):

# Plotly doesn't show anything outside the bounding box
mask = bbox_mask(viewer_state, x, y, z)
Expand Down
3 changes: 2 additions & 1 deletion glue_plotly/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def rgb_colors(layer_state, mask, cmap_att):
return rgba_strs


def color_info(layer_state, mask=None,
def color_info(layer_state,
mask=None,
mode_att="cmap_mode",
cmap_att="cmap_att"):
if getattr(layer_state, mode_att, "Fixed") == "Fixed":
Expand Down
11 changes: 10 additions & 1 deletion glue_plotly/common/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,16 @@
def traces_for_scatter_layer(viewer_state, layer_state, hover_data=None, add_data_label=True):
x = layer_state.layer[viewer_state.x_att].copy()
y = layer_state.layer[viewer_state.y_att].copy()
mask, (x, y) = sanitize(x, y)
arrs = [x, y]
if layer_state.cmap_mode == "Linear":
cvals = layer_state.layer[layer_state.cmap_att].copy()
arrs.append(cvals)
if layer_state.size_mode == "Linear":
svals = layer_state.layer[layer_state.size_att].copy()
arrs.append(svals)

Check warning on line 318 in glue_plotly/common/image.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/image.py#L312-L318

Added lines #L312 - L318 were not covered by tests

mask, sanitized = sanitize(*arrs)
x, y = sanitized[:2]

Check warning on line 321 in glue_plotly/common/image.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/image.py#L320-L321

Added lines #L320 - L321 were not covered by tests

marker = dict(color=color_info(layer_state),
opacity=layer_state.alpha,
Expand Down
11 changes: 10 additions & 1 deletion glue_plotly/common/scatter2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,16 @@

x = layer_state.layer[viewer.state.x_att].copy()
y = layer_state.layer[viewer.state.y_att].copy()
mask, (x, y) = sanitize(x, y)
arrs = [x, y]
if layer_state.cmap_mode == "Linear":
cvals = layer_state.layer[layer_state.cmap_att].copy()
arrs.append(cvals)

Check warning on line 271 in glue_plotly/common/scatter2d.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/scatter2d.py#L270-L271

Added lines #L270 - L271 were not covered by tests
if layer_state.size_mode == "Linear":
svals = layer_state.layer[layer_state.size_att].copy()
arrs.append(svals)

Check warning on line 274 in glue_plotly/common/scatter2d.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/scatter2d.py#L273-L274

Added lines #L273 - L274 were not covered by tests

mask, sanitized = sanitize(*arrs)
x, y = sanitized[:2]

legend_group = uuid4().hex

Expand Down
41 changes: 33 additions & 8 deletions glue_plotly/common/scatter3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@
from plotly.graph_objs import Cone, Scatter3d
from uuid import uuid4

from glue_plotly.common import color_info
from glue_plotly.common.base_3d import clipped_data
from glue_plotly.common import color_info, sanitize
from glue_plotly.common.base_3d import bbox_mask

try:
from glue_vispy_viewers.scatter.layer_state import ScatterLayerState
except ImportError:
ScatterLayerState = type(None)

Check warning on line 15 in glue_plotly/common/scatter3d.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/scatter3d.py#L14-L15

Added lines #L14 - L15 were not covered by tests

def size_info(layer_state, mask):

def size_info(layer_state, mask, size_att="size_attribute"):

# set all points to be the same size, with set scaling
if layer_state.size_mode == 'Fixed':
return layer_state.size_scaling * layer_state.size

# scale size of points by set size scaling
else:
s = ensure_numerical(layer_state.layer[layer_state.size_attribute][mask].ravel())
s = ensure_numerical(layer_state.layer[getattr(layer_state, size_att)][mask].ravel())

Check warning on line 26 in glue_plotly/common/scatter3d.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/scatter3d.py#L26

Added line #L26 was not covered by tests
s = ((s - layer_state.size_vmin) /
(layer_state.size_vmax - layer_state.size_vmin))
# The following ensures that the sizes are in the
Expand Down Expand Up @@ -110,11 +115,31 @@

def traces_for_layer(viewer_state, layer_state, hover_data=None, add_data_label=True):

x, y, z, mask = clipped_data(viewer_state, layer_state)
x = layer_state.layer[viewer_state.x_att]
y = layer_state.layer[viewer_state.y_att]
z = layer_state.layer[viewer_state.z_att]

vispy_layer_state = isinstance(layer_state, ScatterLayerState)
cmap_mode_attr = "color_mode" if vispy_layer_state else "cmap_mode"
cmap_attr = "cmap_attribute" if vispy_layer_state else "cmap_att"
size_attr = "size_attribute" if vispy_layer_state else "size_att"
arrs = [x, y, z]
if getattr(layer_state, cmap_mode_attr) == "Linear":
cvals = layer_state.layer[getattr(layer_state, cmap_attr)].copy()
arrs.append(cvals)

Check warning on line 129 in glue_plotly/common/scatter3d.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/scatter3d.py#L128-L129

Added lines #L128 - L129 were not covered by tests
if layer_state.size_mode == "Linear":
svals = layer_state.layer[getattr(layer_state, size_attr)].copy()
arrs.append(svals)

Check warning on line 132 in glue_plotly/common/scatter3d.py

View check run for this annotation

Codecov / codecov/patch

glue_plotly/common/scatter3d.py#L131-L132

Added lines #L131 - L132 were not covered by tests

mask, _ = sanitize(*arrs)
bounds_mask = bbox_mask(viewer_state, x, y, z)
mask &= bounds_mask
x, y, z = x[mask], y[mask], z[mask]

marker = dict(color=color_info(layer_state, mask=mask,
mode_att="color_mode",
cmap_att="cmap_attribute"),
size=size_info(layer_state, mask),
mode_att=cmap_mode_attr,
cmap_att=cmap_attr),
size=size_info(layer_state, mask, size_att=size_attr),
opacity=layer_state.alpha,
line=dict(width=0))

Expand Down
Loading