Skip to content

feat: add verbose option in optimize_fn #654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ def __init__(
start_method: Optional[str] = None,
storage_options: dict[str, Any] = {},
keep_data_ordered: bool = True,
verbose: bool = True,
):
"""Provides an efficient way to process data across multiple machine into chunks to make training faster.

Expand All @@ -1115,6 +1116,7 @@ def __init__(
inside an interactive shell like Ipython.
storage_options: Storage options for the cloud provider.
keep_data_ordered: Whether to use a shared queue for the workers or not.
verbose: Whether to print the progress & logs of the workers. Defaults to True.
"""
# spawn doesn't work in IPython
start_method = start_method or ("fork" if in_notebook() else "spawn")
Expand All @@ -1124,7 +1126,8 @@ def __init__(
msg += "Tip: Libraries relying on lock can hang with `fork`. To use `spawn` in notebooks, "
msg += "move your code to files and import it within the notebook."

print(msg)
if verbose:
print(msg)

multiprocessing.set_start_method(start_method, force=True)

Expand Down Expand Up @@ -1166,9 +1169,13 @@ def __init__(
if self.output_dir:
# Ensure the output dir is the same across all nodes
self.output_dir = broadcast_object("output_dir", self.output_dir, rank=_get_node_rank())
print(f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}")
if verbose:
print(
f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}"
)

self.random_seed = random_seed
self.verbose = verbose

def run(self, data_recipe: DataRecipe) -> None:
"""Triggers the data recipe processing over your dataset."""
Expand All @@ -1179,7 +1186,8 @@ def run(self, data_recipe: DataRecipe) -> None:
self._cleanup_checkpoints()

t0 = time()
print(f"Setup started with fast_dev_run={self.fast_dev_run}.")
if self.verbose:
print(f"Setup started with fast_dev_run={self.fast_dev_run}.")

# Force random seed to be fixed
random.seed(self.random_seed)
Expand Down Expand Up @@ -1231,7 +1239,8 @@ def run(self, data_recipe: DataRecipe) -> None:
if isinstance(user_items, list)
else "Using a Queue to process items on demand."
)
print(f"Setup finished in {round(time() - t0, 3)} seconds. {msg}")
if self.verbose:
print(f"Setup finished in {round(time() - t0, 3)} seconds. {msg}")

if self.use_checkpoint:
if isinstance(user_items, multiprocessing.queues.Queue):
Expand All @@ -1244,49 +1253,56 @@ def run(self, data_recipe: DataRecipe) -> None:
# Checkpoint feature is not supported for generators for now.
raise ValueError("Checkpoint feature is not supported for generators, yet.")
# get the last checkpoint details
print("Resuming from last saved checkpoint...")
if self.verbose:
print("Resuming from last saved checkpoint...")
self._load_checkpoint_config(workers_user_items)

assert isinstance(self.checkpoint_next_index, list)

if all(self.checkpoint_next_index[i] == 0 for i in range(self.num_workers)):
# save the current configuration in the checkpoints.json file
print("No checkpoints found. Saving current configuration...")
if self.verbose:
print("No checkpoints found. Saving current configuration...")
self._save_current_config(workers_user_items)
else:
# load the last checkpoint details
assert isinstance(self.checkpoint_next_index, list)
workers_user_items = [w[self.checkpoint_next_index[i] :] for i, w in enumerate(workers_user_items)]
print("Checkpoints loaded successfully.")
if self.verbose:
print("Checkpoints loaded successfully.")

if self.fast_dev_run and not isinstance(user_items, multiprocessing.queues.Queue):
assert isinstance(workers_user_items, list)

items_to_keep = self.fast_dev_run if isinstance(self.fast_dev_run, int) else _DEFAULT_FAST_DEV_RUN_ITEMS
workers_user_items = [w[:items_to_keep] for w in workers_user_items]
print(f"Fast dev run is enabled. Limiting to {items_to_keep} items per process.")
if self.verbose:
print(f"Fast dev run is enabled. Limiting to {items_to_keep} items per process.")

self._cleanup_cache()

num_items = sum([len(items) for items in workers_user_items]) if workers_user_items is not None else -1

if workers_user_items is not None:
print(
f"Starting {self.num_workers} workers with {num_items} items."
f" The progress bar is only updated when a worker finishes."
)
else:
print(f"Starting {self.num_workers} workers with a Queue to process items on demand.")
if self.verbose:
if workers_user_items is not None:
print(
f"Starting {self.num_workers} workers with {num_items} items."
f" The progress bar is only updated when a worker finishes."
)
else:
print(f"Starting {self.num_workers} workers with a Queue to process items on demand.")

if self.input_dir is None and self.src_resolver is not None and self.input_dir:
self.input_dir = self.src_resolver(self.input_dir)
print(f"The remote_dir is `{self.input_dir}`.")
if self.verbose:
print(f"The remote_dir is `{self.input_dir}`.")

signal.signal(signal.SIGINT, self._signal_handler)

self._create_process_workers(data_recipe, workers_user_items)

print("Workers are ready ! Starting data processing...")
if self.verbose:
print("Workers are ready ! Starting data processing...")

current_total = 0
if _TQDM_AVAILABLE:
Expand All @@ -1306,7 +1322,8 @@ def run(self, data_recipe: DataRecipe) -> None:
total_num_items = len(user_items) if isinstance(user_items, list) else -1

while True:
flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)
if self.verbose:
flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)

# Exit early if all the workers are done.
# This means either there were some kinda of errors, or optimize function was very small.
Expand All @@ -1315,7 +1332,8 @@ def run(self, data_recipe: DataRecipe) -> None:
error = self.error_queue.get(timeout=0.01)
self._exit_on_error(error)
except Empty:
print("All workers are done. Exiting!")
if self.verbose:
print("All workers are done. Exiting!")
break

try:
Expand Down Expand Up @@ -1349,13 +1367,15 @@ def run(self, data_recipe: DataRecipe) -> None:
with open("status.json", "w") as f:
json.dump({"progress": str(100 * current_total * num_nodes / total_num_items) + "%"}, f)

flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)
if self.verbose:
flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None)

if _TQDM_AVAILABLE:
pbar.clear()
pbar.close()

print("Workers are finished.")
if self.verbose:
print("Workers are finished.")
size = len(workers_user_items) if workers_user_items is not None else None
result = data_recipe._done(size, self.delete_cached_files, self.output_dir)

Expand All @@ -1375,8 +1395,8 @@ def run(self, data_recipe: DataRecipe) -> None:
num_chunks=result.num_chunks,
num_bytes_per_chunk=result.num_bytes_per_chunk,
)

print("Finished data processing!")
if self.verbose:
print("Finished data processing!")
if self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe):
# clean up checkpoints
self._cleanup_checkpoints()
Expand Down
5 changes: 4 additions & 1 deletion src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def optimize(
optimize_dns: Optional[bool] = None,
storage_options: dict[str, Any] = {},
keep_data_ordered: bool = True,
verbose: bool = True,
) -> None:
"""This function converts a dataset into chunks, possibly in a distributed way.

Expand Down Expand Up @@ -454,6 +455,7 @@ def optimize(
workload and reduce idle time when some workers finish early. This may lead to unordered
processing of items. If True, each worker processes a statically assigned subset of items
in order.
verbose: Whether to print the progress of the optimization. Defaults to True.
"""
_check_version_and_prompt_upgrade(__version__)

Expand Down Expand Up @@ -492,7 +494,7 @@ def optimize(
"Only https://lightning.ai/ supports multiple nodes or selecting a machine.Create an account to try it out."
)

if not _IS_IN_STUDIO:
if not _IS_IN_STUDIO and verbose:
print(
"Create an account on https://lightning.ai/ to optimize your data faster "
"using multiple nodes and large machines."
Expand Down Expand Up @@ -564,6 +566,7 @@ def optimize(
start_method=start_method,
storage_options=storage_options,
keep_data_ordered=keep_data_ordered,
verbose=verbose,
)

with optimize_dns_context(optimize_dns if optimize_dns is not None else False):
Expand Down
21 changes: 21 additions & 0 deletions tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,24 @@ def test_optimize_with_streaming_dataloader_on_parquet_data(tmpdir, num_workers)
# check all the indexes are correct
indexes = [sample_record["index"].item() for sample_record in ds]
assert indexes == list(range(num_items)), f"Expected indexes to be {list(range(num_items))}, but got {indexes}"


@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
@pytest.mark.parametrize("verbose", [True, False])
def test_verbose_optimize(tmpdir, verbose):
output_dir = str(tmpdir / "output_dir")

with mock.patch("builtins.print") as mock_print:
optimize(
fn=compress,
inputs=list(range(5)),
num_workers=1,
output_dir=output_dir,
chunk_size=2,
verbose=verbose,
mode="overwrite",
)
if verbose:
mock_print.assert_called()
else:
mock_print.assert_not_called()
Loading