Skip to content

Migration Guide: Distributed Checkpoint Abstraction

Audience: Developers implementing custom trainers or extending existing trainers Status: Phase 3 Complete - All built-in trainers migrated

Overview

This guide explains how to implement the new state-centric checkpoint API for custom trainers. All built-in Forgather trainers now use this system. The new API provides better support for hybrid parallelism and makes checkpoint semantics explicit.

Why Use the New API?

The new checkpoint system offers several advantages:

  1. Explicit semantics: No guessing which ranks should save what
  2. Hybrid parallelism support: Easily express complex DP/MP/PP combinations
  3. Dynamic patterns: Runtime determination of sharing (e.g., dataset state)
  4. Validation: Built-in replication validation and manifest checking
  5. Debugging: Complete checkpoint inventory via manifest
  6. Automatic coordination: No manual rank checks needed
  7. Production-ready: All built-in trainers tested and working

When to Implement

You need to implement the new checkpoint API when:

  1. Creating a custom trainer - Inherit from BaseTrainer and override get_state_components()
  2. Adding a new parallelism strategy - Use appropriate SharingPattern for your components
  3. Extending existing trainers - Follow the pattern from similar built-in trainers

Quick Start

Step 1: Import New Types

from forgather.ml.trainer.checkpoint_types import (
    StateComponent,
    SharingPattern,
)
from forgather.ml.trainer.checkpoint_manager import RNGState

Step 2: Implement get_state_components()

Replace the old get_statefuls_for_save() method with get_state_components():

Before (Legacy API):

class MyTrainer:
    def get_statefuls_for_save(self) -> Dict[str, Stateful]:
        statefuls = {}
        if self.args.save_optimizer_state:
            statefuls["optimizer"] = self.optimizer
        if self.args.save_scheduler_state:
            statefuls["scheduler"] = self.lr_scheduler
        # ... etc
        return statefuls

    def get_statefuls_for_load(self) -> Dict[str, Stateful]:
        # Similar to above
        ...

After (New API):

class MyTrainer:
    def get_state_components(self) -> List[StateComponent]:
        return [
            StateComponent(
                key="model",
                stateful=self.model,
                sharing_pattern=SharingPattern.GLOBAL,
            ),
            StateComponent(
                key="optimizer",
                stateful=self.optimizer,
                sharing_pattern=SharingPattern.GLOBAL,
                required=self.args.save_optimizer_state,
            ),
            StateComponent(
                key="scheduler",
                stateful=self.lr_scheduler,
                sharing_pattern=SharingPattern.GLOBAL,
                required=self.args.save_scheduler_state,
            ),
            StateComponent(
                key="dataset",
                stateful=self.train_dataloader,
                sharing_pattern=self._get_dataset_sharing_pattern(),
                required=self.args.save_dataset_state,
            ),
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=self.args.save_rng_state,
            ),
        ]

Step 3: Handle Dynamic Patterns (Optional)

For components with runtime-determined sharing (like datasets):

def _get_dataset_sharing_pattern(self) -> SharingPattern:
    """Determine dataset state sharing pattern based on dataloader type."""
    if isinstance(self.train_dataloader, DataloaderDispatcher):
        # Dispatcher coordinates data loading
        if self.train_dataloader._dp_size == 1:
            # Pure MP mode: all ranks get same batch, rank 0 loads
            return SharingPattern.GLOBAL
        elif self.train_dataloader._mp_size == 1:
            # Pure DP mode: rank 0 loads and dispatches
            return SharingPattern.GLOBAL
        else:
            # Hybrid: each DP group has one loader
            return SharingPattern.PER_GROUP
    else:
        # Each rank has independent dataloader
        return SharingPattern.PER_RANK

Pattern Selection Guide

Choose the right SharingPattern for each component:

GLOBAL

Use when: Only one copy exists across all ranks Examples: - Training progress when using centralized data dispatch - Global metrics/counters

Save behavior: Rank 0 saves Load behavior: All ranks load same file

StateComponent(
    key="trainer_state",
    stateful=self.state,
    sharing_pattern=SharingPattern.GLOBAL,
)

PER_RANK

Use when: Each rank has unique state Examples: - RNG state (each rank needs different random numbers) - Pipeline stage parameters (different stage per rank) - Rank-specific optimizer state (when optimizing different parameters)

Save behavior: Every rank saves its own file Load behavior: Each rank loads its specific file

StateComponent(
    key="rng",
    stateful=RNGState(),
    sharing_pattern=SharingPattern.PER_RANK,
)

REPLICATED

Use when: State is identical across all ranks Examples: - DDP model weights (synchronized by DDP) - DDP optimizer state (synchronized by DDP) - LR scheduler state (same schedule across all ranks)

Save behavior: Rank 0 saves (avoids redundancy) Load behavior: All ranks load same file Optional: Validate that all ranks actually have identical state

StateComponent(
    key="model",
    stateful=self.unwrapped_model(),
    sharing_pattern=SharingPattern.REPLICATED,
    validate_replication=True,  # Verify DDP synchronization
)

PER_GROUP

Use when: State is shared within process groups, different across groups Examples: - Model shard shared within DP group but different across PP stages - Dataset state shared within DP group - Optimizer state for grouped parameters

Save behavior: One rank per group saves Load behavior: Ranks load based on group membership

StateComponent(
    key="model_shard",
    stateful=self.model_shard,
    sharing_pattern=SharingPattern.PER_GROUP,
    process_group_name="dp_group",
)

Note: Also implement get_process_groups():

def get_process_groups(self) -> Dict[str, ProcessGroup]:
    return {
        "dp_group": self.dp_process_group,
        "pp_group": self.pp_process_group,
    }

PER_NODE

Use when: State is local to each node Examples: - Node-local caches - Node-specific resources

Save behavior: Local rank 0 on each node saves Load behavior: Ranks load based on node membership

StateComponent(
    key="node_cache",
    stateful=self.cache,
    sharing_pattern=SharingPattern.PER_NODE,
)

Complete Migration Examples

Example 1: Single-GPU Trainer

class SimpleTrainer(BaseTrainer):
    def get_state_components(self) -> List[StateComponent]:
        """All state is GLOBAL in single-GPU setting."""
        return [
            StateComponent(
                key="model",
                stateful=self.model,
                sharing_pattern=SharingPattern.GLOBAL,
            ),
            StateComponent(
                key="optimizer",
                stateful=self.optimizer,
                sharing_pattern=SharingPattern.GLOBAL,
                required=self.args.save_optimizer_state,
            ),
            StateComponent(
                key="scheduler",
                stateful=self.lr_scheduler,
                sharing_pattern=SharingPattern.GLOBAL,
                required=self.args.save_scheduler_state,
            ),
            StateComponent(
                key="dataset",
                stateful=self.train_dataloader,
                sharing_pattern=SharingPattern.GLOBAL,
                required=self.args.save_dataset_state,
            ),
            StateComponent(
                key="trainer",
                stateful=self,  # TrainerState
                sharing_pattern=SharingPattern.GLOBAL,
            ),
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=self.args.save_rng_state,
            ),
        ]

Example 2: DDP Trainer

class DDPTrainer(BaseTrainer):
    def get_state_components(self) -> List[StateComponent]:
        """DDP synchronizes weights, so use REPLICATED pattern."""
        return [
            StateComponent(
                key="model",
                stateful=self.unwrapped_model(),
                sharing_pattern=SharingPattern.REPLICATED,
                validate_replication=True,  # Catch DDP sync bugs
            ),
            StateComponent(
                key="optimizer",
                stateful=self.optimizer,
                sharing_pattern=SharingPattern.REPLICATED,
                required=self.args.save_optimizer_state,
            ),
            StateComponent(
                key="scheduler",
                stateful=self.lr_scheduler,
                sharing_pattern=SharingPattern.REPLICATED,
                required=self.args.save_scheduler_state,
            ),
            StateComponent(
                key="dataset",
                stateful=self.train_dataloader,
                sharing_pattern=self._get_dataset_sharing_pattern(),
                required=self.args.save_dataset_state,
            ),
            StateComponent(
                key="trainer",
                stateful=self,
                sharing_pattern=SharingPattern.REPLICATED,
            ),
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=self.args.save_rng_state,
            ),
        ]

Example 3: Pipeline Parallel Trainer

class PipelineTrainer(BaseTrainer):
    def get_state_components(self) -> List[StateComponent]:
        """Each rank has different pipeline stage."""
        return [
            StateComponent(
                key="model",
                stateful=self.pipeline_modules,  # Different per rank
                sharing_pattern=SharingPattern.PER_RANK,
            ),
            StateComponent(
                key="optimizer",
                stateful=self.optimizer,  # Different parameters per rank
                sharing_pattern=SharingPattern.PER_RANK,
                required=self.args.save_optimizer_state,
            ),
            StateComponent(
                key="scheduler",
                stateful=self.lr_scheduler,  # Same schedule across ranks
                sharing_pattern=SharingPattern.REPLICATED,
                required=self.args.save_scheduler_state,
            ),
            StateComponent(
                key="dataset",
                stateful=self.train_dataloader,  # Centralized loading
                sharing_pattern=SharingPattern.GLOBAL,
                required=self.args.save_dataset_state,
            ),
            StateComponent(
                key="trainer",
                stateful=self,
                sharing_pattern=SharingPattern.REPLICATED,
            ),
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=self.args.save_rng_state,
            ),
        ]

Example 4: Hybrid DDP x Pipeline

class HybridDDPPipelineTrainer(BaseTrainer):
    def get_state_components(self) -> List[StateComponent]:
        """Hybrid parallelism: DP groups with PP stages."""
        return [
            StateComponent(
                key="model",
                stateful=self.pipeline_modules,
                sharing_pattern=SharingPattern.PER_GROUP,
                process_group_name="pp_group",  # Same within PP, different across DP
            ),
            StateComponent(
                key="optimizer",
                stateful=self.optimizer,
                sharing_pattern=SharingPattern.PER_GROUP,
                process_group_name="pp_group",
                required=self.args.save_optimizer_state,
            ),
            StateComponent(
                key="scheduler",
                stateful=self.lr_scheduler,
                sharing_pattern=SharingPattern.REPLICATED,
                required=self.args.save_scheduler_state,
            ),
            StateComponent(
                key="dataset",
                stateful=self.train_dataloader,
                sharing_pattern=SharingPattern.PER_GROUP,
                process_group_name="dp_group",  # One per DP group
                required=self.args.save_dataset_state,
            ),
            StateComponent(
                key="trainer",
                stateful=self,
                sharing_pattern=SharingPattern.REPLICATED,
            ),
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=self.args.save_rng_state,
            ),
        ]

    def get_process_groups(self) -> Dict[str, ProcessGroup]:
        return {
            "dp_group": self.dp_process_group,
            "pp_group": self.pp_process_group,
        }

Common Pitfalls

❌ Wrong: Using GLOBAL for DDP weights

# This saves redundantly on every rank!
StateComponent(
    key="model",
    stateful=self.model,
    sharing_pattern=SharingPattern.GLOBAL,  # Wrong!
)

✅ Correct:

StateComponent(
    key="model",
    stateful=self.unwrapped_model(),
    sharing_pattern=SharingPattern.REPLICATED,  # DDP synchronizes
)

❌ Wrong: Using REPLICATED for RNG state

# This makes all ranks use the same random numbers!
StateComponent(
    key="rng",
    stateful=RNGState(),
    sharing_pattern=SharingPattern.REPLICATED,  # Wrong!
)

✅ Correct:

StateComponent(
    key="rng",
    stateful=RNGState(),
    sharing_pattern=SharingPattern.PER_RANK,  # Each rank needs unique RNG
)

❌ Wrong: Forgetting to implement get_process_groups()

StateComponent(
    key="model",
    stateful=self.model,
    sharing_pattern=SharingPattern.PER_GROUP,
    process_group_name="dp_group",  # But get_process_groups() not implemented!
)

✅ Correct:

def get_process_groups(self) -> Dict[str, ProcessGroup]:
    return {"dp_group": self.dp_process_group}

Testing Your Migration

1. Unit Tests

Verify your get_state_components() implementation:

def test_state_components():
    trainer = MyTrainer(...)
    components = trainer.get_state_components()

    # Check all expected components present
    keys = {c.key for c in components}
    assert "model" in keys
    assert "optimizer" in keys

    # Check sharing patterns are correct
    model_component = next(c for c in components if c.key == "model")
    assert model_component.sharing_pattern == SharingPattern.REPLICATED  # For DDP

2. Integration Tests

Test actual save/load cycles:

def test_checkpoint_save_load():
    trainer = MyTrainer(...)

    # Train for a few steps
    trainer.train()

    # Save checkpoint
    checkpoint_path = trainer.save_checkpoint()

    # Verify manifest exists
    assert os.path.exists(os.path.join(checkpoint_path, "checkpoint_manifest.json"))

    # Create new trainer
    trainer2 = MyTrainer(...)
    trainer2.load_checkpoint(checkpoint_path)

    # Verify state was restored
    # ... assertions

3. Distributed Tests

Test with multiple ranks:

# Test DDP save/load
torchrun --nproc_per_node=4 test_ddp_checkpoint.py

# Test pipeline parallel save/load
torchrun --nproc_per_node=4 test_pipeline_checkpoint.py

Built-in Trainer Reference

All built-in trainers provide reference implementations:

BaseTrainer (Single-Device)

  • Model: GLOBAL
  • Optimizer: GLOBAL
  • Scheduler: GLOBAL
  • Dataset: GLOBAL
  • RNG: PER_RANK
  • Location: src/forgather/ml/trainer/base_trainer.py:get_state_components()

DDPTrainer (Data Parallel)

  • Model: REPLICATED (with validation)
  • Optimizer: REPLICATED (with validation)
  • Scheduler: REPLICATED
  • Dataset: GLOBAL or PER_RANK (dynamic - depends on dispatch_batches)
  • RNG: PER_RANK
  • Location: src/forgather/ml/trainer/ddp/ddp_trainer.py:get_state_components()

AccelTrainer (Accelerate DDP)

  • Model: REPLICATED (with validation)
  • Optimizer: REPLICATED (validation disabled - AcceleratedOptimizer wrapper)
  • Scheduler: REPLICATED
  • Dataset: PER_RANK
  • RNG: PER_RANK
  • Location: src/forgather/ml/trainer/accelerate/accel_trainer.py:get_state_components()

PipelineTrainer (Pipeline Parallel)

  • Model: PER_RANK (different stages)
  • Optimizer: PER_RANK (different parameters)
  • Scheduler: REPLICATED
  • Dataset: GLOBAL (rank 0 loads and broadcasts)
  • RNG: PER_RANK
  • Location: src/forgather/ml/trainer/pipeline/pipeline_trainer.py:get_state_components()

Backward Compatibility

CheckpointManager automatically detects which API your trainer implements:

# CheckpointManager initialization (automatic):
state_components = stateful_provider.get_state_components()
if state_components is not None:
    # NEW API: Use CheckpointCoordinator
    self.coordinator = CheckpointCoordinator(...)
else:
    # OLD API: Use legacy get_statefuls_for_save/load
    self.coordinator = None

For custom trainers: - Must implement get_state_components() (required as of v2.0) - Old get_statefuls_for_save/load() methods have been removed - CheckpointManager requires the new API

Migration Checklist

  • [ ] Import new types (StateComponent, SharingPattern)
  • [ ] Implement get_state_components() method
  • [ ] Choose correct SharingPattern for each component
  • [ ] Implement get_process_groups() if using PER_GROUP
  • [ ] Add dynamic pattern resolution for dataset state (if applicable)
  • [ ] Add validation flags (e.g., validate_replication=True for DDP)
  • [ ] Test with unit tests
  • [ ] Test with integration tests (save/load cycles)
  • [ ] Test with distributed training (if applicable)
  • [ ] Update documentation
  • [ ] Remove legacy get_statefuls_for_save/load() (after testing)

Removed Features (v2.0)

Save/Restore Flags Removed

The following flags have been removed from TrainingArguments: - save_optimizer_state / restore_optimizer_state - save_scheduler_state / restore_scheduler_state - save_dataset_state / restore_dataset_state - save_rng_state / restore_rng_state

Rationale: These flags created confusing coupling between save and restore decisions. The new approach is simpler and more flexible.

Migration: All state is now always saved. To skip loading a component, delete its file from the checkpoint directory before resuming training.

Before (old API):

args = TrainingArguments(
    restore_optimizer_state=False,  # Skip optimizer restore
)

After (new API):

# Delete optimizer state file before resuming
rm checkpoint-1000/optimizer_state.pt

Old Protocol Methods Removed

The deprecated get_statefuls_for_save() and get_statefuls_for_load() methods have been removed from the StatefulProvider protocol.

Migration: All trainers must now implement get_state_components() instead.

Before (old API):

def get_statefuls_for_save(self) -> Dict[str, Stateful]:
    return {
        "optimizer": self.optimizer if self.args.save_optimizer_state else None,
        "scheduler": self.lr_scheduler if self.args.save_scheduler_state else None,
    }

After (new API):

@override
def get_state_components(self) -> List[StateComponent]:
    return [
        StateComponent(
            key="optimizer",
            stateful=self.optimizer,
            sharing_pattern=SharingPattern.GLOBAL,
            required=False,  # Optional - can be skipped by deleting file
        ),
        StateComponent(
            key="scheduler",
            stateful=self.lr_scheduler,
            sharing_pattern=SharingPattern.GLOBAL,
            required=False,
        ),
    ]

Note: Model weights remain required=True and cannot be skipped.

Getting Help

  • Main Documentation: docs/checkpointing/distributed_checkpoint_abstraction.md
  • User Guide: docs/checkpointing/user_guide.md - Troubleshooting and best practices
  • Built-in Trainers: Check source code for reference implementations
  • Issues: Report issues at https://github.com/jdinalt/forgather/issues

Current Status

  • v2.0: Checkpoint API cleanup complete
  • New API: Only API supported (old API removed)
  • All state always saved: Simplified approach with manual file deletion for partial loading
  • Production Ready: All built-in trainers tested successfully
  • Breaking Change: Save/restore flags removed - see "Removed Features" section above