Skip to content

Checkpoints

Forgather's checkpointing system saves and restores all trainer state — model weights, optimizer, LR scheduler, dataset position, per-rank RNG, and training progress. Model weights are written as standard HuggingFace Safetensors shards, readable by any HF-compatible tool without a Forgather dependency.

Related documentation:


Checkpoint Management

forgather.ml.sharded_checkpoint.CheckpointMeta dataclass

Source code in src/forgather/ml/sharded_checkpoint.py
@dataclass
class CheckpointMeta:
    # The name of the index, if one exists, else, weights file
    file_name: str

    # The file name is an index file
    is_index: bool

    # The weights file uses safetensors, else PyTorch
    safetensors: bool

forgather.ml.sharded_checkpoint.save_checkpoint(output_dir, module, metadata=None, safetensors=False, max_shard_size=2 ** 31, debug=False, include_param_sharing=True, param_sharing_metadata=None)

Save a sharded checkpoint for the whole model or a raw state dict.

Parameters:

Name Type Description Default
output_dir str

Directory to write the checkpoint files into.

required
module Module or Dict[str, Tensor]

An nn.Module or a raw state dictionary to checkpoint.

required
metadata dict or None

Additional metadata to embed in the shard index.

None
safetensors bool

Save in safetensors format when True, PyTorch otherwise.

False
max_shard_size int

Maximum bytes per shard file.

2 ** 31
debug bool

Enable debug-level logging of individual weights.

False
include_param_sharing bool

If True and module is an nn.Module, detect and include buffer sharing metadata automatically.

True
param_sharing_metadata list of list of str or None

Explicit sharing metadata. When provided, skips auto-detection even if module is an nn.Module.

None
Source code in src/forgather/ml/sharded_checkpoint.py
def save_checkpoint(
    output_dir: str,
    module: StateDictLike,
    metadata: Optional[Dict] = None,
    safetensors: bool = False,
    max_shard_size: int = 2**31,
    debug: bool = False,
    include_param_sharing: bool = True,
    param_sharing_metadata: Optional[SharingMetadataT] = None,
) -> None:
    """
    Save a sharded checkpoint for the whole model or a raw state dict.

    Parameters
    ----------
    output_dir : str
        Directory to write the checkpoint files into.
    module : nn.Module or Dict[str, Tensor]
        An nn.Module or a raw state dictionary to checkpoint.
    metadata : dict or None, optional
        Additional metadata to embed in the shard index.
    safetensors : bool, optional
        Save in safetensors format when True, PyTorch otherwise.
    max_shard_size : int, optional
        Maximum bytes per shard file.
    debug : bool, optional
        Enable debug-level logging of individual weights.
    include_param_sharing : bool, optional
        If True and module is an nn.Module, detect and
        include buffer sharing metadata automatically.
    param_sharing_metadata : list of list of str or None, optional
        Explicit sharing metadata. When provided, skips
        auto-detection even if module is an nn.Module.
    """
    state_dict = _resolve_state_dict(module)

    # Detect buffer sharing if requested and not explicitly provided
    if (
        param_sharing_metadata is None
        and include_param_sharing
        and isinstance(module, Module)
    ):
        param_sharing_metadata = create_sharing_metadata(module)
        if param_sharing_metadata:
            logger.debug(f"Detected {len(param_sharing_metadata)} shared buffer groups")

    shard_index = make_shard_index(
        [state_dict],
        metadata=metadata,
        safetensors=safetensors,
        max_shard_size=max_shard_size,
        param_sharing_metadata=param_sharing_metadata,
    )
    if safetensors:
        index_name = SAFE_WEIGHTS_INDEX_NAME
    else:
        index_name = WEIGHTS_INDEX_NAME
    save_shard_index(shard_index, output_dir, index_name)
    save_sharded_checkpoint(
        output_dir,
        shard_index,
        state_dict,
        safetensors=safetensors,
        debug=debug,
    )

forgather.ml.sharded_checkpoint.load_checkpoint(model_dir, module=None, device='cpu', strict=True, assign=False, keys=None)

load_checkpoint(model_dir: str, module: Module, device: str, strict: bool = True, assign: bool = False, keys: Optional[Set[str]] = None) -> None
load_checkpoint(model_dir: str, module: None, device: str, strict: bool = True, assign: bool = False, keys: Optional[Set[str]] = None) -> Dict[str, Tensor]

Automatically detects checkpoint type and loads accordingly.

This should work for both sharded and normal checkpoint with either PyTorch or safetensor formats.

Parameters:

Name Type Description Default
model_dir str

Directory containing checkpoint files.

required
module Module or None

An nn.Module to load weights into. If None, returns a raw Dict[str, Tensor] instead of loading into a module.

None
device str

Device to map tensors to when loading.

'cpu'
strict bool

Whether to require all module keys to be present in the checkpoint.

True
assign bool

If True, assign loaded tensors rather than copying data.

False
keys set of str or None

When module is None, optionally restrict which keys to load. Ignored when module is provided.

None
Notes

See torch.nn.Module.load_state_dict <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>_ for the semantics of the strict and assign flags.

When the checkpoint is torchao-quantized, this function installs the matching quantized linear modules on module before load_state_dict runs and forces assign=True (Tensor.copy_ does not handle quantized-to-quantized copies). In that branch the device argument is silently overridden to the module's existing device, so the assign-rebound tensors don't migrate the model off the caller's compute device. Tied weights are restored post-load by the trainer's retie_parameters() step; eval / inference paths that don't re-tie still produce correct outputs because quantized inference doesn't grad-update tied tensors.

Source code in src/forgather/ml/sharded_checkpoint.py
def load_checkpoint(
    model_dir: str,
    module: Optional[Module] = None,
    device: str = "cpu",
    strict: bool = True,
    assign: bool = False,
    keys: Optional[Set[str]] = None,
) -> Union[None, Dict[str, Tensor]]:
    """
    Automatically detects checkpoint type and loads accordingly.

    This should work for both sharded and normal checkpoint with either PyTorch
    or safetensor formats.

    Parameters
    ----------
    model_dir : str
        Directory containing checkpoint files.
    module : nn.Module or None, optional
        An nn.Module to load weights into. If None, returns a raw
        Dict[str, Tensor] instead of loading into a module.
    device : str, optional
        Device to map tensors to when loading.
    strict : bool, optional
        Whether to require all module keys to be present in the checkpoint.
    assign : bool, optional
        If True, assign loaded tensors rather than copying data.
    keys : set of str or None, optional
        When module is None, optionally restrict which keys to load.
        Ignored when module is provided.

    Notes
    -----
    See `torch.nn.Module.load_state_dict
    <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>`_
    for the semantics of the ``strict`` and ``assign`` flags.

    When the checkpoint is torchao-quantized, this function installs the
    matching quantized linear modules on ``module`` before
    ``load_state_dict`` runs and forces ``assign=True`` (``Tensor.copy_``
    does not handle quantized-to-quantized copies). In that branch the
    ``device`` argument is silently overridden to the module's existing
    device, so the ``assign``-rebound tensors don't migrate the model
    off the caller's compute device. Tied weights are restored
    post-load by the trainer's ``retie_parameters()`` step; eval /
    inference paths that don't re-tie still produce correct outputs
    because quantized inference doesn't grad-update tied tensors.
    """
    checkpoint_meta = get_checkpoint_metadata(model_dir)

    if not checkpoint_meta:
        raise FileNotFoundError

    if checkpoint_meta.is_index:
        shard_index = load_shard_index(model_dir, checkpoint_meta.file_name)
        if module is not None and _maybe_install_torchao_quantization(
            model_dir, module, shard_index=shard_index,
            safetensors=checkpoint_meta.safetensors, device=device,
        ):
            # Quantized weights are tensor subclasses; ``Tensor.copy_``
            # between two quantized subclasses fails with a metadata
            # mismatch. ``assign=True`` rebinds the Parameter directly,
            # bypassing copy_. The flip-side: assigned tensors keep
            # their map_location device, so we must load to wherever
            # the (already-constructed) module lives — not to the
            # caller-passed ``device`` (which may be a staging area
            # like CPU).
            assign = True
            device = _module_device(module, fallback=device)
        return load_sharded_checkpoint(
            model_dir,
            shard_index,
            module,
            device=device,
            safetensors=checkpoint_meta.safetensors,
            strict=strict,
            assign=assign,
            keys=keys,
        )

    state_dict_path = os.path.join(model_dir, checkpoint_meta.file_name)
    if checkpoint_meta.safetensors:
        state_dict = safetensors_load(
            state_dict_path, device=torch.device(device).index
        )
    else:
        state_dict = torch.load(
            state_dict_path, map_location=device, weights_only=True, mmap=True
        )

    if module is None:
        if keys is not None:
            return {k: v for k, v in state_dict.items() if k in keys}
        return state_dict

    if _maybe_install_torchao_quantization(model_dir, module, state_dict=state_dict):
        assign = True
        # Move the loaded tensors onto the module's existing device before
        # assigning them in (see the sharded branch above for the reason).
        target = _module_device(module, fallback=device)
        if str(target) != str(device):
            state_dict = {k: v.to(target) for k, v in state_dict.items()}

    # TODO: Properly handle strict, in this case?
    # We wish to ensure that all model weights were loaded, but ignore any other weights, like we do in load_sharded_checkpoint()
    module.load_state_dict(state_dict, strict=strict, assign=assign)
    return None

forgather.ml.sharded_checkpoint.find_latest_checkpoint(model_dir)

Find the most recent valid checkpoint in the checkpoints directory.

Uses checkpoint_manifest.json timestamp when available, falling back to filesystem modification time for legacy checkpoints.

Source code in src/forgather/ml/sharded_checkpoint.py
def find_latest_checkpoint(model_dir: str) -> str | None:
    """Find the most recent valid checkpoint in the checkpoints directory.

    Uses checkpoint_manifest.json timestamp when available, falling back to
    filesystem modification time for legacy checkpoints.
    """
    checkpoints_dir = os.path.join(model_dir, "checkpoints")

    # If checkpoints directory does not exist, check the model directory
    if not os.path.exists(checkpoints_dir):
        logger.info(
            "No checkpoint directory found. Defaulting to main model directory."
        )
        if validate_checkpoint(model_dir):
            return model_dir
        else:
            return None

    checkpoints = glob.glob(os.path.join(checkpoints_dir, "checkpoint-*"))
    if not checkpoints:
        return None

    # Filter to only valid checkpoints and sort by modification time
    valid_checkpoints = [cp for cp in checkpoints if validate_checkpoint(cp)]

    if not valid_checkpoints:
        logger.warning("No valid checkpoints found in checkpoint directory")
        return None

    try:
        latest = max(valid_checkpoints, key=_checkpoint_sort_key)
        step_num = (
            os.path.basename(latest).split("-")[1]
            if "-" in os.path.basename(latest)
            else "unknown"
        )
        ts = _get_checkpoint_timestamp(latest)
        ts_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(ts))
        logger.debug(
            f"Found latest valid checkpoint: {latest} (step {step_num}, timestamp {ts_str})"
        )
        return latest
    except (OSError, IndexError) as e:
        logger.warning(f"Error finding latest checkpoint: {e}")
        return None

forgather.ml.sharded_checkpoint.next_checkpoint_path(model_dir, checkpoint_id)

Get path to save next checkpoint, given model directory and global_step

Source code in src/forgather/ml/sharded_checkpoint.py
def next_checkpoint_path(model_dir: str, checkpoint_id: int | str) -> str:
    """Get path to save next checkpoint, given model directory and global_step"""
    checkpoints_dir = os.path.join(model_dir, "checkpoints")
    checkpoint_path = os.path.join(checkpoints_dir, f"checkpoint-{str(checkpoint_id)}")
    return checkpoint_path

forgather.ml.sharded_checkpoint.validate_checkpoint(checkpoint_path)

Validate that a checkpoint directory contains the necessary files.

Source code in src/forgather/ml/sharded_checkpoint.py
def validate_checkpoint(checkpoint_path: str) -> bool:
    """Validate that a checkpoint directory contains the necessary files."""
    if not os.path.isdir(checkpoint_path):
        return False

    # Check for at least one of the expected model files
    expected_model_files = [
        WEIGHTS_NAME,
        SAFE_WEIGHTS_NAME,
        SAFE_WEIGHTS_INDEX_NAME,
        WEIGHTS_INDEX_NAME,
    ]

    has_checkpoint = any(
        os.path.exists(os.path.join(checkpoint_path, filename))
        for filename in expected_model_files
    )

    if not has_checkpoint:
        return False

    return True

Protocols

Protocols that trainer components implement to participate in the checkpoint system.

forgather.ml.trainer.trainer_types.CheckpointInterface

Bases: Protocol

Protocol for checkpoint management.

Defines interface for saving/loading complete training state (model, optimizer, scheduler, dataset position, RNG state, etc.) and standalone model weights.

Implementations: - CheckpointManager: Standard implementation in src/forgather/ml/trainer/checkpoint_manager.py

Key responsibilities: - Save complete training checkpoints with versioning and limits - Load checkpoints for resuming training - Track best checkpoint (for load_best_model_at_end) - Save standalone model weights (HF Trainer compatibility)

Source code in src/forgather/ml/trainer/trainer_types.py
class CheckpointInterface(Protocol):
    """
    Protocol for checkpoint management.

    Defines interface for saving/loading complete training state (model, optimizer,
    scheduler, dataset position, RNG state, etc.) and standalone model weights.

    Implementations:
    - CheckpointManager: Standard implementation in src/forgather/ml/trainer/checkpoint_manager.py

    Key responsibilities:
    - Save complete training checkpoints with versioning and limits
    - Load checkpoints for resuming training
    - Track best checkpoint (for load_best_model_at_end)
    - Save standalone model weights (HF Trainer compatibility)
    """

    @abstractmethod
    def save_checkpoint(
        self,
        checkpoint_path: str | None = None,
        checkpoint_id: str | None = None,
    ) -> str:
        """
        Save complete training checkpoint.

        Args:
            checkpoint_path: Specific path for checkpoint, or None for auto-generated
            checkpoint_id: Identifier for checkpoint (e.g., global_step), used if path is None

        Returns:
            Path to saved checkpoint directory
        """
        pass

    @abstractmethod
    def load_checkpoint(self, checkpoint_path: str | None = None) -> None:
        """
        Load checkpoint to resume training.

        Args:
            checkpoint_path: Path to checkpoint, or None to load latest checkpoint
        """
        pass

    @abstractmethod
    def save_model(
        self,
        output_dir: str | os.PathLike | None = None,
        overwrite_output_dir: bool = False,
    ) -> None:
        """
        Save only model weights (not full training state).

        Args:
            output_dir: Directory to save model, or None for default
            overwrite_output_dir: Whether to overwrite existing model
        """
        pass

    @abstractmethod
    def set_best_checkpoint(self, best_checkpoint: str) -> None:
        """
        Mark a checkpoint as the best model.

        Args:
            best_checkpoint: Path to checkpoint to mark as best
        """
        pass

    @abstractmethod
    def resolve_checkpoint_path(self, checkpoint_path: str | None) -> str | None:
        """
        Resolve checkpoint path (e.g., find latest if path is None).

        Args:
            checkpoint_path: Explicit path or None for auto-resolution

        Returns:
            Resolved checkpoint path or None if not found
        """
        pass

save_checkpoint(checkpoint_path=None, checkpoint_id=None) abstractmethod

Save complete training checkpoint.

Args: checkpoint_path: Specific path for checkpoint, or None for auto-generated checkpoint_id: Identifier for checkpoint (e.g., global_step), used if path is None

Returns: Path to saved checkpoint directory

Source code in src/forgather/ml/trainer/trainer_types.py
@abstractmethod
def save_checkpoint(
    self,
    checkpoint_path: str | None = None,
    checkpoint_id: str | None = None,
) -> str:
    """
    Save complete training checkpoint.

    Args:
        checkpoint_path: Specific path for checkpoint, or None for auto-generated
        checkpoint_id: Identifier for checkpoint (e.g., global_step), used if path is None

    Returns:
        Path to saved checkpoint directory
    """
    pass

load_checkpoint(checkpoint_path=None) abstractmethod

Load checkpoint to resume training.

Args: checkpoint_path: Path to checkpoint, or None to load latest checkpoint

Source code in src/forgather/ml/trainer/trainer_types.py
@abstractmethod
def load_checkpoint(self, checkpoint_path: str | None = None) -> None:
    """
    Load checkpoint to resume training.

    Args:
        checkpoint_path: Path to checkpoint, or None to load latest checkpoint
    """
    pass

save_model(output_dir=None, overwrite_output_dir=False) abstractmethod

Save only model weights (not full training state).

Args: output_dir: Directory to save model, or None for default overwrite_output_dir: Whether to overwrite existing model

Source code in src/forgather/ml/trainer/trainer_types.py
@abstractmethod
def save_model(
    self,
    output_dir: str | os.PathLike | None = None,
    overwrite_output_dir: bool = False,
) -> None:
    """
    Save only model weights (not full training state).

    Args:
        output_dir: Directory to save model, or None for default
        overwrite_output_dir: Whether to overwrite existing model
    """
    pass

set_best_checkpoint(best_checkpoint) abstractmethod

Mark a checkpoint as the best model.

Args: best_checkpoint: Path to checkpoint to mark as best

Source code in src/forgather/ml/trainer/trainer_types.py
@abstractmethod
def set_best_checkpoint(self, best_checkpoint: str) -> None:
    """
    Mark a checkpoint as the best model.

    Args:
        best_checkpoint: Path to checkpoint to mark as best
    """
    pass

resolve_checkpoint_path(checkpoint_path) abstractmethod

Resolve checkpoint path (e.g., find latest if path is None).

Args: checkpoint_path: Explicit path or None for auto-resolution

Returns: Resolved checkpoint path or None if not found

Source code in src/forgather/ml/trainer/trainer_types.py
@abstractmethod
def resolve_checkpoint_path(self, checkpoint_path: str | None) -> str | None:
    """
    Resolve checkpoint path (e.g., find latest if path is None).

    Args:
        checkpoint_path: Explicit path or None for auto-resolution

    Returns:
        Resolved checkpoint path or None if not found
    """
    pass

forgather.ml.trainer.trainer_types.StatefulProvider

Bases: Protocol

Protocol for providing stateful objects for checkpointing.

Used by checkpoint managers to collect all components that need to be saved/restored during checkpointing (optimizer, scheduler, dataset, etc.).

The protocol uses StateComponents which declare explicit sharing patterns (GLOBAL, PER_RANK, REPLICATED, etc.) to enable automatic distributed checkpoint coordination for hybrid parallelism strategies.

All implementations must provide: - get_state_components(): Returns list of StateComponents with sharing patterns - get_process_groups(): Returns named process groups (only if using PER_GROUP pattern)

Source code in src/forgather/ml/trainer/trainer_types.py
class StatefulProvider(Protocol):
    """
    Protocol for providing stateful objects for checkpointing.

    Used by checkpoint managers to collect all components that need to be
    saved/restored during checkpointing (optimizer, scheduler, dataset, etc.).

    The protocol uses StateComponents which declare explicit sharing patterns
    (GLOBAL, PER_RANK, REPLICATED, etc.) to enable automatic distributed
    checkpoint coordination for hybrid parallelism strategies.

    All implementations must provide:
    - get_state_components(): Returns list of StateComponents with sharing patterns
    - get_process_groups(): Returns named process groups (only if using PER_GROUP pattern)
    """

    @abstractmethod
    def get_state_components(self) -> List["StateComponent"]:  # type: ignore
        """
        Get state components with explicit sharing patterns for distributed checkpointing.

        This is the new preferred API for checkpoint coordination. Each StateComponent
        declares its sharing pattern (GLOBAL, PER_RANK, REPLICATED, etc.), enabling
        automatic distributed checkpoint coordination without manual rank checks.

        Returns:
            List of StateComponent objects describing all checkpointable state

        Example implementation for single-GPU trainer:
            def get_state_components(self):
                from forgather.ml.trainer.checkpoint_types import StateComponent, SharingPattern

                return [
                    StateComponent(
                        key="model",
                        stateful=self.model,
                        sharing_pattern=SharingPattern.GLOBAL,
                    ),
                    StateComponent(
                        key="optimizer",
                        stateful=self.optimizer,
                        sharing_pattern=SharingPattern.GLOBAL,
                    ),
                    StateComponent(
                        key="scheduler",
                        stateful=self.lr_scheduler,
                        sharing_pattern=SharingPattern.GLOBAL,
                    ),
                    StateComponent(
                        key="dataset",
                        stateful=self.train_dataloader,
                        sharing_pattern=self._get_dataset_sharing_pattern(),
                    ),
                    StateComponent(
                        key="rng",
                        stateful=RNGState(),
                        sharing_pattern=SharingPattern.PER_RANK,
                    ),
                ]

        Example for DDP trainer:
            def get_state_components(self):
                return [
                    StateComponent(
                        key="model",
                        stateful=self.unwrapped_model(),
                        sharing_pattern=SharingPattern.REPLICATED,
                        validate_replication=True,  # Verify DDP synchronization
                    ),
                    # ... other components
                ]

        See: docs/checkpointing/migration_guide.md for full migration guide
        """
        pass

    def get_process_groups(self) -> Dict[str, Any]:
        """
        Get named process groups for PER_GROUP sharing pattern.

        Returns dictionary mapping group names to ProcessGroup objects.
        Only needed if using PER_GROUP sharing pattern in state components.

        Returns:
            Dictionary mapping process group names to ProcessGroup objects
            (e.g., {"dp_group": dp_pg, "pp_group": pp_pg})

        Example:
            def get_process_groups(self):
                return {
                    "dp_group": self.dp_process_group,
                    "pp_group": self.pp_process_group,
                }
        """
        return {}

get_state_components() abstractmethod

Get state components with explicit sharing patterns for distributed checkpointing.

This is the new preferred API for checkpoint coordination. Each StateComponent declares its sharing pattern (GLOBAL, PER_RANK, REPLICATED, etc.), enabling automatic distributed checkpoint coordination without manual rank checks.

Returns: List of StateComponent objects describing all checkpointable state

Example implementation for single-GPU trainer: def get_state_components(self): from forgather.ml.trainer.checkpoint_types import StateComponent, SharingPattern

    return [
        StateComponent(
            key="model",
            stateful=self.model,
            sharing_pattern=SharingPattern.GLOBAL,
        ),
        StateComponent(
            key="optimizer",
            stateful=self.optimizer,
            sharing_pattern=SharingPattern.GLOBAL,
        ),
        StateComponent(
            key="scheduler",
            stateful=self.lr_scheduler,
            sharing_pattern=SharingPattern.GLOBAL,
        ),
        StateComponent(
            key="dataset",
            stateful=self.train_dataloader,
            sharing_pattern=self._get_dataset_sharing_pattern(),
        ),
        StateComponent(
            key="rng",
            stateful=RNGState(),
            sharing_pattern=SharingPattern.PER_RANK,
        ),
    ]

Example for DDP trainer: def get_state_components(self): return [ StateComponent( key="model", stateful=self.unwrapped_model(), sharing_pattern=SharingPattern.REPLICATED, validate_replication=True, # Verify DDP synchronization ), # ... other components ]

See: docs/checkpointing/migration_guide.md for full migration guide

Source code in src/forgather/ml/trainer/trainer_types.py
@abstractmethod
def get_state_components(self) -> List["StateComponent"]:  # type: ignore
    """
    Get state components with explicit sharing patterns for distributed checkpointing.

    This is the new preferred API for checkpoint coordination. Each StateComponent
    declares its sharing pattern (GLOBAL, PER_RANK, REPLICATED, etc.), enabling
    automatic distributed checkpoint coordination without manual rank checks.

    Returns:
        List of StateComponent objects describing all checkpointable state

    Example implementation for single-GPU trainer:
        def get_state_components(self):
            from forgather.ml.trainer.checkpoint_types import StateComponent, SharingPattern

            return [
                StateComponent(
                    key="model",
                    stateful=self.model,
                    sharing_pattern=SharingPattern.GLOBAL,
                ),
                StateComponent(
                    key="optimizer",
                    stateful=self.optimizer,
                    sharing_pattern=SharingPattern.GLOBAL,
                ),
                StateComponent(
                    key="scheduler",
                    stateful=self.lr_scheduler,
                    sharing_pattern=SharingPattern.GLOBAL,
                ),
                StateComponent(
                    key="dataset",
                    stateful=self.train_dataloader,
                    sharing_pattern=self._get_dataset_sharing_pattern(),
                ),
                StateComponent(
                    key="rng",
                    stateful=RNGState(),
                    sharing_pattern=SharingPattern.PER_RANK,
                ),
            ]

    Example for DDP trainer:
        def get_state_components(self):
            return [
                StateComponent(
                    key="model",
                    stateful=self.unwrapped_model(),
                    sharing_pattern=SharingPattern.REPLICATED,
                    validate_replication=True,  # Verify DDP synchronization
                ),
                # ... other components
            ]

    See: docs/checkpointing/migration_guide.md for full migration guide
    """
    pass

get_process_groups()

Get named process groups for PER_GROUP sharing pattern.

Returns dictionary mapping group names to ProcessGroup objects. Only needed if using PER_GROUP sharing pattern in state components.

Returns: Dictionary mapping process group names to ProcessGroup objects (e.g., {"dp_group": dp_pg, "pp_group": pp_pg})

Example: def get_process_groups(self): return { "dp_group": self.dp_process_group, "pp_group": self.pp_process_group, }

Source code in src/forgather/ml/trainer/trainer_types.py
def get_process_groups(self) -> Dict[str, Any]:
    """
    Get named process groups for PER_GROUP sharing pattern.

    Returns dictionary mapping group names to ProcessGroup objects.
    Only needed if using PER_GROUP sharing pattern in state components.

    Returns:
        Dictionary mapping process group names to ProcessGroup objects
        (e.g., {"dp_group": dp_pg, "pp_group": pp_pg})

    Example:
        def get_process_groups(self):
            return {
                "dp_group": self.dp_process_group,
                "pp_group": self.pp_process_group,
            }
    """
    return {}