ICLR 2026 Orals

Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention

Haiquan Qiu, Quanming Yao

Efficiency, Systems & Kernels Fri, Apr 24 · 11:18 AM–11:28 AM · 204 A/B Avg rating: 6.50 (4–8)
Author-provided TL;DR

For the first time, we mechanistically explain why low-precision training with flash attention fails, identifying a vicious cycle of rounding errors and proposing a simple, effective fix.

Abstract

The pursuit of computational efficiency has driven the adoption of low-precision formats for training transformer models. However, this progress is often hindered by notorious training instabilities. This paper provides the first mechanistic explanation for a long-standing and unresolved failure case where training with flash attention in low-precision settings leads to catastrophic loss explosion. Our in-depth analysis reveals that the failure is not a random artifact but caused by two intertwined phenomena: the emergence of similar low-rank representations within the attention mechanism and the compounding effect of biased rounding errors inherent in low-precision arithmetic. We demonstrate how these factors create a vicious cycle of error accumulation that corrupts weight updates, ultimately derailing the training dynamics. To validate our findings, we introduce a minimal modification to the flash attention that mitigates the bias in rounding errors. This simple change stabilizes the training process, confirming our analysis and offering a practical solution to this persistent problem. Code is available at https://github.com/ucker/why-low-precision-training-fails.

One-sentence summary·Auto-generated by claude-haiku-4-5-20251001(?)

Analyzes low-precision flash attention training failure caused by low-rank representations and biased BF16 rounding errors.

Contributions·Auto-generated by claude-haiku-4-5-20251001(?)
  • First mechanistic explanation showing failure stems from interplay between emergent low-rank representations and biased rounding errors
  • Demonstrates error accumulation vicious cycle corrupting weight updates through detailed analysis of attention mechanism
  • Minimal targeted modification to flash attention validates analysis by restoring training stability
Methods used·Auto-generated by claude-haiku-4-5-20251001(?)
  • Flash attention
  • Low-precision arithmetic (BF16)
  • Numerical analysis
  • Spectral analysis
Datasets used·Auto-generated by claude-haiku-4-5-20251001(?)
  • GPT-2
  • FineWeb
Limitations (author-stated)·Auto-generated by claude-haiku-4-5-20251001(?)
  • Analysis focuses on specific failure case in GPT-2 model
    from the paper
  • Generalizability to other architectures, larger scales or different low-precision formats like FP8 requires further investigation
    from the paper
  • Proposed mitigation tailored to specific rounding error identified, may not address other numerical instability sources
    from the paper
Future work (author-stated)·Auto-generated by claude-haiku-4-5-20251001(?)
  • Extend analysis to FP8 training, larger models and different architectures
    from the paper
  • Develop automated tools to detect and mitigate numerical instabilities during training
    from the paper

Author keywords

  • low-precision training
  • transformer
  • attention

Related orals

Something off? Let us know →