Skip to content

Commit

Permalink
change dataloader for single loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Soorya19Pradeep committed Feb 4, 2025
1 parent 30a8892 commit c8eb64d
Showing 1 changed file with 52 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import dash.dependencies as dd
from functools import lru_cache
from collections import defaultdict
import atexit

from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.representation.evaluation import dataset_of_tracks
from viscy.utils.log_images import render_images
# from viscy.representation.evaluation import dataset_of_tracks
from viscy.data.triplet import TripletDataModule

# Initialize Dash app
app = dash.Dash(__name__)

# Sample DataFrame for demonstration
fov_name = "/0/6/000000"
features_path = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr"
)
Expand All @@ -40,7 +39,7 @@
)

# Filter data for FOVs starting with '/0/6/000000'
mask = features.coords['fov_name'].str.startswith('/0/6/000000')
mask = features.coords['fov_name'].str.startswith(fov_name)
features = features.sel(sample=mask)

df = pd.DataFrame({k: v for k, v in features.coords.items() if k != "features"})
Expand Down Expand Up @@ -71,7 +70,7 @@
figure=fig,
),
html.Div([
html.Img(id="hover-image", src="", style={"width": "150px", "height": "150px"})
html.Img(id="hover-image", src="", style={"width": "300px", "height": "150px"})
])
])

Expand Down Expand Up @@ -106,46 +105,55 @@ def preload_images(df):
"""Preload all images into memory"""
print("Preloading images into cache...")

groups = df.groupby(['fov_name', 'track_id'])
# groups = df.groupby(['fov_name', 'track_id'])
track_id_list = df['track_id'].unique().tolist()

for (fov_name, track_id), group in groups:
# Find the lowest t value for this group
min_t = group['t'].min()

predict_dataset = dataset_of_tracks(
data_path,
tracks_path,
[fov_name],
[track_id],
z_range=(31,36),
source_channel=["Phase3D", "MultiCam_GFP_mCherry_BF-Prime BSI Express"],
)

# for (fov_name, track_id), group in groups:

data_module = TripletDataModule(
data_path=data_path,
tracks_path=tracks_path,
include_fov_names=[fov_name],
include_track_ids=track_id_list,
source_channel=["Phase3D", "MultiCam_GFP_mCherry_BF-Prime BSI Express"],
z_range=(31,36),
initial_yx_patch_size=(128, 128),
final_yx_patch_size=(128, 128),
batch_size=1,
num_workers=16,
normalizations=None,
predict_cells=True,
)
data_module.setup("predict")

for batch in data_module.predict_dataloader():
images = batch["anchor"].numpy()
indices = batch["index"]
track_id = indices["track_id"].tolist()
t = indices["t"].tolist()
# print(track_id, t)

try:
image_patches = np.stack([p["anchor"].numpy() for p in predict_dataset])
img = np.stack(images)
cache_key = (fov_name, track_id[0], t[0])

for i in range(image_patches.shape[0]):
img = image_patches[i]
t = group['t'].iloc[i]
cache_key = (fov_name, track_id, t)

# Extract and normalize each channel independently
channel1 = normalize_image(img[0, 2]) # First channel at z=2
channel2 = normalize_image(np.max(img[1], axis=0)) # Max projection of second channel

# Ensure both channels are uint8
channel1 = channel1.astype(np.uint8)
channel2 = channel2.astype(np.uint8)

# Concatenate the normalized channels horizontally
combined_img = np.hstack((channel1, channel2))

# Store the base64 string in the cache
try:
image_cache[cache_key] = numpy_to_base64(combined_img)
except Exception as e:
print(f"Error converting image to base64 for {cache_key}: {e}")
continue
# Extract and normalize each channel independently
channel1 = normalize_image(img[0, 0, 2]) # First channel at z=2
channel2 = normalize_image(np.max(img[0,1], axis=0)) # Max projection of second channel

# Ensure both channels are uint8
channel1 = channel1.astype(np.uint8)
channel2 = channel2.astype(np.uint8)

# Concatenate the normalized channels horizontally
combined_img = np.hstack((channel1, channel2))

# Store the base64 string in the cache
try:
image_cache[cache_key] = numpy_to_base64(combined_img)
except Exception as e:
print(f"Error converting image to base64 for {cache_key}: {e}")
continue

except Exception as e:
print(f"Error processing images for {fov_name}, {track_id}: {e}")
Expand Down Expand Up @@ -175,4 +183,4 @@ def update_image(hoverData):
return image_cache.get(cache_key, "")

if __name__ == '__main__':
app.run_server(debug=False)
app.run_server(debug=True)

0 comments on commit c8eb64d

Please sign in to comment.