Definition
FlashAttention is an IO-aware exact attention algorithm that computes self-attention without materializing the full N x N attention matrix in GPU high-bandwidth memory (HBM). By tiling the computation and keeping intermediate results in fast on-chip SRAM, it reduces memory usage from O(N^2) to O(N) while computing mathematically identical results to standard attention.
Key Intuition
Standard attention implementations are bottlenecked not by arithmetic (FLOPs) but by memory reads and writes between slow HBM and fast SRAM. FlashAttention restructures the computation into blocks that fit in SRAM, fusing the softmax normalization across blocks using an online softmax trick. This avoids the expensive round-trip of writing and reading the full attention matrix.
History/Origin
Dao et al. (2022) introduced FlashAttention (see flash-attention-paper), framing attention efficiency as an IO complexity problem rather than a FLOP reduction problem. FlashAttention-2 (Dao, 2023) further optimized the algorithm with better work partitioning across GPU thread blocks and warps, achieving 50-73% of theoretical peak throughput. FlashAttention-3 (2024) added support for Hopper GPU features like asynchronous operations.
Relationship to Other Concepts
FlashAttention accelerates self-attention within the transformer architecture. It is orthogonal to approximation-based approaches (sparse attention, linear attention) since it computes exact attention. It has become a standard component in modern LLM training and inference stacks, integrated into frameworks used by llama and other open models. The longer sequences it enables interact with positional-encoding methods designed for length extrapolation.
Notable Results
FlashAttention achieved 2-4x wall-clock speedup for transformer training and enabled training with context lengths up to 16K (later 64K+) where standard attention would run out of memory. It reduced the memory footprint of attention from quadratic to linear, making long-context models practical.
Open Questions
- Extending IO-aware algorithms to other bottleneck operations beyond attention.
- Optimal tiling strategies for emerging hardware architectures.
- Whether exact attention remains necessary or whether carefully designed approximations can match FlashAttention’s practical efficiency.