Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
Haiquan Qiu, Quanming Yao
For the first time, we mechanistically explain why low-precision training with flash attention fails, identifying a vicious cycle of rounding errors and proposing a simple, effective fix.
Abstract
The pursuit of computational efficiency has driven the adoption of low-precision formats for training transformer models. However, this progress is often hindered by notorious training instabilities. This paper provides the first mechanistic explanation for a long-standing and unresolved failure case where training with flash attention in low-precision settings leads to catastrophic loss explosion. Our in-depth analysis reveals that the failure is not a random artifact but caused by two intertwined phenomena: the emergence of similar low-rank representations within the attention mechanism and the compounding effect of biased rounding errors inherent in low-precision arithmetic. We demonstrate how these factors create a vicious cycle of error accumulation that corrupts weight updates, ultimately derailing the training dynamics. To validate our findings, we introduce a minimal modification to the flash attention that mitigates the bias in rounding errors. This simple change stabilizes the training process, confirming our analysis and offering a practical solution to this persistent problem. Code is available at https://github.com/ucker/why-low-precision-training-fails.
Analyzes low-precision flash attention training failure caused by low-rank representations and biased BF16 rounding errors.
- First mechanistic explanation showing failure stems from interplay between emergent low-rank representations and biased rounding errors
- Demonstrates error accumulation vicious cycle corrupting weight updates through detailed analysis of attention mechanism
- Minimal targeted modification to flash attention validates analysis by restoring training stability
- Flash attention
- Low-precision arithmetic (BF16)
- Numerical analysis
- Spectral analysis
- GPT-2
- FineWeb
Analysis focuses on specific failure case in GPT-2 model
from the paperGeneralizability to other architectures, larger scales or different low-precision formats like FP8 requires further investigation
from the paperProposed mitigation tailored to specific rounding error identified, may not address other numerical instability sources
from the paper
Extend analysis to FP8 training, larger models and different architectures
from the paperDevelop automated tools to detect and mitigate numerical instabilities during training
from the paper
Author keywords
- low-precision training
- transformer
- attention
Related orals
TileLang: Bridge Programmability and Performance in Modern Neural Kernels
TileLang enables hardware-aware fused kernel programming with tile inference and recommendation achieving 5-6x speedup.
Probabilistic Kernel Function for Fast Angle Testing
Proposes probabilistic kernel functions for angle testing enabling efficient approximate nearest neighbor search.
SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer
Generates minute-long high-resolution videos efficiently with linear attention and constant-memory KV cache for block autoregression.
Efficient Resource-Constrained Training of Transformers via Subspace Optimization
WASI applies subspace-based training to transformer models reducing memory by 62x and FLOPs by 2x while maintaining accuracy on edge devices.
Speculative Actions: A Lossless Framework for Faster AI Agents
Speculative Actions accelerates agent systems by predicting and executing likely future actions in parallel.