Trainer Callbacks¶
Callbacks extend trainer behaviour at well-defined lifecycle events (step start/end,
epoch start/end, evaluation, checkpoint save/load, etc.) without modifying trainer
source. Pass a list of callback instances to any trainer's callbacks argument.
Related documentation:
- Trainer Control — saving, stopping, and aborting running jobs
- Checkpointing — stateful callbacks and checkpoint resume
- Divergence Detection — detecting and recovering from training instability
- Log Analysis — working with logs produced by
JsonLogger
Base Class¶
forgather.ml.trainer.trainer_types.TrainerCallback
¶
Base class for trainer event callbacks.
Subclasses implement only the event methods they need. Any method not defined is simply never called for that callback. The trainer maintains a lazy index mapping event names to the callbacks that define handlers, so only relevant callbacks are invoked per event.
Available events (each receives args, state, control, **kwargs and may return None or an updated TrainerControl):
on_init_end - After trainer initialization
on_train_begin - Before training loop starts
on_train_end - After training loop ends
on_epoch_begin - Before each epoch
on_epoch_end - After each epoch
on_step_begin - Before each training step
on_step_end - After each training step
on_substep_end - After each gradient-accumulation sub-step
on_forward_backward_begin - Before each forward+backward micro-step
(inside gradient accumulation loop, after data
loading; fires once per micro-batch)
on_forward_backward_end - After each forward+backward micro-step
(before optimizer, grad clipping, LR scheduler;
fires once per micro-batch)
on_optimizer_step - After optimizer.step()
on_pre_optimizer_step - Before optimizer.step()
on_evaluate - After evaluation
on_predict - After prediction (also receives metrics)
on_prediction_step - After each prediction batch
on_save - After checkpoint save
on_log - After metric logging (receives logs kwarg)
Forgather extensions (not in HF Trainer):
on_log_step - Called before on_log; receives the mutable logs dict
so callbacks can inject custom metrics before logging
on_train_metrics - Called each training step with per-step metrics
(loss, grad_norm, tokens, etc.) for fine-grained
monitoring or adaptive control
kwargs always include: model, processing_class, optimizer, lr_scheduler, train_dataloader, eval_dataloader, trainer
Compatible with HuggingFace TrainerCallback for easier porting. See: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py
Identification:
Each callback has a name property used for logging (e.g. when
the trainer reports which callback requested an early stop).
Subclasses inherit the default, which returns the class name, or
can override name to provide a more descriptive label (useful
when multiple instances of the same class are registered with
different configurations).
Source code in src/forgather/ml/trainer/trainer_types.py
name
property
¶
Human-readable identifier for this callback, used in log messages.
Built-in Callbacks¶
Default callbacks included automatically by all trainers.
forgather.ml.trainer.callbacks.DefaultMetrics
¶
Bases: TrainerCallback
Compute derived performance metrics and inject them into logs.
Runs via on_log_step (before on_log), so computed values are
available to all downstream loggers (ProgressCallback, TBLogger, etc.).
Computed metrics:
tok_per_sec -- tokens processed per wall-clock second between log steps.
mfu -- Model FLOPs Utilization (requires peak_hardware_flops).
peak_mem -- per-rank peak CUDA memory allocated (list of bytes),
aliased from peak_mem_allocated for display
formatting (default reduction: max across ranks).
Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | |
__init__(peak_hardware_flops=None)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
peak_hardware_flops
|
float
|
Aggregate peak BF16 FLOP/s across all GPUs used in training,
used to compute MFU. Must be the total across all ranks since
|
None
|
Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
forgather.ml.trainer.callbacks.ProgressCallback
¶
Bases: TrainerCallback
A TQDM progress-bar callback class based upon: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py
Controls which metrics are displayed in console logs during training via configurable column specifications. All metrics are still logged to JsonLogger regardless of display settings.
Derived performance metrics (tok/s, MFU, peak_mem) are computed by
DefaultMetrics via on_log_step and are available in the logs
dict by the time on_log fires.
Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 | |
__init__(use_tqdm=None, output_stream=None, step_columns=None, final_metrics=None, header_interval=20)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
use_tqdm
|
bool
|
If |
None
|
output_stream
|
OutputStream
|
The output stream to use when not using TQDM. |
None
|
step_columns
|
dict
|
Column spec overrides merged with |
None
|
final_metrics
|
dict
|
Final metric spec overrides merged with
|
None
|
header_interval
|
int
|
Print a column header row every this many log steps, and also
whenever the set of active columns changes. Default is |
20
|
Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 | |
forgather.ml.trainer.callbacks.InfoCallback
¶
Bases: TrainerCallback
Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
Job Control¶
forgather.ml.trainer.callbacks.TrainerControlCallback
¶
Bases: TrainerCallback
Callback that enables external control of training jobs via HTTP API.
Features: - Graceful stop: Stop training cleanly after current step - Save checkpoint: Trigger checkpoint save (with evaluation if needed) - Save and stop: Save checkpoint then stop training - Status queries: Get current training status
Only rank 0 runs the HTTP server. Commands are broadcast to all ranks via torch.distributed for coordination.
Source code in src/forgather/ml/trainer/callbacks/control_callback.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 | |
__init__(job_id=None, port=None, enable_http=None, host=None, auth_token=None, disable_auth=False)
¶
Initialize the control callback.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
job_id
|
str
|
Unique identifier for this training job. Auto-generated if |
None
|
port
|
int
|
HTTP server port. Auto-selected if |
None
|
enable_http
|
bool
|
Whether to enable HTTP server. Auto-detected based on |
None
|
host
|
str
|
Bind address. Defaults to |
None
|
auth_token
|
str
|
Pre-shared bearer token. Generated via |
None
|
disable_auth
|
bool
|
Skip bearer-token enforcement on /control, /status, /jobs. Off by default; the trainer logs a warning when this is set. |
False
|
Source code in src/forgather/ml/trainer/callbacks/control_callback.py
on_train_begin(args, state, control, **kwargs)
¶
Initialize control system when training begins.
Source code in src/forgather/ml/trainer/callbacks/control_callback.py
on_log(args, state, control, logs=None, **kwargs)
¶
Check for control commands on each log event.
Source code in src/forgather/ml/trainer/callbacks/control_callback.py
on_train_end(args, state, control, **kwargs)
¶
Clean up when training ends.
Source code in src/forgather/ml/trainer/callbacks/control_callback.py
Divergence Detection¶
forgather.ml.trainer.callbacks.DivergenceDetector
¶
Bases: TrainerCallback, Stateful
Detects training divergence by comparing smoothed loss against its best observed value.
Maintains a smoothed loss (EMA) and tracks its running minimum. Triggers when
the smoothed loss exceeds the baseline minimum by a configurable threshold,
sustained for patience consecutive observations.
Supports absolute threshold (smoothed - best >= threshold), relative threshold (smoothed >= best * factor), or both simultaneously (triggers on whichever fires first).
Also detects NaN/Inf loss values immediately (no patience required).
Defaults are calibrated against real training runs where loss decreases from ~10 to ~3.8 then spikes to ~9.7 on divergence. With default settings (smoothing=0.3, threshold=1.0, patience=3), divergence is detected within 3 log entries (~96 training steps at 32-step log intervals) of the spike, with zero false positives on healthy runs.
Examples:
>>> detector = DivergenceDetector(
... smoothing=0.3, # EMA alpha (higher = more responsive)
... threshold=1.0, # Absolute: stop if smoothed - best >= 1.0
... patience=3, # Require 3 consecutive observations
... action="abort",
... )
>>> trainer = Trainer(..., callbacks=[detector])
>>> trainer.train()
Using relative threshold (e.g., 50% increase from best):
>>> detector = DivergenceDetector(
... smoothing=0.3,
... relative_threshold=1.5, # Stop if smoothed >= 1.5 * best
... patience=3,
... action="stop",
... )
Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | |
__init__(smoothing=0.3, threshold=1.0, relative_threshold=None, patience=3, warmup=10, action='stop', use_eval_loss=False, metric_key=None)
¶
Initialize divergence detector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
smoothing
|
float
|
EMA alpha for smoothing raw loss (0–1). Higher = more responsive to recent values. Effective window ~ 1/alpha observations. |
0.3
|
threshold
|
float or None
|
Absolute divergence threshold. Triggers when
|
1.0
|
relative_threshold
|
float or None
|
Relative divergence threshold. Triggers when
|
None
|
patience
|
int
|
Number of consecutive observations above threshold required
before triggering. Higher values reduce false positives from
transient spikes. Set to |
3
|
warmup
|
int
|
Number of initial observations to skip before checking divergence. Avoids false positives from the high-loss early training phase. |
10
|
action
|
(stop, abort)
|
What to do when divergence is detected. |
"stop"
|
use_eval_loss
|
bool
|
If |
False
|
metric_key
|
str
|
Custom metric key to monitor (overrides |
None
|
Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
on_log(args, state, control, logs=None, **kwargs)
¶
Check for divergence when training metrics are logged.
on_evaluate(args, state, control, metrics=None, **kwargs)
¶
Check for divergence when evaluation metrics are available.
state_dict()
¶
Return callback state to save with checkpoint.
Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
load_state_dict(state_dict)
¶
Restore callback state from checkpoint.
Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
Logging¶
forgather.ml.trainer.callbacks.JsonLogger
¶
Bases: TrainerCallback, Stateful
A JSON logger callback that writes training metrics to a JSON file.
Writes a JSON record (with UTC timestamp, global_step, epoch, and all
reported metrics) each time on_log or on_evaluate is called.
Implements the Stateful protocol so that the log file path and last
written step are saved with checkpoints. When training resumes from a
checkpoint, the logger reopens the original file, truncates any entries
recorded after the checkpoint step, and continues appending.
Source code in src/forgather/ml/trainer/callbacks/json_logger.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | |
__init__(**kwargs)
¶
The contents of kwargs will be recorded when training starts
Source code in src/forgather/ml/trainer/callbacks/json_logger.py
forgather.ml.trainer.callbacks.TBLogger
¶
Bases: TrainerCallback
A Trainer callback that logs scalars to TensorBoard.
Scalars are configured as a dict mapping TensorBoard tags to spec
dicts with optional source and transform fields. The dict
is merged with default_tb_scalars() so only deltas need to be
specified. Set a key to None to erase a default scalar.
Source code in src/forgather/ml/trainer/callbacks/tb_logger.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | |
mapping_as_markdown(mapping)
staticmethod
¶
Format dictionary as markdown
Tensorboard expects text to be in markdown format...
Source code in src/forgather/ml/trainer/callbacks/tb_logger.py
forgather.ml.trainer.callbacks.GradNormLogger
¶
Bases: TrainerCallback, Stateful
Logs per-parameter gradient L2 norms to a JSON file.
Gradient norms are captured in on_pre_optimizer_step (after gradient
clipping, before optimizer step and zero_grad) and written to the log
file in on_evaluate. This means gradient data is logged at eval
frequency, keeping overhead minimal.
The log file uses JSON array format with checkpoint resume support via the Stateful protocol.
When fuse_optim_with_backward is enabled, gradients are consumed
during the backward pass and are not available for capture. The callback
detects this and disables itself with a warning.
Source code in src/forgather/ml/trainer/callbacks/grad_logger.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | |
on_pre_optimizer_step(args, state, control, **kwargs)
¶
Capture per-parameter gradient norms before optimizer step.
Source code in src/forgather/ml/trainer/callbacks/grad_logger.py
on_evaluate(args, state, control, **kwargs)
¶
Write buffered gradient norms to log file.
Source code in src/forgather/ml/trainer/callbacks/grad_logger.py
forgather.ml.trainer.callbacks.ParameterNormLogger
¶
Bases: TrainerCallback, Stateful
Logs per-parameter L2 norms and/or spectral norms to a JSON file.
Data is written on each evaluation step. The log file uses JSON array format with checkpoint resume support via the Stateful protocol.
The existing WeightNormLogger continues to handle the total
parameter norm for TensorBoard/console logging. This callback provides
the per-parameter breakdown for diagnostic analysis and heatmap
visualization.
In pipeline-parallel training the model shell passed to callbacks contains only meta-device tensors. This callback detects that case, warns once, and skips logging for the remainder of training.
Source code in src/forgather/ml/trainer/callbacks/parameter_norm_logger.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | |
__init__(log_norms=True, log_spectral_norms=True, power_iter_steps=10)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_norms
|
bool
|
Whether to log per-parameter L2 norms. |
True
|
log_spectral_norms
|
bool
|
Whether to log per-parameter spectral norms. |
True
|
power_iter_steps
|
int
|
Number of power iteration steps for spectral norm estimation. First evaluation uses 2x this value for cold-start convergence. |
10
|
Source code in src/forgather/ml/trainer/callbacks/parameter_norm_logger.py
forgather.ml.trainer.callbacks.WeightNormLogger
¶
Bases: TrainerCallback
Logs the total L2 norm of all model parameters to logs after each evaluation step.
Computed identically to the gradient norm but using the weight tensors themselves. A growing value over training indicates that weights are increasing in magnitude, which usually means weight decay is too weak. A stable or shrinking value while gradient norms rise points to a different cause.
In pipeline-parallel training the model shell passed to callbacks contains only meta-device tensors. This callback detects that case, warns once, and skips logging for the remainder of training.
Source code in src/forgather/ml/trainer/callbacks/weight_norm_logger.py
forgather.ml.trainer.callbacks.PeakMemory
¶
Bases: TrainerCallback
PeakMemory is a TrainerCallback for monitoring and logging the peak CUDA memory usage during model training. This callback is designed to help diagnose and optimize GPU memory consumption in PyTorch-based training loops, especially when using distributed training. It records the maximum memory allocated on each GPU device throughout the training process, and can optionally log detailed memory statistics and write them to TensorBoard for visualization.
IMPORTANT: Memory history recording is disabled by default to prevent memory leaks. The torch.cuda.memory._record_memory_history feature can consume 1GB+ of memory during training.
Key Features: - Tracks the peak CUDA memory allocated on each GPU during training. - Supports both single-GPU and multi-GPU (distributed) training environments. - Optionally logs detailed CUDA memory statistics for further analysis. - Can write memory usage metrics to a TensorBoard SummaryWriter for visualization. - Provides configurable logging frequency and verbosity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
summary_writer
|
SummaryWriter
|
TensorBoard |
None
|
show_details
|
bool
|
If |
False
|
do_log
|
bool
|
If |
False
|
enable_memory_snapshot
|
bool
|
If |
False
|
file_prefix
|
str
|
Filename prefix for the per-rank memory snapshot pickle.
Defaults to |
'memory_snapshot'
|
Attributes:
| Name | Type | Description |
|---|---|---|
rank |
int
|
The process rank in distributed training. |
world_size |
int
|
The total number of processes in distributed training. |
summary_writer |
SummaryWriter or None
|
The TensorBoard |
enabled |
bool
|
Whether CUDA is available and memory tracking is enabled. |
show_details |
bool
|
Whether to log detailed memory statistics. |
do_log |
bool
|
Whether to log memory usage on each log step. |
enable_memory_snapshot |
bool
|
Whether memory history recording and snapshot dumping is enabled. |
max_allocated |
int
|
The maximum CUDA memory allocated during training (in bytes). |
Source code in src/forgather/ml/trainer/callbacks/peak_memory.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
__init__(summary_writer=None, show_details=False, do_log=False, enable_memory_snapshot=False, file_prefix='memory_snapshot')
¶
:param summary_writer: Optional TensorBoard SummaryWriter to log peak memory :param show_details: Whether to log detailed memory stats :param do_log: Whether to log on each log step :param enable_memory_snapshot: Whether to enable CUDA memory history recording and snapshot
Source code in src/forgather/ml/trainer/callbacks/peak_memory.py
Text Generation¶
forgather.ml.trainer.callbacks.TextgenCallback
¶
Bases: TrainerCallback
Periodically generate and log text from a set of prompts for subjective model evaluation.
Automatically dispatches between single-rank generation (via model.generate()) and pipeline-parallel generation (via trainer.pipeline_generate()) based on whether the trainer exposes a pipeline_generate method. The same callback works unchanged with SimpleTrainer, AccelTrainer/DDPTrainer, and PipelineTrainer.
Source code in src/forgather/ml/trainer/callbacks/textgen_callback.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | |
__init__(summary_writer, prompts, generation_config=None, generation_steps=None, max_new_tokens=200)
¶
Periodically generates and logs text from a set a prompts for subjective model evaluation
This may only trigger on model evaluation steps, which establishes the minimum interval between generations.
args: summary_writer: The Tensor Board SummaryWriter to log to. prompts: Either a list of prompts (List[str]) or a path to a YAML file, defining a list of prompts. generation_config: A dictionary with arguments to HF GenerationConfig generation_steps: The number of steps between generations. If None, it defaults to eval_steps max_new_tokens: The maximum new tokens to generate for each prompt.
Source code in src/forgather/ml/trainer/callbacks/textgen_callback.py
Advanced¶
forgather.ml.trainer.callbacks.ProfilerCallback
¶
Profiles training steps and exports Chrome traces + summary tables.
Source code in src/forgather/ml/trainer/callbacks/profiler_callback.py
forgather.ml.trainer.callbacks.DiLoCoCallback
¶
Bases: TrainerCallback
Trainer callback that manages a DiLoCoWorker for distributed local-SGD training.
Implements the Stateful protocol for checkpoint persistence. The checkpoint manager auto-discovers Stateful callbacks and saves/restores their state.
When server_addr is empty (and DILOCO_SERVER is unset), all methods are
no-ops. This allows a single training configuration to work both with and
without a DiLoCo server.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
server_addr
|
str
|
DiLoCo server address ( |
None
|
sync_every
|
int
|
Local optimizer steps between syncs. Falls back to
|
None
|
worker_id
|
str
|
Unique worker ID. Falls back to |
None
|
bf16_comm
|
bool
|
Cast pseudo-gradients to bfloat16. Falls back to
|
None
|
dylu
|
bool
|
Enable Dynamic Local Updates. Falls back to |
None
|
heartbeat_interval
|
float
|
Seconds between heartbeats. Falls back to
|
None
|
num_fragments
|
int
|
Number of streaming fragments. Falls back to
|
None
|
timeout
|
float
|
Client timeout in seconds. Default |
600
|
max_sync_retries
|
int
|
Max retries for sync failures. Default |
3
|
Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | |
active
property
¶
Whether DiLoCo integration is configured (server_addr is set).
on_train_begin(args, state, control, **kwargs)
¶
Create and start the DiLoCoWorker.
Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
on_log(args, state, control, logs=None, **kwargs)
¶
Inject DiLoCo sync metrics into the logs dict.
Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
on_train_end(args, state, control, **kwargs)
¶
Stop the DiLoCoWorker.
Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
state_dict()
¶
Save DiLoCo state for checkpointing.
Does NOT save global_params snapshot -- the server provides fresh params when the worker re-registers on resume.
Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
load_state_dict(state_dict)
¶
Defer state restoration until on_train_begin.
Checkpoint loading happens during _prepare() before on_train_begin, so the worker doesn't exist yet. We store the state and apply it once the worker is created.
Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
forgather.ml.trainer.callbacks.ResumableSummaryWriter
¶
Bases: TrainerCallback, Stateful
A lazy, resumable wrapper around TensorBoard SummaryWriter.
When registered as a callback, it persists the active logging directory
in checkpoint metadata via the Stateful protocol. On resume from a
checkpoint, it redirects logging to the original directory and uses
SummaryWriter's purge_step to discard stale events recorded after
the checkpoint step.
When used as a SummaryWriter (passed to TBLogger, GradLogger, etc.), it proxies method calls to the underlying writer, constructing it lazily on first use.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_dir
|
str
|
Logging directory path (typically |
required |
Source code in src/forgather/ml/trainer/callbacks/resumable_summary_writer.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | |