Skip to content

Analysis

Tools for parsing, summarizing, and visualizing training logs produced by Forgather's JSON logger. For CLI usage, plots, and the full forgather logs command reference, see Log Analysis.

Quick Example

from forgather.ml.analysis import TrainingLog, compute_summary_statistics

log = TrainingLog.from_file("output_models/my_model/runs/run_id/trainer_logs.json")
summary = compute_summary_statistics(log)
print(f"Best loss: {summary['best_loss']} at step {summary['best_loss_step']}")

forgather.ml.analysis.TrainingLog dataclass

Container for a parsed Forgather training log.

Holds all JSON records emitted by Forgather's JSON logger (trainer_logs.json) together with metadata inferred from the file-system path. Typically created via :meth:from_file or :meth:from_run_dir rather than constructed directly.

Parameters:

Name Type Description Default
log_path Path

Absolute path to the trainer_logs.json file.

required
records list of dict

Raw JSON records as loaded from the log file. Each record is a dictionary that may contain keys such as global_step, loss, eval_loss, learning_rate, grad_norm, epoch, timestamp, and train_runtime.

required
run_name str

Human-readable name of the training run, usually the timestamped directory name under runs/. Inferred from log_path when not provided.

None
model_name str

Name of the model, usually the directory immediately above runs/ in the output path. Inferred from log_path when not provided.

None
label str

Explicit display label used when plotting. When set, this takes priority over model_name and run_name.

None

Examples:

>>> from forgather.ml.analysis import TrainingLog
>>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
>>> train_records = log.get_training_records()
>>> losses = log.get_metric_values("loss", train_records)
Source code in src/forgather/ml/analysis/log_parser.py
@dataclass
class TrainingLog:
    """Container for a parsed Forgather training log.

    Holds all JSON records emitted by Forgather's JSON logger
    (``trainer_logs.json``) together with metadata inferred from the
    file-system path.  Typically created via :meth:`from_file` or
    :meth:`from_run_dir` rather than constructed directly.

    Parameters
    ----------
    log_path : Path
        Absolute path to the ``trainer_logs.json`` file.
    records : list of dict
        Raw JSON records as loaded from the log file.  Each record is a
        dictionary that may contain keys such as ``global_step``, ``loss``,
        ``eval_loss``, ``learning_rate``, ``grad_norm``, ``epoch``,
        ``timestamp``, and ``train_runtime``.
    run_name : str, optional
        Human-readable name of the training run, usually the timestamped
        directory name under ``runs/``.  Inferred from *log_path* when not
        provided.
    model_name : str, optional
        Name of the model, usually the directory immediately above ``runs/``
        in the output path.  Inferred from *log_path* when not provided.
    label : str, optional
        Explicit display label used when plotting.  When set, this takes
        priority over *model_name* and *run_name*.

    Examples
    --------
    >>> from forgather.ml.analysis import TrainingLog
    >>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
    >>> train_records = log.get_training_records()
    >>> losses = log.get_metric_values("loss", train_records)
    """

    log_path: Path
    records: List[Dict[str, Any]]
    run_name: Optional[str] = None
    model_name: Optional[str] = None
    label: Optional[str] = None

    def __post_init__(self):
        """Extract run name and model name from path if not provided."""
        parts = self.log_path.parts
        if "runs" in parts:
            runs_idx = parts.index("runs")
            if self.run_name is None and runs_idx + 1 < len(parts):
                self.run_name = parts[runs_idx + 1]
            if self.model_name is None and runs_idx > 0:
                self.model_name = parts[runs_idx - 1]

    def get_label(self, index: int = 0) -> str:
        """Return a human-readable label for this log.

        Selection priority: explicit :attr:`label` > :attr:`model_name` >
        :attr:`run_name` > ``'Run N'`` (where *N* is *index* + 1).

        Parameters
        ----------
        index : int, optional
            Zero-based position of this log in a collection, used as a
            fallback label suffix.  Default is 0.

        Returns
        -------
        str
            Display label suitable for plot legends and summary output.
        """
        if self.label:
            return self.label
        if self.model_name:
            return self.model_name
        if self.run_name:
            return self.run_name
        return f"Run {index + 1}"

    @classmethod
    def from_file(cls, log_path: str | Path) -> "TrainingLog":
        """Load a training log from a ``trainer_logs.json`` file.

        Handles truncated files (e.g. from a crash or a still-running job) by
        attempting to recover all complete JSON records before the truncation
        point.

        Parameters
        ----------
        log_path : str or Path
            Path to a ``trainer_logs.json`` file produced by Forgather's JSON
            logger.

        Returns
        -------
        TrainingLog
            Populated instance with all recoverable records.

        Raises
        ------
        FileNotFoundError
            If *log_path* does not exist on disk.
        ValueError
            If the file is not a valid JSON array and recovery fails.

        Examples
        --------
        >>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
        >>> print(f"Loaded {len(log.records)} records")
        """
        log_path = Path(log_path)
        if not log_path.exists():
            raise FileNotFoundError(f"Log file not found: {log_path}")

        with open(log_path, "r") as f:
            text = f.read()
        try:
            records = json.loads(text)
        except json.JSONDecodeError as e:
            # Attempt to recover truncated JSON (e.g., crash, still-running job)
            records = _try_recover_truncated_json(text, log_path, e)

        if not isinstance(records, list):
            raise ValueError("Log file must contain a JSON array")

        return cls(log_path=log_path, records=records)

    @classmethod
    def from_run_dir(cls, run_dir: str | Path) -> "TrainingLog":
        """Load a training log from a run directory.

        Convenience wrapper around :meth:`from_file` that automatically appends
        ``trainer_logs.json`` to the supplied directory path.

        Parameters
        ----------
        run_dir : str or Path
            Path to a run directory (e.g.
            ``output_models/my_model/runs/run_001/``) that contains a
            ``trainer_logs.json`` file.

        Returns
        -------
        TrainingLog
            Populated instance with all recoverable records.

        Raises
        ------
        FileNotFoundError
            If ``trainer_logs.json`` is not found inside *run_dir*.
        ValueError
            If the log file is not a valid JSON array and recovery fails.
        """
        run_dir = Path(run_dir)
        log_path = run_dir / "trainer_logs.json"
        return cls.from_file(log_path)

    def get_training_records(self) -> List[Dict[str, Any]]:
        """Return records that contain training-step metrics.

        Training records are identified by the presence of a ``loss`` key and
        the absence of an ``eval_loss`` key.  They typically also carry
        ``grad_norm``, ``learning_rate``, ``global_step``, ``epoch``, and
        ``timestamp``.

        Returns
        -------
        list of dict
            Subset of :attr:`records` corresponding to training steps.
        """
        return [r for r in self.records if "loss" in r and "eval_loss" not in r]

    def get_eval_records(self) -> List[Dict[str, Any]]:
        """Return records that contain evaluation metrics.

        Evaluation records are identified by the presence of an ``eval_loss``
        key.  They typically also carry ``global_step`` and ``epoch``.

        Returns
        -------
        list of dict
            Subset of :attr:`records` corresponding to evaluation checkpoints.
        """
        return [r for r in self.records if "eval_loss" in r]

    def get_final_record(self) -> Optional[Dict[str, Any]]:
        """Return the final summary record emitted at the end of training.

        The final record is identified by the presence of a ``train_runtime``
        key and may also contain ``train_samples``,
        ``train_samples_per_second``, ``train_steps_per_second``, and
        ``effective_batch_size``.

        Returns
        -------
        dict or None
            The last record containing ``train_runtime``, or ``None`` if no
            such record exists (e.g. training was interrupted).
        """
        for r in reversed(self.records):
            if "train_runtime" in r:
                return r
        return None

    def get_metric_values(
        self, metric: str, records: Optional[List[Dict[str, Any]]] = None
    ) -> List[float]:
        """Extract the values for a named metric from a set of records.

        Records that do not contain *metric* are silently skipped, so the
        returned list may be shorter than *records*.

        Parameters
        ----------
        metric : str
            Key to extract (e.g. ``'loss'``, ``'learning_rate'``,
            ``'grad_norm'``, ``'eval_loss'``, ``'global_step'``).
        records : list of dict, optional
            Records to search.  When ``None``, all :attr:`records` are used.

        Returns
        -------
        list of float
            Ordered values for *metric* drawn from the matching records.
        """
        if records is None:
            records = self.records
        return [r[metric] for r in records if metric in r]

    def get_steps(self, records: Optional[List[Dict[str, Any]]] = None) -> List[int]:
        """Extract ``global_step`` values from records.

        Parameters
        ----------
        records : list of dict, optional
            Records to search.  When ``None``, all :attr:`records` are used.

        Returns
        -------
        list of int
            Ordered global step numbers.
        """
        return self.get_metric_values("global_step", records)

    def get_epochs(self, records: Optional[List[Dict[str, Any]]] = None) -> List[float]:
        """Extract ``epoch`` values from records.

        Parameters
        ----------
        records : list of dict, optional
            Records to search.  When ``None``, all :attr:`records` are used.

        Returns
        -------
        list of float
            Ordered fractional epoch numbers.
        """
        return self.get_metric_values("epoch", records)

    def get_timestamps(
        self, records: Optional[List[Dict[str, Any]]] = None
    ) -> List[float]:
        """Extract ``timestamp`` values from records.

        Timestamps are Unix epoch seconds recorded when each log entry was
        written.  They can be used to build a wall-clock x-axis for plots.

        Parameters
        ----------
        records : list of dict, optional
            Records to search.  When ``None``, all :attr:`records` are used.

        Returns
        -------
        list of float
            Ordered Unix timestamps (seconds since epoch).
        """
        return self.get_metric_values("timestamp", records)

    def find_best_step(
        self, metric: str, mode: str = "min"
    ) -> Optional[tuple[int, float]]:
        """Find the training step at which a metric reaches its best value.

        Parameters
        ----------
        metric : str
            Metric key to search for (e.g. ``'loss'``, ``'eval_loss'``).
        mode : {'min', 'max'}, optional
            Whether to look for the minimum (``'min'``) or maximum (``'max'``)
            value.  Default is ``'min'``.

        Returns
        -------
        tuple of (int, float) or None
            ``(global_step, value)`` at the best record, or ``None`` if no
            record contains *metric*.
        """
        records = [r for r in self.records if metric in r]
        if not records:
            return None

        if mode == "min":
            best_record = min(records, key=lambda r: r[metric])
        else:
            best_record = max(records, key=lambda r: r[metric])

        return best_record["global_step"], best_record[metric]

__post_init__()

Extract run name and model name from path if not provided.

Source code in src/forgather/ml/analysis/log_parser.py
def __post_init__(self):
    """Extract run name and model name from path if not provided."""
    parts = self.log_path.parts
    if "runs" in parts:
        runs_idx = parts.index("runs")
        if self.run_name is None and runs_idx + 1 < len(parts):
            self.run_name = parts[runs_idx + 1]
        if self.model_name is None and runs_idx > 0:
            self.model_name = parts[runs_idx - 1]

get_label(index=0)

Return a human-readable label for this log.

Selection priority: explicit :attr:label > :attr:model_name > :attr:run_name > 'Run N' (where N is index + 1).

Parameters:

Name Type Description Default
index int

Zero-based position of this log in a collection, used as a fallback label suffix. Default is 0.

0

Returns:

Type Description
str

Display label suitable for plot legends and summary output.

Source code in src/forgather/ml/analysis/log_parser.py
def get_label(self, index: int = 0) -> str:
    """Return a human-readable label for this log.

    Selection priority: explicit :attr:`label` > :attr:`model_name` >
    :attr:`run_name` > ``'Run N'`` (where *N* is *index* + 1).

    Parameters
    ----------
    index : int, optional
        Zero-based position of this log in a collection, used as a
        fallback label suffix.  Default is 0.

    Returns
    -------
    str
        Display label suitable for plot legends and summary output.
    """
    if self.label:
        return self.label
    if self.model_name:
        return self.model_name
    if self.run_name:
        return self.run_name
    return f"Run {index + 1}"

from_file(log_path) classmethod

Load a training log from a trainer_logs.json file.

Handles truncated files (e.g. from a crash or a still-running job) by attempting to recover all complete JSON records before the truncation point.

Parameters:

Name Type Description Default
log_path str or Path

Path to a trainer_logs.json file produced by Forgather's JSON logger.

required

Returns:

Type Description
TrainingLog

Populated instance with all recoverable records.

Raises:

Type Description
FileNotFoundError

If log_path does not exist on disk.

ValueError

If the file is not a valid JSON array and recovery fails.

Examples:

>>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
>>> print(f"Loaded {len(log.records)} records")
Source code in src/forgather/ml/analysis/log_parser.py
@classmethod
def from_file(cls, log_path: str | Path) -> "TrainingLog":
    """Load a training log from a ``trainer_logs.json`` file.

    Handles truncated files (e.g. from a crash or a still-running job) by
    attempting to recover all complete JSON records before the truncation
    point.

    Parameters
    ----------
    log_path : str or Path
        Path to a ``trainer_logs.json`` file produced by Forgather's JSON
        logger.

    Returns
    -------
    TrainingLog
        Populated instance with all recoverable records.

    Raises
    ------
    FileNotFoundError
        If *log_path* does not exist on disk.
    ValueError
        If the file is not a valid JSON array and recovery fails.

    Examples
    --------
    >>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
    >>> print(f"Loaded {len(log.records)} records")
    """
    log_path = Path(log_path)
    if not log_path.exists():
        raise FileNotFoundError(f"Log file not found: {log_path}")

    with open(log_path, "r") as f:
        text = f.read()
    try:
        records = json.loads(text)
    except json.JSONDecodeError as e:
        # Attempt to recover truncated JSON (e.g., crash, still-running job)
        records = _try_recover_truncated_json(text, log_path, e)

    if not isinstance(records, list):
        raise ValueError("Log file must contain a JSON array")

    return cls(log_path=log_path, records=records)

from_run_dir(run_dir) classmethod

Load a training log from a run directory.

Convenience wrapper around :meth:from_file that automatically appends trainer_logs.json to the supplied directory path.

Parameters:

Name Type Description Default
run_dir str or Path

Path to a run directory (e.g. output_models/my_model/runs/run_001/) that contains a trainer_logs.json file.

required

Returns:

Type Description
TrainingLog

Populated instance with all recoverable records.

Raises:

Type Description
FileNotFoundError

If trainer_logs.json is not found inside run_dir.

ValueError

If the log file is not a valid JSON array and recovery fails.

Source code in src/forgather/ml/analysis/log_parser.py
@classmethod
def from_run_dir(cls, run_dir: str | Path) -> "TrainingLog":
    """Load a training log from a run directory.

    Convenience wrapper around :meth:`from_file` that automatically appends
    ``trainer_logs.json`` to the supplied directory path.

    Parameters
    ----------
    run_dir : str or Path
        Path to a run directory (e.g.
        ``output_models/my_model/runs/run_001/``) that contains a
        ``trainer_logs.json`` file.

    Returns
    -------
    TrainingLog
        Populated instance with all recoverable records.

    Raises
    ------
    FileNotFoundError
        If ``trainer_logs.json`` is not found inside *run_dir*.
    ValueError
        If the log file is not a valid JSON array and recovery fails.
    """
    run_dir = Path(run_dir)
    log_path = run_dir / "trainer_logs.json"
    return cls.from_file(log_path)

get_training_records()

Return records that contain training-step metrics.

Training records are identified by the presence of a loss key and the absence of an eval_loss key. They typically also carry grad_norm, learning_rate, global_step, epoch, and timestamp.

Returns:

Type Description
list of dict

Subset of :attr:records corresponding to training steps.

Source code in src/forgather/ml/analysis/log_parser.py
def get_training_records(self) -> List[Dict[str, Any]]:
    """Return records that contain training-step metrics.

    Training records are identified by the presence of a ``loss`` key and
    the absence of an ``eval_loss`` key.  They typically also carry
    ``grad_norm``, ``learning_rate``, ``global_step``, ``epoch``, and
    ``timestamp``.

    Returns
    -------
    list of dict
        Subset of :attr:`records` corresponding to training steps.
    """
    return [r for r in self.records if "loss" in r and "eval_loss" not in r]

get_eval_records()

Return records that contain evaluation metrics.

Evaluation records are identified by the presence of an eval_loss key. They typically also carry global_step and epoch.

Returns:

Type Description
list of dict

Subset of :attr:records corresponding to evaluation checkpoints.

Source code in src/forgather/ml/analysis/log_parser.py
def get_eval_records(self) -> List[Dict[str, Any]]:
    """Return records that contain evaluation metrics.

    Evaluation records are identified by the presence of an ``eval_loss``
    key.  They typically also carry ``global_step`` and ``epoch``.

    Returns
    -------
    list of dict
        Subset of :attr:`records` corresponding to evaluation checkpoints.
    """
    return [r for r in self.records if "eval_loss" in r]

get_final_record()

Return the final summary record emitted at the end of training.

The final record is identified by the presence of a train_runtime key and may also contain train_samples, train_samples_per_second, train_steps_per_second, and effective_batch_size.

Returns:

Type Description
dict or None

The last record containing train_runtime, or None if no such record exists (e.g. training was interrupted).

Source code in src/forgather/ml/analysis/log_parser.py
def get_final_record(self) -> Optional[Dict[str, Any]]:
    """Return the final summary record emitted at the end of training.

    The final record is identified by the presence of a ``train_runtime``
    key and may also contain ``train_samples``,
    ``train_samples_per_second``, ``train_steps_per_second``, and
    ``effective_batch_size``.

    Returns
    -------
    dict or None
        The last record containing ``train_runtime``, or ``None`` if no
        such record exists (e.g. training was interrupted).
    """
    for r in reversed(self.records):
        if "train_runtime" in r:
            return r
    return None

get_metric_values(metric, records=None)

Extract the values for a named metric from a set of records.

Records that do not contain metric are silently skipped, so the returned list may be shorter than records.

Parameters:

Name Type Description Default
metric str

Key to extract (e.g. 'loss', 'learning_rate', 'grad_norm', 'eval_loss', 'global_step').

required
records list of dict

Records to search. When None, all :attr:records are used.

None

Returns:

Type Description
list of float

Ordered values for metric drawn from the matching records.

Source code in src/forgather/ml/analysis/log_parser.py
def get_metric_values(
    self, metric: str, records: Optional[List[Dict[str, Any]]] = None
) -> List[float]:
    """Extract the values for a named metric from a set of records.

    Records that do not contain *metric* are silently skipped, so the
    returned list may be shorter than *records*.

    Parameters
    ----------
    metric : str
        Key to extract (e.g. ``'loss'``, ``'learning_rate'``,
        ``'grad_norm'``, ``'eval_loss'``, ``'global_step'``).
    records : list of dict, optional
        Records to search.  When ``None``, all :attr:`records` are used.

    Returns
    -------
    list of float
        Ordered values for *metric* drawn from the matching records.
    """
    if records is None:
        records = self.records
    return [r[metric] for r in records if metric in r]

get_steps(records=None)

Extract global_step values from records.

Parameters:

Name Type Description Default
records list of dict

Records to search. When None, all :attr:records are used.

None

Returns:

Type Description
list of int

Ordered global step numbers.

Source code in src/forgather/ml/analysis/log_parser.py
def get_steps(self, records: Optional[List[Dict[str, Any]]] = None) -> List[int]:
    """Extract ``global_step`` values from records.

    Parameters
    ----------
    records : list of dict, optional
        Records to search.  When ``None``, all :attr:`records` are used.

    Returns
    -------
    list of int
        Ordered global step numbers.
    """
    return self.get_metric_values("global_step", records)

get_epochs(records=None)

Extract epoch values from records.

Parameters:

Name Type Description Default
records list of dict

Records to search. When None, all :attr:records are used.

None

Returns:

Type Description
list of float

Ordered fractional epoch numbers.

Source code in src/forgather/ml/analysis/log_parser.py
def get_epochs(self, records: Optional[List[Dict[str, Any]]] = None) -> List[float]:
    """Extract ``epoch`` values from records.

    Parameters
    ----------
    records : list of dict, optional
        Records to search.  When ``None``, all :attr:`records` are used.

    Returns
    -------
    list of float
        Ordered fractional epoch numbers.
    """
    return self.get_metric_values("epoch", records)

get_timestamps(records=None)

Extract timestamp values from records.

Timestamps are Unix epoch seconds recorded when each log entry was written. They can be used to build a wall-clock x-axis for plots.

Parameters:

Name Type Description Default
records list of dict

Records to search. When None, all :attr:records are used.

None

Returns:

Type Description
list of float

Ordered Unix timestamps (seconds since epoch).

Source code in src/forgather/ml/analysis/log_parser.py
def get_timestamps(
    self, records: Optional[List[Dict[str, Any]]] = None
) -> List[float]:
    """Extract ``timestamp`` values from records.

    Timestamps are Unix epoch seconds recorded when each log entry was
    written.  They can be used to build a wall-clock x-axis for plots.

    Parameters
    ----------
    records : list of dict, optional
        Records to search.  When ``None``, all :attr:`records` are used.

    Returns
    -------
    list of float
        Ordered Unix timestamps (seconds since epoch).
    """
    return self.get_metric_values("timestamp", records)

find_best_step(metric, mode='min')

Find the training step at which a metric reaches its best value.

Parameters:

Name Type Description Default
metric str

Metric key to search for (e.g. 'loss', 'eval_loss').

required
mode (min, max)

Whether to look for the minimum ('min') or maximum ('max') value. Default is 'min'.

'min'

Returns:

Type Description
tuple of (int, float) or None

(global_step, value) at the best record, or None if no record contains metric.

Source code in src/forgather/ml/analysis/log_parser.py
def find_best_step(
    self, metric: str, mode: str = "min"
) -> Optional[tuple[int, float]]:
    """Find the training step at which a metric reaches its best value.

    Parameters
    ----------
    metric : str
        Metric key to search for (e.g. ``'loss'``, ``'eval_loss'``).
    mode : {'min', 'max'}, optional
        Whether to look for the minimum (``'min'``) or maximum (``'max'``)
        value.  Default is ``'min'``.

    Returns
    -------
    tuple of (int, float) or None
        ``(global_step, value)`` at the best record, or ``None`` if no
        record contains *metric*.
    """
    records = [r for r in self.records if metric in r]
    if not records:
        return None

    if mode == "min":
        best_record = min(records, key=lambda r: r[metric])
    else:
        best_record = max(records, key=lambda r: r[metric])

    return best_record["global_step"], best_record[metric]

forgather.ml.analysis.compute_summary_statistics(log)

Compute summary statistics from a training log.

Aggregates training-step records, evaluation records, and the final summary record into a flat dictionary of key metrics. Keys are only present when the underlying data exists; callers should use summary.get(key) rather than direct indexing.

Parameters:

Name Type Description Default
log TrainingLog

Parsed training log to summarise.

required

Returns:

Type Description
dict

Dictionary with a subset of the following keys, depending on what data is available in log:

run_name : str or None Name of the training run. log_path : str String representation of the log file path. total_steps : int Global step number of the last training record. final_epoch : float Epoch number at the last training step. final_loss : float Training loss at the last recorded step. avg_loss : float Mean training loss over all recorded steps. min_loss : float Minimum training loss observed during the run. best_loss : float Training loss at the step where it was lowest (same as min_loss but paired with best_loss_step). best_loss_step : int Global step at which best_loss was achieved. avg_grad_norm : float Mean gradient norm over all training steps that recorded it. max_grad_norm_value : float Peak gradient norm observed during training. max_grad_norm_step : int Global step at which max_grad_norm_value was observed. initial_lr : float Learning rate at the first training step. final_lr : float Learning rate at the last training step. final_eval_loss : float Evaluation loss from the most recent evaluation checkpoint. best_eval_loss : float Lowest evaluation loss observed. best_eval_loss_step : int Global step at which best_eval_loss was achieved. train_runtime : float Total training wall-clock time in seconds. train_samples : int Total number of training samples processed. train_samples_per_second : float Average throughput in samples per second. train_steps_per_second : float Average throughput in optimizer steps per second. effective_batch_size : int Effective batch size (local batch x gradient accumulation x world size).

Examples:

>>> from forgather.ml.analysis import TrainingLog, compute_summary_statistics
>>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
>>> summary = compute_summary_statistics(log)
>>> print(f"Best loss: {summary['best_loss']:.4f} at step {summary['best_loss_step']}")
Source code in src/forgather/ml/analysis/metrics.py
def compute_summary_statistics(log: TrainingLog) -> Dict[str, Any]:
    """Compute summary statistics from a training log.

    Aggregates training-step records, evaluation records, and the final
    summary record into a flat dictionary of key metrics.  Keys are only
    present when the underlying data exists; callers should use
    ``summary.get(key)`` rather than direct indexing.

    Parameters
    ----------
    log : TrainingLog
        Parsed training log to summarise.

    Returns
    -------
    dict
        Dictionary with a subset of the following keys, depending on what
        data is available in *log*:

        ``run_name`` : str or None
            Name of the training run.
        ``log_path`` : str
            String representation of the log file path.
        ``total_steps`` : int
            Global step number of the last training record.
        ``final_epoch`` : float
            Epoch number at the last training step.
        ``final_loss`` : float
            Training loss at the last recorded step.
        ``avg_loss`` : float
            Mean training loss over all recorded steps.
        ``min_loss`` : float
            Minimum training loss observed during the run.
        ``best_loss`` : float
            Training loss at the step where it was lowest (same as
            ``min_loss`` but paired with ``best_loss_step``).
        ``best_loss_step`` : int
            Global step at which ``best_loss`` was achieved.
        ``avg_grad_norm`` : float
            Mean gradient norm over all training steps that recorded it.
        ``max_grad_norm_value`` : float
            Peak gradient norm observed during training.
        ``max_grad_norm_step`` : int
            Global step at which ``max_grad_norm_value`` was observed.
        ``initial_lr`` : float
            Learning rate at the first training step.
        ``final_lr`` : float
            Learning rate at the last training step.
        ``final_eval_loss`` : float
            Evaluation loss from the most recent evaluation checkpoint.
        ``best_eval_loss`` : float
            Lowest evaluation loss observed.
        ``best_eval_loss_step`` : int
            Global step at which ``best_eval_loss`` was achieved.
        ``train_runtime`` : float
            Total training wall-clock time in seconds.
        ``train_samples`` : int
            Total number of training samples processed.
        ``train_samples_per_second`` : float
            Average throughput in samples per second.
        ``train_steps_per_second`` : float
            Average throughput in optimizer steps per second.
        ``effective_batch_size`` : int
            Effective batch size (local batch x gradient accumulation x
            world size).

    Examples
    --------
    >>> from forgather.ml.analysis import TrainingLog, compute_summary_statistics
    >>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
    >>> summary = compute_summary_statistics(log)
    >>> print(f"Best loss: {summary['best_loss']:.4f} at step {summary['best_loss_step']}")
    """
    train_records = log.get_training_records()
    eval_records = log.get_eval_records()
    final_record = log.get_final_record()

    summary = {
        "run_name": log.run_name,
        "log_path": str(log.log_path),
    }

    # Training progress
    if train_records:
        summary["total_steps"] = train_records[-1].get("global_step", 0)
        summary["final_epoch"] = train_records[-1].get("epoch", 0)

    # Training metrics
    if train_records:
        losses = log.get_metric_values("loss", train_records)
        if losses:
            summary["final_loss"] = losses[-1]
            summary["avg_loss"] = sum(losses) / len(losses)
            summary["min_loss"] = min(losses)
            best_loss_step, best_loss = log.find_best_step("loss", mode="min")
            summary["best_loss"] = best_loss
            summary["best_loss_step"] = best_loss_step

        # Gradient statistics
        grad_norms = log.get_metric_values("grad_norm", train_records)
        if grad_norms:
            summary["avg_grad_norm"] = sum(grad_norms) / len(grad_norms)
            summary["max_grad_norm_value"] = max(grad_norms)
            max_idx = grad_norms.index(max(grad_norms))
            summary["max_grad_norm_step"] = train_records[max_idx]["global_step"]

        # Learning rate
        learning_rates = log.get_metric_values("learning_rate", train_records)
        if learning_rates:
            summary["initial_lr"] = learning_rates[0]
            summary["final_lr"] = learning_rates[-1]

    # Evaluation metrics
    if eval_records:
        eval_losses = log.get_metric_values("eval_loss", eval_records)
        if eval_losses:
            summary["final_eval_loss"] = eval_losses[-1]
            best_eval_step, best_eval_loss = log.find_best_step("eval_loss", mode="min")
            summary["best_eval_loss"] = best_eval_loss
            summary["best_eval_loss_step"] = best_eval_step

    # Training performance
    if final_record:
        summary["train_runtime"] = final_record.get("train_runtime")
        summary["train_samples"] = final_record.get("train_samples")
        summary["train_samples_per_second"] = final_record.get(
            "train_samples_per_second"
        )
        summary["train_steps_per_second"] = final_record.get("train_steps_per_second")
        summary["effective_batch_size"] = final_record.get("effective_batch_size")

    return summary

forgather.ml.analysis.plot_training_metrics(logs, metrics=None, x_axis='step', smooth_window=None, log_scale=False, output_path=None, figsize=(12, 8), show=False, title=None, ignore_outliers=True, perplexity=False, x_min=None, x_max=None, y_min=None, y_max=None)

Plot one or more training metrics from one or more training logs.

Creates a grid of subplots (up to two columns) with one panel per metric. When multiple logs are supplied each run is drawn in a distinct colour with a legend entry. For loss-like metrics (loss, eval_loss, grad_norm) the y-axis is automatically clipped to the 5th–95th percentile window to suppress early-training outliers; pass ignore_outliers=False to disable this.

Parameters:

Name Type Description Default
logs list of TrainingLog

One or more parsed training logs to plot.

required
metrics list of str

Metric keys to plot. Each element must be a key present in at least some log records (e.g. 'loss', 'eval_loss', 'learning_rate', 'grad_norm'). Default is ['loss', 'eval_loss', 'learning_rate'].

None
x_axis (step, epoch, time)

X-axis variable. 'step' uses global_step, 'epoch' uses epoch, and 'time' converts timestamps to elapsed minutes. Default is 'step'.

'step'
smooth_window int

When greater than 1, draws the raw series at low opacity and overlays a centred moving-average with the given window size. Default is None (no smoothing).

None
log_scale bool

Use a logarithmic y-axis. Outlier-aware auto-scaling is suppressed on log axes. Default is False.

False
output_path str or Path

If provided, the figure is saved to this path at 300 dpi. Parent directories are created automatically.

None
figsize tuple of int

(width, height) in inches passed to plt.subplots. Default is (12, 8).

(12, 8)
show bool

Call plt.show() after rendering. Default is False.

False
title str

Figure-level suptitle. When None no title is added.

None
ignore_outliers bool

Apply percentile-based y-axis clipping for loss-like metrics. Default is True.

True
perplexity bool

Convert loss values to perplexity (exp(loss)) for loss, train_loss, and eval_loss metrics. Default is False.

False
x_min float

Clip data and set the left x-axis limit to this value.

None
x_max float

Clip data and set the right x-axis limit to this value.

None
y_min float

Override the bottom y-axis limit. Takes priority over auto-scaling.

None
y_max float

Override the top y-axis limit. Takes priority over auto-scaling.

None

Returns:

Type Description
Figure

The rendered figure. The caller is responsible for closing it when no longer needed (plt.close(fig)).

Examples:

>>> from forgather.ml.analysis import TrainingLog
>>> from forgather.ml.analysis.plotting import plot_training_metrics
>>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
>>> fig = plot_training_metrics([log], metrics=["loss", "eval_loss"], smooth_window=20)
>>> fig.savefig("training.png", dpi=150)
Source code in src/forgather/ml/analysis/plotting.py
def plot_training_metrics(
    logs: List[TrainingLog],
    metrics: Optional[List[str]] = None,
    x_axis: str = "step",
    smooth_window: Optional[int] = None,
    log_scale: bool = False,
    output_path: Union[str, Path, None] = None,
    figsize: Tuple[int, int] = (12, 8),
    show: bool = False,
    title: Optional[str] = None,
    ignore_outliers: bool = True,
    perplexity: bool = False,
    x_min: Optional[float] = None,
    x_max: Optional[float] = None,
    y_min: Optional[float] = None,
    y_max: Optional[float] = None,
) -> Figure:
    """Plot one or more training metrics from one or more training logs.

    Creates a grid of subplots (up to two columns) with one panel per metric.
    When multiple logs are supplied each run is drawn in a distinct colour with
    a legend entry.  For loss-like metrics (``loss``, ``eval_loss``,
    ``grad_norm``) the y-axis is automatically clipped to the 5th–95th
    percentile window to suppress early-training outliers; pass
    ``ignore_outliers=False`` to disable this.

    Parameters
    ----------
    logs : list of TrainingLog
        One or more parsed training logs to plot.
    metrics : list of str, optional
        Metric keys to plot.  Each element must be a key present in at least
        some log records (e.g. ``'loss'``, ``'eval_loss'``,
        ``'learning_rate'``, ``'grad_norm'``).  Default is
        ``['loss', 'eval_loss', 'learning_rate']``.
    x_axis : {'step', 'epoch', 'time'}, optional
        X-axis variable.  ``'step'`` uses ``global_step``, ``'epoch'`` uses
        ``epoch``, and ``'time'`` converts timestamps to elapsed minutes.
        Default is ``'step'``.
    smooth_window : int, optional
        When greater than 1, draws the raw series at low opacity and overlays
        a centred moving-average with the given window size.  Default is
        ``None`` (no smoothing).
    log_scale : bool, optional
        Use a logarithmic y-axis.  Outlier-aware auto-scaling is suppressed
        on log axes.  Default is ``False``.
    output_path : str or Path, optional
        If provided, the figure is saved to this path at 300 dpi.  Parent
        directories are created automatically.
    figsize : tuple of int, optional
        ``(width, height)`` in inches passed to ``plt.subplots``.  Default is
        ``(12, 8)``.
    show : bool, optional
        Call ``plt.show()`` after rendering.  Default is ``False``.
    title : str, optional
        Figure-level suptitle.  When ``None`` no title is added.
    ignore_outliers : bool, optional
        Apply percentile-based y-axis clipping for loss-like metrics.
        Default is ``True``.
    perplexity : bool, optional
        Convert loss values to perplexity (``exp(loss)``) for ``loss``,
        ``train_loss``, and ``eval_loss`` metrics.  Default is ``False``.
    x_min : float, optional
        Clip data and set the left x-axis limit to this value.
    x_max : float, optional
        Clip data and set the right x-axis limit to this value.
    y_min : float, optional
        Override the bottom y-axis limit.  Takes priority over auto-scaling.
    y_max : float, optional
        Override the top y-axis limit.  Takes priority over auto-scaling.

    Returns
    -------
    matplotlib.figure.Figure
        The rendered figure.  The caller is responsible for closing it when
        no longer needed (``plt.close(fig)``).

    Examples
    --------
    >>> from forgather.ml.analysis import TrainingLog
    >>> from forgather.ml.analysis.plotting import plot_training_metrics
    >>> log = TrainingLog.from_file("output_models/my_model/runs/run_001/trainer_logs.json")
    >>> fig = plot_training_metrics([log], metrics=["loss", "eval_loss"], smooth_window=20)
    >>> fig.savefig("training.png", dpi=150)
    """
    if metrics is None:
        metrics = ["loss", "eval_loss", "learning_rate"]

    n_metrics = len(metrics)
    n_rows = (n_metrics + 1) // 2 if n_metrics > 1 else 1
    n_cols = 2 if n_metrics > 1 else 1

    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    for i in range(n_metrics, len(axes)):
        axes[i].set_visible(False)

    x_label = "Global Step"

    for metric_idx, metric in enumerate(metrics):
        ax = axes[metric_idx]
        plotted_series: List[List[float]] = []
        metric_is_perplexity = perplexity and _is_perplexity_metric(metric)

        for log_idx, log in enumerate(logs):
            color = _get_color(log_idx)
            label = log.get_label(log_idx)

            if metric == "eval_loss":
                records = log.get_eval_records()
            elif metric in ["loss", "grad_norm", "learning_rate", "max_grad_norm"]:
                records = log.get_training_records()
            else:
                records = [r for r in log.records if metric in r]

            if not records:
                continue

            y_values = log.get_metric_values(metric, records)
            x_values, x_label = _get_x_values(log, records, x_axis)

            if metric_is_perplexity:
                y_values = _apply_perplexity(y_values)

            x_values, y_values = _clip_to_x_window(x_values, y_values, x_min, x_max)
            if not x_values:
                continue

            if smooth_window and smooth_window > 1:
                y_values_smooth = smooth_values(y_values, smooth_window)
                ax.plot(x_values, y_values, alpha=0.15, linewidth=0.5, color=color)
                ax.plot(
                    x_values,
                    y_values_smooth,
                    label=label,
                    linewidth=2,
                    color=color,
                )
                plotted_series.append(list(y_values_smooth))
            else:
                ax.plot(x_values, y_values, label=label, linewidth=2, color=color)
                plotted_series.append(list(y_values))

        ax.set_xlabel(x_label)
        ax.set_ylabel(_metric_display_label(metric, perplexity))
        ax.set_title(f"{_metric_display_label(metric, perplexity)} vs {x_label}")
        ax.grid(True, alpha=0.3)
        if ax.get_legend_handles_labels()[0]:
            ax.legend()

        if log_scale:
            ax.set_yscale("log")

        if x_min is not None or x_max is not None:
            ax.set_xlim(left=x_min, right=x_max)

        # Only auto-clip y for loss-like metrics; don't squash LR.
        if plotted_series and (_is_loss_like_metric(metric) or metric_is_perplexity):
            _apply_ylim(
                ax,
                plotted_series,
                log_scale,
                ignore_outliers,
                y_min,
                y_max,
            )
        elif y_min is not None or y_max is not None:
            ax.set_ylim(bottom=y_min, top=y_max)

    if title:
        fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)

    plt.tight_layout()

    if output_path is not None:
        out_path = Path(output_path)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(out_path, dpi=300, bbox_inches="tight")

    if show:
        plt.show()

    return fig