Skip to content

Trainers

Forgather provides a hierarchy of trainer classes for single-GPU through multi-node distributed training.

Trainer Use case
Trainer Single-GPU, the fast path for small-model experiments
DDPTrainer Multi-GPU DistributedDataParallel, with optional PostLocalSGD
FSDP2Trainer FSDP-2 sharded data parallel with CPU offload support
PipelineTrainer Pipeline parallelism for bandwidth-limited environments

For a complete reference of every training argument and constructor parameter across all trainers, see also:


Core Types

Shared types and protocols used across all trainers.

forgather.ml.distributed.DistributedEnvironment

Bases: DistributedEnvInterface

Initialize and manage the PyTorch distributed training environment.

This class handles the complete setup of distributed training, including: - Synchronizing with environment variables set by launchers (torchrun, etc.) - Setting up the appropriate device (GPU or CPU) - Initializing the torch.distributed process group

The distributed environment must be initialized before any torch.distributed calls can be made. In forgather configurations, this is typically included as an early dependency to ensure proper initialization order.

Environment Variable Behavior: - If environment variables are set (e.g., by torchrun), they override the values passed to init - If environment variables are not set, this class exports the init values to the environment for consistency

Device Selection: - With GPU available and no_accelerator=False: Uses GPU with nccl backend - With no_accelerator=True or no GPU: Uses CPU with gloo backend - Device is automatically assigned based on local_rank (or device_map)

Attributes: rank: Global rank of this process local_rank: Rank within the local node world_size: Total number of processes local_world_size: Number of processes on this node master_addr: Address of rank 0 for rendezvous master_port: Port for rendezvous backend: Distributed backend ("nccl", "gloo", etc.) device: Device string for this rank (e.g., "cuda:0", "cpu") device_type: Device type string (e.g., "cuda", "cpu") use_accelerator: Whether to use GPU acceleration

Example: In a forgather YAML configuration::

    distributed_env: &distributed_env !singleton:forgather.ml.distributed:DistributedEnvironment
        backend: "nccl"

For CPU-only testing::

    distributed_env: &distributed_env !singleton:forgather.ml.distributed:DistributedEnvironment
        no_accelerator: True
Source code in src/forgather/ml/distributed.py
class DistributedEnvironment(DistributedEnvInterface):
    """
    Initialize and manage the PyTorch distributed training environment.

    This class handles the complete setup of distributed training, including:
    - Synchronizing with environment variables set by launchers (torchrun, etc.)
    - Setting up the appropriate device (GPU or CPU)
    - Initializing the torch.distributed process group

    The distributed environment must be initialized before any torch.distributed
    calls can be made. In forgather configurations, this is typically included
    as an early dependency to ensure proper initialization order.

    Environment Variable Behavior:
        - If environment variables are set (e.g., by torchrun), they override
          the values passed to __init__
        - If environment variables are not set, this class exports the __init__
          values to the environment for consistency

    Device Selection:
        - With GPU available and no_accelerator=False: Uses GPU with nccl backend
        - With no_accelerator=True or no GPU: Uses CPU with gloo backend
        - Device is automatically assigned based on local_rank (or device_map)

    Attributes:
        rank: Global rank of this process
        local_rank: Rank within the local node
        world_size: Total number of processes
        local_world_size: Number of processes on this node
        master_addr: Address of rank 0 for rendezvous
        master_port: Port for rendezvous
        backend: Distributed backend ("nccl", "gloo", etc.)
        device: Device string for this rank (e.g., "cuda:0", "cpu")
        device_type: Device type string (e.g., "cuda", "cpu")
        use_accelerator: Whether to use GPU acceleration

    Example:
        In a forgather YAML configuration::

            distributed_env: &distributed_env !singleton:forgather.ml.distributed:DistributedEnvironment
                backend: "nccl"

        For CPU-only testing::

            distributed_env: &distributed_env !singleton:forgather.ml.distributed:DistributedEnvironment
                no_accelerator: True
    """

    def __init__(
        self,
        rank: int = 0,
        local_rank: int = 0,
        world_size: int = 1,
        local_world_size: int = 1,
        master_addr: str = "localhost",
        master_port: int = 29501,
        backend: str | None = None,
        log_level="INFO",
        device_map=None,
        always: bool = True,
        no_accelerator: bool = False,
    ):
        """
        Initialize the distributed environment.

        Parameters
        ----------
        rank : int, optional
            Global rank. Default 0, typically overridden by launcher environment.
        local_rank : int, optional
            Local rank within node. Default 0.
        world_size : int, optional
            Total number of processes. Default 1.
        local_world_size : int, optional
            Number of processes per node. Default 1.
        master_addr : str, optional
            Rendezvous address for rank 0. Default "localhost".
        master_port : int, optional
            Rendezvous port. Default 29501.
        backend : str or None, optional
            Distributed backend. If None, auto-selected based on device
            (nccl for GPU, gloo for CPU).
        log_level : str, optional
            Logging level for the distributed module. Default "INFO".
        device_map : dict or None, optional
            Mapping from rank to device index for custom device assignment.
            If None, uses local_rank.
        always : bool, optional
            If True, initialize distributed even for single process.
            Useful for consistent behavior across configurations. Default True.
        no_accelerator : bool, optional
            If True, force CPU execution even if GPU is available.
            Useful for testing distributed configurations without GPUs. Default False.
        """
        logger.setLevel(log_level)
        self.rank = rank
        self.local_rank = local_rank
        self.world_size = world_size
        self.local_world_size = local_world_size
        self.master_addr = master_addr
        self.master_port = master_port
        self.backend = backend
        self.always = always
        self.device_map = device_map
        self.use_accelerator = not no_accelerator
        self._init_distributed()

    def __repr__(self):
        # ``hostname`` makes asymmetric multi-node hangs much easier to
        # diagnose: when one rank's ``DistributedEnvironment(...)`` line
        # shows ``host=muthur`` and the others show ``host=wopr``, the
        # operator can correlate the deadlock site with the topology
        # at a glance instead of cross-referencing local_rank /
        # local_world_size with cluster_jobs.
        import socket

        return (
            f"{type(self).__name__}("
            f"rank={self.rank}, "
            f"local_rank={self.local_rank}, "
            f"world_size={self.world_size}, "
            f"local_world_size={self.local_world_size}, "
            f"host={socket.gethostname()}, "
            f"master_addr={self.master_addr}, "
            f"master_port={self.master_port}, "
            f"backend={self.backend})"
        )

    def _init_distributed(self):
        """
        Initialize distributed training: sync env vars, setup device, init process group.

        This method is called automatically during __init__ and performs:
        1. Bidirectional sync with environment variables
        2. Device selection and configuration
        3. Process group initialization (if not already done)
        """
        init_from_env(self)
        logger.info(str(self))

        # Device setup: use accelerator if available and not disabled
        if self.use_accelerator and accelerator.is_available():
            if self.device_map:
                accelerator.set_device_index(self.device_map[self.rank])
            else:
                accelerator.set_device_index(self.local_rank)
            acc = accelerator.current_accelerator()
            assert (
                acc is not None
            ), "accelerator.current_accelerator() returned None despite is_available() being True"
            if self.backend is None:
                self.backend = dist.get_default_backend_for_device(acc)
            idx = accelerator.current_device_index()
            self.device_type = acc.type
            self.device = f"{self.device_type}:{idx}"
        else:
            # CPU fallback
            self.device = "cpu"
            self.device_type = "cpu"

        # Process group initialization
        if dist.is_available() and (self.world_size > 1 or self.always):
            if not dist.is_initialized():
                self._init_process_group()
            else:
                logger.warning("torch distributed has already been initialized")
        else:
            assert (
                self.world_size == 1
            ), "World size is larger than 1 and torch distributed is not available."

    def _init_process_group(self):
        """Initialize the torch.distributed process group with the configured backend."""
        logger.info(f"init_process_group({self.backend, self.device})")
        dist.init_process_group(backend=self.backend)

__init__(rank=0, local_rank=0, world_size=1, local_world_size=1, master_addr='localhost', master_port=29501, backend=None, log_level='INFO', device_map=None, always=True, no_accelerator=False)

Initialize the distributed environment.

Parameters:

Name Type Description Default
rank int

Global rank. Default 0, typically overridden by launcher environment.

0
local_rank int

Local rank within node. Default 0.

0
world_size int

Total number of processes. Default 1.

1
local_world_size int

Number of processes per node. Default 1.

1
master_addr str

Rendezvous address for rank 0. Default "localhost".

'localhost'
master_port int

Rendezvous port. Default 29501.

29501
backend str or None

Distributed backend. If None, auto-selected based on device (nccl for GPU, gloo for CPU).

None
log_level str

Logging level for the distributed module. Default "INFO".

'INFO'
device_map dict or None

Mapping from rank to device index for custom device assignment. If None, uses local_rank.

None
always bool

If True, initialize distributed even for single process. Useful for consistent behavior across configurations. Default True.

True
no_accelerator bool

If True, force CPU execution even if GPU is available. Useful for testing distributed configurations without GPUs. Default False.

False
Source code in src/forgather/ml/distributed.py
def __init__(
    self,
    rank: int = 0,
    local_rank: int = 0,
    world_size: int = 1,
    local_world_size: int = 1,
    master_addr: str = "localhost",
    master_port: int = 29501,
    backend: str | None = None,
    log_level="INFO",
    device_map=None,
    always: bool = True,
    no_accelerator: bool = False,
):
    """
    Initialize the distributed environment.

    Parameters
    ----------
    rank : int, optional
        Global rank. Default 0, typically overridden by launcher environment.
    local_rank : int, optional
        Local rank within node. Default 0.
    world_size : int, optional
        Total number of processes. Default 1.
    local_world_size : int, optional
        Number of processes per node. Default 1.
    master_addr : str, optional
        Rendezvous address for rank 0. Default "localhost".
    master_port : int, optional
        Rendezvous port. Default 29501.
    backend : str or None, optional
        Distributed backend. If None, auto-selected based on device
        (nccl for GPU, gloo for CPU).
    log_level : str, optional
        Logging level for the distributed module. Default "INFO".
    device_map : dict or None, optional
        Mapping from rank to device index for custom device assignment.
        If None, uses local_rank.
    always : bool, optional
        If True, initialize distributed even for single process.
        Useful for consistent behavior across configurations. Default True.
    no_accelerator : bool, optional
        If True, force CPU execution even if GPU is available.
        Useful for testing distributed configurations without GPUs. Default False.
    """
    logger.setLevel(log_level)
    self.rank = rank
    self.local_rank = local_rank
    self.world_size = world_size
    self.local_world_size = local_world_size
    self.master_addr = master_addr
    self.master_port = master_port
    self.backend = backend
    self.always = always
    self.device_map = device_map
    self.use_accelerator = not no_accelerator
    self._init_distributed()

forgather.ml.trainer.trainer_types.MinimalTrainingArguments dataclass

Minimal training configuration compatible with HuggingFace Trainer.

Provides a subset of transformers.TrainingArguments sufficient for basic training. This is the base configuration class; extend it for additional features rather than adding fields here.

Direct subclasses: BaseTrainingArguments (checkpoint control, PyTorch optimisations) and TrainingArguments (simple single-GPU memory options).

Parameters:

Name Type Description Default
output_dir str

Directory where model predictions and checkpoints are written.

OUTPUTDIR_NAME
logging_dir str or None

TensorBoard log directory. Defaults to output_dir/runs/TIMESTAMP_HOSTNAME.

None
per_device_train_batch_size int

Training batch size per device. Effective global batch size is per_device_train_batch_size * num_devices * gradient_accumulation_steps.

16
per_device_eval_batch_size int

Evaluation batch size per device.

16
num_train_epochs int

Total training epochs (may be fractional, e.g. 2.5).

1
max_steps int

If > 0, total optimiser steps to perform (overrides num_train_epochs).

-1
device Any

Device to use ("cuda", "cpu", etc.). Auto-detected if None.

None
seed int

Random seed for reproducibility. Default -1 disables seeding.

-1
use_cpu bool

Force CPU usage even when CUDA is available.

False
epoch_train_steps int

Fallback epoch length when the dataset does not support len(). Used only for progress bookkeeping. Forgather extension.

100000
dataloader_num_workers int

Subprocesses for data loading. 0 loads in the main process.

0
dataloader_pin_memory bool

Pin memory in DataLoader for faster GPU transfer.

True
dataloader_persistent_workers bool

Keep worker processes alive between epochs (faster, uses more RAM).

False
dataloader_prefetch_factor int or None

Batches prefetched per worker. Defaults to 2 when num_workers > 0.

None
dataloader_drop_last bool

Drop the last incomplete batch when the dataset is not evenly divisible.

False
eval_strategy str

When to run evaluation: "no", "steps", or "epoch".

'no'
eval_steps int

Evaluation frequency in steps (when eval_strategy="steps").

100
eval_delay int

Epochs or steps to wait before the first evaluation.

0
logging_strategy str

When to log metrics: "no", "steps", or "epoch".

'steps'
logging_steps int

Logging frequency in steps (when logging_strategy="steps").

50
logging_first_step bool

Log metrics at the very first global step.

False
torch_compile bool

Compile the model with torch.compile() for speedup.

False
torch_compile_backend str or None

Backend for torch.compile (e.g. "inductor", "aot_eager").

None
torch_compile_mode str or None

Compilation mode: "default", "reduce-overhead", or "max-autotune".

'default'
torch_compile_dynamic bool

Allow dynamic shapes in the compiled model.

True
torch_compile_full_graph bool

Force compilation of the entire model as a single graph.

False
max_grad_norm float or None

Maximum gradient norm for clipping. None disables clipping.

None
gradient_accumulation_steps int

Accumulate gradients over this many steps before an optimiser update.

1
save_strategy str

Checkpoint save strategy: "no", "steps", or "epoch".

'steps'
save_steps int

Checkpoint save frequency in steps (when save_strategy="steps").

1000
save_total_limit int

Maximum number of checkpoints to keep; oldest are deleted first.

2
save_safetensors bool

Write weights as Safetensors (safer and HF-compatible) rather than pickle.

True
save_on_each_node bool

In multi-node training, save on every node rather than only rank 0. Do not enable when using shared storage.

False
overwrite_output_dir bool

Overwrite output_dir contents on a fresh run.

False
resume_from_checkpoint bool or str

True (default) automatically finds and loads the latest checkpoint, falling back to fresh initialisation if none exists. A path string loads that specific checkpoint. False forces fresh initialisation.

True
load_best_model_at_end bool

Restore the best checkpoint at the end of training. Requires save_strategy == eval_strategy.

False
metric_for_best_model str

Metric used to compare checkpoints when load_best_model_at_end=True.

'loss'
greater_is_better bool or None

Whether a higher metric value is better. Auto-determined from the metric name when None.

None
lr_scheduler_type str

LR scheduler type for the built-in AdamW path ("linear", "cosine", etc.).

'linear'
lr_scheduler_kwargs dict or None

Additional keyword arguments forwarded to the LR scheduler constructor.

None
warmup_steps int

Linear warmup steps from 0 to learning_rate.

0
learning_rate float

Initial learning rate for the built-in AdamW optimiser.

5e-05
weight_decay float

Weight decay applied to all parameters except bias and LayerNorm weights.

0.0
adam_beta1 float

Beta1 for the built-in AdamW optimiser.

0.9
adam_beta2 float

Beta2 for the built-in AdamW optimiser.

0.999
adam_epsilon float

Epsilon for the built-in AdamW optimiser.

1e-08
gradient_checkpointing bool

Enable activation checkpointing on models that support the HF API. Pass enable_activation_checkpoint_fn to the Trainer constructor to customise the checkpointing behaviour.

False
Source code in src/forgather/ml/trainer/trainer_types.py
@dataclass(kw_only=True)
class MinimalTrainingArguments:
    """Minimal training configuration compatible with HuggingFace Trainer.

    Provides a subset of ``transformers.TrainingArguments`` sufficient for basic
    training. This is the base configuration class; extend it for additional
    features rather than adding fields here.

    Direct subclasses: ``BaseTrainingArguments`` (checkpoint control, PyTorch
    optimisations) and ``TrainingArguments`` (simple single-GPU memory options).

    Parameters
    ----------
    output_dir : str, optional
        Directory where model predictions and checkpoints are written.
    logging_dir : str or None, optional
        TensorBoard log directory. Defaults to ``output_dir/runs/TIMESTAMP_HOSTNAME``.
    per_device_train_batch_size : int, optional
        Training batch size per device. Effective global batch size is
        ``per_device_train_batch_size * num_devices * gradient_accumulation_steps``.
    per_device_eval_batch_size : int, optional
        Evaluation batch size per device.
    num_train_epochs : int, optional
        Total training epochs (may be fractional, e.g. ``2.5``).
    max_steps : int, optional
        If > 0, total optimiser steps to perform (overrides ``num_train_epochs``).
    device : Any, optional
        Device to use (``"cuda"``, ``"cpu"``, etc.). Auto-detected if ``None``.
    seed : int, optional
        Random seed for reproducibility. Default ``-1`` disables seeding.
    use_cpu : bool, optional
        Force CPU usage even when CUDA is available.
    epoch_train_steps : int, optional
        Fallback epoch length when the dataset does not support ``len()``.
        Used only for progress bookkeeping. Forgather extension.
    dataloader_num_workers : int, optional
        Subprocesses for data loading. ``0`` loads in the main process.
    dataloader_pin_memory : bool, optional
        Pin memory in DataLoader for faster GPU transfer.
    dataloader_persistent_workers : bool, optional
        Keep worker processes alive between epochs (faster, uses more RAM).
    dataloader_prefetch_factor : int or None, optional
        Batches prefetched per worker. Defaults to ``2`` when ``num_workers > 0``.
    dataloader_drop_last : bool, optional
        Drop the last incomplete batch when the dataset is not evenly divisible.
    eval_strategy : str, optional
        When to run evaluation: ``"no"``, ``"steps"``, or ``"epoch"``.
    eval_steps : int, optional
        Evaluation frequency in steps (when ``eval_strategy="steps"``).
    eval_delay : int, optional
        Epochs or steps to wait before the first evaluation.
    logging_strategy : str, optional
        When to log metrics: ``"no"``, ``"steps"``, or ``"epoch"``.
    logging_steps : int, optional
        Logging frequency in steps (when ``logging_strategy="steps"``).
    logging_first_step : bool, optional
        Log metrics at the very first global step.
    torch_compile : bool, optional
        Compile the model with ``torch.compile()`` for speedup.
    torch_compile_backend : str or None, optional
        Backend for ``torch.compile`` (e.g. ``"inductor"``, ``"aot_eager"``).
    torch_compile_mode : str or None, optional
        Compilation mode: ``"default"``, ``"reduce-overhead"``, or ``"max-autotune"``.
    torch_compile_dynamic : bool, optional
        Allow dynamic shapes in the compiled model.
    torch_compile_full_graph : bool, optional
        Force compilation of the entire model as a single graph.
    max_grad_norm : float or None, optional
        Maximum gradient norm for clipping. ``None`` disables clipping.
    gradient_accumulation_steps : int, optional
        Accumulate gradients over this many steps before an optimiser update.
    save_strategy : str, optional
        Checkpoint save strategy: ``"no"``, ``"steps"``, or ``"epoch"``.
    save_steps : int, optional
        Checkpoint save frequency in steps (when ``save_strategy="steps"``).
    save_total_limit : int, optional
        Maximum number of checkpoints to keep; oldest are deleted first.
    save_safetensors : bool, optional
        Write weights as Safetensors (safer and HF-compatible) rather than pickle.
    save_on_each_node : bool, optional
        In multi-node training, save on every node rather than only rank 0.
        Do not enable when using shared storage.
    overwrite_output_dir : bool, optional
        Overwrite ``output_dir`` contents on a fresh run.
    resume_from_checkpoint : bool or str, optional
        ``True`` (default) automatically finds and loads the latest checkpoint,
        falling back to fresh initialisation if none exists. A path string loads
        that specific checkpoint. ``False`` forces fresh initialisation.
    load_best_model_at_end : bool, optional
        Restore the best checkpoint at the end of training. Requires
        ``save_strategy == eval_strategy``.
    metric_for_best_model : str, optional
        Metric used to compare checkpoints when ``load_best_model_at_end=True``.
    greater_is_better : bool or None, optional
        Whether a higher metric value is better. Auto-determined from the metric
        name when ``None``.
    lr_scheduler_type : str, optional
        LR scheduler type for the built-in AdamW path (``"linear"``, ``"cosine"``, etc.).
    lr_scheduler_kwargs : dict or None, optional
        Additional keyword arguments forwarded to the LR scheduler constructor.
    warmup_steps : int, optional
        Linear warmup steps from 0 to ``learning_rate``.
    learning_rate : float, optional
        Initial learning rate for the built-in AdamW optimiser.
    weight_decay : float, optional
        Weight decay applied to all parameters except bias and LayerNorm weights.
    adam_beta1 : float, optional
        Beta1 for the built-in AdamW optimiser.
    adam_beta2 : float, optional
        Beta2 for the built-in AdamW optimiser.
    adam_epsilon : float, optional
        Epsilon for the built-in AdamW optimiser.
    gradient_checkpointing : bool, optional
        Enable activation checkpointing on models that support the HF API.
        Pass ``enable_activation_checkpoint_fn`` to the Trainer constructor to
        customise the checkpointing behaviour.
    """

    output_dir: str = OUTPUTDIR_NAME
    logging_dir: str | None = None
    per_device_eval_batch_size: int = 16
    per_device_train_batch_size: int = 16
    num_train_epochs: int = 1
    device: Any = None

    seed: int = -1
    use_cpu: bool = False

    # Not in HF trainer; number of train-batches in an epoch, when dataset does not support len()
    # This just becomes a relative value for book-keeping.
    epoch_train_steps: int = 100000
    max_steps: int = -1

    dataloader_num_workers: int = 0
    dataloader_pin_memory: int = True
    dataloader_persistent_workers: bool = False
    dataloader_prefetch_factor: int | None = None
    dataloader_drop_last: bool = False

    # Strategy may also be: "no" | "steps" | "epoch"
    eval_strategy: str = "no"
    eval_steps: int = 100
    eval_delay: int = 0

    logging_strategy: str = "steps"
    logging_steps: int = 50
    logging_first_step: bool = False

    torch_compile: bool = False
    torch_compile_backend: str | None = None
    torch_compile_mode: str | None = "default"
    torch_compile_dynamic: bool = True
    torch_compile_full_graph: bool = False

    max_grad_norm: float | None = None
    gradient_accumulation_steps: int = 1

    # Checkpointing options
    save_strategy: str = "steps"
    save_steps: int = 1000
    save_total_limit: int = 2
    save_safetensors: bool = True
    save_on_each_node: bool = False
    overwrite_output_dir: bool = False
    resume_from_checkpoint: bool | str = True

    # Best model tracking and loading options
    load_best_model_at_end: bool = False
    metric_for_best_model: str = "loss"
    greater_is_better: bool | None = None  # Auto-determined from metric name

    # Compatibility with HF Trainer -- would be better if they took a factory arg...
    lr_scheduler_type: str = "linear"
    lr_scheduler_kwargs: dict | None = None
    warmup_steps: int = 0
    learning_rate: float = 5e-5
    weight_decay: float = 0.0
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1.0e-8

    # Enable gradient checkpointing (a.k.a activation checkpointing) on models which support the HF API
    gradient_checkpointing: bool = False

    def __str__(self):
        return pformat(self)

forgather.ml.trainer.trainer_types.TrainerState dataclass

Trainer state tracking training progress and configuration.

Maintains compatibility with HuggingFace Trainer API for easier porting. Passed to callbacks to allow them to inspect and log training progress.

Key training progress fields: - global_step: Total optimizer updates since start (0-indexed) - raw_epoch: Integer epoch counter (increments at end of each dataset iteration) - epoch_start_step: Global step when current epoch started - epoch: Continuous epoch value = raw_epoch + fractional progress through current epoch Computed as: epoch = raw_epoch + (global_step - epoch_start_step) / epoch_train_steps

Best model tracking (for load_best_model_at_end): - best_metric: Best metric value seen during training - best_model_checkpoint: Path to checkpoint with best metric

See: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py

Source code in src/forgather/ml/trainer/trainer_types.py
@dataclass(kw_only=True)
class TrainerState:
    """
    Trainer state tracking training progress and configuration.

    Maintains compatibility with HuggingFace Trainer API for easier porting.
    Passed to callbacks to allow them to inspect and log training progress.

    Key training progress fields:
    - global_step: Total optimizer updates since start (0-indexed)
    - raw_epoch: Integer epoch counter (increments at end of each dataset iteration)
    - epoch_start_step: Global step when current epoch started
    - epoch: Continuous epoch value = raw_epoch + fractional progress through current epoch
              Computed as: epoch = raw_epoch + (global_step - epoch_start_step) / epoch_train_steps

    Best model tracking (for load_best_model_at_end):
    - best_metric: Best metric value seen during training
    - best_model_checkpoint: Path to checkpoint with best metric

    See: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py
    """

    logging_steps: int  # How often to log metrics (in steps)
    eval_steps: int  # How often to run evaluation (in steps)
    train_batch_size: int  # Per-device training batch size
    max_steps: int  # Total optimizer updates planned
    epoch: float = 0.0  # Continuous epoch value (integer + fractional progress)
    global_step: int = 0  # Total optimizer updates completed (0-indexed)
    num_train_epochs: int  # Total epochs to train
    is_local_process_zero: bool = True  # True if rank 0 on this node
    is_world_process_zero: bool = True  # True if global rank 0
    log_history: list[Dict[str, float]] = field(
        default_factory=lambda: []
    )  # All logged metrics
    save_steps: int = 0  # How often to save checkpoints (in steps)
    best_metric: float | None = None  # Best metric value (for load_best_model_at_end)
    best_model_checkpoint: str | None = None  # Path to best checkpoint
    # HF compatibility fields (not fully implemented in all trainers)
    num_input_tokens_seen: int = 0  # Total input tokens processed
    total_flos: float = 0.0  # Total floating point operations
    is_hyper_param_search: bool = False  # Whether in hyperparameter search
    stateful_callbacks: List["TrainerCallback"] = field(default_factory=lambda: [])

    # Forgather extensions (not in HF Trainer)
    max_eval_steps: int  # Maximum eval steps to run (-1 for unlimited)
    epoch_start_step: int = 0  # Global step when current epoch started
    raw_epoch: int = 0  # Integer epoch counter (increments at end of dataset iteration)

forgather.ml.trainer.trainer_types.TrainerControl dataclass

Control flags for trainer execution flow.

Callbacks can return a modified TrainerControl to influence trainer behavior: - Trigger checkpointing: Set should_save = True - Trigger evaluation: Set should_evaluate = True - Trigger logging: Set should_log = True - Stop training gracefully: Set should_training_stop = True - Stop current epoch: Set should_epoch_stop = True - Abort without saving: Set should_abort_without_save = True

Compatible with HuggingFace Trainer API for easier callback porting.

Example callback usage: def on_step_end(self, args, state, control, **kwargs): if state.global_step % 1000 == 0: control.should_save = True # Force checkpoint every 1000 steps return control

Source code in src/forgather/ml/trainer/trainer_types.py
@dataclass(slots=True)
class TrainerControl:
    """
    Control flags for trainer execution flow.

    Callbacks can return a modified TrainerControl to influence trainer behavior:
    - Trigger checkpointing: Set should_save = True
    - Trigger evaluation: Set should_evaluate = True
    - Trigger logging: Set should_log = True
    - Stop training gracefully: Set should_training_stop = True
    - Stop current epoch: Set should_epoch_stop = True
    - Abort without saving: Set should_abort_without_save = True

    Compatible with HuggingFace Trainer API for easier callback porting.

    Example callback usage:
        def on_step_end(self, args, state, control, **kwargs):
            if state.global_step % 1000 == 0:
                control.should_save = True  # Force checkpoint every 1000 steps
            return control
    """

    should_training_stop: bool = False  # Stop training loop after current step
    should_epoch_stop: bool = False  # Stop current epoch after current step
    should_save: bool = False  # Trigger checkpoint save
    should_evaluate: bool = False  # Trigger evaluation
    should_log: bool = False  # Trigger metric logging

    # Forgather extension: abort without saving checkpoint
    should_abort_without_save: bool = False  # Abort training immediately without saving

forgather.ml.trainer.trainer_types.TrainOutput

Bases: NamedTuple

Source code in src/forgather/ml/trainer/trainer_types.py
class TrainOutput(NamedTuple):
    global_step: int
    metrics: Dict[str, float]

Base Classes

Abstract base from which all concrete trainers derive. Implement these three methods to build a custom trainer: _prepare, _train_loop, _eval_loop.

forgather.ml.trainer.base_trainer.BaseTrainer

Bases: ExtensibleTrainer, Stateful, StatefulProvider, Generic[TBaseTrainingArguments]

Abstract base class implementing common trainer infrastructure.

Provides callback management, checkpoint coordination, training-state tracking, and the PyTorch Stateful interface. The actual training and evaluation loops are left abstract so that concrete subclasses (Trainer, AccelTrainer, PipelineTrainer) can implement them with their own parallelism strategy.

This class intentionally mirrors the public surface of transformers.Trainer to make porting existing training scripts straightforward.

Parameters:

Name Type Description Default
args TBaseTrainingArguments

Training configuration dataclass. Must be an instance of BaseTrainingArguments or one of its subclasses.

required
model Module or None

Pre-constructed model. Either model or model_init must be provided. Default is None.

None
data_collator DataCollatorT or None

Callable that collates a list of dataset examples into a batch dict. Default is None.

None
train_dataset IterableDatasetT or None

Training dataset (torch.utils.data.Dataset or any iterable). Default is None.

None
eval_dataset IterableDatasetT or None

Evaluation dataset. Default is None.

None
processing_class PreprocessingClassT or None

Tokenizer or feature extractor saved alongside model weights. Default is None.

None
model_init Callable[[], Module] or None

Zero-argument factory that constructs the model. Required when model is None (e.g. for pipeline training where construction must happen inside the trainer). Default is None.

None
callbacks list of TrainerCallback or None

Callbacks to install. When None, default_callbacks() is used. Default is None.

None
compute_loss_func LossFunctionT or None

Custom loss function. When None, the trainer computes cross-entropy from model logits. Default is None.

None

Attributes:

Name Type Description
state TrainerState

Mutable training progress state (global step, epoch, log history, etc.).

control TrainerControl

Mutable flags set by callbacks to signal save/eval/stop actions.

checkpoint_manager CheckpointInterface or None

Set by _prepare() before the training loop starts.

Raises:

Type Description
AssertionError

If neither model nor model_init is provided.

AssertionError

If args.gradient_accumulation_steps is not greater than 0.

Notes

Concrete subclasses must implement three abstract methods:

  • _prepare(train_dataset, eval_dataset) — set up dataloaders, model, optimizer, and checkpoint manager.
  • _train_loop() — the main training iteration loop, returning TrainOutput.
  • _eval_loop() — the evaluation loop, returning a metrics dict.
Source code in src/forgather/ml/trainer/base_trainer.py
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
class BaseTrainer(
    ExtensibleTrainer, Stateful, StatefulProvider, Generic[TBaseTrainingArguments]
):
    """Abstract base class implementing common trainer infrastructure.

    Provides callback management, checkpoint coordination, training-state tracking,
    and the ``PyTorch Stateful`` interface. The actual training and evaluation loops
    are left abstract so that concrete subclasses (``Trainer``, ``AccelTrainer``,
    ``PipelineTrainer``) can implement them with their own parallelism strategy.

    This class intentionally mirrors the public surface of ``transformers.Trainer``
    to make porting existing training scripts straightforward.

    Parameters
    ----------
    args : TBaseTrainingArguments
        Training configuration dataclass. Must be an instance of
        ``BaseTrainingArguments`` or one of its subclasses.
    model : torch.nn.Module or None, optional
        Pre-constructed model. Either ``model`` or ``model_init`` must be
        provided. Default is ``None``.
    data_collator : DataCollatorT or None, optional
        Callable that collates a list of dataset examples into a batch dict.
        Default is ``None``.
    train_dataset : IterableDatasetT or None, optional
        Training dataset (``torch.utils.data.Dataset`` or any iterable).
        Default is ``None``.
    eval_dataset : IterableDatasetT or None, optional
        Evaluation dataset. Default is ``None``.
    processing_class : PreprocessingClassT or None, optional
        Tokenizer or feature extractor saved alongside model weights.
        Default is ``None``.
    model_init : Callable[[], torch.nn.Module] or None, optional
        Zero-argument factory that constructs the model. Required when ``model``
        is ``None`` (e.g. for pipeline training where construction must happen
        inside the trainer). Default is ``None``.
    callbacks : list of TrainerCallback or None, optional
        Callbacks to install. When ``None``, ``default_callbacks()`` is used.
        Default is ``None``.
    compute_loss_func : LossFunctionT or None, optional
        Custom loss function. When ``None``, the trainer computes cross-entropy
        from model logits. Default is ``None``.

    Attributes
    ----------
    state : TrainerState
        Mutable training progress state (global step, epoch, log history, etc.).
    control : TrainerControl
        Mutable flags set by callbacks to signal save/eval/stop actions.
    checkpoint_manager : CheckpointInterface or None
        Set by ``_prepare()`` before the training loop starts.

    Raises
    ------
    AssertionError
        If neither ``model`` nor ``model_init`` is provided.
    AssertionError
        If ``args.gradient_accumulation_steps`` is not greater than 0.

    Notes
    -----
    Concrete subclasses must implement three abstract methods:

    * ``_prepare(train_dataset, eval_dataset)`` — set up dataloaders, model,
      optimizer, and checkpoint manager.
    * ``_train_loop()`` — the main training iteration loop, returning
      ``TrainOutput``.
    * ``_eval_loop()`` — the evaluation loop, returning a metrics dict.
    """

    args: TBaseTrainingArguments
    model: torch.nn.Module | None
    data_collator: DataCollatorT | None
    train_dataset: IterableDatasetT | None
    eval_dataset: IterableDatasetT | None
    processing_class: PreprocessingClassT | None
    model_init: ModelConstructor | None
    callbacks: List[TrainerCallback]
    loss_fn: LossFunctionT | None
    train_dataloader: Iterable | None
    eval_dataloader: Iterable | None
    optimizer: OptimizerT | None
    lr_scheduler: LRSchedulerT | None
    is_local_process_zero: bool
    is_world_process_zero: bool
    num_processes: int
    checkpoint_manager: CheckpointInterface | None
    state: TrainerState
    control: TrainerControl

    @classmethod
    def default_callbacks(cls):
        """Return the default callbacks for this trainer class.

        Subclasses override this to provide callbacks that are always installed
        (e.g. ``ProgressCallback``, ``InfoCallback``). The base implementation
        returns an empty list.

        Returns
        -------
        list of TrainerCallback
            Default callback instances.
        """
        return []

    def __init__(
        self,
        args: TBaseTrainingArguments,
        model: torch.nn.Module | None = None,
        *,
        data_collator: Optional[DataCollatorT] = None,
        train_dataset: Optional[IterableDatasetT] = None,
        eval_dataset: Optional[IterableDatasetT] = None,
        processing_class: Optional[PreprocessingClassT] = None,
        model_init: Optional[ModelConstructor] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        compute_loss_func: Optional[LossFunctionT] = None,
    ):
        assert (
            model or model_init
        ), "Either a model or a model constructor must be specified"

        assert (
            args.gradient_accumulation_steps > 0
        ), "gradient_accumulation_steps must be > 0"

        self.model = model
        self.args = args
        self.data_collator = data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.processing_class = processing_class
        self.model_init = model_init
        if callbacks is None:
            self.callbacks = self.default_callbacks()
        else:
            self.callbacks = callbacks
        self.loss_fn = compute_loss_func

        # Init attributes
        self.train_dataloader = None
        self.eval_dataloader = None
        self.optimizer = None
        self.lr_scheduler = None
        self.is_local_process_zero = True
        self.is_world_process_zero = True
        self.num_processes = 1
        self.checkpoint_manager: CheckpointInterface | None = None

        self.state = TrainerState(
            logging_steps=args.logging_steps,
            eval_steps=args.eval_steps,
            train_batch_size=args.per_device_train_batch_size,
            max_steps=args.max_steps,
            num_train_epochs=args.num_train_epochs,
            max_eval_steps=args.max_eval_steps,
        )
        self.control = TrainerControl()
        # When a callback flips control.should_training_stop (or
        # should_abort_without_save), `_dispatch_event` records
        # (callback_name, event, reason) here so the trainer can log a
        # clear "stopped by X" message on the next iteration check.
        self._stop_requested_by: tuple[str, str, str] | None = None

        # Silence annoying Huggingface FastTokenizer warnings
        # If knows if it is safe or not, and does the right thing, why
        # do I need to hear about it and create a janky workaround for
        # a non-issue!?
        if self.args.dataloader_num_workers > 0:
            os.environ["TOKENIZERS_PARALLELISM"] = "false"

        if self.args.dynamo_recompile_limit:
            logger.info(
                f"Setting torch._dynamo.config.recompile_limit = {self.args.dynamo_recompile_limit}"
            )
            torch._dynamo.config.recompile_limit = self.args.dynamo_recompile_limit

        if self.args.detect_anomaly:
            logger.warning(
                "Enabling autograd detect anomaly; expect performance degradation"
            )
            torch.autograd.set_detect_anomaly(True)

        if self.args.float32_matmul_precision is not None:
            logger.info(
                f'Setting float32_matmul_precision to "{self.args.float32_matmul_precision}"'
            )
            torch.set_float32_matmul_precision(self.args.float32_matmul_precision)

        # Lazy index: event name -> list of callbacks that define that handler.
        # Built on first dispatch of each event, invalidated on callback add/remove.
        self._event_index: dict[str, list[TrainerCallback]] = {}

    def __repr__(self):
        return (
            f"{type(self).__name__}("
            f"model={self.model},"
            f"args={self.args},"
            f"data_collator={self.data_collator},"
            f"train_dataset={self.train_dataset},"
            f"eval_dataset={self.eval_dataset},"
            f"processing_class={self.processing_class},"
            f"model_init={self.model_init},"
            f"callbacks={self.callbacks},"
            ")"
        )

    # AbstractBaseTrainer
    @override
    def train(self, **kwargs) -> TrainOutput:
        """Run the full training loop.

        Applies any configured PyTorch context managers (SDPA backend, activation
        offloading), calls ``_prepare()`` to set up all components, then delegates
        to ``_train_loop()``.

        Returns
        -------
        TrainOutput
            Named tuple with ``global_step``, ``training_loss``, and ``metrics``.
        """
        with ExitStack() as exit_stack:
            backends = self._get_sdpa_backends(self.args.sdpa_backend)
            if backends:
                logger.info(
                    f"sdpa_backends={backends}, set_priority={self.args.sdpa_set_priority}"
                )
                exit_stack.enter_context(
                    sdpa_kernel(backends, set_priority=self.args.sdpa_set_priority)
                )
            if self.args.enable_activation_offloading:
                exit_stack.enter_context(
                    torch.autograd.graph.save_on_cpu(pin_memory=True)
                )
            self._prepare(
                train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
            )
            return self._train_loop()

    # AbstractBaseTrainer
    @override
    def evaluate(
        self, eval_dataset: Optional[BaseDataset] = None, **kwargs
    ) -> dict[str, float]:
        """Run evaluation on the given dataset.

        Applies the configured SDPA backend context, calls ``_prepare()`` with
        ``train_dataset=None``, then delegates to ``_eval_loop()``.

        Parameters
        ----------
        eval_dataset : BaseDataset or None, optional
            Dataset to evaluate on. Falls back to ``self.eval_dataset`` when
            ``None``. Default is ``None``.

        Returns
        -------
        dict of str to float
            Evaluation metrics, e.g. ``{"eval_loss": 1.23}``.
        """
        if eval_dataset is None:
            eval_dataset = self.eval_dataset

        with ExitStack() as exit_stack:
            backends = self._get_sdpa_backends(self.args.sdpa_backend)
            if backends:
                exit_stack.enter_context(
                    sdpa_kernel(backends, set_priority=self.args.sdpa_set_priority)
                )
            self._prepare(train_dataset=None, eval_dataset=eval_dataset)
            return self._eval_loop()

    # AbstractBaseTrainer
    @override
    def add_callback(self, callback: TrainerCallback):
        if isinstance(callback, type):
            callback = callback()
        self.callbacks.append(callback)
        self._event_index.clear()

    # AbstractBaseTrainer
    @override
    def pop_callback(self, callback: TrainerCallback) -> TrainerCallback | None:
        if isinstance(callback, type):
            compare = lambda a, b: type(a) == b
        else:
            compare = lambda a, b: id(a) == id(b)
        for i, cb in enumerate(self.callbacks):
            if compare(cb, callback):
                self._event_index.clear()
                return self.callbacks.pop(i)
        return None

    # AbstractBaseTrainer
    @override
    def remove_callback(self, callback: TrainerCallback):
        self.pop_callback(callback)

    def log(self, logs: Dict[str, float]):
        """Log metrics and dispatch the ``on_log`` event to all callbacks.

        Appends the metrics dict to ``state.log_history``, then fires the
        ``on_log`` callback event. Callbacks use this to write to TensorBoard,
        wandb, or other logging backends.

        Parameters
        ----------
        logs : dict of str to float
            Metrics to record, e.g. ``{"loss": 0.5, "lr": 1e-4}``.

        Returns
        -------
        TrainerControl
            Updated control object (callbacks may have set ``should_save``,
            ``should_evaluate``, etc.).
        """
        self.state.log_history.append(logs)

        return self._dispatch_event(
            "on_log",
            logs=logs,
        )

    @staticmethod
    def _get_sdpa_backends(
        backend: List[str | SDPBackend] | str | SDPBackend | None,  # type: ignore[valid-type]
    ) -> List[SDPBackend] | SDPBackend | None:  # type: ignore[valid-type]
        """Normalize SDPA backend specifications to ``SDPBackend`` enum values.

        Converts string names (``"math"``, ``"flash"``, ``"efficient"``,
        ``"cudnn"``) or already-resolved ``SDPBackend`` enums to the form
        expected by ``torch.nn.attention.sdpa_kernel``.

        Parameters
        ----------
        backend : list of str or SDPBackend, str, SDPBackend, or None
            Backend specification. ``None`` is passed through unchanged.

        Returns
        -------
        list of SDPBackend, SDPBackend, or None
            Resolved enum value(s), or ``None`` when input is ``None``.

        Raises
        ------
        ValueError
            If ``backend`` is not a ``str``, ``list``, ``SDPBackend``, or
            ``None``.
        """
        if backend is None:
            return None

        sdpa_mapping = {
            "math": SDPBackend.MATH,
            "flash": SDPBackend.FLASH_ATTENTION,
            "efficient": SDPBackend.EFFICIENT_ATTENTION,
            "cudnn": SDPBackend.CUDNN_ATTENTION,
        }

        def get_backend(b):
            if isinstance(b, SDPBackend):
                return b
            return sdpa_mapping[b]

        if isinstance(backend, str):
            return get_backend(backend)
        elif isinstance(backend, list):
            return [get_backend(i) for i in backend]
        else:
            raise ValueError("sdpa-backend must be a List[str] or str")

    def _dispatch_event(self, event: str, **kwargs):
        """Dispatch a trainer event to all registered callbacks.

        Uses a lazy per-event index (built on first dispatch, invalidated on
        callback add/remove) to skip callbacks that do not implement the
        requested handler method.

        Parameters
        ----------
        event : str
            Name of the callback method to invoke, e.g. ``"on_train_begin"``.
        **kwargs
            Extra arguments forwarded to each callback handler (``logs``,
            ``metrics``, etc.).

        Returns
        -------
        TrainerControl
            The ``self.control`` object after all callbacks have been invoked.
            Callbacks may mutate it in place or return a new one.
        """
        handlers = self._event_index.get(event)
        if handlers is None:
            handlers = [
                cb for cb in self.callbacks if callable(getattr(cb, event, None))
            ]
            self._event_index[event] = handlers

        if not handlers:
            return self.control

        unwrapped_model = self.unwrapped_model()
        for callback in handlers:
            # Snapshot stop-related flags before the callback runs so we
            # can attribute any state change to a specific callback. This
            # makes "training stopped by callback X" logging possible
            # without every callback having to shout about it individually.
            old_control = self.control
            prev_stop = bool(old_control.should_training_stop)
            prev_abort = bool(getattr(old_control, "should_abort_without_save", False))

            new_control = getattr(callback, event)(
                args=self.args,
                state=self.state,
                control=old_control,
                model=unwrapped_model,
                processing_class=self.processing_class,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                train_dataloader=self.train_dataloader,
                eval_dataloader=self.eval_dataloader,
                trainer=self,
                **kwargs,
            )

            if new_control is not None and new_control is not old_control:
                # Callbacks may either mutate the passed control in place
                # or return a new one. A small number of older callbacks
                # do both -- mutating flags on the received object and
                # then returning a fresh one. If we blindly replace
                # `self.control` with the returned object, those in-place
                # mutations would be lost. Propagate any stop flags that
                # were set on the old object forward to the new one so
                # the trainer's main stop check still sees them.
                if old_control.should_training_stop:
                    new_control.should_training_stop = True
                if getattr(old_control, "should_abort_without_save", False):
                    new_control.should_abort_without_save = True
                self.control = new_control

            # Record the first callback that flips should_training_stop or
            # should_abort_without_save. `_stop_requested_by` is read by
            # the training loop when it logs "Training stopped ..." so the
            # cause is visible even when the callback itself is silent
            # (DivergenceDetector does log, but trainer_control callbacks
            # set the flag via RPC without a clear local log line).
            # Check both the old control (in case the callback mutated it
            # in place) and the new control (in case it was replaced) so
            # we catch the transition regardless of the callback pattern.
            now_stop = bool(self.control.should_training_stop) or bool(
                old_control.should_training_stop
            )
            now_abort = bool(
                getattr(self.control, "should_abort_without_save", False)
            ) or bool(getattr(old_control, "should_abort_without_save", False))
            if (now_stop and not prev_stop) or (now_abort and not prev_abort):
                # Prefer "abort" when both transitions happen in the same
                # call; the trainer's abort path is more restrictive and
                # that's the user-facing distinction that matters.
                reason = "abort" if (now_abort and not prev_abort) else "stop"
                if getattr(self, "_stop_requested_by", None) is None:
                    self._stop_requested_by = (callback.name, event, reason)

        return self.control

    def unwrapped_model(self) -> torch.nn.Module:
        """Return the underlying model, free of any distributed wrappers.

        Subclasses that wrap ``self.model`` in DDP, FSDP, Accelerate, or
        pipeline-parallel containers override this method to strip those wrappers
        before the model is passed to callbacks.

        Returns
        -------
        torch.nn.Module
            The bare model without any framework wrapper.
        """
        assert self.model
        return self.model

    # AbstractBaseTrainer
    @override
    def save_model(self, output_dir: Optional[os.PathLike | str] = None) -> None:
        """Save model weights and the preprocessing class (HF Trainer API compatibility).

        Writes only the model weights to ``output_dir`` (or ``args.output_dir``
        when ``output_dir`` is ``None``). The full training state (optimizer,
        scheduler, RNG, etc.) is **not** saved. For resumable training, prefer
        ``save_checkpoint()``.

        Parameters
        ----------
        output_dir : path-like or str, optional
            Destination directory. Defaults to ``args.output_dir``.
        """
        assert self.checkpoint_manager
        self.checkpoint_manager.save_model(
            output_dir=output_dir, overwrite_output_dir=self.args.overwrite_output_dir
        )

    # AbstractBaseTrainer
    @override
    def save_checkpoint(self, checkpoint_path=None) -> None:
        """Save a complete training checkpoint.

        Writes all training state to a timestamped directory under
        ``args.output_dir``. The following components are always saved:

        * Model weights (required for resuming)
        * Optimizer state (momentum buffers, adaptive learning rates, etc.)
        * LR scheduler state (current step position)
        * Training progress (``global_step``, epoch counter, etc.)
        * Dataset position (when the dataloader is stateful)
        * Random number generator states (for bit-exact reproducibility)

        Parameters
        ----------
        checkpoint_path : path-like or str, optional
            Explicit checkpoint directory path. When ``None``, a path is
            auto-generated under ``args.output_dir`` based on the current
            step count.
        """
        assert self.checkpoint_manager
        self.checkpoint_manager.save_checkpoint(checkpoint_path)

    # AbstractBaseTrainer
    @override
    def load_checkpoint(self, checkpoint_path=None) -> None:
        """Load a training checkpoint to resume training.

        Restores all available training state from the specified checkpoint
        directory. Each component is loaded only if its file exists:

        * Model weights (always required — raises if missing)
        * Optimizer state
        * LR scheduler state
        * Training progress (``global_step``, epoch, etc.)
        * Dataset position
        * Random number generator states

        When ``checkpoint_path`` is ``None``, the latest checkpoint under
        ``args.output_dir`` is located automatically.

        To intentionally skip reloading a component, delete its file from the
        checkpoint directory before calling this method. The checkpoint system
        logs a warning for each missing file but continues loading the rest.

        Parameters
        ----------
        checkpoint_path : path-like or str, optional
            Path to the checkpoint directory. ``None`` auto-selects the latest
            checkpoint under ``args.output_dir``.
        """
        assert self.checkpoint_manager
        self.checkpoint_manager.load_checkpoint(checkpoint_path)

    # StatefulProvider
    @override
    def get_state_components(self) -> List[StateComponent]:
        """Return state components for checkpoint save/load.

        Describes every piece of training state that should be persisted. The
        checkpoint manager calls this method to determine what to save and how
        state is shared across distributed ranks.

        For the single-GPU base trainer all components use ``GLOBAL`` sharing
        except RNG which uses ``PER_RANK``.

        Returned components (in order):

        * ``"model"`` — model weights, **required**
        * ``"optimizer"`` — optimizer state (optional)
        * ``"scheduler"`` — LR scheduler state (optional)
        * ``"trainer"`` — training progress counters (optional)
        * ``"dataset"`` — dataloader position, only when stateful (optional)
        * ``"rng"`` — per-rank RNG state (optional)

        Returns
        -------
        list of StateComponent
            All checkpointable state components with their sharing patterns.
        """
        components = []
        assert self.model is not None
        # Model - REQUIRED (always must be present).
        # cast: nn.Module.load_state_dict returns _IncompatibleKeys, not None,
        # so it doesn't satisfy the Stateful protocol strictly. This is a PyTorch
        # library limitation; at runtime it works correctly.
        components.append(
            StateComponent(
                key="model",
                stateful=cast(Stateful, self.model),
                sharing_pattern=SharingPattern.GLOBAL,
                required=True,  # Model is always required
            )
        )

        # Optimizer - optional (allows changing optimizer type)
        if self.optimizer is not None:
            components.append(
                StateComponent(
                    key="optimizer",
                    stateful=cast(Stateful, self.optimizer),
                    sharing_pattern=SharingPattern.GLOBAL,
                    required=False,
                )
            )

        # LR Scheduler - optional (allows changing scheduler type)
        if self.lr_scheduler is not None:
            components.append(
                StateComponent(
                    key="scheduler",
                    stateful=cast(Stateful, self.lr_scheduler),
                    sharing_pattern=SharingPattern.GLOBAL,
                    required=False,
                )
            )

        # Trainer state - optional (allows fresh training progress)
        components.append(
            StateComponent(
                key="trainer",
                stateful=self,
                sharing_pattern=SharingPattern.GLOBAL,
                required=False,
            )
        )

        # Dataset state - optional, only if dataloader is stateful
        if self.train_dataloader is not None and hasattr(
            self.train_dataloader, "state_dict"
        ):
            components.append(
                StateComponent(
                    key="dataset",
                    stateful=cast(Stateful, self.train_dataloader),
                    sharing_pattern=self._get_dataset_sharing_pattern(),
                    required=False,
                )
            )

        # RNG state - optional (allows fresh randomization)
        components.append(
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=False,
            )
        )

        return components

    def _get_dataset_sharing_pattern(self) -> SharingPattern:
        """Return the dataset state sharing pattern for this trainer.

        The pattern depends on the dataloader strategy:

        * ``GLOBAL`` — rank 0 loads and dispatches data (``DataloaderDispatcher``).
        * ``PER_RANK`` — each rank loads its own shard independently.

        The base implementation always returns ``GLOBAL`` (single-GPU training).
        Subclasses with distributed training should override this method.

        Returns
        -------
        SharingPattern
            ``SharingPattern.GLOBAL`` for the single-GPU base trainer.
        """
        # For single-GPU trainer, dataset is GLOBAL
        # Subclasses with distributed training should override this method
        return SharingPattern.GLOBAL

    def get_process_groups(self) -> Dict[str, Any]:
        """Return named process groups for ``PER_GROUP`` sharing pattern.

        The checkpoint manager uses this to coordinate group-level saves (e.g.
        tensor-parallel replicas). Single-GPU trainers have no process groups.
        Subclasses implementing hybrid parallelism should override this method.

        Returns
        -------
        dict of str to Any
            Empty dict for the single-GPU base trainer.
        """
        return {}

    # Stateful
    @override
    def load_state_dict(self, state_dict):
        """Restore trainer progress state from a checkpoint state dict.

        Implements the ``torch.distributed.checkpoint.stateful.Stateful``
        interface. Restores step counters and progress tracking so training
        resumes at the exact point where it was saved. Also restores the
        ``GradScaler`` state when fp16 AMP is active.

        Parameters
        ----------
        state_dict : dict
            State dictionary previously returned by ``state_dict()``. Expected
            keys: ``global_step``, ``epoch_start_step``, ``raw_epoch``,
            ``num_input_tokens_seen``, ``total_flos``, and optionally
            ``grad_scaler``.
        """
        self.state.global_step = state_dict["global_step"]
        self.state.epoch_start_step = state_dict["epoch_start_step"]
        self.state.raw_epoch = state_dict["raw_epoch"]
        self.state.num_input_tokens_seen = state_dict["num_input_tokens_seen"]
        self.state.total_flos = state_dict["total_flos"]

        # Restore GradScaler state if present and AMP is active
        if "grad_scaler" in state_dict and hasattr(self, "amp_context"):
            self.amp_context.load_state_dict(state_dict["grad_scaler"])

    # Stateful
    @override
    def state_dict(self):
        """Return trainer progress state for checkpointing.

        Implements the ``torch.distributed.checkpoint.stateful.Stateful``
        interface. The returned dict is consumed by ``load_state_dict()`` to
        restore training from the exact saved point.

        Returns
        -------
        dict
            Training state with the following keys:

            * ``global_step`` — total optimizer updates performed.
            * ``epoch_start_step`` — global step at the start of the current epoch.
            * ``raw_epoch`` — integer epoch counter.
            * ``num_input_tokens_seen`` — total tokens processed (for throughput logging).
            * ``total_flos`` — total floating-point operations (for efficiency metrics).
            * ``grad_scaler`` — ``GradScaler`` state dict (only when fp16 AMP is active).
        """
        state = {
            "global_step": self.state.global_step,
            "epoch_start_step": self.state.epoch_start_step,
            "raw_epoch": self.state.raw_epoch,
            "num_input_tokens_seen": self.state.num_input_tokens_seen,
            "total_flos": self.state.total_flos,
        }

        # Save GradScaler state if AMP fp16 is active
        if hasattr(self, "amp_context"):
            scaler_state = self.amp_context.state_dict()
            if scaler_state:
                state["grad_scaler"] = scaler_state

        return state

    @abstractmethod
    def _prepare(
        self,
        train_dataset: Optional[BaseDataset],
        eval_dataset: Optional[BaseDataset],
    ) -> None:
        """Prepare all components for training and/or evaluation.

        Called at the start of ``train()`` or ``evaluate()``. Either dataset
        argument may be ``None`` (eval-only or train-only runs).

        Subclasses implement a typical sequence of:

        1. Create dataloaders from datasets.
        2. Initialize and move the model to the target device(s).
        3. Set up the optimizer and LR scheduler (training only).
        4. Wrap components for distributed training (DDP, Accelerate, etc.).
        5. Initialize the ``TrainerState``.
        6. Create the ``CheckpointManager``.
        7. Load from checkpoint when ``args.resume_from_checkpoint`` is set.

        Parameters
        ----------
        train_dataset : BaseDataset or None
            Training dataset. ``None`` when only evaluation is requested.
        eval_dataset : BaseDataset or None
            Evaluation dataset. ``None`` when only training is requested.
        """
        pass

    @abstractmethod
    def _train_loop(self) -> TrainOutput:
        """Execute the main training loop.

        Concrete implementations iterate over epochs and batches, perform
        forward/backward passes with gradient accumulation, step the optimizer
        and LR scheduler, dispatch callback events, and handle periodic logging,
        evaluation, and checkpointing.

        Returns
        -------
        TrainOutput
            Named tuple with ``global_step``, ``training_loss``, and a metrics
            dict aggregated over the full training run.
        """
        pass

    @abstractmethod
    def _eval_loop(self) -> dict[str, float]:
        """Execute the evaluation loop.

        Concrete implementations iterate over the evaluation dataset, perform
        forward-only passes (no gradient computation), aggregate metrics, and
        dispatch callback events.

        Returns
        -------
        dict of str to float
            Computed evaluation metrics, e.g. ``{"eval_loss": 0.5}``.
        """
        pass

default_callbacks() classmethod

Return the default callbacks for this trainer class.

Subclasses override this to provide callbacks that are always installed (e.g. ProgressCallback, InfoCallback). The base implementation returns an empty list.

Returns:

Type Description
list of TrainerCallback

Default callback instances.

Source code in src/forgather/ml/trainer/base_trainer.py
@classmethod
def default_callbacks(cls):
    """Return the default callbacks for this trainer class.

    Subclasses override this to provide callbacks that are always installed
    (e.g. ``ProgressCallback``, ``InfoCallback``). The base implementation
    returns an empty list.

    Returns
    -------
    list of TrainerCallback
        Default callback instances.
    """
    return []

train(**kwargs)

Run the full training loop.

Applies any configured PyTorch context managers (SDPA backend, activation offloading), calls _prepare() to set up all components, then delegates to _train_loop().

Returns:

Type Description
TrainOutput

Named tuple with global_step, training_loss, and metrics.

Source code in src/forgather/ml/trainer/base_trainer.py
@override
def train(self, **kwargs) -> TrainOutput:
    """Run the full training loop.

    Applies any configured PyTorch context managers (SDPA backend, activation
    offloading), calls ``_prepare()`` to set up all components, then delegates
    to ``_train_loop()``.

    Returns
    -------
    TrainOutput
        Named tuple with ``global_step``, ``training_loss``, and ``metrics``.
    """
    with ExitStack() as exit_stack:
        backends = self._get_sdpa_backends(self.args.sdpa_backend)
        if backends:
            logger.info(
                f"sdpa_backends={backends}, set_priority={self.args.sdpa_set_priority}"
            )
            exit_stack.enter_context(
                sdpa_kernel(backends, set_priority=self.args.sdpa_set_priority)
            )
        if self.args.enable_activation_offloading:
            exit_stack.enter_context(
                torch.autograd.graph.save_on_cpu(pin_memory=True)
            )
        self._prepare(
            train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
        )
        return self._train_loop()

evaluate(eval_dataset=None, **kwargs)

Run evaluation on the given dataset.

Applies the configured SDPA backend context, calls _prepare() with train_dataset=None, then delegates to _eval_loop().

Parameters:

Name Type Description Default
eval_dataset BaseDataset or None

Dataset to evaluate on. Falls back to self.eval_dataset when None. Default is None.

None

Returns:

Type Description
dict of str to float

Evaluation metrics, e.g. {"eval_loss": 1.23}.

Source code in src/forgather/ml/trainer/base_trainer.py
@override
def evaluate(
    self, eval_dataset: Optional[BaseDataset] = None, **kwargs
) -> dict[str, float]:
    """Run evaluation on the given dataset.

    Applies the configured SDPA backend context, calls ``_prepare()`` with
    ``train_dataset=None``, then delegates to ``_eval_loop()``.

    Parameters
    ----------
    eval_dataset : BaseDataset or None, optional
        Dataset to evaluate on. Falls back to ``self.eval_dataset`` when
        ``None``. Default is ``None``.

    Returns
    -------
    dict of str to float
        Evaluation metrics, e.g. ``{"eval_loss": 1.23}``.
    """
    if eval_dataset is None:
        eval_dataset = self.eval_dataset

    with ExitStack() as exit_stack:
        backends = self._get_sdpa_backends(self.args.sdpa_backend)
        if backends:
            exit_stack.enter_context(
                sdpa_kernel(backends, set_priority=self.args.sdpa_set_priority)
            )
        self._prepare(train_dataset=None, eval_dataset=eval_dataset)
        return self._eval_loop()

log(logs)

Log metrics and dispatch the on_log event to all callbacks.

Appends the metrics dict to state.log_history, then fires the on_log callback event. Callbacks use this to write to TensorBoard, wandb, or other logging backends.

Parameters:

Name Type Description Default
logs dict of str to float

Metrics to record, e.g. {"loss": 0.5, "lr": 1e-4}.

required

Returns:

Type Description
TrainerControl

Updated control object (callbacks may have set should_save, should_evaluate, etc.).

Source code in src/forgather/ml/trainer/base_trainer.py
def log(self, logs: Dict[str, float]):
    """Log metrics and dispatch the ``on_log`` event to all callbacks.

    Appends the metrics dict to ``state.log_history``, then fires the
    ``on_log`` callback event. Callbacks use this to write to TensorBoard,
    wandb, or other logging backends.

    Parameters
    ----------
    logs : dict of str to float
        Metrics to record, e.g. ``{"loss": 0.5, "lr": 1e-4}``.

    Returns
    -------
    TrainerControl
        Updated control object (callbacks may have set ``should_save``,
        ``should_evaluate``, etc.).
    """
    self.state.log_history.append(logs)

    return self._dispatch_event(
        "on_log",
        logs=logs,
    )

unwrapped_model()

Return the underlying model, free of any distributed wrappers.

Subclasses that wrap self.model in DDP, FSDP, Accelerate, or pipeline-parallel containers override this method to strip those wrappers before the model is passed to callbacks.

Returns:

Type Description
Module

The bare model without any framework wrapper.

Source code in src/forgather/ml/trainer/base_trainer.py
def unwrapped_model(self) -> torch.nn.Module:
    """Return the underlying model, free of any distributed wrappers.

    Subclasses that wrap ``self.model`` in DDP, FSDP, Accelerate, or
    pipeline-parallel containers override this method to strip those wrappers
    before the model is passed to callbacks.

    Returns
    -------
    torch.nn.Module
        The bare model without any framework wrapper.
    """
    assert self.model
    return self.model

save_model(output_dir=None)

Save model weights and the preprocessing class (HF Trainer API compatibility).

Writes only the model weights to output_dir (or args.output_dir when output_dir is None). The full training state (optimizer, scheduler, RNG, etc.) is not saved. For resumable training, prefer save_checkpoint().

Parameters:

Name Type Description Default
output_dir path - like or str

Destination directory. Defaults to args.output_dir.

None
Source code in src/forgather/ml/trainer/base_trainer.py
@override
def save_model(self, output_dir: Optional[os.PathLike | str] = None) -> None:
    """Save model weights and the preprocessing class (HF Trainer API compatibility).

    Writes only the model weights to ``output_dir`` (or ``args.output_dir``
    when ``output_dir`` is ``None``). The full training state (optimizer,
    scheduler, RNG, etc.) is **not** saved. For resumable training, prefer
    ``save_checkpoint()``.

    Parameters
    ----------
    output_dir : path-like or str, optional
        Destination directory. Defaults to ``args.output_dir``.
    """
    assert self.checkpoint_manager
    self.checkpoint_manager.save_model(
        output_dir=output_dir, overwrite_output_dir=self.args.overwrite_output_dir
    )

save_checkpoint(checkpoint_path=None)

Save a complete training checkpoint.

Writes all training state to a timestamped directory under args.output_dir. The following components are always saved:

  • Model weights (required for resuming)
  • Optimizer state (momentum buffers, adaptive learning rates, etc.)
  • LR scheduler state (current step position)
  • Training progress (global_step, epoch counter, etc.)
  • Dataset position (when the dataloader is stateful)
  • Random number generator states (for bit-exact reproducibility)

Parameters:

Name Type Description Default
checkpoint_path path - like or str

Explicit checkpoint directory path. When None, a path is auto-generated under args.output_dir based on the current step count.

None
Source code in src/forgather/ml/trainer/base_trainer.py
@override
def save_checkpoint(self, checkpoint_path=None) -> None:
    """Save a complete training checkpoint.

    Writes all training state to a timestamped directory under
    ``args.output_dir``. The following components are always saved:

    * Model weights (required for resuming)
    * Optimizer state (momentum buffers, adaptive learning rates, etc.)
    * LR scheduler state (current step position)
    * Training progress (``global_step``, epoch counter, etc.)
    * Dataset position (when the dataloader is stateful)
    * Random number generator states (for bit-exact reproducibility)

    Parameters
    ----------
    checkpoint_path : path-like or str, optional
        Explicit checkpoint directory path. When ``None``, a path is
        auto-generated under ``args.output_dir`` based on the current
        step count.
    """
    assert self.checkpoint_manager
    self.checkpoint_manager.save_checkpoint(checkpoint_path)

load_checkpoint(checkpoint_path=None)

Load a training checkpoint to resume training.

Restores all available training state from the specified checkpoint directory. Each component is loaded only if its file exists:

  • Model weights (always required — raises if missing)
  • Optimizer state
  • LR scheduler state
  • Training progress (global_step, epoch, etc.)
  • Dataset position
  • Random number generator states

When checkpoint_path is None, the latest checkpoint under args.output_dir is located automatically.

To intentionally skip reloading a component, delete its file from the checkpoint directory before calling this method. The checkpoint system logs a warning for each missing file but continues loading the rest.

Parameters:

Name Type Description Default
checkpoint_path path - like or str

Path to the checkpoint directory. None auto-selects the latest checkpoint under args.output_dir.

None
Source code in src/forgather/ml/trainer/base_trainer.py
@override
def load_checkpoint(self, checkpoint_path=None) -> None:
    """Load a training checkpoint to resume training.

    Restores all available training state from the specified checkpoint
    directory. Each component is loaded only if its file exists:

    * Model weights (always required — raises if missing)
    * Optimizer state
    * LR scheduler state
    * Training progress (``global_step``, epoch, etc.)
    * Dataset position
    * Random number generator states

    When ``checkpoint_path`` is ``None``, the latest checkpoint under
    ``args.output_dir`` is located automatically.

    To intentionally skip reloading a component, delete its file from the
    checkpoint directory before calling this method. The checkpoint system
    logs a warning for each missing file but continues loading the rest.

    Parameters
    ----------
    checkpoint_path : path-like or str, optional
        Path to the checkpoint directory. ``None`` auto-selects the latest
        checkpoint under ``args.output_dir``.
    """
    assert self.checkpoint_manager
    self.checkpoint_manager.load_checkpoint(checkpoint_path)

get_state_components()

Return state components for checkpoint save/load.

Describes every piece of training state that should be persisted. The checkpoint manager calls this method to determine what to save and how state is shared across distributed ranks.

For the single-GPU base trainer all components use GLOBAL sharing except RNG which uses PER_RANK.

Returned components (in order):

  • "model" — model weights, required
  • "optimizer" — optimizer state (optional)
  • "scheduler" — LR scheduler state (optional)
  • "trainer" — training progress counters (optional)
  • "dataset" — dataloader position, only when stateful (optional)
  • "rng" — per-rank RNG state (optional)

Returns:

Type Description
list of StateComponent

All checkpointable state components with their sharing patterns.

Source code in src/forgather/ml/trainer/base_trainer.py
@override
def get_state_components(self) -> List[StateComponent]:
    """Return state components for checkpoint save/load.

    Describes every piece of training state that should be persisted. The
    checkpoint manager calls this method to determine what to save and how
    state is shared across distributed ranks.

    For the single-GPU base trainer all components use ``GLOBAL`` sharing
    except RNG which uses ``PER_RANK``.

    Returned components (in order):

    * ``"model"`` — model weights, **required**
    * ``"optimizer"`` — optimizer state (optional)
    * ``"scheduler"`` — LR scheduler state (optional)
    * ``"trainer"`` — training progress counters (optional)
    * ``"dataset"`` — dataloader position, only when stateful (optional)
    * ``"rng"`` — per-rank RNG state (optional)

    Returns
    -------
    list of StateComponent
        All checkpointable state components with their sharing patterns.
    """
    components = []
    assert self.model is not None
    # Model - REQUIRED (always must be present).
    # cast: nn.Module.load_state_dict returns _IncompatibleKeys, not None,
    # so it doesn't satisfy the Stateful protocol strictly. This is a PyTorch
    # library limitation; at runtime it works correctly.
    components.append(
        StateComponent(
            key="model",
            stateful=cast(Stateful, self.model),
            sharing_pattern=SharingPattern.GLOBAL,
            required=True,  # Model is always required
        )
    )

    # Optimizer - optional (allows changing optimizer type)
    if self.optimizer is not None:
        components.append(
            StateComponent(
                key="optimizer",
                stateful=cast(Stateful, self.optimizer),
                sharing_pattern=SharingPattern.GLOBAL,
                required=False,
            )
        )

    # LR Scheduler - optional (allows changing scheduler type)
    if self.lr_scheduler is not None:
        components.append(
            StateComponent(
                key="scheduler",
                stateful=cast(Stateful, self.lr_scheduler),
                sharing_pattern=SharingPattern.GLOBAL,
                required=False,
            )
        )

    # Trainer state - optional (allows fresh training progress)
    components.append(
        StateComponent(
            key="trainer",
            stateful=self,
            sharing_pattern=SharingPattern.GLOBAL,
            required=False,
        )
    )

    # Dataset state - optional, only if dataloader is stateful
    if self.train_dataloader is not None and hasattr(
        self.train_dataloader, "state_dict"
    ):
        components.append(
            StateComponent(
                key="dataset",
                stateful=cast(Stateful, self.train_dataloader),
                sharing_pattern=self._get_dataset_sharing_pattern(),
                required=False,
            )
        )

    # RNG state - optional (allows fresh randomization)
    components.append(
        StateComponent(
            key="rng",
            stateful=RNGState(),
            sharing_pattern=SharingPattern.PER_RANK,
            required=False,
        )
    )

    return components

get_process_groups()

Return named process groups for PER_GROUP sharing pattern.

The checkpoint manager uses this to coordinate group-level saves (e.g. tensor-parallel replicas). Single-GPU trainers have no process groups. Subclasses implementing hybrid parallelism should override this method.

Returns:

Type Description
dict of str to Any

Empty dict for the single-GPU base trainer.

Source code in src/forgather/ml/trainer/base_trainer.py
def get_process_groups(self) -> Dict[str, Any]:
    """Return named process groups for ``PER_GROUP`` sharing pattern.

    The checkpoint manager uses this to coordinate group-level saves (e.g.
    tensor-parallel replicas). Single-GPU trainers have no process groups.
    Subclasses implementing hybrid parallelism should override this method.

    Returns
    -------
    dict of str to Any
        Empty dict for the single-GPU base trainer.
    """
    return {}

load_state_dict(state_dict)

Restore trainer progress state from a checkpoint state dict.

Implements the torch.distributed.checkpoint.stateful.Stateful interface. Restores step counters and progress tracking so training resumes at the exact point where it was saved. Also restores the GradScaler state when fp16 AMP is active.

Parameters:

Name Type Description Default
state_dict dict

State dictionary previously returned by state_dict(). Expected keys: global_step, epoch_start_step, raw_epoch, num_input_tokens_seen, total_flos, and optionally grad_scaler.

required
Source code in src/forgather/ml/trainer/base_trainer.py
@override
def load_state_dict(self, state_dict):
    """Restore trainer progress state from a checkpoint state dict.

    Implements the ``torch.distributed.checkpoint.stateful.Stateful``
    interface. Restores step counters and progress tracking so training
    resumes at the exact point where it was saved. Also restores the
    ``GradScaler`` state when fp16 AMP is active.

    Parameters
    ----------
    state_dict : dict
        State dictionary previously returned by ``state_dict()``. Expected
        keys: ``global_step``, ``epoch_start_step``, ``raw_epoch``,
        ``num_input_tokens_seen``, ``total_flos``, and optionally
        ``grad_scaler``.
    """
    self.state.global_step = state_dict["global_step"]
    self.state.epoch_start_step = state_dict["epoch_start_step"]
    self.state.raw_epoch = state_dict["raw_epoch"]
    self.state.num_input_tokens_seen = state_dict["num_input_tokens_seen"]
    self.state.total_flos = state_dict["total_flos"]

    # Restore GradScaler state if present and AMP is active
    if "grad_scaler" in state_dict and hasattr(self, "amp_context"):
        self.amp_context.load_state_dict(state_dict["grad_scaler"])

state_dict()

Return trainer progress state for checkpointing.

Implements the torch.distributed.checkpoint.stateful.Stateful interface. The returned dict is consumed by load_state_dict() to restore training from the exact saved point.

Returns:

Type Description
dict

Training state with the following keys:

  • global_step — total optimizer updates performed.
  • epoch_start_step — global step at the start of the current epoch.
  • raw_epoch — integer epoch counter.
  • num_input_tokens_seen — total tokens processed (for throughput logging).
  • total_flos — total floating-point operations (for efficiency metrics).
  • grad_scalerGradScaler state dict (only when fp16 AMP is active).
Source code in src/forgather/ml/trainer/base_trainer.py
@override
def state_dict(self):
    """Return trainer progress state for checkpointing.

    Implements the ``torch.distributed.checkpoint.stateful.Stateful``
    interface. The returned dict is consumed by ``load_state_dict()`` to
    restore training from the exact saved point.

    Returns
    -------
    dict
        Training state with the following keys:

        * ``global_step`` — total optimizer updates performed.
        * ``epoch_start_step`` — global step at the start of the current epoch.
        * ``raw_epoch`` — integer epoch counter.
        * ``num_input_tokens_seen`` — total tokens processed (for throughput logging).
        * ``total_flos`` — total floating-point operations (for efficiency metrics).
        * ``grad_scaler`` — ``GradScaler`` state dict (only when fp16 AMP is active).
    """
    state = {
        "global_step": self.state.global_step,
        "epoch_start_step": self.state.epoch_start_step,
        "raw_epoch": self.state.raw_epoch,
        "num_input_tokens_seen": self.state.num_input_tokens_seen,
        "total_flos": self.state.total_flos,
    }

    # Save GradScaler state if AMP fp16 is active
    if hasattr(self, "amp_context"):
        scaler_state = self.amp_context.state_dict()
        if scaler_state:
            state["grad_scaler"] = scaler_state

    return state

forgather.ml.trainer.base_trainer.BaseTrainingArguments dataclass

Bases: MinimalTrainingArguments

Extended training arguments with checkpoint management and PyTorch optimizations.

Extends MinimalTrainingArguments with full checkpoint state preservation and a range of PyTorch runtime optimizations (mixed precision, FP8, SDPA backend selection, activation offloading, etc.).

All training state (model, optimizer, scheduler, dataset position, RNG state) is automatically saved in checkpoints. To skip loading a specific component when resuming, manually delete its file from the checkpoint directory before calling train().

.. note:: The checkpoint-related options in this class are not compatible with the HuggingFace Trainer. Use MinimalTrainingArguments when HF compatibility is required.

Parameters:

Name Type Description Default
default_dtype str or None

Default torch.dtype for model construction (e.g. "float32", "bfloat16", "float16"). None leaves PyTorch's global default unchanged. Default is None.

None
max_eval_steps int

Maximum number of evaluation steps per evaluation call. -1 runs the full evaluation dataset. Default is -1.

-1
preserve_best_model bool

If True, keep the checkpoint with the best value of best_model_metric protected from cleanup rotation. Default is False.

False
best_model_metric str

Name of the metric used to select the best checkpoint when preserve_best_model=True. Default is "loss".

'loss'
best_model_greater_is_better bool or None

Whether higher values of best_model_metric are better. None auto-detects from the metric name (metrics containing "loss" or "perplexity" default to lower-is-better). Default is None.

None
preserve_n_best int

Number of best checkpoints to keep safe from save_total_limit cleanup. Default is 1.

1
eval_on_save bool

Force an evaluation pass before each checkpoint save. Useful for decoupling the save and eval schedules. Default is False.

False
enable_activation_offloading bool

Offload saved activation tensors to CPU during the backward pass to reduce peak GPU memory. Best combined with activation checkpointing. Trades GPU memory for CPU memory bandwidth. Default is False.

False
detect_anomaly bool

Enable torch.autograd anomaly detection for debugging NaN/Inf gradients. Adds significant overhead — use only for debugging. Default is False.

False
sdpa_backend list of str, str, or None

Scaled Dot-Product Attention backend(s). Valid string values are "math", "flash", "efficient", and "cudnn". Pass a list to specify multiple backends; if sdpa_set_priority=True, the list is treated as a priority order. None uses PyTorch's default selection. Default is None.

None
sdpa_set_priority bool

When sdpa_backend is a list, interpret it as a priority order rather than requiring all backends to be available. Default is False.

False
float32_matmul_precision str or None

Float32 matrix-multiplication precision on Ampere+ GPUs. One of "highest" (full IEEE, slowest), "high" (TF32, ~10–20 % speedup), or "medium" (more aggressive, may impact accuracy). None leaves the PyTorch default unchanged. Default is None.

None
dynamo_recompile_limit int or None

Override torch._dynamo.config.recompile_limit. Increase when torch.compile() produces frequent recompilation warnings. None leaves the default unchanged. Default is None.

None
mixed_precision str or None

Automatic Mixed Precision mode. None or "no" disables AMP. "bf16" enables bfloat16 autocast without loss scaling (recommended for Ampere+ GPUs). "fp16" enables float16 autocast with GradScaler loss scaling. Default is None.

None
fp8_recipe str or None

FP8 training recipe via torchao. Converts nn.Linear layers to Float8Linear. One of "tensorwise" (fastest), "rowwise" (more accurate), or "rowwise_with_gw_hp" (most accurate). None disables FP8. Orthogonal to mixed_precision; combine both for FP8 matmuls with bfloat16 non-linear ops. Requires CUDA SM >= 8.9. Default is None.

None
fp8_dim_alignment int

Minimum alignment for FP8 Linear layer dimensions. Layers whose in_features or out_features are not divisible by this value are skipped. Hardware requires 16. Default is 16.

16
qat_recipe str or None

Quantization-aware training recipe via torchao. Inserts FakeQuantizedLinear modules so the forward pass simulates the target low-bit precision while backward stays in full precision. After training, run forgather finalize --quantize <recipe> to produce the real low-bit deployment artifact. Mutually exclusive with fp8_recipe. See docs/trainers/qat-training.md for the recipe list. Default is None.

None
Source code in src/forgather/ml/trainer/base_trainer.py
@dataclass(kw_only=True)
class BaseTrainingArguments(MinimalTrainingArguments):
    """Extended training arguments with checkpoint management and PyTorch optimizations.

    Extends ``MinimalTrainingArguments`` with full checkpoint state preservation and a
    range of PyTorch runtime optimizations (mixed precision, FP8, SDPA backend
    selection, activation offloading, etc.).

    All training state (model, optimizer, scheduler, dataset position, RNG state) is
    automatically saved in checkpoints. To skip loading a specific component when
    resuming, manually delete its file from the checkpoint directory before calling
    ``train()``.

    .. note::
        The checkpoint-related options in this class are **not** compatible with the
        HuggingFace ``Trainer``. Use ``MinimalTrainingArguments`` when HF compatibility
        is required.

    Parameters
    ----------
    default_dtype : str or None, optional
        Default ``torch.dtype`` for model construction (e.g. ``"float32"``,
        ``"bfloat16"``, ``"float16"``). ``None`` leaves PyTorch's global default
        unchanged. Default is ``None``.
    max_eval_steps : int, optional
        Maximum number of evaluation steps per evaluation call. ``-1`` runs the
        full evaluation dataset. Default is ``-1``.
    preserve_best_model : bool, optional
        If ``True``, keep the checkpoint with the best value of
        ``best_model_metric`` protected from cleanup rotation. Default is ``False``.
    best_model_metric : str, optional
        Name of the metric used to select the best checkpoint when
        ``preserve_best_model=True``. Default is ``"loss"``.
    best_model_greater_is_better : bool or None, optional
        Whether higher values of ``best_model_metric`` are better. ``None``
        auto-detects from the metric name (metrics containing ``"loss"`` or
        ``"perplexity"`` default to lower-is-better). Default is ``None``.
    preserve_n_best : int, optional
        Number of best checkpoints to keep safe from ``save_total_limit``
        cleanup. Default is ``1``.
    eval_on_save : bool, optional
        Force an evaluation pass before each checkpoint save. Useful for
        decoupling the save and eval schedules. Default is ``False``.
    enable_activation_offloading : bool, optional
        Offload saved activation tensors to CPU during the backward pass to
        reduce peak GPU memory. Best combined with activation checkpointing.
        Trades GPU memory for CPU memory bandwidth. Default is ``False``.
    detect_anomaly : bool, optional
        Enable ``torch.autograd`` anomaly detection for debugging NaN/Inf
        gradients. Adds significant overhead — use only for debugging.
        Default is ``False``.
    sdpa_backend : list of str, str, or None, optional
        Scaled Dot-Product Attention backend(s). Valid string values are
        ``"math"``, ``"flash"``, ``"efficient"``, and ``"cudnn"``. Pass a list
        to specify multiple backends; if ``sdpa_set_priority=True``, the list is
        treated as a priority order. ``None`` uses PyTorch's default selection.
        Default is ``None``.
    sdpa_set_priority : bool, optional
        When ``sdpa_backend`` is a list, interpret it as a priority order rather
        than requiring all backends to be available. Default is ``False``.
    float32_matmul_precision : str or None, optional
        Float32 matrix-multiplication precision on Ampere+ GPUs. One of
        ``"highest"`` (full IEEE, slowest), ``"high"`` (TF32, ~10–20 % speedup),
        or ``"medium"`` (more aggressive, may impact accuracy). ``None`` leaves
        the PyTorch default unchanged. Default is ``None``.
    dynamo_recompile_limit : int or None, optional
        Override ``torch._dynamo.config.recompile_limit``. Increase when
        ``torch.compile()`` produces frequent recompilation warnings. ``None``
        leaves the default unchanged. Default is ``None``.
    mixed_precision : str or None, optional
        Automatic Mixed Precision mode. ``None`` or ``"no"`` disables AMP.
        ``"bf16"`` enables bfloat16 autocast without loss scaling (recommended
        for Ampere+ GPUs). ``"fp16"`` enables float16 autocast with
        ``GradScaler`` loss scaling. Default is ``None``.
    fp8_recipe : str or None, optional
        FP8 training recipe via ``torchao``. Converts ``nn.Linear`` layers to
        ``Float8Linear``. One of ``"tensorwise"`` (fastest), ``"rowwise"``
        (more accurate), or ``"rowwise_with_gw_hp"`` (most accurate). ``None``
        disables FP8. Orthogonal to ``mixed_precision``; combine both for FP8
        matmuls with bfloat16 non-linear ops. Requires CUDA SM >= 8.9.
        Default is ``None``.
    fp8_dim_alignment : int, optional
        Minimum alignment for FP8 ``Linear`` layer dimensions. Layers whose
        ``in_features`` or ``out_features`` are not divisible by this value are
        skipped. Hardware requires 16. Default is ``16``.
    qat_recipe : str or None, optional
        Quantization-aware training recipe via ``torchao``. Inserts
        ``FakeQuantizedLinear`` modules so the forward pass simulates the
        target low-bit precision while backward stays in full precision.
        After training, run ``forgather finalize --quantize <recipe>`` to
        produce the real low-bit deployment artifact. Mutually exclusive with
        ``fp8_recipe``. See ``docs/trainers/qat-training.md`` for the recipe
        list. Default is ``None``.
    """

    # Default torch dtype for model construction (e.g., "float32", "bfloat16", "float16")
    default_dtype: str | None = None

    # Limit maximum validation/eval steps (-1 for unlimited)
    max_eval_steps: int = -1

    # Checkpoint preservation
    preserve_best_model: bool = False
    best_model_metric: str = "loss"
    best_model_greater_is_better: bool | None = None
    preserve_n_best: int = 1  # Keep N best checkpoints safe from cleanup

    # Force evaluation before save (decouples save/eval scheduling)
    eval_on_save: bool = False

    # Offload activation tensors to CPU memory during backward pass to reduce GPU memory usage.
    # Best combined with activation checkpointing. Trade GPU memory for CPU memory and bandwidth.
    # https://docs.pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#saving-tensors-to-cpu
    enable_activation_offloading: bool = False

    # Enable PyTorch anomaly detection for debugging NaN/Inf in gradients.
    # Adds overhead - only use for debugging.
    # https://docs.pytorch.org/docs/stable/autograd.html#debugging-and-anomaly-detection
    detect_anomaly: bool = False

    # Set Scaled Dot-Product Attention (SDPA) backend implementation.
    # Options: "math" (reference), "flash" (Flash Attention), "efficient" (memory-efficient), "cudnn"
    # Can be a single backend or list for priority order (if sdpa_set_priority=True)
    # https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel
    sdpa_backend: List[str] | str | None = (
        None  # "math" | "flash" | "efficient" | "cudnn"
    )
    sdpa_set_priority: bool = (
        False  # If True and sdpa_backend is list, interpret as priority order
    )

    # Set matmul precision for float32 operations on Ampere+ GPUs for speedup.
    # "highest": Full IEEE precision (slowest)
    # "high": TF32 precision (~10-20% speedup, minimal accuracy loss)
    # "medium": More aggressive optimization (faster but may impact accuracy)
    # https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
    float32_matmul_precision: str | None = None  # "highest" | "high" | "medium"

    # Override PyTorch dynamo recompilation limit (default is quite low).
    # Increase if seeing frequent recompilations with torch.compile().
    dynamo_recompile_limit: int | None = None

    # Automatic Mixed Precision mode.
    # None or "no": disabled (default)
    # "bf16": bfloat16 autocast, no loss scaling (recommended for Ampere+ GPUs)
    # "fp16": float16 autocast with GradScaler loss scaling
    mixed_precision: str | None = None

    # FP8 training via torchao. Converts nn.Linear to Float8Linear for FP8 matmuls.
    # Recipes: "tensorwise" (fastest), "rowwise" (more accurate), "rowwise_with_gw_hp" (most accurate)
    # None = disabled. Orthogonal to mixed_precision (use both for FP8 matmuls + bf16 non-linear ops).
    # Requires CUDA SM >= 8.9 (RTX 4090, H100, etc.) and matmul dims divisible by 16.
    fp8_recipe: str | None = None

    # Minimum alignment for FP8 Linear layer dimensions. Layers with in_features or
    # out_features not divisible by this value are skipped. Hardware requires 16.
    fp8_dim_alignment: int = 16

    # Quantization-aware training (QAT) via torchao. Inserts FakeQuantizedLinear
    # modules in the prepare phase; convert is done post-training via
    # `forgather finalize --quantize <recipe>`. Mutually exclusive with fp8_recipe.
    # See src/forgather/ml/qat_recipes.py for the recipe table.
    qat_recipe: str | None = None

    def __post_init__(self):
        if self.logging_dir is None:
            self.logging_dir = os.path.join(
                self.output_dir, "runs", f"{time.time_ns()}_{platform.node()}"
            )

        # As per https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
        if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0:
            self.dataloader_prefetch_factor = 2
        if self.torch_compile_backend is None:
            self.torch_compile_backend = "inductor"
        if self.torch_compile_mode is None:
            self.torch_compile_backend = "default"

        if self.lr_scheduler_kwargs is None:
            self.lr_scheduler_kwargs = {}

        # Validate mixed_precision
        if self.mixed_precision is not None:
            if self.mixed_precision == "no":
                self.mixed_precision = None
            elif self.mixed_precision not in ("bf16", "fp16"):
                raise ValueError(
                    f"mixed_precision must be None, 'no', 'bf16', or 'fp16', "
                    f"got '{self.mixed_precision}'"
                )

        # Validate fp8_recipe
        _FP8_RECIPES = ("tensorwise", "rowwise", "rowwise_with_gw_hp")
        if self.fp8_recipe is not None:
            if self.fp8_recipe not in _FP8_RECIPES:
                raise ValueError(
                    f"fp8_recipe must be one of {_FP8_RECIPES}, got '{self.fp8_recipe}'"
                )

        # Validate qat_recipe
        if self.qat_recipe is not None:
            from forgather.ml.qat_recipes import QAT_RECIPES

            if self.qat_recipe not in QAT_RECIPES:
                raise ValueError(
                    f"qat_recipe must be one of {QAT_RECIPES}, got '{self.qat_recipe}'"
                )

        # Linear-swap recipes are mutually exclusive (each replaces nn.Linear
        # with a different specialised class). Add new recipes here to keep
        # the check single-source.
        _LINEAR_SWAP_RECIPES = ("fp8_recipe", "qat_recipe")
        _active = [name for name in _LINEAR_SWAP_RECIPES if getattr(self, name)]
        if len(_active) > 1:
            raise ValueError(
                f"Linear-swap recipes are mutually exclusive "
                f"(each replaces nn.Linear); set at most one. Got: {_active}"
            )

Single-GPU Trainer

forgather.ml.trainer.trainer.Trainer

Bases: BaseTrainer[TTrainingArguments], Generic[TTrainingArguments]

A lightweight, single-device trainer with API close to transformers.Trainer.

This trainer provides a simplified, more comprehensible implementation of the HuggingFace Trainer, intended as a drop-in replacement for basic use cases.

Key features: - Compatible with HF Trainer API for basic training workflows - Memory optimizations: fused loss, fused optimizer/backward, activation checkpointing - Flexible model construction: default/meta/device modes for different memory/speed tradeoffs - Full checkpoint management: saves/restores model, optimizer, scheduler, dataset state - Best model tracking via load_best_model_at_end

For distributed training, see AccelTrainer (data parallel via Accelerate) and PipelineTrainer (pipeline parallelism).

Basic usage: trainer = Trainer( model=model, args=TrainingArguments(...), train_dataset=train_dataset, eval_dataset=eval_dataset, optimizer_factory=optimizer_factory, lr_scheduler_factory=lr_scheduler_factory, ) trainer.train()

Source code in src/forgather/ml/trainer/trainer.py
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
class Trainer(BaseTrainer[TTrainingArguments], Generic[TTrainingArguments]):
    """
    A lightweight, single-device trainer with API close to transformers.Trainer.

    This trainer provides a simplified, more comprehensible implementation of the
    HuggingFace Trainer, intended as a drop-in replacement for basic use cases.

    Key features:
    - Compatible with HF Trainer API for basic training workflows
    - Memory optimizations: fused loss, fused optimizer/backward, activation checkpointing
    - Flexible model construction: default/meta/device modes for different memory/speed tradeoffs
    - Full checkpoint management: saves/restores model, optimizer, scheduler, dataset state
    - Best model tracking via load_best_model_at_end

    For distributed training, see AccelTrainer (data parallel via Accelerate) and
    PipelineTrainer (pipeline parallelism).

    Basic usage:
        trainer = Trainer(
            model=model,
            args=TrainingArguments(...),
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            optimizer_factory=optimizer_factory,
            lr_scheduler_factory=lr_scheduler_factory,
        )
        trainer.train()
    """

    args: TTrainingArguments
    dist: DistributedEnvInterface
    optimizer_factory: OptimizerFactoryT | None
    lr_scheduler_factory: LRSchedulerFactoryT | None
    enable_activation_checkpoint_fn: EnableCheckpointFnT | None
    fused_loss_factory: FusedLossFactoryT | None
    optimizer_groups: OptimGroupMap | None

    max_steps: int
    epoch_train_steps: int
    do_train: bool
    do_eval: bool
    use_fused_loss: bool
    gradient_accumulation_step: int

    @classmethod
    def default_callbacks(cls):
        return [ProgressCallback(), InfoCallback()]

    def __init__(
        self,
        *,
        args: TTrainingArguments | dict,
        distributed_env: DistributedEnvInterface,
        optimizer_factory: Optional[OptimizerFactoryT] = None,
        # Alternative, for compatibility with transformers.Trainer
        optimizer_cls_and_kwargs: Optional[
            Tuple[Type[OptimizerT], Dict[str, Any]]
        ] = None,
        lr_scheduler_factory: Optional[LRSchedulerFactoryT] = None,
        enable_activation_checkpoint_fn: Optional[
            EnableCheckpointFnT
        ] = enable_hf_activation_checkpointing,
        fused_loss_factory: Optional[FusedLossFactoryT] = None,
        optimizer_groups: Optional[OptimGroupMap] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        args : TrainingArguments or dict
            Training configuration. Accepts a ``TrainingArguments`` instance or a
            plain dict (converted automatically). See ``TrainingArguments`` for all options.
        distributed_env : DistributedEnvInterface
            Distributed environment object providing rank/device information.
            Use ``DistributedEnvironment`` for real training or ``StaticDistributedEnvironment``
            for single-process runs. The trainer asserts ``world_size == 1`` for non-distributed
            subclasses.
        optimizer_factory : callable, optional
            Callable that accepts ``model.named_parameters()`` and returns an optimizer.
            If not provided and ``optimizer_cls_and_kwargs`` is also absent, defaults to
            AdamW with parameters from ``args``.
        optimizer_cls_and_kwargs : tuple, optional
            HuggingFace Trainer-compatible alternative to ``optimizer_factory``.
            A ``(optimizer_class, kwargs_dict)`` pair. Ignored when ``optimizer_factory``
            is supplied.
        lr_scheduler_factory : callable, optional
            Callable that accepts the optimizer and returns an LR scheduler.
            If not provided and ``args.lr_scheduler_type`` is set, falls back to
            ``transformers.get_scheduler``.
        enable_activation_checkpoint_fn : callable, optional
            Called as ``fn(rank, model)`` to enable activation (gradient) checkpointing.
            Defaults to ``enable_hf_activation_checkpointing``. Set to ``None`` to
            disable activation checkpointing even when ``args.gradient_checkpointing=True``.
        fused_loss_factory : callable, optional
            If provided, enables fused logits-loss computation. Called with the model's
            output embedding layer and returns a loss function. Requires the model to
            support ``get_output_embeddings()`` and ``return_hidden_states``.
        optimizer_groups : OptimGroupMap, optional
            Parameter group configuration for the optimizer. Allows different
            hyperparameters (lr, weight_decay) for different parameter subsets.
        **kwargs
            Passed to ``BaseTrainer``: ``model``, ``model_init``, ``train_dataset``,
            ``eval_dataset``, ``loss_fn``, ``data_collator``, ``processing_class``,
            ``callbacks``, ``optimizer``, ``lr_scheduler``.
        """
        if isinstance(args, dict):
            args = cast(TTrainingArguments, from_dict(TrainingArguments, args))
        super().__init__(args=args, **kwargs)

        # HF Trainer compatibility.
        if not optimizer_factory:
            if not optimizer_cls_and_kwargs:
                optimizer_factory = partial(  # type: ignore[assignment]
                    torch.optim.AdamW,
                    lr=args.learning_rate,
                    betas=(args.adam_beta1, args.adam_beta2),
                    weight_decay=args.weight_decay,
                    eps=args.adam_epsilon,
                )
            else:
                optimizer_factory = partial(
                    optimizer_cls_and_kwargs[0], **optimizer_cls_and_kwargs[1]
                )
        self.optimizer_groups = optimizer_groups
        self.dist = distributed_env
        self.optimizer_factory = optimizer_factory
        self.lr_scheduler_factory = lr_scheduler_factory
        self.enable_activation_checkpoint_fn = enable_activation_checkpoint_fn
        self.fused_loss_factory = fused_loss_factory
        self.use_fused_loss = False
        assert (self.model is not None) or (
            self.model_init is not None
        ), "Either a model or a model constructor must be specified."

        assert (
            self.args.max_grad_norm is None or not self.args.fuse_optim_with_backward
        ), "max_grad_norm is incompatible with fuse_optim_with_backward"

        assert (
            self.args.gradient_accumulation_steps == 1
            or not self.args.fuse_optim_with_backward
        ), "gradient_accumulation_steps={self.args.gradient_accumulation_steps} is incompatible with fuse_optim_with_backward"

        assert (
            self.args.mixed_precision != "fp16"
            or not self.args.fuse_optim_with_backward
        ), (
            "fp16 mixed precision with GradScaler is incompatible with fuse_optim_with_backward. "
            "Use mixed_precision='bf16' instead (no GradScaler needed)."
        )

        assert (
            self.loss_fn or self.args.gradient_accumulation_steps == 1
        ), f"gradient_accumulation_steps [{self.args.gradient_accumulation_steps}] > 1 requires loss_fn"

        if self.data_collator is None:
            self.data_collator = torch.utils.data.default_collate

        # Compute FLOPs per token for tracking
        # Note: model not initialized yet, will be set in _prepare()

    def _compute_flops_per_token(self) -> float:
        """
        Estimate FLOPs per token for forward + backward pass.

        Uses the standard C = 6N approximation where N = non-embedding parameters.
        Counts each multiply-accumulate as 2 FLOPs, consistent with hardware specs.
        See ModelParameterStats.flops_per_token for references.
        """
        assert self.model is not None
        from forgather.ml.utils import count_parameters

        return count_parameters(self.model).flops_per_token

    def _count_batch_tokens(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Tensor:
        """
        Count non-padding tokens using labels tensor.

        Uses labels as the primary source since the cross-entropy ignore_index (-100)
        marks padding and special tokens that should not be counted.

        Returns a GPU tensor to avoid forcing a GPU-CPU synchronization via .item(),
        which would stall the CPU waiting for any pending compiled graph execution.

        Parameters
        ----------
        input_dict : dict
            Batch input dictionary (unused in base, available for overrides).
        labels : Tensor
            Target labels with -100 (ignore_index) marking padding/special positions.

        Returns
        -------
        Tensor
            Count of non-ignored tokens in the batch.
        """
        # Labels have -100 for padding/special tokens; count only real target tokens
        return (labels != -100).sum()

    def _distributed_tokens(self, tokens: Tensor) -> Tensor:
        """
        Aggregate token counts across processes.

        Base implementation for single-device training.
        Distributed trainers override to sum across ranks.

        Parameters
        ----------
        tokens : Tensor
            Token count tensor from the current process.

        Returns
        -------
        Tensor
            Token count tensor (single-device) or aggregated across ranks (distributed).
        """
        return tokens

    def _distributed_peak_mem(self, local_peak: int) -> list[int]:
        """
        Gather per-rank peak CUDA memory into a list on every rank.

        Base implementation is single-device and returns a one-element list.
        Distributed trainers override to all_gather across ranks so logging
        reflects per-rank high-water marks.

        Parameters
        ----------
        local_peak : int
            ``torch.cuda.max_memory_allocated`` for this rank, in bytes.

        Returns
        -------
        list of int
            Peak bytes indexed by rank (length ``world_size``, or 1 for single-device).
        """
        return [int(local_peak)]

    def _init_distributed(self):
        """
        Subclasses are expected to override, if they support distributed training.
        If distributed training is supported, set the following variables:
          self.is_local_process_zero
          self.is_world_process_zero
          self.num_processes
        """
        assert (
            self.dist.world_size == 1
        ), "'Trainer' does not support distributed training. See subclasses for implementations which do support it."

    def _init_device(self):
        """Update / init trainer's device"""
        # If unspecified, set a default device
        if self.args.device is None:
            self.args.device = self.dist.device
        # Override for debug.
        if self.args.use_cpu:
            self.args.device = "cpu"

    def _get_dataloader(self, dataset, batch_size):
        if not isinstance(dataset, tn.BaseNode | DataLoader):
            dataloader_kwargs = {
                "batch_size": batch_size,
                "collate_fn": self.data_collator,
                "drop_last": self.args.dataloader_drop_last,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "prefetch_factor": self.args.dataloader_prefetch_factor,
                "persistent_workers": self.args.dataloader_persistent_workers,
            }

            # Use StatefulDataLoader for datasets with state if available and requested
            return StatefulDataLoader(dataset, **dataloader_kwargs)
        else:
            return dataset

    @override
    def _prepare(
        self, train_dataset: Optional[BaseDataset], eval_dataset: Optional[BaseDataset]
    ) -> None:
        """
        Prepare for training and/or evaluation
        """
        self._init_distributed()
        self._init_device()

        # Initialize AMP context for mixed precision training
        device_type = (
            self.args.device
            if isinstance(self.args.device, str)
            else str(self.args.device)
        )
        if "cuda" in device_type:
            device_type = "cuda"
        self.amp_context = AMPContext(
            mixed_precision=self.args.mixed_precision,
            device_type=device_type,
        )
        if self.amp_context.enabled:
            logger.info(
                f"Mixed precision training enabled: {self.args.mixed_precision}"
            )

        # Set the random seed
        if self.args.seed != -1:
            import random

            torch.manual_seed(self.args.seed)
            random.seed(self.args.seed)

        if self.args.activation_memory_budget:
            logger.info(
                f"Setting memory budget to {self.args.activation_memory_budget}"
            )
            try:
                torch._functorch.config.activation_memory_budget = (  # type: ignore[attr-defined]
                    self.args.activation_memory_budget
                )
            except AttributeError:
                logger.warning(
                    "PyTorch does not appear to support the experimental activation_memory_budget API"
                )

        # Resolve resume_from_checkpoint before model construction, since the
        # construction strategy depends on whether we're resuming (e.g., "meta"
        # requires a checkpoint, "default" skips init when resuming).
        # After this block, resume_from_checkpoint is either a path string or False.
        self._resolve_checkpoint()

        self._init_dataloaders(train_dataset, eval_dataset)
        self._prepare_model()
        if self.args.torch_compile:
            logger.info(
                f"Compiling model: backend={self.args.torch_compile_backend}, mode={self.args.torch_compile_mode}, "
                f"dynamic={self.args.torch_compile_dynamic}, fullgraph={self.args.torch_compile_full_graph}, "
                f"activation_memory_budget={self.args.activation_memory_budget}"
            )

            if os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", 0):
                logger.warning(
                    "Computing paritioner Pareto memory budget -- be patient, this takes time..."
                )
            self._compile_model()

        if self.do_train:
            self._init_optimizer()

        self._wrap()
        self.state = self._init_state()
        self.checkpoint_manager = self._init_checkpoint_manager()

        # Restore from checkpoint (path already resolved by _resolve_checkpoint)
        if self.args.resume_from_checkpoint:
            self.load_checkpoint(self.args.resume_from_checkpoint)
            # Non-persistent buffers (e.g., RotaryEmbedding's inv_freq) are not
            # saved in checkpoints. When the model was constructed on meta device,
            # these buffers were never initialized. Compute them now.
            if self.args.construct_model_on == "meta" and self.model is not None:
                self._initialize_non_persistent_buffers(self.model)

        self._dispatch_event("on_init_end")

    def _resolve_checkpoint(self) -> None:
        """Resolve resume_from_checkpoint to a concrete path or False.

        Called early in ``_prepare()`` before model construction, since the
        construction strategy depends on whether we're resuming.

        After this method:
        - ``self.args.resume_from_checkpoint`` is a path string (checkpoint found)
        - or ``False`` (no checkpoint, fresh initialization)
        """
        rfc = self.args.resume_from_checkpoint
        if not rfc:
            return

        if isinstance(rfc, bool):
            # True = auto-find latest checkpoint
            checkpoint_path = find_latest_checkpoint(self.args.output_dir)
        else:
            # Explicit path string
            checkpoint_path = rfc if os.path.exists(rfc) else None

        if checkpoint_path is None:
            logger.info(
                "No checkpoint found. Starting with fresh model initialization."
            )
            self.args.resume_from_checkpoint = False
        else:
            self.args.resume_from_checkpoint = checkpoint_path

    def _wrap(self) -> None:
        """
        Hook for wrapping objects after construction in _prepare().

        Called after dataloaders, model, optimizer, and scheduler are initialized
        but before training begins. Subclasses use this to wrap objects for
        distributed training or other runtime modifications.

        Examples:
        - AccelTrainer: Wraps model/optimizer/dataloaders with Accelerate
        - PipelineTrainer: Sets up pipeline parallel stage wrapping

        See src/forgather/ml/trainer/accelerate/accel_trainer.py and
        src/forgather/ml/trainer/pipeline/pipeline_trainer.py for examples.
        """
        pass

    def _compile_model(self):
        """
        Compile model using torch.compile().

        Hook for model compilation. Subclasses may override to customize compilation
        behavior or apply compilation to additional wrapped objects.
        """
        assert self.model is not None
        self.model.compile(
            backend=self.args.torch_compile_backend,
            mode=self.args.torch_compile_mode,
            dynamic=self.args.torch_compile_dynamic,
            fullgraph=self.args.torch_compile_full_graph,
        )

    def _init_checkpoint_manager(self) -> CheckpointManager:  # type: ignore[override]
        """
        Initialize checkpoint manager hook.

        Creates the CheckpointManager responsible for saving/loading complete training
        state including model, optimizer, scheduler, dataset state, and RNG state.

        Subclasses may override to provide custom checkpoint management behavior.

        Returns:
            CheckpointInterface: Initialized checkpoint manager
        """
        cp_config = CheckpointConfig(
            output_dir=self.args.output_dir,
            save_total_limit=self.args.save_total_limit,
            save_on_each_node=self.args.save_on_each_node,
            save_safetensors=self.args.save_safetensors,
        )

        checkpoint_manager = CheckpointManager(
            config=cp_config,
            dist=self.dist,
            model=self.unwrapped_model(),
            model_preprocessor=self.processing_class,
            stateful_provider=self,
        )
        # Set trainer reference for callback state save/load
        checkpoint_manager.trainer = self
        # Set preserve_n_best from training args
        if hasattr(self.args, "preserve_n_best"):
            checkpoint_manager.preserve_n_best = self.args.preserve_n_best
        return checkpoint_manager

    def _init_dataloaders(self, train_dataset, eval_dataset) -> None:
        """
        Initialize train and evaluation dataloaders (_prepare() sub-step 1).

        Creates StatefulDataLoader instances that support checkpointing dataset iteration state.
        Also computes training step counts for scheduling logging/evaluation/checkpointing.

        Parameters
        ----------
        train_dataset : dataset or None
            Training dataset. Pass ``None`` for eval-only runs.
        eval_dataset : dataset or None
            Evaluation dataset. Pass ``None`` for train-only runs.
        """
        # _prepare() sub-step 1
        self.max_steps = 0
        self.epoch_train_steps = self.args.epoch_train_steps

        self.do_train = train_dataset is not None
        self.do_eval = eval_dataset is not None

        if self.do_train:
            if (
                self.args.set_dataset_epoch
                and self.args.num_train_epochs > 1.0
                and not hasattr(train_dataset, "set_epoch")
            ):
                logger.warning(
                    "Train dataset does not support `set_epoch` and training for > 1 epoch. Dataset will not be reshuffled after each epoch"
                )
                self.args.set_dataset_epoch = False
            self.train_dataloader = self._get_dataloader(
                train_dataset, self.args.per_device_train_batch_size
            )

            self._update_training_steps()

        if self.do_eval:
            self.eval_dataloader = self._get_dataloader(
                eval_dataset, self.args.per_device_eval_batch_size
            )

    def _prepare_model(self) -> None:
        """
        Construct/initialize model and move to device (_prepare() sub-step 2).

        Handles three model construction strategies based on construct_model_on:
        - default: Safe, works everywhere. Constructs on CPU (with no_init_weights if loading
                  checkpoint), then moves to device.
        - meta: Fastest. Constructs on meta device (no memory), materializes empty on device.
                Requires loading checkpoint. May have issues with non-persistent buffers.
        - device: Middle ground. Constructs directly on device with initialization. Faster
                 than default but may fail with model sharding. Good for models with buffers
                 not saved in checkpoint (e.g., RoPE).

        Also sets up gradient checkpointing if enabled and initializes fused loss if available.
        """
        # _prepare() sub-step 2
        # Meta device construction requires a checkpoint. If none is available,
        # fall back to the default strategy.
        if (
            self.args.construct_model_on == "meta"
            and not self.args.resume_from_checkpoint
        ):
            logger.warning(
                "No checkpoint available for meta device construction. "
                "Falling back to 'default' strategy (construct on CPU, "
                "initialize, move to device)."
            )
            self.args.construct_model_on = "default"

        match self.args.construct_model_on:
            case "default":
                if self.model_init:
                    logger.info(
                        f"Constructing model on default device and moving to {self.args.device}"
                    )
                    with ExitStack() as exit_stack:
                        if self.args.default_dtype:
                            exit_stack.enter_context(
                                default_dtype(torch_dtype(self.args.default_dtype))
                            )
                        if self.args.resume_from_checkpoint:
                            exit_stack.enter_context(no_init_weights())
                        self.model = self.model_init()
                else:
                    logger.info(f"Moving model to {self.args.device}")
                    assert self.model is not None
                self.model = self.model.to(self.args.device)
            case "meta":
                assert (
                    self.model_init
                ), "Constructing the model on meta device requires model_init"

                logger.info(
                    f"Constructing model on meta device and materializing on {self.args.device}"
                )
                with ExitStack() as exit_stack:
                    if self.args.default_dtype:
                        exit_stack.enter_context(
                            default_dtype(torch_dtype(self.args.default_dtype))
                        )
                    exit_stack.enter_context(torch.device("meta"))
                    self.model = self.model_init()
                sharing_metadata = create_sharing_metadata(self.model)
                self.model.to_empty(device=self.args.device)
                # to_empty() will break tied parameters. Fix them!
                retie_parameters(self.model, sharing_metadata)
            case "device":
                assert (
                    self.model_init
                ), "Constructing the model on device requires model_init"
                logger.info(
                    f"Constructing and initializing model directly on {self.args.device}"
                )
                with ExitStack() as exit_stack:
                    if self.args.default_dtype:
                        exit_stack.enter_context(
                            default_dtype(torch_dtype(self.args.default_dtype))
                        )
                    exit_stack.enter_context(torch.device(self.args.device))
                    if self.args.resume_from_checkpoint:
                        exit_stack.enter_context(no_init_weights())
                    self.model = self.model_init()
            case _:
                raise ValueError("Requires one of: default|meta|device")
        assert self.model is not None
        # Linear-swap recipes (fp8 / qat) are mutually exclusive — see the
        # _LINEAR_SWAP_RECIPES check in BaseTrainingArguments.__post_init__.
        # The if-chain is sequential rather than elif so a future relaxed
        # mutex still surfaces a clear error (the second swap would find
        # no nn.Linear left and report 0/N converted) instead of silently
        # producing a single-recipe model.
        if self.args.fp8_recipe:
            self.model = self._apply_fp8_training(self.model)
        if self.args.qat_recipe:
            self.model = self._apply_qat_training(self.model)
        if self.args.gradient_checkpointing:
            if self.enable_activation_checkpoint_fn is None:
                if self.dist.rank == 0:
                    logger.warning(
                        f"Activation checkpointing requested, but no function defined!"
                    )
            else:
                # Enable activation checkpointing for all modules in the pipeline.
                self.enable_activation_checkpoint_fn(self.dist.rank, self.model)
        self.loss_fn = self._maybe_get_fused_loss_fn(self.model, self.loss_fn)
        self._wrap_loss_fn()

        # Compute FLOPs per token for performance tracking
        self._flops_per_token = self._compute_flops_per_token()
        if self.dist.rank == 0:
            logger.info(f"Estimated FLOPs per token: {self._flops_per_token:.2e}")

    @staticmethod
    def _initialize_non_persistent_buffers(model: torch.nn.Module) -> None:
        """Initialize non-persistent buffers that were not loaded from checkpoint.

        Non-persistent buffers (registered with ``persistent=False``) are excluded
        from ``state_dict()`` and therefore not saved in or loaded from checkpoints.
        When the model is constructed on the meta device and then materialized via
        ``to_empty()``, these buffers contain uninitialized data.

        This method finds modules with non-persistent buffers and calls their
        ``reset_parameters()`` method to compute the correct values. For example,
        ``RotaryEmbedding.reset_parameters()`` computes ``inv_freq`` from the
        configured ``rope_theta`` and ``d_head``.

        Should be called after checkpoint load when using ``construct_model_on="meta"``.
        """
        for module in model.modules():
            if getattr(module, "_non_persistent_buffers_set", None):
                if hasattr(module, "reset_parameters"):
                    module.reset_parameters()

    def _apply_fp8_training(self, model: torch.nn.Module) -> torch.nn.Module:
        """Convert nn.Linear layers to Float8Linear for FP8 training via torchao."""
        from torchao.float8 import Float8LinearConfig, convert_to_float8_training
        from torchao.float8.float8_linear import Float8Linear

        assert self.args.fp8_recipe is not None
        if self.args.fp8_recipe == "rowwise_with_gw_hp":
            # torchao 0.16.0: matmul_with_hp_or_float8_args reshapes axiswise-scaled
            # inputs in forward(), which trips an assertion in float8_ops. ND inputs
            # (e.g. transformer hidden states of shape (B, S, H)) cannot be used with
            # this recipe. Plain "rowwise" and "tensorwise" are unaffected.
            logger.warning(
                "fp8_recipe='rowwise_with_gw_hp' is currently broken in torchao for "
                "ND inputs (transformer hidden states): reshape on axiswise-scaled "
                "Float8Tensor raises 'aten.reshape.default with axiswise scaling is "
                "not supported yet'. Use 'rowwise' or 'tensorwise' instead."
            )
        config = Float8LinearConfig.from_recipe_name(self.args.fp8_recipe)

        module_filter_fn = None
        pad = self.args.fp8_dim_alignment
        if pad:

            def _filter_fn(mod: torch.nn.Module, fqn: str) -> bool:
                if isinstance(mod, torch.nn.Linear):
                    ok = mod.in_features % pad == 0 and mod.out_features % pad == 0
                    if not ok:
                        logger.info(
                            f"Skipping FP8 for {fqn}: dims ({mod.in_features}, {mod.out_features}) "
                            f"not divisible by {pad}"
                        )
                    return ok
                return True

            module_filter_fn = _filter_fn

        model = convert_to_float8_training(
            model, config=config, module_filter_fn=module_filter_fn
        )

        converted = sum(1 for m in model.modules() if isinstance(m, Float8Linear))
        total_linear = sum(
            1 for m in model.modules() if isinstance(m, (torch.nn.Linear, Float8Linear))
        )
        logger.info(
            f"FP8 training ({self.args.fp8_recipe}): "
            f"converted {converted}/{total_linear} Linear layers"
        )

        return model

    def _apply_qat_training(self, model: torch.nn.Module) -> torch.nn.Module:
        """Install torchao FakeQuantizedLinear modules for quantization-aware training.

        The forward pass simulates the target low-bit precision via fake
        quantizers while the backward pass stays in full precision, letting
        the model learn to be robust to quantization noise. The convert phase
        (real low-bit ops) is run post-training by ``forgather finalize
        --quantize <recipe>``.
        """
        from torchao.quantization import quantize_
        from torchao.quantization.qat import FakeQuantizedLinear, QATConfig

        from forgather.ml.qat_recipes import recipe_to_base_config

        assert self.args.qat_recipe is not None
        base_config = recipe_to_base_config(self.args.qat_recipe)
        quantize_(model, QATConfig(base_config, step="prepare"))

        converted = sum(
            1 for m in model.modules() if isinstance(m, FakeQuantizedLinear)
        )
        total_linear = sum(
            1
            for m in model.modules()
            if isinstance(m, (torch.nn.Linear, FakeQuantizedLinear))
        )
        logger.info(
            f"QAT training ({self.args.qat_recipe}): "
            f"converted {converted}/{total_linear} Linear layers to FakeQuantizedLinear. "
            f"Run `forgather finalize --quantize {self.args.qat_recipe}` "
            f"after training to produce a deployable quantized artifact."
        )

        return model

    def _wrap_loss_fn(self):
        # Rescale loss by gradient accumulation steps.
        self.loss_fn = RescaleLoss(
            self.loss_fn, 1 / self.args.gradient_accumulation_steps
        )

    def _maybe_get_fused_loss_fn(
        self, module: torch.nn.Module, default_loss_fn: Optional[LossFunctionT]
    ):
        """
        Attempt to enable fused loss-logits computation for memory optimization.

        Fused loss combines the final linear layer (computing logits from hidden states)
        with the cross-entropy loss computation. This avoids materializing the full logits
        tensor in memory, which is critical for models with large vocabulary sizes where
        logits can be gigabytes in size.

        For example, with vocab_size=50k, batch_size=8, seq_len=2048:
        - Unfused: logits tensor is 8 * 2048 * 50000 * 4 bytes = ~3.2 GB
        - Fused: Only computes logits for one token at a time, dramatically less memory

        Requires:
        - fused_loss_factory provided to Trainer constructor
        - Model supports get_output_embeddings() (returns final linear layer)
        - Model supports return_hidden_states=True (returns hidden states instead of logits)

        See src/forgather/ml/loss.py and docs/fused_loss/ for implementations and details.

        If model supports fused-loss, and fused loss function is returned, sets:
            self.use_fused_loss = True

        Parameters
        ----------
        module : torch.nn.Module
            The model to attempt enabling fused loss on.
        default_loss_fn : callable or None
            Fallback loss function if fused loss is not supported.

        Returns
        -------
        callable or None
            The fused loss function, or ``default_loss_fn`` if unsupported.
        """
        if self.fused_loss_factory:
            if not hasattr(module, "get_output_embeddings"):
                logger.warning(
                    "Model does not support get_output_embeddings() for fused_loss_factory()"
                )
                return default_loss_fn
            if not getattr(module, "can_return_hidden_states", False):
                logger.warning(
                    f"Model does not support 'return_hidden_states' API; fused loss will not be used."
                )
                return default_loss_fn
            logger.info("Enabled fused loss-logits function")
            self.use_fused_loss = True
            return self.fused_loss_factory(module.get_output_embeddings())  # type: ignore[operator]
        return default_loss_fn

    def _init_optimizer(self) -> None:
        """
        Initialize optimizer and learning rate scheduler (_prepare() sub-step 3).

        Creates optimizer from factory and optionally sets up:
        - Learning rate scheduler (from factory or HF get_scheduler)
        - Fused backward/optimizer hooks if fuse_optim_with_backward=True
        """
        # _prepare() sub-step 3
        assert self.model is not None
        if self.optimizer is None:
            assert self.optimizer_factory is not None
            if self.optimizer_groups:
                buckets = build_optimizer_buckets(
                    self.model.named_parameters(),
                    self.optimizer_groups,
                    default_factory=self.optimizer_factory,
                    debug=self.args.debug_optimizer_groups,
                )
                if len(buckets) == 1:
                    factory, param_groups = buckets[0]
                    self.optimizer = factory(param_groups)
                else:
                    if self.args.fuse_optim_with_backward:
                        raise ValueError(
                            "fuse_optim_with_backward is incompatible with "
                            "multiple per-group optimizer factories: the "
                            "per-parameter post-accumulate hook can only "
                            "drive a single optimizer instance."
                        )
                    self.optimizer = cast(
                        Any,
                        Multiopt([factory(pg) for factory, pg in buckets]),
                    )
            else:
                self.optimizer = self.optimizer_factory(self.model.named_parameters())

            # Combine backward with optimizer step?
            if self.args.fuse_optim_with_backward:
                self._total_grad_squared = torch.zeros(
                    1, device=self.args.device, dtype=torch.float32
                )

                for name, p in self.model.named_parameters():
                    if p.requires_grad:
                        hook = partial(
                            optimizer_hook,
                            self.optimizer,
                            self._total_grad_squared,
                            name,
                        )
                        p.register_post_accumulate_grad_hook(hook)

        if self.lr_scheduler is None:
            if self.lr_scheduler_factory is not None:
                assert self.optimizer is not None
                self.lr_scheduler = self.lr_scheduler_factory(self.optimizer)
            elif self.args.lr_scheduler_type:
                from transformers import get_scheduler

                self.lr_scheduler = get_scheduler(
                    name=self.args.lr_scheduler_type,
                    optimizer=cast(Any, self.optimizer),  # type: ignore[arg-type]
                    num_warmup_steps=self.args.warmup_steps,
                    num_training_steps=self.max_steps,
                    scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
                )

            # Auto-register scheduler as callback if it derives from TrainerCallback.
            # This allows schedulers like GradientNoiseScheduler to receive
            # training metrics (grad_norm, loss, etc.) via on_train_metrics.
            if isinstance(self.lr_scheduler, TrainerCallback):
                self.add_callback(self.lr_scheduler)

    def _maybe_log_save_evaluate(
        self,
        loss_log,
        total_norm_log,
        tokens_log,
        periodic_log,
        periodic_eval,
        periodic_save,
    ):
        # The logic diverges slighlty from HF, in that this in an 'and'
        # It's not clear if this a bug? It's also not clear if any callbacks depend on this?
        # Until proven otherwise, try to do the right thing here.
        if periodic_log.step() or self.control.should_log:
            log_steps = periodic_log.reset()
            self.control.should_log = False

            self._log_step(loss_log, total_norm_log, tokens_log)

        # Handle evaluation (normal schedule or control-triggered)
        eval_metrics = None
        should_eval = periodic_eval.step() or self.control.should_evaluate

        # Force eval if saving and eval_on_save enabled
        should_save = periodic_save.step() or self.control.should_save
        if should_save and self.args.eval_on_save and self.eval_dataset is not None:
            should_eval = True

        if should_eval:
            periodic_eval.reset()
            self.control.should_evaluate = False

            # Do eval
            maybe_cleanup_memory(self.args.gc_threshold)
            eval_metrics = self._eval_loop()

        # Handle checkpointing - normal schedule or control-triggered
        if should_save:
            periodic_save.reset()
            self.control.should_save = False
            assert self.checkpoint_manager

            # Determine checkpoint path BEFORE saving
            checkpoint_path = next_checkpoint_path(
                self.args.output_dir, str(self.state.global_step)
            )

            # Update best checkpoints list BEFORE saving (so preserved list is correct)
            if self.args.preserve_best_model and eval_metrics:
                checkpoint_manager = cast(CheckpointManager, self.checkpoint_manager)
                checkpoint_manager.update_best_checkpoints(
                    checkpoint_path=checkpoint_path,
                    metrics=eval_metrics,
                    metric_key=self.args.best_model_metric,
                    greater_is_better=self.args.best_model_greater_is_better,
                    preserve_n_best=self.args.preserve_n_best,
                )
                # Update state for compatibility
                if checkpoint_manager.best_checkpoints:
                    self.state.best_metric = checkpoint_manager.best_checkpoints[0][1]
                    self.state.best_model_checkpoint = (
                        checkpoint_manager.best_checkpoints[0][0]
                    )

            # Now save checkpoint (deletion will use updated preserved list)
            saved_path = self.checkpoint_manager.save_checkpoint(
                checkpoint_id=str(self.state.global_step)
            )
            self._dispatch_event("on_save")

            # Save metrics file
            if eval_metrics:
                save_checkpoint_metrics(saved_path, eval_metrics)

    def load_best_model(self) -> None:
        """
        Load the best model from the best checkpoint.

        Called at end of training when load_best_model_at_end=True to restore
        the checkpoint with the best metric value seen during training.
        """
        if not self.state.best_model_checkpoint:
            logger.warning("No best model checkpoint available to load")
            return

        if not os.path.exists(self.state.best_model_checkpoint):
            logger.warning(
                f"Best model checkpoint path does not exist: {self.state.best_model_checkpoint}"
            )
            return

        logger.info(f"Loading best model from {self.state.best_model_checkpoint}")
        self.load_checkpoint(self.state.best_model_checkpoint)

    def _setup_periodic_functions(
        self,
    ) -> tuple[PeriodicFunction, PeriodicFunction, PeriodicFunction]:
        """Build the three periodic functions driving log/eval/save cadence."""
        periodic_log = PeriodicFunction(
            global_step=self.state.global_step,
            strategy=self.args.logging_strategy,
            period=self.args.logging_steps,
            epoch_period=self.epoch_train_steps,
            first_step=self.args.logging_first_step,
        )
        periodic_eval = PeriodicFunction(
            global_step=self.state.global_step,
            strategy=self.args.eval_strategy,
            period=self.args.eval_steps,
            epoch_period=self.epoch_train_steps,
            first_step=self.args.eval_delay,
        )
        periodic_save = PeriodicFunction(
            global_step=self.state.global_step,
            strategy=self.args.save_strategy,
            period=self.args.save_steps,
            epoch_period=self.epoch_train_steps,
            first_step=0,
        )
        return periodic_log, periodic_eval, periodic_save

    def _log_stop_condition(self) -> None:
        """Log a clear message describing why the training loop is stopping.

        Called from `_train_loop` immediately after the early-stop check
        trips. Distinguishes the three termination paths (abort without
        save, graceful stop, reaching max_steps) and identifies which
        callback -- if any -- requested the stop via
        ``self._stop_requested_by`` (populated by ``_dispatch_event``).
        """
        step = self.state.global_step
        requested_by = self._stop_requested_by

        if self.control.should_abort_without_save:
            if requested_by is not None:
                cb_name, event, _ = requested_by
                logger.warning(
                    f"Training aborted at step {step} by callback "
                    f"'{cb_name}' during '{event}' "
                    f"(no final checkpoint will be saved)"
                )
            else:
                logger.warning(
                    f"Training aborted at step {step} "
                    f"(should_abort_without_save set; "
                    f"no final checkpoint will be saved)"
                )
        elif self.control.should_training_stop:
            if requested_by is not None:
                cb_name, event, _ = requested_by
                logger.info(
                    f"Training stopped at step {step} by callback "
                    f"'{cb_name}' during '{event}'"
                )
            else:
                logger.info(
                    f"Training stopped at step {step} " f"(should_training_stop set)"
                )
        else:
            logger.info(
                f"Training stopped at step {step}: "
                f"reached max_steps ({self.max_steps})"
            )

    def _run_final_checkpoint_if_needed(
        self,
        accumulated_loss: list,
        accumulated_grad_norm: list,
        accumulated_tokens: list,
        periodic_log: PeriodicFunction,
        periodic_eval: PeriodicFunction,
        periodic_save: PeriodicFunction,
    ) -> None:
        """Run the end-of-training log/eval/save sequence.

        Skipped if ``should_abort_without_save`` is set (e.g. the
        divergence detector triggered). Otherwise forces a save on the
        current step if one hasn't already happened, then runs
        ``_maybe_log_save_evaluate`` once more to drain any pending
        accumulated metrics.
        """
        if self.control.should_abort_without_save:
            return

        # Force save, if we have not already saved on this step and save enabled.
        if periodic_save.rel_step != 0 and self.args.save_strategy != "no":
            logger.info(f"Saving final checkpoint at step {self.state.global_step}")
            # If load best model, we need to evaluate it too.
            if self.args.load_best_model_at_end:
                self.control.should_evaluate = True
            self.control.should_save = True
        self._maybe_log_save_evaluate(
            accumulated_loss,
            accumulated_grad_norm,
            accumulated_tokens,
            periodic_log,
            periodic_eval,
            periodic_save,
        )

    def _finalize_training(
        self,
        start_time: Optional[float],
        speed_metrics_steps: int,
    ) -> TrainOutput:
        """Post-training cleanup: load best model, summarise, fire on_train_end.

        Runs after the training loop exits (either normally, via a
        callback-requested stop, or via dataset exhaustion). Returns the
        ``TrainOutput`` that the public ``train()`` entry point
        eventually propagates to callers.
        """
        if self.args.load_best_model_at_end:
            logger.info(f'Loading best model "{self.state.best_model_checkpoint}"')
            self.load_best_model()

        metrics = self._end_train_loop(start_time, speed_metrics_steps)
        self.log(metrics)

        # Log best checkpoints summary at end of training
        if (
            self.args.preserve_best_model
            and self.checkpoint_manager
            and self.state.is_world_process_zero
        ):
            summary = cast(
                CheckpointManager, self.checkpoint_manager
            ).get_best_checkpoints_summary(metric_key=self.args.best_model_metric)
            logger.info(f"\n{'='*60}\nTraining complete!\n{summary}\n{'='*60}")

        self._dispatch_event("on_train_end")
        return TrainOutput(self.state.global_step, metrics)

    @override
    def _train_loop(self) -> TrainOutput:
        periodic_log, periodic_eval, periodic_save = self._setup_periodic_functions()

        assert self.optimizer is not None
        assert self.model is not None
        assert self.train_dataloader is not None

        # Just to be sure...
        self.optimizer.zero_grad()

        start_time = None
        train_steps = 0
        self._dispatch_event("on_train_begin")

        # Holds loss, grad-norm, and token samples between log steps
        accumulated_grad_norm: list = []
        accumulated_loss: list = []
        accumulated_tokens: list = []

        # Context manager for setting model.train()/eval()
        with set_train(self.model, True):
            # Epoch loop
            while True:
                self.control.should_epoch_stop = False
                if self.args.set_dataset_epoch and self.state.raw_epoch > 0:
                    # If supported, reshuffle dataset at the start of each epoch.
                    # Datasets without set_epoch (e.g. standard HF Dataset) are
                    # silently skipped rather than crashing mid-training.
                    dataset = self.train_dataloader.dataset  # type: ignore[union-attr]
                    if hasattr(dataset, "set_epoch"):
                        logger.debug(f"Setting dataset epoch {self.state.raw_epoch}")
                        dataset.set_epoch(self.state.raw_epoch)
                data_iterator = iter(self.train_dataloader)
                self._dispatch_event("on_epoch_begin")

                while True:
                    self._dispatch_event("on_step_begin")

                    try:
                        loss, total_norm, tokens = self._train_step(data_iterator)
                        self._dispatch_event(
                            "on_train_metrics",
                            loss=loss,
                            grad_norm=total_norm,
                            tokens=tokens,
                        )
                    except StopIteration:
                        self.state.raw_epoch += 1
                        self.state.epoch_start_step = self.state.global_step
                        if (
                            self.args.num_train_epochs >= 0
                            and self.state.raw_epoch >= self.args.num_train_epochs
                        ):
                            self.control.should_epoch_stop = True
                        self._dispatch_event("on_step_end")
                        break

                    accumulated_grad_norm.append(total_norm)
                    accumulated_loss.append(loss)
                    accumulated_tokens.append(tokens)

                    # Increment global step
                    self.state.global_step += 1

                    # Compute epoch as continuous value from global steps
                    self.state.epoch = float(self.state.raw_epoch) + (
                        float(self.state.global_step - self.state.epoch_start_step)
                        / float(self.epoch_train_steps)
                    )

                    self._dispatch_event("on_step_end")
                    self._maybe_log_save_evaluate(
                        accumulated_loss,
                        accumulated_grad_norm,
                        accumulated_tokens,
                        periodic_log,
                        periodic_eval,
                        periodic_save,
                    )

                    train_steps += 1

                    # Check for early stop condition
                    if (
                        self.control.should_training_stop
                        or self.control.should_abort_without_save
                        or self.state.global_step >= self.max_steps
                    ):
                        self._log_stop_condition()
                        self.control.should_epoch_stop = True
                        break

                    # Periodic GC
                    maybe_cleanup_memory(self.args.gc_threshold)

                    # maybe delay start of metrics recording for torch.compile()
                    if train_steps == self.args.speed_metrics_start_step:
                        start_time = time.time()

                self._dispatch_event("on_epoch_end")
                if self.control.should_epoch_stop:
                    break

        self._run_final_checkpoint_if_needed(
            accumulated_loss,
            accumulated_grad_norm,
            accumulated_tokens,
            periodic_log,
            periodic_eval,
            periodic_save,
        )

        return self._finalize_training(
            start_time,
            train_steps - self.args.speed_metrics_start_step,
        )

    @staticmethod
    def _format_zero_eval_batches_diagnostic(
        header: str,
        settings: list[tuple[str, Any]],
        explanation: str,
    ) -> str:
        """Render a zero-eval-batches diagnostic with a uniform shape.

        Subclasses build ``settings`` and ``explanation`` to add their own
        context (e.g. distributed dispatch flags) without restating the
        framing or the doc reference.
        """
        width = max(len(name) for name, _ in settings)
        settings_block = "\n".join(
            f"  {name.ljust(width)} = {value}" for name, value in settings
        )
        return (
            f"{header}\n"
            "\n"
            f"Effective settings:\n{settings_block}\n"
            "\n"
            f"{explanation}\n"
            "\n"
            "See docs/trainers/distributed-eval-zero-batches.md for the\n"
            "full diagnosis and the available remedies."
        )

    def _zero_eval_batches_message(self) -> str:
        """Build a diagnostic for the eval dataloader yielding zero batches.

        The base implementation surfaces the settings most likely to be
        responsible (eval batch size, drop_last, max_eval_steps). Distributed
        subclasses override this to also report ``dispatch_batches`` /
        ``dispatch_eval_batches`` and the world size, since those interact
        with sharded eval splits.
        """
        return self._format_zero_eval_batches_diagnostic(
            header="The eval dataloader did not yield any examples.",
            settings=[
                ("per_device_eval_batch_size", self.args.per_device_eval_batch_size),
                ("dataloader_drop_last", self.args.dataloader_drop_last),
                ("max_eval_steps", self.args.max_eval_steps),
            ],
            explanation=(
                "The eval split may simply be empty, or the dataloader may\n"
                "have dropped every batch because the dataset has fewer than\n"
                "per_device_eval_batch_size examples and\n"
                "dataloader_drop_last is True."
            ),
        )

    @override
    @torch.no_grad()
    def _eval_loop(self) -> Dict[str, float]:
        """
        The inner evaluation loop
        """
        assert self.model is not None
        assert self.eval_dataloader is not None
        with set_train(self.model, False):
            total_loss = torch.zeros(1, device=self.args.device)
            step = -1
            for step, batch in enumerate(self.eval_dataloader):
                if self.args.max_eval_steps > 0 and step >= self.args.max_eval_steps:
                    break
                input_dict, labels = self._prepare_batch(batch)
                outputs = self._prediction_step(input_dict, labels)
                loss = outputs["loss"]
                assert loss is not None
                total_loss += loss
                self._dispatch_event("on_prediction_step")
            if step < 0:
                raise RuntimeError(self._zero_eval_batches_message())

            metrics = {"eval_loss": (total_loss / (step + 1)).item()}
            if isinstance(self.eval_dataloader, StatefulDataLoader):
                sync_dataset_state_from_dataloader(self.eval_dataloader)
            self._dispatch_event("on_evaluate", metrics=metrics)
            return metrics

    def _clip_grad_norm(
        self, max_grad_norm: float | None, norm_type: float = 2.0
    ) -> Optional[Tensor]:
        """
        Clip gradients by norm.

        Returns
        -------
        Tensor or None
            Total norm of the parameters.

        Raises
        ------
        RuntimeError
            If parameters are not of supported types for ``foreach=True``.
        """
        # In the case of fused backward / optimizer step, we accumulate squared norm
        # in the optimizer hook. Compute norm via sqrt() and reset accumulator
        if self.args.fuse_optim_with_backward:
            total_norm = self._total_grad_squared.sqrt()
            self._total_grad_squared -= self._total_grad_squared
            return total_norm

        assert self.model is not None
        # If not clipping, just compute and return it
        if max_grad_norm is None or max_grad_norm == 0:
            grads = [p.grad for p in self.model.parameters() if p.grad is not None]

            total_norm = torch.nn.utils.get_total_norm(
                grads, norm_type=norm_type, foreach=True
            )
            return total_norm

        # Otherwise, use fused clip_grad_norm_
        total_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            max_grad_norm,
            norm_type=norm_type,
            foreach=False if self.args.device == "cpu" else True,
        )

        return total_norm

    def _train_step(
        self, data_iterator: Iterator[dict[str, Tensor]]
    ) -> Tuple[Tensor, Tensor | None, Tensor]:
        """
        Perform a single training step, with optional gradient accumulation.

        Parameters
        ----------
        data_iterator : Iterator
            Iterator over training batches.

        Returns
        -------
        tuple of (Tensor, Tensor or None, Tensor)
            ``(loss, grad_norm, tokens)``. Loss is unscaled for logging consistency.
            ``grad_norm`` is None if not computed on this step.
            ``tokens`` is the total non-padding tokens processed in this step.
        """
        accumulated_losses = []
        accumulated_tokens: Tensor | int = 0

        for gradient_step in range(self.args.gradient_accumulation_steps):
            self.gradient_accumulation_step = gradient_step + 1
            input_dict, labels = self._prepare_batch(next(data_iterator))
            self._dispatch_event("on_forward_backward_begin")
            loss = self._forward_backward_step(input_dict, labels)
            self._dispatch_event("on_forward_backward_end")
            accumulated_losses.append(loss)
            # Count tokens in this micro-batch (local, not yet synchronized across ranks).
            # _count_batch_tokens returns a GPU tensor to avoid forcing GPU-CPU sync.
            accumulated_tokens = accumulated_tokens + self._count_batch_tokens(
                input_dict, labels
            )

        assert self._should_sync_gradients()
        assert self.optimizer is not None

        # Unscale gradients before clipping (no-op when not using fp16 GradScaler)
        self.amp_context.unscale_(self.optimizer)
        total_norm = self._clip_grad_norm(self.args.max_grad_norm)
        self._dispatch_event("on_pre_optimizer_step")
        if not self.args.fuse_optim_with_backward:
            self.amp_context.optimizer_step(self.optimizer)
            self.optimizer.zero_grad()
        self._dispatch_event("on_optimizer_step")
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        loss = torch.sum(torch.stack(accumulated_losses))
        # Loss and token synchronization deferred to _log_step() to avoid
        # per-step NCCL all_reduce overhead.
        if isinstance(accumulated_tokens, Tensor):
            tokens_tensor = accumulated_tokens.to(dtype=torch.int64)
        else:
            tokens_tensor = torch.tensor(
                accumulated_tokens, device=self.args.device, dtype=torch.int64
            )
        return loss, total_norm, tokens_tensor

    def _forward_backward_step(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Tensor:
        """
        Execute forward pass followed by backward pass.

        Handles both standard and fused loss computation. When fused loss is enabled,
        requests hidden states instead of logits to avoid materializing the large
        logits tensor.

        Parameters
        ----------
        input_dict : dict
            Model inputs (input_ids, attention_mask, etc.).
        labels : Tensor
            Target labels for loss computation.

        Returns
        -------
        Tensor
            Detached loss tensor for logging.
        """
        assert self.model is not None
        assert self.loss_fn is not None
        if self.use_fused_loss:
            input_dict["return_hidden_states"] = True  # type: ignore[assignment]
        with self.amp_context.autocast():
            outputs = self.model(**input_dict)
            logits = logits_from_outputs(outputs)
            loss = self.loss_fn(logits, labels)
        self._backward(loss)
        return loss.detach()

    def _backward(self, loss: Tensor) -> None:
        """
        Execute backward pass to compute gradients.

        When fp16 mixed precision is active, the loss is scaled by GradScaler
        before backward to prevent gradient underflow.

        Subclasses may override to customize backward behavior (e.g., for pipeline parallelism).

        Parameters
        ----------
        loss : Tensor
            Loss tensor to backpropagate.
        """
        self.amp_context.scale_loss(loss).backward()

    def _should_sync_gradients(self) -> bool:
        """
        Determine if gradients should be synchronized across processes.

        Normally, this just compares the gradient accumulation step against the total
        accumulation steps.

        If overridden, as is the case for the Accelerate trainer, an assert verifies that the logic
        agrees with when gradient synchronization is expected.

        Returns
        -------
        bool
            True if gradients should be synchronized on the current forward-backward step.
        """
        return self.gradient_accumulation_step == self.args.gradient_accumulation_steps

    def _update_training_steps(self) -> None:
        """
        Compute training step counts from dataloader length.

        Calculates:
        - epoch_train_steps: Number of batches per epoch
        - max_steps: Total optimizer updates across all epochs

        With gradient accumulation: global_steps = batch_count // gradient_accumulation_steps

        This may need to be called multiple times if dataloader is wrapped (e.g., by Accelerate)
        and its length changes. Also synchronizes dataset state to ensure accurate length.
        """
        if isinstance(self.train_dataloader, StatefulDataLoader):
            sync_dataset_state_from_dataloader(self.train_dataloader)

        # The number of training steps in a single epoch (batch processing steps)
        if isinstance(self.train_dataloader, Sized):
            self.epoch_train_steps = len(self.train_dataloader)

        # Convert to effective global steps (optimizer updates) with gradient accumulation
        effective_epoch_steps = (
            self.epoch_train_steps // self.args.gradient_accumulation_steps
        )

        # The total number of global training steps (optimizer updates) in all epochs.
        # num_train_epochs < 0 means "no epoch cap" -- rely entirely on max_steps.
        if self.args.num_train_epochs >= 0:
            self.max_steps = self.args.num_train_epochs * effective_epoch_steps
        else:
            self.max_steps = sys.maxsize

        # If an explicit max_steps limit is specified, constrain to it.
        if self.args.max_steps >= 0:
            self.max_steps = min(self.args.max_steps, self.max_steps)

        if self.state is not None:
            self.state.max_steps = self.max_steps

    def _init_state(self) -> TrainerState:
        """
        Initialize trainer state for tracking training progress.

        Creates TrainerState with training step counts, batch sizes, and process info.
        State is saved/restored with checkpoints to resume training accurately.

        Key state fields:
        - global_step: Total optimizer updates since training start (0-indexed)
        - raw_epoch: Integer epoch counter, increments at end of each dataset iteration
        - epoch_start_step: Global step when raw_epoch was last incremented
        - epoch: Continuous value = raw_epoch + fractional progress through current epoch
        - best_metric/best_model_checkpoint: Best model tracking for load_best_model_at_end

        Returns:
            TrainerState: Initialized state object
        """

        if self.do_train:
            assert has_batch_size(self.train_dataloader)
            return TrainerState(
                max_steps=self.max_steps,
                logging_steps=self.args.logging_steps,
                eval_steps=self.args.eval_steps,
                num_train_epochs=int(self.args.num_train_epochs),
                train_batch_size=self.train_dataloader.batch_size,
                epoch_train_steps=self.epoch_train_steps,
                is_local_process_zero=self.is_local_process_zero,
                is_world_process_zero=self.is_world_process_zero,
                num_processes=self.num_processes,
                save_steps=self.args.save_steps,
                # Initialize best model tracking
                best_metric=None,
                best_model_checkpoint=None,
                max_eval_steps=self.args.max_eval_steps,
            )
        else:
            return TrainerState(
                max_steps=0,
                logging_steps=0,
                eval_steps=0,
                num_train_epochs=0,
                train_batch_size=0,
                epoch_train_steps=0,
                is_local_process_zero=self.is_local_process_zero,
                is_world_process_zero=self.is_world_process_zero,
                num_processes=self.num_processes,
                save_steps=0,
                # Initialize best model tracking
                best_metric=None,
                best_model_checkpoint=None,
                max_eval_steps=self.args.max_eval_steps,
            )

    def _end_train_loop(
        self, start_time: float | None, train_steps: int
    ) -> dict[str, int | float]:
        if train_steps == -1:
            train_steps = 0
        if start_time:
            runtime = time.time() - start_time
        else:
            runtime = None
        # Calculate effective batch size including gradient accumulation
        effective_batch_size = self._calculate_effective_batch_size()
        total_train_samples = effective_batch_size * train_steps
        metrics = self._speed_metrics(
            "train", runtime, total_train_samples, train_steps
        )
        metrics["epoch"] = self.state.epoch
        metrics["effective_batch_size"] = effective_batch_size

        # Add token and FLOP metrics
        metrics["total_tokens"] = self.state.num_input_tokens_seen
        if runtime and runtime > 0:
            metrics["tokens_per_second"] = round(
                self.state.num_input_tokens_seen / runtime, 2
            )

        if self.state.total_flos > 0:
            metrics["total_flops"] = self.state.total_flos
            if runtime and runtime > 0:
                metrics["flops_per_second"] = round(self.state.total_flos / runtime, 2)

        return metrics

    def _calculate_effective_batch_size(self) -> int:
        """
        Calculate effective batch size accounting for gradient accumulation and parallelism.

        The effective batch size is the total number of examples processed per optimizer update.

        Calculation depends on parallelism strategy:
        - Data parallelism (DDP): Multiply by num_processes (each process sees different batch)
        - Pipeline parallelism: Don't multiply (same batch flows through pipeline stages)
        - Model parallelism: Don't multiply (same batch processed by different model shards)
        - Combined: Don't multiply (same batch)

        Formula: base_batch = per_device_batch * gradient_accumulation_steps
                 effective = base_batch * num_processes (only for data parallel)

        Returns:
            Total number of examples per optimizer update
        """
        base_batch_size = (
            self.state.train_batch_size * self.args.gradient_accumulation_steps
        )

        # Check if any form of model parallelism is being used
        # If so, don't multiply by num_processes since the same batch is processed
        # across different parts of the model (stages and/or shards)
        is_pipeline_parallel = (
            hasattr(self, "_is_pipeline_parallel") and self._is_pipeline_parallel()
        )
        is_model_parallel = (
            hasattr(self, "_is_model_parallel") and self._is_model_parallel()
        )

        if is_pipeline_parallel or is_model_parallel:
            # Any form of model parallelism: don't multiply by num_processes
            return base_batch_size
        else:
            # Data parallel (DDP) or single process: multiply by num_processes
            return self.num_processes * base_batch_size

    def _is_pipeline_parallel(self) -> bool:
        """
        Indicate if trainer uses pipeline parallelism.

        Used for effective batch size calculation. Pipeline parallel trainers process
        the same batch across different pipeline stages, so don't multiply by num_processes.

        Returns:
            False for single-device Trainer, True in PipelineTrainer
        """
        return False

    def _is_model_parallel(self) -> bool:
        """
        Indicate if trainer uses model parallelism (tensor/expert parallelism).

        Used for effective batch size calculation. Model parallel trainers process
        the same batch across different model shards, so don't multiply by num_processes.

        Returns:
            False for single-device Trainer, True in model parallel trainers
        """
        return False

    def _prediction_step(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Dict[str, Tensor | None]:
        """
        Perform a single evaluation batch forward pass.

        Computes loss without gradient computation (wrapped in @torch.no_grad()).
        Uses unscaled loss (via loss_fn.no_rescale()) for accurate eval metrics.

        Parameters
        ----------
        input_dict : dict
            Model inputs (input_ids, attention_mask, etc.).
        labels : Tensor
            Target labels for loss computation.

        Returns
        -------
        dict
            Dictionary with ``'loss'``, ``'logits'``, and ``'labels'`` tensors.
        """
        assert self.model is not None
        assert isinstance(self.loss_fn, RescaleLoss)
        if self.use_fused_loss:
            input_dict["return_hidden_states"] = True  # type: ignore[assignment]
        with self.loss_fn.no_rescale(), self.amp_context.autocast():
            outputs = self.model(**input_dict)
            logits = logits_from_outputs(outputs)
            loss = self.loss_fn(logits, labels)

        loss = self._distributed_loss(loss.detach())
        return {
            "loss": loss,
            "logits": logits.detach(),
            "labels": labels,
        }

    def _speed_metrics(
        self, prefix: str, runtime: float | None, samples: int, steps: int
    ) -> dict[str, int | float]:
        if runtime is not None and steps > 0:
            samples_per_second = round(samples / runtime, 3)
            steps_per_second = round(steps / runtime, 3)
        else:
            samples_per_second = float("nan")
            steps_per_second = float("nan")
        metrics = {
            f"{prefix}_runtime": runtime,
            f"{prefix}_samples": samples,
            "step": steps,
            f"{prefix}_samples_per_second": samples_per_second,
            f"{prefix}_steps_per_second": steps_per_second,
        }
        return metrics

    def _log_step(
        self,
        loss_log: list[Tensor],
        total_norm_log: list[Tensor],
        tokens_log: list[Tensor],
    ):
        self._update_training_steps()

        logs: dict[str, Any] = {
            "epoch": self.state.epoch,
        }

        if not len(loss_log):
            # No losses to log - this can happen with gradient accumulation
            # when logging is called multiple times in the same step
            return

        # Reduce loss: compute local mean then synchronize across ranks at log step
        # (mirrors the deferred token synchronization pattern)
        mean_loss = torch.stack(loss_log).mean()
        mean_loss = self._distributed_loss(mean_loss)
        logs["loss"] = mean_loss.item()
        loss_log.clear()

        if len(total_norm_log):
            total_norm = torch.stack(total_norm_log)
            logs["grad_norm"] = total_norm.square().mean().sqrt().item()
            logs["max_grad_norm"] = total_norm.max().item()
            total_norm_log.clear()
        else:
            raise ValueError("No grad norm!")

        # Synchronize token counts across ranks at log step (amortizes all_reduce cost)
        if len(tokens_log):
            local_tokens = torch.stack(tokens_log).sum()
            synced_tokens = self._distributed_tokens(local_tokens)
            tokens_count = int(synced_tokens.item())
            self.state.num_input_tokens_seen += tokens_count
            self.state.total_flos += self._flops_per_token * tokens_count
            logs["tokens"] = tokens_count
            logs["total_tokens"] = self.state.num_input_tokens_seen
            logs["total_flos"] = self.state.total_flos
            tokens_log.clear()

        if self.lr_scheduler is not None:
            last_lr = self.lr_scheduler.get_last_lr()[0]
            if last_lr is not None:
                if torch.is_tensor(last_lr):
                    last_lr = last_lr.item()
                logs["learning_rate"] = last_lr

        # Capture peak CUDA memory for this rank, gather across ranks, and reset
        # stats so each interval reflects only the high-water mark since the previous
        # log step.  Doing this here (rather than in individual callbacks) ensures
        # the reset happens exactly once regardless of how many callbacks query
        # memory statistics, and the per-rank list is uniformly available to all
        # logging callbacks (JsonLogger, PeakMemory, ProgressCallback).
        if torch.cuda.is_available():
            device = torch.cuda.current_device()
            local_peak = int(torch.cuda.max_memory_allocated(device=device))
            torch.cuda.reset_peak_memory_stats(device=device)
            logs["peak_mem_allocated"] = self._distributed_peak_mem(local_peak)

        # Allow callbacks to add entries to the logs before on_log fires
        self._dispatch_event(
            "on_log_step",
            logs=logs,
        )

        # Dispatch on_log to all logging callbacks
        return self.log(logs)

    def _prepare_batch(
        self, batch: Dict[str, Tensor]
    ) -> tuple[Dict[str, Tensor], Tensor]:
        """
        Move batch tensors to target device and extract labels.

        Parameters
        ----------
        batch : dict
            Dictionary of tensors from the dataloader. Must include a ``'labels'`` key.

        Returns
        -------
        tuple of (dict, Tensor)
            ``(input_dict, labels)`` with labels separated for loss computation.
        """
        batch = {k: v.to(self.args.device, non_blocking=True) for k, v in batch.items()}
        labels = batch.pop("labels")

        return (batch, labels)

    def _distributed_loss(self, loss: Tensor) -> Tensor:
        """
        Reduce loss across all processes for accurate logging.

        Single-device trainer just returns the input. Distributed trainers (AccelTrainer,
        PipelineTrainer) override this to all-reduce loss values so logging reflects
        the average loss across all processes/devices.

        See src/forgather/ml/trainer/accelerate/accel_trainer.py for distributed implementation.

        Parameters
        ----------
        loss : Tensor
            Loss tensor from the current process.

        Returns
        -------
        Tensor
            Loss tensor (single-device) or all-reduced loss (distributed).
        """
        return loss

    @override
    def load_checkpoint(self, *args, **kwargs) -> None:
        super().load_checkpoint(*args, **kwargs)
        self._update_training_steps()

__init__(*, args, distributed_env, optimizer_factory=None, optimizer_cls_and_kwargs=None, lr_scheduler_factory=None, enable_activation_checkpoint_fn=enable_hf_activation_checkpointing, fused_loss_factory=None, optimizer_groups=None, **kwargs)

Parameters:

Name Type Description Default
args TrainingArguments or dict

Training configuration. Accepts a TrainingArguments instance or a plain dict (converted automatically). See TrainingArguments for all options.

required
distributed_env DistributedEnvInterface

Distributed environment object providing rank/device information. Use DistributedEnvironment for real training or StaticDistributedEnvironment for single-process runs. The trainer asserts world_size == 1 for non-distributed subclasses.

required
optimizer_factory callable

Callable that accepts model.named_parameters() and returns an optimizer. If not provided and optimizer_cls_and_kwargs is also absent, defaults to AdamW with parameters from args.

None
optimizer_cls_and_kwargs tuple

HuggingFace Trainer-compatible alternative to optimizer_factory. A (optimizer_class, kwargs_dict) pair. Ignored when optimizer_factory is supplied.

None
lr_scheduler_factory callable

Callable that accepts the optimizer and returns an LR scheduler. If not provided and args.lr_scheduler_type is set, falls back to transformers.get_scheduler.

None
enable_activation_checkpoint_fn callable

Called as fn(rank, model) to enable activation (gradient) checkpointing. Defaults to enable_hf_activation_checkpointing. Set to None to disable activation checkpointing even when args.gradient_checkpointing=True.

enable_hf_activation_checkpointing
fused_loss_factory callable

If provided, enables fused logits-loss computation. Called with the model's output embedding layer and returns a loss function. Requires the model to support get_output_embeddings() and return_hidden_states.

None
optimizer_groups OptimGroupMap

Parameter group configuration for the optimizer. Allows different hyperparameters (lr, weight_decay) for different parameter subsets.

None
**kwargs

Passed to BaseTrainer: model, model_init, train_dataset, eval_dataset, loss_fn, data_collator, processing_class, callbacks, optimizer, lr_scheduler.

{}
Source code in src/forgather/ml/trainer/trainer.py
def __init__(
    self,
    *,
    args: TTrainingArguments | dict,
    distributed_env: DistributedEnvInterface,
    optimizer_factory: Optional[OptimizerFactoryT] = None,
    # Alternative, for compatibility with transformers.Trainer
    optimizer_cls_and_kwargs: Optional[
        Tuple[Type[OptimizerT], Dict[str, Any]]
    ] = None,
    lr_scheduler_factory: Optional[LRSchedulerFactoryT] = None,
    enable_activation_checkpoint_fn: Optional[
        EnableCheckpointFnT
    ] = enable_hf_activation_checkpointing,
    fused_loss_factory: Optional[FusedLossFactoryT] = None,
    optimizer_groups: Optional[OptimGroupMap] = None,
    **kwargs,
):
    """
    Parameters
    ----------
    args : TrainingArguments or dict
        Training configuration. Accepts a ``TrainingArguments`` instance or a
        plain dict (converted automatically). See ``TrainingArguments`` for all options.
    distributed_env : DistributedEnvInterface
        Distributed environment object providing rank/device information.
        Use ``DistributedEnvironment`` for real training or ``StaticDistributedEnvironment``
        for single-process runs. The trainer asserts ``world_size == 1`` for non-distributed
        subclasses.
    optimizer_factory : callable, optional
        Callable that accepts ``model.named_parameters()`` and returns an optimizer.
        If not provided and ``optimizer_cls_and_kwargs`` is also absent, defaults to
        AdamW with parameters from ``args``.
    optimizer_cls_and_kwargs : tuple, optional
        HuggingFace Trainer-compatible alternative to ``optimizer_factory``.
        A ``(optimizer_class, kwargs_dict)`` pair. Ignored when ``optimizer_factory``
        is supplied.
    lr_scheduler_factory : callable, optional
        Callable that accepts the optimizer and returns an LR scheduler.
        If not provided and ``args.lr_scheduler_type`` is set, falls back to
        ``transformers.get_scheduler``.
    enable_activation_checkpoint_fn : callable, optional
        Called as ``fn(rank, model)`` to enable activation (gradient) checkpointing.
        Defaults to ``enable_hf_activation_checkpointing``. Set to ``None`` to
        disable activation checkpointing even when ``args.gradient_checkpointing=True``.
    fused_loss_factory : callable, optional
        If provided, enables fused logits-loss computation. Called with the model's
        output embedding layer and returns a loss function. Requires the model to
        support ``get_output_embeddings()`` and ``return_hidden_states``.
    optimizer_groups : OptimGroupMap, optional
        Parameter group configuration for the optimizer. Allows different
        hyperparameters (lr, weight_decay) for different parameter subsets.
    **kwargs
        Passed to ``BaseTrainer``: ``model``, ``model_init``, ``train_dataset``,
        ``eval_dataset``, ``loss_fn``, ``data_collator``, ``processing_class``,
        ``callbacks``, ``optimizer``, ``lr_scheduler``.
    """
    if isinstance(args, dict):
        args = cast(TTrainingArguments, from_dict(TrainingArguments, args))
    super().__init__(args=args, **kwargs)

    # HF Trainer compatibility.
    if not optimizer_factory:
        if not optimizer_cls_and_kwargs:
            optimizer_factory = partial(  # type: ignore[assignment]
                torch.optim.AdamW,
                lr=args.learning_rate,
                betas=(args.adam_beta1, args.adam_beta2),
                weight_decay=args.weight_decay,
                eps=args.adam_epsilon,
            )
        else:
            optimizer_factory = partial(
                optimizer_cls_and_kwargs[0], **optimizer_cls_and_kwargs[1]
            )
    self.optimizer_groups = optimizer_groups
    self.dist = distributed_env
    self.optimizer_factory = optimizer_factory
    self.lr_scheduler_factory = lr_scheduler_factory
    self.enable_activation_checkpoint_fn = enable_activation_checkpoint_fn
    self.fused_loss_factory = fused_loss_factory
    self.use_fused_loss = False
    assert (self.model is not None) or (
        self.model_init is not None
    ), "Either a model or a model constructor must be specified."

    assert (
        self.args.max_grad_norm is None or not self.args.fuse_optim_with_backward
    ), "max_grad_norm is incompatible with fuse_optim_with_backward"

    assert (
        self.args.gradient_accumulation_steps == 1
        or not self.args.fuse_optim_with_backward
    ), "gradient_accumulation_steps={self.args.gradient_accumulation_steps} is incompatible with fuse_optim_with_backward"

    assert (
        self.args.mixed_precision != "fp16"
        or not self.args.fuse_optim_with_backward
    ), (
        "fp16 mixed precision with GradScaler is incompatible with fuse_optim_with_backward. "
        "Use mixed_precision='bf16' instead (no GradScaler needed)."
    )

    assert (
        self.loss_fn or self.args.gradient_accumulation_steps == 1
    ), f"gradient_accumulation_steps [{self.args.gradient_accumulation_steps}] > 1 requires loss_fn"

    if self.data_collator is None:
        self.data_collator = torch.utils.data.default_collate

load_best_model()

Load the best model from the best checkpoint.

Called at end of training when load_best_model_at_end=True to restore the checkpoint with the best metric value seen during training.

Source code in src/forgather/ml/trainer/trainer.py
def load_best_model(self) -> None:
    """
    Load the best model from the best checkpoint.

    Called at end of training when load_best_model_at_end=True to restore
    the checkpoint with the best metric value seen during training.
    """
    if not self.state.best_model_checkpoint:
        logger.warning("No best model checkpoint available to load")
        return

    if not os.path.exists(self.state.best_model_checkpoint):
        logger.warning(
            f"Best model checkpoint path does not exist: {self.state.best_model_checkpoint}"
        )
        return

    logger.info(f"Loading best model from {self.state.best_model_checkpoint}")
    self.load_checkpoint(self.state.best_model_checkpoint)

forgather.ml.trainer.trainer.TrainingArguments dataclass

Bases: BaseTrainingArguments

Training arguments specific to the simple Trainer implementation.

Extends BaseTrainingArguments with memory optimization and model construction options. Maintains compatibility with HuggingFace Trainer API where possible.

Source code in src/forgather/ml/trainer/trainer.py
@dataclass(kw_only=True)
class TrainingArguments(BaseTrainingArguments):
    """
    Training arguments specific to the simple Trainer implementation.

    Extends BaseTrainingArguments with memory optimization and model construction options.
    Maintains compatibility with HuggingFace Trainer API where possible.
    """

    # Ratio of reserved to total GPU memory to trigger GC
    # If OOM from fragmentation, lower ratio
    gc_threshold: float = 0.5

    # Construct model on meta-device and materialize directly on device
    # default: Construct model on CPU and move to device. Safest option, works in all cases.
    #          Uses no_init_weights() context when loading checkpoint to skip initialization.
    # device:  Construct model directly on device with initialization. Faster than default
    #          but may fail when model needs sharding across devices. Use when checkpoint
    #          doesn't save all buffers (e.g., RoPE).
    # meta:    Construct on meta device (no memory backing) and materialize as empty tensors
    #          on target device. Fastest option but requires loading checkpoint since model
    #          is uninitialized. May have issues with buffers not saved in checkpoint.
    construct_model_on: str = "default"  # "default" | "meta" | "device"

    # https://pytorch.org/blog/activation-checkpointing-techniques/
    # Requires "torch_compile = True" option
    activation_memory_budget: float | None = None

    # Combine gradient calculation with optimizer step, to save memory.
    # As each gradient is computed during backward(), it's immediately applied by the
    # optimizer and freed. Incompatible with max_grad_norm and gradient_accumulation_steps > 1.
    # Greatest memory savings when combined with gradient checkpointing.
    # https://docs.pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
    fuse_optim_with_backward: bool = False

    # The step at which to start collecting speed metrics
    # We default to 1, to remove the effects from torch.compile().
    # Set this to 0 to include all steps or > 0 for compile warmup time.
    speed_metrics_start_step: int = 1

    # If the train dataset has a `set_epoch(epoch: int)` method, call it at the start of each epoch.
    set_dataset_epoch: bool = True

    # Debug grouped optimizer assignments
    debug_optimizer_groups: bool = False

Distributed Data Parallel (DDP) Trainer

forgather.ml.trainer.ddp.ddp_trainer.DDPTrainer

Bases: Trainer[TDDPTrainingArguments], Generic[TDDPTrainingArguments]

Multi-GPU trainer using DistributedDataParallel (DDP).

Wraps the base Trainer with DDP for data-parallel training across multiple GPUs or nodes. Each rank receives a different batch; gradients are all-reduced automatically after each backward pass. Optionally uses PostLocalSGD for bandwidth-limited environments.

Launch with torchrun (or the forgather train -d ... shortcut)::

torchrun --nproc_per_node=4 train.py

Key differences from single-GPU Trainer:

  • Model wrapped in torch.nn.parallel.DistributedDataParallel
  • Gradient accumulation uses DDP's no_sync() context to skip reductions on intermediate steps
  • Dataset loading via DataloaderDispatcher (dispatch_batches=True, default) or SynchronizedDataLoader (dispatch_batches=False)
  • Optional PostLocalSGD communication hook for reduced all-reduce frequency
Source code in src/forgather/ml/trainer/ddp/ddp_trainer.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
class DDPTrainer(Trainer[TDDPTrainingArguments], Generic[TDDPTrainingArguments]):
    """
    Multi-GPU trainer using DistributedDataParallel (DDP).

    Wraps the base ``Trainer`` with DDP for data-parallel training across multiple GPUs
    or nodes. Each rank receives a different batch; gradients are all-reduced automatically
    after each backward pass. Optionally uses PostLocalSGD for bandwidth-limited environments.

    Launch with ``torchrun`` (or the ``forgather train -d ...`` shortcut)::

        torchrun --nproc_per_node=4 train.py

    Key differences from single-GPU ``Trainer``:

    - Model wrapped in ``torch.nn.parallel.DistributedDataParallel``
    - Gradient accumulation uses DDP's ``no_sync()`` context to skip reductions on
      intermediate steps
    - Dataset loading via ``DataloaderDispatcher`` (``dispatch_batches=True``, default)
      or ``SynchronizedDataLoader`` (``dispatch_batches=False``)
    - Optional PostLocalSGD communication hook for reduced all-reduce frequency
    """

    args: TDDPTrainingArguments
    gradient_accumulation_step: int

    def __init__(
        self,
        *,
        args: TDDPTrainingArguments | dict,
        fused_loss_factory: Optional[FusedLossFactoryT] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        args : DDPTrainingArguments or dict
            Training configuration including DDP-specific options (``ddp``,
            ``post_local_sgd``, ``dispatch_batches``). Accepts a dict for
            programmatic construction.
        fused_loss_factory : callable, optional
            Factory for fused logits-loss computation. See ``Trainer`` for details.
        **kwargs
            Forwarded to ``Trainer.__init__``: ``distributed_env``, ``model``,
            ``model_init``, ``train_dataset``, ``eval_dataset``, ``optimizer_factory``,
            ``lr_scheduler_factory``, ``callbacks``, etc.
        """
        if isinstance(args, dict):
            args = cast(TDDPTrainingArguments, from_dict(DDPTrainingArguments, args))

        super().__init__(args=args, fused_loss_factory=fused_loss_factory, **kwargs)

        assert (
            not self.args.fuse_optim_with_backward
        ), "DDPTrainer does not support option fuse_optim_with_backward"

    @override
    def _init_distributed(self):
        assert (
            dist.is_available and dist.is_initialized()
        ) or self.dist.world_size == 1, (
            "DDP trainer requires that torch.distributed has been initialized."
        )

        if self.dist.world_size == 1:
            return super()._init_distributed()

        self.is_local_process_zero = self.dist.local_rank == 0
        self.is_world_process_zero = self.dist.rank == 0
        self.num_processes = self.dist.world_size

        self.mesh = init_device_mesh(
            self.dist.device_type,
            (self.dist.world_size,),
            mesh_dim_names=("data_parallel",),
        )
        self.ddp_group = self.mesh.get_group(0)  # data-parallel group

    @override
    def _wrap(
        self,
    ) -> None:
        """
        Wrap assets for DDP
        """
        if self.dist.world_size == 1:
            return super()._wrap()

        self.model = DDP(
            self.model,
            device_ids=[self.args.device] if self.dist.device_type != "cpu" else None,
            process_group=self.ddp_group,
            broadcast_buffers=self.args.ddp.broadcast_buffers,
            init_sync=self.args.ddp.init_sync,
            bucket_cap_mb=self.args.ddp.bucket_cap_mb,
            find_unused_parameters=self.args.ddp.find_unused_parameters,
            gradient_as_bucket_view=self.args.ddp.gradient_as_bucket_view,
            static_graph=self.args.ddp.static_graph,
            skip_all_reduce_unused_params=self.args.ddp.skip_all_reduce_unused_params,
        )

        dispatch_eval = self._dispatch_eval_batches()

        if self.train_dataloader:
            if self.args.dispatch_batches:
                # Use DataloaderDispatcher for centralized batch loading
                self.train_dataloader = DataloaderDispatcher(
                    cast(DataLoader, self.train_dataloader),
                    self.mesh,
                    self.args.device,
                )
            else:
                # Use SynchronizedDataLoader for sharded datasets
                # Ensures all ranks agree on when to stop iterating
                self.train_dataloader = SynchronizedDataLoader(
                    self.train_dataloader,
                    device=self.args.device,
                    process_group=self.ddp_group,
                )

        if self.eval_dataloader:
            if dispatch_eval:
                self.eval_dataloader = DataloaderDispatcher(
                    cast(DataLoader, self.eval_dataloader),
                    self.mesh,
                    self.args.device,
                )
            else:
                self.eval_dataloader = SynchronizedDataLoader(
                    self.eval_dataloader,
                    device=self.args.device,
                    process_group=self.ddp_group,
                )

        # Optimizer is only required for training. When running under
        # ``trainer.evaluate()`` alone, ``do_train`` is False and no optimizer
        # is constructed, which is fine — we just need the DDP wrapping above.
        if self.do_train:
            assert self.optimizer is not None
        if self.do_train and self.args.post_local_sgd.enabled:
            logger.info(f"Enabling post-local-SGD: {self.args.post_local_sgd}")
            self.post_local_sgd_state = PostLocalSGDState(
                process_group=self.ddp_group,
                subgroup=None,
                start_localSGD_iter=self.args.post_local_sgd.start_step,
                post_local_gradient_allreduce=self.args.post_local_sgd.post_local_gradient_allreduce,
            )
            self.model.register_comm_hook(self.post_local_sgd_state, post_localSGD_hook)
            self.optimizer = PostLocalSGDOptimizer(
                optim=cast(Optimizer, self.optimizer),
                averager=averagers.PeriodicModelAverager(
                    period=self.args.post_local_sgd.period,
                    warmup_steps=self.args.post_local_sgd.start_step,
                ),
            )

    @override
    def unwrapped_model(self) -> torch.nn.Module:
        """
        Get and returned the wrapped model

        In the case of DDP, the original model is stored in the model's "module" attribute.
        """
        if self.dist.world_size == 1:
            return super().unwrapped_model()

        assert self.model
        return cast(Module, self.model.module)

    def _dispatch_eval_batches(self) -> bool:
        """Resolve the effective eval-time dispatch_batches setting.

        ``dispatch_eval_batches`` overrides ``dispatch_batches`` for eval only.
        ``None`` (the default) means "follow ``dispatch_batches``".
        """
        if self.args.dispatch_eval_batches is not None:
            return self.args.dispatch_eval_batches
        return self.args.dispatch_batches

    @override
    @torch.no_grad()
    def _eval_loop(self) -> Dict[str, float]:
        """
        Evaluation loop for DDP training.

        For dispatch_eval_batches=True or single-GPU, delegates to the base
        Trainer._eval_loop().

        For dispatch_eval_batches=False, bypasses SynchronizedDataLoader's
        MIN-based synchronization to let each rank process ALL its local
        validation data independently. This prevents data loss when token
        packing creates highly uneven shard lengths across ranks (e.g., shard
        lengths [85, 167, 239, 247] would otherwise be truncated to the
        shortest rank's count).
        """
        if self.dist.world_size == 1 or self._dispatch_eval_batches():
            return super()._eval_loop()
        return self._eval_loop_all_shards()

    def _check_eval_shards_nonempty(self, iterator) -> Any:
        """Confirm every rank's eval iterator yields at least one batch.

        Fetches one batch from ``iterator`` on each rank, then runs an
        ``all_reduce(SUM)`` over per-rank "has-first-batch" flags. If any
        rank failed to fetch, raises ``RuntimeError`` *on every rank*
        with the empty-rank set listed in the diagnostic.

        Returns the first batch on success so the caller can prepend it
        back into the eval loop's iterator.

        This is the contract that lets ``_eval_loop_all_shards`` rely on
        every rank having a valid "last seen" batch to reuse as a dummy
        once its real shard exhausts.
        """
        try:
            first_batch = next(iterator)
            local_has_first = 1
        except StopIteration:
            first_batch = None
            local_has_first = 0
        per_rank_has_first = torch.zeros(
            self.dist.world_size,
            dtype=torch.int32,
            device=self.args.device,
        )
        per_rank_has_first[self.dist.rank] = local_has_first
        dist.all_reduce(per_rank_has_first, op=dist.ReduceOp.SUM)
        empty_ranks = [
            r for r, has in enumerate(per_rank_has_first.tolist()) if not has
        ]
        if empty_ranks:
            raise RuntimeError(self._zero_eval_batches_message(empty_ranks=empty_ranks))
        return first_batch

    @torch.no_grad()
    def _eval_loop_all_shards(self) -> Dict[str, float]:
        """
        Eval loop that lets each rank process all its local validation data.

        Two design constraints shape this loop:

        1. **No data loss across uneven shards.** ``SynchronizedDataLoader``'s
           per-step MIN sync would stop every rank as soon as the *shortest*
           shard exhausted, truncating the longer shards (e.g. shard lengths
           ``[85, 167, 239, 247]`` would all drop to 85). For an eval pass we
           want every example evaluated exactly once, not just the first 85.

        2. **Symmetric DDP collectives.** The wrapped DDP module participates
           in collectives (buffer broadcast, parameter-sync hooks) on every
           ``self.model(...)`` call, even under ``torch.no_grad()``. If one
           rank skips a forward call while peers run it, the process group
           deadlocks. So every rank must call ``self.model(...)`` the same
           number of times across the entire loop.

        Strategy:

        - **Pre-flight** (`_check_eval_shards_nonempty`): every rank fetches
          its first batch and gathers per-rank "has-first-batch" flags via
          ``all_reduce(SUM)``. If any rank is empty, every rank raises with
          the empty-rank set and a remediation pointer. This guarantees the
          rest of the loop has a usable per-rank "last-seen" batch shape.
        - **Symmetric loop**: every iteration, every rank calls
          ``self.model(...)``. Ranks with a real next batch use it and
          accumulate loss; ranks that have exhausted their shard reuse
          their last-seen real batch as a *dummy* (same shape, no
          recompile, ignored result) so the DDP collectives stay
          balanced. A per-step ``all_reduce(MAX)`` over a "rank still has
          real data" flag terminates the loop once every shard is
          exhausted.
        - **Aggregation**: ``total_loss`` and ``step_count`` are
          ``all_reduce(SUM)`` across ranks, so only real batches
          contribute to ``eval_loss``.
        """
        assert self.model is not None
        assert self.eval_dataloader is not None
        assert isinstance(self.loss_fn, RescaleLoss)

        # Access the underlying dataloader, bypassing SynchronizedDataLoader
        assert isinstance(self.eval_dataloader, SynchronizedDataLoader)
        raw_dataloader = self.eval_dataloader._dataloader

        with set_train(self.model, False):
            total_loss = torch.zeros(1, device=self.args.device)
            local_real_steps = 0
            iterator = iter(raw_dataloader)

            # Pre-flight: detect empty shards before any model forward call.
            first_batch = self._check_eval_shards_nonempty(iterator)
            assert first_batch is not None  # post-pre-flight invariant

            # Re-prepend the first batch so the loop can yield it on iter 0.
            iterator = itertools.chain([first_batch], iterator)
            # Last real batch on this rank — reused as a shape-matched dummy
            # once the local iterator exhausts. Always populated post-pre-flight.
            last_real_batch = first_batch

            # Reuse a single tensor for the per-step "any rank has real data"
            # synchronization (MAX as logical OR over int32 flags).
            any_real_tensor = torch.zeros(1, dtype=torch.int32, device=self.args.device)

            while True:
                # Try to get the next real batch on this rank. If none, fall
                # back to the dummy (last_real_batch) so the model.forward()
                # call below still runs with a valid shape; mark this iteration
                # as having no real data for this rank so its loss is discarded.
                try:
                    batch = next(iterator)
                    last_real_batch = batch
                    local_has_real = 1
                except StopIteration:
                    batch = last_real_batch
                    local_has_real = 0

                # Respect max_eval_steps on this rank's own real data only.
                if (
                    local_has_real
                    and self.args.max_eval_steps > 0
                    and local_real_steps >= self.args.max_eval_steps
                ):
                    local_has_real = 0
                    batch = last_real_batch

                # Continue while ANY rank still has real data.
                any_real_tensor.fill_(local_has_real)
                dist.all_reduce(any_real_tensor, op=dist.ReduceOp.MAX)
                if any_real_tensor.item() == 0:
                    break

                # Symmetric forward on every rank. Inline the work from
                # _prediction_step but skip _distributed_loss — that would
                # add an extra per-step all_reduce that we don't need here.
                input_dict, labels = self._prepare_batch(batch)
                if self.use_fused_loss:
                    input_dict["return_hidden_states"] = True  # type: ignore[assignment]
                with self.loss_fn.no_rescale(), self.amp_context.autocast():
                    outputs = self.model(**input_dict)
                    logits = logits_from_outputs(outputs)
                    loss = self.loss_fn(logits, labels)

                # Only real batches contribute to the eval metric.
                if local_has_real:
                    total_loss += loss.detach()
                    local_real_steps += 1

                # Dispatch on every synchronized step so rank 0's progress
                # indicator keeps advancing even after its own shard is done.
                self._dispatch_event("on_prediction_step")

            # Aggregate loss across all ranks. Only real-batch contributions
            # are counted; dummy iterations on exhausted ranks were skipped
            # in the accumulator above.
            step_count = torch.tensor(
                local_real_steps, device=self.args.device, dtype=torch.int64
            )
            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
            dist.all_reduce(step_count, op=dist.ReduceOp.SUM)

            total_steps = step_count.item()
            if total_steps == 0:
                # Defence-in-depth: the pre-flight should have caught this,
                # but if max_eval_steps=0 or some other edge case leaves
                # everyone with zero real steps, surface the diagnostic.
                raise RuntimeError(self._zero_eval_batches_message())
            eval_loss = (total_loss / total_steps).item()

            # Sync dataset state on the underlying StatefulDataLoader
            if isinstance(raw_dataloader, StatefulDataLoader):
                sync_dataset_state_from_dataloader(raw_dataloader)

            metrics = {"eval_loss": eval_loss}
            self._dispatch_event("on_evaluate", metrics=metrics)
            return metrics

    @override
    def _zero_eval_batches_message(
        self, empty_ranks: Optional[List[int]] = None
    ) -> str:
        """Build a diagnostic for the zero-eval-batches failure mode.

        Fires from three places:

        - ``_eval_loop_all_shards`` pre-flight when *some* ranks have an
          empty shard (``empty_ranks`` is the list of those ranks); this
          would otherwise deadlock on asymmetric ``self.model(...)`` calls.
        - ``_eval_loop_all_shards`` post-loop when *every* rank produced
          zero steps (``empty_ranks`` is None or covers every rank).
        - The inherited base ``_eval_loop`` when ``dispatch_eval_batches``
          is effectively True and the rank-0 dispatcher terminated before
          yielding a batch (e.g. it could not assemble one full batch per
          rank from the available eval data); ``empty_ranks`` is None.
        """
        effective_dispatch_eval = self._dispatch_eval_batches()
        partial = (
            empty_ranks is not None and 0 < len(empty_ranks) < self.dist.world_size
        )
        if effective_dispatch_eval:
            header = (
                f"Distributed evaluation produced zero batches across all "
                f"{self.dist.world_size} ranks."
            )
            explanation = (
                "With dispatch_eval_batches=True (effective) rank 0 loads\n"
                "the eval dataloader and broadcasts batches to the other\n"
                "ranks. The dispatcher needs to assemble at least one batch\n"
                "per rank before any rank can step; if rank 0 exhausts the\n"
                "dataloader before that point (small eval split, or\n"
                "dataloader_drop_last=True dropping the only partial batch),\n"
                "every rank reports zero steps."
            )
        elif partial:
            assert empty_ranks is not None
            ranks_str = ", ".join(str(r) for r in empty_ranks)
            plural = "s" if len(empty_ranks) > 1 else ""
            header = (
                f"Distributed evaluation produced zero batches on "
                f"{len(empty_ranks)} of {self.dist.world_size} ranks "
                f"(empty rank{plural}: {ranks_str}).\n"
                "Refusing to enter the eval loop because asymmetric "
                "self.model(...) calls across the DDP process group "
                "would deadlock."
            )
            explanation = (
                "With dispatch_eval_batches=False (effective) each rank\n"
                "evaluates its own shard of the eval dataset; if a shard\n"
                "contains fewer than per_device_eval_batch_size examples\n"
                "and dataloader_drop_last is True, that rank yields no\n"
                "batches. The other ranks would still try to step, and\n"
                "their model.forward() calls participate in DDP collectives\n"
                "that the empty ranks must mirror — so the eval loop would\n"
                "hang. This pre-flight check fails fast on every rank."
            )
        else:
            header = (
                f"Distributed evaluation produced zero batches across all "
                f"{self.dist.world_size} ranks."
            )
            explanation = (
                "With dispatch_eval_batches=False (effective) each rank\n"
                "evaluates its own shard of the eval dataset; if every\n"
                "shard contains fewer than per_device_eval_batch_size\n"
                "examples and dataloader_drop_last is True, every shard\n"
                "is dropped and total_steps collapses to zero."
            )
        return self._format_zero_eval_batches_diagnostic(
            header=header,
            settings=[
                ("dispatch_batches", self.args.dispatch_batches),
                (
                    "dispatch_eval_batches",
                    f"{self.args.dispatch_eval_batches}"
                    f" (effective: {effective_dispatch_eval})",
                ),
                ("dataloader_drop_last", self.args.dataloader_drop_last),
                ("per_device_eval_batch_size", self.args.per_device_eval_batch_size),
                ("max_eval_steps", self.args.max_eval_steps),
                ("world_size", self.dist.world_size),
            ],
            explanation=explanation,
        )

    @override
    def _distributed_loss(self, loss: Tensor) -> Tensor:
        """
        Reduces loss across processes
        """
        if self.dist.world_size == 1:
            return super()._distributed_loss(loss)

        dist.all_reduce(loss, op=dist.ReduceOp.AVG)
        return loss

    @override
    def _distributed_tokens(self, tokens: Tensor) -> Tensor:
        """
        Sum token counts across all DDP ranks.

        In DDP, each rank processes different batches, so token counts must be
        summed across all ranks to get the total tokens processed per step.

        Parameters
        ----------
        tokens : Tensor
            Token count from the current rank.

        Returns
        -------
        Tensor
            Total token count summed across all DDP ranks.
        """
        if self.dist.world_size == 1:
            return super()._distributed_tokens(tokens)

        dist.all_reduce(tokens, op=dist.ReduceOp.SUM)
        return tokens

    @override
    def _distributed_peak_mem(self, local_peak: int) -> list[int]:
        """
        All-gather per-rank peak CUDA memory across DDP ranks.
        """
        if self.dist.world_size == 1:
            return super()._distributed_peak_mem(local_peak)

        value = torch.tensor(
            [int(local_peak)], dtype=torch.long, device=self.args.device
        )
        gathered = [torch.zeros_like(value) for _ in range(self.dist.world_size)]
        dist.all_gather(gathered, value)
        return [int(t.item()) for t in gathered]

    @override
    def _forward_backward_step(
        self,
        *args,
        **kwargs,
    ) -> Tensor:
        """
        Skip gradient reduction when not a gradient sync step.

        This is achieved with DDP's "no_sync" context manager.
        """
        if self.dist.world_size == 1:
            return super()._forward_backward_step(*args, **kwargs)

        with (
            nullcontext()
            if self._should_sync_gradients()
            else cast(DDP, self.model).no_sync()
        ):
            return super()._forward_backward_step(*args, **kwargs)

    @override
    def get_state_components(self) -> List[StateComponent]:
        """
        Get state components for DDP training.

        All training state is always saved to checkpoints. To skip loading a component,
        delete its file from the checkpoint directory.

        DDP uses data parallelism where model and optimizer state are replicated
        across all ranks. DDP automatically synchronizes model weights and gradients,
        so these components use REPLICATED pattern with validation enabled to catch
        synchronization bugs.

        Dataset pattern depends on dispatch_batches setting:
        - dispatch_batches=True: GLOBAL (rank 0 loads and dispatches)
        - dispatch_batches=False: PER_RANK (each rank has independent dataloader)

        Returns
        -------
        list of StateComponent
            State components with REPLICATED sharing patterns for DDP.
        """
        if self.dist.world_size == 1:
            return super().get_state_components()

        components = []

        # Model - REQUIRED, REPLICATED in DDP
        # DDP synchronizes model weights across all ranks
        components.append(
            StateComponent(
                key="model",
                stateful=cast(Stateful, self.unwrapped_model()),
                sharing_pattern=SharingPattern.REPLICATED,
                validate_replication=True,  # Verify DDP synchronization
                validation_level="tensor",  # Good balance of speed vs accuracy
                required=True,  # Model is always required
            )
        )

        # Optimizer - optional, REPLICATED in DDP
        # DDP synchronizes gradients, so optimizer state should be identical
        if self.optimizer is not None:
            components.append(
                StateComponent(
                    key="optimizer",
                    stateful=cast(Stateful, self.optimizer),
                    sharing_pattern=SharingPattern.REPLICATED,
                    validate_replication=True,
                    validation_level="tensor",  # Per-tensor checksums for accurate validation
                    required=False,
                )
            )

        # LR Scheduler - optional, REPLICATED
        # Same schedule across all ranks
        if self.lr_scheduler is not None:
            components.append(
                StateComponent(
                    key="scheduler",
                    stateful=cast(Stateful, self.lr_scheduler),
                    sharing_pattern=SharingPattern.REPLICATED,
                    required=False,
                )
            )

        # Trainer state - optional, REPLICATED
        # Training progress is synchronized across all ranks
        components.append(
            StateComponent(
                key="trainer",
                stateful=self,
                sharing_pattern=SharingPattern.REPLICATED,
                required=False,
            )
        )

        # Dataset state - optional, depends on dispatch_batches setting
        if hasattr(self.train_dataloader, "state_dict"):
            components.append(
                StateComponent(
                    key="dataset",
                    stateful=cast(Stateful, self.train_dataloader),
                    sharing_pattern=self._get_dataset_sharing_pattern(),
                    required=False,
                )
            )

        # RNG state - optional, PER_RANK
        # Each rank needs different random numbers for data augmentation, dropout, etc.
        components.append(
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=False,
            )
        )

        return components

    @override
    def _get_dataset_sharing_pattern(self) -> SharingPattern:
        """
        Determine dataset sharing pattern for DDP training.

        The pattern depends on the dispatch_batches setting:
        - dispatch_batches=True: Uses DataloaderDispatcher where rank 0 loads
          data and broadcasts to all ranks (GLOBAL pattern)
        - dispatch_batches=False: Each rank has independent dataloader iteration
          (PER_RANK pattern)

        Returns
        -------
        SharingPattern
            GLOBAL when ``dispatch_batches=True``, PER_RANK otherwise.
        """
        if self.dist.world_size == 1:
            return super()._get_dataset_sharing_pattern()

        if self.args.dispatch_batches:
            # DataloaderDispatcher: rank 0 loads and broadcasts
            return SharingPattern.GLOBAL
        else:
            # Independent dataloaders per rank
            return SharingPattern.PER_RANK

    @override
    def get_process_groups(self) -> Dict[str, Any]:
        """
        Get named process groups for checkpoint coordination.

        Returns
        -------
        dict
            Mapping of group names to ProcessGroup objects.
            For DDP, returns the data parallel group.
        """
        if self.dist.world_size == 1:
            return super().get_process_groups()

        return {
            "ddp_group": self.ddp_group,
        }

__init__(*, args, fused_loss_factory=None, **kwargs)

Parameters:

Name Type Description Default
args DDPTrainingArguments or dict

Training configuration including DDP-specific options (ddp, post_local_sgd, dispatch_batches). Accepts a dict for programmatic construction.

required
fused_loss_factory callable

Factory for fused logits-loss computation. See Trainer for details.

None
**kwargs

Forwarded to Trainer.__init__: distributed_env, model, model_init, train_dataset, eval_dataset, optimizer_factory, lr_scheduler_factory, callbacks, etc.

{}
Source code in src/forgather/ml/trainer/ddp/ddp_trainer.py
def __init__(
    self,
    *,
    args: TDDPTrainingArguments | dict,
    fused_loss_factory: Optional[FusedLossFactoryT] = None,
    **kwargs,
):
    """
    Parameters
    ----------
    args : DDPTrainingArguments or dict
        Training configuration including DDP-specific options (``ddp``,
        ``post_local_sgd``, ``dispatch_batches``). Accepts a dict for
        programmatic construction.
    fused_loss_factory : callable, optional
        Factory for fused logits-loss computation. See ``Trainer`` for details.
    **kwargs
        Forwarded to ``Trainer.__init__``: ``distributed_env``, ``model``,
        ``model_init``, ``train_dataset``, ``eval_dataset``, ``optimizer_factory``,
        ``lr_scheduler_factory``, ``callbacks``, etc.
    """
    if isinstance(args, dict):
        args = cast(TDDPTrainingArguments, from_dict(DDPTrainingArguments, args))

    super().__init__(args=args, fused_loss_factory=fused_loss_factory, **kwargs)

    assert (
        not self.args.fuse_optim_with_backward
    ), "DDPTrainer does not support option fuse_optim_with_backward"

unwrapped_model()

Get and returned the wrapped model

In the case of DDP, the original model is stored in the model's "module" attribute.

Source code in src/forgather/ml/trainer/ddp/ddp_trainer.py
@override
def unwrapped_model(self) -> torch.nn.Module:
    """
    Get and returned the wrapped model

    In the case of DDP, the original model is stored in the model's "module" attribute.
    """
    if self.dist.world_size == 1:
        return super().unwrapped_model()

    assert self.model
    return cast(Module, self.model.module)

get_state_components()

Get state components for DDP training.

All training state is always saved to checkpoints. To skip loading a component, delete its file from the checkpoint directory.

DDP uses data parallelism where model and optimizer state are replicated across all ranks. DDP automatically synchronizes model weights and gradients, so these components use REPLICATED pattern with validation enabled to catch synchronization bugs.

Dataset pattern depends on dispatch_batches setting: - dispatch_batches=True: GLOBAL (rank 0 loads and dispatches) - dispatch_batches=False: PER_RANK (each rank has independent dataloader)

Returns:

Type Description
list of StateComponent

State components with REPLICATED sharing patterns for DDP.

Source code in src/forgather/ml/trainer/ddp/ddp_trainer.py
@override
def get_state_components(self) -> List[StateComponent]:
    """
    Get state components for DDP training.

    All training state is always saved to checkpoints. To skip loading a component,
    delete its file from the checkpoint directory.

    DDP uses data parallelism where model and optimizer state are replicated
    across all ranks. DDP automatically synchronizes model weights and gradients,
    so these components use REPLICATED pattern with validation enabled to catch
    synchronization bugs.

    Dataset pattern depends on dispatch_batches setting:
    - dispatch_batches=True: GLOBAL (rank 0 loads and dispatches)
    - dispatch_batches=False: PER_RANK (each rank has independent dataloader)

    Returns
    -------
    list of StateComponent
        State components with REPLICATED sharing patterns for DDP.
    """
    if self.dist.world_size == 1:
        return super().get_state_components()

    components = []

    # Model - REQUIRED, REPLICATED in DDP
    # DDP synchronizes model weights across all ranks
    components.append(
        StateComponent(
            key="model",
            stateful=cast(Stateful, self.unwrapped_model()),
            sharing_pattern=SharingPattern.REPLICATED,
            validate_replication=True,  # Verify DDP synchronization
            validation_level="tensor",  # Good balance of speed vs accuracy
            required=True,  # Model is always required
        )
    )

    # Optimizer - optional, REPLICATED in DDP
    # DDP synchronizes gradients, so optimizer state should be identical
    if self.optimizer is not None:
        components.append(
            StateComponent(
                key="optimizer",
                stateful=cast(Stateful, self.optimizer),
                sharing_pattern=SharingPattern.REPLICATED,
                validate_replication=True,
                validation_level="tensor",  # Per-tensor checksums for accurate validation
                required=False,
            )
        )

    # LR Scheduler - optional, REPLICATED
    # Same schedule across all ranks
    if self.lr_scheduler is not None:
        components.append(
            StateComponent(
                key="scheduler",
                stateful=cast(Stateful, self.lr_scheduler),
                sharing_pattern=SharingPattern.REPLICATED,
                required=False,
            )
        )

    # Trainer state - optional, REPLICATED
    # Training progress is synchronized across all ranks
    components.append(
        StateComponent(
            key="trainer",
            stateful=self,
            sharing_pattern=SharingPattern.REPLICATED,
            required=False,
        )
    )

    # Dataset state - optional, depends on dispatch_batches setting
    if hasattr(self.train_dataloader, "state_dict"):
        components.append(
            StateComponent(
                key="dataset",
                stateful=cast(Stateful, self.train_dataloader),
                sharing_pattern=self._get_dataset_sharing_pattern(),
                required=False,
            )
        )

    # RNG state - optional, PER_RANK
    # Each rank needs different random numbers for data augmentation, dropout, etc.
    components.append(
        StateComponent(
            key="rng",
            stateful=RNGState(),
            sharing_pattern=SharingPattern.PER_RANK,
            required=False,
        )
    )

    return components

get_process_groups()

Get named process groups for checkpoint coordination.

Returns:

Type Description
dict

Mapping of group names to ProcessGroup objects. For DDP, returns the data parallel group.

Source code in src/forgather/ml/trainer/ddp/ddp_trainer.py
@override
def get_process_groups(self) -> Dict[str, Any]:
    """
    Get named process groups for checkpoint coordination.

    Returns
    -------
    dict
        Mapping of group names to ProcessGroup objects.
        For DDP, returns the data parallel group.
    """
    if self.dist.world_size == 1:
        return super().get_process_groups()

    return {
        "ddp_group": self.ddp_group,
    }

forgather.ml.trainer.ddp.ddp_trainer.DDPTrainingArguments dataclass

Bases: TrainingArguments

Source code in src/forgather/ml/trainer/ddp/ddp_trainer.py
@dataclass(kw_only=True)
class DDPTrainingArguments(TrainingArguments):
    # Load and preprocess all batches on rank-0, then dispatch to other ranks
    # All ranks are sent full batches, as specified by `per_device_train_batch_size`, where
    # the total effective batch size is per_device_train_batch_size * world_size
    #
    # This avoid the need to manually specify how to shard the dataset, at the expense of
    # adding some non-zero amount of processing latency. This also greatly simplifies dataset
    # checkpointing, as there is only one global state to keep track of.
    #
    # When set to False, care must be taken to ensure that each rank receives different examples.
    dispatch_batches: bool = True

    # Optional eval-only override for `dispatch_batches`. ``None`` (default) means
    # eval follows whatever `dispatch_batches` is set to. Set to ``True`` to make
    # eval centralise on rank 0 even when training shards across ranks; this is
    # useful when the eval split is small enough that per-rank shards may not
    # contain enough examples to fill a single batch (especially with
    # ``dataloader_drop_last=True``). See
    # ``docs/trainers/distributed-eval-zero-batches.md`` for the failure mode
    # this guards against.
    dispatch_eval_batches: Optional[bool] = None

    ddp: DDPArguments = field(default_factory=DDPArguments)
    post_local_sgd: PostLocalSGDArguments = field(default_factory=PostLocalSGDArguments)

Fully Sharded Distributed Data Parallel (FSDP2) Trainer


forgather.ml.trainer.fsdp2.fsdp2_trainer.FSDP2Trainer

Bases: Trainer[TFSDP2TrainingArguments], Generic[TFSDP2TrainingArguments]

Trainer that shards model, gradients, and optimizer state via FSDP2.

Uses torch.distributed.fsdp.fully_shard (PyTorch's FSDP2 API) to distribute parameters, gradients, and optimizer state across all ranks. Provides ZeRO-3-style memory savings, making it suitable for models that don't fit in a single GPU's memory.

Launch with torchrun (or the forgather train -d ... shortcut)::

torchrun --nproc_per_node=4 train.py

Key differences from DDP:

  • Each rank stores only a shard of parameters, gradients, and optimizer state
  • Parameters are all-gathered before each forward/backward and re-sharded after (controlled by fsdp2.reshard_after_forward)
  • Model checkpoints are saved as full HuggingFace safetensors gathered on rank 0, making them loadable by from_pretrained without special tooling
  • Optimizer state is saved per-rank (sharded) and tied to the world size

See FSDP2Arguments for sharding configuration options (mixed precision policy, CPU offload, transformer-layer-wise sharding).

Source code in src/forgather/ml/trainer/fsdp2/fsdp2_trainer.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
class FSDP2Trainer(Trainer[TFSDP2TrainingArguments], Generic[TFSDP2TrainingArguments]):
    """
    Trainer that shards model, gradients, and optimizer state via FSDP2.

    Uses ``torch.distributed.fsdp.fully_shard`` (PyTorch's FSDP2 API) to distribute
    parameters, gradients, and optimizer state across all ranks. Provides ZeRO-3-style
    memory savings, making it suitable for models that don't fit in a single GPU's memory.

    Launch with ``torchrun`` (or the ``forgather train -d ...`` shortcut)::

        torchrun --nproc_per_node=4 train.py

    Key differences from DDP:

    - Each rank stores only a shard of parameters, gradients, and optimizer state
    - Parameters are all-gathered before each forward/backward and re-sharded after
      (controlled by ``fsdp2.reshard_after_forward``)
    - Model checkpoints are saved as full HuggingFace safetensors gathered on rank 0,
      making them loadable by ``from_pretrained`` without special tooling
    - Optimizer state is saved per-rank (sharded) and tied to the world size

    See ``FSDP2Arguments`` for sharding configuration options (mixed precision policy,
    CPU offload, transformer-layer-wise sharding).
    """

    args: TFSDP2TrainingArguments

    def __init__(
        self,
        *,
        args: TFSDP2TrainingArguments | dict,
        fused_loss_factory: Optional[FusedLossFactoryT] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        args : FSDP2TrainingArguments or dict
            Training configuration including FSDP2-specific options under the ``fsdp2``
            field (``reshard_after_forward``, ``param_dtype``, ``cpu_offload``, etc.).
            Accepts a dict for programmatic construction.
        fused_loss_factory : callable, optional
            Factory for fused logits-loss computation. See ``Trainer`` for details.
        **kwargs
            Forwarded to ``Trainer.__init__``: ``distributed_env``, ``model``,
            ``model_init``, ``train_dataset``, ``eval_dataset``, ``optimizer_factory``,
            ``lr_scheduler_factory``, ``callbacks``, etc.
        """
        if isinstance(args, dict):
            args = cast(
                TFSDP2TrainingArguments, from_dict(FSDP2TrainingArguments, args)
            )

        super().__init__(args=args, fused_loss_factory=fused_loss_factory, **kwargs)

        assert (
            not self.args.fuse_optim_with_backward
        ), "FSDP2Trainer does not support option fuse_optim_with_backward"

    @override
    def _init_distributed(self):
        assert (
            dist.is_available and dist.is_initialized()
        ) or self.dist.world_size == 1, (
            "FSDP2 trainer requires that torch.distributed has been initialized."
        )

        if self.dist.world_size == 1:
            return super()._init_distributed()

        self.is_local_process_zero = self.dist.local_rank == 0
        self.is_world_process_zero = self.dist.rank == 0
        self.num_processes = self.dist.world_size

        self.mesh = init_device_mesh(
            self.dist.device_type,
            (self.dist.world_size,),
            mesh_dim_names=("data_parallel",),
        )
        self.fsdp_group = self.mesh.get_group(0)

    def _build_mp_policy(self) -> MixedPrecisionPolicy:
        fsdp_args = self.args.fsdp2
        return MixedPrecisionPolicy(
            param_dtype=_parse_dtype(fsdp_args.param_dtype),
            reduce_dtype=_parse_dtype(fsdp_args.reduce_dtype),
            output_dtype=None,
            cast_forward_inputs=True,
        )

    @override
    def _prepare_model(self) -> None:
        # Stash the base (unfused) loss_fn before super() wraps it. The
        # base trainer's _prepare_model calls _maybe_get_fused_loss_fn BEFORE
        # we apply fully_shard, so any fused loss (LinearCrossEntropyLoss)
        # gets constructed against a plain-tensor lm_head and happily
        # selects Apple CCE / Liger. Those kernels silently produce zero
        # gradients once lm_head becomes a DTensor after fully_shard, so
        # we have to rebuild the fused loss after sharding. See the bottom
        # of this method for the rebuild.
        base_loss_fn = self.loss_fn

        super()._prepare_model()
        if self.dist.world_size == 1:
            return

        assert self.model is not None

        mp_policy = self._build_mp_policy()
        offload_policy: OffloadPolicy = (
            CPUOffloadPolicy() if self.args.fsdp2.cpu_offload else OffloadPolicy()
        )

        layer_iter: List[torch.nn.Module] = []
        if (
            self.args.fsdp2.shard_transformer_layers
            and self.args.fsdp2.transformer_layers_path
        ):
            container = _resolve_attr_path(
                self.model, self.args.fsdp2.transformer_layers_path
            )
            if container is not None:
                layer_iter = _iter_layer_modules(container)

        if layer_iter:
            for layer in layer_iter:
                fully_shard(
                    layer,
                    mesh=self.mesh,
                    reshard_after_forward=self.args.fsdp2.reshard_after_forward,
                    mp_policy=mp_policy,
                    offload_policy=offload_policy,
                )
            logger.info(
                f"Applied fully_shard to {len(layer_iter)} transformer blocks at "
                f"'{self.args.fsdp2.transformer_layers_path}'"
            )
        else:
            logger.warning(
                f"FSDP2 layer-wise sharding skipped: could not resolve "
                f"'{self.args.fsdp2.transformer_layers_path}' on model as an "
                "iterable of nn.Modules. Falling back to root-level "
                "fully_shard only; memory savings will be limited."
            )

        fully_shard(
            self.model,
            mesh=self.mesh,
            reshard_after_forward=self.args.fsdp2.reshard_after_forward,
            mp_policy=mp_policy,
            offload_policy=offload_policy,
        )

        # Fused loss is incompatible with DTensor lm_head under FSDP2.
        # None of CCE / Liger / pytorch-chunked handle a sharded weight:
        # CCE and Liger silently produce zero gradients (loss stuck at
        # ln(vocab_size)), and the pytorch path raises "mixed torch.Tensor
        # and DTensor" on its hidden_states @ weight.T matmul. The base
        # trainer's _prepare_model already built a fused-loss wrapper
        # against the pre-shard plain-tensor lm_head; undo that and
        # re-wrap the base unfused loss_fn. The model's normal lm_head
        # forward pass handles DTensor correctly via FSDP2's all_gather
        # forward hooks, at the cost of materializing the full logits
        # tensor that the fused path would have avoided. A future DTensor-
        # native fused loss can revisit this.
        if self.use_fused_loss:
            logger.warning(
                "FSDP2: fused LinearCrossEntropyLoss is incompatible with a "
                "DTensor lm_head weight; falling back to the standard "
                "logits path. Expect higher activation memory from "
                "materialized logits."
            )
            self.use_fused_loss = False
            self.loss_fn = base_loss_fn
            self._wrap_loss_fn()

    @override
    def _wrap(self) -> None:
        """Wrap dataloaders for DP; the model is already sharded."""
        if self.dist.world_size == 1:
            return super()._wrap()

        if self.args.dispatch_batches:
            if self.train_dataloader:
                self.train_dataloader = DataloaderDispatcher(
                    cast(DataLoader, self.train_dataloader),
                    self.mesh,
                    self.args.device,
                )
            if self.eval_dataloader:
                self.eval_dataloader = DataloaderDispatcher(
                    cast(DataLoader, self.eval_dataloader),
                    self.mesh,
                    self.args.device,
                )
        else:
            if self.train_dataloader:
                self.train_dataloader = SynchronizedDataLoader(
                    self.train_dataloader,
                    device=self.args.device,
                    process_group=self.fsdp_group,
                )
            if self.eval_dataloader:
                self.eval_dataloader = SynchronizedDataLoader(
                    self.eval_dataloader,
                    device=self.args.device,
                    process_group=self.fsdp_group,
                )

    @override
    def _init_checkpoint_manager(self) -> CheckpointManager:
        """
        Wire CheckpointManager with FSDP2-aware model save/load hooks.

        FSDP2 parameters are DTensors on a device mesh, so CheckpointManager's
        default shard-index path (which calls ``save_sharded_checkpoint`` on
        plain ``nn.Module.state_dict()``) can't be used directly: save needs
        a full-state-dict gather to rank 0, load needs a broadcast-and-
        reshard from rank 0. The hooks below handle both, and the output is
        standard HuggingFace safetensors layout, so:

        - ``transformers.AutoModel.from_pretrained`` can load checkpoints
          this trainer saves, and
        - this trainer can resume from any plain HF checkpoint it did not
          create.

        Because the hooks themselves are collectives, CheckpointManager
        calls them on every rank (rank gating for file writes lives inside
        the hooks). ``model_parts=[]`` and ``shard_index={}`` make the
        legacy shard-index path a no-op.
        """
        if self.dist.world_size == 1:
            return super()._init_checkpoint_manager()

        cp_config = CheckpointConfig(
            output_dir=self.args.output_dir,
            save_total_limit=self.args.save_total_limit,
            save_on_each_node=self.args.save_on_each_node,
            save_safetensors=self.args.save_safetensors,
        )

        fsdp2_model = self.unwrapped_model()
        safetensors = self.args.save_safetensors

        def _save_model_hook(output_dir: str) -> None:
            save_fsdp2_model_as_hf(
                fsdp2_model,
                output_dir,
                dist=self.dist,
                safetensors=safetensors,
            )

        def _load_model_hook(checkpoint_path: str) -> None:
            load_fsdp2_model_from_hf(
                fsdp2_model,
                checkpoint_path,
                dist=self.dist,
            )

        checkpoint_manager = CheckpointManager(
            config=cp_config,
            dist=self.dist,
            model=fsdp2_model,
            model_parts=[],
            model_preprocessor=self.processing_class,
            stateful_provider=self,
            shard_index={},
            model_save_fn=_save_model_hook,
            model_load_fn=_load_model_hook,
        )
        checkpoint_manager.trainer = self
        if hasattr(self.args, "preserve_n_best"):
            checkpoint_manager.preserve_n_best = self.args.preserve_n_best
        return checkpoint_manager

    @override
    def unwrapped_model(self) -> torch.nn.Module:
        # fully_shard mutates the module class in place; there's no .module
        # wrapper to peel off.
        assert self.model is not None
        return self.model

    @override
    def _distributed_loss(self, loss: Tensor) -> Tensor:
        if self.dist.world_size == 1:
            return super()._distributed_loss(loss)
        dist.all_reduce(loss, op=dist.ReduceOp.AVG)
        return loss

    @override
    def _distributed_tokens(self, tokens: Tensor) -> Tensor:
        if self.dist.world_size == 1:
            return super()._distributed_tokens(tokens)
        dist.all_reduce(tokens, op=dist.ReduceOp.SUM)
        return tokens

    @override
    def _distributed_peak_mem(self, local_peak: int) -> list[int]:
        if self.dist.world_size == 1:
            return super()._distributed_peak_mem(local_peak)
        value = torch.tensor(
            [int(local_peak)], dtype=torch.long, device=self.args.device
        )
        gathered = [torch.zeros_like(value) for _ in range(self.dist.world_size)]
        dist.all_gather(gathered, value)
        return [int(t.item()) for t in gathered]

    @override
    def _forward_backward_step(self, *args, **kwargs) -> Tensor:
        """
        Gate FSDP2 gradient reduction at accumulation boundaries.

        FSDP2 uses set_requires_gradient_sync() rather than DDP's no_sync()
        context manager. Setting False suppresses the reduce-scatter in the
        backward hook for the current backward; setting True re-enables it
        for the final micro-batch of the accumulation window.
        """
        if self.dist.world_size == 1:
            return super()._forward_backward_step(*args, **kwargs)

        sync = self._should_sync_gradients()
        cast(FSDPModule, self.model).set_requires_gradient_sync(sync)
        return super()._forward_backward_step(*args, **kwargs)

    def pipeline_generate(self, input_ids: Tensor, **kwargs) -> Tensor:
        """All-rank generate: FSDP2 forward pass needs every rank in the
        all_gather, so generation must run collectively. The
        ``TextgenCallback`` detects this method and uses a coordinated
        broadcast-then-generate flow (same path as PipelineTrainer)."""
        assert self.model is not None
        return self.model.generate(input_ids=input_ids, **kwargs)

    @override
    def get_state_components(self) -> List[StateComponent]:
        """
        State components for FSDP2.

        The model is saved/loaded as HuggingFace safetensors via the model
        hooks wired in ``_init_checkpoint_manager``; it is NOT registered as
        a StateComponent. Optimizer state stays sharded per rank because
        the DTensor layout of the optimizer moments cannot cheaply round-
        trip through a gather/broadcast. Scheduler, trainer progress,
        dataset and RNG mirror DDPTrainer.
        """
        if self.dist.world_size == 1:
            return super().get_state_components()

        assert self.model is not None

        components: List[StateComponent] = []

        if self.optimizer is not None:
            components.append(
                StateComponent(
                    key="optimizer",
                    stateful=_FSDP2OptimStateful(
                        self.model, cast(torch.optim.Optimizer, self.optimizer)
                    ),
                    sharing_pattern=SharingPattern.PER_RANK,
                    validate_replication=False,
                    required=False,
                )
            )

        if self.lr_scheduler is not None:
            components.append(
                StateComponent(
                    key="scheduler",
                    stateful=cast(Stateful, self.lr_scheduler),
                    sharing_pattern=SharingPattern.REPLICATED,
                    required=False,
                )
            )

        components.append(
            StateComponent(
                key="trainer",
                stateful=self,
                sharing_pattern=SharingPattern.REPLICATED,
                required=False,
            )
        )

        if hasattr(self.train_dataloader, "state_dict"):
            components.append(
                StateComponent(
                    key="dataset",
                    stateful=cast(Stateful, self.train_dataloader),
                    sharing_pattern=self._get_dataset_sharing_pattern(),
                    required=False,
                )
            )

        components.append(
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=False,
            )
        )

        return components

    @override
    def _get_dataset_sharing_pattern(self) -> SharingPattern:
        if self.dist.world_size == 1:
            return super()._get_dataset_sharing_pattern()
        if self.args.dispatch_batches:
            return SharingPattern.GLOBAL
        return SharingPattern.PER_RANK

    @override
    def get_process_groups(self) -> Dict[str, Any]:
        if self.dist.world_size == 1:
            return super().get_process_groups()
        return {"fsdp_group": self.fsdp_group}

__init__(*, args, fused_loss_factory=None, **kwargs)

Parameters:

Name Type Description Default
args FSDP2TrainingArguments or dict

Training configuration including FSDP2-specific options under the fsdp2 field (reshard_after_forward, param_dtype, cpu_offload, etc.). Accepts a dict for programmatic construction.

required
fused_loss_factory callable

Factory for fused logits-loss computation. See Trainer for details.

None
**kwargs

Forwarded to Trainer.__init__: distributed_env, model, model_init, train_dataset, eval_dataset, optimizer_factory, lr_scheduler_factory, callbacks, etc.

{}
Source code in src/forgather/ml/trainer/fsdp2/fsdp2_trainer.py
def __init__(
    self,
    *,
    args: TFSDP2TrainingArguments | dict,
    fused_loss_factory: Optional[FusedLossFactoryT] = None,
    **kwargs,
):
    """
    Parameters
    ----------
    args : FSDP2TrainingArguments or dict
        Training configuration including FSDP2-specific options under the ``fsdp2``
        field (``reshard_after_forward``, ``param_dtype``, ``cpu_offload``, etc.).
        Accepts a dict for programmatic construction.
    fused_loss_factory : callable, optional
        Factory for fused logits-loss computation. See ``Trainer`` for details.
    **kwargs
        Forwarded to ``Trainer.__init__``: ``distributed_env``, ``model``,
        ``model_init``, ``train_dataset``, ``eval_dataset``, ``optimizer_factory``,
        ``lr_scheduler_factory``, ``callbacks``, etc.
    """
    if isinstance(args, dict):
        args = cast(
            TFSDP2TrainingArguments, from_dict(FSDP2TrainingArguments, args)
        )

    super().__init__(args=args, fused_loss_factory=fused_loss_factory, **kwargs)

    assert (
        not self.args.fuse_optim_with_backward
    ), "FSDP2Trainer does not support option fuse_optim_with_backward"

pipeline_generate(input_ids, **kwargs)

All-rank generate: FSDP2 forward pass needs every rank in the all_gather, so generation must run collectively. The TextgenCallback detects this method and uses a coordinated broadcast-then-generate flow (same path as PipelineTrainer).

Source code in src/forgather/ml/trainer/fsdp2/fsdp2_trainer.py
def pipeline_generate(self, input_ids: Tensor, **kwargs) -> Tensor:
    """All-rank generate: FSDP2 forward pass needs every rank in the
    all_gather, so generation must run collectively. The
    ``TextgenCallback`` detects this method and uses a coordinated
    broadcast-then-generate flow (same path as PipelineTrainer)."""
    assert self.model is not None
    return self.model.generate(input_ids=input_ids, **kwargs)

get_state_components()

State components for FSDP2.

The model is saved/loaded as HuggingFace safetensors via the model hooks wired in _init_checkpoint_manager; it is NOT registered as a StateComponent. Optimizer state stays sharded per rank because the DTensor layout of the optimizer moments cannot cheaply round- trip through a gather/broadcast. Scheduler, trainer progress, dataset and RNG mirror DDPTrainer.

Source code in src/forgather/ml/trainer/fsdp2/fsdp2_trainer.py
@override
def get_state_components(self) -> List[StateComponent]:
    """
    State components for FSDP2.

    The model is saved/loaded as HuggingFace safetensors via the model
    hooks wired in ``_init_checkpoint_manager``; it is NOT registered as
    a StateComponent. Optimizer state stays sharded per rank because
    the DTensor layout of the optimizer moments cannot cheaply round-
    trip through a gather/broadcast. Scheduler, trainer progress,
    dataset and RNG mirror DDPTrainer.
    """
    if self.dist.world_size == 1:
        return super().get_state_components()

    assert self.model is not None

    components: List[StateComponent] = []

    if self.optimizer is not None:
        components.append(
            StateComponent(
                key="optimizer",
                stateful=_FSDP2OptimStateful(
                    self.model, cast(torch.optim.Optimizer, self.optimizer)
                ),
                sharing_pattern=SharingPattern.PER_RANK,
                validate_replication=False,
                required=False,
            )
        )

    if self.lr_scheduler is not None:
        components.append(
            StateComponent(
                key="scheduler",
                stateful=cast(Stateful, self.lr_scheduler),
                sharing_pattern=SharingPattern.REPLICATED,
                required=False,
            )
        )

    components.append(
        StateComponent(
            key="trainer",
            stateful=self,
            sharing_pattern=SharingPattern.REPLICATED,
            required=False,
        )
    )

    if hasattr(self.train_dataloader, "state_dict"):
        components.append(
            StateComponent(
                key="dataset",
                stateful=cast(Stateful, self.train_dataloader),
                sharing_pattern=self._get_dataset_sharing_pattern(),
                required=False,
            )
        )

    components.append(
        StateComponent(
            key="rng",
            stateful=RNGState(),
            sharing_pattern=SharingPattern.PER_RANK,
            required=False,
        )
    )

    return components

forgather.ml.trainer.fsdp2.fsdp2_trainer.FSDP2Arguments dataclass

Source code in src/forgather/ml/trainer/fsdp2/fsdp2_trainer.py
@dataclass(kw_only=True)
class FSDP2Arguments:
    # Whether to re-shard parameters after forward. True behaves like ZeRO-3
    # (minimum memory); False behaves like ZeRO-2 (keeps params unsharded
    # between forward and backward, lower comm, higher memory). An int N
    # enables hybrid resharding across N ranks (ZeRO++ hpZ).
    reshard_after_forward: bool = True

    # MixedPrecisionPolicy dtypes. None disables FSDP-level mixed precision;
    # the trainer's existing AMP autocast still applies.
    param_dtype: Optional[str] = None
    reduce_dtype: Optional[str] = None
    buffer_dtype: Optional[str] = None

    # Offload parameters (and gradients) to CPU between uses.
    cpu_offload: bool = False

    # Apply fully_shard layer-by-layer on the transformer blocks before the
    # root module. Disabling this falls back to a single root-level wrap,
    # which severely limits memory savings (no per-layer reshard).
    shard_transformer_layers: bool = True

    # Dotted attribute path on the unwrapped model that yields an iterable
    # of transformer blocks. The default matches Forgather's standard
    # causal-LM structure (an HF wrapper with `base_model_prefix="causal_lm"`
    # around a `CausalLM` whose `layer_stack.layers` is a `ModuleDict` /
    # `ModuleList`; see modelsrc/transformer/). Override for non-standard
    # models; set to "" or a path that does not resolve to disable
    # layer-wise sharding gracefully.
    transformer_layers_path: str = "causal_lm.layer_stack.layers"

Pipeline Parallel Trainer

forgather.ml.trainer.pipeline.pipeline_trainer.PipelineTrainer

Bases: Trainer[TPipelineTrainingArguments], Generic[TPipelineTrainingArguments]

Trainer for pipeline parallel training using PyTorch distributed pipelining.

Partitions a model across multiple GPUs — each GPU hosts one or more sequential pipeline stages. Input batches are split into microbatches that flow through the stages with multiple microbatches in flight simultaneously, keeping all GPUs busy.

This trainer is designed for environments where inter-GPU bandwidth is limited (consumer GPUs over PCIe, multi-node over Ethernet) where all-reduce–based DDP or FSDP would be communication-bound.

Key differences from the single-device Trainer:

  • Model is constructed on the meta device, then each stage is materialised on its assigned GPU — no full model ever lives on one GPU.
  • Rank 0 constructs a fully-initialised CPU model and distributes parameters to other ranks via point-to-point sends, avoiding N redundant initialisations.
  • All ranks receive the same batch (pure model parallelism); rank 0 loads data via DataloaderDispatcher and broadcasts it.
  • Gradient norm is all-reduced across ranks because each rank holds only a subset of the model's parameters.
  • Effective batch size does not scale with num_processes (the same batch flows through all stages; unlike DDP, there is no data replication).

Parameters:

Name Type Description Default
args PipelineTrainingArguments or dict

Pipeline training configuration. Dicts are converted via dacite.from_dict.

required
model_splitter ModelSplitter

Callable that splits the model into pipeline stages and returns PipelineStage objects. See src/forgather/ml/trainer/pipeline/model_splitter.py for the expected signature.

required
pipe_schedule_factory callable

Factory for the pipeline scheduler (e.g. ScheduleGPipe, ScheduleZBVZeroBubble). Default is ScheduleGPipe.

ScheduleGPipe
**kwargs

Additional arguments forwarded to the base Trainer (model_init, train_dataset, optimizer_factory, etc.).

{}

Raises:

Type Description
AssertionError

If model is provided (pipeline training requires model_init).

AssertionError

If model_init is not provided.

AssertionError

If batch size is not divisible by n_microbatches.

AssertionError

If stages_per_rank > 1 but is_multistage=False.

AssertionError

If mixed_precision="fp16" (incompatible with pipeline schedulers).

AssertionError

If a zero-bubble schedule is used with torch_compile=True.

AssertionError

If world_size == 1 (pipeline parallelism requires multiple ranks).

Examples:

>>> from torch.distributed.pipelining import ScheduleGPipe
>>> args = PipelineTrainingArguments(
...     n_microbatches=8,
...     per_device_train_batch_size=64,
...     stages_per_rank=1,
... )
>>> trainer = PipelineTrainer(
...     args=args,
...     model_init=model_factory,
...     model_splitter=my_splitter_fn,
...     pipe_schedule_factory=ScheduleGPipe,
...     train_dataset=train_dataset,
...     optimizer_factory=optimizer_factory,
... )
>>> trainer.train()
See Also

ModelSplitter : Protocol for the model-splitting callable.

References

PyTorch pipeline parallelism: https://docs.pytorch.org/docs/stable/distributed.pipelining.html

Source code in src/forgather/ml/trainer/pipeline/pipeline_trainer.py
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
class PipelineTrainer(
    Trainer[TPipelineTrainingArguments], Generic[TPipelineTrainingArguments]
):
    """Trainer for pipeline parallel training using PyTorch distributed pipelining.

    Partitions a model across multiple GPUs — each GPU hosts one or more
    sequential pipeline stages. Input batches are split into microbatches that
    flow through the stages with multiple microbatches in flight simultaneously,
    keeping all GPUs busy.

    This trainer is designed for environments where inter-GPU bandwidth is limited
    (consumer GPUs over PCIe, multi-node over Ethernet) where all-reduce–based DDP
    or FSDP would be communication-bound.

    Key differences from the single-device ``Trainer``:

    * Model is constructed on the meta device, then each stage is materialised
      on its assigned GPU — no full model ever lives on one GPU.
    * Rank 0 constructs a fully-initialised CPU model and distributes parameters
      to other ranks via point-to-point sends, avoiding N redundant initialisations.
    * All ranks receive the same batch (pure model parallelism); rank 0 loads data
      via ``DataloaderDispatcher`` and broadcasts it.
    * Gradient norm is all-reduced across ranks because each rank holds only a
      subset of the model's parameters.
    * Effective batch size does **not** scale with ``num_processes`` (the same
      batch flows through all stages; unlike DDP, there is no data replication).

    Parameters
    ----------
    args : PipelineTrainingArguments or dict
        Pipeline training configuration. Dicts are converted via
        ``dacite.from_dict``.
    model_splitter : ModelSplitter
        Callable that splits the model into pipeline stages and returns
        ``PipelineStage`` objects. See
        ``src/forgather/ml/trainer/pipeline/model_splitter.py`` for the
        expected signature.
    pipe_schedule_factory : callable, optional
        Factory for the pipeline scheduler (e.g. ``ScheduleGPipe``,
        ``ScheduleZBVZeroBubble``). Default is ``ScheduleGPipe``.
    **kwargs
        Additional arguments forwarded to the base ``Trainer``
        (``model_init``, ``train_dataset``, ``optimizer_factory``, etc.).

    Raises
    ------
    AssertionError
        If ``model`` is provided (pipeline training requires ``model_init``).
    AssertionError
        If ``model_init`` is not provided.
    AssertionError
        If batch size is not divisible by ``n_microbatches``.
    AssertionError
        If ``stages_per_rank > 1`` but ``is_multistage=False``.
    AssertionError
        If ``mixed_precision="fp16"`` (incompatible with pipeline schedulers).
    AssertionError
        If a zero-bubble schedule is used with ``torch_compile=True``.
    AssertionError
        If ``world_size == 1`` (pipeline parallelism requires multiple ranks).

    Examples
    --------
    >>> from torch.distributed.pipelining import ScheduleGPipe
    >>> args = PipelineTrainingArguments(
    ...     n_microbatches=8,
    ...     per_device_train_batch_size=64,
    ...     stages_per_rank=1,
    ... )
    >>> trainer = PipelineTrainer(
    ...     args=args,
    ...     model_init=model_factory,
    ...     model_splitter=my_splitter_fn,
    ...     pipe_schedule_factory=ScheduleGPipe,
    ...     train_dataset=train_dataset,
    ...     optimizer_factory=optimizer_factory,
    ... )
    >>> trainer.train()

    See Also
    --------
    ModelSplitter : Protocol for the model-splitting callable.

    References
    ----------
    PyTorch pipeline parallelism:
    https://docs.pytorch.org/docs/stable/distributed.pipelining.html
    """

    args: TPipelineTrainingArguments
    model_splitter: ModelSplitter
    pipe_schedule_factory: PipelineSchedulerFactorT
    pp_group: Any
    n_pipeline_stages: int
    scheduler: PipelineSchedulerT | None
    pipeline_modules: List[Module] | None
    sharing_metadata: SharingMetadataT | None
    shard_index: ShardIndex | None
    stage_indices: Tuple[int, ...] | None
    pp_has_last_stage: bool
    pp_has_first_stage: bool
    attention_mask_creator: Callable

    def __init__(
        self,
        *,
        args: TPipelineTrainingArguments | dict,
        model_splitter: ModelSplitter,  # Required: function to split model into pipeline stages
        pipe_schedule_factory: PipelineSchedulerFactorT = ScheduleGPipe,  # type: ignore[assignment]
        **kwargs,
    ):
        """Initialise the pipeline parallel trainer.

        Parameters
        ----------
        args : PipelineTrainingArguments or dict
            Pipeline training configuration. Dicts are converted via
            ``dacite.from_dict(PipelineTrainingArguments, args)``.
        model_splitter : ModelSplitter
            Callable that accepts the model on the meta device and returns
            all pipeline stage modules, the rank-local stage modules, and
            ``PipelineStage`` objects. See
            ``src/forgather/ml/trainer/pipeline/model_splitter.py`` for the
            full signature.
        pipe_schedule_factory : callable, optional
            Pipeline scheduler factory. ``ScheduleGPipe`` (default) uses simple
            GPipe scheduling. Pass ``ScheduleZBVZeroBubble`` or similar for
            zero-bubble schedules.
        **kwargs
            Forwarded to the base ``Trainer`` constructor (``model_init``,
            ``train_dataset``, ``optimizer_factory``, ``callbacks``, etc.).
        """
        if isinstance(args, dict):
            args = cast(
                TPipelineTrainingArguments, from_dict(PipelineTrainingArguments, args)
            )
        super().__init__(args=args, **kwargs)
        self.model_splitter = model_splitter
        self.pipe_schedule_factory = pipe_schedule_factory

        # Zero-bubble schedules split backward into input-grad and weight-grad
        # steps and call torch.autograd.grad(..., retain_graph=True). That
        # conflicts with donated buffers in torch.compile'd backward
        # (e.g. flex_attention). Disable the optimization before any compiled
        # backward has been captured.
        if is_zero_bubble_schedule(pipe_schedule_factory):
            disable_compiled_backward_donated_buffers()

            # torch.compile applied at stage granularity (see _compile_model)
            # collapses the stage interior into a single Python autograd.Function
            # (CompiledFunctionBackward). Zero-bubble's stage_backward_weight
            # then constructs GradientEdge(intermediate, 0) over those nodes and
            # passes them to torch.autograd.grad, which calls _make_grads ->
            # out.node._input_metadata. That attribute is unavailable on Python
            # autograd.Function nodes and the C++ binding raises:
            #   "Attribute '_input_metadata' is invalid for this instance of
            #    _C._FunctionBase. ... legacy access pattern that is no longer
            #    supported."
            # The crash is structural: AOTAutograd flattens the stage interior
            # so the I/W split has nothing to walk between intermediates and
            # parameters. Refuse the combination at init time with a clear
            # diagnostic instead of letting the user hit it mid-training.
            assert not self.args.torch_compile, (
                "PipelineTrainer does not support torch.compile with zero-bubble "
                "schedules (ScheduleZBVZeroBubble, ScheduleInterleavedZeroBubble). "
                "AOTAutograd wraps each compiled stage in a Python autograd.Function "
                "whose internal nodes do not expose _input_metadata, which the "
                "split-backward weight step requires. Use a non-zero-bubble schedule "
                "(e.g. ScheduleInterleaved1F1B) or disable torch_compile."
            )

        assert self.args.mixed_precision != "fp16", (
            "PipelineTrainer does not support fp16 mixed precision (GradScaler is incompatible "
            "with pipeline scheduler's internal backward). Use mixed_precision='bf16' instead."
        )

        if self.args.debug_pipeline:
            logger.setLevel(logging.DEBUG)

        assert (
            self.model is None
        ), "Pipeline trainer only support model_init=fn, where fn is a zero-args Callable, returning a model"
        assert self.model_init, "Pipeline trainer requires a model_init function"

        for batch_size in (
            self.args.per_device_train_batch_size,
            self.args.per_device_eval_batch_size,
        ):
            assert (
                batch_size % self.args.n_microbatches == 0
            ), f"Batch size ({batch_size}) must be evenly divisible by n_microbatches ({self.args.n_microbatches})"
        assert (
            self.args.is_multistage or self.args.stages_per_rank == 1
        ), "Only multistage schedulers may have more than one stages_per_rank"

        # The pipeline requires a fixed shape for the inputs
        self.args.dataloader_drop_last = True

    @override
    def _is_pipeline_parallel(self) -> bool:
        """Indicate that this trainer uses pipeline parallelism.

        Pipeline parallelism does **not** increase the effective batch size
        (unlike DDP): all microbatches belong to the same original batch and
        flow through the stages sequentially. The base trainer uses this flag
        to skip the DDP-style effective-batch-size scaling.

        Returns
        -------
        bool
            Always ``True``.
        """
        return True

    @override
    def _init_distributed(self):
        self.is_local_process_zero = self.dist.local_rank == 0
        self.is_world_process_zero = self.dist.rank == 0
        self.num_processes = self.dist.world_size

        # Pipeline parallelism is meaningless with a single rank, and downstream
        # code (DataloaderDispatcher pure-MP, pipeline scheduler) will fail in
        # confusing ways if we let this through. Catch it here with a clear
        # diagnostic.
        assert self.dist.world_size > 1, (
            f"PipelineTrainer requires world_size > 1, got {self.dist.world_size}. "
            "Launch with torchrun --nproc-per-node N (N > 1), or verify that any "
            "dynamic-arg override of nproc_per_node is being honored by the "
            "`forgather train` command."
        )

        # Calculate total number of pipeline stages
        self.n_pipeline_stages = self.args.stages_per_rank * self.dist.world_size

        # Create device mesh for pipeline parallel (pure MP - all ranks get same batch)
        # This mesh is used for batch distribution via DataloaderDispatcher
        self.mesh = init_device_mesh(
            self.dist.device_type,
            (self.dist.world_size,),
            mesh_dim_names=("pipeline_parallel",),
        )

        # Create pipeline parallel process group
        # For now, includes all ranks, but this allows future support for
        # hybrid parallelism where PP is a subset of ranks
        self.pp_group = self.mesh.get_group(0)

    def _print_modules(self, modules):
        if self.args.debug_model_params:
            for mod in modules:
                for name, p in mod.named_parameters(remove_duplicate=False):
                    logger.debug(
                        f"P {self.dist.rank} {name} : device {p.device}, dtype {p.dtype}"
                    )
                for name, p in mod.named_buffers(remove_duplicate=False):
                    logger.debug(
                        f"B {self.dist.rank} {name} : device {p.device}, dtype {p.dtype}"
                    )

    @override
    def _wrap(self) -> None:
        """Wrap dataloaders with ``DataloaderDispatcher`` for pipeline batch distribution.

        Pipeline parallelism requires all ranks to process the same batch (pure
        model-parallel mode). Rank 0 loads data from the underlying
        ``DataLoader`` and broadcasts it to all other ranks.

        ``DataloaderDispatcher`` is created with ``dp_mesh_dim=None`` to
        signal pure model-parallelism — no data-parallel dimension exists, so
        rank 0 broadcasts the full batch to every participant.
        """
        if self.train_dataloader:
            self.train_dataloader = DataloaderDispatcher(
                cast(DataLoader, self.train_dataloader),
                self.mesh,
                torch.device(self.dist.device),
                dp_mesh_dim=None,  # Pure MP: all ranks get same batch
            )

        if self.eval_dataloader:
            self.eval_dataloader = DataloaderDispatcher(
                cast(DataLoader, self.eval_dataloader),
                self.mesh,
                torch.device(self.dist.device),
                dp_mesh_dim=None,  # Pure MP: all ranks get same batch
            )

    @override
    def _prepare_model(self):
        """Construct and distribute the model across pipeline stages.

        This is the central setup method for pipeline parallel training. It
        performs the following steps in order:

        1. Construct the full model on the meta device (no memory allocation).
        2. Capture parameter-sharing metadata (tied weights, etc.).
        3. Split the model into pipeline stages via ``model_splitter``.
        4. Materialise each stage's parameters on its assigned device.
        5. Initialise parameters: rank 0 builds a full CPU model and sends
           each rank's stage parameters via point-to-point sends.
        6. Build the shard index for distributed checkpoint save/load.
        7. Construct the pipeline scheduler with the configured microbatch count.
        8. Set up the loss function (only the last stage computes loss).
        9. Enable gradient checkpointing on each stage when requested.

        After this method returns, ``self.pipeline_modules`` contains the
        materialised stage modules for this rank, and ``self.model`` holds the
        original meta-device model for shape/config queries.
        """
        # Reset -- this trainer always resets everything.
        self.scheduler = None
        self.model = None
        self.pipeline_modules = None
        self.optimizer = None
        self.lr_scheduler = None
        self.sharing_metadata = None

        assert self.train_dataloader or self.eval_dataloader

        # Construct model instance on the "meta" device; parameters have meta-data, but no actual data.
        # This allows us to construct a "huge" model, without having to have the memory for it.
        model = self._construct_model(device="meta")
        if self.dist.rank == 0:
            self._print_modules([model])

        # Get parameter sharing metadata
        self.sharing_metadata = create_sharing_metadata(model)

        # Get a micro-batch from the train_dataloader to use for tracing.
        dataloader = (
            self.train_dataloader if self.train_dataloader else self.eval_dataloader
        )
        example_args, example_kwargs = self._get_example(dataloader)

        # stage_indices : A List[Tuple[int]] with the assigned stage indices for each rank
        #   e.g. stage_indices[rank] would have the stage indices for "rank"
        stage_indices = pipeline_stage_indices(
            self.dist.world_size, self.n_pipeline_stages, style=self.args.pp_stage_type
        )

        last_stage_index = self.n_pipeline_stages - 1
        self.stage_indices = stage_indices[self.dist.rank]
        self.pp_has_last_stage = last_stage_index in self.stage_indices
        self.pp_has_first_stage = 0 in self.stage_indices

        # Split model into pipeline segments.
        if self.dist.rank == 0:
            logger.debug(f"All assigned pipeline indices {stage_indices}")
            logger.info("Splitting model...")
        all_pipeline_modules, pipeline_modules, pipeline_stages = self._split_model(
            model, example_args, example_kwargs, stage_indices, train=True
        )
        # all_pipeline_modules : A list of all modules in the pipeline
        # pipeline_modules : A list of modules assigned to this rank
        # pipeline_stages : A list of pipeline stages assigned to this rank

        # Convert meta tensors to real tensor on assigned devices.
        for mod in pipeline_modules:
            mod.to_empty(device=self.dist.device)
            retie_parameters(mod, self.sharing_metadata)

        # Load from checkpoint?
        if self.args.resume_from_checkpoint:
            missing_buffer_set = missing_buffers(model)
            if len(missing_buffer_set):
                # Non-persistent buffers are not saved in checkpoints. Initialize
                # them locally on each rank via reset_parameters(). This is safe
                # because buffer computation is deterministic and local to each
                # module (e.g., RotaryEmbedding computes inv_freq from its own
                # rope_theta and d_head -- no cross-module dependencies).
                if self.dist.rank == 0:
                    logger.info(
                        f"Initializing non-persistent buffers locally on each rank: "
                        f"{missing_buffer_set}"
                    )
                for mod in pipeline_modules:
                    self._initialize_non_persistent_buffers(mod)
        else:
            if self.dist.rank == 0:
                # If this results in OOM (really large model), you will have to initialize the model from a checkpoint
                # which will likely entail some amount of work.
                logger.info(
                    "Constructing full model on CPU and distributing initialized parameters from rank0."
                )
            self._initialize_params(
                all_pipeline_modules, pipeline_modules, stage_indices, False
            )

        self._print_modules(pipeline_modules)

        if self.args.fp8_recipe:
            for i, mod in enumerate(pipeline_modules):
                pipeline_modules[i] = self._apply_fp8_training(mod)

        # Construct the pipeline scheduler.
        # Depending upon the class, it either takes a single stage (PipelineScheduleSingle) or a list of stages,
        # PipelineScheduleMulti. See: https://docs.pytorch.org/docs/stable/distributed.pipelining.html#torch.distributed.pipelining.schedules.PipelineScheduleSingle
        if self.args.is_multistage:
            stages_arg = pipeline_stages
        else:
            assert len(pipeline_stages) == 1
            stages_arg = pipeline_stages[0]

        # Make the shard index, which we will need for saving the distribued model.
        # First, assert no duplicate FQNs exist across pipeline modules
        state_dicts = [mod.state_dict() for mod in all_pipeline_modules]
        assert_no_duplicate_fqns(state_dicts)

        self.shard_index = make_shard_index(
            state_dicts,
            safetensors=self.args.save_safetensors,
            param_sharing_metadata=self.sharing_metadata,
        )

        # Only the last state needs to compute loss
        # TODO: Placing this here is not going to work for the auto-splitter. That will require more work...
        # Zero-bubble schedules are incompatible with Python autograd.Function-based
        # fused loss kernels (e.g. Apple CCE's LinearCrossEntropyFunction). Zero-bubble's
        # stage_backward_weight constructs GradientEdge over intermediate nodes and
        # calls torch.autograd.grad -> _make_grads -> out.node._input_metadata, which
        # is not exposed on Python autograd.Function nodes and raises the same legacy
        # accessor error documented for torch.compile above. Disable fused loss in this
        # case so the pipeline falls back to the standard (unfused) loss path.
        if self.pp_has_last_stage and not is_zero_bubble_schedule(
            self.pipe_schedule_factory
        ):
            self.loss_fn = self._maybe_get_fused_loss_fn(
                pipeline_modules[-1], self.loss_fn
            )
        elif self.pp_has_last_stage and self.fused_loss_factory is not None:
            logger.warning(
                "Zero-bubble pipeline schedules are incompatible with Python "
                "autograd.Function-based fused loss kernels. Falling back to the "
                "standard (unfused) loss path."
            )

        # Loss needs to be scaled by the number of micro-batches
        self.loss_fn = RescaleLoss(self.loss_fn, 1 / self.args.n_microbatches)

        # Only the outer wrapper will be disabled for eval
        self.loss_fn = RescaleLoss(
            self.loss_fn, 1 / self.args.gradient_accumulation_steps
        )

        # Note: scale_grads=True (default) rescales gradients in-place during each microbatch step.
        # This breaks gradient accumulation because it rescales the cumulative gradient repeatedly.
        # Instead, we manually scale the loss by 1/n_microbatches above, achieving correct scaling
        # without interfering with gradient accumulation. Set scale_grads=False to disable
        # scheduler's built-in (broken) scaling.
        # This mirrors the fix applied to TorchTitan: https://github.com/pytorch/torchtitan/pull/XXX
        self.scheduler = self.pipe_schedule_factory(
            stages_arg,
            self.args.n_microbatches,
            loss_fn=self.loss_fn,  # type: ignore[call-arg]
            scale_grads=False,
        )

        if self.args.gradient_checkpointing:
            if self.enable_activation_checkpoint_fn is None:
                if self.dist.rank == 0:
                    logger.warning(
                        f"Activation checkpointing requested, but no function defined!"
                    )
            else:
                # Enable activation checkpointing for all modules in the pipeline.
                for mod in pipeline_modules:
                    self.enable_activation_checkpoint_fn(self.dist.rank, mod)

        self.pipeline_modules = pipeline_modules

        for rank in range(len(stage_indices)):
            if last_stage_index in stage_indices[rank]:
                self.pp_last_stage_rank = rank
                break

        # Map every global stage index to the rank that owns it.
        self._stage_to_rank: dict[int, int] = {
            stage: rank_idx
            for rank_idx, stages in enumerate(stage_indices)
            for stage in stages
        }

        # Map global stage index → local pipeline_modules index (for this rank only).
        # pipeline_modules[j] corresponds to stage_indices[rank][j].
        self._stage_to_local_mod: dict[int, int] = {
            stage: idx for idx, stage in enumerate(self.stage_indices)
        }

        # Dtype for activations exchanged between ranks during generation.
        if self.args.mixed_precision == "bf16":
            self._generation_dtype = torch.bfloat16
        elif self.args.mixed_precision == "fp16":
            self._generation_dtype = torch.float16
        else:
            param = next((p for m in pipeline_modules for p in m.parameters()), None)
            self._generation_dtype = param.dtype if param is not None else torch.float32

        # We keep the original model on the meta-device. This model is obviously not functional, but
        # some trainer callbacks may wish to dump the layout.
        self.model = model

        # Compute FLOPs per token using the meta-device model.
        # p.numel() works correctly for meta tensors (shape is defined, data is not).
        self._flops_per_token = self._compute_flops_per_token()
        if self.dist.rank == 0:
            logger.info(f"Estimated FLOPs per token: {self._flops_per_token:.2e}")

    def _construct_model(self, device):
        # Construct model on device
        assert self.model_init
        with ExitStack() as exit_stack:
            exit_stack.enter_context(torch.device(device))
            if self.args.default_dtype:
                exit_stack.enter_context(
                    default_dtype(torch_dtype(self.args.default_dtype))
                )
            model = self.model_init()
        return model

    @override
    def _compile_model(self):
        """Compile each pipeline stage module assigned to this rank with ``torch.compile``.

        Unlike single-device training (where the entire model is compiled as one
        unit), each pipeline stage is compiled independently, which is necessary
        because each stage lives on a different GPU.
        """
        assert self.pipeline_modules
        for mod in self.pipeline_modules:
            mod.compile(
                backend=self.args.torch_compile_backend,
                mode=self.args.torch_compile_mode,
                dynamic=self.args.torch_compile_dynamic,
                fullgraph=self.args.torch_compile_full_graph,
            )

    def _get_example(self, example_dataloader):
        """Build an example microbatch for pipeline stage tracing.

        Pipeline parallel requires all batches to have identical shapes. This
        method draws one real batch from the dataloader, creates a meta-device
        tensor with the same shape as ``"input_ids"``, splits it into
        ``n_microbatches`` chunks, and returns the first chunk for use during
        model splitting (tracing).

        Parameters
        ----------
        example_dataloader : iterable
            Dataloader from which to extract the batch shape.

        Returns
        -------
        tuple of (tuple, dict)
            ``(example_args, example_kwargs)`` representing a single microbatch
            on the meta device, suitable for passing to ``model_splitter``.

        Notes
        -----
        Currently hardcoded to use ``"input_ids"`` as the primary input tensor.
        """
        # Note that pipeline parallel requires all batches to have the same shape!
        # TODO: We have hard-coded "input_ids" This should be more flexible, as this is not always the case.
        example_batch = next(iter(example_dataloader))
        example_args = (torch.empty_like(example_batch["input_ids"], device="meta"),)
        example_kwargs = dict(
            #    input_ids=torch.empty_like(
            #        example_batch["input_ids"],
            #        device="meta"
            #    ),
        )

        # Split into microbatches
        split_args, split_kwargs = split_args_kwargs_into_chunks(
            example_args, example_kwargs, chunks=self.args.n_microbatches
        )

        # Return example micro-batches
        return split_args[0], split_kwargs[0]

    def _split_model(
        self, model, example_args, example_kwargs, stage_indices, train
    ) -> Tuple[List[Module], List[Module], List[_PipelineStageBase]]:
        """
        Split model into pipeline stages using the injected splitter.

        Delegates to self.model_splitter and captures the attention_mask_creator
        for later use in forward/backward passes.

        Returns modules on meta device - caller must materialize them.
        """
        rank = self.dist.rank

        # Call the injected splitter
        (
            all_pipeline_modules,
            pipeline_modules,
            pipeline_stages,
            attention_mask_creator,
        ) = self.model_splitter(
            model,
            example_args,
            example_kwargs,
            stage_indices,
            train,
            device=self.dist.device,  # type: ignore[call-arg]
            rank=rank,
            pp_group=self.pp_group,
        )

        # Store attention mask creator for use in forward/backward steps
        # Will be None if splitter doesn't support external masks
        self.attention_mask_creator = attention_mask_creator

        if rank == 0 and self.args.debug_split_model:
            logger.debug("Pipeline modules created:")
            for i, mod in enumerate(all_pipeline_modules):
                logger.debug(f"  Stage {i}: {mod}")

        return all_pipeline_modules, pipeline_modules, list(pipeline_stages)

    @torch.no_grad()
    def _initialize_params(
        self, all_pipeline_modules, pipeline_modules, stage_indices, missing_buf_only
    ):
        """Distribute initialised parameters from rank 0 to all other ranks.

        Initialising each rank independently would require building N full models
        in CPU memory (one per rank). Instead, rank 0 constructs a single
        fully-initialised CPU model, copies its own stage parameters locally,
        and sends each other rank's parameters via NCCL point-to-point.

        Process
        -------
        1. Rank 0 constructs the full initialised model on CPU.
        2. Rank 0 copies parameters for its own stages directly to its GPU.
        3. Rank 0 streams each other rank's stage parameters to that rank's GPU
           via ``dist.send`` / ``dist.recv``.
        4. Non-rank-0 ranks call ``dist.recv`` to receive their stage parameters.

        Parameters
        ----------
        all_pipeline_modules : list of torch.nn.Module
            All pipeline stage modules across all ranks (on meta device).
        pipeline_modules : list of torch.nn.Module
            Stage modules assigned to the current rank (on this rank's device
            after ``to_empty``).
        stage_indices : list of tuple of int
            Stage index assignments per rank; ``stage_indices[r]`` is the tuple
            of global stage indices assigned to rank ``r``.
        missing_buf_only : bool
            When ``True``, transfer only non-persistent buffers (used when
            resuming from a checkpoint that omitted those buffers). When
            ``False``, transfer all parameters and buffers.

        Notes
        -----
        Each parameter is temporarily moved to the sender's GPU before calling
        ``dist.send``, because NCCL requires device tensors for transfers.
        """

        def make_state_dict(mod, missing_buf_only):
            """
            Build a state dictionary with /all/ the params/buffers,
            as non-persistent buffers are normally excluded from state_dict()

            missing_buf_only: When True, only include the buffers which are missing from
            the "actual" state_dict.
            """
            output_state_dict = {}
            if missing_buf_only:
                state_dict = mod.state_dict()
                for name, p in mod.named_buffers():
                    if name not in state_dict:
                        output_state_dict[name] = p.data
            else:
                # Include parameter alias names for shared parameters
                for name, p in mod.named_parameters(remove_duplicate=False):
                    output_state_dict[name] = p.data
                for name, p in mod.named_buffers(remove_duplicate=False):
                    output_state_dict[name] = p.data
            return output_state_dict

        if self.dist.rank == 0:
            # Construct a fully initialized model on the CPU, which we will use to distribute
            # initialized parameters.
            logger.debug("Constructing model on CPU")
            initialized_model = self._construct_model(device="cpu")
            init_state_dict = make_state_dict(initialized_model, missing_buf_only)
            # Initialize our own parameters first
            for mod in pipeline_modules:
                for name, p in make_state_dict(mod, missing_buf_only).items():
                    p.copy_(init_state_dict[name])

            logger.debug("Distributing params")
            # Send the initialized parameters for the other stages to their
            # respective processes.
            for dst_rank in range(1, self.dist.world_size):
                # Modules owned by dst_rank
                rank_indices = stage_indices[dst_rank]
                logger.debug(
                    f"rank0: Sending initialized params for stages {rank_indices} to rank{dst_rank}"
                )

                for stage_index in rank_indices:
                    mod = all_pipeline_modules[stage_index]

                    # All params and buffers in destination module
                    for name, _ in make_state_dict(mod, missing_buf_only).items():
                        # NCCL can't send between GPU and GPU, so copy each parameter to
                        # our GPU, send it, then free it. Kind of hack'ish, but it works.
                        # See: https://docs.pytorch.org/docs/stable/distributed.html
                        # I believe this will work for CPU to CPU, with gloo, but
                        # have yet to try it.
                        p = init_state_dict[name].to(self.dist.device)
                        if self.args.debug_model_init:
                            logger.debug(f"rank0: Sending {name} to rank{dst_rank}")
                        distributed.send(p, dst=dst_rank)
                        p = None
        else:
            # Load the parameters from rank 0
            rank_indices = stage_indices[self.dist.rank]

            logger.debug(
                f"rank{self.dist.rank}: Receiving initialized params for stages {rank_indices} from rank0"
            )
            for mod in pipeline_modules:
                for name, p in make_state_dict(mod, missing_buf_only).items():
                    if self.args.debug_model_init:
                        logger.debug(f"rank{self.dist.rank}: Receiving {name}")
                    distributed.recv(p, src=0)

    @override
    def _forward_backward_step(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Tensor:
        """Execute a combined forward and backward pass via the pipeline scheduler.

        The scheduler handles activation forwarding between stages and gradient
        backpropagation. Each rank's role depends on which stage(s) it hosts:

        * **First stage** — consumes ``input_ids`` and sends activations downstream.
        * **Middle stages** — receive activations, compute, and send downstream.
        * **Last stage** — receives activations, computes the loss, and initiates
          the backward pass.

        Attention masks and ``position_ids`` are constructed outside the pipeline
        and passed as kwargs rather than through the inter-stage activation stream.
        This is necessary because PyTorch's pipeline transport only handles tensors
        that require gradients; Python objects and non-differentiable tensors would
        raise errors if piped between stages.

        Parameters
        ----------
        input_dict : dict of str to Tensor
            Batch inputs. Must contain ``"input_ids"``; may contain
            ``"position_ids"`` when the splitter supports explicit positions.
        labels : Tensor
            Target token ids for loss computation (only consumed by the last stage).

        Returns
        -------
        Tensor
            Sum of per-microbatch losses on the last stage; ``0.0`` on all other
            stages. The caller broadcasts this to all ranks via
            ``_distributed_loss()``.
        """
        inputs = (input_dict["input_ids"],)

        # Create attention mask externally if supported by splitter
        # This follows TorchTitan pattern to avoid pipeline transport issues
        extra_kwargs = {}
        if self.use_fused_loss:
            extra_kwargs["return_hidden_states"] = True

        if self.attention_mask_creator is not None:
            attention_mask = self.attention_mask_creator(**input_dict)
            extra_kwargs["attention_mask"] = attention_mask
            if (position_ids := input_dict.get("position_ids", None)) is not None:
                extra_kwargs["position_ids"] = position_ids

        # See: https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L377
        targets, losses = (labels, []) if self.pp_has_last_stage else (None, None)
        assert self.scheduler
        with self.amp_context.autocast():
            if self.pp_has_first_stage:
                self.scheduler.step(
                    *inputs, **extra_kwargs, target=targets, losses=cast(list, losses)
                )
            else:
                self.scheduler.step(
                    **extra_kwargs, target=targets, losses=cast(list, losses)
                )

        if self.pp_has_last_stage:
            assert losses
            mean_loss = torch.stack([x.detach().float() for x in losses]).sum()
        else:
            mean_loss = torch.tensor(0.0, device=self.dist.device, dtype=torch.float32)
        return mean_loss

    @torch.no_grad()
    def _pipeline_step_for_generation(
        self,
        input_ids: Tensor,
        attention_mask: Optional[Tensor],
    ) -> Optional[Tensor]:
        """Execute one forward pass through all pipeline stages for text generation.

        All ranks must call this method simultaneously. Works for any value of
        ``stages_per_rank`` and both ``"loop"`` and ``"v"`` stage assignment styles.

        Cross-rank activation transfers use ``dist.batch_isend_irecv`` with
        ``dist.P2POp``, the same primitive the pipeline scheduler uses internally.
        Using the batched async API avoids lazy creation of new per-pair NCCL
        sub-communicators on every call and ensures the two code paths share NCCL
        state cleanly.

        Parameters
        ----------
        input_ids : Tensor
            Token ids of shape ``[batch, seq]``, identical on all ranks.
        attention_mask : Tensor or None
            Optional attention mask created externally (bypasses pipeline transport
            limitations). ``None`` when the splitter does not support external masks.

        Returns
        -------
        Tensor or None
            Logits of shape ``[batch, seq, vocab]`` on the last-stage rank;
            ``None`` on all other ranks.

        Notes
        -----
        Inter-stage activations are assumed to have shape
        ``[batch, seq, model.config.hidden_size]``.
        """
        assert self.pipeline_modules

        forward_kwargs: dict = {}
        if attention_mask is not None:
            forward_kwargs["attention_mask"] = attention_mask

        hidden_states: Optional[Tensor] = None
        batch, seq = input_ids.shape

        with self.amp_context.autocast():
            for stage_k in range(self.n_pipeline_stages):
                rank_k = self._stage_to_rank[stage_k]

                # At each stage boundary, transfer activations if stages are on different ranks.
                if stage_k > 0:
                    rank_prev = self._stage_to_rank[stage_k - 1]
                    if rank_prev != rank_k:
                        if self.dist.rank == rank_prev:
                            assert hidden_states is not None
                            op = distributed.P2POp(
                                distributed.isend,
                                hidden_states.contiguous(),
                                peer=rank_k,
                                group=self.pp_group,
                            )
                            for w in distributed.batch_isend_irecv([op]):
                                w.wait()
                        elif self.dist.rank == rank_k:
                            recv_buf = torch.empty(
                                batch,
                                seq,
                                self.model.config.hidden_size,
                                dtype=self._generation_dtype,
                                device=self.dist.device,
                            )
                            op = distributed.P2POp(
                                distributed.irecv,
                                recv_buf,
                                peer=rank_prev,
                                group=self.pp_group,
                            )
                            for w in distributed.batch_isend_irecv([op]):
                                w.wait()
                            hidden_states = recv_buf
                        # Other ranks: not involved in this boundary, continue loop.

                # Run the module on its owning rank.
                if self.dist.rank == rank_k:
                    local_idx = self._stage_to_local_mod[stage_k]
                    mod = self.pipeline_modules[local_idx]
                    inp = input_ids if stage_k == 0 else hidden_states
                    hidden_states = mod(inp, **forward_kwargs)

        if self.pp_has_last_stage:
            return hidden_states  # logits tensor from last stage
        return None

    @torch.no_grad()
    def pipeline_generate(
        self,
        input_ids: Tensor,
        max_new_tokens: int,
        eos_token_id: int,
        pad_token_id: int,
        do_sample: bool = True,
        temperature: float = 1.0,
        top_k: int = 0,
        repetition_penalty: float = 1.0,
    ) -> Tensor:
        """Generate text autoregressively through all pipeline stages.

        Bypasses the pipeline scheduler so input shapes are not constrained to
        the fixed training batch dimensions. All ranks must call this method
        simultaneously. The full generated sequence (prompt + new tokens) is
        returned on every rank.

        No KV caching is used; each decoding step reprocesses the entire
        sequence. This is acceptable for infrequent, qualitative generation
        checks (e.g. during a callback).

        Parameters
        ----------
        input_ids : Tensor
            Prompt token ids of shape ``[batch, prompt_len]``, same on all
            ranks.
        max_new_tokens : int
            Maximum number of new tokens to generate.
        eos_token_id : int
            Token id that signals end of sequence. Once all sequences in the
            batch have emitted this token, generation stops early.
        pad_token_id : int
            Token id used to pad sequences that have already finished.
        do_sample : bool, optional
            If ``True``, sample from the probability distribution; if
            ``False``, use greedy (argmax) decoding. Default is ``True``.
        temperature : float, optional
            Softmax temperature applied before top-k filtering. Values ``< 1``
            sharpen the distribution; values ``> 1`` flatten it.
            Default is ``1.0``.
        top_k : int, optional
            When ``> 0``, restrict sampling to the top-k logits. ``0`` uses
            the full vocabulary. Default is ``0``.
        repetition_penalty : float, optional
            Multiplicative penalty applied to logits of tokens already present
            in the sequence. ``1.0`` disables the penalty. Default is ``1.0``.

        Returns
        -------
        Tensor
            Generated token ids of shape ``[batch, prompt_len + n_new_tokens]``
            as a ``LongTensor`` on the current device, identical on all ranks.
        """
        # Ensure all ranks reach this point before issuing any generation collectives.
        # This forces a clean synchronization fence after the trainer's eval phase, so
        # any pending scheduler ops on the same NCCL communicators are guaranteed to be
        # drained before our textgen ops start. Without this, our hand-rolled p2p can
        # get interleaved with leftover scheduler state and deadlock.
        distributed.barrier(group=self.pp_group)

        batch_size = input_ids.shape[0]
        generated_ids = input_ids.clone()
        done = torch.zeros(batch_size, dtype=torch.bool, device=self.dist.device)

        # Temporarily bypass torch.compile on the pipeline stage modules for the
        # duration of generation. Compiled modules (e.g. flex_attention + max-autotune)
        # are specialized on the training shapes and fail when called with the varying
        # shapes used during autoregressive decoding. Mirrors the single-rank workaround
        # in TextgenCallback.generate().
        assert self.pipeline_modules
        saved_compiled_calls: list = []
        for mod in self.pipeline_modules:
            compiled_call = getattr(mod, "_compiled_call_impl", None)
            saved_compiled_calls.append(compiled_call)
            if compiled_call is not None:
                mod._compiled_call_impl = mod._call_impl

        try:
            for _ in range(max_new_tokens):
                attention_mask = None
                if self.attention_mask_creator is not None:
                    attention_mask = self.attention_mask_creator(
                        input_ids=generated_ids
                    )

                logits = self._pipeline_step_for_generation(
                    generated_ids, attention_mask
                )

                if self.pp_has_last_stage:
                    assert logits is not None
                    next_logits = logits[:, -1, :].float()  # [batch, vocab]

                    # Sanitize logits before sampling. Unstable model output (common
                    # early in training or with mixed precision) can produce NaN/Inf,
                    # which makes torch.multinomial fail the "probability tensor
                    # contains inf, nan or element < 0" assertion. Replace NaN with 0
                    # and clamp ±Inf to finite values. Top-k masking below uses its
                    # own float("-inf") after this step, so -inf handling here does
                    # not interfere with top-k.
                    next_logits = torch.nan_to_num(
                        next_logits, nan=0.0, posinf=1e4, neginf=-1e4
                    )

                    if repetition_penalty != 1.0:
                        for b in range(batch_size):
                            for tok in set(generated_ids[b].tolist()):
                                if next_logits[b, tok] > 0:
                                    next_logits[b, tok] /= repetition_penalty
                                else:
                                    next_logits[b, tok] *= repetition_penalty

                    if temperature != 1.0:
                        next_logits = next_logits / temperature

                    if top_k > 0:
                        top_values, _ = torch.topk(next_logits, top_k, dim=-1)
                        threshold = top_values[:, -1, None]
                        next_logits = next_logits.masked_fill(
                            next_logits < threshold, float("-inf")
                        )

                    if do_sample:
                        probs = torch.softmax(next_logits, dim=-1)
                        next_tokens = torch.multinomial(probs, 1).squeeze(1)
                    else:
                        next_tokens = next_logits.argmax(dim=-1)
                else:
                    next_tokens = torch.zeros(
                        batch_size, dtype=torch.long, device=self.dist.device
                    )

                distributed.broadcast(next_tokens, src=self.pp_last_stage_rank)

                done = done | (next_tokens == eos_token_id)
                next_tokens = next_tokens.masked_fill(done, pad_token_id)
                generated_ids = torch.cat(
                    [generated_ids, next_tokens.unsqueeze(1)], dim=1
                )

                # Broadcast early-exit flag from last-stage rank to all ranks.
                if self.pp_has_last_stage:
                    stop = torch.tensor(
                        [int(done.all())], dtype=torch.long, device=self.dist.device
                    )
                else:
                    stop = torch.zeros(1, dtype=torch.long, device=self.dist.device)
                distributed.broadcast(stop, src=self.pp_last_stage_rank)
                if stop.item():
                    break
        finally:
            # Restore compiled forwards for training, even if generation raised.
            for mod, compiled_call in zip(self.pipeline_modules, saved_compiled_calls):
                if compiled_call is not None:
                    mod._compiled_call_impl = compiled_call

        return generated_ids

    @override
    def _init_optimizer(self):
        """Initialise the optimizer over all pipeline stage parameters on this rank.

        Collects parameters from every module in ``self.pipeline_modules`` (i.e.
        all stages assigned to this rank) and passes them to
        ``self.optimizer_factory`` to create a single optimizer instance. Also
        registers post-accumulate gradient hooks when
        ``args.fuse_optim_with_backward`` is enabled.
        """
        if self.optimizer is None:
            # Build a named-parameter generator for all of our modules
            def named_parameters(modules):
                for mod in modules:
                    for param in mod.named_parameters():
                        yield param

            assert self.pipeline_modules
            assert self.optimizer_factory
            if self.optimizer_groups is not None:
                opt_params = build_parameter_groups(
                    named_parameters(self.pipeline_modules),
                    self.optimizer_groups,
                    debug=self.args.debug_optimizer_groups,
                )
            else:
                opt_params = named_parameters(self.pipeline_modules)
            self.optimizer = self.optimizer_factory(opt_params)

            if self.args.fuse_optim_with_backward:
                self._total_grad_squared = torch.zeros(
                    1, device=self.args.device, dtype=torch.float32
                )

                for name, p in named_parameters(self.pipeline_modules):
                    if p.requires_grad:
                        hook = partial(
                            optimizer_hook,
                            self.optimizer,
                            self._total_grad_squared,
                            name,
                        )
                        p.register_post_accumulate_grad_hook(hook)

        if self.lr_scheduler is None and self.lr_scheduler_factory is not None:
            self.lr_scheduler = self.lr_scheduler_factory(self.optimizer)

    @contextmanager
    def _eval_schedule_context(self):
        """Temporarily strip backward actions from a pre-computed pipeline schedule.

        Workaround for PyTorch issue: schedules that inherit from
        _PipelineScheduleRuntime (e.g. ScheduleZBVZeroBubble) pre-compute
        pipeline_order_with_comms at __init__ time, including backward actions.
        During eval(), _has_backward is set to False and stage.backward_one_chunk
        early-returns, but the V-schedule special-case code still calls
        stage.get_local_bwd_output() which asserts has_backward, causing a crash.

        This context manager removes backward-related actions (BACKWARD_INPUT,
        BACKWARD_WEIGHT, FULL_BACKWARD, SEND_B, RECV_B, REDUCE_GRAD) from the
        schedule for the duration of eval, then restores the original.
        """
        from torch.distributed.pipelining.schedules import _ComputationType

        _BACKWARD_ACTIONS = frozenset(
            {
                _ComputationType.BACKWARD_INPUT,
                _ComputationType.BACKWARD_WEIGHT,
                _ComputationType.FULL_BACKWARD,
                _ComputationType.SEND_B,
                _ComputationType.RECV_B,
                _ComputationType.REDUCE_GRAD,
            }
        )

        scheduler = self.scheduler
        original = getattr(scheduler, "pipeline_order_with_comms", None)
        if original is not None:
            setattr(
                scheduler,
                "pipeline_order_with_comms",
                {
                    rank: [
                        a
                        for a in actions
                        if a.computation_type not in _BACKWARD_ACTIONS
                    ]
                    for rank, actions in original.items()
                },
            )
        try:
            yield
        finally:
            if original is not None:
                setattr(scheduler, "pipeline_order_with_comms", original)

    @override
    def _prediction_step(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Dict[str, Tensor | None]:
        """
        Execute evaluation forward pass through pipeline scheduler.

        Uses scheduler.eval() instead of scheduler.step() to disable gradient computation
        and skip backward pass. Loss is still computed on last stage for metrics.

        Similar to _forward_backward_step but without gradients. Creates attention masks
        externally for same reasons (PyTorch pipeline transport limitations).

        Parameters
        ----------
        input_dict : dict of str to Tensor
            Batch inputs with ``"input_ids"``.
        labels : Tensor
            Target labels for loss computation (only used by last stage).

        Returns
        -------
        dict
            Dictionary with ``"loss"`` (mean over microbatches), ``"logits"``
            (``None``), and ``"labels"`` (``None``).
        """
        inputs = (input_dict["input_ids"],)

        # Create attention mask externally if supported by splitter
        # This follows TorchTitan pattern to avoid pipeline transport issues
        extra_kwargs = {}
        if self.use_fused_loss:
            extra_kwargs["return_hidden_states"] = True
        if self.attention_mask_creator is not None:
            attention_mask = self.attention_mask_creator(
                input_ids=input_dict["input_ids"]
            )
            extra_kwargs["attention_mask"] = attention_mask

        targets, losses = (labels, []) if self.pp_has_last_stage else (None, None)
        assert self.scheduler
        loss_fn = self.loss_fn
        assert isinstance(loss_fn, RescaleLoss)
        with loss_fn.no_rescale(), self.amp_context.autocast():
            with self._eval_schedule_context():
                if self.pp_has_first_stage:
                    self.scheduler.eval(
                        *inputs,
                        **extra_kwargs,
                        target=targets,
                        losses=cast(list, losses),
                    )
                else:
                    self.scheduler.eval(
                        **extra_kwargs,
                        target=targets,
                        losses=cast(list, losses),
                    )

        # Compute loss on last stage
        if self.pp_has_last_stage:
            assert losses
            mean_loss = torch.stack([x.detach().float() for x in losses]).sum()
        else:
            mean_loss = torch.tensor(0.0, device=self.dist.device, dtype=torch.float32)

        mean_loss = self._distributed_loss(mean_loss)
        return {
            "loss": mean_loss,
            "logits": None,
            "labels": None,
        }

    @override
    def _init_checkpoint_manager(self) -> CheckpointManager:
        """
        Initialize checkpoint manager for distributed pipeline parallel model.

        Unlike single-device trainer, pipeline trainer needs to save model shards
        across all ranks since each rank only has a portion of the model. The
        shard_index tracks which parameters belong to which rank for coordinated
        save/load operations.

        Sets save_on_all_ranks=True so all ranks participate in checkpointing,
        each saving their own pipeline stages.

        Returns
        -------
        CheckpointManager
            Configured for distributed pipeline model saving.
        """
        cp_config = CheckpointConfig(
            output_dir=self.args.output_dir,
            save_total_limit=self.args.save_total_limit,
            save_on_each_node=self.args.save_on_each_node,
            save_safetensors=self.args.save_safetensors,
            save_on_all_ranks=True,
        )

        assert self.model
        assert self.shard_index
        checkpoint_manager = CheckpointManager(
            config=cp_config,
            dist=self.dist,
            model=self.model,
            model_parts=self.pipeline_modules,
            model_preprocessor=self.processing_class,
            stateful_provider=self,
            shard_index=self.shard_index,
        )
        return checkpoint_manager

    @staticmethod
    def _all_reduce_norm(total_norm, norm_type):
        """All-reduce local gradient norms to compute the global norm.

        Each rank has computed a local norm over its pipeline stage parameters.
        This method combines those local norms into the true global gradient norm
        using the appropriate reduction for the requested norm type:

        * **L2** (``norm_type=2``) — sum squared local norms, then take sqrt.
        * **Lp** — sum ``p``-th powers, then take the ``1/p`` root.
        * **L-inf** — max across ranks.

        Parameters
        ----------
        total_norm : Tensor
            Local norm tensor on this rank (scalar).
        norm_type : float
            Exponent of the norm. Use ``2.0`` for L2, ``float("inf")`` for
            L-inf.

        Returns
        -------
        Tensor
            Global gradient norm after all-reduce, same value on all ranks.
        """
        if math.isinf(norm_type):
            dist.all_reduce(total_norm, op=dist.ReduceOp.MAX)
        else:
            total_norm **= norm_type
            dist.all_reduce(total_norm, op=dist.ReduceOp.SUM)
            total_norm **= 1.0 / norm_type
        return total_norm

    @override
    def _clip_grad_norm(self, max_grad_norm, norm_type=2.0) -> Tensor:
        """Compute and optionally clip the gradient norm across all pipeline stages.

        Unlike the single-device trainer, each rank holds gradients only for its
        own pipeline stages, so the local norms must be all-reduced to obtain the
        true global norm before clipping.

        Parameters
        ----------
        max_grad_norm : float or None
            Maximum allowed gradient norm. When ``None`` or ``0``, the norm is
            computed but no clipping is applied.
        norm_type : float, optional
            Type of norm. Default is ``2.0`` (L2 norm).

        Returns
        -------
        Tensor
            Global gradient norm (after all-reduce across pipeline ranks).
        """
        # If fused optimizer, we can't clip, but we can compute the value,
        # which we do from the tensor callacks
        if self.args.fuse_optim_with_backward:
            # Apply sqrt, as we accumulate the sum of the squares
            total_norm = self._total_grad_squared.sqrt()
            self._total_grad_squared -= self._total_grad_squared
            # Collective all-reduce with other ranks
            return self._all_reduce_norm(total_norm, norm_type)

        # Compute norm over all local trainable parameters
        assert self.pipeline_modules

        if False:
            sum = None
            for i, mod in enumerate(self.pipeline_modules):
                for name, p in mod.named_parameters():
                    if p.grad is not None:
                        grad = p.grad
                        norm = grad.square().sum().sqrt()
                        logger.info(f"r{self.dist.rank} m{i} {name} {norm}")

        parameters = [
            p
            for mod in self.pipeline_modules
            for p in mod.parameters()
            if p.grad is not None
        ]

        grads = [p.grad for p in parameters if p.grad is not None]

        total_norm = torch.nn.utils.get_total_norm(
            grads, norm_type=norm_type, foreach=True
        )

        # All-reduce over all ranks
        total_norm = self._all_reduce_norm(total_norm, norm_type)

        if max_grad_norm is None or max_grad_norm == 0:
            return total_norm

        torch.nn.utils.clip_grads_with_norm_(
            parameters,
            max_grad_norm,
            total_norm,
            foreach=True,
        )

        return total_norm

    @override
    def _count_batch_tokens(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Tensor:
        """Count the number of tokens in the current batch for the first pipeline stage.

        Only the first stage receives ``input_ids`` and meaningful labels; all
        other stages return a zero tensor. ``_distributed_tokens()`` then sums
        across ranks to obtain the true batch token count.

        Parameters
        ----------
        input_dict : dict of str to Tensor
            Batch dictionary; ``"input_ids"`` is only populated on the first
            stage.
        labels : Tensor
            Target labels (meaningful only on the first stage).

        Returns
        -------
        Tensor
            Scalar token count on the first stage; zero tensor on other stages.
        """
        if not self.pp_has_first_stage:
            return torch.tensor(0, device=self.args.device, dtype=torch.int64)
        return super()._count_batch_tokens(input_dict, labels)

    @override
    def _distributed_tokens(self, tokens: Tensor) -> Tensor:
        """All-reduce token counts across pipeline stages.

        The first stage contributes the real count; all other stages contribute
        zero. Summing across ranks yields the total token count for the batch.

        Parameters
        ----------
        tokens : Tensor
            Local token count (non-zero only on the first stage).

        Returns
        -------
        Tensor
            Total token count summed across all pipeline ranks.
        """
        dist.all_reduce(tokens, op=dist.ReduceOp.SUM)
        return tokens

    @override
    def _distributed_loss(self, loss: Tensor):
        """Broadcast the loss from the last pipeline stage to all ranks.

        Only the last stage computes a meaningful loss (it alone has the labels).
        All other stages hold ``0.0``. Broadcasting from the last-stage rank
        ensures every rank can log the same loss value.

        Parameters
        ----------
        loss : Tensor
            Scalar loss tensor — meaningful only on the last stage, ``0.0``
            on all other stages.

        Returns
        -------
        Tensor
            Loss value from the last stage, same on all ranks after broadcast.
        """
        distributed.broadcast(loss, src=self.pp_last_stage_rank)
        return loss

    @override
    def _distributed_peak_mem(self, local_peak: int) -> list[int]:
        """All-gather per-rank peak CUDA memory across all pipeline ranks.

        Because each pipeline stage hosts a different subset of the model's
        layers, the memory footprint per rank can differ significantly. The
        per-rank peaks are therefore genuinely informative for capacity
        planning and imbalance detection.

        Parameters
        ----------
        local_peak : int
            Peak CUDA memory allocated on this rank, in bytes.

        Returns
        -------
        list of int
            Peak memory in bytes for each rank, indexed by rank.
        """
        value = torch.tensor(
            [int(local_peak)], dtype=torch.long, device=self.args.device
        )
        gathered = [torch.zeros_like(value) for _ in range(self.dist.world_size)]
        dist.all_gather(gathered, value)
        return [int(t.item()) for t in gathered]

    @override
    def get_state_components(self) -> List[StateComponent]:
        """Return state components for pipeline parallel training.

        Because the model is split across ranks, each rank saves only its own
        stage parameters. The sharing patterns reflect this:

        * ``"model"`` — PER_RANK (each rank holds different stages), required.
        * ``"optimizer"`` — PER_RANK (optimises different parameters), optional.
        * ``"scheduler"`` — REPLICATED (same LR schedule on all ranks), optional.
        * ``"trainer"`` — REPLICATED (same global step on all ranks), optional.
        * ``"dataset"`` — GLOBAL (``DataloaderDispatcher`` with
          ``dp_mesh_dim=None``; rank 0 loads and broadcasts), optional.
        * ``"rng"`` — PER_RANK (each stage may have different dropout), optional.

        Returns
        -------
        list of StateComponent
            All checkpointable state components with their sharing patterns.
        """
        components = []

        # Model - REQUIRED, PER_RANK (different pipeline stages per rank)
        # pipeline_modules contains the stage modules assigned to this rank
        if self.pipeline_modules:
            components.append(
                StateComponent(
                    key="model",
                    stateful=cast(Stateful, self.pipeline_modules),
                    sharing_pattern=SharingPattern.PER_RANK,
                    required=True,  # Model is always required
                )
            )

        # Optimizer - optional, PER_RANK (optimizes different parameters per rank)
        # Each rank's optimizer only contains parameters for its pipeline stages
        if self.optimizer:
            components.append(
                StateComponent(
                    key="optimizer",
                    stateful=self.optimizer,
                    sharing_pattern=SharingPattern.PER_RANK,
                    required=False,
                )
            )

        # LR Scheduler - optional, REPLICATED (same schedule across all ranks)
        # All ranks follow the same learning rate schedule
        if self.lr_scheduler:
            components.append(
                StateComponent(
                    key="scheduler",
                    stateful=self.lr_scheduler,
                    sharing_pattern=SharingPattern.REPLICATED,
                    required=False,
                )
            )

        # Trainer state - optional, REPLICATED (same global step across all ranks)
        # Training progress is synchronized across all pipeline stages
        components.append(
            StateComponent(
                key="trainer",
                stateful=self,
                sharing_pattern=SharingPattern.REPLICATED,
                required=False,
            )
        )

        # Dataset state - optional, GLOBAL (DataloaderDispatcher with pure MP mode)
        # Rank 0 loads data and broadcasts to all ranks (dp_mesh_dim=None)
        if hasattr(self.train_dataloader, "state_dict"):
            components.append(
                StateComponent(
                    key="dataset",
                    stateful=cast(Stateful, self.train_dataloader),
                    sharing_pattern=self._get_dataset_sharing_pattern(),
                    required=False,
                )
            )

        # RNG state - optional, PER_RANK (each rank needs different random numbers)
        # Different dropout patterns, etc. for different pipeline stages
        components.append(
            StateComponent(
                key="rng",
                stateful=RNGState(),
                sharing_pattern=SharingPattern.PER_RANK,
                required=False,
            )
        )

        return components

    @override
    def _get_dataset_sharing_pattern(self) -> SharingPattern:
        """Return the dataset sharing pattern for pipeline parallel training.

        ``PipelineTrainer`` uses ``DataloaderDispatcher`` with
        ``dp_mesh_dim=None`` (pure model-parallel mode). Rank 0 loads the data
        and broadcasts it to all other ranks, so only one dataloader state
        exists globally.

        Returns
        -------
        SharingPattern
            ``SharingPattern.GLOBAL``.
        """
        # Pure model parallelism: all ranks get same batch from rank 0
        # See _wrap() method where DataloaderDispatcher is created with dp_mesh_dim=None
        return SharingPattern.GLOBAL

    @override
    def get_process_groups(self) -> Dict[str, Any]:
        """Return named process groups for checkpoint coordination.

        The checkpoint manager uses this mapping to implement ``PER_GROUP``
        sharing patterns (e.g. saving one copy per pipeline-parallel group).

        Returns
        -------
        dict of str to ProcessGroup
            ``{"pp_group": self.pp_group}`` for pure pipeline parallelism.
        """
        return {
            "pp_group": self.pp_group,
        }

__init__(*, args, model_splitter, pipe_schedule_factory=ScheduleGPipe, **kwargs)

Initialise the pipeline parallel trainer.

Parameters:

Name Type Description Default
args PipelineTrainingArguments or dict

Pipeline training configuration. Dicts are converted via dacite.from_dict(PipelineTrainingArguments, args).

required
model_splitter ModelSplitter

Callable that accepts the model on the meta device and returns all pipeline stage modules, the rank-local stage modules, and PipelineStage objects. See src/forgather/ml/trainer/pipeline/model_splitter.py for the full signature.

required
pipe_schedule_factory callable

Pipeline scheduler factory. ScheduleGPipe (default) uses simple GPipe scheduling. Pass ScheduleZBVZeroBubble or similar for zero-bubble schedules.

ScheduleGPipe
**kwargs

Forwarded to the base Trainer constructor (model_init, train_dataset, optimizer_factory, callbacks, etc.).

{}
Source code in src/forgather/ml/trainer/pipeline/pipeline_trainer.py
def __init__(
    self,
    *,
    args: TPipelineTrainingArguments | dict,
    model_splitter: ModelSplitter,  # Required: function to split model into pipeline stages
    pipe_schedule_factory: PipelineSchedulerFactorT = ScheduleGPipe,  # type: ignore[assignment]
    **kwargs,
):
    """Initialise the pipeline parallel trainer.

    Parameters
    ----------
    args : PipelineTrainingArguments or dict
        Pipeline training configuration. Dicts are converted via
        ``dacite.from_dict(PipelineTrainingArguments, args)``.
    model_splitter : ModelSplitter
        Callable that accepts the model on the meta device and returns
        all pipeline stage modules, the rank-local stage modules, and
        ``PipelineStage`` objects. See
        ``src/forgather/ml/trainer/pipeline/model_splitter.py`` for the
        full signature.
    pipe_schedule_factory : callable, optional
        Pipeline scheduler factory. ``ScheduleGPipe`` (default) uses simple
        GPipe scheduling. Pass ``ScheduleZBVZeroBubble`` or similar for
        zero-bubble schedules.
    **kwargs
        Forwarded to the base ``Trainer`` constructor (``model_init``,
        ``train_dataset``, ``optimizer_factory``, ``callbacks``, etc.).
    """
    if isinstance(args, dict):
        args = cast(
            TPipelineTrainingArguments, from_dict(PipelineTrainingArguments, args)
        )
    super().__init__(args=args, **kwargs)
    self.model_splitter = model_splitter
    self.pipe_schedule_factory = pipe_schedule_factory

    # Zero-bubble schedules split backward into input-grad and weight-grad
    # steps and call torch.autograd.grad(..., retain_graph=True). That
    # conflicts with donated buffers in torch.compile'd backward
    # (e.g. flex_attention). Disable the optimization before any compiled
    # backward has been captured.
    if is_zero_bubble_schedule(pipe_schedule_factory):
        disable_compiled_backward_donated_buffers()

        # torch.compile applied at stage granularity (see _compile_model)
        # collapses the stage interior into a single Python autograd.Function
        # (CompiledFunctionBackward). Zero-bubble's stage_backward_weight
        # then constructs GradientEdge(intermediate, 0) over those nodes and
        # passes them to torch.autograd.grad, which calls _make_grads ->
        # out.node._input_metadata. That attribute is unavailable on Python
        # autograd.Function nodes and the C++ binding raises:
        #   "Attribute '_input_metadata' is invalid for this instance of
        #    _C._FunctionBase. ... legacy access pattern that is no longer
        #    supported."
        # The crash is structural: AOTAutograd flattens the stage interior
        # so the I/W split has nothing to walk between intermediates and
        # parameters. Refuse the combination at init time with a clear
        # diagnostic instead of letting the user hit it mid-training.
        assert not self.args.torch_compile, (
            "PipelineTrainer does not support torch.compile with zero-bubble "
            "schedules (ScheduleZBVZeroBubble, ScheduleInterleavedZeroBubble). "
            "AOTAutograd wraps each compiled stage in a Python autograd.Function "
            "whose internal nodes do not expose _input_metadata, which the "
            "split-backward weight step requires. Use a non-zero-bubble schedule "
            "(e.g. ScheduleInterleaved1F1B) or disable torch_compile."
        )

    assert self.args.mixed_precision != "fp16", (
        "PipelineTrainer does not support fp16 mixed precision (GradScaler is incompatible "
        "with pipeline scheduler's internal backward). Use mixed_precision='bf16' instead."
    )

    if self.args.debug_pipeline:
        logger.setLevel(logging.DEBUG)

    assert (
        self.model is None
    ), "Pipeline trainer only support model_init=fn, where fn is a zero-args Callable, returning a model"
    assert self.model_init, "Pipeline trainer requires a model_init function"

    for batch_size in (
        self.args.per_device_train_batch_size,
        self.args.per_device_eval_batch_size,
    ):
        assert (
            batch_size % self.args.n_microbatches == 0
        ), f"Batch size ({batch_size}) must be evenly divisible by n_microbatches ({self.args.n_microbatches})"
    assert (
        self.args.is_multistage or self.args.stages_per_rank == 1
    ), "Only multistage schedulers may have more than one stages_per_rank"

    # The pipeline requires a fixed shape for the inputs
    self.args.dataloader_drop_last = True

pipeline_generate(input_ids, max_new_tokens, eos_token_id, pad_token_id, do_sample=True, temperature=1.0, top_k=0, repetition_penalty=1.0)

Generate text autoregressively through all pipeline stages.

Bypasses the pipeline scheduler so input shapes are not constrained to the fixed training batch dimensions. All ranks must call this method simultaneously. The full generated sequence (prompt + new tokens) is returned on every rank.

No KV caching is used; each decoding step reprocesses the entire sequence. This is acceptable for infrequent, qualitative generation checks (e.g. during a callback).

Parameters:

Name Type Description Default
input_ids Tensor

Prompt token ids of shape [batch, prompt_len], same on all ranks.

required
max_new_tokens int

Maximum number of new tokens to generate.

required
eos_token_id int

Token id that signals end of sequence. Once all sequences in the batch have emitted this token, generation stops early.

required
pad_token_id int

Token id used to pad sequences that have already finished.

required
do_sample bool

If True, sample from the probability distribution; if False, use greedy (argmax) decoding. Default is True.

True
temperature float

Softmax temperature applied before top-k filtering. Values < 1 sharpen the distribution; values > 1 flatten it. Default is 1.0.

1.0
top_k int

When > 0, restrict sampling to the top-k logits. 0 uses the full vocabulary. Default is 0.

0
repetition_penalty float

Multiplicative penalty applied to logits of tokens already present in the sequence. 1.0 disables the penalty. Default is 1.0.

1.0

Returns:

Type Description
Tensor

Generated token ids of shape [batch, prompt_len + n_new_tokens] as a LongTensor on the current device, identical on all ranks.

Source code in src/forgather/ml/trainer/pipeline/pipeline_trainer.py
@torch.no_grad()
def pipeline_generate(
    self,
    input_ids: Tensor,
    max_new_tokens: int,
    eos_token_id: int,
    pad_token_id: int,
    do_sample: bool = True,
    temperature: float = 1.0,
    top_k: int = 0,
    repetition_penalty: float = 1.0,
) -> Tensor:
    """Generate text autoregressively through all pipeline stages.

    Bypasses the pipeline scheduler so input shapes are not constrained to
    the fixed training batch dimensions. All ranks must call this method
    simultaneously. The full generated sequence (prompt + new tokens) is
    returned on every rank.

    No KV caching is used; each decoding step reprocesses the entire
    sequence. This is acceptable for infrequent, qualitative generation
    checks (e.g. during a callback).

    Parameters
    ----------
    input_ids : Tensor
        Prompt token ids of shape ``[batch, prompt_len]``, same on all
        ranks.
    max_new_tokens : int
        Maximum number of new tokens to generate.
    eos_token_id : int
        Token id that signals end of sequence. Once all sequences in the
        batch have emitted this token, generation stops early.
    pad_token_id : int
        Token id used to pad sequences that have already finished.
    do_sample : bool, optional
        If ``True``, sample from the probability distribution; if
        ``False``, use greedy (argmax) decoding. Default is ``True``.
    temperature : float, optional
        Softmax temperature applied before top-k filtering. Values ``< 1``
        sharpen the distribution; values ``> 1`` flatten it.
        Default is ``1.0``.
    top_k : int, optional
        When ``> 0``, restrict sampling to the top-k logits. ``0`` uses
        the full vocabulary. Default is ``0``.
    repetition_penalty : float, optional
        Multiplicative penalty applied to logits of tokens already present
        in the sequence. ``1.0`` disables the penalty. Default is ``1.0``.

    Returns
    -------
    Tensor
        Generated token ids of shape ``[batch, prompt_len + n_new_tokens]``
        as a ``LongTensor`` on the current device, identical on all ranks.
    """
    # Ensure all ranks reach this point before issuing any generation collectives.
    # This forces a clean synchronization fence after the trainer's eval phase, so
    # any pending scheduler ops on the same NCCL communicators are guaranteed to be
    # drained before our textgen ops start. Without this, our hand-rolled p2p can
    # get interleaved with leftover scheduler state and deadlock.
    distributed.barrier(group=self.pp_group)

    batch_size = input_ids.shape[0]
    generated_ids = input_ids.clone()
    done = torch.zeros(batch_size, dtype=torch.bool, device=self.dist.device)

    # Temporarily bypass torch.compile on the pipeline stage modules for the
    # duration of generation. Compiled modules (e.g. flex_attention + max-autotune)
    # are specialized on the training shapes and fail when called with the varying
    # shapes used during autoregressive decoding. Mirrors the single-rank workaround
    # in TextgenCallback.generate().
    assert self.pipeline_modules
    saved_compiled_calls: list = []
    for mod in self.pipeline_modules:
        compiled_call = getattr(mod, "_compiled_call_impl", None)
        saved_compiled_calls.append(compiled_call)
        if compiled_call is not None:
            mod._compiled_call_impl = mod._call_impl

    try:
        for _ in range(max_new_tokens):
            attention_mask = None
            if self.attention_mask_creator is not None:
                attention_mask = self.attention_mask_creator(
                    input_ids=generated_ids
                )

            logits = self._pipeline_step_for_generation(
                generated_ids, attention_mask
            )

            if self.pp_has_last_stage:
                assert logits is not None
                next_logits = logits[:, -1, :].float()  # [batch, vocab]

                # Sanitize logits before sampling. Unstable model output (common
                # early in training or with mixed precision) can produce NaN/Inf,
                # which makes torch.multinomial fail the "probability tensor
                # contains inf, nan or element < 0" assertion. Replace NaN with 0
                # and clamp ±Inf to finite values. Top-k masking below uses its
                # own float("-inf") after this step, so -inf handling here does
                # not interfere with top-k.
                next_logits = torch.nan_to_num(
                    next_logits, nan=0.0, posinf=1e4, neginf=-1e4
                )

                if repetition_penalty != 1.0:
                    for b in range(batch_size):
                        for tok in set(generated_ids[b].tolist()):
                            if next_logits[b, tok] > 0:
                                next_logits[b, tok] /= repetition_penalty
                            else:
                                next_logits[b, tok] *= repetition_penalty

                if temperature != 1.0:
                    next_logits = next_logits / temperature

                if top_k > 0:
                    top_values, _ = torch.topk(next_logits, top_k, dim=-1)
                    threshold = top_values[:, -1, None]
                    next_logits = next_logits.masked_fill(
                        next_logits < threshold, float("-inf")
                    )

                if do_sample:
                    probs = torch.softmax(next_logits, dim=-1)
                    next_tokens = torch.multinomial(probs, 1).squeeze(1)
                else:
                    next_tokens = next_logits.argmax(dim=-1)
            else:
                next_tokens = torch.zeros(
                    batch_size, dtype=torch.long, device=self.dist.device
                )

            distributed.broadcast(next_tokens, src=self.pp_last_stage_rank)

            done = done | (next_tokens == eos_token_id)
            next_tokens = next_tokens.masked_fill(done, pad_token_id)
            generated_ids = torch.cat(
                [generated_ids, next_tokens.unsqueeze(1)], dim=1
            )

            # Broadcast early-exit flag from last-stage rank to all ranks.
            if self.pp_has_last_stage:
                stop = torch.tensor(
                    [int(done.all())], dtype=torch.long, device=self.dist.device
                )
            else:
                stop = torch.zeros(1, dtype=torch.long, device=self.dist.device)
            distributed.broadcast(stop, src=self.pp_last_stage_rank)
            if stop.item():
                break
    finally:
        # Restore compiled forwards for training, even if generation raised.
        for mod, compiled_call in zip(self.pipeline_modules, saved_compiled_calls):
            if compiled_call is not None:
                mod._compiled_call_impl = compiled_call

    return generated_ids

get_state_components()

Return state components for pipeline parallel training.

Because the model is split across ranks, each rank saves only its own stage parameters. The sharing patterns reflect this:

  • "model" — PER_RANK (each rank holds different stages), required.
  • "optimizer" — PER_RANK (optimises different parameters), optional.
  • "scheduler" — REPLICATED (same LR schedule on all ranks), optional.
  • "trainer" — REPLICATED (same global step on all ranks), optional.
  • "dataset" — GLOBAL (DataloaderDispatcher with dp_mesh_dim=None; rank 0 loads and broadcasts), optional.
  • "rng" — PER_RANK (each stage may have different dropout), optional.

Returns:

Type Description
list of StateComponent

All checkpointable state components with their sharing patterns.

Source code in src/forgather/ml/trainer/pipeline/pipeline_trainer.py
@override
def get_state_components(self) -> List[StateComponent]:
    """Return state components for pipeline parallel training.

    Because the model is split across ranks, each rank saves only its own
    stage parameters. The sharing patterns reflect this:

    * ``"model"`` — PER_RANK (each rank holds different stages), required.
    * ``"optimizer"`` — PER_RANK (optimises different parameters), optional.
    * ``"scheduler"`` — REPLICATED (same LR schedule on all ranks), optional.
    * ``"trainer"`` — REPLICATED (same global step on all ranks), optional.
    * ``"dataset"`` — GLOBAL (``DataloaderDispatcher`` with
      ``dp_mesh_dim=None``; rank 0 loads and broadcasts), optional.
    * ``"rng"`` — PER_RANK (each stage may have different dropout), optional.

    Returns
    -------
    list of StateComponent
        All checkpointable state components with their sharing patterns.
    """
    components = []

    # Model - REQUIRED, PER_RANK (different pipeline stages per rank)
    # pipeline_modules contains the stage modules assigned to this rank
    if self.pipeline_modules:
        components.append(
            StateComponent(
                key="model",
                stateful=cast(Stateful, self.pipeline_modules),
                sharing_pattern=SharingPattern.PER_RANK,
                required=True,  # Model is always required
            )
        )

    # Optimizer - optional, PER_RANK (optimizes different parameters per rank)
    # Each rank's optimizer only contains parameters for its pipeline stages
    if self.optimizer:
        components.append(
            StateComponent(
                key="optimizer",
                stateful=self.optimizer,
                sharing_pattern=SharingPattern.PER_RANK,
                required=False,
            )
        )

    # LR Scheduler - optional, REPLICATED (same schedule across all ranks)
    # All ranks follow the same learning rate schedule
    if self.lr_scheduler:
        components.append(
            StateComponent(
                key="scheduler",
                stateful=self.lr_scheduler,
                sharing_pattern=SharingPattern.REPLICATED,
                required=False,
            )
        )

    # Trainer state - optional, REPLICATED (same global step across all ranks)
    # Training progress is synchronized across all pipeline stages
    components.append(
        StateComponent(
            key="trainer",
            stateful=self,
            sharing_pattern=SharingPattern.REPLICATED,
            required=False,
        )
    )

    # Dataset state - optional, GLOBAL (DataloaderDispatcher with pure MP mode)
    # Rank 0 loads data and broadcasts to all ranks (dp_mesh_dim=None)
    if hasattr(self.train_dataloader, "state_dict"):
        components.append(
            StateComponent(
                key="dataset",
                stateful=cast(Stateful, self.train_dataloader),
                sharing_pattern=self._get_dataset_sharing_pattern(),
                required=False,
            )
        )

    # RNG state - optional, PER_RANK (each rank needs different random numbers)
    # Different dropout patterns, etc. for different pipeline stages
    components.append(
        StateComponent(
            key="rng",
            stateful=RNGState(),
            sharing_pattern=SharingPattern.PER_RANK,
            required=False,
        )
    )

    return components

get_process_groups()

Return named process groups for checkpoint coordination.

The checkpoint manager uses this mapping to implement PER_GROUP sharing patterns (e.g. saving one copy per pipeline-parallel group).

Returns:

Type Description
dict of str to ProcessGroup

{"pp_group": self.pp_group} for pure pipeline parallelism.

Source code in src/forgather/ml/trainer/pipeline/pipeline_trainer.py
@override
def get_process_groups(self) -> Dict[str, Any]:
    """Return named process groups for checkpoint coordination.

    The checkpoint manager uses this mapping to implement ``PER_GROUP``
    sharing patterns (e.g. saving one copy per pipeline-parallel group).

    Returns
    -------
    dict of str to ProcessGroup
        ``{"pp_group": self.pp_group}`` for pure pipeline parallelism.
    """
    return {
        "pp_group": self.pp_group,
    }

forgather.ml.trainer.pipeline.pipeline_trainer.PipelineTrainingArguments dataclass

Bases: TrainingArguments

Training arguments for pipeline parallel training.

Pipeline parallelism partitions a model across multiple GPUs, each handling one or more sequential stages. Input batches are split into microbatches that flow through the stages, allowing overlapped computation to keep all GPUs busy.

See the PyTorch pipeline parallelism documentation for background: https://docs.pytorch.org/docs/stable/distributed.pipelining.html

Parameters:

Name Type Description Default
n_microbatches int

Number of microbatches to split each batch into. More microbatches improve pipeline efficiency (fewer bubbles) but increase memory usage. The batch size must be evenly divisible by n_microbatches. Typical values are 4–16 depending on pipeline depth and memory constraints. Default is 4.

4
stages_per_rank int

Number of pipeline stages hosted on each GPU. Most schedulers use 1. Multi-stage schedulers (e.g. ScheduleZBVZeroBubble) assign multiple stages per rank to reduce pipeline bubbles. Only set > 1 together with is_multistage=True. Default is 1.

1
pp_stage_type str

Stage-to-rank assignment pattern. "loop" uses round-robin (e.g. 4 stages on 2 ranks: rank0=[0,2], rank1=[1,3]). "v" uses the V-pattern required by ZeroBubble schedulers (see https://arxiv.org/pdf/2401.10241). Default is "loop".

'loop'
is_multistage bool

Set True when the scheduler inherits from PipelineScheduleMulti (e.g. ScheduleZBVZeroBubble). Leave False for single-stage schedulers such as ScheduleGPipe. Default is False.

False
debug_pipeline bool

Enable debug-level logging for the pipeline scheduler. Internal development flag. Default is False.

False
debug_split_model bool

Log pipeline module details after splitting. Internal development flag. Default is False.

False
debug_model_params bool

Log all parameter and buffer devices/dtypes after model construction. Internal development flag. Default is False.

False
debug_model_init bool

Log every send/recv during parameter distribution from rank 0. Internal development flag. Default is False.

False
Notes

model_splitter is passed to PipelineTrainer.__init__() rather than stored here because it is a callable, not a primitive serialisable type.

Source code in src/forgather/ml/trainer/pipeline/pipeline_trainer.py
@dataclass(kw_only=True)
class PipelineTrainingArguments(TrainingArguments):
    """Training arguments for pipeline parallel training.

    Pipeline parallelism partitions a model across multiple GPUs, each handling
    one or more sequential stages. Input batches are split into microbatches that
    flow through the stages, allowing overlapped computation to keep all GPUs busy.

    See the PyTorch pipeline parallelism documentation for background:
    https://docs.pytorch.org/docs/stable/distributed.pipelining.html

    Parameters
    ----------
    n_microbatches : int, optional
        Number of microbatches to split each batch into. More microbatches
        improve pipeline efficiency (fewer bubbles) but increase memory usage.
        The batch size must be evenly divisible by ``n_microbatches``. Typical
        values are 4–16 depending on pipeline depth and memory constraints.
        Default is ``4``.
    stages_per_rank : int, optional
        Number of pipeline stages hosted on each GPU. Most schedulers use
        ``1``. Multi-stage schedulers (e.g. ``ScheduleZBVZeroBubble``) assign
        multiple stages per rank to reduce pipeline bubbles. Only set ``> 1``
        together with ``is_multistage=True``. Default is ``1``.
    pp_stage_type : str, optional
        Stage-to-rank assignment pattern. ``"loop"`` uses round-robin (e.g. 4
        stages on 2 ranks: rank0=[0,2], rank1=[1,3]). ``"v"`` uses the
        V-pattern required by ZeroBubble schedulers (see
        https://arxiv.org/pdf/2401.10241). Default is ``"loop"``.
    is_multistage : bool, optional
        Set ``True`` when the scheduler inherits from
        ``PipelineScheduleMulti`` (e.g. ``ScheduleZBVZeroBubble``). Leave
        ``False`` for single-stage schedulers such as ``ScheduleGPipe``.
        Default is ``False``.
    debug_pipeline : bool, optional
        Enable debug-level logging for the pipeline scheduler. Internal
        development flag. Default is ``False``.
    debug_split_model : bool, optional
        Log pipeline module details after splitting. Internal development
        flag. Default is ``False``.
    debug_model_params : bool, optional
        Log all parameter and buffer devices/dtypes after model construction.
        Internal development flag. Default is ``False``.
    debug_model_init : bool, optional
        Log every send/recv during parameter distribution from rank 0.
        Internal development flag. Default is ``False``.

    Notes
    -----
    ``model_splitter`` is passed to ``PipelineTrainer.__init__()`` rather than
    stored here because it is a callable, not a primitive serialisable type.
    """

    debug_pipeline: bool = False
    debug_split_model: bool = False
    debug_model_params: bool = False
    debug_model_init: bool = False
    n_microbatches: int = 4
    stages_per_rank: int = 1
    pp_stage_type: str = "loop"
    is_multistage: bool = False