Skip to content

Trainer Callbacks

Callbacks extend trainer behaviour at well-defined lifecycle events (step start/end, epoch start/end, evaluation, checkpoint save/load, etc.) without modifying trainer source. Pass a list of callback instances to any trainer's callbacks argument.

Related documentation:


Base Class

forgather.ml.trainer.trainer_types.TrainerCallback

Base class for trainer event callbacks.

Subclasses implement only the event methods they need. Any method not defined is simply never called for that callback. The trainer maintains a lazy index mapping event names to the callbacks that define handlers, so only relevant callbacks are invoked per event.

Available events (each receives args, state, control, **kwargs and may return None or an updated TrainerControl):

on_init_end          - After trainer initialization
on_train_begin       - Before training loop starts
on_train_end         - After training loop ends
on_epoch_begin       - Before each epoch
on_epoch_end         - After each epoch
on_step_begin        - Before each training step
on_step_end          - After each training step
on_substep_end       - After each gradient-accumulation sub-step
on_forward_backward_begin - Before each forward+backward micro-step
                       (inside gradient accumulation loop, after data
                       loading; fires once per micro-batch)
on_forward_backward_end   - After each forward+backward micro-step
                       (before optimizer, grad clipping, LR scheduler;
                       fires once per micro-batch)
on_optimizer_step    - After optimizer.step()
on_pre_optimizer_step - Before optimizer.step()
on_evaluate          - After evaluation
on_predict           - After prediction (also receives metrics)
on_prediction_step   - After each prediction batch
on_save              - After checkpoint save
on_log               - After metric logging (receives logs kwarg)

Forgather extensions (not in HF Trainer):

on_log_step          - Called before on_log; receives the mutable logs dict
                       so callbacks can inject custom metrics before logging
on_train_metrics     - Called each training step with per-step metrics
                       (loss, grad_norm, tokens, etc.) for fine-grained
                       monitoring or adaptive control

kwargs always include: model, processing_class, optimizer, lr_scheduler, train_dataloader, eval_dataloader, trainer

Compatible with HuggingFace TrainerCallback for easier porting. See: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py

Identification: Each callback has a name property used for logging (e.g. when the trainer reports which callback requested an early stop). Subclasses inherit the default, which returns the class name, or can override name to provide a more descriptive label (useful when multiple instances of the same class are registered with different configurations).

Source code in src/forgather/ml/trainer/trainer_types.py
class TrainerCallback:
    """
    Base class for trainer event callbacks.

    Subclasses implement only the event methods they need. Any method not
    defined is simply never called for that callback. The trainer maintains
    a lazy index mapping event names to the callbacks that define handlers,
    so only relevant callbacks are invoked per event.

    Available events (each receives args, state, control, **kwargs and
    may return None or an updated TrainerControl):

        on_init_end          - After trainer initialization
        on_train_begin       - Before training loop starts
        on_train_end         - After training loop ends
        on_epoch_begin       - Before each epoch
        on_epoch_end         - After each epoch
        on_step_begin        - Before each training step
        on_step_end          - After each training step
        on_substep_end       - After each gradient-accumulation sub-step
        on_forward_backward_begin - Before each forward+backward micro-step
                               (inside gradient accumulation loop, after data
                               loading; fires once per micro-batch)
        on_forward_backward_end   - After each forward+backward micro-step
                               (before optimizer, grad clipping, LR scheduler;
                               fires once per micro-batch)
        on_optimizer_step    - After optimizer.step()
        on_pre_optimizer_step - Before optimizer.step()
        on_evaluate          - After evaluation
        on_predict           - After prediction (also receives metrics)
        on_prediction_step   - After each prediction batch
        on_save              - After checkpoint save
        on_log               - After metric logging (receives logs kwarg)

    Forgather extensions (not in HF Trainer):

        on_log_step          - Called before on_log; receives the mutable logs dict
                               so callbacks can inject custom metrics before logging
        on_train_metrics     - Called each training step with per-step metrics
                               (loss, grad_norm, tokens, etc.) for fine-grained
                               monitoring or adaptive control

    kwargs always include:
        model, processing_class, optimizer, lr_scheduler,
        train_dataloader, eval_dataloader, trainer

    Compatible with HuggingFace TrainerCallback for easier porting.
    See: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py

    Identification:
        Each callback has a `name` property used for logging (e.g. when
        the trainer reports which callback requested an early stop).
        Subclasses inherit the default, which returns the class name, or
        can override `name` to provide a more descriptive label (useful
        when multiple instances of the same class are registered with
        different configurations).
    """

    @property
    def name(self) -> str:
        """Human-readable identifier for this callback, used in log messages."""
        return type(self).__name__

name property

Human-readable identifier for this callback, used in log messages.


Built-in Callbacks

Default callbacks included automatically by all trainers.

forgather.ml.trainer.callbacks.DefaultMetrics

Bases: TrainerCallback

Compute derived performance metrics and inject them into logs.

Runs via on_log_step (before on_log), so computed values are available to all downstream loggers (ProgressCallback, TBLogger, etc.).

Computed metrics: tok_per_sec -- tokens processed per wall-clock second between log steps. mfu -- Model FLOPs Utilization (requires peak_hardware_flops). peak_mem -- per-rank peak CUDA memory allocated (list of bytes), aliased from peak_mem_allocated for display formatting (default reduction: max across ranks).

Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
class DefaultMetrics(TrainerCallback):
    """Compute derived performance metrics and inject them into logs.

    Runs via ``on_log_step`` (before ``on_log``), so computed values are
    available to all downstream loggers (ProgressCallback, TBLogger, etc.).

    Computed metrics:
        tok_per_sec   -- tokens processed per wall-clock second between log steps.
        mfu           -- Model FLOPs Utilization (requires *peak_hardware_flops*).
        peak_mem      -- per-rank peak CUDA memory allocated (list of bytes),
                         aliased from ``peak_mem_allocated`` for display
                         formatting (default reduction: max across ranks).
    """

    def __init__(
        self,
        peak_hardware_flops: Optional[float] = None,
    ):
        """
        Parameters
        ----------
        peak_hardware_flops : float, optional
            Aggregate peak BF16 FLOP/s across all GPUs used in training,
            used to compute MFU. Must be the total across all ranks since
            ``total_flos`` accounts for tokens processed across all ranks.
            When ``None``, MFU is not computed.
            Example values (dense BF16, FP32 accumulate):
            Single RTX 4090: 165.2e12; Single RTX 3090: 71.2e12;
            4x RTX 4090: 660.8e12; A100 SXM: 312e12; H100 SXM: 989e12.
        """
        super().__init__()
        self.peak_hardware_flops = peak_hardware_flops

        # Wall-clock time at each log step, for tok/s end-to-end throughput.
        self._last_log_time: Optional[float] = None
        # Training step timing (on_step_begin/end), for FLOPs/MFU.
        # Records the start time of each training step.
        self._step_start_time: Optional[float] = None
        self._accumulated_train_time: float = 0.0
        self._last_total_flos: float = 0.0

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        self._last_log_time = None
        self._last_total_flos = state.total_flos
        self._accumulated_train_time = 0.0
        self._step_start_time = None

    def on_step_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        self._step_start_time = time.monotonic()

    def on_step_end(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        if self._step_start_time is not None:
            self._accumulated_train_time += time.monotonic() - self._step_start_time
            self._step_start_time = None

    def on_log_step(self, state, logs, **kwargs):
        if not state.is_world_process_zero:
            return

        now = time.monotonic()

        if "tokens" in logs:
            # tok/s: wall-clock throughput between log steps
            if self._last_log_time is not None:
                wall_elapsed = now - self._last_log_time
                if wall_elapsed > 0:
                    logs["tok_per_sec"] = round(logs["tokens"] / wall_elapsed)

            # MFU: hardware utilization during forward/backward only
            train_elapsed = self._accumulated_train_time
            if self.peak_hardware_flops is not None and "total_flos" in logs:
                if train_elapsed > 0:
                    delta_flos = logs["total_flos"] - self._last_total_flos
                    if delta_flos > 0:
                        achieved_flops = delta_flos / train_elapsed
                        logs["mfu"] = achieved_flops / self.peak_hardware_flops

        # Peak CUDA memory: alias for display formatting
        peak_mem = logs.get("peak_mem_allocated")
        if peak_mem is not None:
            logs["peak_mem"] = peak_mem

        # Reset interval tracking for the next log period
        self._last_log_time = now
        self._accumulated_train_time = 0.0
        self._last_total_flos = logs.get("total_flos", self._last_total_flos)

__init__(peak_hardware_flops=None)

Parameters:

Name Type Description Default
peak_hardware_flops float

Aggregate peak BF16 FLOP/s across all GPUs used in training, used to compute MFU. Must be the total across all ranks since total_flos accounts for tokens processed across all ranks. When None, MFU is not computed. Example values (dense BF16, FP32 accumulate): Single RTX 4090: 165.2e12; Single RTX 3090: 71.2e12; 4x RTX 4090: 660.8e12; A100 SXM: 312e12; H100 SXM: 989e12.

None
Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
def __init__(
    self,
    peak_hardware_flops: Optional[float] = None,
):
    """
    Parameters
    ----------
    peak_hardware_flops : float, optional
        Aggregate peak BF16 FLOP/s across all GPUs used in training,
        used to compute MFU. Must be the total across all ranks since
        ``total_flos`` accounts for tokens processed across all ranks.
        When ``None``, MFU is not computed.
        Example values (dense BF16, FP32 accumulate):
        Single RTX 4090: 165.2e12; Single RTX 3090: 71.2e12;
        4x RTX 4090: 660.8e12; A100 SXM: 312e12; H100 SXM: 989e12.
    """
    super().__init__()
    self.peak_hardware_flops = peak_hardware_flops

    # Wall-clock time at each log step, for tok/s end-to-end throughput.
    self._last_log_time: Optional[float] = None
    # Training step timing (on_step_begin/end), for FLOPs/MFU.
    # Records the start time of each training step.
    self._step_start_time: Optional[float] = None
    self._accumulated_train_time: float = 0.0
    self._last_total_flos: float = 0.0

forgather.ml.trainer.callbacks.ProgressCallback

Bases: TrainerCallback

A TQDM progress-bar callback class based upon: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py

Controls which metrics are displayed in console logs during training via configurable column specifications. All metrics are still logged to JsonLogger regardless of display settings.

Derived performance metrics (tok/s, MFU, peak_mem) are computed by DefaultMetrics via on_log_step and are available in the logs dict by the time on_log fires.

Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
class ProgressCallback(TrainerCallback):
    """
    A TQDM progress-bar callback class based upon:
    https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_callback.py

    Controls which metrics are displayed in console logs during training
    via configurable column specifications.  All metrics are still logged
    to JsonLogger regardless of display settings.

    Derived performance metrics (tok/s, MFU, peak_mem) are computed by
    ``DefaultMetrics`` via ``on_log_step`` and are available in the logs
    dict by the time ``on_log`` fires.
    """

    def __init__(
        self,
        use_tqdm: Optional[bool] = None,
        output_stream: Optional[OutputStream] = None,
        step_columns: Optional[dict] = None,
        final_metrics: Optional[dict] = None,
        header_interval: int = 20,
    ):
        """
        Parameters
        ----------
        use_tqdm : bool, optional
            If ``True``, use TQDM; if ``False``, use logging; if ``None``,
            auto-select.
        output_stream : OutputStream, optional
            The output stream to use when not using TQDM.
        step_columns : dict, optional
            Column spec overrides merged with ``default_step_columns()``.
            Each key maps a metric name to a dict of ``ColumnSpec`` fields
            (``label``, ``width``, ``fmt``). Set a key to ``None`` to erase
            that column from the defaults. When ``None``, uses defaults
            unmodified. Column order follows insertion order of the merged
            result. Only columns whose key appears in the current log entry
            are shown.
        final_metrics : dict, optional
            Final metric spec overrides merged with
            ``default_final_metrics()``. Each key maps a metric name to a
            dict of ``FinalMetricSpec`` fields (``label``, ``fmt``,
            ``suffix``). Set a key to ``None`` to erase that metric. When
            ``None``, uses defaults unmodified.
        header_interval : int, optional
            Print a column header row every this many log steps, and also
            whenever the set of active columns changes. Default is ``20``.
        """
        super().__init__()
        self.train_progress_bar = None
        self.eval_progress_bar = None
        self.header_interval = header_interval

        # Merge step_columns overrides with defaults, then convert to ColumnSpec list.
        merged_columns = _merge_spec_dicts(default_step_columns(), step_columns)
        self.step_columns: list[ColumnSpec] = _normalize_columns(merged_columns)

        # Merge final_metrics overrides with defaults, then convert to FinalMetricSpec list.
        merged_final = _merge_spec_dicts(default_final_metrics(), final_metrics)
        self.final_metrics: list[FinalMetricSpec] = _normalize_final_metrics(
            merged_final
        )

        self._column_keys: frozenset[str] = frozenset(c.key for c in self.step_columns)

        # Column header tracking: print header every header_interval rows and
        # whenever the active column set changes.
        self._log_row_count: int = 0
        self._last_active_keys: frozenset[str] = frozenset()

        # Remember actual eval steps from previous run for accurate progress bar
        self._last_eval_steps: Optional[int] = None

        if use_tqdm is None:
            self.use_tqdm = get_env_type() != "file"
        else:
            self.use_tqdm = use_tqdm

        if not self.use_tqdm:
            self.logger = logging.getLogger("progress_logger")
            self.logger.setLevel(logging.INFO)
            self.logger.propagate = False

            console_handler = logging.StreamHandler(
                self._get_output_stream(output_stream)
            )
            log_format = logging.Formatter(
                fmt="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
            )
            console_handler.setFormatter(log_format)
            self.logger.addHandler(console_handler)

    @staticmethod
    def _get_output_stream(output_stream: Optional[OutputStream]) -> TextIOBase:
        if output_stream is None:
            # sys.stdout satisfies the TextIOBase interface at runtime
            return cast(TextIOBase, sys.stdout)  # type: ignore[return-value]
        elif isinstance(output_stream, TextIOBase):
            return output_stream
        else:
            assert isinstance(output_stream, str)
            if output_stream == "stderr":
                # sys.stderr satisfies the TextIOBase interface at runtime
                return cast(TextIOBase, sys.stderr)  # type: ignore[return-value]
            elif output_stream == "stdout":
                # sys.stdout satisfies the TextIOBase interface at runtime
                return cast(TextIOBase, sys.stdout)  # type: ignore[return-value]
            else:
                raise ValueError("Must be one of 'stderr' or 'stdout'")

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        self.last_step = state.global_step
        self._log_row_count = 0
        self._last_active_keys = frozenset()
        if self.use_tqdm:
            self.train_progress_bar = tqdm(
                initial=state.global_step,
                smoothing=0.03,
                total=state.max_steps,
                dynamic_ncols=True,
            )

    def on_train_end(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        if self.use_tqdm:
            if self.train_progress_bar is not None:
                self.train_progress_bar.close()
            self.train_progress_bar = None

    def on_step_end(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        if self.use_tqdm:
            self.train_progress_bar.update(state.global_step - self.last_step)
        self.last_step = state.global_step

    def on_prediction_step(self, args, state, control, eval_dataloader, **kwargs):
        if not state.is_world_process_zero:
            return
        if self.use_tqdm:
            if self.eval_progress_bar is None:
                if self._last_eval_steps is not None:
                    total = self._last_eval_steps
                else:
                    max_eval_steps = getattr(state, "max_eval_steps", -1)
                    total = max(len(eval_dataloader), max_eval_steps, 1)
                self.eval_progress_bar = tqdm(
                    initial=1,
                    total=total,
                    leave=self.train_progress_bar is None,
                    dynamic_ncols=True,
                )
            else:
                self.eval_progress_bar.update(1)

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        if not state.is_world_process_zero:
            return
        if self.use_tqdm:
            if self.eval_progress_bar is not None:
                # Remember actual step count for next eval's progress bar
                self._last_eval_steps = self.eval_progress_bar.n
                self.eval_progress_bar.write(
                    format_timestamp() + format_eval_log(state, metrics)
                )
                self.eval_progress_bar.close()
                self.eval_progress_bar = None
        else:
            self.logger.info(format_eval_log(state, metrics))

    def on_log(self, args, state, control, logs, **kwargs):
        if not state.is_world_process_zero:
            return

        # Final training metrics get their own formatted summary
        if "train_runtime" in logs:
            summary = format_final_metrics(logs, self.final_metrics)
            if self.use_tqdm:
                if self.train_progress_bar is not None:
                    self.train_progress_bar.write(format_timestamp() + summary)
                else:
                    tqdm.write(format_timestamp() + summary)
            else:
                self.logger.info(summary)
            return

        # Computed metrics (tok_per_sec, mfu, peak_mem, etc.) are already
        # in logs, injected by DefaultMetrics via on_log_step.

        # Filter to only keys present in step_columns
        display_metrics = {k: v for k, v in logs.items() if k in self._column_keys}

        # Print a column header when the interval fires or the active column set changes
        active_keys = frozenset(display_metrics)
        if (
            self._log_row_count % self.header_interval == 0
            or active_keys != self._last_active_keys
        ):
            header_line = format_train_header(self.step_columns, display_metrics)
            if self.use_tqdm:
                if self.train_progress_bar is not None:
                    self.train_progress_bar.write(format_timestamp() + header_line)
            else:
                self.logger.info(header_line)
            self._last_active_keys = active_keys
        self._log_row_count += 1

        if self.use_tqdm:
            if self.train_progress_bar is not None:
                # Update steps, if max steps changes
                if self.train_progress_bar.total != state.max_steps:
                    self.train_progress_bar.total = state.max_steps
                    self.train_progress_bar.refresh()
                self.train_progress_bar.write(
                    format_timestamp()
                    + format_train_log(state, self.step_columns, display_metrics)
                )
        else:
            self.logger.info(
                format_train_log(state, self.step_columns, display_metrics)
            )

__init__(use_tqdm=None, output_stream=None, step_columns=None, final_metrics=None, header_interval=20)

Parameters:

Name Type Description Default
use_tqdm bool

If True, use TQDM; if False, use logging; if None, auto-select.

None
output_stream OutputStream

The output stream to use when not using TQDM.

None
step_columns dict

Column spec overrides merged with default_step_columns(). Each key maps a metric name to a dict of ColumnSpec fields (label, width, fmt). Set a key to None to erase that column from the defaults. When None, uses defaults unmodified. Column order follows insertion order of the merged result. Only columns whose key appears in the current log entry are shown.

None
final_metrics dict

Final metric spec overrides merged with default_final_metrics(). Each key maps a metric name to a dict of FinalMetricSpec fields (label, fmt, suffix). Set a key to None to erase that metric. When None, uses defaults unmodified.

None
header_interval int

Print a column header row every this many log steps, and also whenever the set of active columns changes. Default is 20.

20
Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
def __init__(
    self,
    use_tqdm: Optional[bool] = None,
    output_stream: Optional[OutputStream] = None,
    step_columns: Optional[dict] = None,
    final_metrics: Optional[dict] = None,
    header_interval: int = 20,
):
    """
    Parameters
    ----------
    use_tqdm : bool, optional
        If ``True``, use TQDM; if ``False``, use logging; if ``None``,
        auto-select.
    output_stream : OutputStream, optional
        The output stream to use when not using TQDM.
    step_columns : dict, optional
        Column spec overrides merged with ``default_step_columns()``.
        Each key maps a metric name to a dict of ``ColumnSpec`` fields
        (``label``, ``width``, ``fmt``). Set a key to ``None`` to erase
        that column from the defaults. When ``None``, uses defaults
        unmodified. Column order follows insertion order of the merged
        result. Only columns whose key appears in the current log entry
        are shown.
    final_metrics : dict, optional
        Final metric spec overrides merged with
        ``default_final_metrics()``. Each key maps a metric name to a
        dict of ``FinalMetricSpec`` fields (``label``, ``fmt``,
        ``suffix``). Set a key to ``None`` to erase that metric. When
        ``None``, uses defaults unmodified.
    header_interval : int, optional
        Print a column header row every this many log steps, and also
        whenever the set of active columns changes. Default is ``20``.
    """
    super().__init__()
    self.train_progress_bar = None
    self.eval_progress_bar = None
    self.header_interval = header_interval

    # Merge step_columns overrides with defaults, then convert to ColumnSpec list.
    merged_columns = _merge_spec_dicts(default_step_columns(), step_columns)
    self.step_columns: list[ColumnSpec] = _normalize_columns(merged_columns)

    # Merge final_metrics overrides with defaults, then convert to FinalMetricSpec list.
    merged_final = _merge_spec_dicts(default_final_metrics(), final_metrics)
    self.final_metrics: list[FinalMetricSpec] = _normalize_final_metrics(
        merged_final
    )

    self._column_keys: frozenset[str] = frozenset(c.key for c in self.step_columns)

    # Column header tracking: print header every header_interval rows and
    # whenever the active column set changes.
    self._log_row_count: int = 0
    self._last_active_keys: frozenset[str] = frozenset()

    # Remember actual eval steps from previous run for accurate progress bar
    self._last_eval_steps: Optional[int] = None

    if use_tqdm is None:
        self.use_tqdm = get_env_type() != "file"
    else:
        self.use_tqdm = use_tqdm

    if not self.use_tqdm:
        self.logger = logging.getLogger("progress_logger")
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False

        console_handler = logging.StreamHandler(
            self._get_output_stream(output_stream)
        )
        log_format = logging.Formatter(
            fmt="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
        )
        console_handler.setFormatter(log_format)
        self.logger.addHandler(console_handler)

forgather.ml.trainer.callbacks.InfoCallback

Bases: TrainerCallback

Source code in src/forgather/ml/trainer/callbacks/default_callbacks.py
class InfoCallback(TrainerCallback):
    def __init__(self, verbose: bool = False):
        self.verbose = verbose
        self.logger = logging.getLogger("info_logger")
        if verbose:
            self.logger.setLevel(logging.DEBUG)
        else:
            self.logger.setLevel(logging.INFO)

        self.logger.propagate = False

        console_handler = logging.StreamHandler(sys.stdout)
        log_format = logging.Formatter(fmt="[%(levelname)s|%(name)s] %(message)s")
        console_handler.setFormatter(log_format)
        self.logger.addHandler(console_handler)

    def on_train_begin(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if not state.is_world_process_zero:
            return
        info, extra_info = format_train_info(args, state, control, **kwargs)
        self.logger.info("\n" + format_mapping(info))
        self.logger.debug("\n" + format_mapping(extra_info))

Job Control

forgather.ml.trainer.callbacks.TrainerControlCallback

Bases: TrainerCallback

Callback that enables external control of training jobs via HTTP API.

Features: - Graceful stop: Stop training cleanly after current step - Save checkpoint: Trigger checkpoint save (with evaluation if needed) - Save and stop: Save checkpoint then stop training - Status queries: Get current training status

Only rank 0 runs the HTTP server. Commands are broadcast to all ranks via torch.distributed for coordination.

Source code in src/forgather/ml/trainer/callbacks/control_callback.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
class TrainerControlCallback(TrainerCallback):
    """
    Callback that enables external control of training jobs via HTTP API.

    Features:
    - Graceful stop: Stop training cleanly after current step
    - Save checkpoint: Trigger checkpoint save (with evaluation if needed)
    - Save and stop: Save checkpoint then stop training
    - Status queries: Get current training status

    Only rank 0 runs the HTTP server. Commands are broadcast to all ranks
    via torch.distributed for coordination.
    """

    def __init__(
        self,
        job_id: Optional[str] = None,
        port: Optional[int] = None,
        enable_http: Optional[bool] = None,
        host: Optional[str] = None,
        auth_token: Optional[str] = None,
        disable_auth: bool = False,
    ):
        """
        Initialize the control callback.

        Parameters
        ----------
        job_id : str, optional
            Unique identifier for this training job. Auto-generated if ``None``.
        port : int, optional
            HTTP server port. Auto-selected if ``None``.
        enable_http : bool, optional
            Whether to enable HTTP server. Auto-detected based on ``aiohttp``
            availability.
        host : str, optional
            Bind address. Defaults to ``127.0.0.1`` (loopback only). Pass
            ``"0.0.0.0"`` to expose the control endpoint on every interface;
            a warning is logged in that case.
        auth_token : str, optional
            Pre-shared bearer token. Generated via ``secrets.token_hex(32)``
            when ``None`` (the common case). Persisted to
            ``<forgather_config_dir>/jobs/{job_id}/auth_token`` at mode ``0o600`` so
            local clients (CLI, server proxy) can read it back.
        disable_auth : bool, optional
            Skip bearer-token enforcement on /control, /status, /jobs.
            Off by default; the trainer logs a warning when this is set.
        """
        super().__init__()

        if enable_http is None:
            enable_http = AIOHTTP_AVAILABLE

        if enable_http and not AIOHTTP_AVAILABLE:
            logger.warning("aiohttp not available, disabling HTTP control server")
            enable_http = False

        self.enable_http = enable_http
        self.job_id = job_id or self._generate_job_id()
        self.port = port
        self.host = host if host is not None else "127.0.0.1"
        self.control_dir = Path(forgather_config_dir()) / "jobs" / self.job_id

        # Auth state. Token is set lazily on rank 0 in on_train_begin so we
        # don't burn entropy in non-rank-0 processes that never serve.
        self.disable_auth = disable_auth
        self.auth_token: Optional[str] = auth_token

        # Command handling
        self.command_queue: Optional[asyncio.Queue] = None
        self.pending_commands: List[ControlCommand] = []

        # HTTP server
        self.server_task: Optional[asyncio.Task] = None
        self.server_runner: Optional[aiohttp.web.AppRunner] = None
        self.server_thread: Optional[threading.Thread] = None
        self.event_loop: Optional[asyncio.AbstractEventLoop] = None

        # State tracking
        self.trainer_args: Optional[MinimalTrainingArguments] = None
        self.trainer_state: Optional[TrainerState] = None
        self.last_status: Dict[str, Any] = {}

    def _generate_job_id(self) -> str:
        """Generate a unique job ID."""
        import platform

        timestamp = int(time.time())
        hostname = platform.node()
        pid = os.getpid()
        return f"job_{timestamp}_{hostname}_{pid}"

    def _find_available_port(
        self, start_port: int = 8900, max_attempts: int = 100
    ) -> int:
        """Find an available port for the HTTP server."""
        for i in range(max_attempts):
            port = start_port + i
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                continue
        raise RuntimeError(
            f"Could not find available port in range {start_port}-{start_port + max_attempts}"
        )

    def _ensure_control_dir(self) -> None:
        """Create control_dir at 0o700 (best-effort chmod after mkdir)."""
        self.control_dir.mkdir(parents=True, exist_ok=True)
        try:
            os.chmod(self.control_dir, 0o700)
        except OSError:
            pass

    def _write_auth_token(self) -> None:
        """Persist the auth token at mode 0o600 next to endpoint.json.

        Local CLI/server clients read this file to attach the bearer header.
        Keeping it out of endpoint.json avoids duplicating the secret.
        """
        if self.disable_auth or not self.auth_token:
            return
        self._ensure_control_dir()
        token_file = self.control_dir / "auth_token"
        tmp = token_file.with_suffix(token_file.suffix + ".tmp")
        with open(tmp, "w") as f:
            f.write(self.auth_token)
            f.flush()
            os.fsync(f.fileno())
        try:
            os.chmod(tmp, 0o600)
        except OSError:
            pass
        os.replace(tmp, token_file)

    def _write_endpoint_file(self, port: int):
        """Write endpoint information for service discovery."""
        self._ensure_control_dir()
        endpoint_file = self.control_dir / "endpoint.json"

        # Also publish the run's output_dir and logging_dir so external
        # tooling (e.g. the Forgather server) can locate checkpoints, log
        # files, and stdout/stderr without having to re-materialize the
        # config. Both come from the trainer's already-resolved args.
        output_dir = None
        logging_dir = None
        if self.trainer_args is not None:
            raw_output = getattr(self.trainer_args, "output_dir", None)
            raw_logging = getattr(self.trainer_args, "logging_dir", None)
            output_dir = os.path.abspath(raw_output) if raw_output else None
            logging_dir = os.path.abspath(raw_logging) if raw_logging else None

        # ``host`` is the actual bind address. Earlier versions wrote
        # ``platform.node()`` (the FQDN), which is wrong for a loopback
        # bind — clients hitting the FQDN would get connection-refused.
        endpoint_info = {
            "job_id": self.job_id,
            "host": self.host,
            "port": port,
            "pid": os.getpid(),
            "started_at": time.time(),
            "output_dir": output_dir,
            "logging_dir": logging_dir,
        }

        tmp = endpoint_file.with_suffix(endpoint_file.suffix + ".tmp")
        with open(tmp, "w") as f:
            json.dump(endpoint_info, f, indent=2)
            f.flush()
            os.fsync(f.fileno())
        try:
            os.chmod(tmp, 0o600)
        except OSError:
            pass
        os.replace(tmp, endpoint_file)

        logger.info(
            f"Trainer control endpoint: http://{self.host}:{port}/jobs/{self.job_id}"
        )

    def _setup_signal_handler(self):
        """Setup signal handler for lightweight command notification."""

        def signal_handler(signum, frame):
            if self.event_loop and not self.event_loop.is_closed():
                self.event_loop.call_soon_threadsafe(self._check_for_signals)

        signal.signal(signal.SIGUSR1, signal_handler)

    def _check_for_signals(self):
        """Check for signal-based commands (placeholder for future enhancement)."""
        pass

    def _get_device(self):
        """Get the correct device for tensor operations in distributed training."""
        if (
            self.trainer_args
            and hasattr(self.trainer_args, "device")
            and self.trainer_args.device is not None
        ):
            return self.trainer_args.device
        elif torch.cuda.is_available():
            return torch.cuda.current_device()
        else:
            return "cpu"

    def _make_auth_middleware(self):
        """Bearer-token middleware. Returns 401 on missing/invalid token.

        Uses ``hmac.compare_digest`` so a wrong-length token doesn't leak
        timing info. The realm string is mostly cosmetic but lets curl
        users know which endpoint refused them.
        """

        @aiohttp.web.middleware
        async def auth_middleware(request, handler):
            if self.disable_auth or not self.auth_token:
                return await handler(request)
            header = request.headers.get("Authorization", "")
            expected_prefix = "Bearer "
            if not header.startswith(expected_prefix) or not hmac.compare_digest(
                header[len(expected_prefix) :], self.auth_token
            ):
                return aiohttp.web.json_response(
                    {"detail": "authentication required"},
                    status=401,
                    headers={"WWW-Authenticate": 'Bearer realm="forgather-trainer"'},
                )
            return await handler(request)

        return auth_middleware

    async def _run_http_server(self):
        """Run the HTTP server in async mode."""
        try:
            self.command_queue = asyncio.Queue()

            middlewares = []
            if not self.disable_auth and self.auth_token:
                middlewares.append(self._make_auth_middleware())

            app = aiohttp.web.Application(middlewares=middlewares)
            app.router.add_post(
                f"/jobs/{self.job_id}/control", self._handle_control_request
            )
            app.router.add_get(
                f"/jobs/{self.job_id}/status", self._handle_status_request
            )
            app.router.add_get("/jobs", self._handle_list_jobs)

            # access_log=None suppresses the per-request INFO line that the
            # Forgather server polls every few seconds for /status; it clutters
            # training output with no useful signal.
            self.server_runner = aiohttp.web.AppRunner(app, access_log=None)
            await self.server_runner.setup()

            if self.port is None:
                self.port = self._find_available_port()

            site = aiohttp.web.TCPSite(self.server_runner, self.host, self.port)
            await site.start()

            self._write_endpoint_file(self.port)
            logger.info(f"Trainer control server started on port {self.port}")

            # Keep server running
            while True:
                await asyncio.sleep(1)

        except asyncio.CancelledError:
            logger.info("HTTP server shutting down")
            raise
        except Exception as e:
            logger.error(f"HTTP server error: {e}")
            raise

    async def _start_http_server(self):
        """Schedule the HTTP server as a cancellable background task."""
        self.server_task = asyncio.create_task(self._run_http_server())

    def _run_server_thread(self):
        """Run HTTP server in separate thread."""
        asyncio.set_event_loop(self.event_loop)
        assert self.event_loop is not None
        try:
            self.event_loop.run_forever()
        except Exception as e:
            logger.error(f"Server thread error: {e}")

    async def _handle_control_request(
        self, request: aiohttp.web.Request
    ) -> aiohttp.web.Response:
        """Handle incoming control commands."""
        try:
            data = await request.json()
            command = data.get("command")

            if command not in COMMAND_CODES:
                return aiohttp.web.json_response(
                    {
                        "error": f"Unknown command: {command}",
                        "valid_commands": list(COMMAND_CODES.keys()),
                    },
                    status=400,
                )

            control_command = ControlCommand(
                command=command, timestamp=time.time(), data=data.get("data", {})
            )

            command_queue = self.command_queue
            assert command_queue is not None
            await command_queue.put(control_command)

            return aiohttp.web.json_response(
                {
                    "status": "acknowledged",
                    "command": command,
                    "timestamp": control_command.timestamp,
                    "message": f"Command {command} queued for execution",
                }
            )

        except Exception as e:
            logger.error(f"Error handling control request: {e}")
            return aiohttp.web.json_response({"error": str(e)}, status=500)

    async def _handle_status_request(
        self, request: aiohttp.web.Request
    ) -> aiohttp.web.Response:
        """Handle status requests."""
        try:
            status = {
                "job_id": self.job_id,
                "status": "running",
                "timestamp": time.time(),
                **self.last_status,
            }

            if self.trainer_state:
                status.update(
                    {
                        "global_step": self.trainer_state.global_step,
                        "epoch": self.trainer_state.epoch,
                        "max_steps": self.trainer_state.max_steps,
                    }
                )

            return aiohttp.web.json_response(status)

        except Exception as e:
            logger.error(f"Error handling status request: {e}")
            return aiohttp.web.json_response({"error": str(e)}, status=500)

    async def _handle_list_jobs(
        self, request: aiohttp.web.Request
    ) -> aiohttp.web.Response:
        """Handle job listing requests."""
        try:
            jobs_dir = Path(forgather_config_dir()) / "jobs"
            jobs = []

            if jobs_dir.exists():
                for job_dir in jobs_dir.iterdir():
                    if job_dir.is_dir():
                        endpoint_file = job_dir / "endpoint.json"
                        if endpoint_file.exists():
                            try:
                                with open(endpoint_file) as f:
                                    job_info = json.load(f)
                                    jobs.append(job_info)
                            except Exception as e:
                                logger.warning(
                                    f"Could not read job info from {endpoint_file}: {e}"
                                )

            return aiohttp.web.json_response({"jobs": jobs})

        except Exception as e:
            logger.error(f"Error listing jobs: {e}")
            return aiohttp.web.json_response({"error": str(e)}, status=500)

    def _check_commands_non_blocking(self):
        """Check for new commands without blocking."""
        if not self.command_queue:
            return

        try:
            while True:
                try:
                    command = self.command_queue.get_nowait()
                    self.pending_commands.append(command)
                    logger.info(f"Received command: {command.command}")
                except asyncio.QueueEmpty:
                    break
        except Exception as e:
            logger.error(f"Error checking commands: {e}")

    def _pack_commands_for_broadcast(
        self, commands: List[ControlCommand]
    ) -> torch.Tensor:
        """Pack commands into tensor for broadcast."""
        if not commands:
            data = [0]
        else:
            # Pack: [num_commands, cmd1_code, cmd1_timestamp, cmd2_code, cmd2_timestamp, ...]
            data = [len(commands)]
            for cmd in commands:
                data.extend([COMMAND_CODES[cmd.command], int(cmd.timestamp)])

        # Create tensor on correct device for distributed communication
        device = self._get_device()
        return torch.tensor(data, dtype=torch.long, device=device)

    def _unpack_commands_from_broadcast(
        self, tensor: torch.Tensor
    ) -> List[ControlCommand]:
        """Unpack commands from broadcast tensor."""
        data = tensor.tolist()
        if not data or data[0] == 0:
            return []

        commands = []
        num_commands = data[0]
        idx = 1

        for _ in range(num_commands):
            command_code = data[idx]
            timestamp = data[idx + 1]
            idx += 2

            command = ControlCommand(
                command=COMMAND_NAMES[command_code], timestamp=float(timestamp)
            )
            commands.append(command)

        return commands

    def _broadcast_and_handle_commands(self, control: TrainerControl) -> TrainerControl:
        """Broadcast commands to all ranks and handle them."""
        commands_to_process = []

        if not torch.distributed.is_initialized():
            # Single process mode - just check for commands on rank 0
            if (
                self.trainer_state
                and self.trainer_state.is_world_process_zero
                and self.enable_http
            ):
                self._check_commands_non_blocking()
                commands_to_process = self.pending_commands.copy()
                self.pending_commands.clear()
        else:
            device = self._get_device()

            # Rank 0: Broadcast pending commands
            if self.trainer_state and self.trainer_state.is_world_process_zero:
                if self.enable_http:
                    self._check_commands_non_blocking()

                commands_tensor = self._pack_commands_for_broadcast(
                    self.pending_commands
                )

                # First broadcast the tensor size so other ranks can allocate correctly
                size_tensor = torch.tensor(
                    [commands_tensor.size(0)], dtype=torch.long, device=device
                )
                torch.distributed.broadcast(size_tensor, src=0)

                # Then broadcast the actual commands
                torch.distributed.broadcast(commands_tensor, src=0)

                commands_to_process = self.pending_commands.copy()
                self.pending_commands.clear()

            # All other ranks: Receive broadcast commands
            else:
                # First receive the size
                size_tensor = torch.tensor([0], dtype=torch.long, device=device)
                torch.distributed.broadcast(size_tensor, src=0)

                # Then receive the commands tensor with the correct size
                tensor_size = int(size_tensor.item())
                if tensor_size > 0:
                    commands_tensor = torch.zeros(
                        tensor_size, dtype=torch.long, device=device
                    )
                    torch.distributed.broadcast(commands_tensor, src=0)
                    commands_to_process = self._unpack_commands_from_broadcast(
                        commands_tensor
                    )

        # All ranks: Process commands
        for command in commands_to_process:
            control = self._apply_command(command, control)

        return control

    def _apply_command(
        self, command: ControlCommand, control: TrainerControl
    ) -> TrainerControl:
        """Apply a control command to the trainer control state."""
        logger.info(f"Applying command: {command.command}")

        if command.command == "graceful_stop":
            control.should_training_stop = True
            logger.info(
                "Graceful stop requested - training will stop after current step"
            )

        elif command.command == "save_checkpoint":
            control.should_save = True
            # If we're tracking best model, trigger evaluation first
            if (
                self.trainer_args
                and self.trainer_args.load_best_model_at_end
                and getattr(
                    self.trainer_args.eval_strategy,
                    "value",
                    self.trainer_args.eval_strategy,
                )
                != "no"
            ):
                control.should_evaluate = True
            logger.info("Checkpoint save requested")

        elif command.command == "save_and_stop":
            control.should_save = True
            if (
                self.trainer_args
                and self.trainer_args.load_best_model_at_end
                and getattr(
                    self.trainer_args.eval_strategy,
                    "value",
                    self.trainer_args.eval_strategy,
                )
                != "no"
            ):
                control.should_evaluate = True
            control.should_training_stop = True
            logger.info("Save and stop requested - will save checkpoint then stop")

        elif command.command == "abort":
            # Check if the TrainerControl supports the forgather extension
            if hasattr(control, "should_abort_without_save"):
                control.should_abort_without_save = True
                control.should_training_stop = True
                logger.info(
                    "Abort requested - training will stop WITHOUT saving checkpoint"
                )
            else:
                # Fallback for standard HF TrainerControl
                control.should_training_stop = True
                logger.warning(
                    "Abort requested but TrainerControl doesn't support abort_without_save - will stop gracefully"
                )

        return control

    # TrainerCallback interface methods

    def on_train_begin(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        """Initialize control system when training begins."""
        self.trainer_args = args
        self.trainer_state = state

        # Only rank 0 runs the HTTP server
        if state.is_world_process_zero and self.enable_http:
            try:
                # Generate the per-job bearer token (rank-0 only) and persist
                # it at 0o600 so local clients can read it back. Skip when
                # disable_auth is set; warn in that case so it's loud in
                # the log.
                loopback_hosts = {"127.0.0.1", "localhost", "::1"}
                if self.host not in loopback_hosts:
                    logger.warning(
                        "Trainer control endpoint binding to %s is exposed "
                        "beyond loopback; any user reaching this host:port "
                        "with the bearer token can save/abort the job.",
                        self.host,
                    )
                if self.disable_auth:
                    logger.warning(
                        "Trainer control endpoint started with auth DISABLED; "
                        "anyone who can reach %s:%s can save/abort this job.",
                        self.host,
                        self.port if self.port is not None else "<auto>",
                    )
                else:
                    if not self.auth_token:
                        self.auth_token = secrets.token_hex(32)
                    self._write_auth_token()

                self.event_loop = asyncio.new_event_loop()
                self.server_thread = threading.Thread(
                    target=self._run_server_thread, daemon=True
                )
                self.server_thread.start()

                # Give the thread time to start the loop, then schedule the server task
                time.sleep(0.1)
                asyncio.run_coroutine_threadsafe(
                    self._start_http_server(), self.event_loop
                ).result(timeout=5)

                logger.info(f"Trainer control system initialized for job {self.job_id}")

            except Exception as e:
                logger.error(f"Failed to start control system: {e}")
                self.enable_http = False

    def on_log(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        logs: Optional[dict] = None,
        **kwargs,
    ):
        """Check for control commands on each log event."""
        self.trainer_state = state

        # Update status for queries
        if logs:
            self.last_status.update(logs)
            self.last_status["timestamp"] = time.time()

        # Check for and broadcast commands
        return self._broadcast_and_handle_commands(control)

    def on_train_end(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        """Clean up when training ends."""
        if state.is_world_process_zero and self.enable_http:
            try:
                event_loop = self.event_loop
                if event_loop is None or event_loop.is_closed():
                    return

                # Shut down the aiohttp runner and cancel the server task on
                # the loop, awaiting the task so the CancelledError is actually
                # delivered before the loop stops. If we only scheduled the
                # cancel and then stopped the loop, run_forever() would return
                # with the task still pending, producing "Task was destroyed
                # but it is pending!".
                async def _shutdown_coro():
                    if self.server_runner is not None:
                        try:
                            await self.server_runner.cleanup()
                        except Exception as e:
                            logger.warning(f"aiohttp runner cleanup error: {e}")
                    if self.server_task and not self.server_task.done():
                        self.server_task.cancel()
                        try:
                            await self.server_task
                        except (asyncio.CancelledError, Exception):
                            pass

                asyncio.run_coroutine_threadsafe(_shutdown_coro(), event_loop).result(
                    timeout=10
                )

                # Task is done; now it's safe to stop the loop.
                event_loop.call_soon_threadsafe(event_loop.stop)

                # Wait for the server thread to exit cleanly
                if self.server_thread:
                    self.server_thread.join(timeout=10)

                # Clean up endpoint file and the per-job directory so
                # the jobs dir doesn't accumulate empty job_* dirs
                # over time. rmdir is best-effort — leave the dir behind
                # if anything else is still in there (e.g. unread
                # control/*.json command files from a flaky shutdown).
                endpoint_file = self.control_dir / "endpoint.json"
                if endpoint_file.exists():
                    endpoint_file.unlink()
                token_file = self.control_dir / "auth_token"
                if token_file.exists():
                    try:
                        token_file.unlink()
                    except OSError:
                        pass
                try:
                    self.control_dir.rmdir()
                except OSError:
                    pass

                logger.info("Trainer control system shutdown complete")

            except Exception as e:
                logger.warning(f"Error during control system shutdown: {e}")

__init__(job_id=None, port=None, enable_http=None, host=None, auth_token=None, disable_auth=False)

Initialize the control callback.

Parameters:

Name Type Description Default
job_id str

Unique identifier for this training job. Auto-generated if None.

None
port int

HTTP server port. Auto-selected if None.

None
enable_http bool

Whether to enable HTTP server. Auto-detected based on aiohttp availability.

None
host str

Bind address. Defaults to 127.0.0.1 (loopback only). Pass "0.0.0.0" to expose the control endpoint on every interface; a warning is logged in that case.

None
auth_token str

Pre-shared bearer token. Generated via secrets.token_hex(32) when None (the common case). Persisted to <forgather_config_dir>/jobs/{job_id}/auth_token at mode 0o600 so local clients (CLI, server proxy) can read it back.

None
disable_auth bool

Skip bearer-token enforcement on /control, /status, /jobs. Off by default; the trainer logs a warning when this is set.

False
Source code in src/forgather/ml/trainer/callbacks/control_callback.py
def __init__(
    self,
    job_id: Optional[str] = None,
    port: Optional[int] = None,
    enable_http: Optional[bool] = None,
    host: Optional[str] = None,
    auth_token: Optional[str] = None,
    disable_auth: bool = False,
):
    """
    Initialize the control callback.

    Parameters
    ----------
    job_id : str, optional
        Unique identifier for this training job. Auto-generated if ``None``.
    port : int, optional
        HTTP server port. Auto-selected if ``None``.
    enable_http : bool, optional
        Whether to enable HTTP server. Auto-detected based on ``aiohttp``
        availability.
    host : str, optional
        Bind address. Defaults to ``127.0.0.1`` (loopback only). Pass
        ``"0.0.0.0"`` to expose the control endpoint on every interface;
        a warning is logged in that case.
    auth_token : str, optional
        Pre-shared bearer token. Generated via ``secrets.token_hex(32)``
        when ``None`` (the common case). Persisted to
        ``<forgather_config_dir>/jobs/{job_id}/auth_token`` at mode ``0o600`` so
        local clients (CLI, server proxy) can read it back.
    disable_auth : bool, optional
        Skip bearer-token enforcement on /control, /status, /jobs.
        Off by default; the trainer logs a warning when this is set.
    """
    super().__init__()

    if enable_http is None:
        enable_http = AIOHTTP_AVAILABLE

    if enable_http and not AIOHTTP_AVAILABLE:
        logger.warning("aiohttp not available, disabling HTTP control server")
        enable_http = False

    self.enable_http = enable_http
    self.job_id = job_id or self._generate_job_id()
    self.port = port
    self.host = host if host is not None else "127.0.0.1"
    self.control_dir = Path(forgather_config_dir()) / "jobs" / self.job_id

    # Auth state. Token is set lazily on rank 0 in on_train_begin so we
    # don't burn entropy in non-rank-0 processes that never serve.
    self.disable_auth = disable_auth
    self.auth_token: Optional[str] = auth_token

    # Command handling
    self.command_queue: Optional[asyncio.Queue] = None
    self.pending_commands: List[ControlCommand] = []

    # HTTP server
    self.server_task: Optional[asyncio.Task] = None
    self.server_runner: Optional[aiohttp.web.AppRunner] = None
    self.server_thread: Optional[threading.Thread] = None
    self.event_loop: Optional[asyncio.AbstractEventLoop] = None

    # State tracking
    self.trainer_args: Optional[MinimalTrainingArguments] = None
    self.trainer_state: Optional[TrainerState] = None
    self.last_status: Dict[str, Any] = {}

on_train_begin(args, state, control, **kwargs)

Initialize control system when training begins.

Source code in src/forgather/ml/trainer/callbacks/control_callback.py
def on_train_begin(
    self,
    args: MinimalTrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    **kwargs,
):
    """Initialize control system when training begins."""
    self.trainer_args = args
    self.trainer_state = state

    # Only rank 0 runs the HTTP server
    if state.is_world_process_zero and self.enable_http:
        try:
            # Generate the per-job bearer token (rank-0 only) and persist
            # it at 0o600 so local clients can read it back. Skip when
            # disable_auth is set; warn in that case so it's loud in
            # the log.
            loopback_hosts = {"127.0.0.1", "localhost", "::1"}
            if self.host not in loopback_hosts:
                logger.warning(
                    "Trainer control endpoint binding to %s is exposed "
                    "beyond loopback; any user reaching this host:port "
                    "with the bearer token can save/abort the job.",
                    self.host,
                )
            if self.disable_auth:
                logger.warning(
                    "Trainer control endpoint started with auth DISABLED; "
                    "anyone who can reach %s:%s can save/abort this job.",
                    self.host,
                    self.port if self.port is not None else "<auto>",
                )
            else:
                if not self.auth_token:
                    self.auth_token = secrets.token_hex(32)
                self._write_auth_token()

            self.event_loop = asyncio.new_event_loop()
            self.server_thread = threading.Thread(
                target=self._run_server_thread, daemon=True
            )
            self.server_thread.start()

            # Give the thread time to start the loop, then schedule the server task
            time.sleep(0.1)
            asyncio.run_coroutine_threadsafe(
                self._start_http_server(), self.event_loop
            ).result(timeout=5)

            logger.info(f"Trainer control system initialized for job {self.job_id}")

        except Exception as e:
            logger.error(f"Failed to start control system: {e}")
            self.enable_http = False

on_log(args, state, control, logs=None, **kwargs)

Check for control commands on each log event.

Source code in src/forgather/ml/trainer/callbacks/control_callback.py
def on_log(
    self,
    args: MinimalTrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    logs: Optional[dict] = None,
    **kwargs,
):
    """Check for control commands on each log event."""
    self.trainer_state = state

    # Update status for queries
    if logs:
        self.last_status.update(logs)
        self.last_status["timestamp"] = time.time()

    # Check for and broadcast commands
    return self._broadcast_and_handle_commands(control)

on_train_end(args, state, control, **kwargs)

Clean up when training ends.

Source code in src/forgather/ml/trainer/callbacks/control_callback.py
def on_train_end(
    self,
    args: MinimalTrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    **kwargs,
):
    """Clean up when training ends."""
    if state.is_world_process_zero and self.enable_http:
        try:
            event_loop = self.event_loop
            if event_loop is None or event_loop.is_closed():
                return

            # Shut down the aiohttp runner and cancel the server task on
            # the loop, awaiting the task so the CancelledError is actually
            # delivered before the loop stops. If we only scheduled the
            # cancel and then stopped the loop, run_forever() would return
            # with the task still pending, producing "Task was destroyed
            # but it is pending!".
            async def _shutdown_coro():
                if self.server_runner is not None:
                    try:
                        await self.server_runner.cleanup()
                    except Exception as e:
                        logger.warning(f"aiohttp runner cleanup error: {e}")
                if self.server_task and not self.server_task.done():
                    self.server_task.cancel()
                    try:
                        await self.server_task
                    except (asyncio.CancelledError, Exception):
                        pass

            asyncio.run_coroutine_threadsafe(_shutdown_coro(), event_loop).result(
                timeout=10
            )

            # Task is done; now it's safe to stop the loop.
            event_loop.call_soon_threadsafe(event_loop.stop)

            # Wait for the server thread to exit cleanly
            if self.server_thread:
                self.server_thread.join(timeout=10)

            # Clean up endpoint file and the per-job directory so
            # the jobs dir doesn't accumulate empty job_* dirs
            # over time. rmdir is best-effort — leave the dir behind
            # if anything else is still in there (e.g. unread
            # control/*.json command files from a flaky shutdown).
            endpoint_file = self.control_dir / "endpoint.json"
            if endpoint_file.exists():
                endpoint_file.unlink()
            token_file = self.control_dir / "auth_token"
            if token_file.exists():
                try:
                    token_file.unlink()
                except OSError:
                    pass
            try:
                self.control_dir.rmdir()
            except OSError:
                pass

            logger.info("Trainer control system shutdown complete")

        except Exception as e:
            logger.warning(f"Error during control system shutdown: {e}")

Divergence Detection

forgather.ml.trainer.callbacks.DivergenceDetector

Bases: TrainerCallback, Stateful

Detects training divergence by comparing smoothed loss against its best observed value.

Maintains a smoothed loss (EMA) and tracks its running minimum. Triggers when the smoothed loss exceeds the baseline minimum by a configurable threshold, sustained for patience consecutive observations.

Supports absolute threshold (smoothed - best >= threshold), relative threshold (smoothed >= best * factor), or both simultaneously (triggers on whichever fires first).

Also detects NaN/Inf loss values immediately (no patience required).

Defaults are calibrated against real training runs where loss decreases from ~10 to ~3.8 then spikes to ~9.7 on divergence. With default settings (smoothing=0.3, threshold=1.0, patience=3), divergence is detected within 3 log entries (~96 training steps at 32-step log intervals) of the spike, with zero false positives on healthy runs.

Examples:

>>> detector = DivergenceDetector(
...     smoothing=0.3,        # EMA alpha (higher = more responsive)
...     threshold=1.0,        # Absolute: stop if smoothed - best >= 1.0
...     patience=3,           # Require 3 consecutive observations
...     action="abort",
... )
>>> trainer = Trainer(..., callbacks=[detector])
>>> trainer.train()

Using relative threshold (e.g., 50% increase from best):

>>> detector = DivergenceDetector(
...     smoothing=0.3,
...     relative_threshold=1.5,  # Stop if smoothed >= 1.5 * best
...     patience=3,
...     action="stop",
... )
Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
class DivergenceDetector(TrainerCallback, Stateful):
    """
    Detects training divergence by comparing smoothed loss against its best observed value.

    Maintains a smoothed loss (EMA) and tracks its running minimum. Triggers when
    the smoothed loss exceeds the baseline minimum by a configurable threshold,
    sustained for ``patience`` consecutive observations.

    Supports absolute threshold (smoothed - best >= threshold), relative threshold
    (smoothed >= best * factor), or both simultaneously (triggers on whichever
    fires first).

    Also detects NaN/Inf loss values immediately (no patience required).

    Defaults are calibrated against real training runs where loss decreases from
    ~10 to ~3.8 then spikes to ~9.7 on divergence. With default settings
    (smoothing=0.3, threshold=1.0, patience=3), divergence is detected within
    3 log entries (~96 training steps at 32-step log intervals) of the spike,
    with zero false positives on healthy runs.

    Examples
    --------
    >>> detector = DivergenceDetector(
    ...     smoothing=0.3,        # EMA alpha (higher = more responsive)
    ...     threshold=1.0,        # Absolute: stop if smoothed - best >= 1.0
    ...     patience=3,           # Require 3 consecutive observations
    ...     action="abort",
    ... )
    >>> trainer = Trainer(..., callbacks=[detector])
    >>> trainer.train()

    Using relative threshold (e.g., 50% increase from best):

    >>> detector = DivergenceDetector(
    ...     smoothing=0.3,
    ...     relative_threshold=1.5,  # Stop if smoothed >= 1.5 * best
    ...     patience=3,
    ...     action="stop",
    ... )
    """

    def __init__(
        self,
        smoothing: float = 0.3,
        threshold: float | None = 1.0,
        relative_threshold: float | None = None,
        patience: int = 3,
        warmup: int = 10,
        action: Literal["stop", "abort"] = "stop",
        use_eval_loss: bool = False,
        metric_key: str | None = None,
    ):
        """
        Initialize divergence detector.

        Parameters
        ----------
        smoothing : float, optional
            EMA alpha for smoothing raw loss (0–1). Higher = more responsive
            to recent values. Effective window ~ 1/alpha observations.
        threshold : float or None, optional
            Absolute divergence threshold. Triggers when
            ``(smoothed_loss - best_smoothed_loss) >= threshold``.
            Set to ``None`` to disable absolute threshold.
        relative_threshold : float or None, optional
            Relative divergence threshold. Triggers when
            ``smoothed_loss >= best_smoothed_loss * relative_threshold``.
            For example, ``1.5`` means "50% increase from best".
            Set to ``None`` to disable relative threshold.
        patience : int, optional
            Number of consecutive observations above threshold required
            before triggering. Higher values reduce false positives from
            transient spikes. Set to ``1`` for immediate triggering.
        warmup : int, optional
            Number of initial observations to skip before checking divergence.
            Avoids false positives from the high-loss early training phase.
        action : {"stop", "abort"}, optional
            What to do when divergence is detected. ``"stop"`` gracefully
            stops training (saves checkpoint first); ``"abort"`` stops
            immediately without saving.
        use_eval_loss : bool, optional
            If ``True``, monitor ``eval_loss``; if ``False``, monitor train
            loss. Defaults to ``False`` because train loss is logged much more
            frequently, enabling faster detection.
        metric_key : str, optional
            Custom metric key to monitor (overrides ``use_eval_loss``).
        """
        super().__init__()

        if not 0 < smoothing <= 1:
            raise ValueError(f"smoothing must be in (0, 1], got {smoothing}")
        if threshold is not None and threshold <= 0:
            raise ValueError(f"threshold must be > 0, got {threshold}")
        if relative_threshold is not None and relative_threshold <= 1:
            raise ValueError(
                f"relative_threshold must be > 1, got {relative_threshold}"
            )
        if threshold is None and relative_threshold is None:
            raise ValueError(
                "At least one of threshold or relative_threshold must be set"
            )
        if patience < 1:
            raise ValueError(f"patience must be >= 1, got {patience}")
        if warmup < 0:
            raise ValueError(f"warmup must be >= 0, got {warmup}")

        self.smoothing = smoothing
        self.threshold = threshold
        self.relative_threshold = relative_threshold
        self.patience = patience
        self.warmup = warmup
        self.action = action
        self.use_eval_loss = use_eval_loss
        self.metric_key = metric_key

        # State
        self.smoothed_loss: float | None = None
        self.best_smoothed_loss: float | None = None
        self.observation_count: int = 0
        self.consecutive_above: int = 0

    def _get_metric(self, data):
        """Extract the monitored metric from a data dict. Returns (value, key_name)."""
        if self.metric_key:
            return data.get(self.metric_key), self.metric_key
        elif self.use_eval_loss:
            return data.get("eval_loss"), "eval_loss"
        else:
            return data.get("loss"), "loss"

    def _check_divergence(self, args, state, control, logs=None, metrics=None):
        """Check for divergence given metrics dict."""
        data = logs or metrics
        if not data:
            return control

        loss, key_name = self._get_metric(data)
        if loss is None:
            return control

        # Detect NaN/Inf immediately
        if math.isnan(loss) or math.isinf(loss):
            if state is None or state.is_world_process_zero:
                logger.error(
                    f"Training divergence detected! {key_name}={loss} (NaN/Inf)\n"
                    f"Action: {self.action}"
                )
            self._trigger_action(control)
            return control

        self.observation_count += 1

        # Update smoothed loss
        if self.smoothed_loss is None:
            self.smoothed_loss = loss
            self.best_smoothed_loss = loss
            if state is None or state.is_world_process_zero:
                logger.info(
                    f"DivergenceDetector initialized with {key_name}={loss:.4f}"
                )
            return control

        # Both are guaranteed non-None after the initialization branch above
        assert self.smoothed_loss is not None
        assert self.best_smoothed_loss is not None

        smoothed = self.smoothing * loss + (1 - self.smoothing) * self.smoothed_loss
        self.smoothed_loss = smoothed

        # Update best (only track improvements)
        best = self.best_smoothed_loss
        if smoothed < best:
            best = smoothed
            self.best_smoothed_loss = best

        # Skip warmup period
        if self.observation_count <= self.warmup:
            return control

        # Check thresholds
        above = False
        abs_divergence = smoothed - best
        rel_ratio = smoothed / best if best > 0 else 0.0

        if self.threshold is not None and abs_divergence >= self.threshold:
            above = True
        if self.relative_threshold is not None and rel_ratio >= self.relative_threshold:
            above = True

        logger.debug(
            f"Divergence detector: {key_name}={loss:.4f}, "
            f"smoothed={smoothed:.4f}, best={best:.4f}, "
            f"abs_div={abs_divergence:.4f}, rel={rel_ratio:.4f}, "
            f"consecutive={self.consecutive_above}/{self.patience}"
        )

        if above:
            self.consecutive_above += 1
            if self.consecutive_above >= self.patience:
                if state is None or state.is_world_process_zero:
                    parts = [
                        f"Training divergence detected! {key_name}={loss:.4f}",
                        f"Smoothed loss: {smoothed:.4f}",
                        f"Best smoothed loss: {best:.4f}",
                    ]
                    if self.threshold is not None:
                        parts.append(
                            f"Absolute divergence: {abs_divergence:.4f} "
                            f"(threshold: {self.threshold:.4f})"
                        )
                    if self.relative_threshold is not None:
                        parts.append(
                            f"Relative ratio: {rel_ratio:.4f} "
                            f"(threshold: {self.relative_threshold:.4f})"
                        )
                    parts.append(f"Action: {self.action}")
                    logger.error("\n".join(parts))

                self._trigger_action(control)
        else:
            self.consecutive_above = 0

        return control

    def _trigger_action(self, control):
        """Apply the configured action to the trainer control."""
        if self.action == "stop":
            control.should_training_stop = True
        elif self.action == "abort":
            control.should_training_stop = True
            if hasattr(control, "should_abort_without_save"):
                control.should_abort_without_save = True

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Check for divergence when training metrics are logged."""
        return self._check_divergence(args, state, control, logs=logs)

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Check for divergence when evaluation metrics are available."""
        return self._check_divergence(args, state, control, metrics=metrics)

    def state_dict(self):
        """Return callback state to save with checkpoint."""
        return {
            "smoothed_loss": self.smoothed_loss,
            "best_smoothed_loss": self.best_smoothed_loss,
            "observation_count": self.observation_count,
            "consecutive_above": self.consecutive_above,
        }

    def load_state_dict(self, state_dict):
        """Restore callback state from checkpoint."""
        self.smoothed_loss = state_dict["smoothed_loss"]
        self.best_smoothed_loss = state_dict["best_smoothed_loss"]
        self.observation_count = state_dict["observation_count"]
        self.consecutive_above = state_dict["consecutive_above"]
        if self.smoothed_loss is not None and self.best_smoothed_loss is not None:
            logger.debug(
                f"Restored DivergenceDetector state: "
                f"smoothed={self.smoothed_loss:.4f}, best={self.best_smoothed_loss:.4f}, "
                f"observations={self.observation_count}"
            )

__init__(smoothing=0.3, threshold=1.0, relative_threshold=None, patience=3, warmup=10, action='stop', use_eval_loss=False, metric_key=None)

Initialize divergence detector.

Parameters:

Name Type Description Default
smoothing float

EMA alpha for smoothing raw loss (0–1). Higher = more responsive to recent values. Effective window ~ 1/alpha observations.

0.3
threshold float or None

Absolute divergence threshold. Triggers when (smoothed_loss - best_smoothed_loss) >= threshold. Set to None to disable absolute threshold.

1.0
relative_threshold float or None

Relative divergence threshold. Triggers when smoothed_loss >= best_smoothed_loss * relative_threshold. For example, 1.5 means "50% increase from best". Set to None to disable relative threshold.

None
patience int

Number of consecutive observations above threshold required before triggering. Higher values reduce false positives from transient spikes. Set to 1 for immediate triggering.

3
warmup int

Number of initial observations to skip before checking divergence. Avoids false positives from the high-loss early training phase.

10
action (stop, abort)

What to do when divergence is detected. "stop" gracefully stops training (saves checkpoint first); "abort" stops immediately without saving.

"stop"
use_eval_loss bool

If True, monitor eval_loss; if False, monitor train loss. Defaults to False because train loss is logged much more frequently, enabling faster detection.

False
metric_key str

Custom metric key to monitor (overrides use_eval_loss).

None
Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
def __init__(
    self,
    smoothing: float = 0.3,
    threshold: float | None = 1.0,
    relative_threshold: float | None = None,
    patience: int = 3,
    warmup: int = 10,
    action: Literal["stop", "abort"] = "stop",
    use_eval_loss: bool = False,
    metric_key: str | None = None,
):
    """
    Initialize divergence detector.

    Parameters
    ----------
    smoothing : float, optional
        EMA alpha for smoothing raw loss (0–1). Higher = more responsive
        to recent values. Effective window ~ 1/alpha observations.
    threshold : float or None, optional
        Absolute divergence threshold. Triggers when
        ``(smoothed_loss - best_smoothed_loss) >= threshold``.
        Set to ``None`` to disable absolute threshold.
    relative_threshold : float or None, optional
        Relative divergence threshold. Triggers when
        ``smoothed_loss >= best_smoothed_loss * relative_threshold``.
        For example, ``1.5`` means "50% increase from best".
        Set to ``None`` to disable relative threshold.
    patience : int, optional
        Number of consecutive observations above threshold required
        before triggering. Higher values reduce false positives from
        transient spikes. Set to ``1`` for immediate triggering.
    warmup : int, optional
        Number of initial observations to skip before checking divergence.
        Avoids false positives from the high-loss early training phase.
    action : {"stop", "abort"}, optional
        What to do when divergence is detected. ``"stop"`` gracefully
        stops training (saves checkpoint first); ``"abort"`` stops
        immediately without saving.
    use_eval_loss : bool, optional
        If ``True``, monitor ``eval_loss``; if ``False``, monitor train
        loss. Defaults to ``False`` because train loss is logged much more
        frequently, enabling faster detection.
    metric_key : str, optional
        Custom metric key to monitor (overrides ``use_eval_loss``).
    """
    super().__init__()

    if not 0 < smoothing <= 1:
        raise ValueError(f"smoothing must be in (0, 1], got {smoothing}")
    if threshold is not None and threshold <= 0:
        raise ValueError(f"threshold must be > 0, got {threshold}")
    if relative_threshold is not None and relative_threshold <= 1:
        raise ValueError(
            f"relative_threshold must be > 1, got {relative_threshold}"
        )
    if threshold is None and relative_threshold is None:
        raise ValueError(
            "At least one of threshold or relative_threshold must be set"
        )
    if patience < 1:
        raise ValueError(f"patience must be >= 1, got {patience}")
    if warmup < 0:
        raise ValueError(f"warmup must be >= 0, got {warmup}")

    self.smoothing = smoothing
    self.threshold = threshold
    self.relative_threshold = relative_threshold
    self.patience = patience
    self.warmup = warmup
    self.action = action
    self.use_eval_loss = use_eval_loss
    self.metric_key = metric_key

    # State
    self.smoothed_loss: float | None = None
    self.best_smoothed_loss: float | None = None
    self.observation_count: int = 0
    self.consecutive_above: int = 0

on_log(args, state, control, logs=None, **kwargs)

Check for divergence when training metrics are logged.

Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
def on_log(self, args, state, control, logs=None, **kwargs):
    """Check for divergence when training metrics are logged."""
    return self._check_divergence(args, state, control, logs=logs)

on_evaluate(args, state, control, metrics=None, **kwargs)

Check for divergence when evaluation metrics are available.

Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
    """Check for divergence when evaluation metrics are available."""
    return self._check_divergence(args, state, control, metrics=metrics)

state_dict()

Return callback state to save with checkpoint.

Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
def state_dict(self):
    """Return callback state to save with checkpoint."""
    return {
        "smoothed_loss": self.smoothed_loss,
        "best_smoothed_loss": self.best_smoothed_loss,
        "observation_count": self.observation_count,
        "consecutive_above": self.consecutive_above,
    }

load_state_dict(state_dict)

Restore callback state from checkpoint.

Source code in src/forgather/ml/trainer/callbacks/divergence_detector.py
def load_state_dict(self, state_dict):
    """Restore callback state from checkpoint."""
    self.smoothed_loss = state_dict["smoothed_loss"]
    self.best_smoothed_loss = state_dict["best_smoothed_loss"]
    self.observation_count = state_dict["observation_count"]
    self.consecutive_above = state_dict["consecutive_above"]
    if self.smoothed_loss is not None and self.best_smoothed_loss is not None:
        logger.debug(
            f"Restored DivergenceDetector state: "
            f"smoothed={self.smoothed_loss:.4f}, best={self.best_smoothed_loss:.4f}, "
            f"observations={self.observation_count}"
        )

Logging

forgather.ml.trainer.callbacks.JsonLogger

Bases: TrainerCallback, Stateful

A JSON logger callback that writes training metrics to a JSON file.

Writes a JSON record (with UTC timestamp, global_step, epoch, and all reported metrics) each time on_log or on_evaluate is called.

Implements the Stateful protocol so that the log file path and last written step are saved with checkpoints. When training resumes from a checkpoint, the logger reopens the original file, truncates any entries recorded after the checkpoint step, and continues appending.

Source code in src/forgather/ml/trainer/callbacks/json_logger.py
class JsonLogger(TrainerCallback, Stateful):
    """
    A JSON logger callback that writes training metrics to a JSON file.

    Writes a JSON record (with UTC timestamp, global_step, epoch, and all
    reported metrics) each time ``on_log`` or ``on_evaluate`` is called.

    Implements the ``Stateful`` protocol so that the log file path and last
    written step are saved with checkpoints.  When training resumes from a
    checkpoint, the logger reopens the original file, truncates any entries
    recorded after the checkpoint step, and continues appending.
    """

    def __init__(self, **kwargs):
        """
        The contents of kwargs will be recorded when training starts
        """
        super().__init__()
        self.log_file = None
        self.log_path = None
        self.kwargs = kwargs
        self.prefix = ""
        self._last_step = -1

        # Set by load_state_dict when resuming from checkpoint
        self._original_log_path: str | None = None
        self._resume_step: int | None = None

    def __del__(self):
        self.close()

    def close(self):
        if self.log_file is not None:
            self.log_file.write("\n]")
            self.log_file.close()
            self.log_file = None

    # -- Stateful protocol --------------------------------------------------

    def state_dict(self) -> dict:
        return {
            "log_path": self.log_path,
            "last_step": self._last_step,
        }

    def load_state_dict(self, state_dict: dict) -> None:
        self._original_log_path = state_dict.get("log_path")
        self._resume_step = state_dict.get("last_step", -1)
        logger.debug(
            "JsonLogger: loaded state (path=%s, step=%s)",
            self._original_log_path,
            self._resume_step,
        )

    # -- Callback hooks -----------------------------------------------------

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero or args.logging_dir is None:
            return

        if self._original_log_path and os.path.isfile(self._original_log_path):
            self.log_path = self._original_log_path
            self._truncate_and_reopen()
        else:
            if self._original_log_path:
                logger.warning(
                    "JsonLogger: original log file not found (%s), "
                    "starting fresh in %s",
                    self._original_log_path,
                    args.logging_dir,
                )
            os.makedirs(args.logging_dir, exist_ok=True)
            self.log_path = os.path.join(args.logging_dir, "trainer_logs.json")
            self.log_file = open(self.log_path, "x")
            self.log_file.write("[\n")

    def on_evaluate(self, args, state, control, **kwargs):
        metrics = kwargs.get("metrics", {})
        if self.log_file is None:
            return
        self._write_log(state, metrics)

    def on_log(self, args, state, control, **kwargs):
        logs = kwargs.get("logs", {})
        if self.log_file is None:
            return
        self._write_log(state, logs)

    def on_train_end(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        self.close()

    # -- Internal -----------------------------------------------------------

    def _write_log(self, state, data: dict):
        assert self.log_file is not None
        new_fields = dict(
            timestamp=datetime.datetime.now(datetime.UTC).timestamp(),
            global_step=state.global_step,
            epoch=state.epoch,
        )
        self.log_file.write(self.prefix + json.dumps(new_fields | data))
        self.prefix = ",\n"
        self._last_step = state.global_step

    def _truncate_and_reopen(self):
        """Reopen the original JSON log file and truncate entries after
        the checkpoint step."""
        log_path = self.log_path
        assert log_path is not None
        resume_step = self._resume_step if self._resume_step is not None else -1

        try:
            with open(log_path, "r") as f:
                content = f.read()

            records = _parse_json_log(content)
            kept = [r for r in records if r.get("global_step", 0) <= resume_step]

            logger.info(
                "JsonLogger: resuming %s, kept %d/%d records (up to step %d)",
                log_path,
                len(kept),
                len(records),
                resume_step,
            )

            self.log_file = open(log_path, "w")
            self.log_file.write("[\n")
            self.prefix = ""
            for record in kept:
                self.log_file.write(self.prefix + json.dumps(record))
                self.prefix = ",\n"
            self.log_file.flush()

        except Exception as e:
            logger.warning(
                "JsonLogger: failed to parse/truncate %s: %s. "
                "Backing up and starting fresh.",
                log_path,
                e,
            )
            backup = log_path + ".bak"
            try:
                os.rename(log_path, backup)
                logger.info("JsonLogger: backed up corrupted file to %s", backup)
            except OSError:
                pass
            self.log_file = open(log_path, "w")
            self.log_file.write("[\n")
            self.prefix = ""

__init__(**kwargs)

The contents of kwargs will be recorded when training starts

Source code in src/forgather/ml/trainer/callbacks/json_logger.py
def __init__(self, **kwargs):
    """
    The contents of kwargs will be recorded when training starts
    """
    super().__init__()
    self.log_file = None
    self.log_path = None
    self.kwargs = kwargs
    self.prefix = ""
    self._last_step = -1

    # Set by load_state_dict when resuming from checkpoint
    self._original_log_path: str | None = None
    self._resume_step: int | None = None

forgather.ml.trainer.callbacks.TBLogger

Bases: TrainerCallback

A Trainer callback that logs scalars to TensorBoard.

Scalars are configured as a dict mapping TensorBoard tags to spec dicts with optional source and transform fields. The dict is merged with default_tb_scalars() so only deltas need to be specified. Set a key to None to erase a default scalar.

Source code in src/forgather/ml/trainer/callbacks/tb_logger.py
class TBLogger(TrainerCallback):
    """A Trainer callback that logs scalars to TensorBoard.

    Scalars are configured as a dict mapping TensorBoard tags to spec
    dicts with optional ``source`` and ``transform`` fields.  The dict
    is merged with ``default_tb_scalars()`` so only deltas need to be
    specified.  Set a key to ``None`` to erase a default scalar.
    """

    def __init__(
        self,
        summary_writer,
        scalars: Optional[dict] = None,
        experiment_info: Optional[dict] = None,
    ):
        super().__init__()
        merged = _merge_spec_dicts(default_tb_scalars(), scalars)
        self.scalars: list[TBScalarSpec] = _normalize_tb_scalars(merged)
        self.summary_writer = summary_writer
        self.last_step = -1
        if experiment_info is not None:
            self.experiment_info = self.mapping_as_markdown(experiment_info)
        else:
            self.experiment_info = None

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        if self.experiment_info is not None:
            self.summary_writer.add_text(
                "experiment", self.experiment_info, global_step=state.global_step
            )

        info, extra_info = format_train_info(args, state, control, **kwargs)
        self.summary_writer.add_text(
            "training_info",
            self.mapping_as_markdown(info | extra_info),
            global_step=state.global_step,
        )

    @staticmethod
    def mapping_as_markdown(mapping):
        """
        Format dictionary as markdown

        Tensorboard expects text to be in markdown format...
        """
        s = "```\n"
        s += format_mapping(mapping)
        s += "```"
        return s

    def _log_metrics(self, global_step, metrics):
        for spec in self.scalars:
            value = metrics.get(spec.source)
            if value is None:
                continue
            if spec.transform is not None:
                value = spec.transform(value, metrics)
                if value is None:
                    continue
            self.summary_writer.add_scalar(spec.tag, value, global_step=global_step)

    def on_evaluate(self, args, state, control, **kwargs):
        metrics = kwargs.get("metrics", {})
        if not state.is_world_process_zero:
            return
        global_step = state.global_step
        if self.last_step == global_step:
            return
        self.last_step = global_step

        self._log_metrics(global_step, metrics)

    def on_log(self, args, state, control, **kwargs):
        logs = kwargs.get("logs", {})
        if not state.is_world_process_zero:
            return

        self._log_metrics(state.global_step, logs)
        self.summary_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if not state.is_world_process_zero and len(state.log_history):
            return
        self.summary_writer.add_text(
            "train_results",
            self.mapping_as_markdown(
                state.log_history[-1],
            ),
            global_step=state.global_step,
        )

mapping_as_markdown(mapping) staticmethod

Format dictionary as markdown

Tensorboard expects text to be in markdown format...

Source code in src/forgather/ml/trainer/callbacks/tb_logger.py
@staticmethod
def mapping_as_markdown(mapping):
    """
    Format dictionary as markdown

    Tensorboard expects text to be in markdown format...
    """
    s = "```\n"
    s += format_mapping(mapping)
    s += "```"
    return s

forgather.ml.trainer.callbacks.GradNormLogger

Bases: TrainerCallback, Stateful

Logs per-parameter gradient L2 norms to a JSON file.

Gradient norms are captured in on_pre_optimizer_step (after gradient clipping, before optimizer step and zero_grad) and written to the log file in on_evaluate. This means gradient data is logged at eval frequency, keeping overhead minimal.

The log file uses JSON array format with checkpoint resume support via the Stateful protocol.

When fuse_optim_with_backward is enabled, gradients are consumed during the backward pass and are not available for capture. The callback detects this and disables itself with a warning.

Source code in src/forgather/ml/trainer/callbacks/grad_logger.py
class GradNormLogger(TrainerCallback, Stateful):
    """Logs per-parameter gradient L2 norms to a JSON file.

    Gradient norms are captured in ``on_pre_optimizer_step`` (after gradient
    clipping, before optimizer step and zero_grad) and written to the log
    file in ``on_evaluate``. This means gradient data is logged at eval
    frequency, keeping overhead minimal.

    The log file uses JSON array format with checkpoint resume support
    via the Stateful protocol.

    When ``fuse_optim_with_backward`` is enabled, gradients are consumed
    during the backward pass and are not available for capture. The callback
    detects this and disables itself with a warning.
    """

    LOG_FILENAME = "gradient_norms.json"

    def __init__(self):
        super().__init__()
        self._writer = JsonLogWriter(self.LOG_FILENAME)
        self._buffered_norms: dict[str, float] | None = None
        self._buffered_step: int = -1
        self._buffered_epoch: float = 0.0
        self._disabled = False
        self._warned_meta = False

    # -- Stateful protocol ----------------------------------------------------

    def state_dict(self) -> dict:
        return {"writer": self._writer.state_dict()}

    def load_state_dict(self, state_dict: dict) -> None:
        writer_state = state_dict.get("writer", {})
        self._writer.load_state_dict(writer_state)

    # -- Callback hooks -------------------------------------------------------

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero or args.logging_dir is None:
            return

        if getattr(args, "fuse_optim_with_backward", False):
            logger.warning(
                "GradNormLogger: fuse_optim_with_backward is enabled, "
                "gradients are consumed during backward. "
                "Gradient norm logging is disabled."
            )
            self._disabled = True
            return

        self._writer.open(args.logging_dir)

    def on_pre_optimizer_step(self, args, state, control, **kwargs):
        """Capture per-parameter gradient norms before optimizer step."""
        if self._disabled or not state.is_world_process_zero:
            return

        model = kwargs.get("model")
        if model is None:
            return

        # Pipeline-parallel guard
        try:
            first_param = next(model.parameters())
        except StopIteration:
            first_param = None

        if first_param is None or first_param.device.type == "meta":
            if not self._warned_meta:
                logger.warning(
                    "GradNormLogger: model parameters are on the meta "
                    "device (pipeline-parallel training). Logging disabled."
                )
                self._warned_meta = True
            return

        norms = OrderedDict()
        with torch.no_grad():
            for name, p in model.named_parameters():
                if p.grad is not None:
                    norms[name] = p.grad.float().norm().item()

        self._buffered_norms = norms
        self._buffered_step = state.global_step
        self._buffered_epoch = state.epoch

    def on_evaluate(self, args, state, control, **kwargs):
        """Write buffered gradient norms to log file."""
        if self._disabled or not state.is_world_process_zero:
            return
        if not self._writer.is_open or self._buffered_norms is None:
            return

        record = {"grad_norms": self._buffered_norms}
        self._writer.write_record(self._buffered_step, self._buffered_epoch, record)
        self._buffered_norms = None

    def on_train_end(self, args, state, control, **kwargs):
        self._writer.close()

on_pre_optimizer_step(args, state, control, **kwargs)

Capture per-parameter gradient norms before optimizer step.

Source code in src/forgather/ml/trainer/callbacks/grad_logger.py
def on_pre_optimizer_step(self, args, state, control, **kwargs):
    """Capture per-parameter gradient norms before optimizer step."""
    if self._disabled or not state.is_world_process_zero:
        return

    model = kwargs.get("model")
    if model is None:
        return

    # Pipeline-parallel guard
    try:
        first_param = next(model.parameters())
    except StopIteration:
        first_param = None

    if first_param is None or first_param.device.type == "meta":
        if not self._warned_meta:
            logger.warning(
                "GradNormLogger: model parameters are on the meta "
                "device (pipeline-parallel training). Logging disabled."
            )
            self._warned_meta = True
        return

    norms = OrderedDict()
    with torch.no_grad():
        for name, p in model.named_parameters():
            if p.grad is not None:
                norms[name] = p.grad.float().norm().item()

    self._buffered_norms = norms
    self._buffered_step = state.global_step
    self._buffered_epoch = state.epoch

on_evaluate(args, state, control, **kwargs)

Write buffered gradient norms to log file.

Source code in src/forgather/ml/trainer/callbacks/grad_logger.py
def on_evaluate(self, args, state, control, **kwargs):
    """Write buffered gradient norms to log file."""
    if self._disabled or not state.is_world_process_zero:
        return
    if not self._writer.is_open or self._buffered_norms is None:
        return

    record = {"grad_norms": self._buffered_norms}
    self._writer.write_record(self._buffered_step, self._buffered_epoch, record)
    self._buffered_norms = None

forgather.ml.trainer.callbacks.ParameterNormLogger

Bases: TrainerCallback, Stateful

Logs per-parameter L2 norms and/or spectral norms to a JSON file.

Data is written on each evaluation step. The log file uses JSON array format with checkpoint resume support via the Stateful protocol.

The existing WeightNormLogger continues to handle the total parameter norm for TensorBoard/console logging. This callback provides the per-parameter breakdown for diagnostic analysis and heatmap visualization.

In pipeline-parallel training the model shell passed to callbacks contains only meta-device tensors. This callback detects that case, warns once, and skips logging for the remainder of training.

Source code in src/forgather/ml/trainer/callbacks/parameter_norm_logger.py
class ParameterNormLogger(TrainerCallback, Stateful):
    """Logs per-parameter L2 norms and/or spectral norms to a JSON file.

    Data is written on each evaluation step. The log file uses JSON array
    format with checkpoint resume support via the Stateful protocol.

    The existing ``WeightNormLogger`` continues to handle the total
    parameter norm for TensorBoard/console logging. This callback provides
    the per-parameter breakdown for diagnostic analysis and heatmap
    visualization.

    In pipeline-parallel training the model shell passed to callbacks
    contains only meta-device tensors. This callback detects that case,
    warns once, and skips logging for the remainder of training.
    """

    LOG_FILENAME = "parameter_norms.json"

    def __init__(
        self,
        log_norms: bool = True,
        log_spectral_norms: bool = True,
        power_iter_steps: int = 10,
    ):
        """
        Parameters
        ----------
        log_norms : bool, optional
            Whether to log per-parameter L2 norms.
        log_spectral_norms : bool, optional
            Whether to log per-parameter spectral norms.
        power_iter_steps : int, optional
            Number of power iteration steps for spectral norm estimation.
            First evaluation uses 2x this value for cold-start convergence.
        """
        super().__init__()
        self.log_norms = log_norms
        self.log_spectral_norms = log_spectral_norms
        self.power_iter_steps = power_iter_steps

        self._writer = JsonLogWriter(self.LOG_FILENAME)
        self._warned_meta = False
        # Cached direction vectors for power iteration warm-starting,
        # keyed by parameter FQN.
        self._u_vectors: dict[str, torch.Tensor] = {}
        self._first_eval = True

    # -- Stateful protocol ----------------------------------------------------

    def state_dict(self) -> dict:
        return {"writer": self._writer.state_dict()}

    def load_state_dict(self, state_dict: dict) -> None:
        writer_state = state_dict.get("writer", {})
        self._writer.load_state_dict(writer_state)

    # -- Callback hooks -------------------------------------------------------

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero or args.logging_dir is None:
            return
        self._writer.open(args.logging_dir)

    def on_evaluate(self, args, state, control, **kwargs):
        if not state.is_world_process_zero or not self._writer.is_open:
            return

        model = kwargs.get("model")
        if model is None:
            return

        # Pipeline-parallel guard: detect meta-device tensors.
        try:
            first_param = next(model.parameters())
        except StopIteration:
            first_param = None

        if first_param is None or first_param.device.type == "meta":
            if not self._warned_meta:
                logger.warning(
                    "ParameterNormLogger: model parameters are on the meta "
                    "device (pipeline-parallel training). Logging disabled."
                )
                self._warned_meta = True
            return

        record = {}

        n_iters = self.power_iter_steps
        if self._first_eval:
            n_iters *= 2
            self._first_eval = False

        with torch.no_grad():
            if self.log_norms:
                norms = OrderedDict()
                for name, p in model.named_parameters():
                    norms[name] = p.float().norm().item()
                record["norms"] = norms

            if self.log_spectral_norms:
                spectral_norms = OrderedDict()
                for name, p in model.named_parameters():
                    u = self._u_vectors.get(name)
                    sigma, u_new = _spectral_norm_power_iter(p.data, n_iters, u)
                    spectral_norms[name] = sigma
                    if u_new is not None:
                        self._u_vectors[name] = u_new
                record["spectral_norms"] = spectral_norms

        self._writer.write_record(state.global_step, state.epoch, record)

    def on_train_end(self, args, state, control, **kwargs):
        self._writer.close()

__init__(log_norms=True, log_spectral_norms=True, power_iter_steps=10)

Parameters:

Name Type Description Default
log_norms bool

Whether to log per-parameter L2 norms.

True
log_spectral_norms bool

Whether to log per-parameter spectral norms.

True
power_iter_steps int

Number of power iteration steps for spectral norm estimation. First evaluation uses 2x this value for cold-start convergence.

10
Source code in src/forgather/ml/trainer/callbacks/parameter_norm_logger.py
def __init__(
    self,
    log_norms: bool = True,
    log_spectral_norms: bool = True,
    power_iter_steps: int = 10,
):
    """
    Parameters
    ----------
    log_norms : bool, optional
        Whether to log per-parameter L2 norms.
    log_spectral_norms : bool, optional
        Whether to log per-parameter spectral norms.
    power_iter_steps : int, optional
        Number of power iteration steps for spectral norm estimation.
        First evaluation uses 2x this value for cold-start convergence.
    """
    super().__init__()
    self.log_norms = log_norms
    self.log_spectral_norms = log_spectral_norms
    self.power_iter_steps = power_iter_steps

    self._writer = JsonLogWriter(self.LOG_FILENAME)
    self._warned_meta = False
    # Cached direction vectors for power iteration warm-starting,
    # keyed by parameter FQN.
    self._u_vectors: dict[str, torch.Tensor] = {}
    self._first_eval = True

forgather.ml.trainer.callbacks.WeightNormLogger

Bases: TrainerCallback

Logs the total L2 norm of all model parameters to logs after each evaluation step.

Computed identically to the gradient norm but using the weight tensors themselves. A growing value over training indicates that weights are increasing in magnitude, which usually means weight decay is too weak. A stable or shrinking value while gradient norms rise points to a different cause.

In pipeline-parallel training the model shell passed to callbacks contains only meta-device tensors. This callback detects that case, warns once, and skips logging for the remainder of training.

Source code in src/forgather/ml/trainer/callbacks/weight_norm_logger.py
class WeightNormLogger(TrainerCallback):
    """
    Logs the total L2 norm of all model parameters to logs after each
    evaluation step.

    Computed identically to the gradient norm but using the weight tensors
    themselves. A growing value over training indicates that weights are
    increasing in magnitude, which usually means weight decay is too weak.
    A stable or shrinking value while gradient norms rise points to a
    different cause.

    In pipeline-parallel training the model shell passed to callbacks contains
    only meta-device tensors. This callback detects that case, warns once, and
    skips logging for the remainder of training.
    """

    def __init__(self):
        super().__init__()
        self._warned_meta = False

    def on_log_step(self, state, logs, **kwargs):
        if not state.is_world_process_zero:
            return

        model = kwargs.get("model")
        if model is None:
            return

        # In pipeline-parallel training the model shell holds no real tensors;
        # parameters live on the meta device and actual stage weights are stored
        # elsewhere in the trainer. Detect this and bail out.
        try:
            first_param = next(model.parameters())
        except StopIteration:
            first_param = None

        if first_param is None or first_param.device.type == "meta":
            if not self._warned_meta:
                logger.warning(
                    "WeightNormLogger: model parameters are on the meta device "
                    "(pipeline-parallel training). Weight norm logging is disabled."
                )
                self._warned_meta = True
            return

        total_norm_sq = 0.0
        with torch.no_grad():
            for p in model.parameters():
                total_norm_sq += p.float().square().sum().item()

        logs["weight_norm"] = math.sqrt(total_norm_sq)

forgather.ml.trainer.callbacks.PeakMemory

Bases: TrainerCallback

PeakMemory is a TrainerCallback for monitoring and logging the peak CUDA memory usage during model training. This callback is designed to help diagnose and optimize GPU memory consumption in PyTorch-based training loops, especially when using distributed training. It records the maximum memory allocated on each GPU device throughout the training process, and can optionally log detailed memory statistics and write them to TensorBoard for visualization.

IMPORTANT: Memory history recording is disabled by default to prevent memory leaks. The torch.cuda.memory._record_memory_history feature can consume 1GB+ of memory during training.

Key Features: - Tracks the peak CUDA memory allocated on each GPU during training. - Supports both single-GPU and multi-GPU (distributed) training environments. - Optionally logs detailed CUDA memory statistics for further analysis. - Can write memory usage metrics to a TensorBoard SummaryWriter for visualization. - Provides configurable logging frequency and verbosity.

Parameters:

Name Type Description Default
summary_writer SummaryWriter

TensorBoard SummaryWriter instance for logging memory statistics.

None
show_details bool

If True, logs detailed CUDA memory statistics at each logging step and at the end of training.

False
do_log bool

If True, logs peak memory usage at each logging step (on_log callback).

False
enable_memory_snapshot bool

If True, enables comprehensive CUDA memory history recording and writes a pickled snapshot at end-of-training. WARNING: This can consume 1 GB+ memory and cause memory leaks.

False
file_prefix str

Filename prefix for the per-rank memory snapshot pickle. Defaults to "memory_snapshot".

'memory_snapshot'

Attributes:

Name Type Description
rank int

The process rank in distributed training.

world_size int

The total number of processes in distributed training.

summary_writer SummaryWriter or None

The TensorBoard SummaryWriter for logging.

enabled bool

Whether CUDA is available and memory tracking is enabled.

show_details bool

Whether to log detailed memory statistics.

do_log bool

Whether to log memory usage on each log step.

enable_memory_snapshot bool

Whether memory history recording and snapshot dumping is enabled.

max_allocated int

The maximum CUDA memory allocated during training (in bytes).

Source code in src/forgather/ml/trainer/callbacks/peak_memory.py
class PeakMemory(TrainerCallback):
    """
    PeakMemory is a TrainerCallback for monitoring and logging the peak CUDA memory usage during model training.
    This callback is designed to help diagnose and optimize GPU memory consumption in PyTorch-based training loops,
    especially when using distributed training. It records the maximum memory allocated on each GPU device throughout
    the training process, and can optionally log detailed memory statistics and write them to TensorBoard for visualization.

    IMPORTANT: Memory history recording is disabled by default to prevent memory leaks.
    The torch.cuda.memory._record_memory_history feature can consume 1GB+ of memory during training.

    Key Features:
    - Tracks the peak CUDA memory allocated on each GPU during training.
    - Supports both single-GPU and multi-GPU (distributed) training environments.
    - Optionally logs detailed CUDA memory statistics for further analysis.
    - Can write memory usage metrics to a TensorBoard SummaryWriter for visualization.
    - Provides configurable logging frequency and verbosity.

    Parameters
    ----------
    summary_writer : SummaryWriter, optional
        TensorBoard ``SummaryWriter`` instance for logging memory statistics.
    show_details : bool, optional
        If ``True``, logs detailed CUDA memory statistics at each logging
        step and at the end of training.
    do_log : bool, optional
        If ``True``, logs peak memory usage at each logging step
        (``on_log`` callback).
    enable_memory_snapshot : bool, optional
        If ``True``, enables comprehensive CUDA memory history recording
        and writes a pickled snapshot at end-of-training.
        WARNING: This can consume 1 GB+ memory and cause memory leaks.
    file_prefix : str, optional
        Filename prefix for the per-rank memory snapshot pickle.
        Defaults to ``"memory_snapshot"``.

    Attributes
    ----------
    rank : int
        The process rank in distributed training.
    world_size : int
        The total number of processes in distributed training.
    summary_writer : SummaryWriter or None
        The TensorBoard ``SummaryWriter`` for logging.
    enabled : bool
        Whether CUDA is available and memory tracking is enabled.
    show_details : bool
        Whether to log detailed memory statistics.
    do_log : bool
        Whether to log memory usage on each log step.
    enable_memory_snapshot : bool
        Whether memory history recording and snapshot dumping is enabled.
    max_allocated : int
        The maximum CUDA memory allocated during training (in bytes).
    """

    def __init__(
        self,
        summary_writer=None,
        show_details=False,
        do_log=False,
        enable_memory_snapshot=False,
        file_prefix="memory_snapshot",
    ):
        """
        :param summary_writer: Optional TensorBoard SummaryWriter to log peak memory
        :param show_details: Whether to log detailed memory stats
        :param do_log: Whether to log on each log step
        :param enable_memory_snapshot: Whether to enable CUDA memory history recording and snapshot
        """
        super().__init__()
        self.rank = int(os.environ.get("RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.summary_writer = summary_writer
        self.enabled = torch.cuda.is_available()
        self.show_details = show_details
        self.do_log = do_log
        self.enable_memory_snapshot = enable_memory_snapshot
        self.file_prefix = file_prefix

    def on_train_begin(self, args, state, control, **kwargs):
        # Disable, if not CUDA
        device = torch.device(args.device)
        if device.type != "cuda":
            self.enabled = False

        if not self.enabled:
            return
        self.max_allocated = 0
        # This feature can consume 1GB+ of memory during training
        if self.enable_memory_snapshot:
            torch.cuda.memory._record_memory_history(enabled="all")

    @staticmethod
    def _format_peak_memory(max_allocated):
        """
        Format peak memory in GiB (binary GB, 1024 ** 3 bytes)
        """
        gib = 1024**3
        return f"{max_allocated / gib:.3f} GiB" if max_allocated else "0 GiB"

    @staticmethod
    def _mapping_as_markdown(mapping):
        """
        Format dictionary as markdown

        Tensorboard expects text to be in markdown format...
        """
        s = "```\n"
        s += format_mapping(mapping)
        s += "```"
        return s

    def on_log(self, args, state, control, **kwargs):
        logs = kwargs.get("logs", {})
        if not self.enabled or (not self.do_log and not self.summary_writer):
            return
        device = torch.cuda.current_device()
        if self.enable_memory_snapshot:
            # Decode at https://docs.pytorch.org/memory_viz
            output_file = f"{self.file_prefix}_rank{self.rank}.pickle"
            logger.info(f"Saving memory snapshot to {output_file}")
            try:
                torch.cuda.memory._dump_snapshot(output_file)
            except Exception as e:
                logger.error(f"Failed to capture memory snapshot {output_file}")
            torch.cuda.memory._record_memory_history(enabled=None)
            # Only take a single snapshot
            self.enable_memory_snapshot = False

        # peak_mem_allocated is a per-rank list populated by Trainer._log_step
        # via _distributed_peak_mem. Every rank sees the full list.
        max_allocated_list = logs.get("peak_mem_allocated")
        if not max_allocated_list:
            return
        self.max_allocated = max(self.max_allocated, max_allocated_list[self.rank])

        # Per-rank TB scalars and the console line are rank-0 only.
        if self.rank != 0:
            return

        if self.summary_writer:
            for i, mem in enumerate(max_allocated_list):
                self.summary_writer.add_scalar(
                    f"peak_memory_rank{i}", mem, global_step=state.global_step
                )
            self.summary_writer.flush()
            if self.show_details:
                details = torch.cuda.memory_stats(device)
                self.summary_writer.add_text(
                    f"peak_memory_details",
                    self._mapping_as_markdown(details),
                    global_step=state.global_step,
                )
        if self.do_log:
            s = "Peak CUDA Memory Allocated: "
            for i, mem in enumerate(max_allocated_list):
                s += f"RANK{i} {self._format_peak_memory(mem)}, "
            logger.info(s)
            if self.show_details and not self.summary_writer:
                details = torch.cuda.memory_stats(device)
                logger.info(f"RANK{self.rank} Peak Memory Details: {pformat(details)}")

    def on_train_end(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if not self.enabled:
            return
        max_allocated = torch.cuda.max_memory_allocated()
        self.max_allocated = max(self.max_allocated, max_allocated)
        if self.enable_memory_snapshot:
            torch.cuda.memory._record_memory_history(enabled=None)
        logger.info(
            f"RANK{self.rank} MAX CUDA MEMORY ALLOCATED: {self._format_peak_memory(self.max_allocated)}"
        )
        if self.show_details and not self.do_log:
            details = torch.cuda.memory_stats(torch.cuda.current_device())
            if self.summary_writer:
                self.summary_writer.add_text(
                    "peak_memory_details",
                    self._mapping_as_markdown(details),
                    global_step=state.global_step,
                )
            logger.info(f"RANK{self.rank}: {pformat(details)}")

__init__(summary_writer=None, show_details=False, do_log=False, enable_memory_snapshot=False, file_prefix='memory_snapshot')

:param summary_writer: Optional TensorBoard SummaryWriter to log peak memory :param show_details: Whether to log detailed memory stats :param do_log: Whether to log on each log step :param enable_memory_snapshot: Whether to enable CUDA memory history recording and snapshot

Source code in src/forgather/ml/trainer/callbacks/peak_memory.py
def __init__(
    self,
    summary_writer=None,
    show_details=False,
    do_log=False,
    enable_memory_snapshot=False,
    file_prefix="memory_snapshot",
):
    """
    :param summary_writer: Optional TensorBoard SummaryWriter to log peak memory
    :param show_details: Whether to log detailed memory stats
    :param do_log: Whether to log on each log step
    :param enable_memory_snapshot: Whether to enable CUDA memory history recording and snapshot
    """
    super().__init__()
    self.rank = int(os.environ.get("RANK", 0))
    self.world_size = int(os.environ.get("WORLD_SIZE", 1))
    self.summary_writer = summary_writer
    self.enabled = torch.cuda.is_available()
    self.show_details = show_details
    self.do_log = do_log
    self.enable_memory_snapshot = enable_memory_snapshot
    self.file_prefix = file_prefix

Text Generation

forgather.ml.trainer.callbacks.TextgenCallback

Bases: TrainerCallback

Periodically generate and log text from a set of prompts for subjective model evaluation.

Automatically dispatches between single-rank generation (via model.generate()) and pipeline-parallel generation (via trainer.pipeline_generate()) based on whether the trainer exposes a pipeline_generate method. The same callback works unchanged with SimpleTrainer, AccelTrainer/DDPTrainer, and PipelineTrainer.

Source code in src/forgather/ml/trainer/callbacks/textgen_callback.py
class TextgenCallback(TrainerCallback):
    """
    Periodically generate and log text from a set of prompts for subjective model evaluation.

    Automatically dispatches between single-rank generation (via model.generate()) and
    pipeline-parallel generation (via trainer.pipeline_generate()) based on whether the
    trainer exposes a pipeline_generate method. The same callback works unchanged with
    SimpleTrainer, AccelTrainer/DDPTrainer, and PipelineTrainer.
    """

    # Stride is the number of steps between text generations
    def __init__(
        self,
        summary_writer: SummaryWriter,
        prompts: List[str] | str,
        generation_config: Optional[dict] = None,
        generation_steps: Optional[int] = None,
        max_new_tokens: int = 200,
    ):
        """
        Periodically generates and logs text from a set a prompts for subjective model evaluation

        This may only trigger on model evaluation steps, which establishes the minimum interval between generations.

        args:
            summary_writer: The Tensor Board SummaryWriter to log to.
            prompts: Either a list of prompts (List[str]) or a path to a YAML file, defining a list of prompts.
            generation_config: A dictionary with arguments to HF GenerationConfig
            generation_steps: The number of steps between generations. If None, it defaults to eval_steps
            max_new_tokens: The maximum new tokens to generate for each prompt.
        """
        super().__init__()
        self.summary_writer = summary_writer
        if isinstance(prompts, list):
            self.prompts = prompts
        else:
            if not isinstance(prompts, str):
                raise ValueError(
                    f"'prompts' must be List[str] | str, found {type(prompts)}"
                )
            with open(prompts, "r") as file:
                self.prompts = yaml.safe_load(file)

            if not isinstance(self.prompts, list):
                raise ValueError(
                    f"From file {prompts}, expected 'prompts' to be a list but found {type(self.prompts)}"
                )

        for s in self.prompts:
            if not isinstance(s, str):
                raise ValueError(
                    f"Expected all prompts to be strings, but found {type(s)}"
                )

        # To construct GenerationConfig, we need token ids from the model or tokenizer
        # We don't have these here, so defer construction until callback.
        if generation_config is None:
            self.gen_config_args = dict(
                do_sample=True,
                top_k=20,
                temperature=0.7,
                repetition_penalty=1.15,
            )
        else:
            self.gen_config_args = generation_config

        self.generation_steps = generation_steps
        self.max_new_tokens = max_new_tokens
        self.next_gen_step = 0

    def on_evaluate(self, args, state, control, **kwargs):
        trainer = kwargs.get("trainer")
        # Pipeline trainers expose pipeline_generate() and require collective participation
        # from all ranks. Single-rank trainers (including DDP) use model.generate() on rank 0.
        if hasattr(trainer, "pipeline_generate"):
            self._on_evaluate_pipeline(args, state, control, **kwargs)
        else:
            self._on_evaluate_single_rank(args, state, control, **kwargs)

    def _on_evaluate_single_rank(self, args, state, control, **kwargs):
        model = kwargs.get("model")
        processing_class = kwargs.get("processing_class")
        if self.generation_steps is None:
            self.generation_steps = args.eval_steps
        if not state.is_world_process_zero or state.global_step < self.next_gen_step:
            return
        self.next_gen_step += self.generation_steps
        text = ""
        for output in self.generate(args.device, model, processing_class):
            text += output + "\n\n---\n\n"
        self.summary_writer.add_text("eval-text", text, global_step=state.global_step)
        self.summary_writer.flush()

    def _on_evaluate_pipeline(self, args, state, control, **kwargs):
        trainer = kwargs.get("trainer")
        processing_class = kwargs.get("processing_class")

        if self.generation_steps is None:
            self.generation_steps = args.eval_steps

        # Coordinate "should we generate this step?" across all ranks.
        # Only rank 0 has the authoritative next_gen_step counter.
        should_gen = torch.tensor(
            [
                int(
                    state.is_world_process_zero
                    and state.global_step >= self.next_gen_step
                )
            ],
            dtype=torch.long,
            device=args.device,
        )
        torch.distributed.broadcast(should_gen, src=0)
        if not should_gen.item():
            return

        if state.is_world_process_zero:
            self.next_gen_step += self.generation_steps

        # Rank 0 tokenizes; broadcast shape then ids to all ranks.
        if state.is_world_process_zero:
            enc = processing_class(
                self.prompts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                padding_side="left",
            )
            input_ids = enc["input_ids"].to(args.device)
            shape_t = torch.tensor(list(input_ids.shape), device=args.device)
        else:
            shape_t = torch.zeros(2, dtype=torch.long, device=args.device)

        torch.distributed.broadcast(shape_t, src=0)

        if not state.is_world_process_zero:
            input_ids = torch.zeros(
                shape_t.tolist(), dtype=torch.long, device=args.device
            )
        torch.distributed.broadcast(input_ids, src=0)

        # All ranks generate together.
        gen_config = dict(
            eos_token_id=processing_class.eos_token_id,
            pad_token_id=processing_class.pad_token_id or processing_class.eos_token_id,
            **self.gen_config_args,
        )
        generated_ids = trainer.pipeline_generate(
            input_ids=input_ids,
            max_new_tokens=self.max_new_tokens,
            **gen_config,
        )

        # Only rank 0 decodes and logs.
        if state.is_world_process_zero:
            texts = processing_class.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            body = ""
            for prompt, decoded in zip(self.prompts, texts):
                body += (
                    prompt + " [START] " + decoded[len(prompt) + 1 :] + "\n\n---\n\n"
                )
            self.summary_writer.add_text(
                "eval-text", body, global_step=state.global_step
            )
            self.summary_writer.flush()

    def generate(self, device, model, tokenizer):
        generation_config = GenerationConfig(
            eos_token_id=tokenizer.eos_token_id,
            bos_token_id=tokenizer.bos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=self.max_new_tokens,
            return_dict_in_generate=True,
            **self.gen_config_args,
        )

        tokenizer_outputs = tokenizer(
            self.prompts,
            truncation=True,
            padding=True,
            return_tensors="pt",
            padding_side="left",
        )

        # Temporarily remove torch.compile from the model for generation.
        # The compiled model may use flex_attention + max-autotune which fails
        # during the generate loop (different shapes/cache behavior than training).
        # Saving and restoring _compiled_call_impl reverts to eager for this call.
        compiled_call = getattr(model, "_compiled_call_impl", None)
        if compiled_call is not None:
            model._compiled_call_impl = model._call_impl

        input_ids = tokenizer_outputs["input_ids"].to(device)
        with torch.inference_mode():
            outputs = model.generate(
                input_ids,
                generation_config=generation_config,
                tokenizer=tokenizer,
            )

        # Restore compiled forward for training
        if compiled_call is not None:
            model._compiled_call_impl = compiled_call

        output_text = tokenizer.batch_decode(
            outputs.sequences,
            skip_special_tokens=True,
        )

        for prompt, y in zip(self.prompts, output_text):
            s = prompt + " [START] " + y[len(prompt) + 1 :]
            yield s

__init__(summary_writer, prompts, generation_config=None, generation_steps=None, max_new_tokens=200)

Periodically generates and logs text from a set a prompts for subjective model evaluation

This may only trigger on model evaluation steps, which establishes the minimum interval between generations.

args: summary_writer: The Tensor Board SummaryWriter to log to. prompts: Either a list of prompts (List[str]) or a path to a YAML file, defining a list of prompts. generation_config: A dictionary with arguments to HF GenerationConfig generation_steps: The number of steps between generations. If None, it defaults to eval_steps max_new_tokens: The maximum new tokens to generate for each prompt.

Source code in src/forgather/ml/trainer/callbacks/textgen_callback.py
def __init__(
    self,
    summary_writer: SummaryWriter,
    prompts: List[str] | str,
    generation_config: Optional[dict] = None,
    generation_steps: Optional[int] = None,
    max_new_tokens: int = 200,
):
    """
    Periodically generates and logs text from a set a prompts for subjective model evaluation

    This may only trigger on model evaluation steps, which establishes the minimum interval between generations.

    args:
        summary_writer: The Tensor Board SummaryWriter to log to.
        prompts: Either a list of prompts (List[str]) or a path to a YAML file, defining a list of prompts.
        generation_config: A dictionary with arguments to HF GenerationConfig
        generation_steps: The number of steps between generations. If None, it defaults to eval_steps
        max_new_tokens: The maximum new tokens to generate for each prompt.
    """
    super().__init__()
    self.summary_writer = summary_writer
    if isinstance(prompts, list):
        self.prompts = prompts
    else:
        if not isinstance(prompts, str):
            raise ValueError(
                f"'prompts' must be List[str] | str, found {type(prompts)}"
            )
        with open(prompts, "r") as file:
            self.prompts = yaml.safe_load(file)

        if not isinstance(self.prompts, list):
            raise ValueError(
                f"From file {prompts}, expected 'prompts' to be a list but found {type(self.prompts)}"
            )

    for s in self.prompts:
        if not isinstance(s, str):
            raise ValueError(
                f"Expected all prompts to be strings, but found {type(s)}"
            )

    # To construct GenerationConfig, we need token ids from the model or tokenizer
    # We don't have these here, so defer construction until callback.
    if generation_config is None:
        self.gen_config_args = dict(
            do_sample=True,
            top_k=20,
            temperature=0.7,
            repetition_penalty=1.15,
        )
    else:
        self.gen_config_args = generation_config

    self.generation_steps = generation_steps
    self.max_new_tokens = max_new_tokens
    self.next_gen_step = 0

Advanced

forgather.ml.trainer.callbacks.ProfilerCallback

Profiles training steps and exports Chrome traces + summary tables.

Source code in src/forgather/ml/trainer/callbacks/profiler_callback.py
class ProfilerCallback:
    """Profiles training steps and exports Chrome traces + summary tables."""

    def __init__(
        self,
        start_step: int = 3,
        num_steps: int = 5,
        output_dir: str = "benchmarks/profiles",
        with_stack: bool = True,
        with_flops: bool = True,
        record_shapes: bool = True,
    ):
        self.start_step = start_step
        self.end_step = start_step + num_steps
        self.output_dir = output_dir
        self.with_stack = with_stack
        self.with_flops = with_flops
        self.record_shapes = record_shapes
        self._profiler = None
        self._rank = 0

    def on_train_begin(self, args, state, control, **kwargs):
        self._rank = getattr(args, "local_rank", 0) or 0
        os.makedirs(self.output_dir, exist_ok=True)
        if self._rank == 0:
            logger.info(
                f"ProfilerCallback: will profile steps {self.start_step}-{self.end_step - 1}, "
                f"output to {self.output_dir}"
            )

    def on_step_begin(self, args, state, control, **kwargs):
        step = state.global_step
        if step == self.start_step:
            if self._rank == 0:
                logger.info(f"ProfilerCallback: starting profiler at step {step}")
            self._profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                record_shapes=self.record_shapes,
                with_stack=self.with_stack,
                with_flops=self.with_flops,
            )
            self._profiler.__enter__()

    def on_step_end(self, args, state, control, **kwargs):
        step = state.global_step
        if self._profiler is not None and step >= self.end_step:
            self._profiler.__exit__(None, None, None)
            prof = self._profiler
            self._profiler = None

            trace_path = os.path.join(
                self.output_dir,
                f"real_training_rank{self._rank}.json",
            )
            prof.export_chrome_trace(trace_path)

            if self._rank == 0:
                logger.info(f"ProfilerCallback: trace saved to {trace_path}")
                print("\n" + "=" * 80)
                print(
                    f"CPU time summary (steps {self.start_step}-{self.end_step - 1}):"
                )
                print("=" * 80)
                print(
                    prof.key_averages().table(
                        sort_by="cpu_time_total",
                        row_limit=30,
                    )
                )
                print("\n" + "=" * 80)
                print(
                    f"CUDA time summary (steps {self.start_step}-{self.end_step - 1}):"
                )
                print("=" * 80)
                print(
                    prof.key_averages().table(
                        sort_by="cuda_time_total",
                        row_limit=30,
                    )
                )

forgather.ml.trainer.callbacks.DiLoCoCallback

Bases: TrainerCallback

Trainer callback that manages a DiLoCoWorker for distributed local-SGD training.

Implements the Stateful protocol for checkpoint persistence. The checkpoint manager auto-discovers Stateful callbacks and saves/restores their state.

When server_addr is empty (and DILOCO_SERVER is unset), all methods are no-ops. This allows a single training configuration to work both with and without a DiLoCo server.

Parameters:

Name Type Description Default
server_addr str

DiLoCo server address ("host:port"). Falls back to DILOCO_SERVER env var.

None
sync_every int

Local optimizer steps between syncs. Falls back to DILOCO_SYNC_EVERY env var. Default 500.

None
worker_id str

Unique worker ID. Falls back to DILOCO_WORKER_ID env var. Auto-generated if unset.

None
bf16_comm bool

Cast pseudo-gradients to bfloat16. Falls back to DILOCO_BF16_COMM env var. Default True.

None
dylu bool

Enable Dynamic Local Updates. Falls back to DILOCO_DYLU env var. Default False.

None
heartbeat_interval float

Seconds between heartbeats. Falls back to DILOCO_HEARTBEAT_INTERVAL env var. Default 30.0.

None
num_fragments int

Number of streaming fragments. Falls back to DILOCO_NUM_FRAGMENTS env var. Default 1 (no streaming).

None
timeout float

Client timeout in seconds. Default 600.

600
max_sync_retries int

Max retries for sync failures. Default 3.

3
Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
class DiLoCoCallback(TrainerCallback):
    """
    Trainer callback that manages a DiLoCoWorker for distributed local-SGD training.

    Implements the Stateful protocol for checkpoint persistence. The checkpoint
    manager auto-discovers Stateful callbacks and saves/restores their state.

    When ``server_addr`` is empty (and DILOCO_SERVER is unset), all methods are
    no-ops. This allows a single training configuration to work both with and
    without a DiLoCo server.

    Parameters
    ----------
    server_addr : str, optional
        DiLoCo server address (``"host:port"``). Falls back to
        ``DILOCO_SERVER`` env var.
    sync_every : int, optional
        Local optimizer steps between syncs. Falls back to
        ``DILOCO_SYNC_EVERY`` env var. Default ``500``.
    worker_id : str, optional
        Unique worker ID. Falls back to ``DILOCO_WORKER_ID`` env var.
        Auto-generated if unset.
    bf16_comm : bool, optional
        Cast pseudo-gradients to bfloat16. Falls back to
        ``DILOCO_BF16_COMM`` env var. Default ``True``.
    dylu : bool, optional
        Enable Dynamic Local Updates. Falls back to ``DILOCO_DYLU`` env var.
        Default ``False``.
    heartbeat_interval : float, optional
        Seconds between heartbeats. Falls back to
        ``DILOCO_HEARTBEAT_INTERVAL`` env var. Default ``30.0``.
    num_fragments : int, optional
        Number of streaming fragments. Falls back to
        ``DILOCO_NUM_FRAGMENTS`` env var. Default ``1`` (no streaming).
    timeout : float, optional
        Client timeout in seconds. Default ``600``.
    max_sync_retries : int, optional
        Max retries for sync failures. Default ``3``.
    """

    def __init__(
        self,
        server_addr: Optional[str] = None,
        sync_every: Optional[int] = None,
        worker_id: Optional[str] = None,
        bf16_comm: Optional[bool] = None,
        dylu: Optional[bool] = None,
        heartbeat_interval: Optional[float] = None,
        num_fragments: Optional[int] = None,
        timeout: float = 600,
        max_sync_retries: int = 3,
    ):
        # Resolve with env var fallbacks
        self.server_addr = server_addr or os.environ.get("DILOCO_SERVER", "")
        self.sync_every = (
            sync_every if sync_every is not None else _env_int("DILOCO_SYNC_EVERY", 500)
        )
        self.worker_id = worker_id or os.environ.get("DILOCO_WORKER_ID", "") or None
        self.bf16_comm = (
            bf16_comm if bf16_comm is not None else _env_bool("DILOCO_BF16_COMM", True)
        )
        self.dylu = dylu if dylu is not None else _env_bool("DILOCO_DYLU", False)
        self.heartbeat_interval = (
            heartbeat_interval
            if heartbeat_interval is not None
            else _env_float("DILOCO_HEARTBEAT_INTERVAL", 30.0)
        )
        self.num_fragments = (
            num_fragments
            if num_fragments is not None
            else _env_int("DILOCO_NUM_FRAGMENTS", 1)
        )
        self.timeout = timeout
        self.max_sync_retries = max_sync_retries

        # Worker instance (created in on_train_begin)
        self._worker = None

        # Deferred checkpoint state (loaded before on_train_begin)
        self._pending_state: Optional[Dict[str, Any]] = None

    @property
    def active(self) -> bool:
        """Whether DiLoCo integration is configured (server_addr is set)."""
        return bool(self.server_addr)

    def on_train_begin(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        """Create and start the DiLoCoWorker."""
        if not self.active:
            logger.info("DiLoCoCallback: no server_addr configured, running as no-op")
            return

        model = kwargs.get("model")
        optimizer = kwargs.get("optimizer")
        if model is None or optimizer is None:
            logger.error(
                "DiLoCoCallback: model or optimizer not provided in kwargs. "
                "Cannot initialize DiLoCoWorker."
            )
            return

        from forgather.ml.diloco.worker import DiLoCoWorker

        self._worker = DiLoCoWorker(
            model=model,
            optimizer=optimizer,
            server_addr=self.server_addr,
            sync_every=self.sync_every,
            worker_id=self.worker_id,
            bf16_comm=self.bf16_comm,
            timeout=self.timeout,
            dylu=self.dylu,
            heartbeat_interval=self.heartbeat_interval,
            num_fragments=self.num_fragments,
            max_sync_retries=self.max_sync_retries,
        )
        self._worker.start()

        # Apply deferred checkpoint state
        if self._pending_state is not None:
            self._apply_pending_state()
            self._pending_state = None

        logger.info(
            f"DiLoCoCallback: worker started "
            f"(server={self.server_addr}, sync_every={self.sync_every})"
        )

    def _apply_pending_state(self):
        """Apply deferred state from load_state_dict to the active worker."""
        if self._worker is None or self._pending_state is None:
            return

        st = self._pending_state
        self._worker._sync_count = st.get("sync_count", 0)
        self._worker._local_step = st.get("local_step", 0)
        self._worker._total_sync_time = st.get("total_sync_time", 0.0)
        self._worker._sync_retries = st.get("sync_retries", 0)
        self._worker._reconnections = st.get("reconnections", 0)
        self._worker._dylu_adjustments = st.get("dylu_adjustments", 0)
        self._worker._fragment_syncs = st.get("fragment_syncs", 0)

        # Restore sync_every (may have been adjusted by DyLU)
        if "sync_every" in st:
            self._worker.sync_every = st["sync_every"]

        logger.info(
            f"DiLoCoCallback: restored state from checkpoint "
            f"(sync_count={st.get('sync_count', 0)}, "
            f"local_step={st.get('local_step', 0)})"
        )

    def on_log(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        logs: Optional[dict] = None,
        **kwargs,
    ):
        """Inject DiLoCo sync metrics into the logs dict."""
        if self._worker is not None and logs is not None:
            logs.update(self._worker.sync_metrics)

    def on_train_end(
        self,
        args: MinimalTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        """Stop the DiLoCoWorker."""
        if self._worker is not None:
            self._worker.stop()
            logger.info("DiLoCoCallback: worker stopped")
            self._worker = None

    # -- Stateful protocol --

    def state_dict(self) -> Dict[str, Any]:
        """Save DiLoCo state for checkpointing.

        Does NOT save global_params snapshot -- the server provides fresh
        params when the worker re-registers on resume.
        """
        if self._worker is None:
            return {}

        return {
            "sync_count": self._worker._sync_count,
            "local_step": self._worker._local_step,
            "sync_every": self._worker.sync_every,
            "worker_id": self._worker.worker_id,
            "total_sync_time": self._worker._total_sync_time,
            "sync_retries": self._worker._sync_retries,
            "reconnections": self._worker._reconnections,
            "dylu_adjustments": self._worker._dylu_adjustments,
            "fragment_syncs": self._worker._fragment_syncs,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Defer state restoration until on_train_begin.

        Checkpoint loading happens during _prepare() before on_train_begin,
        so the worker doesn't exist yet. We store the state and apply it
        once the worker is created.
        """
        if not state_dict:
            return
        self._pending_state = state_dict
        logger.debug("DiLoCoCallback: checkpoint state deferred until on_train_begin")

active property

Whether DiLoCo integration is configured (server_addr is set).

on_train_begin(args, state, control, **kwargs)

Create and start the DiLoCoWorker.

Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
def on_train_begin(
    self,
    args: MinimalTrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    **kwargs,
):
    """Create and start the DiLoCoWorker."""
    if not self.active:
        logger.info("DiLoCoCallback: no server_addr configured, running as no-op")
        return

    model = kwargs.get("model")
    optimizer = kwargs.get("optimizer")
    if model is None or optimizer is None:
        logger.error(
            "DiLoCoCallback: model or optimizer not provided in kwargs. "
            "Cannot initialize DiLoCoWorker."
        )
        return

    from forgather.ml.diloco.worker import DiLoCoWorker

    self._worker = DiLoCoWorker(
        model=model,
        optimizer=optimizer,
        server_addr=self.server_addr,
        sync_every=self.sync_every,
        worker_id=self.worker_id,
        bf16_comm=self.bf16_comm,
        timeout=self.timeout,
        dylu=self.dylu,
        heartbeat_interval=self.heartbeat_interval,
        num_fragments=self.num_fragments,
        max_sync_retries=self.max_sync_retries,
    )
    self._worker.start()

    # Apply deferred checkpoint state
    if self._pending_state is not None:
        self._apply_pending_state()
        self._pending_state = None

    logger.info(
        f"DiLoCoCallback: worker started "
        f"(server={self.server_addr}, sync_every={self.sync_every})"
    )

on_log(args, state, control, logs=None, **kwargs)

Inject DiLoCo sync metrics into the logs dict.

Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
def on_log(
    self,
    args: MinimalTrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    logs: Optional[dict] = None,
    **kwargs,
):
    """Inject DiLoCo sync metrics into the logs dict."""
    if self._worker is not None and logs is not None:
        logs.update(self._worker.sync_metrics)

on_train_end(args, state, control, **kwargs)

Stop the DiLoCoWorker.

Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
def on_train_end(
    self,
    args: MinimalTrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    **kwargs,
):
    """Stop the DiLoCoWorker."""
    if self._worker is not None:
        self._worker.stop()
        logger.info("DiLoCoCallback: worker stopped")
        self._worker = None

state_dict()

Save DiLoCo state for checkpointing.

Does NOT save global_params snapshot -- the server provides fresh params when the worker re-registers on resume.

Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
def state_dict(self) -> Dict[str, Any]:
    """Save DiLoCo state for checkpointing.

    Does NOT save global_params snapshot -- the server provides fresh
    params when the worker re-registers on resume.
    """
    if self._worker is None:
        return {}

    return {
        "sync_count": self._worker._sync_count,
        "local_step": self._worker._local_step,
        "sync_every": self._worker.sync_every,
        "worker_id": self._worker.worker_id,
        "total_sync_time": self._worker._total_sync_time,
        "sync_retries": self._worker._sync_retries,
        "reconnections": self._worker._reconnections,
        "dylu_adjustments": self._worker._dylu_adjustments,
        "fragment_syncs": self._worker._fragment_syncs,
    }

load_state_dict(state_dict)

Defer state restoration until on_train_begin.

Checkpoint loading happens during _prepare() before on_train_begin, so the worker doesn't exist yet. We store the state and apply it once the worker is created.

Source code in src/forgather/ml/trainer/callbacks/diloco_callback.py
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    """Defer state restoration until on_train_begin.

    Checkpoint loading happens during _prepare() before on_train_begin,
    so the worker doesn't exist yet. We store the state and apply it
    once the worker is created.
    """
    if not state_dict:
        return
    self._pending_state = state_dict
    logger.debug("DiLoCoCallback: checkpoint state deferred until on_train_begin")

forgather.ml.trainer.callbacks.ResumableSummaryWriter

Bases: TrainerCallback, Stateful

A lazy, resumable wrapper around TensorBoard SummaryWriter.

When registered as a callback, it persists the active logging directory in checkpoint metadata via the Stateful protocol. On resume from a checkpoint, it redirects logging to the original directory and uses SummaryWriter's purge_step to discard stale events recorded after the checkpoint step.

When used as a SummaryWriter (passed to TBLogger, GradLogger, etc.), it proxies method calls to the underlying writer, constructing it lazily on first use.

Parameters:

Name Type Description Default
log_dir str

Logging directory path (typically ns.logging_dir from the template system).

required
Source code in src/forgather/ml/trainer/callbacks/resumable_summary_writer.py
class ResumableSummaryWriter(TrainerCallback, Stateful):
    """
    A lazy, resumable wrapper around TensorBoard SummaryWriter.

    When registered as a callback, it persists the active logging directory
    in checkpoint metadata via the Stateful protocol. On resume from a
    checkpoint, it redirects logging to the original directory and uses
    SummaryWriter's ``purge_step`` to discard stale events recorded after
    the checkpoint step.

    When used as a SummaryWriter (passed to TBLogger, GradLogger, etc.),
    it proxies method calls to the underlying writer, constructing it
    lazily on first use.

    Parameters
    ----------
    log_dir : str
        Logging directory path (typically ``ns.logging_dir``
        from the template system).
    """

    def __init__(self, log_dir: str):
        super().__init__()
        self._new_log_dir = log_dir
        self._active_log_dir = log_dir
        self._writer: SummaryWriter | None = None
        self._purge_step: int | None = None
        self._resumed = False

    # -- Stateful protocol --------------------------------------------------

    def state_dict(self) -> dict:
        return {"log_dir": self._active_log_dir}

    def load_state_dict(self, state_dict: dict) -> None:
        original_dir = state_dict.get("log_dir")
        if original_dir and os.path.isdir(original_dir):
            logger.info(
                "ResumableSummaryWriter: resuming into original log dir: %s",
                original_dir,
            )
            self._active_log_dir = original_dir
            self._resumed = True
        else:
            logger.warning(
                "ResumableSummaryWriter: original log dir not found (%s), "
                "using new directory: %s",
                original_dir,
                self._new_log_dir,
            )

    # -- TrainerCallback protocol -------------------------------------------

    def on_train_begin(self, args, state, control, **kwargs):
        if self._resumed and state.global_step > 0:
            self._purge_step = state.global_step
            logger.info(
                "ResumableSummaryWriter: will purge TensorBoard events "
                "after step %d",
                self._purge_step,
            )

    # -- Lazy writer construction -------------------------------------------

    def _ensure_writer(self) -> SummaryWriter:
        if self._writer is None:
            kwargs: dict = {}
            if self._purge_step is not None:
                kwargs["purge_step"] = self._purge_step
            os.makedirs(self._active_log_dir, exist_ok=True)
            self._writer = SummaryWriter(self._active_log_dir, **kwargs)
            logger.info(
                "ResumableSummaryWriter: created SummaryWriter at %s",
                self._active_log_dir,
            )
        return self._writer

    # -- SummaryWriter method proxies ---------------------------------------

    def add_scalar(self, *args, **kwargs):
        return self._ensure_writer().add_scalar(*args, **kwargs)

    def add_scalars(self, *args, **kwargs):
        return self._ensure_writer().add_scalars(*args, **kwargs)

    def add_text(self, *args, **kwargs):
        return self._ensure_writer().add_text(*args, **kwargs)

    def add_histogram(self, *args, **kwargs):
        return self._ensure_writer().add_histogram(*args, **kwargs)

    def add_image(self, *args, **kwargs):
        return self._ensure_writer().add_image(*args, **kwargs)

    def add_images(self, *args, **kwargs):
        return self._ensure_writer().add_images(*args, **kwargs)

    def add_figure(self, *args, **kwargs):
        return self._ensure_writer().add_figure(*args, **kwargs)

    def add_graph(self, *args, **kwargs):
        return self._ensure_writer().add_graph(*args, **kwargs)

    def flush(self):
        if self._writer is not None:
            self._writer.flush()

    def close(self):
        if self._writer is not None:
            self._writer.close()
            self._writer = None

    def __del__(self):
        self.close()