Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

incorrect dataloader length when drop_last=False #402

Open
grez72 opened this issue Oct 28, 2024 · 1 comment
Open

incorrect dataloader length when drop_last=False #402

grez72 opened this issue Oct 28, 2024 · 1 comment
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@grez72
Copy link
Contributor

grez72 commented Oct 28, 2024

🐛 Bug

When drop_last=False, len(StreamingDataLoader) returns the incorrect length if batch_size does not divide evenly into len(dataset). It appears to return ceil(length / self.batch_size), but the actual length is greater than this and depends on the num_workers (apparently each worker returns a final batch that's < batch_size). One consequence is that the number of full batches (where actual_batch_size == dataloader.batch_size) is less than dataset.length // batch_size.

I noticed this because I use fastprogress.progress_bar instead of tqdm, and that progress_bar appears to check len(dataloader) to determine the total number of items to iterate over, and consequently drops the extra partial batches. So I was expecting to iterate over the full imagenet validation set (50000 samples), but was only iterating over 49432 samples even though I set drop_last=False.

To Reproduce

Steps to reproduce the behavior...

Code sample

generate a fake dataset for testing

import os, io
import numpy as np
from PIL import Image
import litdata as ld

def random_images_jpeg_encode(index):
    fake_images = Image.fromarray(np.random.randint(0, 256, (224, np.random.choice([224,320,384]), 3), dtype=np.uint8))
    fake_labels = np.random.randint(10)
    
    image_bytes = io.BytesIO()
    fake_images.save(image_bytes, format="JPEG", quality=100, optimize=True)
    image_bytes.seek(0)
    
    # You can use any key:value pairs. Note that their types must not change between samples, and Python lists must
    # always contain the same number of elements with the same types.
    data = {"index": index, "image": image_bytes.read(), "label": fake_labels}

    return data

ld.optimize(
    fn=random_images_jpeg_encode,       # the function applied to each input
    inputs=list(range(50000)),          # the inputs to the function (here it's a list of numbers)
    output_dir="fast_data",             # optimized data is stored here
    num_workers=4,                      # The number of workers on the same machine
    chunk_bytes="64MB"                  # size of each chunk
)

helpers for testing iteration over fake dataset

import os
import torch
from tqdm import tqdm
from fastprogress import progress_bar
from litdata import StreamingDataset, StreamingDataLoader
from litdata.streaming.serializers import JPEGSerializer
import torchvision.transforms.v2 as T2
from pdb import set_trace

serializer = JPEGSerializer()

class ImageNetStreamingDataset(StreamingDataset):

    def __init__(self, *args, **kwargs):
        self.transform = T2.Compose([
            lambda img_bytes: serializer.deserialize(img_bytes),
            T2.RandomResizedCrop(224, antialias=True),
            T2.RandomHorizontalFlip(p=.5),
            T2.ToImage(),            
            T2.ToDtype(torch.float16, scale=True),        
        ])
        super().__init__(*args, **kwargs)

    def __getitem__(self, idx):
        # Note: If torchvision is installed, we return a tensor image instead of a pil image as it is much faster. 
        sample  = super().__getitem__(idx) # <- Whatever you returned from the DatasetOptimizer prepare_item method.
        sample['image'] = self.transform(sample['image'])
        return sample
    
def get_dataloader(input_dir, num_workers, batch_size, drop_last):
    dataset = ImageNetStreamingDataset(input_dir, shuffle=False, drop_last=drop_last)
    print(f"Length of dataset: {len(dataset)}")
    dataloader = StreamingDataLoader(dataset, num_workers=num_workers, batch_size=batch_size, 
                                     profile_batches=False, shuffle=False, drop_last=drop_last)
    print(f"Length of dataloader: {len(dataloader)}")
    return dataloader

def iterate_dataloader(dataloader, pbar):
    # iterate over dataloader
    image_count = 0
    batch_count = 0
    full_batch_count = 0
    partial_batch_sizes = []
    for batch_num,sample in enumerate(pbar(dataloader)):
        batch_count += 1
        image_count+=sample['image'].shape[0]
        bs = sample['image'].shape[0]
        if bs != dataloader.batch_size: 
            partial_batch_sizes.append(bs)
        else:
            full_batch_count+=1
            
    print(f"batch_size: {dataloader.batch_size}")
    print(f"num_workers: {dataloader.num_workers}")
    if len(dataloader) != batch_count:
        print(f"\u274C len(dataloader) = {len(dataloader)}, actual num_batches = {batch_count}")
    else:
        print(f"\u2705 len(dataloader) = {len(dataloader)}, actual num_batches = {batch_count}")
    if image_count != len(dataloader.dataset):
        print(f"\u274C Actual number of images: {image_count}")
    else:
        print(f"\u2705 Actual number of images: {image_count}")        
    print(f"Number of full batches (img_count == {dataloader.batch_size}): {full_batch_count}")
    print(f"Number partial batches (img_count < {dataloader.batch_size}): {len(partial_batch_sizes)}")
    print(f"Sizes of partial batches: {partial_batch_sizes}")

test with tqdm

You'll see that the len(dataloader) is not match the actual number of batches, but tqdm still iterates over the full dataset (a bunch of partial batches, one per worker).

dataloader = get_dataloader(input_dir='fast_data', num_workers = 12, batch_size = 256, drop_last = False)
iterate_dataloader(dataloader, tqdm)

Length of dataset: 50000
Length of dataloader: 196
batch_size: 256
num_workers: 12
❌ len(dataloader) = 196, actual num_batches = 204
✅ Actual number of images: 50000
Number of full batches (img_count == 256): 192
Number partial batches (img_count < 256): 12
Sizes of partial batches: [70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 78]

test with fastprogress.progress_bar

fastprogress.progress_bar stops early (after len(dataloader) batches), dropping most of the partial batches.

dataloader = get_dataloader(input_dir='fast_data',num_workers = 12, batch_size = 256, drop_last = False)
iterate_dataloader(dataloader, progress_bar)

Length of dataset: 50000
Length of dataloader: 196

batch_size: 256
num_workers: 12
✅ len(dataloader) = 196, actual num_batches = 196
❌ Actual number of images: 49432
Number of full batches (img_count == 256): 192
Number partial batches (img_count < 256): 4
Sizes of partial batches: [70, 70, 70, 70]

Expected behavior

I would expect len(dataloader) to return the actual number of batches that will be yielded when iterating over the dataloader.

I would also have expected there to be only one "partial batch" that's less than the total batch size (similar to the behavior seen with the torchvision DataLoader). So for the examples above, I would expect 195 batches of size 256, and a single partial batch of size 80 (195*256+80 = 50,000).

Additional context

latest litdata

@grez72 grez72 added bug Something isn't working help wanted Extra attention is needed labels Oct 28, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda Borda changed the title incorrect dataloader length when drop_last=False incorrect dataloader length when drop_last=False Oct 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant