1818
1919import argparse
2020from datetime import datetime
21+ import os
2122from pathlib import Path
2223
2324import anndata as ad
2425import scipy .sparse as sp
2526import torch
2627import zarr
28+ import psutil
2729from annbatch import ZarrSparseDataset , create_anndata_collection
2830from torch .utils .data import DataLoader
2931
3032from bionemo .scspeedtest import benchmark_dataloaders_with_configs
3133
3234
3335# TODO: Should num_workers control threading? If scdataset were to wrap zarr, it would still be multithreaded under the processes, so my inclination is "no".
34- # TODO: How big is the data? Presumably if it is not big enough to fit in memory, we would want O_DIRECT reading to be on.
35- zarr .config .set ({"codec_pipeline.path" : "zarrs.ZarrsCodecPipeline" , "io_direct" : True })
36+ zarr .config .set ({"codec_pipeline.path" : "zarrs.ZarrsCodecPipeline" })
3637
37- # TODO: Should annbatch export this? Probably.
38- def collate_annbatch (data_to_collate : list [tuple [sp .csr_matrix , ...]]) -> torch .Tensor :
39- """Collation of anndata tensors from torch."""
40- if torch .cuda .is_available ():
41- import cupy as np
42- import cupyx .scipy .sparse as sp
43- else :
44- import numpy as np
45- import scipy .sparse as sp
46- from annbatch .utils import to_torch
47- sparse_mat_not_torch = sp .vstack (
48- [
49- sp .csr_matrix (
50- (np .array (v [0 ].data ), np .array (v [0 ].indices ), np .array (v [0 ].indptr )),
51- shape = v [0 ].shape ,
52- )
53- for v in data_to_collate
54- ],
55- format = "csr" ,
56- )
57- return to_torch (sparse_mat_not_torch , preload_to_gpu = torch .cuda .is_available ())
58-
59- def create_dataset_factory (adata_path : Path | str ) -> ad .AnnData :
38+ def create_dataset_factory (adata_path : Path | str , num_workers : int ) -> ad .AnnData :
6039 """Generate an `anndata.AnnData` object for use in `annbatch` i.e., zarr v3 on disk, loaded using `anndata.io.sparse_dataset`."""
6140 def dataset_factory ():
6241 # TODO: Where to dump this data?
@@ -69,29 +48,38 @@ def load_adata(adata_path):
6948 adata .uns = { k : v .compute () if isinstance (v , da .Array ) else v for k , v in adata .uns .items () }
7049 return adata
7150 create_anndata_collection ([adata_path ], output_zarr_collection , zarr_sparse_chunk_size = 32768 // 2 , load_adata = load_adata , zarr_compressor = None )
51+ # Allocate each worker an even number of threads plus a little
52+ # TODO: How big is the data? Presumably if it is not big enough to fit in memory, we would want O_DIRECT reading to be on.
53+ # TODO: There are probably faster ways to get a directory size? If you're using this data loader, presumably your data doesn't fit in memory?
7254 # TODO: Does X always contain the genes of interest? Layers? Raw?
73- return [ad .AnnData (X = ad .io .sparse_dataset (zarr .open (p )["X" ])) for p in output_zarr_collection .iterdir ()]
55+ use_direct_io = psutil .virtual_memory ().total < sum (f .stat ().st_size for f in output_zarr_collection .glob ("**/*" ) if f .is_file ())
56+ with zarr .config .set (
57+ {
58+ "threading.max_workers" : (os .cpu_count () // max (num_workers , 1 )) + 1 ,
59+ "codec_pipeline.direct_io" : use_direct_io ,
60+ }
61+ ):
62+ return [ad .AnnData (X = ad .io .sparse_dataset (zarr .open (p )["X" ])) for p in output_zarr_collection .iterdir ()]
7463 return dataset_factory
7564
7665def to_annbatch (adatas : list [ad .AnnData ], * , batch_size : int = 64 , shuffle : bool = True , block_size : int = 1 , fetch_factor : int = 2 , num_workers : int = 0 ) -> ZarrSparseDataset :
7766 """Generate an `annbatch.ZarrSparseDataset` based on configuration and a list of input `anndata.AnnData` objects backed by zarr v3 on disk."""
78- if block_size == 1 and num_workers > 1 :
67+ if num_workers > 0 :
7968 ds = ZarrSparseDataset (
80- batch_size = batch_size // num_workers ,
69+ batch_size = batch_size ,
8170 chunk_size = block_size ,
8271 preload_nchunks = ((fetch_factor * batch_size ) // block_size ),
8372 preload_to_gpu = False ,
8473 shuffle = shuffle ,
85- to_torch = False
74+ to_torch = True
8675 ).add_anndatas (adatas )
8776 loader = DataLoader (
8877 ds ,
89- batch_size = num_workers ,
78+ batch_size = None ,
9079 num_workers = num_workers ,
9180 multiprocessing_context = "spawn" ,
92- collate_fn = collate_annbatch
9381 )
94- return loader
82+ return ( v [ 0 ] for v in iter ( loader ))
9583 ds = ZarrSparseDataset (
9684 batch_size = batch_size ,
9785 chunk_size = block_size ,
@@ -114,7 +102,7 @@ def create_annbatch_factory(
114102 """Factory creator for on-disk zarr v3 sharded anndata file __and__ a `annbatch.ZarrSparseDataset` based on configuration in the arguments to this function as well as that on-disk."""
115103
116104 def factory ():
117- adatas = create_dataset_factory (adata_path )()
105+ adatas = create_dataset_factory (adata_path , num_workers )()
118106 return to_annbatch (
119107 adatas ,
120108 batch_size = batch_size ,
@@ -166,6 +154,7 @@ def comprehensive_benchmarking_example(
166154 print (f"Benchmarking { num_runs } run(s) each" )
167155 print ()
168156 print ("Running annbatch..." )
157+ num_workers = 0 # TODO: why 0? the other scripts seem to have this hardcoded though
169158 annbatch_configurations = []
170159 for fetch_factor in fetch_factors :
171160 for block_size in block_sizes :
@@ -175,7 +164,7 @@ def comprehensive_benchmarking_example(
175164 "dataloader_factory" : create_annbatch_from_preloaded_anndata_factory (
176165 batch_size = 64 ,
177166 shuffle = True ,
178- num_workers = 4 , # TODO: why 0? the other scripts seem to have this hardcoded though
167+ num_workers = num_workers ,
179168 block_size = block_size ,
180169 fetch_factor = fetch_factor ,
181170 ),
@@ -189,7 +178,7 @@ def comprehensive_benchmarking_example(
189178
190179 benchmark_dataloaders_with_configs (
191180 dataloader_configs = annbatch_configurations ,
192- shared_dataset_factory = create_dataset_factory (adata_path ),
181+ shared_dataset_factory = create_dataset_factory (adata_path , num_workers ),
193182 output_prefix = f"annbatch_benchmark_{ timestamp } " ,
194183 )
195184
0 commit comments