Skip to content

Render shapes fails for categoricals of different length across coordinate systems #425

Open
@Marius1311

Description

@Marius1311

Hi there, I'm trying to show a leiden clustering in space for an sdata object with many coordinate systems. I'm calling

    (
        sdata
        .pl.render_shapes(fill_alpha=1, method="matplotlib", color="leiden")
        .pl.show(colorbar=False)
    )

which should show segmentation masks, colored by the leiden clustering. This works for individual samples (when using .filter_by_coordinate_system(), and it also work when all clusters are present in every coordinate system. It fails when the number of categories per coordinate systems differs, with:

File [/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py:937](https://jupyter.euler.hpc.ethz.ch/user/mlange/lab/workspaces/auto-o/tree/cluster/home/mlange/mlange/github/gli3_merscope_analysis/analysis/preprocessing/post_segmentation/v4/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py#line=936), in PlotAccessor.show(self, coordinate_systems, legend_fontsize, legend_fontweight, legend_loc, legend_fontoutline, na_in_legend, colorbar, wspace, hspace, ncols, frameon, figsize, dpi, fig, title, share_extent, pad_extent, ax, return_ax, save)
    932     wanted_elements, wanted_shapes_on_this_cs, wants_shapes = _get_wanted_render_elements(
    933         sdata, wanted_elements, params_copy, cs, "shapes"
    934     )
    936     if wanted_shapes_on_this_cs:
--> 937         _render_shapes(
    938             sdata=sdata,
    939             render_params=params_copy,
    940             coordinate_system=cs,
    941             ax=ax,
    942             fig_params=fig_params,
    943             scalebar_params=scalebar_params,
    944             legend_params=legend_params,
    945         )
    947 elif cmd == "render_points" and has_points:
    948     wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements(
    949         sdata, wanted_elements, params_copy, cs, "points"
    950     )

File [/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py:109](https://jupyter.euler.hpc.ethz.ch/user/mlange/lab/workspaces/auto-o/tree/cluster/home/mlange/mlange/github/gli3_merscope_analysis/analysis/preprocessing/post_segmentation/v4/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py#line=108), in _render_shapes(sdata, render_params, coordinate_system, ax, fig_params, scalebar_params, legend_params)
    106     sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
    108 # get color vector (categorical or continuous)
--> 109 color_source_vector, color_vector, _ = _set_color_source_vec(
    110     sdata=sdata_filt,
    111     element=sdata_filt[element],
    112     element_name=element,
    113     value_to_plot=col_for_color,
    114     groups=groups,
    115     palette=render_params.palette,
    116     na_color=render_params.color or render_params.cmap_params.na_color,
    117     cmap_params=render_params.cmap_params,
    118     table_name=table_name,
    119     table_layer=table_layer,
    120 )
    122 values_are_categorical = color_source_vector is not None
    124 # color_source_vector is None when the values aren't categorical

File [/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/utils.py:757](https://jupyter.euler.hpc.ethz.ch/user/mlange/lab/workspaces/auto-o/tree/cluster/home/mlange/mlange/github/gli3_merscope_analysis/analysis/preprocessing/post_segmentation/v4/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/utils.py#line=756), in _set_color_source_vec(sdata, element, value_to_plot, na_color, element_name, groups, palette, cmap_params, table_name, table_layer)
    753     return None, color_source_vector, False
    755 color_source_vector = pd.Categorical(color_source_vector)  # convert, e.g., `pd.Series`
--> 757 color_mapping = _get_categorical_color_mapping(
    758     adata=sdata.table,
    759     cluster_key=value_to_plot,
    760     color_source_vector=color_source_vector,
    761     groups=groups,
    762     palette=palette,
    763     na_color=na_color,
    764 )
    766 color_source_vector = color_source_vector.set_categories(color_mapping.keys())
    767 if color_mapping is None:

File [/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/utils.py:934](https://jupyter.euler.hpc.ethz.ch/user/mlange/lab/workspaces/auto-o/tree/cluster/home/mlange/mlange/github/gli3_merscope_analysis/analysis/preprocessing/post_segmentation/v4/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/utils.py#line=933), in _get_categorical_color_mapping(adata, na_color, cluster_key, color_source_vector, groups, palette)
    932     base_mapping = _get_default_categorial_color_mapping(color_source_vector)
    933 else:
--> 934     base_mapping = _generate_base_categorial_color_mapping(adata, cluster_key, color_source_vector, na_color)
    936 return _modify_categorical_color_mapping(mapping=base_mapping, groups=groups, palette=palette)

File [/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/utils.py:867](https://jupyter.euler.hpc.ethz.ch/user/mlange/lab/workspaces/auto-o/tree/cluster/home/mlange/mlange/github/gli3_merscope_analysis/analysis/preprocessing/post_segmentation/v4/cluster/project/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/utils.py#line=866), in _generate_base_categorial_color_mapping(adata, cluster_key, color_source_vector, na_color)
    864     na_color = to_hex(to_rgba(na_color)[:3])
    866     if na_color and len(categories) > len(colors):
--> 867         return dict(zip(categories, colors + [na_color], strict=True))
    869     return dict(zip(categories, colors, strict=True))
    871 return _get_default_categorial_color_mapping(color_source_vector)

ValueError: zip() argument 2 is shorter than argument 1

What seems to be happening is that colors = adata.uns[f"{cluster_key}_colors"] is correctly subsetted, but categories = color_source_vector.categories.tolist() + ["NaN"] is not. Eventually, these two differ in length. It's easy to fix this, by removing unused categories from the color_source_vector, i.e. changing that one line to:

categories = color_source_vector.remove_unused_categories().categories.tolist() + ["NaN"]

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions