DDP Trainer Example and Integration Tests¶
This project serves multiple purposes: - Integration testing for DDP trainer features (distributed training, checkpointing, dataset strategies, gradient accumulation) - Usage examples for configuring DDP training with different dataset patterns - Performance benchmarking for data-parallel configurations
Quick Start¶
# 2-GPU training (default configuration)
forgather -t 2gpu.yaml train
# 4-GPU training with sharded dataset
forgather -t 4gpu.yaml train
# Single process (baseline comparison)
forgather -t single_processs.yaml train
# View training logs
forgather logs list
forgather logs summary --all --format one-line
forgather logs plot --loss-curves
Overview¶
The DDP (Distributed Data Parallel) trainer implements efficient multi-GPU training using PyTorch's DistributedDataParallel. It provides:
- Automatic gradient synchronization across GPUs
- Flexible dataset loading strategies (dispatch vs. sharding)
- Distributed checkpointing with full state restoration
- Gradient accumulation support
- Transparent scaling from single GPU to multi-node
Trainer Implementation: src/forgather/ml/trainer/ddp/ddp_trainer.py
Configuration Files¶
All configurations extend project.yaml which sets up a small transformer model on Tiny Stories dataset:
Basic DDP Configurations¶
| Config | Purpose | GPUs | Batch Size | Dataset Strategy |
|---|---|---|---|---|
| 2gpu.yaml | Default 2-GPU setup | 2 | 64 (32×2) | Dispatched batches |
| 4gpu.yaml | 4-GPU training | 4 | 128 (32×4) | Sharded dataset |
| single_processs.yaml | Single-process baseline | 1 | 32 | N/A |
| 4cpu.yaml | CPU-based DDP | 4 CPUs | 64 (32×2) | Dispatched batches |
Dataset Strategy Configurations¶
| Config | Dataset Type | Sharding | Purpose |
|---|---|---|---|
| sharded_dataset.yaml | Standard (Arrow) | Per-rank sharding | Test dataset sharding on standard HF datasets |
| iterable_dataset.yaml | HF IterableDataset | None (dispatched) | Baseline iterable dataset performance |
| sharded_iterable_dataset.yaml | HF IterableDataset | Per-rank sharding | Sharded iterable dataset |
| sharded_fast_dataset.yaml | Fast Iterable | Per-rank sharding | Optimized fast iterable dataset |
Advanced Configurations¶
| Config | Feature | Description |
|---|---|---|
| grad_accum.yaml | Gradient Accumulation | Tests DDP with gradient_accumulation_steps: 4, smaller batch size |
| checkpoint_train.yaml | Checkpoint Save | Trains 500 steps with full checkpointing (model, optimizer, scheduler, dataset, RNG) |
| checkpoint_resume.yaml | Checkpoint Resume | Resumes from checkpoint and continues to 1000 steps |
Performance Analysis¶
Performance data from actual training runs on Tiny Stories dataset (10% subset, ~212k samples):
Scaling Efficiency¶
| Configuration | GPUs | Steps | Time | Throughput | Speedup | Efficiency |
|---|---|---|---|---|---|---|
| Single Process | 1 | 6600 | 1:57 | 1,804 samples/s | 1.0× | 100% |
| 2-GPU (dispatched) | 2 | 3300 | 1:35 | 2,224 samples/s | 1.23× | 62% |
| 2-GPU (sharded) | 2 | 3300 | 1:15 | 2,818 samples/s | 1.56× | 78% |
| 4-GPU (sharded) | 4 | 1600 | 0:40 | 5,262 samples/s | 2.92× | 73% |
Key Findings: - Sharded datasets outperform dispatched batches: ~27% faster (2,818 vs 2,224 samples/s on 2 GPUs) - Near-linear scaling to 4 GPUs: 2.92× speedup with 73% efficiency - Communication overhead: ~20-40% efficiency loss due to gradient synchronization
Dataset Loading Strategies¶
| Strategy | Throughput | Pros | Cons |
|---|---|---|---|
| Dispatched (default) | 2,224 samples/s | Simple, centralized state | Rank-0 bottleneck, higher latency |
| Sharded | 2,818 samples/s | Parallel loading, higher throughput | Requires explicit sharding, complex state |
| Iterable (HF) | 1,498 samples/s | Streaming, memory efficient | Slower due to HF overhead |
| Iterable (Fast) | 1,971 samples/s | Optimized streaming | Still slower than Arrow-based |
| Sharded Iterable | 1,977 samples/s | Streaming + parallel | Complex coordination |
Recommendation: Use sharded datasets for maximum throughput when dataset supports it. Fall back to dispatched batches for simplicity when performance is not critical.
DDP Features¶
1. Batch Dispatching (dispatch_batches)¶
How it works: Rank 0 loads and preprocesses all batches, then dispatches them to other ranks via torch.distributed.
When to use: - ✅ Dataset doesn't support sharding - ✅ Simple checkpoint management (single global state) - ✅ Smaller datasets where rank-0 isn't a bottleneck
Trade-offs: - ❌ Rank-0 becomes bottleneck for data loading - ❌ Higher communication overhead - ✅ Easier to debug and reason about
2. Dataset Sharding¶
How it works: Each rank loads its own shard of the dataset independently.
[dataset_project]
shard_dataset: True
[trainer_args]
dispatch_batches: False # Must disable dispatching
When to use: - ✅ Large datasets where loading is a bottleneck - ✅ Dataset supports clean sharding (e.g., Arrow-based HF datasets) - ✅ Maximum throughput is critical
Trade-offs: - ✅ 20-30% faster data loading - ✅ Parallel preprocessing on all ranks - ❌ Must ensure each rank gets different examples - ❌ More complex checkpoint state management
3. Gradient Accumulation¶
How it works: Accumulate gradients over multiple micro-batches before synchronizing.
When to use: - ✅ Large models that don't fit with full batch size - ✅ Want effective batch size > per-device memory limit - ✅ Reduce communication frequency (sync every N steps)
Trade-offs: - ✅ Enables training larger models - ✅ Reduces gradient sync overhead - ❌ Slightly slower due to more forward/backward passes - ❌ Must carefully tune accumulation steps
4. Distributed Checkpointing¶
Full checkpoint support with automatic state management:
[trainer_args]
save_strategy: "steps"
save_steps: 200
save_optimizer_state: True
save_scheduler_state: True
save_dataset_state: True # Critical for reproducibility
save_rng_state: True # Ensures identical randomness
Checkpoint contents: - Model weights: REPLICATED (synchronized across all ranks via DDP) - Optimizer state: REPLICATED (same due to gradient sync) - Dataset state: GLOBAL (dispatched) or PER_RANK (sharded) - RNG state: PER_RANK (each rank needs different random numbers)
Resume behavior:
# Train with checkpointing
forgather -t checkpoint_train.yaml train # Saves at step 200, 400
# Resume and continue
forgather -t checkpoint_resume.yaml train # Resumes from step 400, continues to 1000
See docs/checkpointing/user_guide.md for detailed documentation.
Usage Examples¶
Basic Multi-GPU Training¶
# 2-GPU training with batch dispatching
forgather -t 2gpu.yaml train
# 4-GPU training with dataset sharding (faster)
forgather -t 4gpu.yaml train -d 0,1,2,3
# View results
forgather logs summary
forgather logs plot --loss-curves -e
Compare Dataset Strategies¶
# Train with different strategies
forgather -t sharded_dataset.yaml train
forgather -t iterable_dataset.yaml train
forgather -t sharded_fast_dataset.yaml train
# Compare performance
forgather logs summary --all --format one-line
# Visualize comparison
forgather logs plot --compare \
output_models/default_model/runs/sharded_*/trainer_logs.json \
output_models/default_model/runs/iterable_*/trainer_logs.json \
--loss-curves -e
Test Checkpoint Functionality¶
# Initial training with checkpointing
forgather -t checkpoint_train.yaml train
# Verify checkpoints created
ls -lh output_models/default_model/checkpoint-*
# Resume from checkpoint
forgather -t checkpoint_resume.yaml train
# Verify training continued from correct step
forgather logs summary output_models/default_model/runs/checkpoint_*/trainer_logs.json
Gradient Accumulation¶
# Train with gradient accumulation
forgather -t grad_accum.yaml train
# Compare memory usage and throughput
# (Note: requires additional memory profiling)
Architecture Details¶
DDP Trainer Class Hierarchy¶
Trainer (base)
└── DDPTrainer
├── Wraps model with DDP
├── Handles gradient synchronization
├── Manages distributed checkpoints
└── Optional: DataloaderDispatcher for batch broadcasting
Key Methods¶
_init_distributed(): Initialize device mesh and DDP process group
_wrap(): Wrap model in DDP, optionally wrap dataloaders with dispatcher
unwrapped_model(): Access original model (stored in model.module)
_distributed_loss(): All-reduce loss across ranks for logging
_forward_backward_step(): Skip gradient sync during accumulation steps using no_sync()
get_state_components(): Define checkpoint sharing patterns:
- Model/Optimizer: REPLICATED (DDP synchronized)
- Dataset: GLOBAL (dispatched) or PER_RANK (sharded)
- RNG: PER_RANK (different per rank)
Checkpoint State Patterns¶
# From get_state_components()
components = [
StateComponent(
key="model",
sharing_pattern=SharingPattern.REPLICATED, # DDP synced
validate_replication=True, # Catch sync bugs
),
StateComponent(
key="dataset",
sharing_pattern=(
SharingPattern.GLOBAL if dispatch_batches
else SharingPattern.PER_RANK
),
),
StateComponent(
key="rng",
sharing_pattern=SharingPattern.PER_RANK, # Different per rank
),
]
Integration Tests¶
This project validates:
- ✅ DDP initialization and model wrapping
- ✅ Gradient synchronization across ranks
- ✅ Batch dispatching via DataloaderDispatcher
- ✅ Dataset sharding with proper rank assignment
- ✅ Checkpoint save with REPLICATED/GLOBAL/PER_RANK patterns
- ✅ Checkpoint resume with full state restoration
- ✅ Gradient accumulation with selective sync
- ✅ Single-process fallback (world_size == 1)
- ✅ CPU-based distributed training
- ✅ Iterable dataset integration
- ✅ Fast dataset loader performance
Troubleshooting¶
Common Issues¶
1. Slow training with dispatched batches
# Solution: Use sharded dataset
[dataset_project]
shard_dataset: True
[trainer_args]
dispatch_batches: False
2. Different data on each rank
# Problem: Sharded dataset but dispatch_batches=True
# Solution: Ensure consistency
dispatch_batches: False # When using sharding
3. Checkpoint resume fails
4. Out of memory on GPUs
# Use gradient accumulation
gradient_accumulation_steps: 4
per_device_train_batch_size: 8 # Reduce from 32
Debugging¶
# Check distributed environment
forgather -t 2gpu.yaml train --dry-run
# View detailed logs
tail -f output_models/default_model/runs/*/trainer_logs.json
# Analyze performance
forgather logs summary --all --format one-line
forgather logs plot --metrics "loss,grad_norm,learning_rate" -e
Further Reading¶
- DDP Implementation:
src/forgather/ml/trainer/ddp/ddp_trainer.py - Checkpointing Guide:
docs/checkpointing/user_guide.md - Fast Dataset Loader:
docs/datasets/fast-hf-loader.md - Dataset Sharding: Template examples in
templates/configs/ - PyTorch DDP: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
Performance Tips¶
- Use sharded datasets for >2 GPUs: 20-30% faster than dispatched
- Enable gradient checkpointing if memory bound (in model config)
- Tune batch size: Larger batches = better GPU utilization
- Monitor gradient norms: Track with
forgather logs plot --metrics "grad_norm" - Profile first: Use single-process baseline to identify bottlenecks
- Use fast iterables when streaming: 30% faster than standard HF iterables
Contributing¶
When adding new DDP features:
1. Add configuration in templates/configs/
2. Run training and save logs
3. Update this README with performance data
4. Add integration test validation
Use the log analysis tools to gather metrics: