# A light-weight Trainer with an API close enough to "transformers.Trainer"
# to act as a stand-in for basic use-cases.
import gc
import logging
import os
import sys
import time
from collections.abc import Sized
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (
    Any,
    Dict,
    Generic,
    Iterator,
    Optional,
    Protocol,
    Tuple,
    Type,
    TypeGuard,
    TypeVar,
    cast,
    override,
)

import torch
import torchdata.nodes as tn
from dacite import from_dict
from torch import Tensor
from torch import distributed as dist
from torch.utils.data import DataLoader
from torchdata.stateful_dataloader import StatefulDataLoader

from forgather.ml.construct import torch_dtype
from forgather.ml.datasets import sync_dataset_state_from_dataloader
from forgather.ml.utils import default_dtype

from ..distributed import DistributedEnvInterface, prefix_logger_rank
from ..loss import RescaleLoss
from ..no_init_weights import no_init_weights
from ..optim.multiopt import Multiopt
from ..optim.opt_utils import OptimGroupMap, build_optimizer_buckets
from ..sharded_checkpoint import (
    create_sharing_metadata,
    find_latest_checkpoint,
    next_checkpoint_path,
    retie_parameters,
    save_checkpoint_metrics,
)
from .amp import AMPContext
from .base_trainer import BaseTrainer, BaseTrainingArguments, logits_from_outputs
from .callbacks.default_callbacks import InfoCallback, ProgressCallback
from .checkpoint_manager import CheckpointConfig, CheckpointManager
from .periodic_function import PeriodicFunction
from .trainer_types import (
    BaseDataset,
    CheckpointInterface,
    EnableCheckpointFnT,
    FusedLossFactoryT,
    IntervalStrategy,
    LossFunctionT,
    LRSchedulerFactoryT,
    OptimizerFactoryT,
    OptimizerT,
    TrainerCallback,
)
from .trainer_types import TrainerState as BaseTrainerState
from .trainer_types import (
    TrainOutput,
)

logger = logging.getLogger(__name__)
prefix_logger_rank(logger)


# Type checking protocols
class ModelWithCheckpointing(Protocol):
    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: Any):
        pass


class HasBatchSize(Protocol):
    batch_size: int


class HasMainInputName(Protocol):
    main_input_name: str


def has_gradient_checkpointing_enable(obj: object) -> TypeGuard[ModelWithCheckpointing]:
    return hasattr(obj, "gradient_checkpointing_enable")


def has_batch_size(obj: object) -> TypeGuard[HasBatchSize]:
    return hasattr(obj, "batch_size")


def has_main_input_name(obj: object) -> TypeGuard[HasMainInputName]:
    return hasattr(obj, "main_input_name")


def enable_hf_activation_checkpointing(
    rank, module, gradient_checkpointing_kwargs=None
):
    """
    Enable activation checkpointing via Huggingface protocol
    """

    if has_gradient_checkpointing_enable(module):
        if rank == 0:
            logger.info("rank0: Enabling HF gradient checkpointing")

        if gradient_checkpointing_kwargs is None:
            gradient_checkpointing_kwargs = dict(
                use_reentrant=False,
            )
        module.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
    else:
        logger.warning(
            "rank{rank}: Gradient HF checkpointing requested, but model does not support it!"
        )


@dataclass(kw_only=True)
class TrainerState(BaseTrainerState):
    # --- Not in HF TrainerState ---
    # The total number of processes used for training
    num_processes: int = 1  # Non-standard
    # The number of batches in an epoch
    epoch_train_steps: int = 0


@dataclass(kw_only=True)
class TrainingArguments(BaseTrainingArguments):
    """
    Training arguments specific to the simple Trainer implementation.

    Extends BaseTrainingArguments with memory optimization and model construction options.
    Maintains compatibility with HuggingFace Trainer API where possible.
    """

    # Ratio of reserved to total GPU memory to trigger GC
    # If OOM from fragmentation, lower ratio
    gc_threshold: float = 0.5

    # Construct model on meta-device and materialize directly on device
    # default: Construct model on CPU and move to device. Safest option, works in all cases.
    #          Uses no_init_weights() context when loading checkpoint to skip initialization.
    # device:  Construct model directly on device with initialization. Faster than default
    #          but may fail when model needs sharding across devices. Use when checkpoint
    #          doesn't save all buffers (e.g., RoPE).
    # meta:    Construct on meta device (no memory backing) and materialize as empty tensors
    #          on target device. Fastest option but requires loading checkpoint since model
    #          is uninitialized. May have issues with buffers not saved in checkpoint.
    construct_model_on: str = "default"  # "default" | "meta" | "device"

    # https://pytorch.org/blog/activation-checkpointing-techniques/
    # Requires "torch_compile = True" option
    activation_memory_budget: float | None = None

    # Combine gradient calculation with optimizer step, to save memory.
    # As each gradient is computed during backward(), it's immediately applied by the
    # optimizer and freed. Incompatible with max_grad_norm and gradient_accumulation_steps > 1.
    # Greatest memory savings when combined with gradient checkpointing.
    # https://docs.pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
    fuse_optim_with_backward: bool = False

    # The step at which to start collecting speed metrics
    # We default to 1, to remove the effects from torch.compile().
    # Set this to 0 to include all steps or > 0 for compile warmup time.
    speed_metrics_start_step: int = 1

    # If the train dataset has a `set_epoch(epoch: int)` method, call it at the start of each epoch.
    set_dataset_epoch: bool = True

    # Debug grouped optimizer assignments
    debug_optimizer_groups: bool = False


@contextmanager
def set_train(model: torch.nn.Module, mode: bool):
    """
    Context manager which saves the mode (train/eval) on entry,
    then sets the specified mode (train = True), and finally
    restores the original mode on exit.
    """
    previous_mode = model.training
    try:
        model.train(mode)
        yield
    finally:
        model.train(previous_mode)


def maybe_cleanup_memory(alloc_threshold):
    """
    Release unused CUDA cached memory when allocation pressure is high.

    The CUDA caching allocator reserves blocks for reuse. This function
    releases those cached-but-unused blocks only when two conditions hold:

    1. Allocated memory exceeds ``alloc_threshold`` fraction of total GPU memory.
    2. There is meaningful reclaimable memory (reserved - allocated > 10% of total).

    Without condition 2, ``empty_cache()`` would call ``cudaDeviceSynchronize``
    for nothing -- all reserved memory is actually in use.

    Python ``gc.collect()`` runs only when ``empty_cache`` actually released
    memory, since reference-cycle collection is rarely needed in training
    loops and adds ~100-300 ms of overhead.

    Parameters
    ----------
    alloc_threshold : float
        Ratio of allocated to total GPU memory (0.0–1.0).
        Cleanup is considered when usage exceeds this.
    """
    if not torch.cuda.is_available():
        return

    allocated = torch.cuda.memory_allocated()
    max_memory = torch.cuda.get_device_properties(0).total_memory

    if allocated / max_memory <= alloc_threshold:
        return

    # Only release cache when there is a meaningful amount of reclaimable
    # memory (reserved blocks that are not currently allocated).
    reserved = torch.cuda.memory_reserved()
    reclaimable = reserved - allocated
    if reclaimable / max_memory < 0.10:
        return

    torch.cuda.empty_cache()

    # Run Python GC only if empty_cache actually freed something -- meaning
    # there might be dead Python objects whose reference cycles prevented
    # their GPU tensors from being deallocated.
    if torch.cuda.memory_reserved() < reserved:
        gc.collect()


def optimizer_hook(optimizer, total_grad_squared, name, parameter):
    """
    Hook for fusing optimizer step with backward pass.

    This hook is registered via register_post_accumulate_grad_hook() when
    fuse_optim_with_backward=True. As each gradient is computed during backward(),
    it's immediately applied by the optimizer and freed, reducing peak memory usage.

    Greatest memory savings when combined with gradient checkpointing.

    Parameters
    ----------
    optimizer : torch.optim.Optimizer
        The optimizer instance to apply the gradient update.
    total_grad_squared : Tensor
        Accumulator for computing total gradient norm across all parameters.
    name : str
        Parameter name (for debugging).
    parameter : torch.nn.Parameter
        The parameter whose gradient was just computed.
    """
    if total_grad_squared is not None:
        total_grad_squared += parameter.grad.square().sum().to(dtype=torch.float32)
        # norm = parameter.grad.square().sum().sqrt()
        # logger.info(f"{name} {norm}")
    optimizer.step()
    optimizer.zero_grad()


# Help static-type-checking correctly infer the type of "args"
TTrainingArguments = TypeVar("TTrainingArguments", bound=TrainingArguments)


class Trainer(BaseTrainer[TTrainingArguments], Generic[TTrainingArguments]):
    """
    A lightweight, single-device trainer with API close to transformers.Trainer.

    This trainer provides a simplified, more comprehensible implementation of the
    HuggingFace Trainer, intended as a drop-in replacement for basic use cases.

    Key features:
    - Compatible with HF Trainer API for basic training workflows
    - Memory optimizations: fused loss, fused optimizer/backward, activation checkpointing
    - Flexible model construction: default/meta/device modes for different memory/speed tradeoffs
    - Full checkpoint management: saves/restores model, optimizer, scheduler, dataset state
    - Best model tracking via load_best_model_at_end

    For distributed training, see AccelTrainer (data parallel via Accelerate) and
    PipelineTrainer (pipeline parallelism).

    Basic usage:
        trainer = Trainer(
            model=model,
            args=TrainingArguments(...),
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            optimizer_factory=optimizer_factory,
            lr_scheduler_factory=lr_scheduler_factory,
        )
        trainer.train()
    """

    args: TTrainingArguments
    dist: DistributedEnvInterface
    optimizer_factory: OptimizerFactoryT | None
    lr_scheduler_factory: LRSchedulerFactoryT | None
    enable_activation_checkpoint_fn: EnableCheckpointFnT | None
    fused_loss_factory: FusedLossFactoryT | None
    optimizer_groups: OptimGroupMap | None

    max_steps: int
    epoch_train_steps: int
    do_train: bool
    do_eval: bool
    use_fused_loss: bool
    gradient_accumulation_step: int

    @classmethod
    def default_callbacks(cls):
        return [ProgressCallback(), InfoCallback()]

    def __init__(
        self,
        *,
        args: TTrainingArguments | dict,
        distributed_env: DistributedEnvInterface,
        optimizer_factory: Optional[OptimizerFactoryT] = None,
        # Alternative, for compatibility with transformers.Trainer
        optimizer_cls_and_kwargs: Optional[
            Tuple[Type[OptimizerT], Dict[str, Any]]
        ] = None,
        lr_scheduler_factory: Optional[LRSchedulerFactoryT] = None,
        enable_activation_checkpoint_fn: Optional[
            EnableCheckpointFnT
        ] = enable_hf_activation_checkpointing,
        fused_loss_factory: Optional[FusedLossFactoryT] = None,
        optimizer_groups: Optional[OptimGroupMap] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        args : TrainingArguments or dict
            Training configuration. Accepts a ``TrainingArguments`` instance or a
            plain dict (converted automatically). See ``TrainingArguments`` for all options.
        distributed_env : DistributedEnvInterface
            Distributed environment object providing rank/device information.
            Use ``DistributedEnvironment`` for real training or ``StaticDistributedEnvironment``
            for single-process runs. The trainer asserts ``world_size == 1`` for non-distributed
            subclasses.
        optimizer_factory : callable, optional
            Callable that accepts ``model.named_parameters()`` and returns an optimizer.
            If not provided and ``optimizer_cls_and_kwargs`` is also absent, defaults to
            AdamW with parameters from ``args``.
        optimizer_cls_and_kwargs : tuple, optional
            HuggingFace Trainer-compatible alternative to ``optimizer_factory``.
            A ``(optimizer_class, kwargs_dict)`` pair. Ignored when ``optimizer_factory``
            is supplied.
        lr_scheduler_factory : callable, optional
            Callable that accepts the optimizer and returns an LR scheduler.
            If not provided and ``args.lr_scheduler_type`` is set, falls back to
            ``transformers.get_scheduler``.
        enable_activation_checkpoint_fn : callable, optional
            Called as ``fn(rank, model)`` to enable activation (gradient) checkpointing.
            Defaults to ``enable_hf_activation_checkpointing``. Set to ``None`` to
            disable activation checkpointing even when ``args.gradient_checkpointing=True``.
        fused_loss_factory : callable, optional
            If provided, enables fused logits-loss computation. Called with the model's
            output embedding layer and returns a loss function. Requires the model to
            support ``get_output_embeddings()`` and ``return_hidden_states``.
        optimizer_groups : OptimGroupMap, optional
            Parameter group configuration for the optimizer. Allows different
            hyperparameters (lr, weight_decay) for different parameter subsets.
        **kwargs
            Passed to ``BaseTrainer``: ``model``, ``model_init``, ``train_dataset``,
            ``eval_dataset``, ``loss_fn``, ``data_collator``, ``processing_class``,
            ``callbacks``, ``optimizer``, ``lr_scheduler``.
        """
        if isinstance(args, dict):
            args = cast(TTrainingArguments, from_dict(TrainingArguments, args))
        super().__init__(args=args, **kwargs)

        # HF Trainer compatibility.
        if not optimizer_factory:
            if not optimizer_cls_and_kwargs:
                optimizer_factory = partial(  # type: ignore[assignment]
                    torch.optim.AdamW,
                    lr=args.learning_rate,
                    betas=(args.adam_beta1, args.adam_beta2),
                    weight_decay=args.weight_decay,
                    eps=args.adam_epsilon,
                )
            else:
                optimizer_factory = partial(
                    optimizer_cls_and_kwargs[0], **optimizer_cls_and_kwargs[1]
                )
        self.optimizer_groups = optimizer_groups
        self.dist = distributed_env
        self.optimizer_factory = optimizer_factory
        self.lr_scheduler_factory = lr_scheduler_factory
        self.enable_activation_checkpoint_fn = enable_activation_checkpoint_fn
        self.fused_loss_factory = fused_loss_factory
        self.use_fused_loss = False
        assert (self.model is not None) or (
            self.model_init is not None
        ), "Either a model or a model constructor must be specified."

        assert (
            self.args.max_grad_norm is None or not self.args.fuse_optim_with_backward
        ), "max_grad_norm is incompatible with fuse_optim_with_backward"

        assert (
            self.args.gradient_accumulation_steps == 1
            or not self.args.fuse_optim_with_backward
        ), "gradient_accumulation_steps={self.args.gradient_accumulation_steps} is incompatible with fuse_optim_with_backward"

        assert (
            self.args.mixed_precision != "fp16"
            or not self.args.fuse_optim_with_backward
        ), (
            "fp16 mixed precision with GradScaler is incompatible with fuse_optim_with_backward. "
            "Use mixed_precision='bf16' instead (no GradScaler needed)."
        )

        assert (
            self.loss_fn or self.args.gradient_accumulation_steps == 1
        ), f"gradient_accumulation_steps [{self.args.gradient_accumulation_steps}] > 1 requires loss_fn"

        if self.data_collator is None:
            self.data_collator = torch.utils.data.default_collate

        # Compute FLOPs per token for tracking
        # Note: model not initialized yet, will be set in _prepare()

    def _compute_flops_per_token(self) -> float:
        """
        Estimate FLOPs per token for forward + backward pass.

        Uses the standard C = 6N approximation where N = non-embedding parameters.
        Counts each multiply-accumulate as 2 FLOPs, consistent with hardware specs.
        See ModelParameterStats.flops_per_token for references.
        """
        assert self.model is not None
        from forgather.ml.utils import count_parameters

        return count_parameters(self.model).flops_per_token

    def _count_batch_tokens(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Tensor:
        """
        Count non-padding tokens using labels tensor.

        Uses labels as the primary source since the cross-entropy ignore_index (-100)
        marks padding and special tokens that should not be counted.

        Returns a GPU tensor to avoid forcing a GPU-CPU synchronization via .item(),
        which would stall the CPU waiting for any pending compiled graph execution.

        Parameters
        ----------
        input_dict : dict
            Batch input dictionary (unused in base, available for overrides).
        labels : Tensor
            Target labels with -100 (ignore_index) marking padding/special positions.

        Returns
        -------
        Tensor
            Count of non-ignored tokens in the batch.
        """
        # Labels have -100 for padding/special tokens; count only real target tokens
        return (labels != -100).sum()

    def _distributed_tokens(self, tokens: Tensor) -> Tensor:
        """
        Aggregate token counts across processes.

        Base implementation for single-device training.
        Distributed trainers override to sum across ranks.

        Parameters
        ----------
        tokens : Tensor
            Token count tensor from the current process.

        Returns
        -------
        Tensor
            Token count tensor (single-device) or aggregated across ranks (distributed).
        """
        return tokens

    def _distributed_peak_mem(self, local_peak: int) -> list[int]:
        """
        Gather per-rank peak CUDA memory into a list on every rank.

        Base implementation is single-device and returns a one-element list.
        Distributed trainers override to all_gather across ranks so logging
        reflects per-rank high-water marks.

        Parameters
        ----------
        local_peak : int
            ``torch.cuda.max_memory_allocated`` for this rank, in bytes.

        Returns
        -------
        list of int
            Peak bytes indexed by rank (length ``world_size``, or 1 for single-device).
        """
        return [int(local_peak)]

    def _init_distributed(self):
        """
        Subclasses are expected to override, if they support distributed training.
        If distributed training is supported, set the following variables:
          self.is_local_process_zero
          self.is_world_process_zero
          self.num_processes
        """
        assert (
            self.dist.world_size == 1
        ), "'Trainer' does not support distributed training. See subclasses for implementations which do support it."

    def _init_device(self):
        """Update / init trainer's device"""
        # If unspecified, set a default device
        if self.args.device is None:
            self.args.device = self.dist.device
        # Override for debug.
        if self.args.use_cpu:
            self.args.device = "cpu"

    def _get_dataloader(self, dataset, batch_size):
        if not isinstance(dataset, tn.BaseNode | DataLoader):
            dataloader_kwargs = {
                "batch_size": batch_size,
                "collate_fn": self.data_collator,
                "drop_last": self.args.dataloader_drop_last,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "prefetch_factor": self.args.dataloader_prefetch_factor,
                "persistent_workers": self.args.dataloader_persistent_workers,
            }

            # Use StatefulDataLoader for datasets with state if available and requested
            return StatefulDataLoader(dataset, **dataloader_kwargs)
        else:
            return dataset

    @override
    def _prepare(
        self, train_dataset: Optional[BaseDataset], eval_dataset: Optional[BaseDataset]
    ) -> None:
        """
        Prepare for training and/or evaluation
        """
        self._init_distributed()
        self._init_device()

        # Initialize AMP context for mixed precision training
        device_type = (
            self.args.device
            if isinstance(self.args.device, str)
            else str(self.args.device)
        )
        if "cuda" in device_type:
            device_type = "cuda"
        self.amp_context = AMPContext(
            mixed_precision=self.args.mixed_precision,
            device_type=device_type,
        )
        if self.amp_context.enabled:
            logger.info(
                f"Mixed precision training enabled: {self.args.mixed_precision}"
            )

        # Set the random seed
        if self.args.seed != -1:
            import random

            torch.manual_seed(self.args.seed)
            random.seed(self.args.seed)

        if self.args.activation_memory_budget:
            logger.info(
                f"Setting memory budget to {self.args.activation_memory_budget}"
            )
            try:
                torch._functorch.config.activation_memory_budget = (  # type: ignore[attr-defined]
                    self.args.activation_memory_budget
                )
            except AttributeError:
                logger.warning(
                    "PyTorch does not appear to support the experimental activation_memory_budget API"
                )

        # Resolve resume_from_checkpoint before model construction, since the
        # construction strategy depends on whether we're resuming (e.g., "meta"
        # requires a checkpoint, "default" skips init when resuming).
        # After this block, resume_from_checkpoint is either a path string or False.
        self._resolve_checkpoint()

        self._init_dataloaders(train_dataset, eval_dataset)
        self._prepare_model()
        if self.args.torch_compile:
            logger.info(
                f"Compiling model: backend={self.args.torch_compile_backend}, mode={self.args.torch_compile_mode}, "
                f"dynamic={self.args.torch_compile_dynamic}, fullgraph={self.args.torch_compile_full_graph}, "
                f"activation_memory_budget={self.args.activation_memory_budget}"
            )

            if os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", 0):
                logger.warning(
                    "Computing paritioner Pareto memory budget -- be patient, this takes time..."
                )
            self._compile_model()

        if self.do_train:
            self._init_optimizer()

        self._wrap()
        self.state = self._init_state()
        self.checkpoint_manager = self._init_checkpoint_manager()

        # Restore from checkpoint (path already resolved by _resolve_checkpoint)
        if self.args.resume_from_checkpoint:
            self.load_checkpoint(self.args.resume_from_checkpoint)
            # Non-persistent buffers (e.g., RotaryEmbedding's inv_freq) are not
            # saved in checkpoints. When the model was constructed on meta device,
            # these buffers were never initialized. Compute them now.
            if self.args.construct_model_on == "meta" and self.model is not None:
                self._initialize_non_persistent_buffers(self.model)

        self._dispatch_event("on_init_end")

    def _resolve_checkpoint(self) -> None:
        """Resolve resume_from_checkpoint to a concrete path or False.

        Called early in ``_prepare()`` before model construction, since the
        construction strategy depends on whether we're resuming.

        After this method:
        - ``self.args.resume_from_checkpoint`` is a path string (checkpoint found)
        - or ``False`` (no checkpoint, fresh initialization)
        """
        rfc = self.args.resume_from_checkpoint
        if not rfc:
            return

        if isinstance(rfc, bool):
            # True = auto-find latest checkpoint
            checkpoint_path = find_latest_checkpoint(self.args.output_dir)
        else:
            # Explicit path string
            checkpoint_path = rfc if os.path.exists(rfc) else None

        if checkpoint_path is None:
            logger.info(
                "No checkpoint found. Starting with fresh model initialization."
            )
            self.args.resume_from_checkpoint = False
        else:
            self.args.resume_from_checkpoint = checkpoint_path

    def _wrap(self) -> None:
        """
        Hook for wrapping objects after construction in _prepare().

        Called after dataloaders, model, optimizer, and scheduler are initialized
        but before training begins. Subclasses use this to wrap objects for
        distributed training or other runtime modifications.

        Examples:
        - AccelTrainer: Wraps model/optimizer/dataloaders with Accelerate
        - PipelineTrainer: Sets up pipeline parallel stage wrapping

        See src/forgather/ml/trainer/accelerate/accel_trainer.py and
        src/forgather/ml/trainer/pipeline/pipeline_trainer.py for examples.
        """
        pass

    def _compile_model(self):
        """
        Compile model using torch.compile().

        Hook for model compilation. Subclasses may override to customize compilation
        behavior or apply compilation to additional wrapped objects.
        """
        assert self.model is not None
        self.model.compile(
            backend=self.args.torch_compile_backend,
            mode=self.args.torch_compile_mode,
            dynamic=self.args.torch_compile_dynamic,
            fullgraph=self.args.torch_compile_full_graph,
        )

    def _init_checkpoint_manager(self) -> CheckpointManager:  # type: ignore[override]
        """
        Initialize checkpoint manager hook.

        Creates the CheckpointManager responsible for saving/loading complete training
        state including model, optimizer, scheduler, dataset state, and RNG state.

        Subclasses may override to provide custom checkpoint management behavior.

        Returns:
            CheckpointInterface: Initialized checkpoint manager
        """
        cp_config = CheckpointConfig(
            output_dir=self.args.output_dir,
            save_total_limit=self.args.save_total_limit,
            save_on_each_node=self.args.save_on_each_node,
            save_safetensors=self.args.save_safetensors,
        )

        checkpoint_manager = CheckpointManager(
            config=cp_config,
            dist=self.dist,
            model=self.unwrapped_model(),
            model_preprocessor=self.processing_class,
            stateful_provider=self,
        )
        # Set trainer reference for callback state save/load
        checkpoint_manager.trainer = self
        # Set preserve_n_best from training args
        if hasattr(self.args, "preserve_n_best"):
            checkpoint_manager.preserve_n_best = self.args.preserve_n_best
        return checkpoint_manager

    def _init_dataloaders(self, train_dataset, eval_dataset) -> None:
        """
        Initialize train and evaluation dataloaders (_prepare() sub-step 1).

        Creates StatefulDataLoader instances that support checkpointing dataset iteration state.
        Also computes training step counts for scheduling logging/evaluation/checkpointing.

        Parameters
        ----------
        train_dataset : dataset or None
            Training dataset. Pass ``None`` for eval-only runs.
        eval_dataset : dataset or None
            Evaluation dataset. Pass ``None`` for train-only runs.
        """
        # _prepare() sub-step 1
        self.max_steps = 0
        self.epoch_train_steps = self.args.epoch_train_steps

        self.do_train = train_dataset is not None
        self.do_eval = eval_dataset is not None

        if self.do_train:
            if (
                self.args.set_dataset_epoch
                and self.args.num_train_epochs > 1.0
                and not hasattr(train_dataset, "set_epoch")
            ):
                logger.warning(
                    "Train dataset does not support `set_epoch` and training for > 1 epoch. Dataset will not be reshuffled after each epoch"
                )
                self.args.set_dataset_epoch = False
            self.train_dataloader = self._get_dataloader(
                train_dataset, self.args.per_device_train_batch_size
            )

            self._update_training_steps()

        if self.do_eval:
            self.eval_dataloader = self._get_dataloader(
                eval_dataset, self.args.per_device_eval_batch_size
            )

    def _prepare_model(self) -> None:
        """
        Construct/initialize model and move to device (_prepare() sub-step 2).

        Handles three model construction strategies based on construct_model_on:
        - default: Safe, works everywhere. Constructs on CPU (with no_init_weights if loading
                  checkpoint), then moves to device.
        - meta: Fastest. Constructs on meta device (no memory), materializes empty on device.
                Requires loading checkpoint. May have issues with non-persistent buffers.
        - device: Middle ground. Constructs directly on device with initialization. Faster
                 than default but may fail with model sharding. Good for models with buffers
                 not saved in checkpoint (e.g., RoPE).

        Also sets up gradient checkpointing if enabled and initializes fused loss if available.
        """
        # _prepare() sub-step 2
        # Meta device construction requires a checkpoint. If none is available,
        # fall back to the default strategy.
        if (
            self.args.construct_model_on == "meta"
            and not self.args.resume_from_checkpoint
        ):
            logger.warning(
                "No checkpoint available for meta device construction. "
                "Falling back to 'default' strategy (construct on CPU, "
                "initialize, move to device)."
            )
            self.args.construct_model_on = "default"

        match self.args.construct_model_on:
            case "default":
                if self.model_init:
                    logger.info(
                        f"Constructing model on default device and moving to {self.args.device}"
                    )
                    with ExitStack() as exit_stack:
                        if self.args.default_dtype:
                            exit_stack.enter_context(
                                default_dtype(torch_dtype(self.args.default_dtype))
                            )
                        if self.args.resume_from_checkpoint:
                            exit_stack.enter_context(no_init_weights())
                        self.model = self.model_init()
                else:
                    logger.info(f"Moving model to {self.args.device}")
                    assert self.model is not None
                self.model = self.model.to(self.args.device)
            case "meta":
                assert (
                    self.model_init
                ), "Constructing the model on meta device requires model_init"

                logger.info(
                    f"Constructing model on meta device and materializing on {self.args.device}"
                )
                with ExitStack() as exit_stack:
                    if self.args.default_dtype:
                        exit_stack.enter_context(
                            default_dtype(torch_dtype(self.args.default_dtype))
                        )
                    exit_stack.enter_context(torch.device("meta"))
                    self.model = self.model_init()
                sharing_metadata = create_sharing_metadata(self.model)
                self.model.to_empty(device=self.args.device)
                # to_empty() will break tied parameters. Fix them!
                retie_parameters(self.model, sharing_metadata)
            case "device":
                assert (
                    self.model_init
                ), "Constructing the model on device requires model_init"
                logger.info(
                    f"Constructing and initializing model directly on {self.args.device}"
                )
                with ExitStack() as exit_stack:
                    if self.args.default_dtype:
                        exit_stack.enter_context(
                            default_dtype(torch_dtype(self.args.default_dtype))
                        )
                    exit_stack.enter_context(torch.device(self.args.device))
                    if self.args.resume_from_checkpoint:
                        exit_stack.enter_context(no_init_weights())
                    self.model = self.model_init()
            case _:
                raise ValueError("Requires one of: default|meta|device")
        assert self.model is not None
        # Linear-swap recipes (fp8 / qat) are mutually exclusive — see the
        # _LINEAR_SWAP_RECIPES check in BaseTrainingArguments.__post_init__.
        # The if-chain is sequential rather than elif so a future relaxed
        # mutex still surfaces a clear error (the second swap would find
        # no nn.Linear left and report 0/N converted) instead of silently
        # producing a single-recipe model.
        if self.args.fp8_recipe:
            self.model = self._apply_fp8_training(self.model)
        if self.args.qat_recipe:
            self.model = self._apply_qat_training(self.model)
        if self.args.gradient_checkpointing:
            if self.enable_activation_checkpoint_fn is None:
                if self.dist.rank == 0:
                    logger.warning(
                        f"Activation checkpointing requested, but no function defined!"
                    )
            else:
                # Enable activation checkpointing for all modules in the pipeline.
                self.enable_activation_checkpoint_fn(self.dist.rank, self.model)
        self.loss_fn = self._maybe_get_fused_loss_fn(self.model, self.loss_fn)
        self._wrap_loss_fn()

        # Compute FLOPs per token for performance tracking
        self._flops_per_token = self._compute_flops_per_token()
        if self.dist.rank == 0:
            logger.info(f"Estimated FLOPs per token: {self._flops_per_token:.2e}")

    @staticmethod
    def _initialize_non_persistent_buffers(model: torch.nn.Module) -> None:
        """Initialize non-persistent buffers that were not loaded from checkpoint.

        Non-persistent buffers (registered with ``persistent=False``) are excluded
        from ``state_dict()`` and therefore not saved in or loaded from checkpoints.
        When the model is constructed on the meta device and then materialized via
        ``to_empty()``, these buffers contain uninitialized data.

        This method finds modules with non-persistent buffers and calls their
        ``reset_parameters()`` method to compute the correct values. For example,
        ``RotaryEmbedding.reset_parameters()`` computes ``inv_freq`` from the
        configured ``rope_theta`` and ``d_head``.

        Should be called after checkpoint load when using ``construct_model_on="meta"``.
        """
        for module in model.modules():
            if getattr(module, "_non_persistent_buffers_set", None):
                if hasattr(module, "reset_parameters"):
                    module.reset_parameters()

    def _apply_fp8_training(self, model: torch.nn.Module) -> torch.nn.Module:
        """Convert nn.Linear layers to Float8Linear for FP8 training via torchao."""
        from torchao.float8 import Float8LinearConfig, convert_to_float8_training
        from torchao.float8.float8_linear import Float8Linear

        assert self.args.fp8_recipe is not None
        if self.args.fp8_recipe == "rowwise_with_gw_hp":
            # torchao 0.16.0: matmul_with_hp_or_float8_args reshapes axiswise-scaled
            # inputs in forward(), which trips an assertion in float8_ops. ND inputs
            # (e.g. transformer hidden states of shape (B, S, H)) cannot be used with
            # this recipe. Plain "rowwise" and "tensorwise" are unaffected.
            logger.warning(
                "fp8_recipe='rowwise_with_gw_hp' is currently broken in torchao for "
                "ND inputs (transformer hidden states): reshape on axiswise-scaled "
                "Float8Tensor raises 'aten.reshape.default with axiswise scaling is "
                "not supported yet'. Use 'rowwise' or 'tensorwise' instead."
            )
        config = Float8LinearConfig.from_recipe_name(self.args.fp8_recipe)

        module_filter_fn = None
        pad = self.args.fp8_dim_alignment
        if pad:

            def _filter_fn(mod: torch.nn.Module, fqn: str) -> bool:
                if isinstance(mod, torch.nn.Linear):
                    ok = mod.in_features % pad == 0 and mod.out_features % pad == 0
                    if not ok:
                        logger.info(
                            f"Skipping FP8 for {fqn}: dims ({mod.in_features}, {mod.out_features}) "
                            f"not divisible by {pad}"
                        )
                    return ok
                return True

            module_filter_fn = _filter_fn

        model = convert_to_float8_training(
            model, config=config, module_filter_fn=module_filter_fn
        )

        converted = sum(1 for m in model.modules() if isinstance(m, Float8Linear))
        total_linear = sum(
            1 for m in model.modules() if isinstance(m, (torch.nn.Linear, Float8Linear))
        )
        logger.info(
            f"FP8 training ({self.args.fp8_recipe}): "
            f"converted {converted}/{total_linear} Linear layers"
        )

        return model

    def _apply_qat_training(self, model: torch.nn.Module) -> torch.nn.Module:
        """Install torchao FakeQuantizedLinear modules for quantization-aware training.

        The forward pass simulates the target low-bit precision via fake
        quantizers while the backward pass stays in full precision, letting
        the model learn to be robust to quantization noise. The convert phase
        (real low-bit ops) is run post-training by ``forgather finalize
        --quantize <recipe>``.
        """
        from torchao.quantization import quantize_
        from torchao.quantization.qat import FakeQuantizedLinear, QATConfig

        from forgather.ml.qat_recipes import recipe_to_base_config

        assert self.args.qat_recipe is not None
        base_config = recipe_to_base_config(self.args.qat_recipe)
        quantize_(model, QATConfig(base_config, step="prepare"))

        converted = sum(
            1 for m in model.modules() if isinstance(m, FakeQuantizedLinear)
        )
        total_linear = sum(
            1
            for m in model.modules()
            if isinstance(m, (torch.nn.Linear, FakeQuantizedLinear))
        )
        logger.info(
            f"QAT training ({self.args.qat_recipe}): "
            f"converted {converted}/{total_linear} Linear layers to FakeQuantizedLinear. "
            f"Run `forgather finalize --quantize {self.args.qat_recipe}` "
            f"after training to produce a deployable quantized artifact."
        )

        return model

    def _wrap_loss_fn(self):
        # Rescale loss by gradient accumulation steps.
        self.loss_fn = RescaleLoss(
            self.loss_fn, 1 / self.args.gradient_accumulation_steps
        )

    def _maybe_get_fused_loss_fn(
        self, module: torch.nn.Module, default_loss_fn: Optional[LossFunctionT]
    ):
        """
        Attempt to enable fused loss-logits computation for memory optimization.

        Fused loss combines the final linear layer (computing logits from hidden states)
        with the cross-entropy loss computation. This avoids materializing the full logits
        tensor in memory, which is critical for models with large vocabulary sizes where
        logits can be gigabytes in size.

        For example, with vocab_size=50k, batch_size=8, seq_len=2048:
        - Unfused: logits tensor is 8 * 2048 * 50000 * 4 bytes = ~3.2 GB
        - Fused: Only computes logits for one token at a time, dramatically less memory

        Requires:
        - fused_loss_factory provided to Trainer constructor
        - Model supports get_output_embeddings() (returns final linear layer)
        - Model supports return_hidden_states=True (returns hidden states instead of logits)

        See src/forgather/ml/loss.py and docs/fused_loss/ for implementations and details.

        If model supports fused-loss, and fused loss function is returned, sets:
            self.use_fused_loss = True

        Parameters
        ----------
        module : torch.nn.Module
            The model to attempt enabling fused loss on.
        default_loss_fn : callable or None
            Fallback loss function if fused loss is not supported.

        Returns
        -------
        callable or None
            The fused loss function, or ``default_loss_fn`` if unsupported.
        """
        if self.fused_loss_factory:
            if not hasattr(module, "get_output_embeddings"):
                logger.warning(
                    "Model does not support get_output_embeddings() for fused_loss_factory()"
                )
                return default_loss_fn
            if not getattr(module, "can_return_hidden_states", False):
                logger.warning(
                    f"Model does not support 'return_hidden_states' API; fused loss will not be used."
                )
                return default_loss_fn
            logger.info("Enabled fused loss-logits function")
            self.use_fused_loss = True
            return self.fused_loss_factory(module.get_output_embeddings())  # type: ignore[operator]
        return default_loss_fn

    def _init_optimizer(self) -> None:
        """
        Initialize optimizer and learning rate scheduler (_prepare() sub-step 3).

        Creates optimizer from factory and optionally sets up:
        - Learning rate scheduler (from factory or HF get_scheduler)
        - Fused backward/optimizer hooks if fuse_optim_with_backward=True
        """
        # _prepare() sub-step 3
        assert self.model is not None
        if self.optimizer is None:
            assert self.optimizer_factory is not None
            if self.optimizer_groups:
                buckets = build_optimizer_buckets(
                    self.model.named_parameters(),
                    self.optimizer_groups,
                    default_factory=self.optimizer_factory,
                    debug=self.args.debug_optimizer_groups,
                )
                if len(buckets) == 1:
                    factory, param_groups = buckets[0]
                    self.optimizer = factory(param_groups)
                else:
                    if self.args.fuse_optim_with_backward:
                        raise ValueError(
                            "fuse_optim_with_backward is incompatible with "
                            "multiple per-group optimizer factories: the "
                            "per-parameter post-accumulate hook can only "
                            "drive a single optimizer instance."
                        )
                    self.optimizer = cast(
                        Any,
                        Multiopt([factory(pg) for factory, pg in buckets]),
                    )
            else:
                self.optimizer = self.optimizer_factory(self.model.named_parameters())

            # Combine backward with optimizer step?
            if self.args.fuse_optim_with_backward:
                self._total_grad_squared = torch.zeros(
                    1, device=self.args.device, dtype=torch.float32
                )

                for name, p in self.model.named_parameters():
                    if p.requires_grad:
                        hook = partial(
                            optimizer_hook,
                            self.optimizer,
                            self._total_grad_squared,
                            name,
                        )
                        p.register_post_accumulate_grad_hook(hook)

        if self.lr_scheduler is None:
            if self.lr_scheduler_factory is not None:
                assert self.optimizer is not None
                self.lr_scheduler = self.lr_scheduler_factory(self.optimizer)
            elif self.args.lr_scheduler_type:
                from transformers import get_scheduler

                self.lr_scheduler = get_scheduler(
                    name=self.args.lr_scheduler_type,
                    optimizer=cast(Any, self.optimizer),  # type: ignore[arg-type]
                    num_warmup_steps=self.args.warmup_steps,
                    num_training_steps=self.max_steps,
                    scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
                )

            # Auto-register scheduler as callback if it derives from TrainerCallback.
            # This allows schedulers like GradientNoiseScheduler to receive
            # training metrics (grad_norm, loss, etc.) via on_train_metrics.
            if isinstance(self.lr_scheduler, TrainerCallback):
                self.add_callback(self.lr_scheduler)

    def _maybe_log_save_evaluate(
        self,
        loss_log,
        total_norm_log,
        tokens_log,
        periodic_log,
        periodic_eval,
        periodic_save,
    ):
        # The logic diverges slighlty from HF, in that this in an 'and'
        # It's not clear if this a bug? It's also not clear if any callbacks depend on this?
        # Until proven otherwise, try to do the right thing here.
        if periodic_log.step() or self.control.should_log:
            log_steps = periodic_log.reset()
            self.control.should_log = False

            self._log_step(loss_log, total_norm_log, tokens_log)

        # Handle evaluation (normal schedule or control-triggered)
        eval_metrics = None
        should_eval = periodic_eval.step() or self.control.should_evaluate

        # Force eval if saving and eval_on_save enabled
        should_save = periodic_save.step() or self.control.should_save
        if should_save and self.args.eval_on_save and self.eval_dataset is not None:
            should_eval = True

        if should_eval:
            periodic_eval.reset()
            self.control.should_evaluate = False

            # Do eval
            maybe_cleanup_memory(self.args.gc_threshold)
            eval_metrics = self._eval_loop()

        # Handle checkpointing - normal schedule or control-triggered
        if should_save:
            periodic_save.reset()
            self.control.should_save = False
            assert self.checkpoint_manager

            # Determine checkpoint path BEFORE saving
            checkpoint_path = next_checkpoint_path(
                self.args.output_dir, str(self.state.global_step)
            )

            # Update best checkpoints list BEFORE saving (so preserved list is correct)
            if self.args.preserve_best_model and eval_metrics:
                checkpoint_manager = cast(CheckpointManager, self.checkpoint_manager)
                checkpoint_manager.update_best_checkpoints(
                    checkpoint_path=checkpoint_path,
                    metrics=eval_metrics,
                    metric_key=self.args.best_model_metric,
                    greater_is_better=self.args.best_model_greater_is_better,
                    preserve_n_best=self.args.preserve_n_best,
                )
                # Update state for compatibility
                if checkpoint_manager.best_checkpoints:
                    self.state.best_metric = checkpoint_manager.best_checkpoints[0][1]
                    self.state.best_model_checkpoint = (
                        checkpoint_manager.best_checkpoints[0][0]
                    )

            # Now save checkpoint (deletion will use updated preserved list)
            saved_path = self.checkpoint_manager.save_checkpoint(
                checkpoint_id=str(self.state.global_step)
            )
            self._dispatch_event("on_save")

            # Save metrics file
            if eval_metrics:
                save_checkpoint_metrics(saved_path, eval_metrics)

    def load_best_model(self) -> None:
        """
        Load the best model from the best checkpoint.

        Called at end of training when load_best_model_at_end=True to restore
        the checkpoint with the best metric value seen during training.
        """
        if not self.state.best_model_checkpoint:
            logger.warning("No best model checkpoint available to load")
            return

        if not os.path.exists(self.state.best_model_checkpoint):
            logger.warning(
                f"Best model checkpoint path does not exist: {self.state.best_model_checkpoint}"
            )
            return

        logger.info(f"Loading best model from {self.state.best_model_checkpoint}")
        self.load_checkpoint(self.state.best_model_checkpoint)

    def _setup_periodic_functions(
        self,
    ) -> tuple[PeriodicFunction, PeriodicFunction, PeriodicFunction]:
        """Build the three periodic functions driving log/eval/save cadence."""
        periodic_log = PeriodicFunction(
            global_step=self.state.global_step,
            strategy=self.args.logging_strategy,
            period=self.args.logging_steps,
            epoch_period=self.epoch_train_steps,
            first_step=self.args.logging_first_step,
        )
        periodic_eval = PeriodicFunction(
            global_step=self.state.global_step,
            strategy=self.args.eval_strategy,
            period=self.args.eval_steps,
            epoch_period=self.epoch_train_steps,
            first_step=self.args.eval_delay,
        )
        periodic_save = PeriodicFunction(
            global_step=self.state.global_step,
            strategy=self.args.save_strategy,
            period=self.args.save_steps,
            epoch_period=self.epoch_train_steps,
            first_step=0,
        )
        return periodic_log, periodic_eval, periodic_save

    def _log_stop_condition(self) -> None:
        """Log a clear message describing why the training loop is stopping.

        Called from `_train_loop` immediately after the early-stop check
        trips. Distinguishes the three termination paths (abort without
        save, graceful stop, reaching max_steps) and identifies which
        callback -- if any -- requested the stop via
        ``self._stop_requested_by`` (populated by ``_dispatch_event``).
        """
        step = self.state.global_step
        requested_by = self._stop_requested_by

        if self.control.should_abort_without_save:
            if requested_by is not None:
                cb_name, event, _ = requested_by
                logger.warning(
                    f"Training aborted at step {step} by callback "
                    f"'{cb_name}' during '{event}' "
                    f"(no final checkpoint will be saved)"
                )
            else:
                logger.warning(
                    f"Training aborted at step {step} "
                    f"(should_abort_without_save set; "
                    f"no final checkpoint will be saved)"
                )
        elif self.control.should_training_stop:
            if requested_by is not None:
                cb_name, event, _ = requested_by
                logger.info(
                    f"Training stopped at step {step} by callback "
                    f"'{cb_name}' during '{event}'"
                )
            else:
                logger.info(
                    f"Training stopped at step {step} " f"(should_training_stop set)"
                )
        else:
            logger.info(
                f"Training stopped at step {step}: "
                f"reached max_steps ({self.max_steps})"
            )

    def _run_final_checkpoint_if_needed(
        self,
        accumulated_loss: list,
        accumulated_grad_norm: list,
        accumulated_tokens: list,
        periodic_log: PeriodicFunction,
        periodic_eval: PeriodicFunction,
        periodic_save: PeriodicFunction,
    ) -> None:
        """Run the end-of-training log/eval/save sequence.

        Skipped if ``should_abort_without_save`` is set (e.g. the
        divergence detector triggered). Otherwise forces a save on the
        current step if one hasn't already happened, then runs
        ``_maybe_log_save_evaluate`` once more to drain any pending
        accumulated metrics.
        """
        if self.control.should_abort_without_save:
            return

        # Force save, if we have not already saved on this step and save enabled.
        if periodic_save.rel_step != 0 and self.args.save_strategy != "no":
            logger.info(f"Saving final checkpoint at step {self.state.global_step}")
            # If load best model, we need to evaluate it too.
            if self.args.load_best_model_at_end:
                self.control.should_evaluate = True
            self.control.should_save = True
        self._maybe_log_save_evaluate(
            accumulated_loss,
            accumulated_grad_norm,
            accumulated_tokens,
            periodic_log,
            periodic_eval,
            periodic_save,
        )

    def _finalize_training(
        self,
        start_time: Optional[float],
        speed_metrics_steps: int,
    ) -> TrainOutput:
        """Post-training cleanup: load best model, summarise, fire on_train_end.

        Runs after the training loop exits (either normally, via a
        callback-requested stop, or via dataset exhaustion). Returns the
        ``TrainOutput`` that the public ``train()`` entry point
        eventually propagates to callers.
        """
        if self.args.load_best_model_at_end:
            logger.info(f'Loading best model "{self.state.best_model_checkpoint}"')
            self.load_best_model()

        metrics = self._end_train_loop(start_time, speed_metrics_steps)
        self.log(metrics)

        # Log best checkpoints summary at end of training
        if (
            self.args.preserve_best_model
            and self.checkpoint_manager
            and self.state.is_world_process_zero
        ):
            summary = cast(
                CheckpointManager, self.checkpoint_manager
            ).get_best_checkpoints_summary(metric_key=self.args.best_model_metric)
            logger.info(f"\n{'='*60}\nTraining complete!\n{summary}\n{'='*60}")

        self._dispatch_event("on_train_end")
        return TrainOutput(self.state.global_step, metrics)

    @override
    def _train_loop(self) -> TrainOutput:
        periodic_log, periodic_eval, periodic_save = self._setup_periodic_functions()

        assert self.optimizer is not None
        assert self.model is not None
        assert self.train_dataloader is not None

        # Just to be sure...
        self.optimizer.zero_grad()

        start_time = None
        train_steps = 0
        self._dispatch_event("on_train_begin")

        # Holds loss, grad-norm, and token samples between log steps
        accumulated_grad_norm: list = []
        accumulated_loss: list = []
        accumulated_tokens: list = []

        # Context manager for setting model.train()/eval()
        with set_train(self.model, True):
            # Epoch loop
            while True:
                self.control.should_epoch_stop = False
                if self.args.set_dataset_epoch and self.state.raw_epoch > 0:
                    # If supported, reshuffle dataset at the start of each epoch.
                    # Datasets without set_epoch (e.g. standard HF Dataset) are
                    # silently skipped rather than crashing mid-training.
                    dataset = self.train_dataloader.dataset  # type: ignore[union-attr]
                    if hasattr(dataset, "set_epoch"):
                        logger.debug(f"Setting dataset epoch {self.state.raw_epoch}")
                        dataset.set_epoch(self.state.raw_epoch)
                data_iterator = iter(self.train_dataloader)
                self._dispatch_event("on_epoch_begin")

                while True:
                    self._dispatch_event("on_step_begin")

                    try:
                        loss, total_norm, tokens = self._train_step(data_iterator)
                        self._dispatch_event(
                            "on_train_metrics",
                            loss=loss,
                            grad_norm=total_norm,
                            tokens=tokens,
                        )
                    except StopIteration:
                        self.state.raw_epoch += 1
                        self.state.epoch_start_step = self.state.global_step
                        if (
                            self.args.num_train_epochs >= 0
                            and self.state.raw_epoch >= self.args.num_train_epochs
                        ):
                            self.control.should_epoch_stop = True
                        self._dispatch_event("on_step_end")
                        break

                    accumulated_grad_norm.append(total_norm)
                    accumulated_loss.append(loss)
                    accumulated_tokens.append(tokens)

                    # Increment global step
                    self.state.global_step += 1

                    # Compute epoch as continuous value from global steps
                    self.state.epoch = float(self.state.raw_epoch) + (
                        float(self.state.global_step - self.state.epoch_start_step)
                        / float(self.epoch_train_steps)
                    )

                    self._dispatch_event("on_step_end")
                    self._maybe_log_save_evaluate(
                        accumulated_loss,
                        accumulated_grad_norm,
                        accumulated_tokens,
                        periodic_log,
                        periodic_eval,
                        periodic_save,
                    )

                    train_steps += 1

                    # Check for early stop condition
                    if (
                        self.control.should_training_stop
                        or self.control.should_abort_without_save
                        or self.state.global_step >= self.max_steps
                    ):
                        self._log_stop_condition()
                        self.control.should_epoch_stop = True
                        break

                    # Periodic GC
                    maybe_cleanup_memory(self.args.gc_threshold)

                    # maybe delay start of metrics recording for torch.compile()
                    if train_steps == self.args.speed_metrics_start_step:
                        start_time = time.time()

                self._dispatch_event("on_epoch_end")
                if self.control.should_epoch_stop:
                    break

        self._run_final_checkpoint_if_needed(
            accumulated_loss,
            accumulated_grad_norm,
            accumulated_tokens,
            periodic_log,
            periodic_eval,
            periodic_save,
        )

        return self._finalize_training(
            start_time,
            train_steps - self.args.speed_metrics_start_step,
        )

    @staticmethod
    def _format_zero_eval_batches_diagnostic(
        header: str,
        settings: list[tuple[str, Any]],
        explanation: str,
    ) -> str:
        """Render a zero-eval-batches diagnostic with a uniform shape.

        Subclasses build ``settings`` and ``explanation`` to add their own
        context (e.g. distributed dispatch flags) without restating the
        framing or the doc reference.
        """
        width = max(len(name) for name, _ in settings)
        settings_block = "\n".join(
            f"  {name.ljust(width)} = {value}" for name, value in settings
        )
        return (
            f"{header}\n"
            "\n"
            f"Effective settings:\n{settings_block}\n"
            "\n"
            f"{explanation}\n"
            "\n"
            "See docs/trainers/distributed-eval-zero-batches.md for the\n"
            "full diagnosis and the available remedies."
        )

    def _zero_eval_batches_message(self) -> str:
        """Build a diagnostic for the eval dataloader yielding zero batches.

        The base implementation surfaces the settings most likely to be
        responsible (eval batch size, drop_last, max_eval_steps). Distributed
        subclasses override this to also report ``dispatch_batches`` /
        ``dispatch_eval_batches`` and the world size, since those interact
        with sharded eval splits.
        """
        return self._format_zero_eval_batches_diagnostic(
            header="The eval dataloader did not yield any examples.",
            settings=[
                ("per_device_eval_batch_size", self.args.per_device_eval_batch_size),
                ("dataloader_drop_last", self.args.dataloader_drop_last),
                ("max_eval_steps", self.args.max_eval_steps),
            ],
            explanation=(
                "The eval split may simply be empty, or the dataloader may\n"
                "have dropped every batch because the dataset has fewer than\n"
                "per_device_eval_batch_size examples and\n"
                "dataloader_drop_last is True."
            ),
        )

    @override
    @torch.no_grad()
    def _eval_loop(self) -> Dict[str, float]:
        """
        The inner evaluation loop
        """
        assert self.model is not None
        assert self.eval_dataloader is not None
        with set_train(self.model, False):
            total_loss = torch.zeros(1, device=self.args.device)
            step = -1
            for step, batch in enumerate(self.eval_dataloader):
                if self.args.max_eval_steps > 0 and step >= self.args.max_eval_steps:
                    break
                input_dict, labels = self._prepare_batch(batch)
                outputs = self._prediction_step(input_dict, labels)
                loss = outputs["loss"]
                assert loss is not None
                total_loss += loss
                self._dispatch_event("on_prediction_step")
            if step < 0:
                raise RuntimeError(self._zero_eval_batches_message())

            metrics = {"eval_loss": (total_loss / (step + 1)).item()}
            if isinstance(self.eval_dataloader, StatefulDataLoader):
                sync_dataset_state_from_dataloader(self.eval_dataloader)
            self._dispatch_event("on_evaluate", metrics=metrics)
            return metrics

    def _clip_grad_norm(
        self, max_grad_norm: float | None, norm_type: float = 2.0
    ) -> Optional[Tensor]:
        """
        Clip gradients by norm.

        Returns
        -------
        Tensor or None
            Total norm of the parameters.

        Raises
        ------
        RuntimeError
            If parameters are not of supported types for ``foreach=True``.
        """
        # In the case of fused backward / optimizer step, we accumulate squared norm
        # in the optimizer hook. Compute norm via sqrt() and reset accumulator
        if self.args.fuse_optim_with_backward:
            total_norm = self._total_grad_squared.sqrt()
            self._total_grad_squared -= self._total_grad_squared
            return total_norm

        assert self.model is not None
        # If not clipping, just compute and return it
        if max_grad_norm is None or max_grad_norm == 0:
            grads = [p.grad for p in self.model.parameters() if p.grad is not None]

            total_norm = torch.nn.utils.get_total_norm(
                grads, norm_type=norm_type, foreach=True
            )
            return total_norm

        # Otherwise, use fused clip_grad_norm_
        total_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            max_grad_norm,
            norm_type=norm_type,
            foreach=False if self.args.device == "cpu" else True,
        )

        return total_norm

    def _train_step(
        self, data_iterator: Iterator[dict[str, Tensor]]
    ) -> Tuple[Tensor, Tensor | None, Tensor]:
        """
        Perform a single training step, with optional gradient accumulation.

        Parameters
        ----------
        data_iterator : Iterator
            Iterator over training batches.

        Returns
        -------
        tuple of (Tensor, Tensor or None, Tensor)
            ``(loss, grad_norm, tokens)``. Loss is unscaled for logging consistency.
            ``grad_norm`` is None if not computed on this step.
            ``tokens`` is the total non-padding tokens processed in this step.
        """
        accumulated_losses = []
        accumulated_tokens: Tensor | int = 0

        for gradient_step in range(self.args.gradient_accumulation_steps):
            self.gradient_accumulation_step = gradient_step + 1
            input_dict, labels = self._prepare_batch(next(data_iterator))
            self._dispatch_event("on_forward_backward_begin")
            loss = self._forward_backward_step(input_dict, labels)
            self._dispatch_event("on_forward_backward_end")
            accumulated_losses.append(loss)
            # Count tokens in this micro-batch (local, not yet synchronized across ranks).
            # _count_batch_tokens returns a GPU tensor to avoid forcing GPU-CPU sync.
            accumulated_tokens = accumulated_tokens + self._count_batch_tokens(
                input_dict, labels
            )

        assert self._should_sync_gradients()
        assert self.optimizer is not None

        # Unscale gradients before clipping (no-op when not using fp16 GradScaler)
        self.amp_context.unscale_(self.optimizer)
        total_norm = self._clip_grad_norm(self.args.max_grad_norm)
        self._dispatch_event("on_pre_optimizer_step")
        if not self.args.fuse_optim_with_backward:
            self.amp_context.optimizer_step(self.optimizer)
            self.optimizer.zero_grad()
        self._dispatch_event("on_optimizer_step")
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        loss = torch.sum(torch.stack(accumulated_losses))
        # Loss and token synchronization deferred to _log_step() to avoid
        # per-step NCCL all_reduce overhead.
        if isinstance(accumulated_tokens, Tensor):
            tokens_tensor = accumulated_tokens.to(dtype=torch.int64)
        else:
            tokens_tensor = torch.tensor(
                accumulated_tokens, device=self.args.device, dtype=torch.int64
            )
        return loss, total_norm, tokens_tensor

    def _forward_backward_step(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Tensor:
        """
        Execute forward pass followed by backward pass.

        Handles both standard and fused loss computation. When fused loss is enabled,
        requests hidden states instead of logits to avoid materializing the large
        logits tensor.

        Parameters
        ----------
        input_dict : dict
            Model inputs (input_ids, attention_mask, etc.).
        labels : Tensor
            Target labels for loss computation.

        Returns
        -------
        Tensor
            Detached loss tensor for logging.
        """
        assert self.model is not None
        assert self.loss_fn is not None
        if self.use_fused_loss:
            input_dict["return_hidden_states"] = True  # type: ignore[assignment]
        with self.amp_context.autocast():
            outputs = self.model(**input_dict)
            logits = logits_from_outputs(outputs)
            loss = self.loss_fn(logits, labels)
        self._backward(loss)
        return loss.detach()

    def _backward(self, loss: Tensor) -> None:
        """
        Execute backward pass to compute gradients.

        When fp16 mixed precision is active, the loss is scaled by GradScaler
        before backward to prevent gradient underflow.

        Subclasses may override to customize backward behavior (e.g., for pipeline parallelism).

        Parameters
        ----------
        loss : Tensor
            Loss tensor to backpropagate.
        """
        self.amp_context.scale_loss(loss).backward()

    def _should_sync_gradients(self) -> bool:
        """
        Determine if gradients should be synchronized across processes.

        Normally, this just compares the gradient accumulation step against the total
        accumulation steps.

        If overridden, as is the case for the Accelerate trainer, an assert verifies that the logic
        agrees with when gradient synchronization is expected.

        Returns
        -------
        bool
            True if gradients should be synchronized on the current forward-backward step.
        """
        return self.gradient_accumulation_step == self.args.gradient_accumulation_steps

    def _update_training_steps(self) -> None:
        """
        Compute training step counts from dataloader length.

        Calculates:
        - epoch_train_steps: Number of batches per epoch
        - max_steps: Total optimizer updates across all epochs

        With gradient accumulation: global_steps = batch_count // gradient_accumulation_steps

        This may need to be called multiple times if dataloader is wrapped (e.g., by Accelerate)
        and its length changes. Also synchronizes dataset state to ensure accurate length.
        """
        if isinstance(self.train_dataloader, StatefulDataLoader):
            sync_dataset_state_from_dataloader(self.train_dataloader)

        # The number of training steps in a single epoch (batch processing steps)
        if isinstance(self.train_dataloader, Sized):
            self.epoch_train_steps = len(self.train_dataloader)

        # Convert to effective global steps (optimizer updates) with gradient accumulation
        effective_epoch_steps = (
            self.epoch_train_steps // self.args.gradient_accumulation_steps
        )

        # The total number of global training steps (optimizer updates) in all epochs.
        # num_train_epochs < 0 means "no epoch cap" -- rely entirely on max_steps.
        if self.args.num_train_epochs >= 0:
            self.max_steps = self.args.num_train_epochs * effective_epoch_steps
        else:
            self.max_steps = sys.maxsize

        # If an explicit max_steps limit is specified, constrain to it.
        if self.args.max_steps >= 0:
            self.max_steps = min(self.args.max_steps, self.max_steps)

        if self.state is not None:
            self.state.max_steps = self.max_steps

    def _init_state(self) -> TrainerState:
        """
        Initialize trainer state for tracking training progress.

        Creates TrainerState with training step counts, batch sizes, and process info.
        State is saved/restored with checkpoints to resume training accurately.

        Key state fields:
        - global_step: Total optimizer updates since training start (0-indexed)
        - raw_epoch: Integer epoch counter, increments at end of each dataset iteration
        - epoch_start_step: Global step when raw_epoch was last incremented
        - epoch: Continuous value = raw_epoch + fractional progress through current epoch
        - best_metric/best_model_checkpoint: Best model tracking for load_best_model_at_end

        Returns:
            TrainerState: Initialized state object
        """

        if self.do_train:
            assert has_batch_size(self.train_dataloader)
            return TrainerState(
                max_steps=self.max_steps,
                logging_steps=self.args.logging_steps,
                eval_steps=self.args.eval_steps,
                num_train_epochs=int(self.args.num_train_epochs),
                train_batch_size=self.train_dataloader.batch_size,
                epoch_train_steps=self.epoch_train_steps,
                is_local_process_zero=self.is_local_process_zero,
                is_world_process_zero=self.is_world_process_zero,
                num_processes=self.num_processes,
                save_steps=self.args.save_steps,
                # Initialize best model tracking
                best_metric=None,
                best_model_checkpoint=None,
                max_eval_steps=self.args.max_eval_steps,
            )
        else:
            return TrainerState(
                max_steps=0,
                logging_steps=0,
                eval_steps=0,
                num_train_epochs=0,
                train_batch_size=0,
                epoch_train_steps=0,
                is_local_process_zero=self.is_local_process_zero,
                is_world_process_zero=self.is_world_process_zero,
                num_processes=self.num_processes,
                save_steps=0,
                # Initialize best model tracking
                best_metric=None,
                best_model_checkpoint=None,
                max_eval_steps=self.args.max_eval_steps,
            )

    def _end_train_loop(
        self, start_time: float | None, train_steps: int
    ) -> dict[str, int | float]:
        if train_steps == -1:
            train_steps = 0
        if start_time:
            runtime = time.time() - start_time
        else:
            runtime = None
        # Calculate effective batch size including gradient accumulation
        effective_batch_size = self._calculate_effective_batch_size()
        total_train_samples = effective_batch_size * train_steps
        metrics = self._speed_metrics(
            "train", runtime, total_train_samples, train_steps
        )
        metrics["epoch"] = self.state.epoch
        metrics["effective_batch_size"] = effective_batch_size

        # Add token and FLOP metrics
        metrics["total_tokens"] = self.state.num_input_tokens_seen
        if runtime and runtime > 0:
            metrics["tokens_per_second"] = round(
                self.state.num_input_tokens_seen / runtime, 2
            )

        if self.state.total_flos > 0:
            metrics["total_flops"] = self.state.total_flos
            if runtime and runtime > 0:
                metrics["flops_per_second"] = round(self.state.total_flos / runtime, 2)

        return metrics

    def _calculate_effective_batch_size(self) -> int:
        """
        Calculate effective batch size accounting for gradient accumulation and parallelism.

        The effective batch size is the total number of examples processed per optimizer update.

        Calculation depends on parallelism strategy:
        - Data parallelism (DDP): Multiply by num_processes (each process sees different batch)
        - Pipeline parallelism: Don't multiply (same batch flows through pipeline stages)
        - Model parallelism: Don't multiply (same batch processed by different model shards)
        - Combined: Don't multiply (same batch)

        Formula: base_batch = per_device_batch * gradient_accumulation_steps
                 effective = base_batch * num_processes (only for data parallel)

        Returns:
            Total number of examples per optimizer update
        """
        base_batch_size = (
            self.state.train_batch_size * self.args.gradient_accumulation_steps
        )

        # Check if any form of model parallelism is being used
        # If so, don't multiply by num_processes since the same batch is processed
        # across different parts of the model (stages and/or shards)
        is_pipeline_parallel = (
            hasattr(self, "_is_pipeline_parallel") and self._is_pipeline_parallel()
        )
        is_model_parallel = (
            hasattr(self, "_is_model_parallel") and self._is_model_parallel()
        )

        if is_pipeline_parallel or is_model_parallel:
            # Any form of model parallelism: don't multiply by num_processes
            return base_batch_size
        else:
            # Data parallel (DDP) or single process: multiply by num_processes
            return self.num_processes * base_batch_size

    def _is_pipeline_parallel(self) -> bool:
        """
        Indicate if trainer uses pipeline parallelism.

        Used for effective batch size calculation. Pipeline parallel trainers process
        the same batch across different pipeline stages, so don't multiply by num_processes.

        Returns:
            False for single-device Trainer, True in PipelineTrainer
        """
        return False

    def _is_model_parallel(self) -> bool:
        """
        Indicate if trainer uses model parallelism (tensor/expert parallelism).

        Used for effective batch size calculation. Model parallel trainers process
        the same batch across different model shards, so don't multiply by num_processes.

        Returns:
            False for single-device Trainer, True in model parallel trainers
        """
        return False

    def _prediction_step(
        self, input_dict: dict[str, Tensor], labels: Tensor
    ) -> Dict[str, Tensor | None]:
        """
        Perform a single evaluation batch forward pass.

        Computes loss without gradient computation (wrapped in @torch.no_grad()).
        Uses unscaled loss (via loss_fn.no_rescale()) for accurate eval metrics.

        Parameters
        ----------
        input_dict : dict
            Model inputs (input_ids, attention_mask, etc.).
        labels : Tensor
            Target labels for loss computation.

        Returns
        -------
        dict
            Dictionary with ``'loss'``, ``'logits'``, and ``'labels'`` tensors.
        """
        assert self.model is not None
        assert isinstance(self.loss_fn, RescaleLoss)
        if self.use_fused_loss:
            input_dict["return_hidden_states"] = True  # type: ignore[assignment]
        with self.loss_fn.no_rescale(), self.amp_context.autocast():
            outputs = self.model(**input_dict)
            logits = logits_from_outputs(outputs)
            loss = self.loss_fn(logits, labels)

        loss = self._distributed_loss(loss.detach())
        return {
            "loss": loss,
            "logits": logits.detach(),
            "labels": labels,
        }

    def _speed_metrics(
        self, prefix: str, runtime: float | None, samples: int, steps: int
    ) -> dict[str, int | float]:
        if runtime is not None and steps > 0:
            samples_per_second = round(samples / runtime, 3)
            steps_per_second = round(steps / runtime, 3)
        else:
            samples_per_second = float("nan")
            steps_per_second = float("nan")
        metrics = {
            f"{prefix}_runtime": runtime,
            f"{prefix}_samples": samples,
            "step": steps,
            f"{prefix}_samples_per_second": samples_per_second,
            f"{prefix}_steps_per_second": steps_per_second,
        }
        return metrics

    def _log_step(
        self,
        loss_log: list[Tensor],
        total_norm_log: list[Tensor],
        tokens_log: list[Tensor],
    ):
        self._update_training_steps()

        logs: dict[str, Any] = {
            "epoch": self.state.epoch,
        }

        if not len(loss_log):
            # No losses to log - this can happen with gradient accumulation
            # when logging is called multiple times in the same step
            return

        # Reduce loss: compute local mean then synchronize across ranks at log step
        # (mirrors the deferred token synchronization pattern)
        mean_loss = torch.stack(loss_log).mean()
        mean_loss = self._distributed_loss(mean_loss)
        logs["loss"] = mean_loss.item()
        loss_log.clear()

        if len(total_norm_log):
            total_norm = torch.stack(total_norm_log)
            logs["grad_norm"] = total_norm.square().mean().sqrt().item()
            logs["max_grad_norm"] = total_norm.max().item()
            total_norm_log.clear()
        else:
            raise ValueError("No grad norm!")

        # Synchronize token counts across ranks at log step (amortizes all_reduce cost)
        if len(tokens_log):
            local_tokens = torch.stack(tokens_log).sum()
            synced_tokens = self._distributed_tokens(local_tokens)
            tokens_count = int(synced_tokens.item())
            self.state.num_input_tokens_seen += tokens_count
            self.state.total_flos += self._flops_per_token * tokens_count
            logs["tokens"] = tokens_count
            logs["total_tokens"] = self.state.num_input_tokens_seen
            logs["total_flos"] = self.state.total_flos
            tokens_log.clear()

        if self.lr_scheduler is not None:
            last_lr = self.lr_scheduler.get_last_lr()[0]
            if last_lr is not None:
                if torch.is_tensor(last_lr):
                    last_lr = last_lr.item()
                logs["learning_rate"] = last_lr

        # Capture peak CUDA memory for this rank, gather across ranks, and reset
        # stats so each interval reflects only the high-water mark since the previous
        # log step.  Doing this here (rather than in individual callbacks) ensures
        # the reset happens exactly once regardless of how many callbacks query
        # memory statistics, and the per-rank list is uniformly available to all
        # logging callbacks (JsonLogger, PeakMemory, ProgressCallback).
        if torch.cuda.is_available():
            device = torch.cuda.current_device()
            local_peak = int(torch.cuda.max_memory_allocated(device=device))
            torch.cuda.reset_peak_memory_stats(device=device)
            logs["peak_mem_allocated"] = self._distributed_peak_mem(local_peak)

        # Allow callbacks to add entries to the logs before on_log fires
        self._dispatch_event(
            "on_log_step",
            logs=logs,
        )

        # Dispatch on_log to all logging callbacks
        return self.log(logs)

    def _prepare_batch(
        self, batch: Dict[str, Tensor]
    ) -> tuple[Dict[str, Tensor], Tensor]:
        """
        Move batch tensors to target device and extract labels.

        Parameters
        ----------
        batch : dict
            Dictionary of tensors from the dataloader. Must include a ``'labels'`` key.

        Returns
        -------
        tuple of (dict, Tensor)
            ``(input_dict, labels)`` with labels separated for loss computation.
        """
        batch = {k: v.to(self.args.device, non_blocking=True) for k, v in batch.items()}
        labels = batch.pop("labels")

        return (batch, labels)

    def _distributed_loss(self, loss: Tensor) -> Tensor:
        """
        Reduce loss across all processes for accurate logging.

        Single-device trainer just returns the input. Distributed trainers (AccelTrainer,
        PipelineTrainer) override this to all-reduce loss values so logging reflects
        the average loss across all processes/devices.

        See src/forgather/ml/trainer/accelerate/accel_trainer.py for distributed implementation.

        Parameters
        ----------
        loss : Tensor
            Loss tensor from the current process.

        Returns
        -------
        Tensor
            Loss tensor (single-device) or all-reduced loss (distributed).
        """
        return loss

    @override
    def load_checkpoint(self, *args, **kwargs) -> None:
        super().load_checkpoint(*args, **kwargs)
        self._update_training_steps()
