Dao et al. (2022) from Stanford introduce flash-attention, an IO-aware exact attention algorithm that uses tiling to minimize memory reads/writes between GPU HBM and on-chip SRAM. Unlike approximate attention methods, FlashAttention computes exact attention while achieving significant wall-clock speedups and linear memory scaling.
Problem
self-attention in transformer models has O(N^2) time and memory complexity in sequence length, making long sequences prohibitively expensive. Prior approximate attention methods reduce FLOPs but often fail to achieve actual wall-clock speedup because they ignore the dominant cost: memory access (IO) between GPU HBM and SRAM.
Key Contribution
An IO-aware reformulation of exact attention that never materializes the full N x N attention matrix in HBM. By tiling the computation and recomputing attention in the backward pass rather than storing it, FlashAttention achieves fewer HBM accesses (O(N^2 d^2 M^{-1}) vs. O(Nd + N^2) for standard attention) and is provably optimal for a range of SRAM sizes.
Method
FlashAttention splits Q, K, V into blocks loaded to SRAM, computes partial softmax incrementally via online softmax (tiling), and writes only the final output to HBM. For the backward pass, it stores only the softmax normalization factors and recomputes attention on-chip, trading extra FLOPs for drastically reduced memory access. The authors also extend the approach to block-sparse attention for further speedups.
Main Results
- 15% end-to-end wall-clock speedup on BERT-large (seq. 512) over the MLPerf 1.1 training record.
- 3x speedup on GPT-2 (seq. 1K); 2.4x on Long Range Arena (seq. 1K-4K).
- Up to 7.6x speedup on the attention computation itself.
- Enables longer context: 0.7 perplexity improvement on GPT-2, 6.4 point lift on long-document classification.
- First transformer to achieve better-than-chance on Path-X (seq. 16K, 61.4%) and Path-256 (seq. 64K, 63.1%).
Limitations
Requires custom CUDA kernels, limiting portability across hardware backends. The tiling approach adds implementation complexity. Block-sparse FlashAttention still requires choosing a sparsity pattern.
Impact
FlashAttention became the de facto attention implementation in major training frameworks, enabling the long-context revolution in LLMs. It influenced the architectural choices in llama, gpt-4-technical-report, and most subsequent large-scale transformer training. FlashAttention-2 and later versions further improved throughput.