DiLoCo Architecture & Maintainer Guide¶
This document describes the internal architecture of Forgather's DiLoCo system. It is intended for developers who need to understand how the system works, troubleshoot issues, or implement new features.
For user-facing documentation (CLI usage, quick start, API examples), see diloco.md.
Contents¶
- System Overview
- Source Layout
- Data Structures and State
- Wire Protocol
- Threading Model
- Synchronization Modes
- Streaming DiLoCo (Fragments)
- Outer Optimizer Integration
- Lifecycle and Data Flow
- Fault Tolerance
- Server State Persistence
- CLI Layer
- Testing
- Troubleshooting
- Extension Points
- Known Limitations
System Overview¶
DiLoCo is a client-server distributed training system. A central server holds the global model parameters and outer optimizer state. Multiple workers (each running any Forgather trainer) train locally and periodically submit pseudo-gradients to the server over HTTP. The server applies an outer optimizer step and returns updated global parameters.
DiLoCoServer (CPU-only process)
+---------------------------------+
| _param_list: ParameterList |
| outer_optimizer: torch.optim.SGD |
| _workers: Dict[str, WorkerInfo] |
| ThreadingHTTPServer |
+---------------------------------+
|
HTTP (torch.save payloads)
|
+-------------------+-------------------+
| |
DiLoCoWorker (GPU machine) DiLoCoWorker (GPU machine)
+---------------------------+ +---------------------------+
| model: nn.Module | | model: nn.Module |
| optimizer: inner optimizer | | optimizer: inner optimizer |
| _global_params: CPU snapshot| | _global_params: CPU snapshot|
| DiLoCoClient | | DiLoCoClient |
| optimizer post-step hook | | optimizer post-step hook |
+---------------------------+ +---------------------------+
The system has three operating dimensions that can be combined:
| Dimension | Options | Key mechanism |
|---|---|---|
| Sync mode | Synchronous / Asynchronous | Server barrier vs immediate apply |
| Momentum | Standard / Delayed Nesterov (DN) | Buffer submissions before momentum |
| Communication | Full-model / Streaming (fragments) | Background thread per fragment |
Source Layout¶
src/forgather/ml/diloco/
__init__.py Exports: DiLoCoServer, DiLoCoClient, DiLoCoWorker, FragmentManager, HealthMonitor
server.py HTTP server, outer optimizer, sync barrier, fragment handling, fault tolerance
client.py HTTP client, tensor serialization, request construction, retry logic
worker.py Optimizer hook, pseudo-gradient computation, streaming, reconnection
fragments.py FragmentManager: parameter splitting, scheduling
health.py HealthMonitor: background worker liveness detection
src/forgather/cli/
diloco.py CLI command handlers (_server_cmd, _status_cmd, _worker_cmd)
diloco_args.py Argument parser (create_diloco_parser)
tests/unit/ml/diloco/
test_server.py Server: outer optimizer correctness, serialization
test_server_client.py HTTP round-trip: register, submit, status
test_worker.py Worker: pseudo-gradients, optimizer hooks, full sync cycle
test_async.py Async mode, DN momentum, DyLU
test_streaming.py FragmentManager, fragment server/client, streaming worker
test_fault_tolerance.py Health monitor, worker death, barrier release, reconnection
docs/trainers/
diloco.md User-facing documentation
diloco-architecture.md This file
Data Structures and State¶
Server (DiLoCoServer)¶
Global parameters:
_param_names: List[str] # Ordered parameter names (matches model.state_dict() order)
_param_list: nn.ParameterList # Global params as nn.Parameter (float32, CPU, requires_grad=False)
_param_name_to_idx: Dict[str, int] # Reverse lookup: name -> index in _param_list
Parameters are stored as nn.Parameter objects inside a ParameterList so that
any standard torch.optim.Optimizer can be constructed against them. All
parameters are float32 on CPU regardless of what workers send (incoming bf16
pseudo-gradients are cast to float32 before accumulation).
Worker registry:
_workers: Dict[str, WorkerInfo] # worker_id -> metadata
_workers_lock: threading.Lock # Protects _workers dict
WorkerInfo is a dataclass with: worker_id, hostname, registered_at,
last_heartbeat, sync_round (worker's count), last_sync_server_round
(server round at last sync), steps_per_second, extra.
Synchronous state:
_sync_round: int # Global monotonic counter
_pending_pseudograds: Dict[str, Dict[str, Tensor]] # worker_id -> pseudograds (waiting for barrier)
_sync_cond: threading.Condition # Barrier notification
_completed_rounds: Dict[int, Dict[str, Tensor]] # round_number -> result (cached for late wakers)
Async state:
_async_lock: threading.Lock # Serializes async submissions
_total_submissions: int # Total submissions received
_dn_grad_buffer: List[Dict[str, Tensor]] # Delayed Nesterov buffer
Fragment state (sync + async):
_fragment_pending: Dict[int, Dict[str, Dict[str, Tensor]]] # frag_id -> worker_id -> pseudograds
_fragment_rounds: Dict[int, int] # frag_id -> current round
_completed_fragment_rounds: Dict[Tuple[int,int], Dict[str, Tensor]] # (frag_id, round) -> result
_fragment_submissions: int # Total fragment submissions
Fault tolerance state:
_round_expected_workers: Optional[set] # Worker IDs expected for current sync round (None before first submission)
_health_monitor: Optional[HealthMonitor] # Background health checker (None if heartbeat_timeout=0)
_total_worker_deaths: int # Cumulative dead worker count
heartbeat_timeout: float # Seconds before a worker is considered dead (0 = disabled)
min_workers: int # Floor for num_workers during death handling
_round_expected_workers is the key data structure for fault-tolerant barriers.
It is snapshotted from _workers.keys() when a sync round completes (or lazily
on the first submission of a round). Workers that join mid-round are not added
to the current snapshot -- they participate starting next round. When a worker
dies, it is removed from this set, which may cause the barrier to release early
if the remaining submissions satisfy the reduced expected count.
Worker (DiLoCoWorker)¶
model: nn.Module # Live model (on GPU)
optimizer: torch.optim.Optimizer # Inner optimizer (AdamW, etc.)
client: DiLoCoClient # HTTP client for server communication
_global_params: Dict[str, Tensor] # CPU snapshot taken after each sync
_local_step: int # Steps since last sync (reset to 0 after sync)
_sync_count: int # Completed sync rounds
_hooks: List # Optimizer post-step hook handles
_fragment_manager: Optional[FragmentManager] # None when num_fragments <= 1
_inflight_thread: Optional[Thread] # Background thread for current fragment
_inflight_result: Optional[Tuple[int, Optional[Dict[str, Tensor]]]] # (frag_id, result)
max_sync_retries: int # Max retry attempts per sync (default: 3)
_sync_retries: int # Cumulative sync retry count
_reconnections: int # Cumulative reconnection count
FragmentManager¶
fragments: List[List[str]] # fragment_id -> list of param names
param_to_fragment: Dict[str, int] # param_name -> fragment_id
num_fragments: int
Parameters are split into contiguous groups of roughly equal count (not equal
tensor size). The first total % N fragments get one extra parameter.
Wire Protocol¶
All communication uses HTTP/1.1 over TCP. The server runs a
ThreadingHTTPServer (one thread per request).
Endpoints¶
| Method | Path | Request | Response |
|---|---|---|---|
| POST | /register |
JSON: {worker_id, hostname, ...} |
Tensor: global params |
| POST | /submit_pseudograd |
Binary: header + tensors | Tensor: updated global params |
| POST | /submit_fragment_pseudograd |
Binary: header + tensors | Tensor: updated fragment params |
| GET | /global_params |
(none) | Tensor: global params |
| POST | /heartbeat |
JSON: {worker_id, steps_per_second} |
JSON: {status, sync_round, recommended_sync_every?} |
| POST | /deregister |
JSON: {worker_id} |
JSON: {status: "ok"} |
| GET | /status |
(none) | JSON: server state |
Binary tensor format (submit endpoints)¶
Pseudo-gradient submissions use a length-prefixed header format:
[4 bytes: header length (big-endian uint32)]
[header_length bytes: JSON header (UTF-8)]
[remaining bytes: torch.save payload]
The JSON header contains worker_id and optionally fragment_id. The tensor
payload is a serialized Dict[str, torch.Tensor] produced by
torch.save(state_dict, BytesIO).
Tensor serialization¶
Both client and server use torch.save / torch.load with map_location="cpu"
and weights_only=True. This is the same format used by PyTorch checkpoints.
Response payloads use Content-Type: application/octet-stream.
Client retry behavior¶
- JSON requests (
register,heartbeat,deregister): retried up tomax_retriestimes with exponential backoff (default: 3 retries, 1s base) - Tensor requests (
submit_pseudograd,submit_fragment_pseudograd,get_global_params): configurable retries via theretriesparameter on_request_tensor(). Default is 0 (no retries). The worker's_sync()method handles retry at a higher level via_reconnect()+ resubmit. - Default timeout: 600 seconds (sync submissions may block for a long time at the server barrier)
Threading Model¶
Server threads¶
The server uses ThreadingHTTPServer which spawns a new daemon thread for each
incoming HTTP request. This is required because in synchronous mode, multiple
worker requests block concurrently waiting at the barrier.
Critical locking:
| Lock | Protects | Used by |
|---|---|---|
_sync_cond (Condition) |
_pending_pseudograds, _completed_rounds, _fragment_pending, _fragment_rounds, _completed_fragment_rounds |
Sync submit, fragment sync submit |
_async_lock (Lock) |
All async state, global param reads/writes in async mode | Async submit, async fragment submit, register (async), get_global_params (async) |
_workers_lock (Lock) |
_workers dict |
All handlers that read/update worker info |
Lock ordering (always acquire in this order to avoid deadlocks):
_sync_condor_async_lock(never both at once)_workers_lock(acquired inside the above)
In sync mode, _sync_cond is used as a Condition (with wait/notify_all),
not just a lock. In async mode, _async_lock is a simple Lock (no wait
needed). The two modes are mutually exclusive; the server either uses
_sync_cond or _async_lock, never both for the same submission type.
Server health monitor thread¶
When heartbeat_timeout > 0, the server creates a HealthMonitor (from
health.py) that runs a daemon thread checking worker liveness every
check_interval seconds (default: heartbeat_timeout / 3). On each check it
reads _workers under _workers_lock, compares last_heartbeat timestamps to
the current time, and calls _handle_worker_death() for any worker that
exceeds the timeout. The health monitor is started in start() / run() and
stopped in stop().
Worker threads¶
The worker has up to two background threads:
-
Heartbeat thread: sends periodic heartbeats to the server to report training speed and maintain liveness. Runs when
heartbeat_interval > 0(default: 30s). Stopped via_heartbeat_stopEvent. When DyLU is enabled, the worker also reads backrecommended_sync_everyfrom the heartbeat response and adjusts itssync_every. -
Fragment inflight thread (streaming mode only): submits one fragment's pseudo-gradients to the server in the background. At most one inflight thread exists at any time. The main training thread joins this thread before starting the next fragment submission.
Invariant: At most one background fragment is in-flight. Before starting a
new fragment submission, _wait_and_apply_inflight_fragment() joins the previous
thread and applies its result to the model. This prevents concurrent model
modifications and simplifies reasoning about parameter consistency.
Main thread: [train]--[join prev, compute pg, launch bg]--[train]--[join, compute, launch]--...
Fragment thread: [submit to server] [submit to server]
Synchronization Modes¶
Synchronous mode (default)¶
All workers must submit pseudo-gradients before any receives the updated global parameters. The server uses a Condition variable as a barrier:
Server thread per worker:
1. Acquire _sync_cond
2. Store pseudograds in _pending_pseudograds[worker_id]
3. Record current _sync_round as my_round
4. If all workers submitted:
a. Average pseudo-gradients
b. Apply outer optimizer
c. Store result in _completed_rounds[my_round]
d. Increment _sync_round
e. notify_all()
5. While my_round not in _completed_rounds: wait()
6. Return _completed_rounds[my_round]
Per-round result caching: _completed_rounds maps round number to global
params. This prevents a race where a late-waking thread reads the wrong round's
result (the server might already be in the next round). Stale entries are pruned
to keep only the last two rounds.
Asynchronous mode¶
Each worker's pseudo-gradients are applied immediately under _async_lock. No
barrier, no waiting. The worker receives updated global params in the response.
Staleness tracking: When a worker submits, the server computes staleness as
current_sync_round - worker.last_sync_server_round. High staleness means the
pseudo-gradient was computed against parameters that are many updates behind.
Staleness is logged but not currently used for weighting or rejection.
Delayed Nesterov (DN)¶
A server-side strategy for async mode. When dn_buffer_size > 0:
- Intermediate submissions (buffer not full): Apply direct gradient descent
param -= lr * gradwithout calling the optimizer (no momentum update) - Buffer-full submissions (every N-th): Average the buffer, set as
.grad, callouter_optimizer.step()(full momentum update), clear buffer
This prevents momentum from tracking stale individual worker directions.
Dynamic Local Updates (DyLU)¶
Server-side computation, communicated via heartbeats:
Workers report steps_per_second in heartbeats. The server computes the
recommended sync interval proportional to the worker's relative speed and returns
it in the heartbeat response. Workers that opt in (dylu=True) adjust their
sync_every accordingly.
Streaming DiLoCo (Fragments)¶
Fragment scheduling¶
FragmentManager.get_fragment_schedule(local_step, sync_every) determines which
fragment syncs at a given step:
fragment_interval = sync_every // num_fragments
fragment_idx = (local_step // fragment_interval - 1) % num_fragments
Example with sync_every=600, num_fragments=3:
- fragment_interval = 200
- Step 200: (200/200 - 1) % 3 = 0 -> fragment 0
- Step 400: (400/200 - 1) % 3 = 1 -> fragment 1
- Step 600: (600/200 - 1) % 3 = 2 -> fragment 2
Background sync flow¶
# In _post_step_hook, when a fragment is scheduled:
def _sync_fragment(fragment_id):
# 1. Wait for previous in-flight fragment to complete, apply its result
self._wait_and_apply_inflight_fragment()
# 2. Compute pseudo-gradients for this fragment (CPU, main thread)
pseudograds = self._fragment_manager.compute_fragment_pseudogradients(...)
# 3. Launch background thread to submit to server
self._inflight_thread = Thread(target=self._submit_fragment_background, ...)
self._inflight_thread.start()
# Main thread returns immediately, training continues
Server-side fragment handling¶
The server has separate handlers for fragment submissions:
-
Sync fragment: Per-fragment barrier using
_sync_cond. Each fragment has its own round counter (_fragment_rounds[frag_id]). When all workers submit the same fragment, the server applies the outer optimizer to just that fragment's parameters. -
Async fragment: Under
_async_lock, set.gradon fragment parameters only, callstep(), return the updated fragment params.
Outer optimizer correctness with fragments: Only the fragment's parameters
have .grad set. PyTorch optimizers skip parameters with None grad. SGD's
momentum buffers for other parameters remain untouched because step() only
processes parameters that have a non-None .grad.
Fragment-standard mode boundary¶
When num_fragments=1 (default), the worker's _fragment_manager is None.
The _post_step_hook takes the standard path (full-model sync via _sync())
with zero overhead. No background threads are created. This is a hard branch in
_post_step_hook:
if self._fragment_manager is None:
# Standard: full model sync at sync_every
if self._local_step >= self.sync_every:
self._sync()
else:
# Streaming: check fragment schedule
frag_id = self._fragment_manager.get_fragment_schedule(...)
if frag_id is not None:
self._sync_fragment(frag_id)
Outer Optimizer Integration¶
The outer optimizer is a standard torch.optim.Optimizer instance. The server
constructs it by passing _param_list.parameters() to a factory function:
factory = outer_optimizer_factory or _default_outer_optimizer_factory
self.outer_optimizer = factory(self._param_list.parameters())
Default: torch.optim.SGD(params, lr=0.7, momentum=0.9, nesterov=True)
To apply pseudo-gradients:
- Average pseudo-gradients across workers
- Set
_param_list[i].grad = avg_gradfor each parameter - Call
self.outer_optimizer.step() - Call
self.outer_optimizer.zero_grad()
This pattern works with any optimizer (Adam, Adafactor, etc.) without code
changes. The optimizer's state_dict() is included in server state saves.
For fragments, only the fragment's parameters have .grad set. All other
parameters have None grad. PyTorch optimizers iterate all parameter groups
but skip parameters where grad is None.
LR extraction for DN: The server extracts _outer_lr from the optimizer's
first param group for use in DN direct gradient steps
(param -= lr * grad). This assumes a single learning rate. If different
parameter groups have different LRs, DN would need modification.
Lifecycle and Data Flow¶
Full sync round (synchronous, no fragments)¶
Worker Server
| |
|-- register(worker_id, info) --------->| POST /register
|<--- global_params (torch.save) -------|
| |
| [save CPU snapshot of global params] |
| [train for sync_every steps] |
| |
| [compute pseudograds on CPU] |
| [cast to bf16 if enabled] |
| |
|-- submit_pseudograd(wid, pgs) ------->| POST /submit_pseudograd
| | [store in _pending_pseudograds]
| | [wait at barrier...]
| | [all workers submitted]
| | [average pseudograds]
| | [set as .grad on _param_list]
| | [outer_optimizer.step()]
|<--- updated global_params ------------|
| |
| [copy global params to model (GPU)] |
| [save new CPU snapshot] |
| [reset _local_step = 0] |
| [continue training...] |
Streaming sync round (3 fragments)¶
Worker Server
| |
| [train 200 steps] |
| |
| [compute frag 0 pseudograds, CPU] |
| [launch background thread] |
| |-- submit_fragment(wid, 0, pgs)->| POST /submit_fragment_pseudograd
| | | [barrier for frag 0]
| | | [apply outer opt to frag 0 params]
| [train 200 steps] |
| |<-- frag 0 updated params -------|
| |
| [join bg thread, apply frag 0 result]|
| [compute frag 1 pseudograds] |
| [launch background thread] |
| |-- submit_fragment(wid, 1, pgs)->|
| [train 200 steps] |
| |<-- frag 1 updated params -------|
| |
| [join bg thread, apply frag 1 result]|
| [compute frag 2 pseudograds] |
| [launch background thread] |
| |-- submit_fragment(wid, 2, pgs)->|
| [train 200 steps (next round)] |
| |<-- frag 2 updated params -------|
| ... |
Worker startup (start() / __enter__)¶
- Send registration request to server
- Receive global parameters from server
- Copy global params into model (
_apply_global_params) - Save CPU snapshot (
_save_global_params_snapshot) - Register optimizer post-step hook
- Start heartbeat thread if
heartbeat_interval > 0(default: 30s)
Worker shutdown (stop() / __exit__)¶
- Wait for any in-flight fragment to complete
- Stop heartbeat thread
- Remove optimizer hooks
- Send deregistration request to server
Fault Tolerance¶
The system handles four fault scenarios: worker death, dynamic joining, worker reconnection after transient failures, and server restart recovery.
Worker death detection¶
The HealthMonitor (in health.py) runs a background daemon thread on the
server. Every check_interval seconds (default: heartbeat_timeout / 3) it
scans all registered workers:
for each worker in _workers:
if now - worker.last_heartbeat > heartbeat_timeout:
server._handle_worker_death(worker_id)
Workers update last_heartbeat via the /heartbeat endpoint. The heartbeat
thread runs unconditionally on workers when heartbeat_interval > 0 (default
30s), regardless of DyLU setting.
Worker death handling (_handle_worker_death)¶
When a worker is declared dead (by HealthMonitor or explicit deregistration):
1. Acquire _sync_cond -> _workers_lock (lock ordering preserved)
2. Remove worker from _workers registry
3. Increment _total_worker_deaths
4. Update num_workers = max(min_workers, remaining)
5. Remove worker's pending pseudo-gradients (if any)
6. Remove worker from _round_expected_workers set
7. Re-evaluate full-model sync barrier:
- expected = len(_round_expected_workers)
- if submitted >= expected: apply outer optimizer, complete round
8. Re-evaluate per-fragment barriers (for each active fragment):
- Remove dead worker's fragment submission
- If remaining submissions satisfy expected count: apply and complete
9. notify_all() to wake waiting threads
This ensures that a worker dying mid-sync doesn't deadlock the remaining workers. The barrier dynamically adjusts to the reduced worker count.
min_workers floor: The num_workers field never drops below
min_workers (default 1). This prevents a scenario where all workers die and
the barrier releases with zero submissions.
Dynamic worker joining¶
New workers can register at any time via /register. The registration handler:
- If
_round_expected_workersalready exists (mid-round), the new worker is not added to it. The new worker participates starting the next round. - If more workers register than the current
num_workers,num_workersis increased to accommodate them. - The new worker receives the current global parameters and begins local training immediately.
This design prevents a new worker from blocking the current round's barrier (which would deadlock because existing workers already have the expected count computed).
Worker reconnection¶
Workers handle transient connection failures via retry with reconnection:
# In _sync() - retry loop
for attempt in range(max_sync_retries + 1):
try:
new_global = client.submit_pseudogradients(worker_id, pseudograds)
break
except ConnectionError:
if attempt < max_sync_retries:
sleep(retry_delay) # exponential backoff: 2s, 4s, 8s, ...
retry_delay *= 2
self._reconnect() # re-register, get fresh global params
pseudograds = self._compute_pseudogradients() # recompute
else:
# Skip this sync round, continue training
The _reconnect() method re-registers the worker with the server, receives
the current global parameters, and updates the local snapshot. This handles:
- Server restart: Server comes back with saved state, worker re-registers and gets the latest global params.
- Network partition: Temporary disconnection resolves, worker re-registers.
- Worker eviction: If the server's HealthMonitor evicted this worker, re-registration adds it back.
After reconnection, pseudo-gradients are recomputed against the new global params snapshot to avoid stale deltas.
Client tensor retry¶
The DiLoCoClient._request_tensor() method accepts an optional retries
parameter. When set (used by internal reconnection logic), failed tensor
requests are retried with exponential backoff before raising ConnectionError.
By default (retries=0), tensor requests fail immediately (they are large,
stateful payloads where blind retry is not always appropriate).
Interaction with async mode¶
In async mode, there is no barrier to deadlock, so worker death is less
critical. The _handle_worker_death() method still removes the worker from the
registry and adjusts num_workers. The HealthMonitor runs identically in both
modes.
Status monitoring¶
The /status endpoint includes fault tolerance fields:
heartbeat_timeout: configured timeout valuemin_workers: configured minimum workerstotal_worker_deaths: cumulative death count
Worker sync_metrics include sync_retries and reconnections counters.
Server State Persistence¶
save_state(path) saves a dict via torch.save:
{
"global_params": Dict[str, Tensor], # Current global parameters
"outer_optimizer": optimizer.state_dict(),
"sync_round": int,
"num_workers": int,
"param_names": List[str],
"async_mode": bool,
"total_submissions": int,
}
load_state(path) restores parameters and optimizer state. Note that
weights_only=False is used for loading because the optimizer state dict
contains non-tensor values.
Automatic save: when save_dir is set, the server saves every
save_every_n_rounds sync rounds. Two files are written: a versioned file
(diloco_server_state_round{N}.pt) and a diloco_server_state_latest.pt
symlink.
Fragment state is not persisted. Fragment round counters and pending submissions are transient. After a server restart, workers should re-register and start a fresh sync cycle.
CLI Layer¶
diloco_args.py¶
Builds the argument parser with three subcommands: server, status, worker.
The parser is created by create_diloco_parser(global_args) which is called
from fgcli.py.
diloco.py¶
diloco_cmd(args) dispatches to _server_cmd, _status_cmd, or _worker_cmd.
_worker_cmd does not create a DiLoCoWorker directly. Instead, it sets
environment variables and spawns a subprocess running forgather train:
DILOCO_SERVER -> server address
DILOCO_SYNC_EVERY -> sync interval
DILOCO_BF16_COMM -> "0" or "1"
DILOCO_DYLU -> "0" or "1"
DILOCO_HEARTBEAT_INTERVAL -> seconds
DILOCO_NUM_FRAGMENTS -> number of fragments
DILOCO_WORKER_ID -> optional worker ID
The training script reads these environment variables and constructs a
DiLoCoWorker internally. This keeps the CLI layer thin and avoids
reimplementing training logic.
Testing¶
Test organization¶
| File | Focus | Approach |
|---|---|---|
test_server.py |
Outer optimizer math, serialization | Direct method calls, no HTTP |
test_server_client.py |
HTTP round-trip | Real ThreadingHTTPServer, real DiLoCoClient |
test_worker.py |
Pseudo-gradient computation, hook lifecycle | Full server + worker integration |
test_async.py |
Async mode, DN momentum, DyLU | Multi-threaded workers against real server |
test_streaming.py |
FragmentManager, fragment endpoints, streaming worker | Unit + integration |
test_fault_tolerance.py |
Health monitor, worker death, barrier release, reconnection | Unit + integration |
Test patterns¶
Server fixture: Most integration tests create a DiLoCoServer with
start() (background thread), run their test, then stop() the server. The
server auto-selects a port to avoid conflicts.
TinyModel: Tests use minimal models (2 nn.Linear layers, dim=4-8) to keep
tests fast. The model is the same one used to initialize the server, ensuring
parameter names match.
Simulated training: Tests simulate training by directly modifying model
parameters (p.data.sub_(0.01)) or running optimizer.step() with synthetic
gradients. No actual data loading or forward passes.
Multi-worker sync: Tests spawn multiple workers as separate threads or
sequential submissions to the same server. For synchronous tests, threading is
required because each worker's submit_pseudogradients blocks until all workers
have submitted.
Running tests¶
# All DiLoCo tests (102 tests)
pytest tests/unit/ml/diloco/ -v
# By phase
pytest tests/unit/ml/diloco/test_server.py tests/unit/ml/diloco/test_server_client.py tests/unit/ml/diloco/test_worker.py -v # Phase 1 (32)
pytest tests/unit/ml/diloco/test_async.py -v # Phase 2 (18)
pytest tests/unit/ml/diloco/test_streaming.py -v # Phase 3 (25)
pytest tests/unit/ml/diloco/test_fault_tolerance.py -v # Phase 4 (27)
# Quick smoke test
pytest tests/unit/ml/diloco/test_server.py::TestOuterOptimizer::test_single_worker_outer_step -v
Troubleshooting¶
Worker hangs at sync (synchronous mode)¶
Symptom: One or more workers block at submit_pseudogradients for a long
time.
Cause: The server barrier waits for all expected workers to submit. If a worker crashes and health monitoring is disabled (or timeout is too long), the remaining workers wait until the 600-second HTTP timeout.
Diagnosis:
1. Check server status: forgather diloco status --server host:port
2. Look at pending_submissions in the response. If it lists some workers but
not all, a worker has failed to submit.
3. Check total_worker_deaths to see if the HealthMonitor has already
evicted the dead worker.
4. Check server logs for health monitor warnings.
Mitigation: Ensure --heartbeat-timeout is set (default: 120s). The
HealthMonitor will detect dead workers and release the barrier within
approximately one timeout period. The min_workers setting prevents the
system from continuing with zero workers.
Server port already in use¶
Symptom: OSError: [Errno 98] Address already in use
Cause: Previous server didn't shut down cleanly (socket in TIME_WAIT).
Fix: Either wait ~60 seconds, use a different port, or set
SO_REUSEADDR (not currently done). When port=None, the server
auto-selects an available port starting at 8512.
Pseudo-gradients all zeros¶
Symptom: Training makes no progress, global params don't change.
Cause: The worker model isn't actually training (no gradients flowing).
Diagnosis: Check _local_step is incrementing. Check that _sync() is
being called (look for "starting sync" log messages). Check that the optimizer
hook is installed (verify len(diloco._hooks) > 0).
BFloat16 precision issues¶
Symptom: Numerical differences after sync, especially with very small parameter values.
Cause: BFloat16 has ~3 digits of precision. Very small pseudo-gradients (difference between global and local params) may be rounded to zero.
Mitigation: Disable bf16 communication with --no-bf16 or
bf16_comm=False. This doubles bandwidth usage.
Fragment sync deadlock¶
Symptom: Workers hang when using --num-fragments > 1 in sync mode.
Cause: Per-fragment barriers require all workers to submit the same fragment
in the same round. If workers have different sync_every values (e.g., from
DyLU) or different num_fragments, their fragment schedules won't align.
Requirement: All workers in synchronous fragment mode must use the same
sync_every and num_fragments.
Async staleness drift¶
Symptom: Training loss oscillates or diverges in async mode.
Diagnosis: Check staleness values in server logs. Staleness > 5-10 may indicate that pseudo-gradients are too stale to be useful.
Mitigation:
1. Enable DN momentum (--dn-buffer-size N where N = num_workers)
2. Enable DyLU to equalize submission rates
3. Increase sync_every so each submission is more meaningful
Memory: server accumulates state¶
Symptom: Server memory grows over time.
Cause: _completed_rounds and _completed_fragment_rounds cache results.
These are pruned to keep only 2 most recent entries per round/fragment, but if
many fragments are in play, the cache can grow.
Check: In practice, cache entries are O(num_fragments) dicts of parameter
tensors. For a 1B model with 7 fragments, each entry is ~4 GB (float32). The
cache holds at most 2 entries per fragment, so worst case is
2 * 7 * 4 GB = 56 GB. For large models, consider reducing num_fragments.
Extension Points¶
Adding a new outer optimizer¶
Pass a custom factory function:
server = DiLoCoServer(
model_state_dict=sd,
num_workers=2,
outer_optimizer_factory=lambda p: torch.optim.Adam(p, lr=0.001),
)
The server calls factory(self._param_list.parameters()) once at init. The
optimizer's state_dict is included in saves/loads automatically.
For CLI support, modify _server_cmd in diloco.py to add new --outer-*
flags and build the factory accordingly.
Adding a new server endpoint¶
- Add a handler method
_handle_foo(self, handler)onDiLoCoServer - Register it in
DiLoCoRequestHandler.do_POSTordo_GET(in_create_handler()) - Add a corresponding method on
DiLoCoClientusing_request_jsonor_request_tensor
Adding new communication compression¶
Currently, bf16 casting happens in the worker (_compute_pseudogradients). To
add quantization (e.g., int8, sparse encoding):
- Modify
_compute_pseudogradientsto apply the compression - Modify the server's deserialization to decompress
- Alternatively, implement as a custom serialization format that replaces
torch.savepayloads with a compressed format
Integrating with Forgather callbacks¶
A future callback integration would:
- Create a
DiLoCoCallbackimplementingTrainerCallback - In
on_train_begin: create and startDiLoCoWorker - In
on_train_end: stop the worker - In
on_log: reportdiloco.sync_metricsto the logger
The DiLoCoWorker context manager and optimizer hook design makes this
straightforward -- the callback just manages the worker lifecycle.
Adding P2P allreduce (replacing server)¶
The current architecture is client-server. To add peer-to-peer allreduce:
- Create a new sync backend (e.g.,
allreduce.py) that replacesDiLoCoClient.submit_pseudogradientswith a collective allreduce - The
DiLoCoWorkerwould accept a backend abstraction instead of aDiLoCoClientdirectly - The outer optimizer would run on each worker locally (all workers compute the same average pseudo-gradient, so they'd arrive at the same global params)
Known Limitations¶
-
Single-threaded outer optimizer. The server applies the outer optimizer step in the HTTP handler thread. For very large models, this could delay response time.
-
Fragment split by parameter count, not size. Two fragments may have very different total tensor sizes if parameter dimensions vary (e.g., embedding layer vs attention layers). A size-balanced split would improve streaming overlap.
-
No gradient compression beyond bf16. Int8, sparse, or top-k compression could further reduce bandwidth for larger models.
-
DN direct gradient step uses single LR. The
_outer_lris extracted from the first param group. Multiple param groups with different LRs would need per-group direct steps. -
No per-worker weighting. All workers' pseudo-gradients are equally averaged. Workers with more data or better hardware could be weighted proportionally.
-
ThreadingHTTPServerscalability. One thread per request is fine for 2-10 workers but would need replacement (asyncio, gRPC) for hundreds. -
No fragment-level reconnection. Worker reconnection (
_reconnect()) re-registers and fetches full global params. If a streaming sync was in-flight when the connection dropped, the fragment result is lost and the fragment re-syncs from scratch on the next cycle.