Skip to content

Commit 371f262

Browse files
authored
feat: add verbose option in optimize_fn (#654)
* add verbose option in optimize_fn * tests
1 parent 54cb63b commit 371f262

File tree

3 files changed

+68
-24
lines changed

3 files changed

+68
-24
lines changed

src/litdata/processing/data_processor.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,7 @@ def __init__(
10891089
start_method: Optional[str] = None,
10901090
storage_options: dict[str, Any] = {},
10911091
keep_data_ordered: bool = True,
1092+
verbose: bool = True,
10921093
):
10931094
"""Provides an efficient way to process data across multiple machine into chunks to make training faster.
10941095
@@ -1115,6 +1116,7 @@ def __init__(
11151116
inside an interactive shell like Ipython.
11161117
storage_options: Storage options for the cloud provider.
11171118
keep_data_ordered: Whether to use a shared queue for the workers or not.
1119+
verbose: Whether to print the progress & logs of the workers. Defaults to True.
11181120
"""
11191121
# spawn doesn't work in IPython
11201122
start_method = start_method or ("fork" if in_notebook() else "spawn")
@@ -1124,7 +1126,8 @@ def __init__(
11241126
msg += "Tip: Libraries relying on lock can hang with `fork`. To use `spawn` in notebooks, "
11251127
msg += "move your code to files and import it within the notebook."
11261128

1127-
print(msg)
1129+
if verbose:
1130+
print(msg)
11281131

11291132
multiprocessing.set_start_method(start_method, force=True)
11301133

@@ -1166,9 +1169,13 @@ def __init__(
11661169
if self.output_dir:
11671170
# Ensure the output dir is the same across all nodes
11681171
self.output_dir = broadcast_object("output_dir", self.output_dir, rank=_get_node_rank())
1169-
print(f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}")
1172+
if verbose:
1173+
print(
1174+
f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}"
1175+
)
11701176

11711177
self.random_seed = random_seed
1178+
self.verbose = verbose
11721179

11731180
def run(self, data_recipe: DataRecipe) -> None:
11741181
"""Triggers the data recipe processing over your dataset."""
@@ -1179,7 +1186,8 @@ def run(self, data_recipe: DataRecipe) -> None:
11791186
self._cleanup_checkpoints()
11801187

11811188
t0 = time()
1182-
print(f"Setup started with fast_dev_run={self.fast_dev_run}.")
1189+
if self.verbose:
1190+
print(f"Setup started with fast_dev_run={self.fast_dev_run}.")
11831191

11841192
# Force random seed to be fixed
11851193
random.seed(self.random_seed)
@@ -1231,7 +1239,8 @@ def run(self, data_recipe: DataRecipe) -> None:
12311239
if isinstance(user_items, list)
12321240
else "Using a Queue to process items on demand."
12331241
)
1234-
print(f"Setup finished in {round(time() - t0, 3)} seconds. {msg}")
1242+
if self.verbose:
1243+
print(f"Setup finished in {round(time() - t0, 3)} seconds. {msg}")
12351244

12361245
if self.use_checkpoint:
12371246
if isinstance(user_items, multiprocessing.queues.Queue):
@@ -1244,49 +1253,56 @@ def run(self, data_recipe: DataRecipe) -> None:
12441253
# Checkpoint feature is not supported for generators for now.
12451254
raise ValueError("Checkpoint feature is not supported for generators, yet.")
12461255
# get the last checkpoint details
1247-
print("Resuming from last saved checkpoint...")
1256+
if self.verbose:
1257+
print("Resuming from last saved checkpoint...")
12481258
self._load_checkpoint_config(workers_user_items)
12491259

12501260
assert isinstance(self.checkpoint_next_index, list)
12511261

12521262
if all(self.checkpoint_next_index[i] == 0 for i in range(self.num_workers)):
12531263
# save the current configuration in the checkpoints.json file
1254-
print("No checkpoints found. Saving current configuration...")
1264+
if self.verbose:
1265+
print("No checkpoints found. Saving current configuration...")
12551266
self._save_current_config(workers_user_items)
12561267
else:
12571268
# load the last checkpoint details
12581269
assert isinstance(self.checkpoint_next_index, list)
12591270
workers_user_items = [w[self.checkpoint_next_index[i] :] for i, w in enumerate(workers_user_items)]
1260-
print("Checkpoints loaded successfully.")
1271+
if self.verbose:
1272+
print("Checkpoints loaded successfully.")
12611273

12621274
if self.fast_dev_run and not isinstance(user_items, multiprocessing.queues.Queue):
12631275
assert isinstance(workers_user_items, list)
12641276

12651277
items_to_keep = self.fast_dev_run if isinstance(self.fast_dev_run, int) else _DEFAULT_FAST_DEV_RUN_ITEMS
12661278
workers_user_items = [w[:items_to_keep] for w in workers_user_items]
1267-
print(f"Fast dev run is enabled. Limiting to {items_to_keep} items per process.")
1279+
if self.verbose:
1280+
print(f"Fast dev run is enabled. Limiting to {items_to_keep} items per process.")
12681281

12691282
self._cleanup_cache()
12701283

12711284
num_items = sum([len(items) for items in workers_user_items]) if workers_user_items is not None else -1
12721285

1273-
if workers_user_items is not None:
1274-
print(
1275-
f"Starting {self.num_workers} workers with {num_items} items."
1276-
f" The progress bar is only updated when a worker finishes."
1277-
)
1278-
else:
1279-
print(f"Starting {self.num_workers} workers with a Queue to process items on demand.")
1286+
if self.verbose:
1287+
if workers_user_items is not None:
1288+
print(
1289+
f"Starting {self.num_workers} workers with {num_items} items."
1290+
f" The progress bar is only updated when a worker finishes."
1291+
)
1292+
else:
1293+
print(f"Starting {self.num_workers} workers with a Queue to process items on demand.")
12801294

12811295
if self.input_dir is None and self.src_resolver is not None and self.input_dir:
12821296
self.input_dir = self.src_resolver(self.input_dir)
1283-
print(f"The remote_dir is `{self.input_dir}`.")
1297+
if self.verbose:
1298+
print(f"The remote_dir is `{self.input_dir}`.")
12841299

12851300
signal.signal(signal.SIGINT, self._signal_handler)
12861301

12871302
self._create_process_workers(data_recipe, workers_user_items)
12881303

1289-
print("Workers are ready ! Starting data processing...")
1304+
if self.verbose:
1305+
print("Workers are ready ! Starting data processing...")
12901306

12911307
current_total = 0
12921308
if _TQDM_AVAILABLE:
@@ -1306,7 +1322,8 @@ def run(self, data_recipe: DataRecipe) -> None:
13061322
total_num_items = len(user_items) if isinstance(user_items, list) else -1
13071323

13081324
while True:
1309-
flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)
1325+
if self.verbose:
1326+
flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)
13101327

13111328
# Exit early if all the workers are done.
13121329
# This means either there were some kinda of errors, or optimize function was very small.
@@ -1315,7 +1332,8 @@ def run(self, data_recipe: DataRecipe) -> None:
13151332
error = self.error_queue.get(timeout=0.01)
13161333
self._exit_on_error(error)
13171334
except Empty:
1318-
print("All workers are done. Exiting!")
1335+
if self.verbose:
1336+
print("All workers are done. Exiting!")
13191337
break
13201338

13211339
try:
@@ -1349,13 +1367,15 @@ def run(self, data_recipe: DataRecipe) -> None:
13491367
with open("status.json", "w") as f:
13501368
json.dump({"progress": str(100 * current_total * num_nodes / total_num_items) + "%"}, f)
13511369

1352-
flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)
1370+
if self.verbose:
1371+
flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)
13531372

13541373
if _TQDM_AVAILABLE:
13551374
pbar.clear()
13561375
pbar.close()
13571376

1358-
print("Workers are finished.")
1377+
if self.verbose:
1378+
print("Workers are finished.")
13591379
size = len(workers_user_items) if workers_user_items is not None else None
13601380
result = data_recipe._done(size, self.delete_cached_files, self.output_dir)
13611381

@@ -1375,8 +1395,8 @@ def run(self, data_recipe: DataRecipe) -> None:
13751395
num_chunks=result.num_chunks,
13761396
num_bytes_per_chunk=result.num_bytes_per_chunk,
13771397
)
1378-
1379-
print("Finished data processing!")
1398+
if self.verbose:
1399+
print("Finished data processing!")
13801400
if self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe):
13811401
# clean up checkpoints
13821402
self._cleanup_checkpoints()

src/litdata/processing/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def optimize(
411411
optimize_dns: Optional[bool] = None,
412412
storage_options: dict[str, Any] = {},
413413
keep_data_ordered: bool = True,
414+
verbose: bool = True,
414415
) -> None:
415416
"""This function converts a dataset into chunks, possibly in a distributed way.
416417
@@ -454,6 +455,7 @@ def optimize(
454455
workload and reduce idle time when some workers finish early. This may lead to unordered
455456
processing of items. If True, each worker processes a statically assigned subset of items
456457
in order.
458+
verbose: Whether to print the progress of the optimization. Defaults to True.
457459
"""
458460
_check_version_and_prompt_upgrade(__version__)
459461

@@ -492,7 +494,7 @@ def optimize(
492494
"Only https://lightning.ai/ supports multiple nodes or selecting a machine.Create an account to try it out."
493495
)
494496

495-
if not _IS_IN_STUDIO:
497+
if not _IS_IN_STUDIO and verbose:
496498
print(
497499
"Create an account on https://lightning.ai/ to optimize your data faster "
498500
"using multiple nodes and large machines."
@@ -564,6 +566,7 @@ def optimize(
564566
start_method=start_method,
565567
storage_options=storage_options,
566568
keep_data_ordered=keep_data_ordered,
569+
verbose=verbose,
567570
)
568571

569572
with optimize_dns_context(optimize_dns if optimize_dns is not None else False):

tests/processing/test_functions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,24 @@ def test_optimize_with_streaming_dataloader_on_parquet_data(tmpdir, num_workers)
828828
# check all the indexes are correct
829829
indexes = [sample_record["index"].item() for sample_record in ds]
830830
assert indexes == list(range(num_items)), f"Expected indexes to be {list(range(num_items))}, but got {indexes}"
831+
832+
833+
@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
834+
@pytest.mark.parametrize("verbose", [True, False])
835+
def test_verbose_optimize(tmpdir, verbose):
836+
output_dir = str(tmpdir / "output_dir")
837+
838+
with mock.patch("builtins.print") as mock_print:
839+
optimize(
840+
fn=compress,
841+
inputs=list(range(5)),
842+
num_workers=1,
843+
output_dir=output_dir,
844+
chunk_size=2,
845+
verbose=verbose,
846+
mode="overwrite",
847+
)
848+
if verbose:
849+
mock_print.assert_called()
850+
else:
851+
mock_print.assert_not_called()

0 commit comments

Comments
 (0)