Optimizers¶
Forgather ships several optimizers and learning rate schedulers, available as configuration templates or directly via the Python API.
Related documentation:
- Adafactor Triton Performance — performance analysis and benchmarks for the Triton-optimized Adafactor kernel
Optimizers¶
forgather.ml.optim.adamw.AdamW
¶
Bases: Optimizer
AdamW optimizer with optional stochastic rounding for pure-bf16 training.
Implements decoupled weight-decay regularization (Loshchilov & Hutter, arXiv:1711.05101) on top of the Adam update rule (Kingma & Ba, arXiv:1412.6980). The distinguishing feature of this implementation is first-class support for pure bf16 training — parameters, gradients, and optimizer states all stay in bf16, with stochastic rounding (SR) used for every write-back to avoid systematic truncation bias. This eliminates the need for fp32 master-weight copies while retaining most of the numerical quality of mixed-precision training.
Prefer this optimizer over standard torch.optim.AdamW when:
- Training on hardware with fast bf16 throughput and limited memory.
- Running pure-bf16 experiments where fp32 master weights are undesirable.
- Using FSDP2 (DTensor-backed parameters are handled transparently).
Notes
Stochastic rounding is seeded from a dedicated torch.Generator
initialised with a fixed seed (5489) so that all DDP ranks make identical
rounding decisions and parameters stay in sync without extra communication.
The inner _adam kernel is optionally compiled with
torch.compile(..., fullgraph=True) for improved throughput.
References
Kingma, D. & Ba, J. (2014). Adam: A Method for Stochastic Optimization. arXiv:1412.6980.
Loshchilov, I. & Hutter, F. (2017). Decoupled Weight Decay Regularization. arXiv:1711.05101.
Source code in src/forgather/ml/optim/adamw.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | |
__init__(params, lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.01, torch_compile=True, bf16_stochastic_round=True)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
iterable of Parameter
|
Model parameters to optimize. |
required |
lr
|
float
|
Learning rate. Default is 1e-3. |
0.001
|
betas
|
tuple of (float, float)
|
Exponential decay rates for the first and second moment estimates. Default is (0.9, 0.999). |
(0.9, 0.999)
|
eps
|
float
|
Term added to the denominator to improve numerical stability. Default is 1e-6. |
1e-06
|
weight_decay
|
float
|
Decoupled weight-decay coefficient. Default is 0.01. |
0.01
|
torch_compile
|
bool
|
If |
True
|
bf16_stochastic_round
|
bool
|
If |
True
|
Source code in src/forgather/ml/optim/adamw.py
state_dict()
¶
Return optimizer state with structure validation.
Source code in src/forgather/ml/optim/adamw.py
load_state_dict(state_dict)
¶
Load optimizer state with validation.
Source code in src/forgather/ml/optim/adamw.py
forgather.ml.optim.adafactor.Adafactor
¶
Bases: Optimizer
Memory-efficient adaptive optimizer with factored second-moment estimation.
Implements the Adafactor algorithm (Shazeer & Stern, arXiv:1804.04235). For matrices, the second-moment accumulator is factored into outer-product row and column vectors, reducing per-parameter memory from O(n*m) to O(n+m). For vectors and scalars the full accumulator is retained.
Like AdamW, this implementation supports pure bf16 training via
stochastic rounding on all write-backs, and handles FSDP2 DTensor
parameters transparently. An optional Triton kernel path is available
for higher GPU throughput on CUDA devices.
Prefer Adafactor over AdamW when:
- Memory is the primary constraint (large models, small accelerators).
- Training transformers with large embedding or projection matrices where the factored approximation is a good fit.
Notes
decay_rate controls how the effective beta2 grows with step count:
beta2t = clamp(1 - step^decay_rate, max=beta2). The default of
-0.8 replicates the schedule from the paper.
The Triton kernel path (use_triton=True) does not support
relative_step=True.
References
Shazeer, N. & Stern, M. (2018). Adafactor: Adaptive Learning Rates with Sublinear Memory Cost. arXiv:1804.04235.
Loshchilov, I. & Hutter, F. (2017). Decoupled Weight Decay Regularization. arXiv:1711.05101.
Source code in src/forgather/ml/optim/adafactor.py
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 | |
__init__(params, lr=0.001, decay_rate=-0.8, clip_threshold=1.0, betas=(0.9, 0.999), eps=(1e-30, 0.001), weight_decay=0.01, relative_step=False, torch_compile=True, bf16_stochastic_round=True, use_triton=False)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
iterable of Parameter
|
Model parameters to optimize. |
required |
lr
|
float
|
Learning rate (or relative step size when |
0.001
|
decay_rate
|
float
|
Exponent controlling how the effective |
-0.8
|
clip_threshold
|
float
|
Root-mean-square threshold for gradient clipping. The update is scaled down when its RMS exceeds this value. Default is 1.0. |
1.0
|
betas
|
tuple of (float, float)
|
Upper bounds for the first and second moment decay rates.
|
(0.9, 0.999)
|
eps
|
tuple of (float, float)
|
|
(1e-30, 0.001)
|
weight_decay
|
float
|
Decoupled weight-decay coefficient. Default is 0.01. |
0.01
|
relative_step
|
bool
|
If |
False
|
torch_compile
|
bool
|
If |
True
|
bf16_stochastic_round
|
bool
|
Enable stochastic rounding for bf16 write-backs. Has no effect
when parameters are fp32. Default is |
True
|
use_triton
|
bool
|
If |
False
|
Source code in src/forgather/ml/optim/adafactor.py
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | |
state_dict()
¶
Return optimizer state handling conditional col=None.
Source code in src/forgather/ml/optim/adafactor.py
load_state_dict(state_dict)
¶
Load optimizer state handling conditional col=None.
Source code in src/forgather/ml/optim/adafactor.py
forgather.ml.optim.apollo.Apollo
¶
Bases: Optimizer
Low-rank gradient-projection optimizer with AdamW-level performance.
Implements the Apollo algorithm (Zhu et al., arXiv:2412.05270). Rather
than maintaining full-size first and second moment buffers, Apollo projects
gradients into a low-rank subspace (controlled by rank), runs the
Adam update there, and uses the resulting per-column scaling signal to
scale the full-rank gradient. Moment buffer memory scales as
O(rank * max(n, m)) instead of O(n * m).
Also applies the Norm-Growth Limiter from Fira (arXiv:2410.01623) to prevent destructive gradient updates.
Prefer Apollo over AdamW when:
- Memory is constrained and Adafactor's factored approximation is too aggressive (Apollo retains the full gradient for the parameter update).
rank=1(Apollo-Mini) is desired for maximum memory savings while still outperforming SGD.
Notes
The projector_factory callable is not serialisable and is therefore
stripped from checkpoints. It must be supplied again via the constructor
when resuming from a checkpoint.
References
Zhu, W. et al. (2024). APOLLO: SGD-like Memory, AdamW-level Performance. arXiv:2412.05270.
Chen, Y. et al. (2024). Fira: Can We Achieve Full-Rank Training of LLMs Under Low-Rank Constraint? arXiv:2410.01623.
Source code in src/forgather/ml/optim/apollo.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 | |
__init__(params, lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0, rank=1, scale=1.0, scale_front=False, update_steps=10, mini=False, projector_factory=None)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
iterable of Parameter
|
Model parameters to optimize. |
required |
lr
|
float
|
Learning rate. Default is 1e-3. |
0.001
|
betas
|
tuple of (float, float)
|
Exponential decay rates for the low-rank first and second moment estimates. Default is (0.9, 0.999). |
(0.9, 0.999)
|
eps
|
float
|
Term added to the denominator of the Adam update in the low-rank subspace. Default is 1e-6. |
1e-06
|
weight_decay
|
float
|
Decoupled weight-decay coefficient. Default is 0.0. |
0.0
|
rank
|
int
|
Rank of the gradient projection subspace. Lower rank saves more
memory; |
1
|
scale
|
float
|
Additional scaling factor applied to the update. Applied before
the Norm-Growth Limiter when |
1.0
|
scale_front
|
bool
|
If |
False
|
update_steps
|
int
|
How often (in optimizer steps) the projection matrix is refreshed.
Passed through to the projector created by |
10
|
mini
|
bool
|
If |
False
|
projector_factory
|
callable
|
Factory that constructs the gradient projector given keyword
arguments |
None
|
Source code in src/forgather/ml/optim/apollo.py
state_dict()
¶
Return optimizer state with serialized projector objects.
Projector objects are converted to dicts containing only tensors and primitives to ensure proper checkpoint serialization.
Note: The projector_factory in param_groups is removed since it's a non-serializable function. On load_state_dict, it must be provided via the optimizer constructor.
Source code in src/forgather/ml/optim/apollo.py
load_state_dict(state_dict)
¶
Load optimizer state and reconstruct projector objects.
Deserializes projector dicts back into projector objects.
Source code in src/forgather/ml/optim/apollo.py
forgather.ml.optim.sgd.SGD
¶
Bases: Optimizer
Minimal vanilla SGD optimizer.
Applies the plain stochastic gradient descent update rule:
``p = p - lr * grad``
No momentum, weight decay, or gradient clipping. Intended as a minimal
reference implementation and starting point for custom optimizers. For
production training, prefer AdamW or Adafactor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
iterable of Parameter
|
Model parameters to optimize. |
required |
lr
|
float
|
Learning rate. Default is 1e-3. |
0.001
|
Source code in src/forgather/ml/optim/sgd.py
Schedulers¶
forgather.ml.optim.infinite_lr_scheduler.InfiniteLRScheduler
¶
Bases: LRScheduler
Learning rate scheduler for continual pre-training without a fixed budget.
Implements the Infinite Cosine Schedule (arXiv:2503.02844). The key idea
is a permanent constant phase that can run indefinitely, enabling
continual pre-training without committing to a total step count up front.
Annealing is triggered on demand — typically by resuming from a checkpoint
with start_annealing=True — so multiple annealed checkpoints can be
derived from a single long training run.
The schedule has four sequential phases:
- Warmup — linear ramp from 0 to
base_lroverwarmup_steps. - Cooldown — cosine decay from
base_lrtoconstant_lrovercooldown_steps. - Constant — holds
constant_lrindefinitely (the "infinite" phase). -
Annealing — decays from
constant_lrtowardmin_lr, triggered atcheckpoint_step. Two decay curves are supported: -
"exponential"(default) — original paper formula; exponential decay controlled bytau. "rsqrt"— harmonic/rational decay from the WSD-S paper (arXiv:2410.05192); drops quickly at first then slows.
Notes
start_annealing, annealing_type, annealing_steps, and
min_lr are config-only keys: they are taken from the constructor
arguments and are not saved to or loaded from checkpoints. This ensures
backward compatibility and allows the annealing policy to be changed when
resuming.
References
Zhu, Y. et al. (2025). Beyond Cosine Decay: On the effectiveness of Infinite Learning Rate Schedule for Continual Pre-training. arXiv:2503.02844.
Hu, S. et al. (2024). Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective. arXiv:2410.05192.
Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | |
__init__(optimizer, warmup_steps=0, cooldown_steps=0, constant_lr=3.75e-05, min_lr=1e-08, tau=10000.0, checkpoint_step=-1, start_annealing=False, annealing_type='exponential', annealing_steps=0, last_epoch=-1)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
optimizer
|
Optimizer
|
Wrapped optimizer whose |
required |
warmup_steps
|
int
|
Number of steps for linear warmup (phase 1). Default is 0. |
0
|
cooldown_steps
|
int
|
Number of steps for cosine decay from |
0
|
constant_lr
|
float
|
Learning rate held during the constant phase (phase 3) and the
starting point for annealing (phase 4). Corresponds to
|
3.75e-05
|
min_lr
|
float
|
Target minimum learning rate reached at the end of annealing.
Must be > 0. Corresponds to |
1e-08
|
tau
|
float
|
Annealing step budget for exponential annealing. The LR reaches
|
10000.0
|
checkpoint_step
|
int
|
Step at which to begin annealing (phase 4). Set to |
-1
|
start_annealing
|
bool
|
When |
False
|
annealing_type
|
str
|
Decay curve for the annealing phase. |
'exponential'
|
annealing_steps
|
int
|
Total annealing steps for |
0
|
last_epoch
|
int
|
Index of the last epoch, used when resuming. Default is -1. |
-1
|
Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | |
get_lr()
¶
Compute learning rate for the current step.
Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
state_dict()
¶
Return state dict excluding config-only parameters.
Config-only parameters (start_annealing, annealing_type, annealing_steps) are always determined by the constructor arguments, not by checkpoint state. This ensures backward compatibility with checkpoints saved before these parameters existed.
Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
load_state_dict(state_dict)
¶
Load state dict, preserving config-only parameters.
Config-only parameters are always taken from the constructor, never from the checkpoint. The start_annealing flag controls how checkpoint_step is resolved after loading:
- start_annealing=True with loaded checkpoint_step < 0: Begin annealing at the current step (last_epoch).
- start_annealing=True with loaded checkpoint_step >= 0: Resume annealing from where it left off.
- start_annealing=False: Restore checkpoint_step from the constructor value, ignoring whatever was saved in the checkpoint.
Source code in src/forgather/ml/optim/infinite_lr_scheduler.py
forgather.ml.optim.cosine_lr_scheduler.CosineLRScheduler
¶
Bases: LRScheduler
Cosine decay learning rate scheduler with optional linear warmup.
Linearly warms the learning rate from 0 to base_lr over
warmup_steps, then applies a half-cosine decay from base_lr to
min_lr over the remaining total_steps - warmup_steps steps.
This is the standard schedule for fixed-budget training runs. For
continual pre-training without a predetermined budget, prefer
InfiniteLRScheduler or WSDScheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
optimizer
|
Optimizer
|
Wrapped optimizer whose |
required |
total_steps
|
int
|
Total number of training steps (warmup + decay combined). |
required |
warmup_steps
|
int
|
Number of linear warmup steps before cosine decay begins. Default is 0. |
0
|
min_lr
|
float
|
Minimum learning rate at the end of cosine decay. Default is 0.0. |
0.0
|
last_epoch
|
int
|
Index of the last epoch, used when resuming. Default is -1. |
-1
|
Source code in src/forgather/ml/optim/cosine_lr_scheduler.py
forgather.ml.optim.wsd_scheduler.WSDScheduler
¶
Bases: LRScheduler
Warmup-Stable-Decay learning rate scheduler.
Implements the WSD-S protocol (Hu et al., arXiv:2410.05192). The stable
phase holds base_lr indefinitely, enabling training without a fixed
step budget. Decay is triggered on demand — by setting
decay_start_step ahead of time or retroactively via start_decay=True
when resuming from a checkpoint — so multiple decayed checkpoints can be
produced from a single stable-phase run.
The schedule has three sequential phases:
- Warmup — linear ramp from 0 to
base_lroverwarmup_steps. - Stable — holds
base_lrindefinitely until decay is triggered. - Decay — harmonic/rational decay from
base_lrtomin_lroverdecay_stepsusing linear interpolation of inverse LR. The curve drops quickly at first then slows (convex shape).
Notes
start_decay, min_lr, and decay_steps are config-only keys:
they are taken from the constructor arguments and are not saved to or
loaded from checkpoints. This ensures backward compatibility and allows
the decay policy to be changed when resuming.
References
Hu, S. et al. (2024). Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective. arXiv:2410.05192.
Source code in src/forgather/ml/optim/wsd_scheduler.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | |
__init__(optimizer, warmup_steps=0, min_lr=1e-08, decay_steps=1, decay_start_step=-1, start_decay=False, last_epoch=-1)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
optimizer
|
Optimizer
|
Wrapped optimizer whose |
required |
warmup_steps
|
int
|
Number of steps for linear warmup (phase 1). Default is 0. |
0
|
min_lr
|
float
|
Target minimum learning rate reached at the end of decay. Must be > 0. Config-only: not saved in checkpoints. Default is 1e-8. |
1e-08
|
decay_steps
|
int
|
Total number of steps in the decay phase. The LR reaches
|
1
|
decay_start_step
|
int
|
Step at which to begin decay (phase 3). Set to |
-1
|
start_decay
|
bool
|
When |
False
|
last_epoch
|
int
|
Index of the last epoch, used when resuming. Default is -1. |
-1
|
Source code in src/forgather/ml/optim/wsd_scheduler.py
get_lr()
¶
Compute learning rate for the current step.
Source code in src/forgather/ml/optim/wsd_scheduler.py
state_dict()
¶
Return state dict excluding config-only parameters.
load_state_dict(state_dict)
¶
Load state dict, preserving config-only parameters.
The start_decay flag controls how decay_start_step is resolved:
- start_decay=True with loaded decay_start_step < 0: Begin decay at the current step (last_epoch).
- start_decay=True with loaded decay_start_step >= 0: Resume decay from where it left off.
- start_decay=False: Restore decay_start_step from the constructor value, ignoring whatever was saved in the checkpoint.