Skip to content

Commit b4ccef7

Browse files
committed
fix: no sense in being opinionated about num_workers + sensible defaults
1 parent fdb611e commit b4ccef7

File tree

1 file changed

+24
-35
lines changed

1 file changed

+24
-35
lines changed

sub-packages/bionemo-scspeedtest/examples/annbatch_script.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,45 +18,24 @@
1818

1919
import argparse
2020
from datetime import datetime
21+
import os
2122
from pathlib import Path
2223

2324
import anndata as ad
2425
import scipy.sparse as sp
2526
import torch
2627
import zarr
28+
import psutil
2729
from annbatch import ZarrSparseDataset, create_anndata_collection
2830
from torch.utils.data import DataLoader
2931

3032
from bionemo.scspeedtest import benchmark_dataloaders_with_configs
3133

3234

3335
# TODO: Should num_workers control threading? If scdataset were to wrap zarr, it would still be multithreaded under the processes, so my inclination is "no".
34-
# TODO: How big is the data? Presumably if it is not big enough to fit in memory, we would want O_DIRECT reading to be on.
35-
zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline", "io_direct": True})
36+
zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})
3637

37-
# TODO: Should annbatch export this? Probably.
38-
def collate_annbatch(data_to_collate: list[tuple[sp.csr_matrix, ...]]) -> torch.Tensor:
39-
"""Collation of anndata tensors from torch."""
40-
if torch.cuda.is_available():
41-
import cupy as np
42-
import cupyx.scipy.sparse as sp
43-
else:
44-
import numpy as np
45-
import scipy.sparse as sp
46-
from annbatch.utils import to_torch
47-
sparse_mat_not_torch = sp.vstack(
48-
[
49-
sp.csr_matrix(
50-
(np.array(v[0].data), np.array(v[0].indices), np.array(v[0].indptr)),
51-
shape=v[0].shape,
52-
)
53-
for v in data_to_collate
54-
],
55-
format="csr",
56-
)
57-
return to_torch(sparse_mat_not_torch, preload_to_gpu=torch.cuda.is_available())
58-
59-
def create_dataset_factory(adata_path: Path | str) -> ad.AnnData:
38+
def create_dataset_factory(adata_path: Path | str, num_workers: int) -> ad.AnnData:
6039
"""Generate an `anndata.AnnData` object for use in `annbatch` i.e., zarr v3 on disk, loaded using `anndata.io.sparse_dataset`."""
6140
def dataset_factory():
6241
# TODO: Where to dump this data?
@@ -69,29 +48,38 @@ def load_adata(adata_path):
6948
adata.uns = { k: v.compute() if isinstance(v, da.Array) else v for k, v in adata.uns.items() }
7049
return adata
7150
create_anndata_collection([adata_path], output_zarr_collection, zarr_sparse_chunk_size=32768//2, load_adata=load_adata, zarr_compressor=None)
51+
# Allocate each worker an even number of threads plus a little
52+
# TODO: How big is the data? Presumably if it is not big enough to fit in memory, we would want O_DIRECT reading to be on.
53+
# TODO: There are probably faster ways to get a directory size? If you're using this data loader, presumably your data doesn't fit in memory?
7254
# TODO: Does X always contain the genes of interest? Layers? Raw?
73-
return [ad.AnnData(X=ad.io.sparse_dataset(zarr.open(p)["X"])) for p in output_zarr_collection.iterdir()]
55+
use_direct_io = psutil.virtual_memory().total < sum(f.stat().st_size for f in output_zarr_collection.glob("**/*") if f.is_file())
56+
with zarr.config.set(
57+
{
58+
"threading.max_workers": (os.cpu_count() // max(num_workers, 1)) + 1,
59+
"codec_pipeline.direct_io": use_direct_io,
60+
}
61+
):
62+
return [ad.AnnData(X=ad.io.sparse_dataset(zarr.open(p)["X"])) for p in output_zarr_collection.iterdir()]
7463
return dataset_factory
7564

7665
def to_annbatch(adatas: list[ad.AnnData], *, batch_size: int = 64, shuffle: bool = True, block_size: int = 1, fetch_factor: int = 2, num_workers: int = 0) -> ZarrSparseDataset:
7766
"""Generate an `annbatch.ZarrSparseDataset` based on configuration and a list of input `anndata.AnnData` objects backed by zarr v3 on disk."""
78-
if block_size == 1 and num_workers > 1:
67+
if num_workers > 0:
7968
ds = ZarrSparseDataset(
80-
batch_size=batch_size // num_workers,
69+
batch_size=batch_size,
8170
chunk_size=block_size,
8271
preload_nchunks=((fetch_factor * batch_size) // block_size),
8372
preload_to_gpu=False,
8473
shuffle=shuffle,
85-
to_torch=False
74+
to_torch=True
8675
).add_anndatas(adatas)
8776
loader = DataLoader(
8877
ds,
89-
batch_size=num_workers,
78+
batch_size=None,
9079
num_workers=num_workers,
9180
multiprocessing_context="spawn",
92-
collate_fn=collate_annbatch
9381
)
94-
return loader
82+
return (v[0] for v in iter(loader))
9583
ds = ZarrSparseDataset(
9684
batch_size=batch_size,
9785
chunk_size=block_size,
@@ -114,7 +102,7 @@ def create_annbatch_factory(
114102
"""Factory creator for on-disk zarr v3 sharded anndata file __and__ a `annbatch.ZarrSparseDataset` based on configuration in the arguments to this function as well as that on-disk."""
115103

116104
def factory():
117-
adatas = create_dataset_factory(adata_path)()
105+
adatas = create_dataset_factory(adata_path, num_workers)()
118106
return to_annbatch(
119107
adatas,
120108
batch_size=batch_size,
@@ -166,6 +154,7 @@ def comprehensive_benchmarking_example(
166154
print(f"Benchmarking {num_runs} run(s) each")
167155
print()
168156
print("Running annbatch...")
157+
num_workers = 0 # TODO: why 0? the other scripts seem to have this hardcoded though
169158
annbatch_configurations = []
170159
for fetch_factor in fetch_factors:
171160
for block_size in block_sizes:
@@ -175,7 +164,7 @@ def comprehensive_benchmarking_example(
175164
"dataloader_factory": create_annbatch_from_preloaded_anndata_factory(
176165
batch_size=64,
177166
shuffle=True,
178-
num_workers=4, # TODO: why 0? the other scripts seem to have this hardcoded though
167+
num_workers=num_workers,
179168
block_size=block_size,
180169
fetch_factor=fetch_factor,
181170
),
@@ -189,7 +178,7 @@ def comprehensive_benchmarking_example(
189178

190179
benchmark_dataloaders_with_configs(
191180
dataloader_configs=annbatch_configurations,
192-
shared_dataset_factory=create_dataset_factory(adata_path),
181+
shared_dataset_factory=create_dataset_factory(adata_path, num_workers),
193182
output_prefix=f"annbatch_benchmark_{timestamp}",
194183
)
195184

0 commit comments

Comments
 (0)