No More Quadratic Complexity for Transformers: Discover the Power of Flash Attention!

Agent Issue
2 min readJul 25, 2023

--

The traditional transformer-style attention mechanism does compare every token with every other token, leading to quadratic complexity. What’s the solution? Let’s have a look what Flash Attention offers!

Person reading news about Flash Attention

Introduction

Quadratic complexity happens because the attention mechanism is designed to capture dependencies between all pairs of input and output, regardless of their positions in the sequence.

You may be thinking that not all tokens need to be compared with all other tokens, especially in the context of natural language where words often have stronger dependencies with nearby words.

So to address quadratic time and memory complexity, various approximate attention methods have been proposed, which trade off model quality to reduce the computational complexity. However, these methods often do not achieve significant speedup.

Flash Attention

Flash Attention seems to address the computational and memory limitations of the traditional transformer-style attention mechanism.

It was introduced in a paper titled “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré.

Figure 1: Left: FlashAttention uses tiling to prevent materialization of the large 𝑁 × 𝑁 attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the K and V matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. FlashAttention does not read and write the large 𝑁 × 𝑁 attention matrix to HBM, resulting in an 7.6× speedup on the attention computation.

Flash Attention introduces a new principle to the attention mechanism: making it IO-aware. This means accounting for reads and writes between different levels of GPU memory.

It uses a technique called tiling to reduce the number of memory reads/writes between the GPU’s high bandwidth memory (HBM) and its on-chip SRAM. This results in a faster and more memory-efficient attention mechanism.

The IO complexity of Flash Attention is analyzed in the paper, showing that it requires fewer HBM accesses than standard attention and is optimal for a range of SRAM sizes.

Flash Attention has also been extended to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method.

In terms of performance, Flash Attention has been shown to train transformers faster than existing baselines. It also enables longer context in transformers, yielding higher quality models and new capabilities.

For example, it has been used to train the first transformers to achieve better-than-chance performance on the Path-X challenge and Path-256 challenge, which involve long sequence lengths.

The Flash Attention algorithm has been implemented and is available for use in the Dao-AILab/flash-attention repository on GitHub https://github.com/Dao-AILab/flash-attention

Let’s us know if you have used or tested in any of your scenarios.

--

--

Agent Issue
Agent Issue

Written by Agent Issue

Your front-row seat to the future of Agents.

No responses yet