Best Practice to install flash attention
Step 1: Use pre-built wheel from this repo.
According to your pytorch version, cuda version, python version and system platform (linux x86_64, etc), choose your pre-built wheel link here.
Then run (wget rather than pip install directly from wheel is recommended):
wget <your wheel link obtained from the above step>
pip install <wheel file downloaded>
# for example:
# wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.0/flash_attn-2.6.3+cu124torch2.5-cp312-cp312-linux_x86_64.whl
# pip install ./flash_attn-2.6.3+cu124torch2.5-cp312-cp312-linux_x86_64.whl
Now, you have flash_attn fully work. You can check this by import flash_attn.
However, FA3 still does not work for now.
Step 2 [Optional if working on Hopper GPU]: Install FA3
FA3 is optimized on Hopper GPU (H100, etc).
The best way to install this is to use those Dockerfile that will install FA3.
For instance, the Miles Dockerfile.
The FA3 building cmd is
# 1. Clone and enter the repository
git clone https://github.com/Dao-AILab/flash-attention.git && cd flash-attention/
# 2. Checkout specific commit and sync submodules
git checkout fbf24f67cf7f6442c5cfb2c1057f4bfc57e72d89
git submodule update --init
# 3. Build and install the Hopper-specific kernels
cd hopper/
MAX_JOBS=96 python setup.py install
# 4. Manually install the interface to your python site-packages
export python_path=$(python -c "import site; print(site.getsitepackages()[0])")
mkdir -p $python_path/flash_attn_3
cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py
# 5. Cleanup
cd ../..
rm -rf flash-attention/
Some notes:
- Set CUDA_HOME if your default cuda version does not match the required cuda version.
# Example
CUDA_HOME=/usr/local/cuda-12.8
- Set MAX_JOBS according to your CPU capacity.
Setting MAX_JOBS for CUDA Builds: When compiling template-heavy CUDA projects like Flash Attention, each nvcc process can consume 20–30 GB of RAM — especially for backward-pass kernels with large head dimensions. To avoid OOM kills, set MAX_JOBS to the minimum of your available memory divided by 25 GB and your available CPU cores divided by the --threads value passed to nvcc (typically 2, can be set through env variable NVCC_THREADS). For example, on a machine with 700 GB RAM and 32 CPUs with --threads 2, the memory bound is 700/25 ≈ 28 and the CPU bound is 32/2 = 16, so MAX_JOBS=14 is a reasonable choice. In a Slurm environment, you can derive these limits from $SLURM_MEM_PER_NODE and $SLURM_CPUS_PER_TASK to keep your builds both fast and stable.