Training Performance Metrics¶
Forgather tracks token throughput and estimated FLOPs during training, reporting
both per-interval speed metrics in the console and cumulative totals in the final
training output and trainer_logs.json.
Overview¶
The trainer automatically:
- Counts non-padding tokens processed each step (using the cross-entropy
ignore_index=-100in labels as the mask, so padding and special tokens are excluded) - Estimates FLOPs per token from the model's trainable parameter count using the standard
transformer approximation:
18 × num_paramsper token (6N forward + 12N backward) - Accumulates both counts into
state.num_input_tokens_seenandstate.total_flos - Synchronizes counts across distributed processes at each log step (not every step, to minimize communication overhead)
Final training metrics¶
At the end of training, the following metrics are added to the output dict and logged:
| Metric | Description |
|---|---|
total_tokens |
Total non-padding tokens processed (from state.num_input_tokens_seen) |
tokens_per_second |
Tokens / total runtime (after warmup) |
total_flops |
Estimated total FLOPs (from state.total_flos) |
flops_per_second |
Estimated total FLOPs / total runtime |
Per-interval metrics (ProgressCallback)¶
The ProgressCallback computes two types of per-interval speed metrics:
-
tok/s (token throughput): Uses wall-clock time between log steps, capturing real end-to-end throughput including optimizer updates, data loading, gradient synchronization, and all other overhead. This gives an accurate picture of actual training speed and is useful for comparing different optimizers or configurations.
-
MFU (Model FLOPs Utilization): Uses accumulated pure training step time (forward + backward pass only, from
on_step_begintoon_step_end), excluding evaluation, optimizer, and data loading time. This measures how efficiently the hardware is utilized during the compute-bound portion of training.
Both are display-only; they are not written to trainer_logs.json. The underlying
token and FLOP values in trainer_logs.json can be used to reproduce these
calculations offline.
Callback configuration¶
Performance metrics are split across two callbacks that work together:
DefaultMetricscomputes derived metrics (tok_per_sec,mfu,peak_mem) duringon_log_step, before other callbacks see the log entry.ProgressCallbackformats and displays the console output, using column specifications to control which metrics appear.
Both are included by default in Forgather trainers.
DefaultMetrics¶
from forgather.ml.trainer.callbacks import DefaultMetrics
callbacks = [
DefaultMetrics(
peak_hardware_flops=4 * 165.2e12, # 4× RTX 4090, for MFU display
),
]
| Parameter | Default | Description |
|---|---|---|
peak_hardware_flops |
None |
Aggregate peak BF16 FLOP/s across all GPUs; enables MFU display |
When peak_hardware_flops is set, DefaultMetrics computes MFU each log step. Token
throughput (tok_per_sec) is always computed when token counts are available in the logs.
ProgressCallback¶
ProgressCallback controls the console display. It does not compute metrics itself;
it renders whatever metrics are present in the log entry (including those injected by
DefaultMetrics).
from forgather.ml.trainer.callbacks import ProgressCallback
callbacks = [
ProgressCallback(
use_tqdm=False, # Use line-based logging instead of TQDM
header_interval=20, # Print column headers every 20 log steps
step_columns={...}, # Override default column display (see below)
final_metrics={...}, # Override final summary metrics
),
]
| Parameter | Default | Description |
|---|---|---|
use_tqdm |
None (auto) |
True for TQDM progress bar, False for line-based logging, None to auto-detect |
output_stream |
None |
Output stream for line-based logging ("stdout", "stderr", or a TextIOBase) |
step_columns |
None |
Dict of column spec overrides, merged with defaults. Set a key to None to remove it |
final_metrics |
None |
Dict of final metric spec overrides, merged with defaults |
header_interval |
20 |
Print column headers every N log steps |
Columns are displayed only when the corresponding metric key appears in the current log
entry. The default columns include loss, learning_rate, grad_norm, tok_per_sec,
mfu, and peak_mem. Override step_columns to customize which metrics are shown and
their formatting.
Customizing the progress display¶
Each column in the step-log table is described by a ColumnSpec with five fields:
| Field | Type | Description |
|---|---|---|
key |
str |
Metric key in the log entry (e.g. "loss", "tok_per_sec"). |
label |
str |
Column header text. Defaults to key when empty. |
width |
int |
Fixed column width in characters. |
fmt |
str or callable |
How to format a scalar value (see below). |
reduce |
str, callable, or None |
How to render list-valued metrics (see below). |
step_columns accepts a dict of {key: spec_overrides}; values are shallow-merged with
default_step_columns(). Setting a key to None removes that column from the defaults.
The merged result is turned into a list of ColumnSpec objects; column order follows
insertion order.
The fmt field¶
fmt controls scalar formatting. It accepts:
- A Python format-spec string — applied via
format(value, spec). Integer presentation types (d,o,x, …) auto-convert the value tointfirst.
"loss": {"fmt": ".5f"} # 2.34567
"learning_rate": {"fmt": ".2e"} # 1.00e-04
"tokens": {"fmt": ",d"} # 8,192
"mfu": {"fmt": ".1%"} # 42.0%
- A named formatter alias — shorthand for common unit-aware formatters:
| Alias | Behavior | Example |
|---|---|---|
"si" |
SI prefixes (K, M, G, …) | 8.19M |
"gib" |
Binary gibibytes | 1.863 GiB |
- A callable — any
Callable[[Any], str]is invoked directly with the value.
- An empty string — type-based fallback (float →
.4g, int → comma-separated, elsestr()).
The reduce field¶
Some metrics arrive as a per-rank list rather than a scalar. peak_mem_allocated, for
example, is captured on every rank inside Trainer._log_step and stored in the log
entry as a list of per-rank bytes (length equal to world size; length 1 for single-GPU
runs). reduce controls how such a list is rendered in a fixed-width column:
| Value | Behavior |
|---|---|
None (default) |
Scalars pass through unchanged; lists fall back to an implicit max reduction. |
"max" / "min" / "mean" / "sum" |
Reduce the list to a scalar, then format via fmt. |
"all" |
Format each element with fmt and join with "/" (per-rank display; may overflow width for large world sizes). |
Callable[[list], Any] |
Apply the callable to the list, then format the result via fmt. |
reduce is silently ignored for scalar values unless it is a callable, in which case
the callable is still applied.
Example: show peak memory from all ranks¶
The default peak_mem column reduces the per-rank list with max, so the progress
display stays compact:
To show every rank's peak instead, override reduce to "all":
from forgather.ml.trainer.callbacks import ProgressCallback
callbacks = [
ProgressCallback(
step_columns={
"peak_mem": {"width": 32, "reduce": "all"},
},
),
]
Using the configuration syntax:
[step_columns]
.define: &step_columns !dict
peak_mem: {"width": 32, "reduce": "all"}
[callback_list]
trainer_callbacks: &trainer_callbacks !dlist:@trainer_callbacks
progress_callback: !singleton:forgather.ml.trainer.callbacks:ProgressCallback
step_columns: *step_columns
With a 2-GPU DDP run this renders as:
Widen width to fit your world size; per-rank display scales linearly with the number
of ranks, so plan on roughly len("X.XXX GiB/") × world_size characters.
Example: show min and max peak memory in separate columns¶
Point two columns at the same log key and give each its own reduction:
callbacks = [
ProgressCallback(
step_columns={
# Remove the default peak_mem entry.
"peak_mem": None,
# Add two new columns keyed off the raw trainer metric.
"peak_mem_min": {
"key": "peak_mem_allocated",
"label": "min_mem",
"width": 11,
"fmt": "gib",
"reduce": "min",
},
"peak_mem_max": {
"key": "peak_mem_allocated",
"label": "max_mem",
"width": 11,
"fmt": "gib",
"reduce": "max",
},
},
),
]
Note that the top-level dict key (peak_mem_min) is only used for merging and ordering;
the actual lookup into the log entry uses the nested key field.
Example: gap between max and min as a custom reduction¶
A callable reduce can return any scalar, which is then formatted via fmt:
callbacks = [
ProgressCallback(
step_columns={
"peak_mem_gap": {
"key": "peak_mem_allocated",
"label": "mem_gap",
"width": 11,
"fmt": "gib",
"reduce": lambda xs: max(xs) - min(xs),
},
},
),
]
This surfaces per-rank memory imbalance — useful for spotting a straggler stage in pipeline-parallel training.
What lands in trainer_logs.json¶
JsonLogger serializes the log entry with json.dumps, so list-valued metrics land in
trainer_logs.json as native JSON arrays:
This format is used for single-GPU runs as well (a 1-element array), so any downstream
analysis tool that reads peak_mem_allocated must handle lists. The
forgather.ml.trainer.logging.format_value() helper is a convenient way to reproduce
the progress-display formatting offline — it accepts the same (value, fmt, reduce)
arguments used by ColumnSpec.
Setting peak_hardware_flops¶
peak_hardware_flops must be the aggregate peak FLOP/s across all GPUs used in
the training job. The trainer accumulates total_flos by counting tokens across all
processes (via all_reduce), so achieved_flops = delta_flos / elapsed is the total
rate for the entire job, not per-GPU.
For a 4-GPU job on RTX 4090s:
Automatic GPU detection¶
The project templates (lm_training_project.yaml, tiny.yaml, etc.) use the
get_peak_hardware_flops() preprocessor function to auto-detect the per-GPU
peak BF16 FLOP/s. The function follows this resolution order:
- Read
~/.config/forgather/hardware.yaml-- if the file exists and contains apeak_hardware_flopsvalue, use it immediately. - Detect the current GPU via
torch.cuda.get_device_name()and look it up in the built-in reference table (the same table shown below). - Cache the result -- write the detected value to
~/.config/forgather/hardware.yamlso subsequent runs skip detection.
The multi-GPU templates (e.g. lm_training_project.yaml) multiply the per-GPU
value by world_size automatically, so the auto-detected value is always
per-device.
If the GPU is not in the reference table, the function returns null and MFU
is disabled. In that case, create the file manually:
To re-trigger auto-detection, delete the file:
The --peak-hardware-flops CLI argument still overrides everything -- the
auto-detected value is only used as a default when neither the CLI argument
nor a template override is specified.
Peak BF16 FLOP/s reference table¶
The figures below are the dense BF16 Tensor Core numbers with FP32 accumulation, which is what PyTorch uses in mixed-precision (autocast BF16) training. This is the correct figure for MFU calculations.
Note that NVIDIA spec sheets for consumer GPUs (Ada and Blackwell) often advertise the higher FP16-with-FP16-accumulation figure (approximately 2x the values below). BF16 on these architectures always accumulates in FP32, so the half-rate figure is the correct one for standard training workloads.
These values are also used by the get_peak_hardware_flops() auto-detection
function (see Automatic GPU detection above).
NVIDIA Data Center GPUs¶
| GPU | Architecture | BF16 dense (FP32 accum) |
|---|---|---|
| B200 | Blackwell | 2250 TFLOPS |
| B100 | Blackwell | 1750 TFLOPS |
| H200 SXM | Hopper | 989 TFLOPS |
| H100 SXM | Hopper | 989 TFLOPS |
| H100 PCIe | Hopper | 756 TFLOPS |
| H800 SXM | Hopper | 989 TFLOPS |
| H800 PCIe | Hopper | 756 TFLOPS |
| H20 | Hopper | 148 TFLOPS |
| L40S | Ada Lovelace | 362 TFLOPS |
| L40 | Ada Lovelace | 181 TFLOPS |
| L4 | Ada Lovelace | 121 TFLOPS |
| A100 SXM 80GB | Ampere | 312 TFLOPS |
| A100 PCIe 80GB | Ampere | 312 TFLOPS |
| A100 SXM 40GB | Ampere | 312 TFLOPS |
| A800 80GB | Ampere | 312 TFLOPS |
| A40 | Ampere | 149.7 TFLOPS |
| A30 | Ampere | 165 TFLOPS |
| A10 | Ampere | 31.2 TFLOPS |
Note on H200, H800, A800: These are variants of the H100 and A100 with different memory configurations or reduced NVLink bandwidth (for export compliance). The BF16 compute throughput is identical to the base model.
NVIDIA Professional / Workstation GPUs¶
| GPU | Architecture | BF16 dense (FP32 accum) |
|---|---|---|
| RTX PRO 6000 | Blackwell | 251.9 TFLOPS |
| RTX 6000 Ada | Ada Lovelace | 181 TFLOPS |
| RTX A6000 | Ampere | 154.8 TFLOPS |
| RTX A5000 | Ampere | 111.1 TFLOPS |
| RTX A4000 | Ampere | 76.7 TFLOPS |
Note on professional Ada cards: Professional Ada GPUs (RTX 6000 Ada, L40, L40S) run BF16 tensor ops with FP32 accumulation at full speed -- unlike GeForce Ada cards, which run at half speed.
NVIDIA Consumer GPUs¶
| GPU | Architecture | BF16 dense (FP32 accum) |
|---|---|---|
| RTX 5090 | Blackwell | 209.5 TFLOPS |
| RTX 5080 | Blackwell | 112.6 TFLOPS |
| RTX 5070 Ti | Blackwell | 87.8 TFLOPS |
| RTX 5070 | Blackwell | 61.8 TFLOPS |
| RTX 5060 Ti | Blackwell | 47.4 TFLOPS |
| RTX 5060 | Blackwell | 38.4 TFLOPS |
| RTX 4090 | Ada Lovelace | 165.2 TFLOPS |
| RTX 4080 SUPER | Ada Lovelace | 104.4 TFLOPS |
| RTX 4080 | Ada Lovelace | 97.0 TFLOPS |
| RTX 4070 Ti SUPER | Ada Lovelace | 79.8 TFLOPS |
| RTX 4070 Ti | Ada Lovelace | 40.1 TFLOPS |
| RTX 4070 SUPER | Ada Lovelace | 35.5 TFLOPS |
| RTX 4070 | Ada Lovelace | 29.1 TFLOPS |
| RTX 4060 Ti | Ada Lovelace | 22.1 TFLOPS |
| RTX 4060 | Ada Lovelace | 15.1 TFLOPS |
| RTX 3090 Ti | Ampere | 79.8 TFLOPS |
| RTX 3090 | Ampere | 71.2 TFLOPS |
| RTX 3080 Ti | Ampere | 59.8 TFLOPS |
| RTX 3080 | Ampere | 44.7 TFLOPS |
| RTX 3070 Ti | Ampere | 43.5 TFLOPS |
| RTX 3070 | Ampere | 40.6 TFLOPS |
| RTX 3060 Ti | Ampere | 32.4 TFLOPS |
| RTX 3060 | Ampere | 25.5 TFLOPS |
Note on FP32 accumulation and consumer GPUs: On GeForce Ada (RTX 40xx) and Blackwell (RTX 50xx) cards, BF16 tensor ops with FP32 accumulation run at half the FP16-with-FP16-accumulation rate. NVIDIA's published specs for these cards often cite the higher FP16-accum figure. The values above are the correct half-rate numbers. Ampere consumer cards (RTX 30xx) do not have this penalty -- BF16 with FP32 accumulation runs at full speed on Ampere.
Example: multi-GPU configurations¶
| Configuration | peak_hardware_flops |
|---|---|
| 1× RTX 5090 | 209.5e12 |
| 4× RTX 5090 | 838e12 |
| 1× RTX 4090 | 165.2e12 |
| 4× RTX 4090 | 660.8e12 |
| 1× RTX 3090 | 71.2e12 |
| 4× A100 SXM | 1248e12 |
| 8× A100 SXM | 2496e12 |
| 8× H100 SXM | 7912e12 |
| 8× B200 | 18000e12 |
Notes on FLOP estimation accuracy¶
The 18 × num_params formula is a standard approximation for decoder-only transformer
models. It assumes:
- Forward pass:
6 × num_paramsFLOPs per token (2 multiply-adds per weight, times 3 for Q, K, V projections and attention being rolled into the parameter count) - Backward pass:
12 × num_paramsFLOPs per token (approximately 2× forward)
Real FLOPs will differ from this estimate due to:
- Attention FLOPs: The quadratic attention term (
2 × seq_len × model_dimper layer) is not included. For short sequences this is negligible; at very long sequence lengths it can be significant. - Non-transformer architectures: The formula assumes a standard transformer with weight matrices dominating the compute. Models with unusual architectures (MoE, state-space models, etc.) may diverge substantially.
- Gradient checkpointing: Recomputes activations during backward, adding approximately
one extra forward pass. The true FLOPs are closer to
24 × num_paramsper token when gradient checkpointing is enabled, though the18×estimate is still commonly used.
For comparing runs on the same model and hardware, the absolute accuracy of the estimate does not matter — the MFU and FLOP/s values are consistent relative to each other.