Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 212 additions & 20 deletions slide2vec/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Callable, Sequence

import numpy as np

Expand Down Expand Up @@ -173,15 +173,24 @@ def embed_slides(
tiling_results=tiling_results,
)
emit_progress("embedding.started", slide_count=len(prepared_slides))
local_persist_callback = None
if execution.output_dir is not None and execution.num_gpus <= 1:
local_persist_callback, _, _ = _build_incremental_persist_callback(
model=model,
preprocessing=preprocessing,
execution=execution,
process_list_path=process_list_path,
)
embedded_slides = _select_embedding_path(
model=model,
slide_records=prepared_slides,
tiling_results=tiling_results,
preprocessing=preprocessing,
execution=execution,
work_dir=work_dir,
on_embedded_slide=local_persist_callback,
)
if execution.output_dir is not None:
if execution.output_dir is not None and execution.num_gpus > 1:
for embedded_slide, tiling_result in zip(embedded_slides, tiling_results):
_persist_embedded_slide(
model,
Expand Down Expand Up @@ -216,6 +225,7 @@ def _select_embedding_path(
preprocessing: PreprocessingConfig,
execution: ExecutionOptions,
work_dir: Path,
on_embedded_slide: Callable[[SlideSpec, Any, EmbeddedSlide], None] | None = None,
):
if execution.num_gpus > 1:
if len(slide_records) == 1:
Expand Down Expand Up @@ -243,6 +253,7 @@ def _select_embedding_path(
tiling_results,
preprocessing=preprocessing,
execution=execution,
on_embedded_slide=on_embedded_slide,
)


Expand Down Expand Up @@ -447,24 +458,54 @@ def run_pipeline(
process_list_path=process_list_path,
)

embedded_slides = _compute_embedded_slides(
model,
persist_tile_embeddings = _should_persist_tile_embeddings(model, execution)
include_slide_embeddings = model.level == "slide"
pending_slides, pending_tiling_results = _pending_local_embedding_records(
successful_slides,
tiling_results,
preprocessing=preprocessing,
execution=execution,
process_list_path=process_list_path,
output_dir=output_dir,
output_format=execution.output_format,
persist_tile_embeddings=persist_tile_embeddings,
include_slide_embeddings=include_slide_embeddings,
save_latents=execution.save_latents,
resume=preprocessing.resume,
)
tile_artifacts, slide_artifacts = _collect_local_pipeline_artifacts(
local_persist_callback, _, _ = _build_incremental_persist_callback(
model=model,
embedded_slides=embedded_slides,
tiling_results=tiling_results,
preprocessing=preprocessing,
execution=execution,
process_list_path=process_list_path,
)
embedded_slides: list[EmbeddedSlide] = []
if pending_slides:
embedded_slides = _compute_embedded_slides(
model,
pending_slides,
pending_tiling_results,
preprocessing=preprocessing,
execution=execution,
on_embedded_slide=local_persist_callback,
)
tile_artifacts, slide_artifacts = _collect_pipeline_artifacts(
successful_slides,
output_dir=output_dir,
output_format=execution.output_format,
include_tile_embeddings=persist_tile_embeddings,
include_slide_embeddings=include_slide_embeddings,
)
_update_process_list_after_embedding(
process_list_path,
successful_slides=successful_slides,
persist_tile_embeddings=persist_tile_embeddings,
include_slide_embeddings=include_slide_embeddings,
tile_artifacts=tile_artifacts,
slide_artifacts=slide_artifacts,
)
emit_progress(
"embedding.finished",
slide_count=len(successful_slides),
slides_completed=len(embedded_slides),
slides_completed=len(successful_slides),
tile_artifacts=len(tile_artifacts),
slide_artifacts=len(slide_artifacts),
)
Expand Down Expand Up @@ -508,6 +549,146 @@ def _collect_local_pipeline_artifacts(
return tile_artifacts, slide_artifacts


def _build_incremental_persist_callback(
*,
model,
preprocessing: PreprocessingConfig,
execution: ExecutionOptions,
process_list_path: Path | None = None,
) -> tuple[
Callable[[SlideSpec, Any, EmbeddedSlide], None] | None,
list[TileEmbeddingArtifact],
list[SlideEmbeddingArtifact],
]:
tile_artifacts: list[TileEmbeddingArtifact] = []
slide_artifacts: list[SlideEmbeddingArtifact] = []
if execution.output_dir is None:
return None, tile_artifacts, slide_artifacts

persist_tile_embeddings = _should_persist_tile_embeddings(model, execution)
include_slide_embeddings = model.level == "slide"

def _persist_completed_slide(slide: SlideSpec, tiling_result, embedded_slide: EmbeddedSlide) -> None:
tile_artifact, slide_artifact = _persist_embedded_slide(
model,
embedded_slide,
tiling_result,
preprocessing=preprocessing,
execution=execution,
)
if tile_artifact is not None:
tile_artifacts.append(tile_artifact)
if slide_artifact is not None:
slide_artifacts.append(slide_artifact)
if process_list_path is not None and process_list_path.is_file():
_update_process_list_after_embedding(
process_list_path,
successful_slides=[slide],
persist_tile_embeddings=persist_tile_embeddings,
include_slide_embeddings=include_slide_embeddings,
tile_artifacts=[tile_artifact] if tile_artifact is not None else [],
slide_artifacts=[slide_artifact] if slide_artifact is not None else [],
)

return _persist_completed_slide, tile_artifacts, slide_artifacts


def _pending_local_embedding_records(
successful_slides: Sequence[SlideSpec],
tiling_results,
*,
process_list_path: Path,
output_dir: Path,
output_format: str,
persist_tile_embeddings: bool,
include_slide_embeddings: bool,
save_latents: bool,
resume: bool,
) -> tuple[list[SlideSpec], list[Any]]:
if not resume:
return list(successful_slides), list(tiling_results)

completed_ids = _completed_local_embedding_sample_ids(
process_list_path,
output_dir=output_dir,
output_format=output_format,
persist_tile_embeddings=persist_tile_embeddings,
include_slide_embeddings=include_slide_embeddings,
save_latents=save_latents,
)
pending_slides: list[SlideSpec] = []
pending_tiling_results: list[Any] = []
for slide, tiling_result in zip(successful_slides, tiling_results):
if slide.sample_id in completed_ids:
continue
pending_slides.append(slide)
pending_tiling_results.append(tiling_result)
return pending_slides, pending_tiling_results


def _completed_local_embedding_sample_ids(
process_list_path: Path,
*,
output_dir: Path,
output_format: str,
persist_tile_embeddings: bool,
include_slide_embeddings: bool,
save_latents: bool,
) -> set[str]:
process_df = _load_process_df(
process_list_path,
include_feature_status=persist_tile_embeddings or include_slide_embeddings,
include_aggregation_status=include_slide_embeddings,
)
completed_ids: set[str] = set()
for row in process_df.to_dict("records"):
sample_id = str(row["sample_id"])
if row.get("tiling_status") != "success":
continue
if persist_tile_embeddings and row.get("feature_status") != "success":
continue
if include_slide_embeddings and row.get("aggregation_status") != "success":
continue
if not _has_complete_local_embedding_outputs(
sample_id,
output_dir=output_dir,
output_format=output_format,
persist_tile_embeddings=persist_tile_embeddings,
include_slide_embeddings=include_slide_embeddings,
save_latents=save_latents,
):
continue
completed_ids.add(sample_id)
return completed_ids


def _has_complete_local_embedding_outputs(
sample_id: str,
*,
output_dir: Path,
output_format: str,
persist_tile_embeddings: bool,
include_slide_embeddings: bool,
save_latents: bool,
) -> bool:
if persist_tile_embeddings:
tile_artifact_path = output_dir / "tile_embeddings" / f"{sample_id}.{output_format}"
tile_metadata_path = output_dir / "tile_embeddings" / f"{sample_id}.meta.json"
if not tile_artifact_path.is_file() or not tile_metadata_path.is_file():
return False
if include_slide_embeddings:
slide_artifact_path = output_dir / "slide_embeddings" / f"{sample_id}.{output_format}"
slide_metadata_path = output_dir / "slide_embeddings" / f"{sample_id}.meta.json"
if not slide_artifact_path.is_file() or not slide_metadata_path.is_file():
return False
if save_latents:
latent_suffix = "pt" if output_format == "pt" else "npz"
latent_path = output_dir / "slide_latents" / f"{sample_id}.{latent_suffix}"
if not latent_path.is_file():
return False
return True


def _collect_distributed_pipeline_artifacts(
*,
model,
Expand Down Expand Up @@ -551,6 +732,7 @@ def _compute_embedded_slides(
*,
preprocessing: PreprocessingConfig,
execution: ExecutionOptions,
on_embedded_slide: Callable[[SlideSpec, Any, EmbeddedSlide], None] | None = None,
) -> list[EmbeddedSlide]:
loaded = model._load_backend()
embedded_slides: list[EmbeddedSlide] = []
Expand Down Expand Up @@ -589,15 +771,16 @@ def _compute_embedded_slides(
sample_id=slide.sample_id,
has_latents=latents is not None,
)
embedded_slides.append(
_make_embedded_slide(
slide=slide,
tiling_result=tiling_result,
tile_embeddings=tile_embeddings,
slide_embedding=slide_embedding,
latents=latents,
)
embedded_slide = _make_embedded_slide(
slide=slide,
tiling_result=tiling_result,
tile_embeddings=tile_embeddings,
slide_embedding=slide_embedding,
latents=latents,
)
embedded_slides.append(embedded_slide)
if on_embedded_slide is not None:
on_embedded_slide(slide, tiling_result, embedded_slide)
emit_progress(
"embedding.slide.finished",
sample_id=slide.sample_id,
Expand Down Expand Up @@ -1140,10 +1323,19 @@ def _build_hs2p_configs(preprocessing: PreprocessingConfig):
)


def _load_process_df(process_list_path: Path):
def _load_process_df(
process_list_path: Path,
*,
include_feature_status: bool = False,
include_aggregation_status: bool = False,
):
from slide2vec.utils.tiling_io import load_process_df

return load_process_df(process_list_path)
return load_process_df(
process_list_path,
include_feature_status=include_feature_status,
include_aggregation_status=include_aggregation_status,
)


def _load_tiling_result_from_row(row):
Expand Down
10 changes: 8 additions & 2 deletions tests/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,15 @@ def test_run_pipeline_emits_local_progress_events_in_order(monkeypatch, tmp_path
)
monkeypatch.setattr(
inference,
"_collect_local_pipeline_artifacts",
lambda **kwargs: (["tile-artifact"], ["slide-artifact"]),
"_build_incremental_persist_callback",
lambda **kwargs: (None, [], []),
)
monkeypatch.setattr(
inference,
"_collect_pipeline_artifacts",
lambda *args, **kwargs: (["tile-artifact"], ["slide-artifact"]),
)
monkeypatch.setattr(inference, "_update_process_list_after_embedding", lambda *args, **kwargs: None)

model = SimpleNamespace(
name="prov-gigapath",
Expand Down
Loading
Loading