Skip to content

Commit b739428

Browse files
author
Sonja Stockhaus
committed
code by Tim (#451), only the stacking approach
1 parent c5cb734 commit b739428

File tree

2 files changed

+95
-26
lines changed

2 files changed

+95
-26
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _render_shapes(
338338
cax = None
339339
if aggregate_with_reduction is not None:
340340
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
341-
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
341+
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
342342
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
343343
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
344344
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
@@ -846,20 +846,22 @@ def _render_images(
846846
# 2) Image has any number of channels but 1
847847
else:
848848
layers = {}
849-
for ch_index, c in enumerate(channels):
850-
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
851-
852-
if not isinstance(render_params.cmap_params, list):
853-
if render_params.cmap_params.norm is not None:
854-
layers[c] = render_params.cmap_params.norm(layers[c])
849+
for ch_idx, ch in enumerate(channels):
850+
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
851+
if isinstance(render_params.cmap_params, list):
852+
ch_norm = render_params.cmap_params[ch_idx].norm
853+
ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default
855854
else:
856-
if render_params.cmap_params[ch_index].norm is not None:
857-
layers[c] = render_params.cmap_params[ch_index].norm(layers[c])
855+
ch_norm = render_params.cmap_params.norm
856+
ch_cmap_is_default = render_params.cmap_params.cmap_is_default
857+
858+
if not ch_cmap_is_default and ch_norm is not None:
859+
layers[ch_idx] = ch_norm(layers[ch_idx])
858860

859861
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
860862
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
861863
if render_params.cmap_params.cmap_is_default: # -> use RGB
862-
stacked = np.stack([layers[c] for c in channels], axis=-1)
864+
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
863865
else: # -> use given cmap for each channel
864866
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
865867
stacked = (
@@ -892,12 +894,54 @@ def _render_images(
892894
# overwrite if n_channels == 2 for intuitive result
893895
if n_channels == 2:
894896
seed_colors = ["#ff0000ff", "#00ff00ff"]
895-
else:
897+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
898+
colored = np.stack(
899+
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
900+
0,
901+
).sum(0)
902+
colored = colored[:, :, :3]
903+
elif n_channels == 3:
896904
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
905+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
906+
colored = np.stack(
907+
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
908+
0,
909+
).sum(0)
910+
colored = colored[:, :, :3]
911+
else:
912+
if isinstance(render_params.cmap_params, list):
913+
cmap_is_default = render_params.cmap_params[0].cmap_is_default
914+
else:
915+
cmap_is_default = render_params.cmap_params.cmap_is_default
897916

898-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
899-
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
900-
colored = colored[:, :, :3]
917+
if cmap_is_default:
918+
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
919+
else:
920+
# Sample n_channels colors evenly from the colormap
921+
if isinstance(render_params.cmap_params, list):
922+
seed_colors = [
923+
render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels)
924+
]
925+
else:
926+
seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)]
927+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
928+
929+
# Stack (n_channels, height, width) → (height*width, n_channels)
930+
H, W = next(iter(layers.values())).shape
931+
comp_rgb = np.zeros((H, W, 3), dtype=float)
932+
933+
# For each channel: map to RGBA, apply constant alpha, then add
934+
for ch_idx, ch in enumerate(channels):
935+
layer_arr = layers[ch]
936+
rgba = channel_cmaps[ch_idx](layer_arr)
937+
rgba[..., 3] = render_params.alpha
938+
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]
939+
940+
colored = np.clip(comp_rgb, 0, 1)
941+
logger.info(
942+
f"Your image has {n_channels} channels. Sampling categorical colors and using "
943+
f"multichannel strategy 'stack' to render."
944+
) # TODO: update when pca is added as strategy
901945

902946
_ax_show_and_transform(
903947
colored,
@@ -943,6 +987,7 @@ def _render_images(
943987
zorder=render_params.zorder,
944988
)
945989

990+
# 2D) Image has n channels, no palette but cmap info
946991
elif palette is not None and got_multiple_cmaps:
947992
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
948993

src/spatialdata_plot/pl/utils.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,7 +2006,7 @@ def _validate_col_for_column_table(
20062006
table_name = next(iter(tables))
20072007
if len(tables) > 1:
20082008
warnings.warn(
2009-
f"Multiple tables contain color column, using {table_name}",
2009+
f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.",
20102010
UserWarning,
20112011
stacklevel=2,
20122012
)
@@ -2042,25 +2042,49 @@ def _validate_image_render_params(
20422042
element_params[el] = {}
20432043
spatial_element = param_dict["sdata"][el]
20442044

2045+
# robustly get channel names from image or multiscale image
20452046
spatial_element_ch = (
2046-
spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c
2047+
spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values
20472048
)
2048-
if (channel := param_dict["channel"]) is not None and (
2049-
(isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch))
2050-
or all(ch in spatial_element_ch for ch in channel)
2051-
):
2049+
channel = param_dict["channel"]
2050+
if channel is not None:
2051+
# Normalize channel to always be a list of str or a list of int
2052+
if isinstance(channel, str):
2053+
channel = [channel]
2054+
2055+
if isinstance(channel, int):
2056+
channel = [channel]
2057+
2058+
# If channel is a list, ensure all elements are the same type
2059+
if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
2060+
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")
2061+
2062+
invalid = [c for c in channel if c not in spatial_element_ch]
2063+
if invalid:
2064+
raise ValueError(
2065+
f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}"
2066+
)
20522067
element_params[el]["channel"] = channel
20532068
else:
20542069
element_params[el]["channel"] = None
20552070

20562071
element_params[el]["alpha"] = param_dict["alpha"]
20572072

2058-
if isinstance(palette := param_dict["palette"], list):
2073+
palette = param_dict["palette"]
2074+
assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure
2075+
2076+
if isinstance(palette, list):
2077+
# case A: single palette for all channels
20592078
if len(palette) == 1:
20602079
palette_length = len(channel) if channel is not None else len(spatial_element_ch)
20612080
palette = palette * palette_length
2062-
if (channel is not None and len(palette) != len(channel)) and len(palette) != len(spatial_element_ch):
2063-
palette = None
2081+
# case B: one palette per channel (either given or derived from channel length)
2082+
channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel
2083+
if channels_to_use is not None and len(palette) != len(channels_to_use):
2084+
raise ValueError(
2085+
f"Palette length ({len(palette)}) does not match channel length "
2086+
f"({', '.join(str(c) for c in channels_to_use)})."
2087+
)
20642088
element_params[el]["palette"] = palette
20652089
element_params[el]["na_color"] = param_dict["na_color"]
20662090

@@ -2086,7 +2110,7 @@ def _validate_image_render_params(
20862110
def _get_wanted_render_elements(
20872111
sdata: SpatialData,
20882112
sdata_wanted_elements: list[str],
2089-
params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams),
2113+
params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
20902114
cs: str,
20912115
element_type: Literal["images", "labels", "points", "shapes"],
20922116
) -> tuple[list[str], list[str], bool]:
@@ -2243,7 +2267,7 @@ def _create_image_from_datashader_result(
22432267

22442268

22452269
def _datashader_aggregate_with_function(
2246-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2270+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
22472271
cvs: Canvas,
22482272
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
22492273
col_for_color: str | None,
@@ -2307,7 +2331,7 @@ def _datashader_aggregate_with_function(
23072331

23082332

23092333
def _datshader_get_how_kw_for_spread(
2310-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2334+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
23112335
) -> str:
23122336
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
23132337
reduction = reduction or "sum"

0 commit comments

Comments
 (0)