Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1132,4 +1132,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@
" heart_array = itk.array_from_image(result[\"heart\"])\n",
" vessels_array = itk.array_from_image(result[\"major_vessels\"])\n",
"\n",
" labelmap_essentials = segmenter.trim_mask_to_essentials(result[\"labelmap\"])\n",
" labelmap_essentials_array = itk.array_from_image(labelmap_essentials)\n",
"\n",
" # Select middle slice\n",
" mid_slice = image_array.shape[0] // 2\n",
"\n",
Expand All @@ -380,7 +383,8 @@
"\n",
" axes[0, 1].imshow(image_array[mid_slice, :, :], cmap=\"gray\", vmin=-200, vmax=400)\n",
" labelmap_overlay = np.ma.masked_where(\n",
" labelmap_array[mid_slice, :, :] == 0, labelmap_array[mid_slice, :, :]\n",
" labelmap_essentials_array[mid_slice, :, :] == 0,\n",
" labelmap_essentials_array[mid_slice, :, :],\n",
" )\n",
" axes[0, 1].imshow(labelmap_overlay, cmap=\"jet\", alpha=0.5, vmin=1, vmax=10)\n",
" axes[0, 1].set_title(\"Labelmap Overlay\")\n",
Expand All @@ -407,7 +411,8 @@
" mid_sagittal = image_array.shape[2] // 2\n",
" axes[1, 1].imshow(image_array[:, :, mid_sagittal], cmap=\"gray\", vmin=-200, vmax=400)\n",
" sagittal_overlay = np.ma.masked_where(\n",
" labelmap_array[:, :, mid_sagittal] == 0, labelmap_array[:, :, mid_sagittal]\n",
" labelmap_essentials_array[:, :, mid_sagittal] == 0,\n",
" labelmap_essentials_array[:, :, mid_sagittal],\n",
" )\n",
" axes[1, 1].imshow(sagittal_overlay, cmap=\"jet\", alpha=0.5, vmin=1, vmax=10)\n",
" axes[1, 1].set_title(\"Sagittal View\")\n",
Expand All @@ -417,7 +422,8 @@
" mid_coronal = image_array.shape[1] // 2\n",
" axes[1, 2].imshow(image_array[:, mid_coronal, :], cmap=\"gray\", vmin=-200, vmax=400)\n",
" coronal_overlay = np.ma.masked_where(\n",
" labelmap_array[:, mid_coronal, :] == 0, labelmap_array[:, mid_coronal, :]\n",
" labelmap_essentials_array[:, mid_coronal, :] == 0,\n",
" labelmap_essentials_array[:, mid_coronal, :],\n",
" )\n",
" axes[1, 2].imshow(coronal_overlay, cmap=\"jet\", alpha=0.5, vmin=1, vmax=10)\n",
" axes[1, 2].set_title(\"Coronal View\")\n",
Expand Down Expand Up @@ -457,6 +463,7 @@
"\n",
"# Convert heart mask to VTK\n",
"heart_vtk = itk.vtk_image_from_image(result[\"heart\"])\n",
"heart_essentials_vtk = itk.vtk_image_from_image(labelmap_essentials)\n",
"vessels_vtk = itk.vtk_image_from_image(result[\"major_vessels\"])\n",
"\n",
"# Create PyVista plotter\n",
Expand All @@ -466,7 +473,15 @@
"heart_grid = pv.wrap(heart_vtk)\n",
"heart_surface = heart_grid.contour([0.5])\n",
"if heart_surface.n_points > 0:\n",
" plotter.add_mesh(heart_surface, color=\"red\", opacity=1.0, label=\"Heart\")\n",
" plotter.add_mesh(heart_surface, color=\"red\", opacity=0.5, label=\"Heart\")\n",
"\n",
"# Extract heart surface\n",
"heart_essentials_grid = pv.wrap(heart_essentials_vtk)\n",
"heart_essentials_surface = heart_essentials_grid.contour([0.5])\n",
"if heart_essentials_surface.n_points > 0:\n",
" plotter.add_mesh(\n",
" heart_essentials_surface, color=\"grey\", opacity=1.0, label=\"Heart Essential\"\n",
" )\n",
"\n",
"# Extract vessels surface\n",
"vessels_grid = pv.wrap(vessels_vtk)\n",
Expand All @@ -486,6 +501,14 @@
"print(f\"3D visualization saved to: {screenshot_path}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6046393",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "o3p4q5r6",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ dependencies = [
# AI/ML and segmentation
"monai>=1.3.0",
"torch>=2.0.0,<3.0.0",
"transformers>=4.21.0",
"transformers>=4.21.0,<5.0.0",
"totalsegmentator>=2.0.0",

# Registration
Expand Down
4 changes: 2 additions & 2 deletions src/physiomotion4d/cli/fit_statistical_model_to_patient.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def main() -> int:
pca_model=pca_model,
pca_number_of_modes=args.pca_number_of_modes,
)
if args.use_mask_to_mask:
workflow.set_use_mask_to_mask_registration(args.use_mask_to_mask)
if args.use_mask_to_image:
workflow.set_use_mask_to_image_registration(
True,
Expand All @@ -252,8 +254,6 @@ def main() -> int:
print("\nStarting registration pipeline...")
print("=" * 70)
result = workflow.run_workflow(
use_mask_to_mask_registration=args.use_mask_to_mask,
use_mask_to_image_registration=args.use_mask_to_image,
use_icon_registration_refinement=args.use_icon_refinement,
)

Expand Down
7 changes: 5 additions & 2 deletions src/physiomotion4d/cli/visualize_pca_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,11 @@ def main() -> int:
traceback.print_exc()
return 1

if not isinstance(mean_mesh, pv.PolyData):
print("Error: PCA mean surface must be a PolyData (.vtp).")
if not isinstance(mean_mesh, (pv.PolyData, pv.UnstructuredGrid)):
print(
"Error: PCA mean surface must be PolyData or UnstructuredGrid.",
f"Type: {type(mean_mesh)}",
)
return 1

try:
Expand Down
7 changes: 5 additions & 2 deletions src/physiomotion4d/contour_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,15 @@ def create_reference_image(

def create_mask_from_mesh(
self,
mesh: pv.DataSet,
mesh: pv.DataSet | pv.UnstructuredGrid,
reference_image: itk.Image,
) -> itk.Image:
ref_spacing = np.array(reference_image.GetSpacing())

# Create trimesh object with LPS coordinates
if isinstance(mesh, pv.UnstructuredGrid):
mesh = mesh.extract_surface()

if hasattr(mesh, "n_faces_strict"):
# PyVista PolyData
faces = mesh.faces.reshape((mesh.n_faces_strict, 4))[:, 1:]
Expand Down Expand Up @@ -248,7 +251,7 @@ def create_mask_from_mesh(

def create_distance_map(
self,
mesh: pv.DataSet,
mesh: pv.DataSet | pv.UnstructuredGrid,
reference_image: itk.Image,
squared_distance: bool = False,
negative_inside: bool = True,
Expand Down
162 changes: 145 additions & 17 deletions src/physiomotion4d/segment_heart_simpleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, log_level: int | str = logging.INFO):
# From Base Class
# self.contrast_mask_ids = {135: "contrast"}

self.trim_mesh_to_essentials = False
self._trim_mask = False

self.set_other_and_all_mask_ids()

Expand All @@ -112,13 +112,13 @@ def __init__(self, log_level: int | str = logging.INFO):
"SimplewareScript_heart_segmentation.py",
)

def set_trim_mesh_to_essentials(self, trim_mesh_to_essentials: bool) -> None:
"""Set whether to trim mesh to common and critical structures.
def set_trim_mask_to_essentials(self, trim_mask: bool) -> None:
"""Set whether to trim mask to common and critical structures.

Args:
trim_mesh_to_essentials (bool): Whether to reduce to essential.
trim_mask (bool): Whether to reduce to essential.
"""
self.trim_mesh_to_essentials = trim_mesh_to_essentials
self._trim_mask = trim_mask

def set_simpleware_executable_path(self, path: str) -> None:
"""Set the path to the Simpleware Medical console executable.
Expand Down Expand Up @@ -283,8 +283,9 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
interior_image = itk.GetImageFromArray(interior_array.astype(np.uint8))
interior_image.CopyInformation(preprocessed_image)
imMath = tube.ImageMath.New(interior_image)
imMath.Dilate(7, 1, 0)
imMath.Erode(4, 1, 0)
spacing = interior_image.GetSpacing()
imMath.Dilate(round(7 / spacing[0]), 1, 0)
imMath.Erode(round(4 / spacing[0]), 1, 0)
exterior_image = imMath.GetOutputUChar()
exterior_array = itk.GetArrayFromImage(exterior_image)
mask_id = 6 # Heart mask id
Expand All @@ -300,17 +301,144 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
"ensure the ASCardio module ran successfully."
)

if self.trim_mesh_to_essentials:
z = labelmap_array.shape[2] - 1
z_classes = np.unique(labelmap_array[z, :, :])
heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
while heart_count < 3 and z > 0:
z -= 1
z_classes = np.unique(labelmap_array[z, :, :])
heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
if z < labelmap_array.shape[2] - 3:
labelmap_array[(z + 3) :, :, :] = 0
labelmap_image = itk.GetImageFromArray(labelmap_array.astype(np.uint8))
labelmap_image.CopyInformation(preprocessed_image)

if self._trim_mask:
labelmap_image = self.trim_mask_to_essentials(labelmap_image)

return labelmap_image

def trim_mask_to_essentials(self, labelmap_image: itk.image) -> itk.image:
"""Trim mask to essentials."""

# Reference code for cropping aorta and pulmonary artery to
# portions adjacent to the heart.
# Trim z-axis
# z = labelmap_array.shape[2] - 1
# z_classes = np.unique(labelmap_array[z, :, :])
# heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
# while heart_count < 3 and z > 0:
# z -= 1
# z_classes = np.unique(labelmap_array[z, :, :])
# heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
# if z < labelmap_array.shape[2] - 3:
# labelmap_array[(z + 3) :, :, :] = 0

# In labelmap,
# if pixel is in keep_mask, was left or right atrium, then keep as
# left or right atrium

# 1) Erase Heart and Myo label
labelmap_arr = itk.array_from_image(labelmap_image)

heart_arr = itk.array_from_image(labelmap_image)
heart_arr[heart_arr == 6] = 0
heart_arr[heart_arr == 5] = 0

img = itk.image_from_array(heart_arr)
img.CopyInformation(labelmap_image)
imMath = tube.ImageMath.New(img)

# 2) Erode then Dilate Left Atrium label to clip vessels
spacing = labelmap_image.GetSpacing()
imMath.Erode(round(7 / spacing[0]), 3, 0)
imMath.Dilate(round(7 / spacing[0]), 3, 0)

# 3) Erode then Dilate Right Atrium label to clip vessels
imMath.Erode(round(7 / spacing[0]), 4, 0)
imMath.Dilate(round(7 / spacing[0]), 4, 0)
simple_img = imMath.GetOutput()
simple_arr = itk.array_from_image(simple_img)

# Keep the largest component of the left atrium
simple_arr_3 = simple_arr.copy()
simple_arr_3[simple_arr_3 != 3] = 0
simple_arr_3[simple_arr_3 == 3] = 1
simple_img_3 = itk.image_from_array(simple_arr_3)
connComp = tube.SegmentConnectedComponents.New(simple_img_3)
connComp.SetKeepOnlyLargestComponent(True)
connComp.Update()
mask_img_3 = connComp.GetOutput()
mask_arr_3 = itk.array_from_image(mask_img_3)
simple_arr_3[mask_arr_3 == 0] = 0

# Keep the largest component of the right atrium
simple_arr_4 = simple_arr.copy()
simple_arr_4[simple_arr_4 != 4] = 0
simple_arr_4[simple_arr_4 == 4] = 1
simple_img_4 = itk.image_from_array(simple_arr_4)
connComp = tube.SegmentConnectedComponents.New(simple_img_4)
connComp.SetKeepOnlyLargestComponent(True)
connComp.Update()
mask_img_4 = connComp.GetOutput()
mask_arr_4 = itk.array_from_image(mask_img_4)
simple_arr_4[mask_arr_4 == 0] = 0

# Replace the left and right atrium labels with the largest components
simple_arr[simple_arr == 3] = 0
simple_arr[simple_arr == 4] = 0
simple_arr[simple_arr_3 > 0] = 3
simple_arr[simple_arr_4 > 0] = 4
simple_img = itk.image_from_array(simple_arr)
simple_img.CopyInformation(labelmap_image)

# 4) Dilate all others = keep_mask
keep_mask_arr = heart_arr.copy()
keep_mask_arr[keep_mask_arr == 2] = 1
keep_mask_arr[keep_mask_arr == 5] = 1
keep_mask_arr[keep_mask_arr != 1] = 0
keep_mask = itk.image_from_array(keep_mask_arr)
keep_mask.CopyInformation(labelmap_image)
imMath.SetInput(keep_mask)
imMath.Dilate(round(7 / spacing[0]), 1, 0)
keep_mask = imMath.GetOutput()
keep_mask_arr = itk.array_from_image(keep_mask)

# Add the left and right atrium labels to the keep_mask
heart_arr = heart_arr * keep_mask_arr
heart_arr[simple_arr == 3] = 3
heart_arr[simple_arr == 4] = 4
heart_img = itk.image_from_array(heart_arr)
heart_img.CopyInformation(labelmap_image)

# Dilate the keep_mask to simulate 3mm (heart)
keep_mask_arr = heart_arr.copy()
keep_mask_arr[keep_mask_arr == 1] = 0
keep_mask_arr[keep_mask_arr > 0] = 1
keep_mask = itk.image_from_array(keep_mask_arr)
keep_mask.CopyInformation(labelmap_image)
imMath.SetInput(keep_mask)
imMath.Dilate(round(5 / spacing[0]), 1, 0)
imMath.Erode(round(2 / spacing[0]), 1, 0)
heart_mask = imMath.GetOutput()

# Insert the heart and myo labels back into the labelmap
heart_mask_arr = itk.array_from_image(heart_mask)
heart_mask_arr[heart_arr > 0] = 0
heart_arr[heart_mask_arr > 0] = 6
heart_arr_myo = itk.array_from_image(labelmap_image)
heart_arr[heart_arr_myo == 5] = 5
heart_arr[heart_arr_myo == 1] = 1
heart_img = itk.image_from_array(heart_arr)
heart_img.CopyInformation(labelmap_image)

# Add in missing pieces / gaps of the myocardium
lv_arr = heart_arr.copy()
lv_arr[lv_arr != 1] = 0
lv_img = itk.image_from_array(lv_arr)
lv_img.CopyInformation(labelmap_image)
imMath.SetInput(lv_img)
imMath.Dilate(round(2 / spacing[0]), 1, 0)
lv_img = imMath.GetOutput()
lv_arr = itk.array_from_image(lv_img)
lv_arr = lv_arr * 5 # Myocardium label is 5

# Add the gap-filled myocardium back into the labelmap
heart_arr = np.where(heart_arr == 0, lv_arr, heart_arr)
# Eliminate overlap with other labels
heart_arr = np.where(labelmap_arr > 6, 0, heart_arr)
heart_img = itk.image_from_array(heart_arr)
heart_img.CopyInformation(labelmap_image)

return heart_img
Loading
Loading