Checkpoints¶
Forgather's checkpointing system saves and restores all trainer state — model weights, optimizer, LR scheduler, dataset position, per-rank RNG, and training progress. Model weights are written as standard HuggingFace Safetensors shards, readable by any HF-compatible tool without a Forgather dependency.
Related documentation:
- Checkpointing Overview — concepts and basic usage
- User Guide — practical patterns and troubleshooting
- Distributed Abstraction — state-sharing patterns (GLOBAL, PER_RANK, REPLICATED, etc.)
- Migration Guide — implementing custom trainers
- Sharded Checkpoint API — low-level shard API reference
Checkpoint Management¶
forgather.ml.sharded_checkpoint.CheckpointMeta
dataclass
¶
Source code in src/forgather/ml/sharded_checkpoint.py
forgather.ml.sharded_checkpoint.save_checkpoint(output_dir, module, metadata=None, safetensors=False, max_shard_size=2 ** 31, debug=False, include_param_sharing=True, param_sharing_metadata=None)
¶
Save a sharded checkpoint for the whole model or a raw state dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_dir
|
str
|
Directory to write the checkpoint files into. |
required |
module
|
Module or Dict[str, Tensor]
|
An nn.Module or a raw state dictionary to checkpoint. |
required |
metadata
|
dict or None
|
Additional metadata to embed in the shard index. |
None
|
safetensors
|
bool
|
Save in safetensors format when True, PyTorch otherwise. |
False
|
max_shard_size
|
int
|
Maximum bytes per shard file. |
2 ** 31
|
debug
|
bool
|
Enable debug-level logging of individual weights. |
False
|
include_param_sharing
|
bool
|
If True and module is an nn.Module, detect and include buffer sharing metadata automatically. |
True
|
param_sharing_metadata
|
list of list of str or None
|
Explicit sharing metadata. When provided, skips auto-detection even if module is an nn.Module. |
None
|
Source code in src/forgather/ml/sharded_checkpoint.py
forgather.ml.sharded_checkpoint.load_checkpoint(model_dir, module=None, device='cpu', strict=True, assign=False, keys=None)
¶
Automatically detects checkpoint type and loads accordingly.
This should work for both sharded and normal checkpoint with either PyTorch or safetensor formats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_dir
|
str
|
Directory containing checkpoint files. |
required |
module
|
Module or None
|
An nn.Module to load weights into. If None, returns a raw Dict[str, Tensor] instead of loading into a module. |
None
|
device
|
str
|
Device to map tensors to when loading. |
'cpu'
|
strict
|
bool
|
Whether to require all module keys to be present in the checkpoint. |
True
|
assign
|
bool
|
If True, assign loaded tensors rather than copying data. |
False
|
keys
|
set of str or None
|
When module is None, optionally restrict which keys to load. Ignored when module is provided. |
None
|
Notes
See torch.nn.Module.load_state_dict
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>_
for the semantics of the strict and assign flags.
When the checkpoint is torchao-quantized, this function installs the
matching quantized linear modules on module before
load_state_dict runs and forces assign=True (Tensor.copy_
does not handle quantized-to-quantized copies). In that branch the
device argument is silently overridden to the module's existing
device, so the assign-rebound tensors don't migrate the model
off the caller's compute device. Tied weights are restored
post-load by the trainer's retie_parameters() step; eval /
inference paths that don't re-tie still produce correct outputs
because quantized inference doesn't grad-update tied tensors.
Source code in src/forgather/ml/sharded_checkpoint.py
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 | |
forgather.ml.sharded_checkpoint.find_latest_checkpoint(model_dir)
¶
Find the most recent valid checkpoint in the checkpoints directory.
Uses checkpoint_manifest.json timestamp when available, falling back to filesystem modification time for legacy checkpoints.
Source code in src/forgather/ml/sharded_checkpoint.py
forgather.ml.sharded_checkpoint.next_checkpoint_path(model_dir, checkpoint_id)
¶
Get path to save next checkpoint, given model directory and global_step
Source code in src/forgather/ml/sharded_checkpoint.py
forgather.ml.sharded_checkpoint.validate_checkpoint(checkpoint_path)
¶
Validate that a checkpoint directory contains the necessary files.
Source code in src/forgather/ml/sharded_checkpoint.py
Protocols¶
Protocols that trainer components implement to participate in the checkpoint system.
forgather.ml.trainer.trainer_types.CheckpointInterface
¶
Bases: Protocol
Protocol for checkpoint management.
Defines interface for saving/loading complete training state (model, optimizer, scheduler, dataset position, RNG state, etc.) and standalone model weights.
Implementations: - CheckpointManager: Standard implementation in src/forgather/ml/trainer/checkpoint_manager.py
Key responsibilities: - Save complete training checkpoints with versioning and limits - Load checkpoints for resuming training - Track best checkpoint (for load_best_model_at_end) - Save standalone model weights (HF Trainer compatibility)
Source code in src/forgather/ml/trainer/trainer_types.py
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 | |
save_checkpoint(checkpoint_path=None, checkpoint_id=None)
abstractmethod
¶
Save complete training checkpoint.
Args: checkpoint_path: Specific path for checkpoint, or None for auto-generated checkpoint_id: Identifier for checkpoint (e.g., global_step), used if path is None
Returns: Path to saved checkpoint directory
Source code in src/forgather/ml/trainer/trainer_types.py
load_checkpoint(checkpoint_path=None)
abstractmethod
¶
Load checkpoint to resume training.
Args: checkpoint_path: Path to checkpoint, or None to load latest checkpoint
Source code in src/forgather/ml/trainer/trainer_types.py
save_model(output_dir=None, overwrite_output_dir=False)
abstractmethod
¶
Save only model weights (not full training state).
Args: output_dir: Directory to save model, or None for default overwrite_output_dir: Whether to overwrite existing model
Source code in src/forgather/ml/trainer/trainer_types.py
set_best_checkpoint(best_checkpoint)
abstractmethod
¶
Mark a checkpoint as the best model.
Args: best_checkpoint: Path to checkpoint to mark as best
resolve_checkpoint_path(checkpoint_path)
abstractmethod
¶
Resolve checkpoint path (e.g., find latest if path is None).
Args: checkpoint_path: Explicit path or None for auto-resolution
Returns: Resolved checkpoint path or None if not found
Source code in src/forgather/ml/trainer/trainer_types.py
forgather.ml.trainer.trainer_types.StatefulProvider
¶
Bases: Protocol
Protocol for providing stateful objects for checkpointing.
Used by checkpoint managers to collect all components that need to be saved/restored during checkpointing (optimizer, scheduler, dataset, etc.).
The protocol uses StateComponents which declare explicit sharing patterns (GLOBAL, PER_RANK, REPLICATED, etc.) to enable automatic distributed checkpoint coordination for hybrid parallelism strategies.
All implementations must provide: - get_state_components(): Returns list of StateComponents with sharing patterns - get_process_groups(): Returns named process groups (only if using PER_GROUP pattern)
Source code in src/forgather/ml/trainer/trainer_types.py
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 | |
get_state_components()
abstractmethod
¶
Get state components with explicit sharing patterns for distributed checkpointing.
This is the new preferred API for checkpoint coordination. Each StateComponent declares its sharing pattern (GLOBAL, PER_RANK, REPLICATED, etc.), enabling automatic distributed checkpoint coordination without manual rank checks.
Returns: List of StateComponent objects describing all checkpointable state
Example implementation for single-GPU trainer: def get_state_components(self): from forgather.ml.trainer.checkpoint_types import StateComponent, SharingPattern
return [
StateComponent(
key="model",
stateful=self.model,
sharing_pattern=SharingPattern.GLOBAL,
),
StateComponent(
key="optimizer",
stateful=self.optimizer,
sharing_pattern=SharingPattern.GLOBAL,
),
StateComponent(
key="scheduler",
stateful=self.lr_scheduler,
sharing_pattern=SharingPattern.GLOBAL,
),
StateComponent(
key="dataset",
stateful=self.train_dataloader,
sharing_pattern=self._get_dataset_sharing_pattern(),
),
StateComponent(
key="rng",
stateful=RNGState(),
sharing_pattern=SharingPattern.PER_RANK,
),
]
Example for DDP trainer: def get_state_components(self): return [ StateComponent( key="model", stateful=self.unwrapped_model(), sharing_pattern=SharingPattern.REPLICATED, validate_replication=True, # Verify DDP synchronization ), # ... other components ]
See: docs/checkpointing/migration_guide.md for full migration guide
Source code in src/forgather/ml/trainer/trainer_types.py
get_process_groups()
¶
Get named process groups for PER_GROUP sharing pattern.
Returns dictionary mapping group names to ProcessGroup objects. Only needed if using PER_GROUP sharing pattern in state components.
Returns: Dictionary mapping process group names to ProcessGroup objects (e.g., {"dp_group": dp_pg, "pp_group": pp_pg})
Example: def get_process_groups(self): return { "dp_group": self.dp_process_group, "pp_group": self.pp_process_group, }