Data Loading
Data loading is the most underestimated component of ML training. A poorly designed data pipeline can make your GPU sit idle while waiting for the next batch, wasting expensive compute. This chapter covers the PyTorch data loading system -- Dataset, DataLoader, and their many configuration options -- with emphasis on the performance implications of each choice.
Dataset
A Dataset maps an integer index to a sample. It defines what data you have and how to access it:
from torch.utils.data import Dataset, DataLoader
import torch
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_len=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
# Called once per sample -- this is where preprocessing happens
tokens = self.tokenizer(
self.texts[idx],
max_length=self.max_len,
padding='max_length',
truncation=True,
return_tensors='pt',
)
return {
'input_ids': tokens['input_ids'].squeeze(0),
'attention_mask': tokens['attention_mask'].squeeze(0),
'label': torch.tensor(self.labels[idx], dtype=torch.long),
}
# Image dataset example
from PIL import Image
from torchvision import transforms
class ImageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform or transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
image = self.transform(image)
return image, self.labels[idx]
| Issue | Symptom | Fix |
|---|---|---|
Heavy preprocessing in __getitem__ | CPU-bound, GPU idle | Preprocess and save to disk; load preprocessed data |
| Reading many small files | I/O-bound (SSD seeks) | Pack into fewer large files (WebDataset, TFRecord, Arrow/Parquet) |
| PIL/OpenCV image decoding | CPU-bound | Use NVIDIA DALI for GPU-accelerated decoding |
| Tokenization per sample | CPU-bound for long sequences | Pre-tokenize and save; use memory-mapped files |
Large Python objects in __init__ | High memory per worker | Use memory-mapped arrays (np.memmap) or Arrow tables |
Rule of thumb: If __getitem__ takes more than 1 ms, it is likely the bottleneck. Profile with time.perf_counter() or PyTorch profiler.
DataLoader
DataLoader wraps a Dataset and handles batching, shuffling, parallel loading, and memory transfer:
loader = DataLoader(
dataset,
batch_size=32, # Samples per batch
shuffle=True, # Randomize order each epoch (for training)
num_workers=4, # Number of parallel data loading processes
pin_memory=True, # Use page-locked memory for faster GPU transfer
drop_last=True, # Drop last incomplete batch (important for DDP)
prefetch_factor=2, # Each worker prefetches this many batches
persistent_workers=True, # Keep workers alive between epochs (avoids fork overhead)
collate_fn=None, # Custom function to merge samples into a batch
)
# Training loop
for batch in loader:
# batch is on CPU -- move to GPU with non_blocking for async transfer
input_ids = batch['input_ids'].cuda(non_blocking=True)
labels = batch['label'].cuda(non_blocking=True)
output = model(input_ids)
| Parameter | Default | Recommendation | Effect |
|---|---|---|---|
batch_size | 1 | Task-dependent (8-256 typical) | Larger = better GPU utilization, but more memory |
num_workers | 0 | 4-8 per GPU | More = faster loading, but more CPU/RAM |
pin_memory | False | True (always for GPU training) | 10-30% faster CPU-to-GPU transfer |
prefetch_factor | 2 | 2-4 | Higher = smoother pipeline, but more memory |
persistent_workers | False | True (if num_workers > 0) | Avoids worker restart overhead between epochs |
shuffle | False | True for training, False for eval | Randomization for SGD; deterministic for eval |
drop_last | False | True for DDP training | Ensures all ranks get same-sized batches |
collate_fn | Default | Custom for variable-length sequences | Controls how samples are merged |
# Quick benchmark to find optimal num_workers
import time
for num_workers in [0, 1, 2, 4, 8, 16]:
loader = DataLoader(dataset, batch_size=32, num_workers=num_workers, pin_memory=True)
start = time.perf_counter()
for i, batch in enumerate(loader):
if i >= 50:
break
elapsed = time.perf_counter() - start
print(f"num_workers={num_workers}: {elapsed:.2f}s for 50 batches")
Common guidelines:
- Start with
num_workers = 4 * num_gpus - If CPU utilization is low and GPU is idle between batches, increase workers
- If you run out of RAM, reduce workers (each worker gets a copy of the dataset)
num_workers=0means loading happens in the main process (useful for debugging)- On Windows,
num_workers > 0requiresif __name__ == '__main__':guard
Without pin_memory: Regular RAM → Pinned staging buffer → GPU HBM (2 copies)
With pin_memory: Pinned RAM → GPU HBM (1 copy)
The extra copy adds ~30% overhead. Always use pin_memory=True when training on GPU.
non_blocking=True on .cuda() makes the transfer asynchronous: the CPU starts the DMA transfer and immediately continues to the next operation, overlapping data transfer with GPU computation.
IterableDataset
For streaming data that does not fit in memory, or data that arrives sequentially (logs, real-time feeds):
import torch
from torch.utils.data import IterableDataset, DataLoader
class StreamingDataset(IterableDataset):
def __init__(self, file_paths):
super().__init__()
self.file_paths = file_paths
def __iter__(self):
# IMPORTANT: shard across workers to avoid duplicate data
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single-process loading
files = self.file_paths
else:
# Multi-process: each worker processes a subset of files
per_worker = len(self.file_paths) // worker_info.num_workers
worker_id = worker_info.id
start = worker_id * per_worker
end = start + per_worker if worker_id < worker_info.num_workers - 1 \
else len(self.file_paths)
files = self.file_paths[start:end]
for path in files:
with open(path, 'r') as f:
for line in f:
yield process_line(line) # Yield one sample at a time
# Usage
dataset = StreamingDataset(glob.glob('/data/train/*.jsonl'))
loader = DataLoader(dataset, batch_size=32, num_workers=4)
-
No
__len__: IterableDatasets have no length, so progress bars and epoch-based scheduling do not work without manual tracking. -
Worker duplication: Without manual sharding (as shown above), each worker iterates over the entire dataset, producing duplicate samples. This is the most common IterableDataset bug.
-
No random access: You cannot index into an IterableDataset. Shuffling must be done at the file level or with a shuffle buffer:
# Shuffle buffer: maintain a buffer and randomly sample from itbuffer = []for sample in stream:buffer.append(sample)if len(buffer) >= buffer_size:idx = random.randint(0, len(buffer) - 1)yield buffer.pop(idx) -
Distributed training: With DDP, you must shard across both workers and ranks. Use
torch.distributed.get_rank()andget_world_size()for rank-level sharding.
Custom collate_fn
The collate function controls how individual samples are merged into a batch. The default collate stacks tensors and creates batched containers. You need a custom collate for variable-length sequences:
def dynamic_padding_collate(batch):
"""Pad sequences to the longest in the BATCH (not global max).
This saves computation vs. padding to a fixed max length."""
input_ids = [item['input_ids'] for item in batch]
labels = torch.stack([item['label'] for item in batch])
# Pad to max length in this specific batch
max_len = max(ids.size(0) for ids in input_ids)
padded_ids = torch.zeros(len(input_ids), max_len, dtype=torch.long)
attention_mask = torch.zeros(len(input_ids), max_len, dtype=torch.long)
for i, ids in enumerate(input_ids):
padded_ids[i, :ids.size(0)] = ids
attention_mask[i, :ids.size(0)] = 1
return {
'input_ids': padded_ids,
'attention_mask': attention_mask,
'label': labels,
}
# Alternative: use torch's pad_sequence utility
from torch.nn.utils.rnn import pad_sequence
def collate_with_pad_sequence(batch):
input_ids = [item['input_ids'] for item in batch]
labels = torch.stack([item['label'] for item in batch])
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
attention_mask = (input_ids != 0).long()
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'label': labels}
loader = DataLoader(dataset, batch_size=32, collate_fn=dynamic_padding_collate)
# Sort dataset by length, then create batches from similar-length sequences
from torch.utils.data import Sampler
class BucketSampler(Sampler):
def __init__(self, lengths, batch_size, shuffle=True):
self.sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
# Create batches of similar-length sequences
batches = [self.sorted_indices[i:i+self.batch_size]
for i in range(0, len(self.sorted_indices), self.batch_size)]
if self.shuffle:
random.shuffle(batches) # Shuffle batch ORDER, not within-batch order
for batch in batches:
yield from batch
def __len__(self):
return len(self.sorted_indices)
This reduces wasted padding by 30-50% for NLP workloads with variable-length inputs.
Distributed Data Loading
In distributed training (DDP), each GPU must process a different subset of the data:
from torch.utils.data.distributed import DistributedSampler
# Create sampler that shards data across ranks
sampler = DistributedSampler(
dataset,
num_replicas=world_size, # Total number of GPUs
rank=rank, # This GPU's rank
shuffle=True, # Shuffle within each rank's shard
drop_last=True, # Ensure all ranks get same number of batches
)
loader = DataLoader(
dataset,
batch_size=32,
sampler=sampler, # Replaces shuffle=True (sampler handles shuffling)
num_workers=4,
pin_memory=True,
drop_last=True, # Also drop_last in DataLoader for safety
persistent_workers=True,
)
# CRITICAL: set epoch on sampler for proper shuffling
# Without this, each epoch sees the same data order!
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # Changes the random seed for shuffling
for batch in loader:
input = batch.cuda(non_blocking=True)
...
-
Forgetting
sampler.set_epoch(epoch): Without this, every epoch uses the same shuffle order, reducing effective data diversity. -
Using
shuffle=TruewithDistributedSampler: The sampler handles shuffling. Settingshuffle=TrueinDataLoaderalongside a sampler raises an error. -
Unequal last batch: If the dataset size is not divisible by
world_size * batch_size, some ranks may get fewer samples. This causes a hang because AllReduce requires all ranks to participate. Solution:drop_last=Truein both sampler and DataLoader. -
Data augmentation reproducibility: With
num_workers > 0, each worker has its own random state. Use aworker_init_fnto seed workers deterministically:def worker_init_fn(worker_id):seed = torch.initial_seed() % 2**32np.random.seed(seed + worker_id)random.seed(seed + worker_id)loader = DataLoader(dataset, ..., worker_init_fn=worker_init_fn)
Data Loading Performance Checklist
| Check | Why | How to Verify |
|---|---|---|
pin_memory=True | Faster CPU-to-GPU transfer | Benchmark with/without |
.cuda(non_blocking=True) | Async transfer, overlaps with compute | Profile timeline |
num_workers >= 4 | Parallel data preprocessing | GPU utilization > 90%? |
persistent_workers=True | Avoid worker restart overhead | Measure epoch transition time |
| Preprocessed data on fast storage | Avoid CPU preprocessing bottleneck | Check worker CPU utilization |
| Dynamic padding (not fixed max) | Reduce wasted computation | Compare batch processing time |
| Bucketed sampling | Group similar-length sequences | Compare padding ratio |
| Pre-tokenized data | Avoid tokenizer overhead | Profile __getitem__ |
| Memory-mapped files | Reduce RAM usage with many workers | Monitor RSS per worker |