Skip to content

[wip] feat: Add StreamingRawDataset for cloud storage streaming (early stage) #652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

bhimrazy
Copy link
Collaborator

@bhimrazy bhimrazy commented Jul 6, 2025

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Overview

Adds a new StreamingRawDataset class that enables efficient streaming of raw files from cloud storage (S3/GCS) without requiring data optimization.

Current State
⚠️ Early Stage & Testing Phase - May change significantly based on feedback and testings.

  • Indexing

Usage Example

from litdata.streaming.raw_dataset import StreamingRawDataset

# Custom dataset implementation
class MyImageDataset(StreamingRawDataset):
    def load_sample(self, local_path, file_path, class_name, index):
        from PIL import Image
        return {"image": Image.open(local_path), "label": class_name}

# Usage
dataset = MyImageDataset("s3://bucket/dataset/", max_preload_size=50)
dataloader = DataLoader(dataset, batch_size=32)

Benchmarks
Initial testing is done using Caltech-101 (4cpu machine), but final benchmarking will be done using ImageNet.

To index ImageNet it takes around 2 mins on 4 core cpu machine

Caltech-101
import logging
import os
import time
from typing import Any

import torch
import torchvision.transforms.v2 as T
from torch.utils.data import DataLoader
from tqdm import tqdm

from litdata.streaming.raw_dataset import StreamingRawDataset

logging.basicConfig(level=logging.INFO)


class Caltech101Dataset(StreamingRawDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.transform = T.Compose(
            [
                T.ToImage(),  # Convert images to PIL Image
                T.Resize((224, 224)),  # Resize images to 224x224
                T.ToDtype(torch.float32, scale=True),
            ]
        )
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

    def load_sample(
        self, local_path: str, file_path: str, class_name: str, index: int
    ) -> Any:
        from PIL import Image

        try:
            # Load the image from the local path
            img = Image.open(local_path).convert("RGB")
            if self.transform:
                img = self.transform(img)

            # Convert class name to index
            label = self.class_to_idx[class_name]
            return img, label
        except Exception as e:
            print(f"Error loading sample {index}: {e}")
            print(f"Local path: {local_path}")
            print(f"File path: {file_path}")
            print(f"Class name: {class_name}")
            raise


if __name__ == "__main__":
    # Example usage
    # clear cache before running
    os.system("rm -rf caltech_cache/*")
    dataset = Caltech101Dataset(
        input_dir="s3://grid-cloud-litng-ai-03/projects/01jpacd4y2yza88t23wf049m0t/datasets/caltech101/101_ObjectCategories",
        cache_dir="caltech_cache",
        index_workers=16,
        max_preload_size=100,
        download_workers=100,
    )

    dataloader = DataLoader(
        dataset, batch_size=32, shuffle=False, num_workers=0
    )  # Set to 0 for debugging
    # benchmark the dataset loading
    start_time = time.perf_counter()

    total_samples = 0

    for images, labels in tqdm(dataloader, desc="Loading dataset"):
        total_samples += len(images)

    elapsed = time.perf_counter() - start_time
    throughput = total_samples / elapsed if elapsed > 0 else 0

    print("⚡ Performance Results:")
    print(f"   Processed: {total_samples} samples in {elapsed:.2f}s")
    print(f"   Throughput: {throughput:.1f} samples/sec")
    print(f"   Avg time per sample: {(elapsed / total_samples) * 1000:.2f}ms")
Loading dataset: 100%|█████████████████████████████████████████████████| 272/272 [01:10<00:00,  3.87it/s]
⚡ Performance Results:
   Processed: 8677 samples in 70.36s
   Throughput: 123.3 samples/sec
   Avg time per sample: 8.11ms

with ImageFolder it was around 800-900.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@bhimrazy bhimrazy self-assigned this Jul 6, 2025
@bhimrazy bhimrazy marked this pull request as draft July 6, 2025 18:44
@bhimrazy bhimrazy requested a review from Copilot July 6, 2025 18:45
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds a new StreamingRawDataset for streaming raw files directly from cloud storage (S3/GCS) with local caching, multithreaded indexing, and preloading.

  • Introduces CacheManager for directory-structured caching and file downloading
  • Builds or loads a file index in parallel and saves it to cache
  • Implements adaptive preloading and cache‐hit statistics for performance monitoring
Comments suppressed due to low confidence (2)

src/litdata/streaming/raw_dataset.py:323

  • The fallback return dict uses class_name while the success path uses label for the class key. This inconsistency can confuse consumers; unify on a single key (e.g. always label).
            return {"path": file_path, "class_name": class_name, "index": index}

src/litdata/streaming/raw_dataset.py:94

  • There are no tests accompanying this new streaming dataset. Please add unit tests for index building (fresh and cached), caching behavior, __getitem__, and fallback loading to ensure correct and robust behavior.
class StreamingRawDataset(IterableDataset):

Copy link

codecov bot commented Jul 6, 2025

Codecov Report

Attention: Patch coverage is 0% with 222 lines in your changes missing coverage. Please review.

Project coverage is 81%. Comparing base (fc59c8a) to head (d2fcf50).

❌ Your patch check has failed because the patch coverage (0%) is below the target coverage (50%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@         Coverage Diff          @@
##           main   #652    +/-   ##
====================================
- Coverage    83%    81%    -3%     
====================================
  Files        49     50     +1     
  Lines      6785   7007   +222     
====================================
+ Hits       5662   5665     +3     
- Misses     1123   1342   +219     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants