CellCLIP - Learning Perturbation Effects in Cell Painting via Text-Guided Contrastive Learning (NeurIPS 2025)
This repository provides code and instructions to reproduce the results presented in our work on CellCLIP. The proposed framework aligns Cell Painting image embeddings with perturbation-level textual descriptions, enabling biologically meaningful representations for downstream retrieval and matching tasks.
By Mingyu Lu, Ethan Weinberger, Chanwoo Kim, Su-In Lee
[Update 8/11 - We now provide a pretrained CellCLIP checkpoint with DINOv2-giant on Hugging Face]
from huggingface_hub import hf_hub_download
from src.helper import load # make sure your helper path is correct
ckpt_path = hf_hub_download(
repo_id="suinleelab/CellCLIP",
filename="model.safetensors"
)
model = load(
model_path=ckpt_path,
device=device,
model_type="cell_clip",
input_dim=1536,
loss_type="cwcl"
)├── src/ # Core source files
│ ├── clip/ # CellCLIP and contrastive learning modules
│ │ ├── method.py # Contrastive loss implementations (e.g., CWCL, CLOOB, InfoNCE)
│ │ └── model.py # CellCLIP and CrossChannelFormer model definitions
│ ├── helpler.py # Utility functions
│ └── ... # Other supporting modules
│
├── configs/ # Model and training configuration files
├── preprocessing/ # Files for preprocessing
│
├── main.py # Main training script
│
├── retrieval.py # Cross-modal retrieval evaluation
├── rxrx3-core_efaar_eval.py # Intra-modal evaluation on RxRx3-core (gene–gene recovery)
└── cpjump_matching_eval.py # Replicate detection and sister perturbation matching on CP-JUMP1
To reproduce the results, follow these steps:
- Environment setup and installation
- Preprocessing Cell Painting images and associated metadata
- Training the proposed model
- Evaluating cross-modal and intra-modal retrieval performance
Set up a virtual environment with Python 3.11.5. Before starting, ensure all required packages are installed:
pip install -r requirements.txtCreate a src/constants.py file with the following content:
DATASET_DIR = "dataset_dir"
OUTDIR = "model_out_dir"
LOGDIR = "log_dir"Add the repo directory to PYTHONPATH:
export PYTHONPATH="$PYTHONPATH:$PWD"Please first download Cell Painting images and corresponding metadata and labels from each link.
-
Bray2017 Preprocessed Data Available at: https://ml.jku.at/software/cellpainting/dataset/
-
RxRx3-Core Download from RxRx3-core at Hugging Face
-
CP-JUMP1 Available via instruction from the official repository
To normalize raw Cell Painting image values into the [0-255], use:
python preprocessing/preprocess_images.pyOnce the preprocessed images are ready, you can generate embeddings using our proposed CrossChannelFormer encoding scheme by running:
python preprocessing/convert_npz_to_avg_emb.py \
--model_card facebook/dino-vitb8 \ # Feature extractor to generate embeddings
--dataset bray2017 \
--input_dir path_to_dataset \
--aggregation_strategy mean \ # Aggregation method (e.g., mean, attention)
--n_crop 1 \ # Number of crops per image
--output_file dino-vitb8_ind.h5 # Output file path
To generate molecule-level prompts or fingerprints for contrastive training:
python preprocessing/preprocess_molecules.py \
--dataset [bray2017 | jumpcp | rxrx3-core] \
--output_file output_filename.h5|csv \
--img_dir /path/to/input_dataTo train CellCLIP, execute the following command:
python main.py
# === Dataset and Input Files ===
--dataset [bray2017 | jumpcp] \ # Dataset name
--img_dir /path/to/image_embeddings or images \ # Directory containing image embeddings or images in step2
# === Image Preprocessing (Optional) ===
--image_resolution_train 224 \ # Resolution of training image inputs
--image_resolution_val 224 \ # Resolution of validation image inputs
--molecule_path /path/to/perturbation_descriptions \ # Path to molecule or text input in step 2
--unique # Whether to treat perturbations as unique (multi-instance mode)
# === Model Configuration ===
--model_type [milcellclip | cloome | molphenix] \ # Type of model architecture
--input_dim [768 | 1024 | 1536] \ # Input feature dimensionality (depends on embedding source)
--loss [cwcl | clip | cloob] \ # Contrastive loss function
# === Optimization Hyperparameters ===
--epochs 50 \ # Number of training epochs
--batch_size 512 \ # Batch size
# === Learning Rate and Scheduler ===
--lr 5e-4 \ # Learning rate
--lr_scheduler [cosine | const | const-cooldown] \ # LR scheduler type
--warmup 1000 \ # Number of warmup steps
--num_cycles 5 \ # Number of cosine cycles for LR scheduler
# === Checkpointing & Logging ===
--ckpt_freq 1000 \ # Frequency (in steps) to save checkpoints
--keep_all_ckpts \ # Save all checkpoints (not just latest)
--log_freq 20 \ # Log every N steps
--eval_freq 500 # Evaluate every N stepsTo enable distributed training across multiple GPUs, use accelerate:
accelerate launch --config_file configs/your_config.yaml main.py ...Note: On a setup with 8 × RTX 6000 GPUs, a maximum batch size of 512 has been tested successfully. Below is an example command to train CellCLIP using accelerate.
accelerate launch \
--config_file configs/ddp_config.yaml main.py \
--split 1 \
--is_train \
--resume \
--batch_size 512 \
--epochs 50 \
--model_type mil_cell_clip \
--input_dim 1536 \
--dataset bray2017 \
--img_dir path_to_embeddings \
--unique \
--molecule_path path_to_molecules \
--loss_type cloob \
--lr_scheduler cosine-restarts \
--num_cycles 4 \
--wd 0.1 \
--init-inv-tau 14.3 \
--learnable-inv-tau \
--warmup 1000 \
--ckpt_freq 500 \
--eval_freq 100 \
--opt_seed 42 \
--lr 0.0001
This section describes how to evaluate the trained model on both cross-modal and intra-modal tasks.
Evaluate the alignment between Cell Painting images and perturbation-level text embeddings:
python retrieval.py \
--embedding_type /path/to/eval_embeddings \ # Path to aggregated embeddings
--model_type [milcellclip | cloome | molphenix] \ # Model architecture
--input_dim [768 | 1024 | 1536] \ # Embedding dimensionality
--loss [cwcl | clip | cloob] \ # Loss used during training
--ckpt_path /path/to/trained_model.pt \ # Path to model checkpoint
--unique \ # Use multi-instance mode if applicable
--image_resolution_train 224 \ # Resolution used for training
--image_resolution_val 224 # Resolution used for evaluationFor models trained on individual instances and evaluated on pooled profiles, use retrieval_whole.py.
Use the following script to generate instance-level embeddings for RxRx3-core evaluation:
python preprocessing/convert_emb_to_ind_rxrx3core_emb.py \
--ckpt_path /path/to/trained_model.pt \
--model_type milcellclip \
--loss_type cwcl \
--input_dim 1536 \
--output_file output_embeddings.npz \
--img_dir /path/to/test_embeddingsRun zero-shot recovery of gene–gene relationships evaluation on RxRx3-Core
python rxrx3-core_efaar_eval.py --filepath [path_to_precomputed embeddings from a trained model, e.g., CellCLIP]This evaluation tests the model’s ability to:
- Detect biological replicates (same perturbation, different images)
- Match sister perturbations that target the same biological pathway
Use the following script to generate instance-level embeddings from a trained model:
python preprocessing/convert_emb_to_cellclip_emb.py
--ckpt_path [path to trained CellCLIP ckpt]
--model_type [mil_cell_clip]
--loss_type cwcl
--input_dim 1536
--pretrained_emb name of the pretrained embeddings
--img_dir path_to_testing_data_embeddingsRun CP-JUMP1 Evaluation
python cpjump1_matching_eval.py \
--kernel poly \ # Kernel for batch correction (e.g., poly)
--feature_type [profile | emb] \ # Whether to use raw profiles or embeddings
--batch_correction # Enable batch effect correction
