Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding TArrow for testing #220

Open
wants to merge 15 commits into
base: msd_update
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 316 additions & 4 deletions applications/pseudotime_analysis/pca_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
import seaborn as sns
from viscy.representation.embedding_writer import read_embedding_dataset
Expand Down Expand Up @@ -159,6 +161,261 @@ def analyze_pc_distributions(
return pd.DataFrame(results)


def analyze_gmm_clustering(
pca_result,
track_ids,
time_points,
tracks_of_interest,
n_components_range=range(2, 7),
seed_timepoint=None,
time_window=None,
):
"""Analyze clusters using Gaussian Mixture Models."""
# Get points from tracks of interest
track_mask = np.isin(track_ids, tracks_of_interest)
points = pca_result[track_mask]
track_ids_subset = track_ids[track_mask]
times = time_points[track_mask]

# Apply time window if specified
if seed_timepoint is not None and time_window is not None:
time_mask = (times >= seed_timepoint - time_window) & (
times <= seed_timepoint + time_window
)
points = points[time_mask]
track_ids_subset = track_ids_subset[time_mask]
times = times[time_mask]

# Try different numbers of components
bic_scores = []
silhouette_scores = []
models = []

for n_components in n_components_range:
gmm = GaussianMixture(
n_components=n_components, random_state=RANDOM_SEED, n_init=10
)
gmm.fit(points)
labels = gmm.predict(points)

bic_scores.append(gmm.bic(points))
silhouette_scores.append(silhouette_score(points, labels))
models.append(gmm)

# Plot model selection metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# BIC plot
ax1.plot(list(n_components_range), bic_scores, "bo-")
ax1.set_xlabel("Number of Components")
ax1.set_ylabel("BIC Score")
ax1.set_title("Model Selection: BIC")

# Silhouette plot
ax2.plot(list(n_components_range), silhouette_scores, "ro-")
ax2.set_xlabel("Number of Components")
ax2.set_ylabel("Silhouette Score")
ax2.set_title("Model Selection: Silhouette")

plt.tight_layout()
plt.show()

# Select best model based on BIC
best_idx = np.argmin(bic_scores)
best_n_components = n_components_range[best_idx]
best_model = models[best_idx]

# Get cluster assignments
labels = best_model.predict(points)
probs = best_model.predict_proba(points)

# Plot clustering results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Scatter plot colored by cluster
scatter = ax1.scatter(
points[:, 0], points[:, 1], c=labels, cmap="tab10", alpha=0.6, s=50
)
ax1.set_xlabel("PC1")
ax1.set_ylabel("PC2")
ax1.set_title(f"GMM Clustering (n={best_n_components})")
plt.colorbar(scatter, ax=ax1, label="Cluster")

# Plot cluster assignment probabilities
max_probs = np.max(probs, axis=1)
scatter = ax2.scatter(
points[:, 0], points[:, 1], c=max_probs, cmap="viridis", alpha=0.6, s=50
)
ax2.set_xlabel("PC1")
ax2.set_ylabel("PC2")
ax2.set_title("Cluster Assignment Probability")
plt.colorbar(scatter, ax=ax2, label="Probability")

plt.tight_layout()
plt.show()

# Analyze cluster composition
cluster_stats = []
for i in range(best_n_components):
cluster_mask = labels == i
cluster_tracks = np.unique(track_ids_subset[cluster_mask])
cluster_stats.append(
{
"cluster": i,
"n_points": np.sum(cluster_mask),
"n_tracks": len(cluster_tracks),
"tracks": cluster_tracks,
"mean_prob": np.mean(probs[cluster_mask, i]),
"std_prob": np.std(probs[cluster_mask, i]),
}
)

# Print cluster statistics
print(f"\nBest number of clusters (BIC): {best_n_components}")
print("\nCluster Statistics:")
for stats in cluster_stats:
print(f"\nCluster {stats['cluster']}:")
print(f" Points: {stats['n_points']}")
print(f" Tracks: {stats['n_tracks']}")
print(f" Mean probability: {stats['mean_prob']:.3f} ± {stats['std_prob']:.3f}")
print(f" Tracks in cluster: {stats['tracks']}")

return {
"best_model": best_model,
"best_n_components": best_n_components,
"labels": labels,
"probabilities": probs,
"bic_scores": bic_scores,
"silhouette_scores": silhouette_scores,
"cluster_stats": cluster_stats,
}


def analyze_cluster_characteristics(
gmm_results,
pca_result,
track_ids,
time_points,
tracks_of_interest,
pc_analysis=None,
seed_timepoint=None,
time_window=None,
):
"""Analyze characteristics of GMM clusters including temporal patterns and PC contributions."""
# Get points from tracks of interest first
track_mask = np.isin(track_ids, tracks_of_interest)
points = pca_result[track_mask]
track_ids_subset = track_ids[track_mask]
times = time_points[track_mask]

# Apply time window if specified
if seed_timepoint is not None and time_window is not None:
time_mask = (times >= seed_timepoint - time_window) & (
times <= seed_timepoint + time_window
)
points = points[time_mask]
track_ids_subset = track_ids_subset[time_mask]
times = times[time_mask]

# Get cluster assignments for the filtered points
labels = gmm_results["labels"]
probs = gmm_results["probabilities"]
n_clusters = gmm_results["best_n_components"]

# Analyze temporal patterns in each cluster
print("\nTemporal patterns in clusters:")
for i in range(n_clusters):
cluster_mask = labels == i
cluster_times = times[cluster_mask]
if len(cluster_times) > 0:
print(f"\nCluster {i}:")
print(
f" Time range: {np.min(cluster_times):.1f} to {np.max(cluster_times):.1f}"
)
print(
f" Mean time: {np.mean(cluster_times):.1f} ± {np.std(cluster_times):.1f}"
)

# Analyze PC contributions to cluster separation
print("\nPC contributions to cluster separation:")
for pc_idx in range(min(4, points.shape[1])): # Analyze first 4 PCs
pc_values = points[:, pc_idx]
cluster_means = [np.mean(pc_values[labels == i]) for i in range(n_clusters)]
cluster_stds = [np.std(pc_values[labels == i]) for i in range(n_clusters)]

# Calculate separation score (ratio of between-cluster to within-cluster variance)
between_var = np.var(cluster_means)
within_var = np.mean(cluster_stds)
separation_score = between_var / within_var if within_var > 0 else float("inf")

print(f"\nPC{pc_idx + 1}:")
print(f" Separation score: {separation_score:.3f}")
if pc_analysis is not None:
pc_info = pc_analysis[pc_analysis["PC"] == pc_idx + 1].iloc[0]
print(
f" Top contributing features: {', '.join(pc_info['Top_Features'][:3])}"
)

# Print cluster-specific stats
for i in range(n_clusters):
cluster_mask = labels == i
print(f" Cluster {i}: {cluster_means[i]:.3f} ± {cluster_stds[i]:.3f}")

# Analyze track transitions between clusters
print("\nTrack transitions between clusters:")
for track_id in tracks_of_interest:
track_mask = track_ids_subset == track_id
track_labels = labels[track_mask]
track_times = times[track_mask]

if len(track_labels) > 1:
# Sort by time
sort_idx = np.argsort(track_times)
track_labels = track_labels[sort_idx]
track_times = track_times[sort_idx]

# Find transitions
transitions = np.where(track_labels[1:] != track_labels[:-1])[0]
if len(transitions) > 0:
print(f"\nTrack {track_id}:")
for trans_idx in transitions:
from_cluster = track_labels[trans_idx]
to_cluster = track_labels[trans_idx + 1]
trans_time = track_times[trans_idx + 1]
print(f" {trans_time:.1f}: {from_cluster} -> {to_cluster}")

return {
"temporal_patterns": {
i: {
"mean_time": np.mean(times[labels == i]),
"std_time": np.std(times[labels == i]),
}
for i in range(n_clusters)
},
"pc_contributions": {
f"PC{pc_idx + 1}": {
"separation_score": (
np.var(
[
np.mean(points[labels == i, pc_idx])
for i in range(n_clusters)
]
)
/ np.mean(
[np.std(points[labels == i, pc_idx]) for i in range(n_clusters)]
)
if np.mean(
[np.std(points[labels == i, pc_idx]) for i in range(n_clusters)]
)
> 0
else float("inf")
)
}
for pc_idx in range(min(4, points.shape[1]))
},
}


def analyze_embeddings_with_pca(
embedding_path,
annotation_path=None,
Expand Down Expand Up @@ -504,6 +761,7 @@ def analyze_embeddings_with_pca(
)
print(dist_analysis.to_string(index=False))

# Return PCA results and additional data needed for clustering
return (
pca,
pca_result,
Expand All @@ -513,14 +771,16 @@ def analyze_embeddings_with_pca(
pc_analysis,
cluster_analysis,
dist_analysis,
track_ids,
time_points,
)


# %%
if __name__ == "__main__":
embedding_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr"
annotation_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/phenotype_observations.csv"
# %%

# Using phenotype annotations with specific FOVs
print("\nAnalyzing phenotype 1 in specific FOVs:")
(
Expand All @@ -532,13 +792,39 @@ def analyze_embeddings_with_pca(
pc_analysis,
cluster_analysis,
dist_analysis,
track_ids,
time_points,
) = analyze_embeddings_with_pca(
embedding_path,
annotation_path=annotation_path,
phenotype_of_interest=1,
seed_timepoint=55,
time_window=10,
fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns
fov_patterns=["/C/2/", "/B/3/", "/B/2/"],
)

# Run GMM clustering analysis separately
print("\nPerforming GMM clustering analysis...")
gmm_results = analyze_gmm_clustering(
pca_result,
track_ids,
time_points,
tracks,
seed_timepoint=55,
time_window=10,
)

# Analyze cluster characteristics
print("\nAnalyzing cluster characteristics...")
cluster_characteristics = analyze_cluster_characteristics(
gmm_results,
pca_result,
track_ids,
time_points,
tracks,
pc_analysis=pc_analysis,
seed_timepoint=55,
time_window=10,
)

# Using random tracks from specific FOVs
Expand All @@ -552,13 +838,39 @@ def analyze_embeddings_with_pca(
pc_analysis,
cluster_analysis,
dist_analysis,
track_ids,
time_points,
) = analyze_embeddings_with_pca(
embedding_path,
annotation_path=None, # This triggers random track selection
annotation_path=None,
n_random_tracks=10,
seed_timepoint=55,
time_window=30,
fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns
fov_patterns=["/C/2/", "/B/3/", "/B/2/"],
)
# %%
# Run GMM clustering analysis for random tracks
print("\nPerforming GMM clustering analysis for random tracks...")
gmm_results = analyze_gmm_clustering(
pca_result,
track_ids,
time_points,
tracks,
seed_timepoint=55,
time_window=30,
)

# Analyze cluster characteristics for random tracks
print("\nAnalyzing cluster characteristics for random tracks...")
cluster_characteristics = analyze_cluster_characteristics(
gmm_results,
pca_result,
track_ids,
time_points,
tracks,
pc_analysis=pc_analysis,
seed_timepoint=55,
time_window=30,
)

# %%
Loading
Loading