@@ -338,7 +338,7 @@ def _render_shapes(
338
338
cax = None
339
339
if aggregate_with_reduction is not None :
340
340
vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
341
- vmax = aggregate_with_reduction [1 ].values if norm .vmin is None else norm .vmax
341
+ vmax = aggregate_with_reduction [1 ].values if norm .vmax is None else norm .vmax
342
342
if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
343
343
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
344
344
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
@@ -846,20 +846,22 @@ def _render_images(
846
846
# 2) Image has any number of channels but 1
847
847
else :
848
848
layers = {}
849
- for ch_index , c in enumerate (channels ):
850
- layers [c ] = img .sel (c = c ).copy (deep = True ).squeeze ()
851
-
852
- if not isinstance (render_params .cmap_params , list ):
853
- if render_params .cmap_params .norm is not None :
854
- layers [c ] = render_params .cmap_params .norm (layers [c ])
849
+ for ch_idx , ch in enumerate (channels ):
850
+ layers [ch ] = img .sel (c = ch ).copy (deep = True ).squeeze ()
851
+ if isinstance (render_params .cmap_params , list ):
852
+ ch_norm = render_params .cmap_params [ch_idx ].norm
853
+ ch_cmap_is_default = render_params .cmap_params [ch_idx ].cmap_is_default
855
854
else :
856
- if render_params .cmap_params [ch_index ].norm is not None :
857
- layers [c ] = render_params .cmap_params [ch_index ].norm (layers [c ])
855
+ ch_norm = render_params .cmap_params .norm
856
+ ch_cmap_is_default = render_params .cmap_params .cmap_is_default
857
+
858
+ if not ch_cmap_is_default and ch_norm is not None :
859
+ layers [ch_idx ] = ch_norm (layers [ch_idx ])
858
860
859
861
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
860
862
if palette is None and n_channels == 3 and not isinstance (render_params .cmap_params , list ):
861
863
if render_params .cmap_params .cmap_is_default : # -> use RGB
862
- stacked = np .stack ([layers [c ] for c in channels ], axis = - 1 )
864
+ stacked = np .stack ([layers [ch ] for ch in layers ], axis = - 1 )
863
865
else : # -> use given cmap for each channel
864
866
channel_cmaps = [render_params .cmap_params .cmap ] * n_channels
865
867
stacked = (
@@ -892,12 +894,54 @@ def _render_images(
892
894
# overwrite if n_channels == 2 for intuitive result
893
895
if n_channels == 2 :
894
896
seed_colors = ["#ff0000ff" , "#00ff00ff" ]
895
- else :
897
+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
898
+ colored = np .stack (
899
+ [channel_cmaps [ch_ind ](layers [ch ]) for ch_ind , ch in enumerate (channels )],
900
+ 0 ,
901
+ ).sum (0 )
902
+ colored = colored [:, :, :3 ]
903
+ elif n_channels == 3 :
896
904
seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
905
+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
906
+ colored = np .stack (
907
+ [channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )],
908
+ 0 ,
909
+ ).sum (0 )
910
+ colored = colored [:, :, :3 ]
911
+ else :
912
+ if isinstance (render_params .cmap_params , list ):
913
+ cmap_is_default = render_params .cmap_params [0 ].cmap_is_default
914
+ else :
915
+ cmap_is_default = render_params .cmap_params .cmap_is_default
897
916
898
- channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
899
- colored = np .stack ([channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )], 0 ).sum (0 )
900
- colored = colored [:, :, :3 ]
917
+ if cmap_is_default :
918
+ seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
919
+ else :
920
+ # Sample n_channels colors evenly from the colormap
921
+ if isinstance (render_params .cmap_params , list ):
922
+ seed_colors = [
923
+ render_params .cmap_params [i ].cmap (i / (n_channels - 1 )) for i in range (n_channels )
924
+ ]
925
+ else :
926
+ seed_colors = [render_params .cmap_params .cmap (i / (n_channels - 1 )) for i in range (n_channels )]
927
+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
928
+
929
+ # Stack (n_channels, height, width) → (height*width, n_channels)
930
+ H , W = next (iter (layers .values ())).shape
931
+ comp_rgb = np .zeros ((H , W , 3 ), dtype = float )
932
+
933
+ # For each channel: map to RGBA, apply constant alpha, then add
934
+ for ch_idx , ch in enumerate (channels ):
935
+ layer_arr = layers [ch ]
936
+ rgba = channel_cmaps [ch_idx ](layer_arr )
937
+ rgba [..., 3 ] = render_params .alpha
938
+ comp_rgb += rgba [..., :3 ] * rgba [..., 3 ][..., None ]
939
+
940
+ colored = np .clip (comp_rgb , 0 , 1 )
941
+ logger .info (
942
+ f"Your image has { n_channels } channels. Sampling categorical colors and using "
943
+ f"multichannel strategy 'stack' to render."
944
+ ) # TODO: update when pca is added as strategy
901
945
902
946
_ax_show_and_transform (
903
947
colored ,
@@ -943,6 +987,7 @@ def _render_images(
943
987
zorder = render_params .zorder ,
944
988
)
945
989
990
+ # 2D) Image has n channels, no palette but cmap info
946
991
elif palette is not None and got_multiple_cmaps :
947
992
raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
948
993
0 commit comments