dataset_and_loader
torch.utils.data.DataLoader Overview
Dataset
Dataset Types
-
Map-style Datasets
- Implement
__getitem__()and__len__(). - Map indices/keys to data samples.
- Example:
dataset[idx]loads an image and its label from disk. - More details: Dataset
- Implement
-
Iterable-style Datasets
-
Subclass
IterableDatasetand implement__iter__(). -
Suitable for sequential data fetching (e.g., streaming from databases, remote servers, logs).
-
Example:
iter(dataset)reads a continuous data stream. -
More details: IterableDataset
-
Note on Multi-Process Loading with Iterable Datasets
- Each worker process replicates the dataset object.
- Replicas must be uniquely configured to avoid duplicated data.
- See IterableDataset documentation for details.
-
IterableDataset in PyTorch is designed for streaming data, meaning that it does not store data in memory as a list or any other indexable structure. Instead, it yields data sequentially via an iterator.
Key Differences Between IterableDataset and Dataset
| Feature | Dataset (Map-Style) |
IterableDataset (Streaming) |
|---|---|---|
| Access Method | __getitem__(index) |
__iter__() (no indices) |
| Storage | Typically a list, array, or structured storage | No fixed storage, generates data dynamically |
| Indexable | Yes (dataset[i]) |
No (data is streamed) |
Supports len() |
Yes | No (unless explicitly defined) |
| Shuffle Support | Yes (shuffle=True) |
No (but manual shuffling is possible) |
| Use Case | Small to medium datasets that fit in memory | Large datasets, real-time data, online learning |
Example of How IterableDataset Streams Data
Unlike map-style datasets that store data in a list or other indexable structure, an IterableDataset typically reads data on-the-fly (e.g., from a file, database, or message queue).
Example 1: Reading a Large File Line-by-Line
from torch.utils.data import IterableDataset, DataLoader
class TextFileDataset(IterableDataset):
def __init__(self, filename):
self.filename = filename # No list, just a file reference
def __iter__(self):
with open(self.filename, "r") as file:
for line in file:
yield line.strip() # Data is streamed, not stored
dataset = TextFileDataset("large_text_file.txt")
dataloader = DataLoader(dataset, batch_size=2) # No shuffle, but still batches
for batch in dataloader:
print(batch) # Batches of streamed lines
Here:
- The dataset does not store the entire file in memory.
- It reads each line sequentially, making it memory efficient.
Example 2: Simulating an Endless Data Stream
import random
import time
class RandomNumberStream(IterableDataset):
def __iter__(self):
while True: # Infinite loop (useful for real-time data)
yield random.randint(0, 100)
time.sleep(0.5) # Simulating real-time data arrival
dataset = RandomNumberStream()
dataloader = DataLoader(dataset, batch_size=5) # Still batches!
for batch in dataloader:
print(batch)
Here:
- The dataset continuously generates data instead of storing it.
batch_size=5groups numbers, but no indexing or shuffling is involved.
How to Shuffle an IterableDataset?
Since IterableDataset does not support shuffle=True, you can use buffered shuffling:
import itertools
import random
class ShuffledStreamDataset(IterableDataset):
def __init__(self, data_stream, buffer_size=10):
self.data_stream = data_stream
self.buffer_size = buffer_size
def __iter__(self):
buffer = []
for item in self.data_stream():
buffer.append(item)
if len(buffer) >= self.buffer_size:
random.shuffle(buffer) # Shuffle in small chunks
while buffer:
yield buffer.pop() # Yield shuffled items
random.shuffle(buffer) # Final shuffle
while buffer:
yield buffer.pop()
# Example use case
def generate_numbers():
for i in range(20): # Simulating a stream
yield i
dataset = ShuffledStreamDataset(generate_numbers)
dataloader = DataLoader(dataset, batch_size=4)
for batch in dataloader:
print(batch)
- A buffer (
buffer_size) temporarily stores data for shuffling. - It allows some randomness while still supporting streaming.
DatasetLoader
-
Provides an iterable over a dataset.
-
Supports:
- Map-style and iterable-style datasets
- Custom data loading order
- Automatic batching
- Single- and multi-process data loading
- Automatic memory pinning
-
DataLoaderConstructor SignatureDataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False) -
Automatic Batching in
DataLoader- Enabled by default when
batch_sizeis set (default:1). - Combines individual samples into batched tensors, with one dimension representing the batch (usually the first).
batch_sizeanddrop_lastcontrol batch construction.- For map-style datasets, users can provide a
batch_samplerto yield lists of keys.
- Enabled by default when
-
Batching Behavior Based on Dataset Type
- Map-style datasets:
- Uses a sampler (provided by the user or created based on
shuffle). - Internally constructs a batch_sampler from the sampler.
- Uses a sampler (provided by the user or created based on
- Iterable-style datasets:
- Uses a dummy infinite sampler.
drop_lastremoves incomplete batches for each worker’s dataset replica in multi-processing.
- Map-style datasets:
-
Collation of Data Samples
- After fetching samples,
collate_fncombines them into batches. - Example equivalences:
- Map-style dataset batching:
for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices]) - Iterable-style dataset batching:
dataset_iter = iter(dataset) for indices in batch_sampler: yield collate_fn([next(dataset_iter) for _ in indices])
- Map-style dataset batching:
- A custom
collate_fncan be provided for special collation needs (e.g., padding sequential data to the max batch length).
- After fetching samples,
-
Notes
batch_sizeanddrop_lasthelp form abatch_samplerfrom the sampler.- More details on samplers and
collate_fnare available in their respective sections.
Disabling Automatic Batching
-
Useful when:
- Batching is handled manually in the dataset.
- Bulk reads (e.g., database queries, continuous memory chunks) are more efficient.
- Batch size depends on the data.
- The program operates on individual samples.
-
How to disable it:
- Set both
batch_size=Noneandbatch_sampler=None(defaultbatch_sampler=None). - Each sample is processed with
collate_fn.
- Set both
-
Behavior when disabled:
- Default
collate_fnconverts NumPy arrays to PyTorch Tensors but leaves other types unchanged. - Equivalent processing:
- Map-style dataset:
for index in sampler: yield collate_fn(dataset[index]) - Iterable-style dataset:
for data in iter(dataset): yield collate_fn(data)
- Map-style dataset:
- Default
Working with collate_fn
-
When automatic batching is disabled:
collate_fnis applied to each sample before yielding.- Default
collate_fn:- Converts NumPy arrays to PyTorch Tensors.
- Leaves other data types unchanged.
-
When automatic batching is enabled:
collate_fnis applied to a list of samples at each iteration.- Default
collate_fn(default_collate()):- Adds a batch dimension.
- Converts NumPy arrays and Python numerical values to PyTorch Tensors.
- Maintains data structure (e.g., dictionaries, lists, namedtuples).
- Outputs Tensors when possible, otherwise keeps lists.
-
Custom
collate_fnUse Cases- Custom batching (e.g., using a different batch dimension).
- Padding sequences of varying lengths.
- Supporting non-standard data types.
-
Debugging Unexpected Data Shapes/Types
- If
DataLoaderoutputs have unexpected dimensions or types, verify thecollate_fn.
- If
Single-Process Data Loading (Default)
-
Data fetching happens in the same process as
DataLoader. -
May block computation but avoids overhead from multiprocessing.
-
Suitable when:
- Resources for process sharing (e.g., shared memory, file descriptors) are limited.
- The dataset is small and fits in memory.
- Debugging is needed (produces more readable error traces).
Multi-Process Data Loading
-
Enabled by setting
num_workers > 0. -
Spawns
num_workersprocesses, each handling data loading independently. -
Potential memory issue:
- Worker processes consume as much CPU memory as the parent process for all Python objects accessed.
- Large datasets (e.g., many filenames stored in memory) combined with multiple workers can lead to high memory usage (
num_workers * parent process size). - Workaround: Use non-reference-counted data formats like Pandas, NumPy, or PyArrow.
-
How Multi-Process Data Loading Works
- When iterating over
DataLoader,num_workersworker processes are spawned. dataset,collate_fn, andworker_init_fnare sent to each worker.- Map-style datasets:
- Main process generates indices via
samplerand assigns them to workers. - Any
shufflerandomization occurs in the main process.
- Main process generates indices via
- Iterable-style datasets:
- Each worker gets a replica of the dataset.
- Naive multi-processing may cause duplicate data unless each replica is configured differently using:
torch.utils.data.get_worker_info()worker_init_fn
drop_lastremoves incomplete batches per worker.
- Workers are terminated after iteration completion or garbage collection.
- When iterating over
-
Getting Worker Information
torch.utils.data.get_worker_info()provides:- Worker ID
- Dataset replica
- Initial seed
- Helps configure dataset replicas uniquely (e.g., sharding).
-
CUDA Tensors in Multi-Process Loading
- Not recommended to return CUDA tensors directly due to multiprocessing complexities.
- Alternative: Use
pin_memory=Truefor faster GPU transfers.
-
Platform-Specific Behavior (Windows & macOS vs. Unix)
- Unix (Linux): Uses
fork()(default).- Workers inherit address space from the main process.
- Windows/macOS: Uses
spawn().- Launches a new interpreter and re-runs the main script.
- To ensure compatibility:
- Wrap the main script in
if __name__ == '__main__':to prevent re-execution. - Define custom
collate_fn,worker_init_fn, and dataset classes outside of__main__.
- Wrap the main script in
- Unix (Linux): Uses
-
Randomness in Multi-Process Loading
- Each worker’s PyTorch seed is set to
base_seed + worker_id. - Other libraries (NumPy, random, etc.) may get identical seeds, causing repeated random numbers.
- Solution: Use
torch.utils.data.get_worker_info().seedortorch.initial_seed()insideworker_init_fnto manually seed other libraries.
- Each worker’s PyTorch seed is set to
Memory Pinning for Faster GPU Transfers
-
Host-to-GPU copies are faster when using pinned (page-locked) memory.
-
Setting
pin_memory=TrueinDataLoaderautomatically places fetched Tensors in pinned memory. -
Applies only to recognized data types:
- Tensors
- Maps and iterables containing Tensors
-
Custom data types are not pinned by default unless they implement a
pin_memory()method. -
Enabling Memory Pinning for Custom Data Types
-
Define a
pin_memory()method in the custom batch class. -
Example:
class SimpleCustomBatch: def __init__(self, data): transposed_data = list(zip(*data)) self.inp = torch.stack(transposed_data[0], 0) self.tgt = torch.stack(transposed_data[1], 0) # Custom memory pinning method def pin_memory(self): self.inp = self.inp.pin_memory() self.tgt = self.tgt.pin_memory() return self def collate_wrapper(batch): return SimpleCustomBatch(batch) inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) dataset = TensorDataset(inps, tgts) loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True) for batch_ndx, sample in enumerate(loader): print(sample.inp.is_pinned()) # True print(sample.tgt.is_pinned()) # True
-
-
Key Takeaways
- Use
pin_memory=Truefor efficient GPU data transfers. - Custom batch types require a
pin_memory()method to enable pinning. - The
pin_memory()method should be explicitly applied to all Tensors within the custom type.
- Use