No More Quadratic Complexity for Transformers: Discover the Power of Flash Attention!
The traditional transformer-style attention mechanism does compare every token with every other token, leading to quadratic complexity. What’s the solution? Let’s have a look what Flash Attention offers!
Introduction
Quadratic complexity happens because the attention mechanism is designed to capture dependencies between all pairs of input and output, regardless of their positions in the sequence.
You may be thinking that not all tokens need to be compared with all other tokens, especially in the context of natural language where words often have stronger dependencies with nearby words.
So to address quadratic time and memory complexity, various approximate attention methods have been proposed, which trade off model quality to reduce the computational complexity. However, these methods often do not achieve significant speedup.
Flash Attention
Flash Attention seems to address the computational and memory limitations of the traditional transformer-style attention mechanism.
It was introduced in a paper titled “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré.
Flash Attention introduces a new principle to the attention mechanism: making it IO-aware. This means accounting for reads and writes between different levels of GPU memory.
It uses a technique called tiling to reduce the number of memory reads/writes between the GPU’s high bandwidth memory (HBM) and its on-chip SRAM. This results in a faster and more memory-efficient attention mechanism.
The IO complexity of Flash Attention is analyzed in the paper, showing that it requires fewer HBM accesses than standard attention and is optimal for a range of SRAM sizes.
Flash Attention has also been extended to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method.
In terms of performance, Flash Attention has been shown to train transformers faster than existing baselines. It also enables longer context in transformers, yielding higher quality models and new capabilities.
For example, it has been used to train the first transformers to achieve better-than-chance performance on the Path-X challenge and Path-256 challenge, which involve long sequence lengths.
The Flash Attention algorithm has been implemented and is available for use in the Dao-AILab/flash-attention repository on GitHub https://github.com/Dao-AILab/flash-attention
Let’s us know if you have used or tested in any of your scenarios.