Skip to content

Optimizers

Forgather ships several optimizers and learning rate schedulers, available as configuration templates or directly via the Python API.

Related documentation:

Optimizers

forgather.ml.optim.adamw.AdamW

Bases: Optimizer

AdamW optimizer with optional stochastic rounding for pure-bf16 training.

Implements decoupled weight-decay regularization (Loshchilov & Hutter, arXiv:1711.05101) on top of the Adam update rule (Kingma & Ba, arXiv:1412.6980). The distinguishing feature of this implementation is first-class support for pure bf16 training — parameters, gradients, and optimizer states all stay in bf16, with stochastic rounding (SR) used for every write-back to avoid systematic truncation bias. This eliminates the need for fp32 master-weight copies while retaining most of the numerical quality of mixed-precision training.

Prefer this optimizer over standard torch.optim.AdamW when:

  • Training on hardware with fast bf16 throughput and limited memory.
  • Running pure-bf16 experiments where fp32 master weights are undesirable.
  • Using FSDP2 (DTensor-backed parameters are handled transparently).
Notes

Stochastic rounding is seeded from a dedicated torch.Generator initialised with a fixed seed (5489) so that all DDP ranks make identical rounding decisions and parameters stay in sync without extra communication.

The inner _adam kernel is optionally compiled with torch.compile(..., fullgraph=True) for improved throughput.

References

Kingma, D. & Ba, J. (2014). Adam: A Method for Stochastic Optimization. arXiv:1412.6980.

Loshchilov, I. & Hutter, F. (2017). Decoupled Weight Decay Regularization. arXiv:1711.05101.

Source code in src/forgather/ml/optim/adamw.py
class AdamW(Optimizer):
    """AdamW optimizer with optional stochastic rounding for pure-bf16 training.

    Implements decoupled weight-decay regularization (Loshchilov & Hutter,
    arXiv:1711.05101) on top of the Adam update rule (Kingma & Ba,
    arXiv:1412.6980).  The distinguishing feature of this implementation is
    first-class support for *pure bf16 training* — parameters, gradients, and
    optimizer states all stay in bf16, with stochastic rounding (SR) used for
    every write-back to avoid systematic truncation bias.  This eliminates the
    need for fp32 master-weight copies while retaining most of the numerical
    quality of mixed-precision training.

    Prefer this optimizer over standard ``torch.optim.AdamW`` when:

    * Training on hardware with fast bf16 throughput and limited memory.
    * Running pure-bf16 experiments where fp32 master weights are undesirable.
    * Using FSDP2 (DTensor-backed parameters are handled transparently).

    Notes
    -----
    Stochastic rounding is seeded from a dedicated ``torch.Generator``
    initialised with a fixed seed (5489) so that all DDP ranks make identical
    rounding decisions and parameters stay in sync without extra communication.

    The inner ``_adam`` kernel is optionally compiled with
    ``torch.compile(..., fullgraph=True)`` for improved throughput.

    References
    ----------
    Kingma, D. & Ba, J. (2014). Adam: A Method for Stochastic Optimization.
    arXiv:1412.6980.

    Loshchilov, I. & Hutter, F. (2017). Decoupled Weight Decay Regularization.
    arXiv:1711.05101.
    """

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.01,
        torch_compile: bool = True,
        bf16_stochastic_round: bool = True,
    ):
        """
        Parameters
        ----------
        params : iterable of Parameter
            Model parameters to optimize.
        lr : float, optional
            Learning rate. Default is 1e-3.
        betas : tuple of (float, float), optional
            Exponential decay rates for the first and second moment estimates.
            Default is (0.9, 0.999).
        eps : float, optional
            Term added to the denominator to improve numerical stability.
            Default is 1e-6.
        weight_decay : float, optional
            Decoupled weight-decay coefficient. Default is 0.01.
        torch_compile : bool, optional
            If ``True``, the inner ``_adam`` kernel is compiled with
            ``torch.compile`` for improved throughput. Default is ``True``.
        bf16_stochastic_round : bool, optional
            If ``True``, write-backs of moment buffers and parameter updates
            to bf16 storage use stochastic rounding rather than round-to-
            nearest.  Eliminates systematic truncation bias in pure-bf16
            training.  Has no effect when parameters are fp32.
            Default is ``True``.
        """
        self.compile = torch_compile
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            bf16_stochastic_round=bf16_stochastic_round,
        )
        super().__init__(params, defaults)

        # Dedicated generator for stochastic rounding. Using a fixed seed
        # ensures all DDP ranks produce identical rounding decisions,
        # preventing parameter divergence across ranks. The generator is
        # only advanced by SR draws (not shared with dropout, data loading,
        # etc.) so it stays in sync as long as all ranks process the same
        # parameters in the same order -- which DDP guarantees.
        self._sr_generator = torch.Generator()
        self._sr_generator.manual_seed(5489)
        self._sr_cuda_generators = {}  # device -> Generator, lazily created

    def add_param_group(self, param_group: dict):
        super().add_param_group(param_group)
        group = self.param_groups[-1]
        if not isinstance(group["lr"], Tensor):
            group["lr"] = torch.tensor(group["lr"], dtype=torch.float32)

    def _init_state(self, state, group, p, grad):
        state["step"] = torch.tensor(0.0, dtype=torch.float32)
        state["m"] = torch.zeros_like(grad)
        state["v"] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        with torch._dynamo.utils.disable_cache_limit():
            for group in self.param_groups:
                for p in group["params"]:
                    grad = p.grad
                    if grad is None:
                        continue
                    # Key state by the DTensor-backed Parameter before
                    # unsharding so save/load round-trips correctly.
                    state = self.state[p]
                    # Under FSDP2, p and p.grad are DTensors. Unshard to
                    # local shards for the compiled kernel and for scratch-
                    # buffer sizing (SR random bits). In-place ops on the
                    # local views flow back into the DTensor storage.
                    grad = _local_shard(grad)
                    p = _local_shard(p)

                    # Init state
                    if "step" not in state:
                        self._init_state(state, group, p, grad)

                    lr = group["lr"]
                    assert isinstance(
                        lr, Tensor
                    ), "Someone changed our lr to a non-Tensor!?"

                    state["step"] += 1
                    betas = group["betas"]

                    # Draw SR seed from dedicated generator (same across DDP ranks)
                    bf16_sr = group["bf16_stochastic_round"]
                    if bf16_sr:
                        sr_seed = int(
                            torch.randint(
                                0,
                                2**31,
                                (1,),
                                generator=self._sr_generator,
                            ).item()
                        )
                    else:
                        sr_seed = 0

                    # Pre-generate stochastic rounding noise.
                    # torch.Generator can't be traced by dynamo, so we
                    # generate the random bits here (outside compile) and
                    # pass the resulting tensor into the compiled function.
                    # _adam may call SR up to 3 times (m, v, update), so we
                    # pre-generate 3 * numel() worth of random bits as a
                    # flat buffer that gets sliced inside.
                    sr_rand_bits = None
                    if bf16_sr:
                        device = p.device
                        if device.type == "cuda":
                            if device not in self._sr_cuda_generators:
                                self._sr_cuda_generators[device] = torch.Generator(
                                    device=device
                                )
                            sr_gen = self._sr_cuda_generators[device]
                            sr_gen.manual_seed(sr_seed)
                        else:
                            sr_gen = torch.Generator(device=device)
                            sr_gen.manual_seed(sr_seed)
                        sr_rand_bits = torch.randint(
                            0,
                            1 << 16,
                            (3 * p.numel(),),
                            device=device,
                            dtype=torch.int32,
                            generator=sr_gen,
                        )

                    args = [
                        p,
                        grad,
                        state["step"],
                        state["m"],
                        state["v"],
                        lr,
                        betas[0],
                        betas[1],
                        group["eps"],
                        group["weight_decay"],
                        bf16_sr,
                        sr_rand_bits,
                    ]
                    if self.compile:
                        torch.compile(_adam, fullgraph=True, dynamic=False)(*args)
                    else:
                        _adam(*args)

        return loss

    def state_dict(self):
        """Return optimizer state with structure validation."""
        state_dict = super().state_dict()

        # Validate state structure for debugging
        for param_id, param_state in state_dict["state"].items():
            expected_keys = {"step", "m", "v"}
            if not expected_keys.issubset(param_state.keys()):
                missing = expected_keys - param_state.keys()
                raise ValueError(
                    f"AdamW state missing keys for param {param_id}: {missing}"
                )

        # Save SR generator state for deterministic resume
        state_dict["sr_generator_state"] = self._sr_generator.get_state()

        return state_dict

    def load_state_dict(self, state_dict):
        """Load optimizer state with validation."""
        # Shallow copy to avoid mutating caller's dict
        state_dict = dict(state_dict)
        # Extract SR generator state before super() processes the dict
        sr_gen_state = state_dict.pop("sr_generator_state", None)

        # Validate before loading
        for param_id, param_state in state_dict["state"].items():
            expected_keys = {"step", "m", "v"}
            if not expected_keys.issubset(param_state.keys()):
                missing = expected_keys - param_state.keys()
                raise ValueError(
                    f"Cannot load AdamW: missing keys for param {param_id}: {missing}"
                )

        super().load_state_dict(state_dict)

        # Restore SR generator state for deterministic resume
        if sr_gen_state is not None:
            self._sr_generator.set_state(sr_gen_state)

__init__(params, lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.01, torch_compile=True, bf16_stochastic_round=True)

Parameters:

Name Type Description Default
params iterable of Parameter

Model parameters to optimize.

required
lr float

Learning rate. Default is 1e-3.

0.001
betas tuple of (float, float)

Exponential decay rates for the first and second moment estimates. Default is (0.9, 0.999).

(0.9, 0.999)
eps float

Term added to the denominator to improve numerical stability. Default is 1e-6.

1e-06
weight_decay float

Decoupled weight-decay coefficient. Default is 0.01.

0.01
torch_compile bool

If True, the inner _adam kernel is compiled with torch.compile for improved throughput. Default is True.

True
bf16_stochastic_round bool

If True, write-backs of moment buffers and parameter updates to bf16 storage use stochastic rounding rather than round-to- nearest. Eliminates systematic truncation bias in pure-bf16 training. Has no effect when parameters are fp32. Default is True.

True
Source code in src/forgather/ml/optim/adamw.py
def __init__(
    self,
    params: Iterable[nn.parameter.Parameter],
    lr: float = 1e-3,
    betas: Tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-6,
    weight_decay: float = 0.01,
    torch_compile: bool = True,
    bf16_stochastic_round: bool = True,
):
    """
    Parameters
    ----------
    params : iterable of Parameter
        Model parameters to optimize.
    lr : float, optional
        Learning rate. Default is 1e-3.
    betas : tuple of (float, float), optional
        Exponential decay rates for the first and second moment estimates.
        Default is (0.9, 0.999).
    eps : float, optional
        Term added to the denominator to improve numerical stability.
        Default is 1e-6.
    weight_decay : float, optional
        Decoupled weight-decay coefficient. Default is 0.01.
    torch_compile : bool, optional
        If ``True``, the inner ``_adam`` kernel is compiled with
        ``torch.compile`` for improved throughput. Default is ``True``.
    bf16_stochastic_round : bool, optional
        If ``True``, write-backs of moment buffers and parameter updates
        to bf16 storage use stochastic rounding rather than round-to-
        nearest.  Eliminates systematic truncation bias in pure-bf16
        training.  Has no effect when parameters are fp32.
        Default is ``True``.
    """
    self.compile = torch_compile
    defaults = dict(
        lr=lr,
        betas=betas,
        eps=eps,
        weight_decay=weight_decay,
        bf16_stochastic_round=bf16_stochastic_round,
    )
    super().__init__(params, defaults)

    # Dedicated generator for stochastic rounding. Using a fixed seed
    # ensures all DDP ranks produce identical rounding decisions,
    # preventing parameter divergence across ranks. The generator is
    # only advanced by SR draws (not shared with dropout, data loading,
    # etc.) so it stays in sync as long as all ranks process the same
    # parameters in the same order -- which DDP guarantees.
    self._sr_generator = torch.Generator()
    self._sr_generator.manual_seed(5489)
    self._sr_cuda_generators = {}  # device -> Generator, lazily created

state_dict()

Return optimizer state with structure validation.

Source code in src/forgather/ml/optim/adamw.py
def state_dict(self):
    """Return optimizer state with structure validation."""
    state_dict = super().state_dict()

    # Validate state structure for debugging
    for param_id, param_state in state_dict["state"].items():
        expected_keys = {"step", "m", "v"}
        if not expected_keys.issubset(param_state.keys()):
            missing = expected_keys - param_state.keys()
            raise ValueError(
                f"AdamW state missing keys for param {param_id}: {missing}"
            )

    # Save SR generator state for deterministic resume
    state_dict["sr_generator_state"] = self._sr_generator.get_state()

    return state_dict

load_state_dict(state_dict)

Load optimizer state with validation.

Source code in src/forgather/ml/optim/adamw.py
def load_state_dict(self, state_dict):
    """Load optimizer state with validation."""
    # Shallow copy to avoid mutating caller's dict
    state_dict = dict(state_dict)
    # Extract SR generator state before super() processes the dict
    sr_gen_state = state_dict.pop("sr_generator_state", None)

    # Validate before loading
    for param_id, param_state in state_dict["state"].items():
        expected_keys = {"step", "m", "v"}
        if not expected_keys.issubset(param_state.keys()):
            missing = expected_keys - param_state.keys()
            raise ValueError(
                f"Cannot load AdamW: missing keys for param {param_id}: {missing}"
            )

    super().load_state_dict(state_dict)

    # Restore SR generator state for deterministic resume
    if sr_gen_state is not None:
        self._sr_generator.set_state(sr_gen_state)

forgather.ml.optim.adafactor.Adafactor

Bases: Optimizer

Memory-efficient adaptive optimizer with factored second-moment estimation.

Implements the Adafactor algorithm (Shazeer & Stern, arXiv:1804.04235). For matrices, the second-moment accumulator is factored into outer-product row and column vectors, reducing per-parameter memory from O(n*m) to O(n+m). For vectors and scalars the full accumulator is retained.

Like AdamW, this implementation supports pure bf16 training via stochastic rounding on all write-backs, and handles FSDP2 DTensor parameters transparently. An optional Triton kernel path is available for higher GPU throughput on CUDA devices.

Prefer Adafactor over AdamW when:

  • Memory is the primary constraint (large models, small accelerators).
  • Training transformers with large embedding or projection matrices where the factored approximation is a good fit.
Notes

decay_rate controls how the effective beta2 grows with step count: beta2t = clamp(1 - step^decay_rate, max=beta2). The default of -0.8 replicates the schedule from the paper.

The Triton kernel path (use_triton=True) does not support relative_step=True.

References

Shazeer, N. & Stern, M. (2018). Adafactor: Adaptive Learning Rates with Sublinear Memory Cost. arXiv:1804.04235.

Loshchilov, I. & Hutter, F. (2017). Decoupled Weight Decay Regularization. arXiv:1711.05101.

Source code in src/forgather/ml/optim/adafactor.py
class Adafactor(Optimizer):
    """Memory-efficient adaptive optimizer with factored second-moment estimation.

    Implements the Adafactor algorithm (Shazeer & Stern, arXiv:1804.04235).
    For matrices, the second-moment accumulator is factored into outer-product
    row and column vectors, reducing per-parameter memory from O(n*m) to
    O(n+m).  For vectors and scalars the full accumulator is retained.

    Like `AdamW`, this implementation supports *pure bf16 training* via
    stochastic rounding on all write-backs, and handles FSDP2 DTensor
    parameters transparently.  An optional Triton kernel path is available
    for higher GPU throughput on CUDA devices.

    Prefer Adafactor over AdamW when:

    * Memory is the primary constraint (large models, small accelerators).
    * Training transformers with large embedding or projection matrices where
      the factored approximation is a good fit.

    Notes
    -----
    ``decay_rate`` controls how the effective ``beta2`` grows with step count:
    ``beta2t = clamp(1 - step^decay_rate, max=beta2)``.  The default of
    ``-0.8`` replicates the schedule from the paper.

    The Triton kernel path (``use_triton=True``) does not support
    ``relative_step=True``.

    References
    ----------
    Shazeer, N. & Stern, M. (2018). Adafactor: Adaptive Learning Rates with
    Sublinear Memory Cost. arXiv:1804.04235.

    Loshchilov, I. & Hutter, F. (2017). Decoupled Weight Decay Regularization.
    arXiv:1711.05101.
    """

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        lr: float = 1e-3,
        decay_rate: float = -0.8,
        clip_threshold: float = 1.0,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: Tuple[float, float] = (1e-30, 1e-3),
        weight_decay: float = 0.01,
        relative_step: bool = False,
        torch_compile: bool = True,
        bf16_stochastic_round: bool = True,
        use_triton: bool = False,
    ):
        """
        Parameters
        ----------
        params : iterable of Parameter
            Model parameters to optimize.
        lr : float, optional
            Learning rate (or relative step size when ``relative_step=True``).
            Default is 1e-3.
        decay_rate : float, optional
            Exponent controlling how the effective ``beta2`` grows with step
            count: ``beta2t = clamp(1 - step^decay_rate, max=beta2)``.
            Negative values cause ``beta2t`` to grow toward ``beta2`` over
            time. Default is -0.8.
        clip_threshold : float, optional
            Root-mean-square threshold for gradient clipping.  The update is
            scaled down when its RMS exceeds this value.  Default is 1.0.
        betas : tuple of (float, float), optional
            Upper bounds for the first and second moment decay rates.
            ``beta1`` is the EMA decay for the optional first moment;
            ``beta2`` caps the adaptive ``beta2t``.  Default is (0.9, 0.999).
        eps : tuple of (float, float), optional
            ``(eps1, eps2)``.  ``eps1`` is added to the squared gradient
            before factoring to improve numerical stability.  ``eps2`` is the
            minimum absolute learning rate used with ``relative_step=True``.
            Default is (1e-30, 1e-3).
        weight_decay : float, optional
            Decoupled weight-decay coefficient.  Default is 0.01.
        relative_step : bool, optional
            If ``True``, the effective learning rate is scaled by the RMS of
            the parameter: ``lr = max(eps2, rms(p)) * lr``, following the
            paper's relative-step-size formulation.  Default is ``False``.
        torch_compile : bool, optional
            If ``True``, the inner ``_adafactor`` kernel is compiled with
            ``torch.compile``.  Mutually exclusive with ``use_triton``.
            Default is ``True``.
        bf16_stochastic_round : bool, optional
            Enable stochastic rounding for bf16 write-backs.  Has no effect
            when parameters are fp32.  Default is ``True``.
        use_triton : bool, optional
            If ``True``, use Triton-compiled CUDA kernels instead of the
            PyTorch implementation.  Requires ``triton`` to be installed and
            ``relative_step=False``.  Default is ``False``.
        """
        self.compile = torch_compile
        self.use_triton = use_triton

        # Import Triton kernels if needed
        if use_triton:
            assert (
                relative_step == False
            ), "relative_step is not supported by Adafactor Triton kernel. Set use_triton = False"
            try:
                from . import adafactor_triton

                self.triton_module = adafactor_triton
            except ImportError as e:
                raise ImportError(
                    "Triton is required for use_triton=True. "
                    "Please install it with: pip install triton"
                ) from e

        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            decay_rate=decay_rate,
            clip_threshold=clip_threshold,
            weight_decay=weight_decay,
            relative_step=relative_step,
            bf16_stochastic_round=bf16_stochastic_round,
        )
        super().__init__(params, defaults)

        # Dedicated generator for stochastic rounding. Using a fixed seed
        # ensures all DDP ranks produce identical rounding decisions,
        # preventing parameter divergence across ranks. The generator is
        # only advanced by SR draws (not shared with dropout, data loading,
        # etc.) so it stays in sync as long as all ranks process the same
        # parameters in the same order -- which DDP guarantees.
        self._sr_generator = torch.Generator()
        self._sr_generator.manual_seed(5489)
        self._sr_cuda_generators = {}  # device -> Generator, lazily created

    def add_param_group(self, param_group: dict):
        super().add_param_group(param_group)
        if not self.use_triton:
            group = self.param_groups[-1]
            if not isinstance(group["lr"], Tensor):
                group["lr"] = torch.tensor(group["lr"], dtype=torch.float32)

    def _init_state(self, state, group, p, grad):
        state["step"] = torch.tensor(0.0, dtype=torch.float32)
        if grad.dim() <= 1:
            state["row"] = torch.zeros_like(grad, dtype=torch.float32)
            state["col"] = None
        else:
            state["row"] = torch.zeros(
                grad[..., 0].numel(), dtype=torch.float32, device=grad.device
            )
            state["col"] = torch.zeros(
                grad.shape[-1], dtype=torch.float32, device=grad.device
            )

    @torch.no_grad()
    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        with torch._dynamo.utils.disable_cache_limit():
            for group in self.param_groups:
                for p in group["params"]:
                    grad = p.grad
                    if grad is None:
                        continue
                    # Optimizer state must be keyed by the DTensor-backed
                    # Parameter so save/load via get_optimizer_state_dict
                    # round-trips correctly; do the state lookup before
                    # unsharding.
                    state = self.state[p]
                    # Under FSDP2, p and p.grad are DTensors. The kernel
                    # below operates entirely on local shards; see
                    # _local_shard docstring. The returned tensors are
                    # views over the DTensor's local storage, so in-place
                    # ops flow back into the live parameter.
                    grad = _local_shard(grad)
                    p = _local_shard(p)

                    # Init state
                    if "step" not in state:
                        self._init_state(state, group, p, grad)

                    lr = group["lr"]

                    state["step"] += 1
                    beta1, beta2 = group["betas"]
                    eps1, eps2 = group["eps"]

                    # Compute decay parameter for beta2
                    beta2t = (1.0 - state["step"] ** group["decay_rate"]).clamp(
                        max=beta2
                    )

                    # Draw SR seed from dedicated generator (same across DDP ranks)
                    bf16_sr = group["bf16_stochastic_round"]
                    if bf16_sr:
                        sr_seed = int(
                            torch.randint(
                                0,
                                2**31,
                                (1,),
                                generator=self._sr_generator,
                            ).item()
                        )
                    else:
                        sr_seed = 0

                    # Route to Triton or PyTorch implementation
                    if self.use_triton and grad.is_cuda:
                        # Use Triton kernels
                        if state["col"] is None:
                            # 1D case
                            self.triton_module.adafactor_step_1d_triton(
                                p,
                                grad,
                                state["row"],
                                beta2t,
                                eps1,
                                lr,
                                group["weight_decay"],
                                group["clip_threshold"],
                                bf16_sr,
                                sr_seed,
                            )
                        else:
                            # 2D case
                            self.triton_module.adafactor_step_2d_triton(
                                p,
                                grad,
                                state["row"],
                                state["col"],
                                beta2t,
                                eps1,
                                lr,
                                group["weight_decay"],
                                group["clip_threshold"],
                                bf16_sr,
                                sr_seed,
                            )
                    else:
                        assert isinstance(
                            lr, Tensor
                        ), "Someone changed our lr to a non-Tensor!?"

                        # Pre-generate stochastic rounding noise.
                        # torch.Generator can't be traced by dynamo, so we
                        # generate the random bits here (outside compile) and
                        # pass the resulting tensor into the compiled function.
                        sr_rand_bits = None
                        if bf16_sr:
                            device = p.device
                            if device.type == "cuda":
                                if device not in self._sr_cuda_generators:
                                    self._sr_cuda_generators[device] = torch.Generator(
                                        device=device
                                    )
                                sr_gen = self._sr_cuda_generators[device]
                                sr_gen.manual_seed(sr_seed)
                            else:
                                sr_gen = torch.Generator(device=device)
                                sr_gen.manual_seed(sr_seed)
                            sr_rand_bits = torch.randint(
                                0,
                                1 << 16,
                                p.shape,
                                device=device,
                                dtype=torch.int32,
                                generator=sr_gen,
                            )

                        # Use standard PyTorch implementation
                        args = [
                            p,
                            grad,
                            state["step"],
                            state["row"],
                            state["col"],
                            lr,
                            beta1,
                            beta2,
                            group["decay_rate"],
                            group["clip_threshold"],
                            eps1,
                            eps2,
                            group["weight_decay"],
                            group["relative_step"],
                            bf16_sr,
                            sr_rand_bits,
                        ]
                        if self.compile:
                            torch.compile(_adafactor, fullgraph=True, dynamic=False)(
                                *args
                            )
                        else:
                            _adafactor(*args)

        return loss

    def state_dict(self):
        """Return optimizer state handling conditional col=None."""
        state_dict = super().state_dict()

        # Validate state structure
        for param_id, param_state in state_dict["state"].items():
            expected_keys = {"step", "row", "col"}
            if not expected_keys.issubset(param_state.keys()):
                missing = expected_keys - param_state.keys()
                raise ValueError(
                    f"Adafactor state missing keys for param {param_id}: {missing}"
                )

            # Ensure col=None is handled correctly (not converted to tensor)
            if param_state["col"] is not None and not torch.is_tensor(
                param_state["col"]
            ):
                raise ValueError(
                    f"Adafactor col must be tensor or None, got {type(param_state['col'])}"
                )

        # Save SR generator state for deterministic resume
        state_dict["sr_generator_state"] = self._sr_generator.get_state()

        return state_dict

    def load_state_dict(self, state_dict):
        """Load optimizer state handling conditional col=None."""
        # Shallow copy to avoid mutating caller's dict
        state_dict = dict(state_dict)
        # Extract SR generator state before super() processes the dict
        sr_gen_state = state_dict.pop("sr_generator_state", None)

        # Validate structure
        for param_id, param_state in state_dict["state"].items():
            expected_keys = {"step", "row", "col"}
            if not expected_keys.issubset(param_state.keys()):
                missing = expected_keys - param_state.keys()
                raise ValueError(
                    f"Cannot load Adafactor: missing keys for param {param_id}: {missing}"
                )

        super().load_state_dict(state_dict)

        # Restore SR generator state for deterministic resume
        if sr_gen_state is not None:
            self._sr_generator.set_state(sr_gen_state)

__init__(params, lr=0.001, decay_rate=-0.8, clip_threshold=1.0, betas=(0.9, 0.999), eps=(1e-30, 0.001), weight_decay=0.01, relative_step=False, torch_compile=True, bf16_stochastic_round=True, use_triton=False)

Parameters:

Name Type Description Default
params iterable of Parameter

Model parameters to optimize.

required
lr float

Learning rate (or relative step size when relative_step=True). Default is 1e-3.

0.001
decay_rate float

Exponent controlling how the effective beta2 grows with step count: beta2t = clamp(1 - step^decay_rate, max=beta2). Negative values cause beta2t to grow toward beta2 over time. Default is -0.8.

-0.8
clip_threshold float

Root-mean-square threshold for gradient clipping. The update is scaled down when its RMS exceeds this value. Default is 1.0.

1.0
betas tuple of (float, float)

Upper bounds for the first and second moment decay rates. beta1 is the EMA decay for the optional first moment; beta2 caps the adaptive beta2t. Default is (0.9, 0.999).

(0.9, 0.999)
eps tuple of (float, float)

(eps1, eps2). eps1 is added to the squared gradient before factoring to improve numerical stability. eps2 is the minimum absolute learning rate used with relative_step=True. Default is (1e-30, 1e-3).

(1e-30, 0.001)
weight_decay float

Decoupled weight-decay coefficient. Default is 0.01.

0.01
relative_step bool

If True, the effective learning rate is scaled by the RMS of the parameter: lr = max(eps2, rms(p)) * lr, following the paper's relative-step-size formulation. Default is False.

False
torch_compile bool

If True, the inner _adafactor kernel is compiled with torch.compile. Mutually exclusive with use_triton. Default is True.

True
bf16_stochastic_round bool

Enable stochastic rounding for bf16 write-backs. Has no effect when parameters are fp32. Default is True.

True
use_triton bool

If True, use Triton-compiled CUDA kernels instead of the PyTorch implementation. Requires triton to be installed and relative_step=False. Default is False.

False
Source code in src/forgather/ml/optim/adafactor.py
def __init__(
    self,
    params: Iterable[nn.parameter.Parameter],
    lr: float = 1e-3,
    decay_rate: float = -0.8,
    clip_threshold: float = 1.0,
    betas: Tuple[float, float] = (0.9, 0.999),
    eps: Tuple[float, float] = (1e-30, 1e-3),
    weight_decay: float = 0.01,
    relative_step: bool = False,
    torch_compile: bool = True,
    bf16_stochastic_round: bool = True,
    use_triton: bool = False,
):
    """
    Parameters
    ----------
    params : iterable of Parameter
        Model parameters to optimize.
    lr : float, optional
        Learning rate (or relative step size when ``relative_step=True``).
        Default is 1e-3.
    decay_rate : float, optional
        Exponent controlling how the effective ``beta2`` grows with step
        count: ``beta2t = clamp(1 - step^decay_rate, max=beta2)``.
        Negative values cause ``beta2t`` to grow toward ``beta2`` over
        time. Default is -0.8.
    clip_threshold : float, optional
        Root-mean-square threshold for gradient clipping.  The update is
        scaled down when its RMS exceeds this value.  Default is 1.0.
    betas : tuple of (float, float), optional
        Upper bounds for the first and second moment decay rates.
        ``beta1`` is the EMA decay for the optional first moment;
        ``beta2`` caps the adaptive ``beta2t``.  Default is (0.9, 0.999).
    eps : tuple of (float, float), optional
        ``(eps1, eps2)``.  ``eps1`` is added to the squared gradient
        before factoring to improve numerical stability.  ``eps2`` is the
        minimum absolute learning rate used with ``relative_step=True``.
        Default is (1e-30, 1e-3).
    weight_decay : float, optional
        Decoupled weight-decay coefficient.  Default is 0.01.
    relative_step : bool, optional
        If ``True``, the effective learning rate is scaled by the RMS of
        the parameter: ``lr = max(eps2, rms(p)) * lr``, following the
        paper's relative-step-size formulation.  Default is ``False``.
    torch_compile : bool, optional
        If ``True``, the inner ``_adafactor`` kernel is compiled with
        ``torch.compile``.  Mutually exclusive with ``use_triton``.
        Default is ``True``.
    bf16_stochastic_round : bool, optional
        Enable stochastic rounding for bf16 write-backs.  Has no effect
        when parameters are fp32.  Default is ``True``.
    use_triton : bool, optional
        If ``True``, use Triton-compiled CUDA kernels instead of the
        PyTorch implementation.  Requires ``triton`` to be installed and
        ``relative_step=False``.  Default is ``False``.
    """
    self.compile = torch_compile
    self.use_triton = use_triton

    # Import Triton kernels if needed
    if use_triton:
        assert (
            relative_step == False
        ), "relative_step is not supported by Adafactor Triton kernel. Set use_triton = False"
        try:
            from . import adafactor_triton

            self.triton_module = adafactor_triton
        except ImportError as e:
            raise ImportError(
                "Triton is required for use_triton=True. "
                "Please install it with: pip install triton"
            ) from e

    defaults = dict(
        lr=lr,
        betas=betas,
        eps=eps,
        decay_rate=decay_rate,
        clip_threshold=clip_threshold,
        weight_decay=weight_decay,
        relative_step=relative_step,
        bf16_stochastic_round=bf16_stochastic_round,
    )
    super().__init__(params, defaults)

    # Dedicated generator for stochastic rounding. Using a fixed seed
    # ensures all DDP ranks produce identical rounding decisions,
    # preventing parameter divergence across ranks. The generator is
    # only advanced by SR draws (not shared with dropout, data loading,
    # etc.) so it stays in sync as long as all ranks process the same
    # parameters in the same order -- which DDP guarantees.
    self._sr_generator = torch.Generator()
    self._sr_generator.manual_seed(5489)
    self._sr_cuda_generators = {}  # device -> Generator, lazily created

state_dict()

Return optimizer state handling conditional col=None.

Source code in src/forgather/ml/optim/adafactor.py
def state_dict(self):
    """Return optimizer state handling conditional col=None."""
    state_dict = super().state_dict()

    # Validate state structure
    for param_id, param_state in state_dict["state"].items():
        expected_keys = {"step", "row", "col"}
        if not expected_keys.issubset(param_state.keys()):
            missing = expected_keys - param_state.keys()
            raise ValueError(
                f"Adafactor state missing keys for param {param_id}: {missing}"
            )

        # Ensure col=None is handled correctly (not converted to tensor)
        if param_state["col"] is not None and not torch.is_tensor(
            param_state["col"]
        ):
            raise ValueError(
                f"Adafactor col must be tensor or None, got {type(param_state['col'])}"
            )

    # Save SR generator state for deterministic resume
    state_dict["sr_generator_state"] = self._sr_generator.get_state()

    return state_dict

load_state_dict(state_dict)

Load optimizer state handling conditional col=None.

Source code in src/forgather/ml/optim/adafactor.py
def load_state_dict(self, state_dict):
    """Load optimizer state handling conditional col=None."""
    # Shallow copy to avoid mutating caller's dict
    state_dict = dict(state_dict)
    # Extract SR generator state before super() processes the dict
    sr_gen_state = state_dict.pop("sr_generator_state", None)

    # Validate structure
    for param_id, param_state in state_dict["state"].items():
        expected_keys = {"step", "row", "col"}
        if not expected_keys.issubset(param_state.keys()):
            missing = expected_keys - param_state.keys()
            raise ValueError(
                f"Cannot load Adafactor: missing keys for param {param_id}: {missing}"
            )

    super().load_state_dict(state_dict)

    # Restore SR generator state for deterministic resume
    if sr_gen_state is not None:
        self._sr_generator.set_state(sr_gen_state)

forgather.ml.optim.apollo.Apollo

Bases: Optimizer

Low-rank gradient-projection optimizer with AdamW-level performance.

Implements the Apollo algorithm (Zhu et al., arXiv:2412.05270). Rather than maintaining full-size first and second moment buffers, Apollo projects gradients into a low-rank subspace (controlled by rank), runs the Adam update there, and uses the resulting per-column scaling signal to scale the full-rank gradient. Moment buffer memory scales as O(rank * max(n, m)) instead of O(n * m).

Also applies the Norm-Growth Limiter from Fira (arXiv:2410.01623) to prevent destructive gradient updates.

Prefer Apollo over AdamW when:

  • Memory is constrained and Adafactor's factored approximation is too aggressive (Apollo retains the full gradient for the parameter update).
  • rank=1 (Apollo-Mini) is desired for maximum memory savings while still outperforming SGD.
Notes

The projector_factory callable is not serialisable and is therefore stripped from checkpoints. It must be supplied again via the constructor when resuming from a checkpoint.

References

Zhu, W. et al. (2024). APOLLO: SGD-like Memory, AdamW-level Performance. arXiv:2412.05270.

Chen, Y. et al. (2024). Fira: Can We Achieve Full-Rank Training of LLMs Under Low-Rank Constraint? arXiv:2410.01623.

Source code in src/forgather/ml/optim/apollo.py
class Apollo(Optimizer):
    """Low-rank gradient-projection optimizer with AdamW-level performance.

    Implements the Apollo algorithm (Zhu et al., arXiv:2412.05270).  Rather
    than maintaining full-size first and second moment buffers, Apollo projects
    gradients into a low-rank subspace (controlled by ``rank``), runs the
    Adam update there, and uses the resulting per-column scaling signal to
    scale the full-rank gradient.  Moment buffer memory scales as
    ``O(rank * max(n, m))`` instead of ``O(n * m)``.

    Also applies the Norm-Growth Limiter from Fira (arXiv:2410.01623) to
    prevent destructive gradient updates.

    Prefer Apollo over AdamW when:

    * Memory is constrained and Adafactor's factored approximation is too
      aggressive (Apollo retains the full gradient for the parameter update).
    * ``rank=1`` (Apollo-Mini) is desired for maximum memory savings while
      still outperforming SGD.

    Notes
    -----
    The ``projector_factory`` callable is not serialisable and is therefore
    stripped from checkpoints.  It must be supplied again via the constructor
    when resuming from a checkpoint.

    References
    ----------
    Zhu, W. et al. (2024). APOLLO: SGD-like Memory, AdamW-level Performance.
    arXiv:2412.05270.

    Chen, Y. et al. (2024). Fira: Can We Achieve Full-Rank Training of LLMs
    Under Low-Rank Constraint? arXiv:2410.01623.
    """

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.0,
        rank: int = 1,
        scale: float = 1.0,
        scale_front: bool = False,
        update_steps: int = 10,
        mini: bool = False,
        projector_factory: Callable = None,
    ):
        """
        Parameters
        ----------
        params : iterable of Parameter
            Model parameters to optimize.
        lr : float, optional
            Learning rate.  Default is 1e-3.
        betas : tuple of (float, float), optional
            Exponential decay rates for the low-rank first and second moment
            estimates.  Default is (0.9, 0.999).
        eps : float, optional
            Term added to the denominator of the Adam update in the low-rank
            subspace.  Default is 1e-6.
        weight_decay : float, optional
            Decoupled weight-decay coefficient.  Default is 0.0.
        rank : int, optional
            Rank of the gradient projection subspace.  Lower rank saves more
            memory; ``rank=1`` corresponds to Apollo-Mini.  Default is 1.
        scale : float, optional
            Additional scaling factor applied to the update.  Applied before
            the Norm-Growth Limiter when ``scale_front=True``, or after when
            ``scale_front=False``.  Default is 1.0.
        scale_front : bool, optional
            If ``True``, ``scale`` is applied before the Norm-Growth Limiter
            so that the limiter sees the already-scaled gradient norm.
            Default is ``False``.
        update_steps : int, optional
            How often (in optimizer steps) the projection matrix is refreshed.
            Passed through to the projector created by ``projector_factory``.
            Default is 10.
        mini : bool, optional
            If ``True``, compute a single global scaling scalar instead of
            per-column scalars.  Corresponds to the Apollo-Mini variant.
            Default is ``False``.
        projector_factory : callable, optional
            Factory that constructs the gradient projector given keyword
            arguments ``rank``, ``dim``, and ``proj_type``.  Must be provided;
            the optimizer will raise at the first step if it is ``None``.
        """
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            rank=rank,
            scale=scale,
            scale_front=scale_front,
            update_steps=update_steps,
            mini=mini,
            projector_factory=projector_factory,
        )
        super().__init__(params, defaults)

    def _init_state(self, state, group, p, grad):
        rank = group["rank"]

        if grad.shape[0] < grad.shape[1]:
            dim = grad.shape[0]
            proj_shape = (rank, grad.shape[1])
            proj_type = "left"
        else:
            dim = grad.shape[1]
            proj_shape = (rank, grad.shape[0])
            proj_type = "right"

        state["projector"] = group["projector_factory"](
            rank=rank,
            dim=dim,
            proj_type=proj_type,
        )

        state["step"] = torch.tensor(0.0, dtype=torch.float32)
        state["m"] = torch.zeros(*proj_shape, device=grad.device, dtype=grad.dtype)
        state["v"] = torch.zeros(*proj_shape, device=grad.device, dtype=grad.dtype)

    @torch.no_grad()
    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]

                # Init state
                if "step" not in state:
                    self._init_state(state, group, p, grad)

                projector = state["projector"]
                state["step"] += 1
                step = state["step"]
                beta1, beta2 = group["betas"]
                M, V = state["m"], state["v"]
                lr = group["lr"]
                eps = group["eps"]
                weight_decay = group["weight_decay"]
                scale = group["scale"]
                scale_front = group["scale_front"]

                # Weight decay
                if weight_decay > 0.0:
                    p.add_(p, alpha=(-lr * weight_decay))

                # Apply bias correction to lr
                lr = lr * torch.sqrt(1.0 - beta2**step) / (1.0 - beta1**step)

                # Update projector
                projector.step(grad)

                # Project gradient into low-rank sub-space
                R = projector.down(grad) * projector.scale

                # Update EMA of 1st and 2nd moments
                M.lerp_(R, 1.0 - beta1)
                V.lerp_(R.square(), 1.0 - beta2)
                R_tilde = M / (V.sqrt() + eps)

                if group["mini"]:
                    S = torch.linalg.norm(R_tilde) / (torch.linalg.norm(R) + 1e-8)
                else:
                    S = torch.linalg.norm(R_tilde, dim=0) / (
                        torch.linalg.norm(R, dim=0) + 1e-8
                    )
                    if grad.shape[0] >= grad.shape[1]:
                        S = S.view(-1, 1)

                update = grad * S

                if scale_front and scale != 1.0:
                    update *= math.sqrt(scale)

                # Apply Norm-Growth Limiter in Fira (https://arxiv.org/abs/2410.01623) to avoid destructive gradient updates.
                if "scaled_grad" in state:
                    scaled_grad_norm = torch.linalg.norm(update)
                    limiter = (
                        max(
                            scaled_grad_norm / (state["scaled_grad"] + 1e-8),
                            1.01,
                        )
                        / 1.01
                    )
                    update = update / limiter
                    state["scaled_grad"] = scaled_grad_norm / limiter
                else:
                    state["scaled_grad"] = torch.norm(update)

                if not scale_front and scale != 1.0:
                    update *= math.sqrt(scale)

                p.add_(update, alpha=-lr)
                # print(f"{projector.proj_type=}, {projector.A.shape=}, {R_tilde.shape=}, {S.shape=}, {grad.shape=}")

        return loss

    def state_dict(self):
        """Return optimizer state with serialized projector objects.

        Projector objects are converted to dicts containing only tensors and primitives
        to ensure proper checkpoint serialization.

        Note: The projector_factory in param_groups is removed since it's a
        non-serializable function. On load_state_dict, it must be provided
        via the optimizer constructor.
        """
        from forgather.ml.optim.subspace_proj import OnlinePCAProjector, RandProjector

        state_dict = super().state_dict()

        # Remove projector_factory from param_groups (can't pickle functions)
        for group in state_dict["param_groups"]:
            if "projector_factory" in group:
                del group["projector_factory"]

        # Serialize projector objects
        for param_id, param_state in state_dict["state"].items():
            if "projector" in param_state:
                proj = param_state["projector"]

                # Serialize based on projector type
                proj_dict = {
                    "_class": type(proj).__name__,
                    "rank": proj.rank,
                    "dim": proj.dim,
                    "proj_type": proj.proj_type,
                    "update_steps": proj.update_steps,
                    "_step": proj._step,
                    "scale": proj.scale,
                }

                # Add type-specific state
                if isinstance(proj, OnlinePCAProjector):
                    proj_dict["A"] = proj.A
                    # orthonormalize function is reconstructed from defaults

                elif isinstance(proj, RandProjector):
                    proj_dict["A"] = proj.A
                    proj_dict["init"] = proj.init
                    proj_dict["lazy"] = proj.lazy
                    proj_dict["seed"] = proj.seed
                    if hasattr(proj, "gen") and proj.gen is not None:
                        proj_dict["gen_state"] = proj.gen.get_state()
                    if hasattr(proj, "saved_gen_state"):
                        proj_dict["saved_gen_state"] = proj.saved_gen_state
                    if hasattr(proj, "device"):
                        proj_dict["device"] = str(proj.device)  # Serialize as string
                    if hasattr(proj, "dtype"):
                        proj_dict["dtype"] = str(proj.dtype)  # Serialize as string

                param_state["projector"] = proj_dict

        return state_dict

    def load_state_dict(self, state_dict):
        """Load optimizer state and reconstruct projector objects.

        Deserializes projector dicts back into projector objects.
        """
        from forgather.ml.optim.subspace_proj import OnlinePCAProjector, RandProjector

        # Reconstruct projector objects from serialized dicts
        for param_id, param_state in state_dict["state"].items():
            if "projector" in param_state:
                proj_dict = param_state["projector"]

                if not isinstance(proj_dict, dict):
                    raise ValueError(
                        f"Apollo projector state must be dict, got {type(proj_dict)}"
                    )

                proj_class_name = proj_dict.get("_class")

                # Reconstruct based on class type
                if proj_class_name == "OnlinePCAProjector":
                    # Note: orthag defaults to "none" in current implementation
                    proj = OnlinePCAProjector(
                        rank=proj_dict["rank"],
                        dim=proj_dict["dim"],
                        proj_type=proj_dict["proj_type"],
                        update_steps=proj_dict["update_steps"],
                    )
                    # Restore all attributes
                    proj.A = proj_dict["A"]
                    proj._step = proj_dict["_step"]
                    proj.scale = proj_dict["scale"]
                    # Note: proj_shape, einsum_* are set by __init__, orthonormalize defaults to identity

                elif proj_class_name == "RandProjector":
                    proj = RandProjector(
                        rank=proj_dict["rank"],
                        dim=proj_dict["dim"],
                        proj_type=proj_dict["proj_type"],
                        update_steps=proj_dict["update_steps"],
                        init=proj_dict["init"],
                        lazy=proj_dict["lazy"],
                        seed=proj_dict["seed"],
                    )
                    proj.A = proj_dict["A"]
                    proj._step = proj_dict["_step"]
                    proj.scale = proj_dict["scale"]

                    # Restore generator state if present
                    if "gen_state" in proj_dict:
                        # Need to create generator on correct device
                        device_str = proj_dict.get("device", "cpu")
                        device = torch.device(device_str.replace("cuda:", "cuda:"))
                        proj.gen = torch.Generator(device=device)
                        proj.gen.set_state(proj_dict["gen_state"])

                    if "saved_gen_state" in proj_dict:
                        proj.saved_gen_state = proj_dict["saved_gen_state"]
                    if "device" in proj_dict:
                        proj.device = torch.device(
                            proj_dict["device"].replace("cuda:", "cuda:")
                        )
                    if "dtype" in proj_dict:
                        # Convert string like "torch.float32" to dtype
                        dtype_str = proj_dict["dtype"].replace("torch.", "")
                        proj.dtype = getattr(torch, dtype_str)

                else:
                    raise ValueError(f"Unknown projector class: {proj_class_name}")

                param_state["projector"] = proj

        super().load_state_dict(state_dict)

__init__(params, lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0, rank=1, scale=1.0, scale_front=False, update_steps=10, mini=False, projector_factory=None)

Parameters:

Name Type Description Default
params iterable of Parameter

Model parameters to optimize.

required
lr float

Learning rate. Default is 1e-3.

0.001
betas tuple of (float, float)

Exponential decay rates for the low-rank first and second moment estimates. Default is (0.9, 0.999).

(0.9, 0.999)
eps float

Term added to the denominator of the Adam update in the low-rank subspace. Default is 1e-6.

1e-06
weight_decay float

Decoupled weight-decay coefficient. Default is 0.0.

0.0
rank int

Rank of the gradient projection subspace. Lower rank saves more memory; rank=1 corresponds to Apollo-Mini. Default is 1.

1
scale float

Additional scaling factor applied to the update. Applied before the Norm-Growth Limiter when scale_front=True, or after when scale_front=False. Default is 1.0.

1.0
scale_front bool

If True, scale is applied before the Norm-Growth Limiter so that the limiter sees the already-scaled gradient norm. Default is False.

False
update_steps int

How often (in optimizer steps) the projection matrix is refreshed. Passed through to the projector created by projector_factory. Default is 10.

10
mini bool

If True, compute a single global scaling scalar instead of per-column scalars. Corresponds to the Apollo-Mini variant. Default is False.

False
projector_factory callable

Factory that constructs the gradient projector given keyword arguments rank, dim, and proj_type. Must be provided; the optimizer will raise at the first step if it is None.

None
Source code in src/forgather/ml/optim/apollo.py
def __init__(
    self,
    params: Iterable[nn.parameter.Parameter],
    lr: float = 1e-3,
    betas: Tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-6,
    weight_decay: float = 0.0,
    rank: int = 1,
    scale: float = 1.0,
    scale_front: bool = False,
    update_steps: int = 10,
    mini: bool = False,
    projector_factory: Callable = None,
):
    """
    Parameters
    ----------
    params : iterable of Parameter
        Model parameters to optimize.
    lr : float, optional
        Learning rate.  Default is 1e-3.
    betas : tuple of (float, float), optional
        Exponential decay rates for the low-rank first and second moment
        estimates.  Default is (0.9, 0.999).
    eps : float, optional
        Term added to the denominator of the Adam update in the low-rank
        subspace.  Default is 1e-6.
    weight_decay : float, optional
        Decoupled weight-decay coefficient.  Default is 0.0.
    rank : int, optional
        Rank of the gradient projection subspace.  Lower rank saves more
        memory; ``rank=1`` corresponds to Apollo-Mini.  Default is 1.
    scale : float, optional
        Additional scaling factor applied to the update.  Applied before
        the Norm-Growth Limiter when ``scale_front=True``, or after when
        ``scale_front=False``.  Default is 1.0.
    scale_front : bool, optional
        If ``True``, ``scale`` is applied before the Norm-Growth Limiter
        so that the limiter sees the already-scaled gradient norm.
        Default is ``False``.
    update_steps : int, optional
        How often (in optimizer steps) the projection matrix is refreshed.
        Passed through to the projector created by ``projector_factory``.
        Default is 10.
    mini : bool, optional
        If ``True``, compute a single global scaling scalar instead of
        per-column scalars.  Corresponds to the Apollo-Mini variant.
        Default is ``False``.
    projector_factory : callable, optional
        Factory that constructs the gradient projector given keyword
        arguments ``rank``, ``dim``, and ``proj_type``.  Must be provided;
        the optimizer will raise at the first step if it is ``None``.
    """
    defaults = dict(
        lr=lr,
        betas=betas,
        eps=eps,
        weight_decay=weight_decay,
        rank=rank,
        scale=scale,
        scale_front=scale_front,
        update_steps=update_steps,
        mini=mini,
        projector_factory=projector_factory,
    )
    super().__init__(params, defaults)

state_dict()

Return optimizer state with serialized projector objects.

Projector objects are converted to dicts containing only tensors and primitives to ensure proper checkpoint serialization.

Note: The projector_factory in param_groups is removed since it's a non-serializable function. On load_state_dict, it must be provided via the optimizer constructor.

Source code in src/forgather/ml/optim/apollo.py
def state_dict(self):
    """Return optimizer state with serialized projector objects.

    Projector objects are converted to dicts containing only tensors and primitives
    to ensure proper checkpoint serialization.

    Note: The projector_factory in param_groups is removed since it's a
    non-serializable function. On load_state_dict, it must be provided
    via the optimizer constructor.
    """
    from forgather.ml.optim.subspace_proj import OnlinePCAProjector, RandProjector

    state_dict = super().state_dict()

    # Remove projector_factory from param_groups (can't pickle functions)
    for group in state_dict["param_groups"]:
        if "projector_factory" in group:
            del group["projector_factory"]

    # Serialize projector objects
    for param_id, param_state in state_dict["state"].items():
        if "projector" in param_state:
            proj = param_state["projector"]

            # Serialize based on projector type
            proj_dict = {
                "_class": type(proj).__name__,
                "rank": proj.rank,
                "dim": proj.dim,
                "proj_type": proj.proj_type,
                "update_steps": proj.update_steps,
                "_step": proj._step,
                "scale": proj.scale,
            }

            # Add type-specific state
            if isinstance(proj, OnlinePCAProjector):
                proj_dict["A"] = proj.A
                # orthonormalize function is reconstructed from defaults

            elif isinstance(proj, RandProjector):
                proj_dict["A"] = proj.A
                proj_dict["init"] = proj.init
                proj_dict["lazy"] = proj.lazy
                proj_dict["seed"] = proj.seed
                if hasattr(proj, "gen") and proj.gen is not None:
                    proj_dict["gen_state"] = proj.gen.get_state()
                if hasattr(proj, "saved_gen_state"):
                    proj_dict["saved_gen_state"] = proj.saved_gen_state
                if hasattr(proj, "device"):
                    proj_dict["device"] = str(proj.device)  # Serialize as string
                if hasattr(proj, "dtype"):
                    proj_dict["dtype"] = str(proj.dtype)  # Serialize as string

            param_state["projector"] = proj_dict

    return state_dict

load_state_dict(state_dict)

Load optimizer state and reconstruct projector objects.

Deserializes projector dicts back into projector objects.

Source code in src/forgather/ml/optim/apollo.py
def load_state_dict(self, state_dict):
    """Load optimizer state and reconstruct projector objects.

    Deserializes projector dicts back into projector objects.
    """
    from forgather.ml.optim.subspace_proj import OnlinePCAProjector, RandProjector

    # Reconstruct projector objects from serialized dicts
    for param_id, param_state in state_dict["state"].items():
        if "projector" in param_state:
            proj_dict = param_state["projector"]

            if not isinstance(proj_dict, dict):
                raise ValueError(
                    f"Apollo projector state must be dict, got {type(proj_dict)}"
                )

            proj_class_name = proj_dict.get("_class")

            # Reconstruct based on class type
            if proj_class_name == "OnlinePCAProjector":
                # Note: orthag defaults to "none" in current implementation
                proj = OnlinePCAProjector(
                    rank=proj_dict["rank"],
                    dim=proj_dict["dim"],
                    proj_type=proj_dict["proj_type"],
                    update_steps=proj_dict["update_steps"],
                )
                # Restore all attributes
                proj.A = proj_dict["A"]
                proj._step = proj_dict["_step"]
                proj.scale = proj_dict["scale"]
                # Note: proj_shape, einsum_* are set by __init__, orthonormalize defaults to identity

            elif proj_class_name == "RandProjector":
                proj = RandProjector(
                    rank=proj_dict["rank"],
                    dim=proj_dict["dim"],
                    proj_type=proj_dict["proj_type"],
                    update_steps=proj_dict["update_steps"],
                    init=proj_dict["init"],
                    lazy=proj_dict["lazy"],
                    seed=proj_dict["seed"],
                )
                proj.A = proj_dict["A"]
                proj._step = proj_dict["_step"]
                proj.scale = proj_dict["scale"]

                # Restore generator state if present
                if "gen_state" in proj_dict:
                    # Need to create generator on correct device
                    device_str = proj_dict.get("device", "cpu")
                    device = torch.device(device_str.replace("cuda:", "cuda:"))
                    proj.gen = torch.Generator(device=device)
                    proj.gen.set_state(proj_dict["gen_state"])

                if "saved_gen_state" in proj_dict:
                    proj.saved_gen_state = proj_dict["saved_gen_state"]
                if "device" in proj_dict:
                    proj.device = torch.device(
                        proj_dict["device"].replace("cuda:", "cuda:")
                    )
                if "dtype" in proj_dict:
                    # Convert string like "torch.float32" to dtype
                    dtype_str = proj_dict["dtype"].replace("torch.", "")
                    proj.dtype = getattr(torch, dtype_str)

            else:
                raise ValueError(f"Unknown projector class: {proj_class_name}")

            param_state["projector"] = proj

    super().load_state_dict(state_dict)

forgather.ml.optim.sgd.SGD

Bases: Optimizer

Minimal vanilla SGD optimizer.

Applies the plain stochastic gradient descent update rule:

``p = p - lr * grad``

No momentum, weight decay, or gradient clipping. Intended as a minimal reference implementation and starting point for custom optimizers. For production training, prefer AdamW or Adafactor.

Parameters:

Name Type Description Default
params iterable of Parameter

Model parameters to optimize.

required
lr float

Learning rate. Default is 1e-3.

0.001
Source code in src/forgather/ml/optim/sgd.py
class SGD(Optimizer):
    """Minimal vanilla SGD optimizer.

    Applies the plain stochastic gradient descent update rule:

        ``p = p - lr * grad``

    No momentum, weight decay, or gradient clipping.  Intended as a minimal
    reference implementation and starting point for custom optimizers.  For
    production training, prefer `AdamW` or `Adafactor`.

    Parameters
    ----------
    params : iterable of Parameter
        Model parameters to optimize.
    lr : float, optional
        Learning rate.  Default is 1e-3.
    """

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        lr: float = 1e-3,
    ):
        defaults = dict(
            lr=lr,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]

                p.add_(grad, alpha=-group["lr"])

        return loss

Schedulers

forgather.ml.optim.infinite_lr_scheduler.InfiniteLRScheduler

Bases: LRScheduler

Learning rate scheduler for continual pre-training without a fixed budget.

Implements the Infinite Cosine Schedule (arXiv:2503.02844). The key idea is a permanent constant phase that can run indefinitely, enabling continual pre-training without committing to a total step count up front. Annealing is triggered on demand — typically by resuming from a checkpoint with start_annealing=True — so multiple annealed checkpoints can be derived from a single long training run.

The schedule has four sequential phases:

  1. Warmup — linear ramp from 0 to base_lr over warmup_steps.
  2. Cooldown — cosine decay from base_lr to constant_lr over cooldown_steps.
  3. Constant — holds constant_lr indefinitely (the "infinite" phase).
  4. Annealing — decays from constant_lr toward min_lr, triggered at checkpoint_step. Two decay curves are supported:

  5. "exponential" (default) — original paper formula; exponential decay controlled by tau.

  6. "rsqrt" — harmonic/rational decay from the WSD-S paper (arXiv:2410.05192); drops quickly at first then slows.
Notes

start_annealing, annealing_type, annealing_steps, and min_lr are config-only keys: they are taken from the constructor arguments and are not saved to or loaded from checkpoints. This ensures backward compatibility and allows the annealing policy to be changed when resuming.

References

Zhu, Y. et al. (2025). Beyond Cosine Decay: On the effectiveness of Infinite Learning Rate Schedule for Continual Pre-training. arXiv:2503.02844.

Hu, S. et al. (2024). Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective. arXiv:2410.05192.

Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
class InfiniteLRScheduler(LRScheduler):
    """Learning rate scheduler for continual pre-training without a fixed budget.

    Implements the Infinite Cosine Schedule (arXiv:2503.02844).  The key idea
    is a permanent *constant phase* that can run indefinitely, enabling
    continual pre-training without committing to a total step count up front.
    Annealing is triggered on demand — typically by resuming from a checkpoint
    with ``start_annealing=True`` — so multiple annealed checkpoints can be
    derived from a single long training run.

    The schedule has four sequential phases:

    1. **Warmup** — linear ramp from 0 to ``base_lr`` over ``warmup_steps``.
    2. **Cooldown** — cosine decay from ``base_lr`` to ``constant_lr`` over
       ``cooldown_steps``.
    3. **Constant** — holds ``constant_lr`` indefinitely (the "infinite"
       phase).
    4. **Annealing** — decays from ``constant_lr`` toward ``min_lr``,
       triggered at ``checkpoint_step``.  Two decay curves are supported:

       * ``"exponential"`` (default) — original paper formula; exponential
         decay controlled by ``tau``.
       * ``"rsqrt"`` — harmonic/rational decay from the WSD-S paper
         (arXiv:2410.05192); drops quickly at first then slows.

    Notes
    -----
    ``start_annealing``, ``annealing_type``, ``annealing_steps``, and
    ``min_lr`` are *config-only* keys: they are taken from the constructor
    arguments and are not saved to or loaded from checkpoints.  This ensures
    backward compatibility and allows the annealing policy to be changed when
    resuming.

    References
    ----------
    Zhu, Y. et al. (2025). Beyond Cosine Decay: On the effectiveness of
    Infinite Learning Rate Schedule for Continual Pre-training.
    arXiv:2503.02844.

    Hu, S. et al. (2024). Understanding Warmup-Stable-Decay Learning Rates:
    A River Valley Loss Landscape Perspective. arXiv:2410.05192.
    """

    # Config-only keys: set from constructor config, not saved/loaded
    # from checkpoints.
    _CONFIG_ONLY_KEYS = frozenset(
        (
            "start_annealing",
            "annealing_type",
            "annealing_steps",
            "min_lr",
        )
    )

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_steps: int = 0,
        cooldown_steps: int = 0,
        constant_lr: float = 3.75e-5,
        min_lr: float = 1e-8,
        tau: float = 1e4,
        checkpoint_step: int = -1,
        start_annealing: bool = False,
        annealing_type: str = "exponential",
        annealing_steps: int = 0,
        last_epoch: int = -1,
    ):
        """
        Parameters
        ----------
        optimizer : Optimizer
            Wrapped optimizer whose ``param_groups`` LRs will be managed.
        warmup_steps : int, optional
            Number of steps for linear warmup (phase 1).  Default is 0.
        cooldown_steps : int, optional
            Number of steps for cosine decay from ``base_lr`` to
            ``constant_lr`` (phase 2).  Default is 0.
        constant_lr : float, optional
            Learning rate held during the constant phase (phase 3) and the
            starting point for annealing (phase 4).  Corresponds to
            ``eta_const`` in the paper.  Default is 3.75e-5.
        min_lr : float, optional
            Target minimum learning rate reached at the end of annealing.
            Must be > 0.  Corresponds to ``eta_min`` in the paper.
            Default is 1e-8.
        tau : float, optional
            Annealing step budget for exponential annealing.  The LR reaches
            ``min_lr`` after ``tau + checkpoint_step`` steps past
            ``checkpoint_step``.  Corresponds to ``t_a`` in the paper.
            Ignored when ``annealing_type="rsqrt"``.  Default is 1e4.
        checkpoint_step : int, optional
            Step at which to begin annealing (phase 4).  Set to ``-1`` to
            disable annealing.  When enabled, must be >=
            ``warmup_steps + cooldown_steps``.  Corresponds to ``N_d`` in the
            paper.  Default is -1.
        start_annealing : bool, optional
            When ``True``, annealing begins at the current step upon loading a
            checkpoint (provided ``checkpoint_step`` is still ``-1``).  This
            allows triggering annealing retroactively from any saved
            checkpoint.  When ``False``, ``checkpoint_step`` is restored from
            the constructor value.  Config-only: not saved in checkpoints.
            Default is ``False``.
        annealing_type : str, optional
            Decay curve for the annealing phase.  ``"exponential"`` (default)
            uses the original paper formula.  ``"rsqrt"`` uses
            harmonic/rational decay (WSD-S paper), which drops quickly at
            first then slows.  Config-only: not saved in checkpoints.
        annealing_steps : int, optional
            Total annealing steps for ``annealing_type="rsqrt"``.  The LR
            reaches ``min_lr`` after exactly this many steps past
            ``checkpoint_step``.  Must be > 0 for rsqrt annealing; ignored
            for exponential.  Config-only: not saved in checkpoints.
            Default is 0.
        last_epoch : int, optional
            Index of the last epoch, used when resuming.  Default is -1.
        """
        assert warmup_steps >= 0
        assert cooldown_steps >= 0
        assert checkpoint_step < 0 or checkpoint_step >= warmup_steps + cooldown_steps
        assert tau > 0
        assert min_lr > 0.0
        assert constant_lr > 0.0
        assert annealing_type in ("exponential", "rsqrt")
        assert annealing_steps >= 0
        if annealing_type == "rsqrt":
            assert (
                annealing_steps > 0
            ), "annealing_steps must be > 0 for rsqrt annealing"

        self.warmup_steps = warmup_steps
        self.cooldown_steps = cooldown_steps
        self.constant_lr = constant_lr
        self.checkpoint_step = checkpoint_step
        self.min_lr = min_lr
        self.tau = tau
        self.start_annealing = start_annealing
        self.annealing_type = annealing_type
        self.annealing_steps = annealing_steps

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """Compute learning rate for the current step."""
        if self.last_epoch < self.warmup_steps:
            return self._warmup_lr()
        elif self.last_epoch < self.warmup_steps + self.cooldown_steps:
            return self._cooldown_lr()
        elif self.checkpoint_step >= 0 and self.last_epoch >= self.checkpoint_step:
            return self._annealing_lr()
        else:
            return self._constant_lr()

    def _warmup_lr(self):
        """Phase 1: Linear warmup from 0 to base_lr."""
        return [
            base_lr * self.last_epoch / self.warmup_steps
            for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs)
        ]

    def _cooldown_lr(self):
        """Phase 2: Cosine decay from base_lr to constant_lr."""
        return [
            self.constant_lr
            + ((base_lr - self.constant_lr) / 2)
            * (
                1.0
                + math.cos(
                    math.pi
                    * (self.last_epoch - self.warmup_steps)
                    / self.cooldown_steps
                )
            )
            for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs)
        ]

    def _constant_lr(self):
        """Phase 3: Constant learning rate (the "infinite" phase)."""
        if self.cooldown_steps > 0:
            return [self.constant_lr for _ in self.optimizer.param_groups]
        else:
            return [base_lr for base_lr in self.base_lrs]

    def _annealing_lr(self):
        """Phase 4: Decay from constant_lr toward min_lr."""
        if self.annealing_type == "rsqrt":
            return self._annealing_lr_rsqrt()
        return self._annealing_lr_exponential()

    def _annealing_lr_exponential(self):
        """Exponential annealing (original formula).

        From Eq. 1 of the paper:
            eta(n) = eta_const * (eta_min / eta_const) ^ ((n - N_d) / (t_a + N_d))

        where n is the current step, N_d is checkpoint_step, and t_a is
        tau. The LR equals min_lr when n - N_d = t_a + N_d (i.e., after
        tau + checkpoint_step annealing steps).
        """
        steps_since_anneal = self.last_epoch - self.checkpoint_step
        exponent = steps_since_anneal / (self.tau + self.checkpoint_step)
        return [
            self.constant_lr * (self.min_lr / self.constant_lr) ** exponent
            for group in self.optimizer.param_groups
        ]

    def _annealing_lr_rsqrt(self):
        """Harmonic/rational annealing (WSD-S paper).

        Uses linear interpolation of inverse LR:
            t = step / annealing_steps   (progress 0..1)
            lr = 1 / (t / min_lr + (1 - t) / constant_lr)

        Drops quickly at first, then slows — convex decay curve.
        At step=0: lr = constant_lr (smooth transition from constant phase).
        At step=annealing_steps: lr = min_lr (exact target).
        Past annealing_steps: clamped at min_lr.
        """
        step = self.last_epoch - self.checkpoint_step
        t = min(step / self.annealing_steps, 1.0)
        lr = 1.0 / (t / self.min_lr + (1.0 - t) / self.constant_lr)
        return [lr for _ in self.optimizer.param_groups]

    def state_dict(self):
        """Return state dict excluding config-only parameters.

        Config-only parameters (start_annealing, annealing_type,
        annealing_steps) are always determined by the constructor
        arguments, not by checkpoint state. This ensures backward
        compatibility with checkpoints saved before these parameters
        existed.
        """
        return {
            key: value
            for key, value in super().state_dict().items()
            if key not in self._CONFIG_ONLY_KEYS
        }

    def load_state_dict(self, state_dict):
        """Load state dict, preserving config-only parameters.

        Config-only parameters are always taken from the constructor,
        never from the checkpoint. The start_annealing flag controls
        how checkpoint_step is resolved after loading:

        - start_annealing=True with loaded checkpoint_step < 0:
          Begin annealing at the current step (last_epoch).
        - start_annealing=True with loaded checkpoint_step >= 0:
          Resume annealing from where it left off.
        - start_annealing=False:
          Restore checkpoint_step from the constructor value,
          ignoring whatever was saved in the checkpoint.
        """
        saved_config = {key: getattr(self, key) for key in self._CONFIG_ONLY_KEYS}
        saved_checkpoint_step = self.checkpoint_step

        super().load_state_dict(state_dict)

        # Restore config-only keys
        for key, value in saved_config.items():
            setattr(self, key, value)

        # Handle checkpoint_step based on start_annealing flag
        if self.start_annealing:
            # If loaded checkpoint_step is negative, start annealing now
            if self.checkpoint_step < 0:
                self.checkpoint_step = self.last_epoch
            # else: keep the loaded checkpoint_step (resume existing annealing)
        else:
            # Restore constructor's checkpoint_step value
            self.checkpoint_step = saved_checkpoint_step

__init__(optimizer, warmup_steps=0, cooldown_steps=0, constant_lr=3.75e-05, min_lr=1e-08, tau=10000.0, checkpoint_step=-1, start_annealing=False, annealing_type='exponential', annealing_steps=0, last_epoch=-1)

Parameters:

Name Type Description Default
optimizer Optimizer

Wrapped optimizer whose param_groups LRs will be managed.

required
warmup_steps int

Number of steps for linear warmup (phase 1). Default is 0.

0
cooldown_steps int

Number of steps for cosine decay from base_lr to constant_lr (phase 2). Default is 0.

0
constant_lr float

Learning rate held during the constant phase (phase 3) and the starting point for annealing (phase 4). Corresponds to eta_const in the paper. Default is 3.75e-5.

3.75e-05
min_lr float

Target minimum learning rate reached at the end of annealing. Must be > 0. Corresponds to eta_min in the paper. Default is 1e-8.

1e-08
tau float

Annealing step budget for exponential annealing. The LR reaches min_lr after tau + checkpoint_step steps past checkpoint_step. Corresponds to t_a in the paper. Ignored when annealing_type="rsqrt". Default is 1e4.

10000.0
checkpoint_step int

Step at which to begin annealing (phase 4). Set to -1 to disable annealing. When enabled, must be >= warmup_steps + cooldown_steps. Corresponds to N_d in the paper. Default is -1.

-1
start_annealing bool

When True, annealing begins at the current step upon loading a checkpoint (provided checkpoint_step is still -1). This allows triggering annealing retroactively from any saved checkpoint. When False, checkpoint_step is restored from the constructor value. Config-only: not saved in checkpoints. Default is False.

False
annealing_type str

Decay curve for the annealing phase. "exponential" (default) uses the original paper formula. "rsqrt" uses harmonic/rational decay (WSD-S paper), which drops quickly at first then slows. Config-only: not saved in checkpoints.

'exponential'
annealing_steps int

Total annealing steps for annealing_type="rsqrt". The LR reaches min_lr after exactly this many steps past checkpoint_step. Must be > 0 for rsqrt annealing; ignored for exponential. Config-only: not saved in checkpoints. Default is 0.

0
last_epoch int

Index of the last epoch, used when resuming. Default is -1.

-1
Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
def __init__(
    self,
    optimizer: Optimizer,
    warmup_steps: int = 0,
    cooldown_steps: int = 0,
    constant_lr: float = 3.75e-5,
    min_lr: float = 1e-8,
    tau: float = 1e4,
    checkpoint_step: int = -1,
    start_annealing: bool = False,
    annealing_type: str = "exponential",
    annealing_steps: int = 0,
    last_epoch: int = -1,
):
    """
    Parameters
    ----------
    optimizer : Optimizer
        Wrapped optimizer whose ``param_groups`` LRs will be managed.
    warmup_steps : int, optional
        Number of steps for linear warmup (phase 1).  Default is 0.
    cooldown_steps : int, optional
        Number of steps for cosine decay from ``base_lr`` to
        ``constant_lr`` (phase 2).  Default is 0.
    constant_lr : float, optional
        Learning rate held during the constant phase (phase 3) and the
        starting point for annealing (phase 4).  Corresponds to
        ``eta_const`` in the paper.  Default is 3.75e-5.
    min_lr : float, optional
        Target minimum learning rate reached at the end of annealing.
        Must be > 0.  Corresponds to ``eta_min`` in the paper.
        Default is 1e-8.
    tau : float, optional
        Annealing step budget for exponential annealing.  The LR reaches
        ``min_lr`` after ``tau + checkpoint_step`` steps past
        ``checkpoint_step``.  Corresponds to ``t_a`` in the paper.
        Ignored when ``annealing_type="rsqrt"``.  Default is 1e4.
    checkpoint_step : int, optional
        Step at which to begin annealing (phase 4).  Set to ``-1`` to
        disable annealing.  When enabled, must be >=
        ``warmup_steps + cooldown_steps``.  Corresponds to ``N_d`` in the
        paper.  Default is -1.
    start_annealing : bool, optional
        When ``True``, annealing begins at the current step upon loading a
        checkpoint (provided ``checkpoint_step`` is still ``-1``).  This
        allows triggering annealing retroactively from any saved
        checkpoint.  When ``False``, ``checkpoint_step`` is restored from
        the constructor value.  Config-only: not saved in checkpoints.
        Default is ``False``.
    annealing_type : str, optional
        Decay curve for the annealing phase.  ``"exponential"`` (default)
        uses the original paper formula.  ``"rsqrt"`` uses
        harmonic/rational decay (WSD-S paper), which drops quickly at
        first then slows.  Config-only: not saved in checkpoints.
    annealing_steps : int, optional
        Total annealing steps for ``annealing_type="rsqrt"``.  The LR
        reaches ``min_lr`` after exactly this many steps past
        ``checkpoint_step``.  Must be > 0 for rsqrt annealing; ignored
        for exponential.  Config-only: not saved in checkpoints.
        Default is 0.
    last_epoch : int, optional
        Index of the last epoch, used when resuming.  Default is -1.
    """
    assert warmup_steps >= 0
    assert cooldown_steps >= 0
    assert checkpoint_step < 0 or checkpoint_step >= warmup_steps + cooldown_steps
    assert tau > 0
    assert min_lr > 0.0
    assert constant_lr > 0.0
    assert annealing_type in ("exponential", "rsqrt")
    assert annealing_steps >= 0
    if annealing_type == "rsqrt":
        assert (
            annealing_steps > 0
        ), "annealing_steps must be > 0 for rsqrt annealing"

    self.warmup_steps = warmup_steps
    self.cooldown_steps = cooldown_steps
    self.constant_lr = constant_lr
    self.checkpoint_step = checkpoint_step
    self.min_lr = min_lr
    self.tau = tau
    self.start_annealing = start_annealing
    self.annealing_type = annealing_type
    self.annealing_steps = annealing_steps

    super().__init__(optimizer, last_epoch)

get_lr()

Compute learning rate for the current step.

Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
def get_lr(self):
    """Compute learning rate for the current step."""
    if self.last_epoch < self.warmup_steps:
        return self._warmup_lr()
    elif self.last_epoch < self.warmup_steps + self.cooldown_steps:
        return self._cooldown_lr()
    elif self.checkpoint_step >= 0 and self.last_epoch >= self.checkpoint_step:
        return self._annealing_lr()
    else:
        return self._constant_lr()

state_dict()

Return state dict excluding config-only parameters.

Config-only parameters (start_annealing, annealing_type, annealing_steps) are always determined by the constructor arguments, not by checkpoint state. This ensures backward compatibility with checkpoints saved before these parameters existed.

Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
def state_dict(self):
    """Return state dict excluding config-only parameters.

    Config-only parameters (start_annealing, annealing_type,
    annealing_steps) are always determined by the constructor
    arguments, not by checkpoint state. This ensures backward
    compatibility with checkpoints saved before these parameters
    existed.
    """
    return {
        key: value
        for key, value in super().state_dict().items()
        if key not in self._CONFIG_ONLY_KEYS
    }

load_state_dict(state_dict)

Load state dict, preserving config-only parameters.

Config-only parameters are always taken from the constructor, never from the checkpoint. The start_annealing flag controls how checkpoint_step is resolved after loading:

  • start_annealing=True with loaded checkpoint_step < 0: Begin annealing at the current step (last_epoch).
  • start_annealing=True with loaded checkpoint_step >= 0: Resume annealing from where it left off.
  • start_annealing=False: Restore checkpoint_step from the constructor value, ignoring whatever was saved in the checkpoint.
Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
def load_state_dict(self, state_dict):
    """Load state dict, preserving config-only parameters.

    Config-only parameters are always taken from the constructor,
    never from the checkpoint. The start_annealing flag controls
    how checkpoint_step is resolved after loading:

    - start_annealing=True with loaded checkpoint_step < 0:
      Begin annealing at the current step (last_epoch).
    - start_annealing=True with loaded checkpoint_step >= 0:
      Resume annealing from where it left off.
    - start_annealing=False:
      Restore checkpoint_step from the constructor value,
      ignoring whatever was saved in the checkpoint.
    """
    saved_config = {key: getattr(self, key) for key in self._CONFIG_ONLY_KEYS}
    saved_checkpoint_step = self.checkpoint_step

    super().load_state_dict(state_dict)

    # Restore config-only keys
    for key, value in saved_config.items():
        setattr(self, key, value)

    # Handle checkpoint_step based on start_annealing flag
    if self.start_annealing:
        # If loaded checkpoint_step is negative, start annealing now
        if self.checkpoint_step < 0:
            self.checkpoint_step = self.last_epoch
        # else: keep the loaded checkpoint_step (resume existing annealing)
    else:
        # Restore constructor's checkpoint_step value
        self.checkpoint_step = saved_checkpoint_step

forgather.ml.optim.cosine_lr_scheduler.CosineLRScheduler

Bases: LRScheduler

Cosine decay learning rate scheduler with optional linear warmup.

Linearly warms the learning rate from 0 to base_lr over warmup_steps, then applies a half-cosine decay from base_lr to min_lr over the remaining total_steps - warmup_steps steps.

This is the standard schedule for fixed-budget training runs. For continual pre-training without a predetermined budget, prefer InfiniteLRScheduler or WSDScheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

Wrapped optimizer whose param_groups LRs will be managed.

required
total_steps int

Total number of training steps (warmup + decay combined).

required
warmup_steps int

Number of linear warmup steps before cosine decay begins. Default is 0.

0
min_lr float

Minimum learning rate at the end of cosine decay. Default is 0.0.

0.0
last_epoch int

Index of the last epoch, used when resuming. Default is -1.

-1
Source code in src/forgather/ml/optim/cosine_lr_scheduler.py
class CosineLRScheduler(LRScheduler):
    """Cosine decay learning rate scheduler with optional linear warmup.

    Linearly warms the learning rate from 0 to ``base_lr`` over
    ``warmup_steps``, then applies a half-cosine decay from ``base_lr`` to
    ``min_lr`` over the remaining ``total_steps - warmup_steps`` steps.

    This is the standard schedule for fixed-budget training runs.  For
    continual pre-training without a predetermined budget, prefer
    `InfiniteLRScheduler` or `WSDScheduler`.

    Parameters
    ----------
    optimizer : Optimizer
        Wrapped optimizer whose ``param_groups`` LRs will be managed.
    total_steps : int
        Total number of training steps (warmup + decay combined).
    warmup_steps : int, optional
        Number of linear warmup steps before cosine decay begins.
        Default is 0.
    min_lr : float, optional
        Minimum learning rate at the end of cosine decay.  Default is 0.0.
    last_epoch : int, optional
        Index of the last epoch, used when resuming.  Default is -1.
    """

    def __init__(
        self,
        optimizer: Optimizer,
        total_steps: int,
        warmup_steps: int = 0,
        min_lr: float = 0.0,
        last_epoch: int = -1,
    ):
        assert total_steps > 0
        assert 0 <= warmup_steps < total_steps
        assert min_lr >= 0.0

        self.total_steps = total_steps
        self.warmup_steps = warmup_steps
        self.decay_steps = total_steps - warmup_steps
        self.min_lr = min_lr

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = self.last_epoch

        if step < self.warmup_steps:
            scale = step / self.warmup_steps
            return [base_lr * scale for base_lr in self.base_lrs]
        else:
            progress = (step - self.warmup_steps) / self.decay_steps
            scale = 0.5 * (1.0 + math.cos(math.pi * progress))
            return [
                self.min_lr + (base_lr - self.min_lr) * scale
                for base_lr in self.base_lrs
            ]

forgather.ml.optim.wsd_scheduler.WSDScheduler

Bases: LRScheduler

Warmup-Stable-Decay learning rate scheduler.

Implements the WSD-S protocol (Hu et al., arXiv:2410.05192). The stable phase holds base_lr indefinitely, enabling training without a fixed step budget. Decay is triggered on demand — by setting decay_start_step ahead of time or retroactively via start_decay=True when resuming from a checkpoint — so multiple decayed checkpoints can be produced from a single stable-phase run.

The schedule has three sequential phases:

  1. Warmup — linear ramp from 0 to base_lr over warmup_steps.
  2. Stable — holds base_lr indefinitely until decay is triggered.
  3. Decay — harmonic/rational decay from base_lr to min_lr over decay_steps using linear interpolation of inverse LR. The curve drops quickly at first then slows (convex shape).
Notes

start_decay, min_lr, and decay_steps are config-only keys: they are taken from the constructor arguments and are not saved to or loaded from checkpoints. This ensures backward compatibility and allows the decay policy to be changed when resuming.

References

Hu, S. et al. (2024). Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective. arXiv:2410.05192.

Source code in src/forgather/ml/optim/wsd_scheduler.py
class WSDScheduler(LRScheduler):
    """Warmup-Stable-Decay learning rate scheduler.

    Implements the WSD-S protocol (Hu et al., arXiv:2410.05192).  The stable
    phase holds ``base_lr`` indefinitely, enabling training without a fixed
    step budget.  Decay is triggered on demand — by setting
    ``decay_start_step`` ahead of time or retroactively via ``start_decay=True``
    when resuming from a checkpoint — so multiple decayed checkpoints can be
    produced from a single stable-phase run.

    The schedule has three sequential phases:

    1. **Warmup** — linear ramp from 0 to ``base_lr`` over ``warmup_steps``.
    2. **Stable** — holds ``base_lr`` indefinitely until decay is triggered.
    3. **Decay** — harmonic/rational decay from ``base_lr`` to ``min_lr``
       over ``decay_steps`` using linear interpolation of inverse LR.  The
       curve drops quickly at first then slows (convex shape).

    Notes
    -----
    ``start_decay``, ``min_lr``, and ``decay_steps`` are *config-only* keys:
    they are taken from the constructor arguments and are not saved to or
    loaded from checkpoints.  This ensures backward compatibility and allows
    the decay policy to be changed when resuming.

    References
    ----------
    Hu, S. et al. (2024). Understanding Warmup-Stable-Decay Learning Rates:
    A River Valley Loss Landscape Perspective. arXiv:2410.05192.
    """

    # Config-only keys: set from constructor config, not saved/loaded
    # from checkpoints.
    _CONFIG_ONLY_KEYS = frozenset(("start_decay", "min_lr", "decay_steps"))

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_steps: int = 0,
        min_lr: float = 1e-8,
        decay_steps: int = 1,
        decay_start_step: int = -1,
        start_decay: bool = False,
        last_epoch: int = -1,
    ):
        """
        Parameters
        ----------
        optimizer : Optimizer
            Wrapped optimizer whose ``param_groups`` LRs will be managed.
        warmup_steps : int, optional
            Number of steps for linear warmup (phase 1).  Default is 0.
        min_lr : float, optional
            Target minimum learning rate reached at the end of decay.  Must
            be > 0.  Config-only: not saved in checkpoints.  Default is 1e-8.
        decay_steps : int, optional
            Total number of steps in the decay phase.  The LR reaches
            ``min_lr`` after exactly this many steps past
            ``decay_start_step``.  Must be > 0.  Config-only: not saved in
            checkpoints.  Default is 1.
        decay_start_step : int, optional
            Step at which to begin decay (phase 3).  Set to ``-1`` to disable
            decay.  When enabled, must be >= ``warmup_steps``.  Default is -1.
        start_decay : bool, optional
            When ``True``, decay begins at the current step upon loading a
            checkpoint (provided ``decay_start_step`` is still ``-1``).  This
            allows triggering decay retroactively from any saved checkpoint.
            When ``False``, ``decay_start_step`` is restored from the
            constructor value.  Config-only: not saved in checkpoints.
            Default is ``False``.
        last_epoch : int, optional
            Index of the last epoch, used when resuming.  Default is -1.
        """
        assert warmup_steps >= 0
        assert min_lr > 0.0
        assert decay_steps > 0
        assert decay_start_step < 0 or decay_start_step >= warmup_steps

        self.warmup_steps = warmup_steps
        self.min_lr = min_lr
        self.decay_steps = decay_steps
        self.decay_start_step = decay_start_step
        self.start_decay = start_decay

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """Compute learning rate for the current step."""
        if self.last_epoch < self.warmup_steps:
            return self._warmup_lr()
        elif self.decay_start_step >= 0 and self.last_epoch >= self.decay_start_step:
            return self._decay_lr()
        else:
            return self._stable_lr()

    def _warmup_lr(self):
        """Phase 1: Linear warmup from 0 to base_lr."""
        return [
            base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs
        ]

    def _stable_lr(self):
        """Phase 2: Constant learning rate at base_lr."""
        return list(self.base_lrs)

    def _decay_lr(self):
        """Phase 3: Harmonic/rational decay from base_lr to min_lr.

        Uses linear interpolation of inverse LR:
            t = step / decay_steps   (progress 0..1)
            lr = 1 / (t / min_lr + (1 - t) / base_lr)

        At step=0: lr = base_lr (smooth transition from stable phase).
        At step=decay_steps: lr = min_lr (exact target).
        Past decay_steps: clamped at min_lr.
        """
        step = self.last_epoch - self.decay_start_step
        t = min(step / self.decay_steps, 1.0)
        return [
            1.0 / (t / self.min_lr + (1.0 - t) / base_lr) for base_lr in self.base_lrs
        ]

    def state_dict(self):
        """Return state dict excluding config-only parameters."""
        return {
            key: value
            for key, value in super().state_dict().items()
            if key not in self._CONFIG_ONLY_KEYS
        }

    def load_state_dict(self, state_dict):
        """Load state dict, preserving config-only parameters.

        The start_decay flag controls how decay_start_step is resolved:

        - start_decay=True with loaded decay_start_step < 0:
          Begin decay at the current step (last_epoch).
        - start_decay=True with loaded decay_start_step >= 0:
          Resume decay from where it left off.
        - start_decay=False:
          Restore decay_start_step from the constructor value,
          ignoring whatever was saved in the checkpoint.
        """
        saved_config = {key: getattr(self, key) for key in self._CONFIG_ONLY_KEYS}
        saved_decay_start_step = self.decay_start_step

        super().load_state_dict(state_dict)

        for key, value in saved_config.items():
            setattr(self, key, value)

        if self.start_decay:
            if self.decay_start_step < 0:
                self.decay_start_step = self.last_epoch
        else:
            self.decay_start_step = saved_decay_start_step

__init__(optimizer, warmup_steps=0, min_lr=1e-08, decay_steps=1, decay_start_step=-1, start_decay=False, last_epoch=-1)

Parameters:

Name Type Description Default
optimizer Optimizer

Wrapped optimizer whose param_groups LRs will be managed.

required
warmup_steps int

Number of steps for linear warmup (phase 1). Default is 0.

0
min_lr float

Target minimum learning rate reached at the end of decay. Must be > 0. Config-only: not saved in checkpoints. Default is 1e-8.

1e-08
decay_steps int

Total number of steps in the decay phase. The LR reaches min_lr after exactly this many steps past decay_start_step. Must be > 0. Config-only: not saved in checkpoints. Default is 1.

1
decay_start_step int

Step at which to begin decay (phase 3). Set to -1 to disable decay. When enabled, must be >= warmup_steps. Default is -1.

-1
start_decay bool

When True, decay begins at the current step upon loading a checkpoint (provided decay_start_step is still -1). This allows triggering decay retroactively from any saved checkpoint. When False, decay_start_step is restored from the constructor value. Config-only: not saved in checkpoints. Default is False.

False
last_epoch int

Index of the last epoch, used when resuming. Default is -1.

-1
Source code in src/forgather/ml/optim/wsd_scheduler.py
def __init__(
    self,
    optimizer: Optimizer,
    warmup_steps: int = 0,
    min_lr: float = 1e-8,
    decay_steps: int = 1,
    decay_start_step: int = -1,
    start_decay: bool = False,
    last_epoch: int = -1,
):
    """
    Parameters
    ----------
    optimizer : Optimizer
        Wrapped optimizer whose ``param_groups`` LRs will be managed.
    warmup_steps : int, optional
        Number of steps for linear warmup (phase 1).  Default is 0.
    min_lr : float, optional
        Target minimum learning rate reached at the end of decay.  Must
        be > 0.  Config-only: not saved in checkpoints.  Default is 1e-8.
    decay_steps : int, optional
        Total number of steps in the decay phase.  The LR reaches
        ``min_lr`` after exactly this many steps past
        ``decay_start_step``.  Must be > 0.  Config-only: not saved in
        checkpoints.  Default is 1.
    decay_start_step : int, optional
        Step at which to begin decay (phase 3).  Set to ``-1`` to disable
        decay.  When enabled, must be >= ``warmup_steps``.  Default is -1.
    start_decay : bool, optional
        When ``True``, decay begins at the current step upon loading a
        checkpoint (provided ``decay_start_step`` is still ``-1``).  This
        allows triggering decay retroactively from any saved checkpoint.
        When ``False``, ``decay_start_step`` is restored from the
        constructor value.  Config-only: not saved in checkpoints.
        Default is ``False``.
    last_epoch : int, optional
        Index of the last epoch, used when resuming.  Default is -1.
    """
    assert warmup_steps >= 0
    assert min_lr > 0.0
    assert decay_steps > 0
    assert decay_start_step < 0 or decay_start_step >= warmup_steps

    self.warmup_steps = warmup_steps
    self.min_lr = min_lr
    self.decay_steps = decay_steps
    self.decay_start_step = decay_start_step
    self.start_decay = start_decay

    super().__init__(optimizer, last_epoch)

get_lr()

Compute learning rate for the current step.

Source code in src/forgather/ml/optim/wsd_scheduler.py
def get_lr(self):
    """Compute learning rate for the current step."""
    if self.last_epoch < self.warmup_steps:
        return self._warmup_lr()
    elif self.decay_start_step >= 0 and self.last_epoch >= self.decay_start_step:
        return self._decay_lr()
    else:
        return self._stable_lr()

state_dict()

Return state dict excluding config-only parameters.

Source code in src/forgather/ml/optim/wsd_scheduler.py
def state_dict(self):
    """Return state dict excluding config-only parameters."""
    return {
        key: value
        for key, value in super().state_dict().items()
        if key not in self._CONFIG_ONLY_KEYS
    }

load_state_dict(state_dict)

Load state dict, preserving config-only parameters.

The start_decay flag controls how decay_start_step is resolved:

  • start_decay=True with loaded decay_start_step < 0: Begin decay at the current step (last_epoch).
  • start_decay=True with loaded decay_start_step >= 0: Resume decay from where it left off.
  • start_decay=False: Restore decay_start_step from the constructor value, ignoring whatever was saved in the checkpoint.
Source code in src/forgather/ml/optim/wsd_scheduler.py
def load_state_dict(self, state_dict):
    """Load state dict, preserving config-only parameters.

    The start_decay flag controls how decay_start_step is resolved:

    - start_decay=True with loaded decay_start_step < 0:
      Begin decay at the current step (last_epoch).
    - start_decay=True with loaded decay_start_step >= 0:
      Resume decay from where it left off.
    - start_decay=False:
      Restore decay_start_step from the constructor value,
      ignoring whatever was saved in the checkpoint.
    """
    saved_config = {key: getattr(self, key) for key in self._CONFIG_ONLY_KEYS}
    saved_decay_start_step = self.decay_start_step

    super().load_state_dict(state_dict)

    for key, value in saved_config.items():
        setattr(self, key, value)

    if self.start_decay:
        if self.decay_start_step < 0:
            self.decay_start_step = self.last_epoch
    else:
        self.decay_start_step = saved_decay_start_step