Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FileCreatingProgress class #195

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 71 additions & 27 deletions adaptive_scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
from adaptive.notebook_integration import in_ipynb
from rich.console import Console
from rich.progress import (
MofNCompleteColumn,
Progress,
SpinnerColumn,
TimeElapsedColumn,
get_console,
)
Expand Down Expand Up @@ -1199,15 +1201,19 @@ def _update_progress_for_paths(
progress: Progress,
total_task: TaskID | None,
task_ids: dict[str, TaskID],
) -> int:
) -> bool:
"""Update progress bars for each set of paths."""
n_completed = _remove_completed_paths(paths_dict)
total_completed = sum(n_completed.values())
for key, n_done in n_completed.items():
progress.update(task_ids[key], advance=n_done)
if total_task is not None:
progress.update(total_task, advance=total_completed)
return total_completed
progress.refresh()
if not any(paths_dict.values()):
progress.stop()
return True
return False


def _remove_completed_paths(
Expand Down Expand Up @@ -1241,24 +1247,17 @@ def _remove_completed_paths(
return n_completed


async def _track_file_creation_progress(
def _initialize_progress_for_paths(
paths_dict: dict[str, set[Path | tuple[Path, ...]]],
progress: Progress,
interval: float = 1,
) -> None:
"""Asynchronously track and update the progress of file creation.
) -> tuple[Progress, dict[str, TaskID], TaskID | None, int]:
columns = (
SpinnerColumn(),
*Progress.get_default_columns(),
TimeElapsedColumn(),
MofNCompleteColumn(),
)
progress = Progress(*columns, auto_refresh=False)

Parameters
----------
paths_dict
A dictionary with keys representing categories and values being sets of file paths to monitor.
progress
The Progress object from the rich library for displaying progress.
interval
The time interval (in seconds) at which to update the progress. The interval is dynamically
adjusted to be at least 50 times the time it takes to update the progress. This ensures that
updating the progress does not take up a significant amount of time.
"""
# create total_files and add_total_progress before updating paths_dict
total_files = sum(len(paths) for paths in paths_dict.values())
add_total_progress = len(paths_dict) > 1
Expand All @@ -1283,21 +1282,41 @@ async def _track_file_creation_progress(
total=n_remaining + n_done,
completed=n_done,
)
return progress, task_ids, total_task, total_files


async def _track_file_creation_progress(
paths_dict: dict[str, set[Path | tuple[Path, ...]]],
interval: float = 1,
) -> None:
"""Asynchronously track and update the progress of file creation.

Parameters
----------
paths_dict
A dictionary with keys representing categories and values being sets of file paths to monitor.
progress
The Progress object from the rich library for displaying progress.
interval
The time interval (in seconds) at which to update the progress. The interval is dynamically
adjusted to be at least 50 times the time it takes to update the progress. This ensures that
updating the progress does not take up a significant amount of time.
"""
progress, task_ids, total_task, total_files = _initialize_progress_for_paths(
paths_dict,
)
try:
progress.start() # Start the progress display
total_processed = 0
while True:
t_start = time.time()
total_processed += _update_progress_for_paths(
is_done = _update_progress_for_paths(
paths_dict,
progress,
total_task,
task_ids,
)
if total_processed >= total_files:
progress.refresh() # Final refresh to ensure 100%
if is_done:
break # Exit loop if all files are processed
progress.refresh()
# Sleep for at least 50 times the update time
t_update = time.time() - t_start
await asyncio.sleep(max(interval, 50 * t_update))
Expand All @@ -1306,9 +1325,35 @@ async def _track_file_creation_progress(
progress.stop() # Stop the progress display, regardless of what happens


class FileCreatingProgress:
"""A class for tracking the progress of file creation.

Alternative to the async function `track_file_creation_progress`.

This class needs to be manually updated by calling the `update` method.
"""

def __init__(self, paths_dict: dict[str, set[Path | tuple[Path, ...]]]) -> None:
"""Initialize the progress bar."""
get_console().clear_live()
_init = _initialize_progress_for_paths(paths_dict)
self.paths_dict = paths_dict
self.progress, self.task_ids, self.total_task, self.total_files = _init
self.progress.start()

def update(self) -> None:
"""Update the progress bar."""
_update_progress_for_paths(
self.paths_dict,
self.progress,
self.total_task,
self.task_ids,
)


def track_file_creation_progress(
paths_dict: dict[str, set[Path | tuple[Path, ...]]],
interval: int = 1,
interval: float = 1,
) -> asyncio.Task:
"""Initialize and asynchronously track the progress of file creation.

Expand Down Expand Up @@ -1350,8 +1395,7 @@ def track_file_creation_progress(
>>> task = track_file_creation_progress(paths_dict)
"""
get_console().clear_live() # avoid LiveError, only 1 live render allowed at a time
columns = (*Progress.get_default_columns(), TimeElapsedColumn())
progress = Progress(*columns, auto_refresh=False)
coro = _track_file_creation_progress(paths_dict, progress, interval)

coro = _track_file_creation_progress(paths_dict, interval)
ioloop = asyncio.get_event_loop()
return ioloop.create_task(coro)