year: 2025/03
paper: https://openreview.net/forum?id=stFPf3gzq1
website: https://blog.exolabs.net/day-12/
code:
connections: exolabs, distributed training, ensembles, DiLoCo, model merging
TLDR:
- Share / sync / average only a small percentage of randomly sampled parameters each step (0.05-0.5%)
- Reduces network communication by up to 1000x (e.g. sharing 0.1% of parameters)
- Models stay highly correlated (>0.9, for reasonable percentages like 0.1%)
- This even seems to work with weights some steps into the past, i.e. asynchronoulsy
- Every node has a full instance of the model; It’s an ensemble of models
- Training is more stable; Works with higher learning rates
- At the end, just do a full average, or use the models for baysian approximation @ inference
- Doesn’t scale well beyond 16 nodes :(, where the current AllReduce-style averaging becomes inefficient. Future work might explore gossip protocols to reduce O(n²) communication to O(n).
Intuition
- This works because model weights evolve slowly compared to gradients. Even slightly outdated parameters provide value, enabling fully overlapped communication and computation.
- After n steps with probability p, the fraction of parameters shared at least once is . With p=0.05% over 10,000 steps, that’s 99.3% coverage - explaining why such tiny p values work so well.
Improving DiLoCo
Pure DiLoCo struggles with large synchronization intervals H. When H grows from 100 to 10,000, models diverge too far between syncs, causing performance drops. SPARTA solves this by maintaining alignment through continuous sparse updates between the full DiLoCo synchronizations.
The combination enables a 100 increase in DiLoCo interval while achieving 14.3% lower perplexity than DiLoCo alone. SPARTA’s regularization effect also allows 2 higher learning rates, accelerating convergence.
SPARTA acts as a low-pass filter, keeping models in the same loss basin without expensive full synchronization.
This connects to broader model merging principles - both SPARTA and PMA exploit the insight that weight space moves slowly enough for asynchronous averaging. PMA averages full checkpoints at intervals; SPARTA averages sparse subsets continuously.