Fast HuggingFace Dataset Loader with Stateful Checkpointing¶
Overview¶
fast_hf_loader_simple.py provides instant dataset loading and efficient mid-epoch checkpointing for large HuggingFace datasets.
Multi-node training? See the Dataset Server — a Uvicorn + FastAPI service that serves what
fast_load_iterable_datasetwould load locally, so you don't have to download the same dataset on every node. SetFORGATHER_DATASET_SERVER=http://host:8766and existing call sites route through it transparently.
Performance¶
Loading Speed: - First load: ~20 minutes (indexes Arrow files) - Subsequent loads: <1 second (instant) - Your dataset: 234 Arrow files, ~958GB
Checkpoint Resumption: - Without stateful: Hours to iterate through millions of examples - With stateful: <1 second to resume from any position
Usage¶
Basic Loading¶
from forgather.ml.datasets import fast_load_iterable_dataset
# Load dataset (instant after first time)
ids = fast_load_iterable_dataset(
"HuggingFaceTB/smollm-corpus",
"fineweb-edu-dedup",
split="train"
)
# Shuffle at both shard and example level
ids = ids.shuffle(seed=42, buffer_size=1000)
# Shard-level only (no example buffer)
ids = ids.shuffle(seed=42, buffer_size=0)
# Optional: Shard for DDP
ids = ids.shard(num_shards=world_size, index=rank)
# Lazy transformations
ids = ids.map(tokenize_function)
# Iterate
for example in ids:
# Training code
...
Virtual Splits (Train/Val/Test)¶
Create train/val/test splits without copying data:
from fast_hf_loader_simple import fast_load_iterable_dataset
# Load full dataset
ids = fast_load_iterable_dataset("dataset", "config", split="train")
# Create train/val/test splits (70%/15%/15%)
train_ds = ids.slice(None, 0.7) # First 70%
val_ds = ids.slice(0.7, 0.85) # Next 15%
test_ds = ids.slice(0.85, None) # Last 15%
# Or using percentage strings
train_ds = ids.slice(None, "80%") # First 80%
val_ds = ids.slice("80%", None) # Last 20%
# Or absolute indices
subset = ids.slice(1000, 2000) # Examples 1000-1999
# Or using HuggingFace-style select() with indices
subset = ids.select(range(1000, 2000)) # Examples 1000-1999
first_100 = ids.select(range(100)) # First 100 examples
# Combine with shuffling and sharding
train_ds = ids.shuffle(seed=42).slice(None, 0.8)
train_shard = train_ds.shard(num_shards=world_size, index=rank)
Benefits: - No data duplication - virtual slicing is zero-copy - Works with shuffling and sharding - Compatible with checkpoint resumption - Flexible indexing: percentages, absolute indices, or percentage strings
With StatefulDataLoader Checkpointing¶
from fast_hf_loader_simple import fast_load_iterable_dataset
from torchdata.stateful_dataloader import StatefulDataLoader
# Load and prepare dataset
ids = fast_load_iterable_dataset("dataset", "config", split="train")
ids = ids.shuffle(seed=42)
ids = ids.map(tokenize)
# Create dataloader (works with any num_workers)
dataloader = StatefulDataLoader(
ids,
batch_size=32,
num_workers=4 # Use workers to keep GPU fed
)
# Training loop
for step, batch in enumerate(dataloader):
# Forward/backward pass
loss = model(batch)
loss.backward()
optimizer.step()
# Save checkpoint periodically
if step % 1000 == 0:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'dataloader': dataloader.state_dict(), # Includes dataset state
'step': step
}
torch.save(checkpoint, f'checkpoint_{step}.pt')
# Resume from checkpoint
checkpoint = torch.load('checkpoint_5000.pt')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
# Create NEW dataloader and restore
ids = fast_load_iterable_dataset("dataset", "config", split="train")
ids = ids.shuffle(seed=42)
ids = ids.map(tokenize)
dataloader = StatefulDataLoader(ids, batch_size=32, num_workers=4)
dataloader.load_state_dict(checkpoint['dataloader'])
# Continue training from exact position
for step, batch in enumerate(dataloader, start=checkpoint['step']+1):
...
Features¶
✅ Instant Loading - First load indexes Arrow file paths - Subsequent loads are instant (<1 second) - No data copying - memory-maps HF cache
✅ Multi-Level Shuffling - Shard-level: Shuffles Arrow file order (234 files = 234 shards) - Example-level: Shuffle buffer for fine-grained randomization - Configurable buffer size (default: 1000 examples) - Combines both levels for better randomization
✅ Virtual Splits - Create train/val/test splits without copying data - Support percentage (0.8 or "80%") and absolute indexing - Zero-copy slicing for memory efficiency - Compatible with shuffling, sharding, and checkpointing
✅ Flexible Sharding for DDP - File-level sharding: Each rank gets different Arrow files (efficient) - Example-level sharding: Works even with more ranks than files - Auto mode automatically selects the best approach - No data duplication across ranks
✅ Stateful Checkpointing
- Implements state_dict()/load_state_dict() protocol
- Compatible with torchdata.stateful_dataloader.StatefulDataLoader
- Efficient resumption without iteration
- Works with any num_workers value (0, 1, 2, 4, 8, etc.)
✅ Multi-Worker Support - Workers prefetch batches in parallel with training - Keeps GPU fed during data loading - Each worker gets a disjoint subset of Arrow files - No data duplication across workers
How It Works¶
Indexing (First Load Only)¶
- Calls
load_dataset()to load dataset normally - Extracts Arrow file paths from HF cache
- Computes per-file example counts (for fast
len()later) - Saves file paths, lengths, and metadata version to index
- Returns
SimpleArrowIterableDataset
Instant Loading (Subsequent Loads)¶
- Reads index and checks metadata version
- If version mismatch, automatically reindexes (ensures compatibility)
- Creates
SimpleArrowIterableDatasetwith cached file lengths len()uses cached counts - no file access needed!- Memory-maps Arrow files on demand during iteration (zero-copy)
Checkpointing¶
- Tracks position as
(file_index, example_index) state_dict()returns current position (~1 KB)load_state_dict()restores position- Iterator skips directly to saved position (no iteration needed)
Shuffle Buffer¶
The shuffle buffer implements reservoir sampling for fine-grained example shuffling:
- Fill Phase: Fills buffer with first
buffer_sizeexamples - Streaming Phase: For each new example:
- Randomly selects a position in the buffer
- Yields the example at that position
- Replaces it with the new example
- Flush Phase: Shuffles and yields remaining buffer contents
Configuration:
# Default buffer size (1000 examples)
ids = ids.shuffle(seed=42, buffer_size=1000)
# Larger buffer for better randomization (uses more memory)
ids = ids.shuffle(seed=42, buffer_size=10000)
# Disable example-level shuffling (shard-level only)
ids = ids.shuffle(seed=42, buffer_size=0)
Memory Usage: - Buffer size = 1000 → ~few MB (depends on example size) - Buffer size = 10000 → ~tens of MB - Larger buffers provide better randomization at cost of memory
Checkpoint Behavior: When resuming from a checkpoint with shuffle buffer enabled, the shuffle pattern after resumption will differ from a non-interrupted run. However, randomness is maintained and the seed ensures reproducibility of the overall shuffle behavior.
Worker Distribution¶
With 234 Arrow files and num_workers=4:
- Worker 0: files 0, 4, 8, 12, ... (~59 files)
- Worker 1: files 1, 5, 9, 13, ... (~59 files)
- Worker 2: files 2, 6, 10, 14, ... (~58 files)
- Worker 3: files 3, 7, 11, 15, ... (~58 files)
Each worker independently tracks its position for checkpointing.
Flexible Sharding¶
The dataset supports two sharding modes for DDP training:
File-Level Sharding (default when num_shards ≤ num_files): - Most efficient - minimal overhead during iteration - Each shard gets a subset of Arrow files - Shard i gets files at indices i, i+num_shards, i+2*num_shards, ... - Example with 8 files and 2 shards: - Shard 0: files 0, 2, 4, 6 - Shard 1: files 1, 3, 5, 7
Example-Level Sharding (used when num_shards > num_files): - Works with any number of shards, even more than files - Divides total examples evenly across shards - Slightly more overhead during iteration (tracks global position) - Example with 1 file (1000 examples) and 2 shards: - Shard 0: examples 0-499 - Shard 1: examples 500-999
Auto Mode (recommended):
# Automatically selects best mode
ids = ids.shard(num_shards=world_size, index=rank, mode='auto')
# Auto uses file-level if num_shards <= num_files
# Auto uses example-level if num_shards > num_files
Explicit Mode Selection:
# Force file-level sharding
ids = ids.shard(num_shards=2, index=0, mode='file')
# Force example-level sharding
ids = ids.shard(num_shards=4, index=0, mode='example')
Use Cases: - Training on 2 GPUs with 1-file dataset → Example-level required - Training on 8 GPUs with 234-file dataset → File-level more efficient - Auto mode handles both cases automatically
Testing¶
Comprehensive tests verify correctness:
# Test instant loading and basic functionality
python fast_hf_loader_simple.py
# Test checkpoint functionality (num_workers=0, 1, 2)
python test_checkpoint_final.py
# Test specific checkpoint scenarios
python test_correct_comparison.py
All tests pass, confirming: - ✅ Instant loading works - ✅ Shuffling works - ✅ Virtual splits work (percentage, absolute, train/val/test) - ✅ File-level and example-level sharding work - ✅ Auto mode selects appropriate sharding - ✅ Multi-worker works - ✅ Checkpointing works with num_workers=0, 1, 2, etc. - ✅ Checkpointing works with both sharding modes and virtual splits
API Reference¶
fast_load_iterable_dataset(...)¶
Main loading function.
Parameters:
- path: Dataset path (e.g., "HuggingFaceTB/smollm-corpus")
- name: Config name (e.g., "fineweb-edu-dedup")
- split: Split to load (e.g., "train")
- data_files: Optional data files pattern
- revision: Optional git revision
- force_reindex: Force re-indexing (default: False)
- num_proc: Processes for initial indexing
- index_dir: Custom index directory
- **load_dataset_kwargs: Additional args for load_dataset()
Returns: SimpleArrowIterableDataset
SimpleArrowIterableDataset¶
Iterable dataset with checkpointing support.
Methods:
- .shuffle(seed=None, buffer_size=1000): Shuffle at shard and example level
- seed: Random seed for reproducibility
- buffer_size: Size of example shuffle buffer (0 to disable example-level shuffling)
- .select(indices): Select examples by indices (HuggingFace API compatible)
- indices: Range, list, iterable, ndarray, or Series of integer indices
- Supports only contiguous ranges (efficiently translates to .slice())
- Example: ds.select(range(100, 200)) selects examples 100-199
- .slice(start, end): Create virtual split (train/val/test)
- start: Start index - int, float (percentage), string ("80%"), or None
- end: End index - int, float (percentage), string ("80%"), or None
- Returns new dataset with virtual split applied
- .shard(num_shards, index, mode='auto'): Shard for DDP
- mode='auto': Auto-select file or example-level sharding
- mode='file': File-level sharding (efficient, requires num_shards ≤ num_files)
- mode='example': Example-level sharding (works with any num_shards)
- .map(function, batched=False): Lazy transformations
- .__len__(): Total examples (cached, accounts for splits and sharding)
- .state_dict(): Get checkpoint state
- .load_state_dict(state_dict): Restore from checkpoint
Recommendations¶
- Use num_workers≥1 for training
- Workers prefetch batches in parallel
- Prevents GPU starvation
-
Recommended:
num_workers=4for most cases -
Checkpoint frequently
- Every 1000-10000 steps
-
Keep multiple checkpoints
-
For DDP training
- Each rank loads the same dataset
- Use
.shard(num_shards=world_size, index=rank) -
Each rank gets different Arrow files
-
Test checkpoint resume early
- Verify resumption works in your setup
- Ensures you won't lose progress
Example: Full Training Script¶
See CHECKPOINT_GUIDE.md for complete training examples with:
- Basic checkpoint save/restore
- DDP training
- Multi-worker DataLoader
- Production-ready patterns
Architecture¶
- No copying: Memory-maps Arrow files from HF cache
- Cached file lengths: Per-file example counts cached in index for O(1)
len() - Metadata versioning: Automatic reindex when format changes
- No expensive operations: File access only during iteration, not for metadata
- Minimal state: Checkpoint is ~1 KB
- Standard protocol: Compatible with PyTorch ecosystem
Summary¶
This implementation provides: - 20 minutes → <1 second dataset loading - Hours → <1 second checkpoint resumption - Virtual splits for train/val/test without data copying - Flexible sharding (file-level or example-level) - Full DDP and multi-worker support - Production-ready for long-running training jobs - Compatible with standard PyTorch tools
Perfect for training on large datasets (100GB+) where: - Dataset loading is a bottleneck - Long training runs need checkpointing - Mid-epoch resumption is critical