FlashAttention安装教程
FlashAttention 是一种高效且内存优化的注意力机制实现,旨在提升大规模深度学习模型的训练和推理效率。
-
高效计算:通过优化 IO 操作,减少内存访问开销,提升计算效率。
-
内存优化:降低内存占用,使得在大规模模型上运行更加可行。
-
精确注意力:保持注意力机制的精确性,不引入近似误差。
-
FlashAttention-2 是 FlashAttention 的升级版本,优化了并行计算策略,充分利用硬件资源。改进了工作负载分配,进一步提升计算效率。
-
FlashAttention-3:FlashAttention-3 是专为 Hopper GPU(如 H100)优化的版本,目前处于 Beta 测试阶段。
常见问题:
安装成功后,实际模型代码运行时报错未安装,核心原因就是cxx11abiFALSE这个参数,表示该包在构建时不启用 C++11 ABI。
必须开启不使用才行。否则报错如下:
ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn.
最佳安装步骤(方法1)
- 安装依赖:
- 基础环境:cuda12.1、nvcc.
- 安装python,示例3.10。
- 安装PyTorch,示例orchtorch2.3.0; torchvision0.18.0
ninja
Python 包
- 获取releases对应的whl包:
- 地址:https://github.com/Dao-AILab/flash-attention/releases
- 按照系统环境选whl
3. 我的环境对应的包是:flash_attn-2.7.2.post1+cu12torch2.3cxx11abiTRUE-cp310-cp310-linux_x86_64.whl,解释如下:- flash_attn: 包的名称,表示这个 Wheel 文件是
flash_attn
包的安装文件。 - 2.7.2.post1: 包的版本号,遵循 PEP 440 版本规范。
2.7.2
: 主版本号,表示这是flash_attn
的第 2.7.2 版本。post1
: 表示这是一个“后发布版本”(post-release),通常用于修复发布后的某些问题。
- +cu12torch2.3cxx11abiFALSE: 构建标签,表示该 Wheel 文件是在特定环境下构建的。
cu12
: 表示该包是针对 CUDA 12 构建的。torch2.3
: 表示该包是针对 PyTorch 2.3 构建的。cxx11abiFALSE
: 表示该包在构建时不启用 C++11 ABI(Application Binary Interface)。如果安装包后不识别,就要选为False的版本。
- cp310: Python 版本的标签,表示该包是为 Python 3.10 构建的。
cp310
: 是cpython 3.10
的缩写,表示该包适用于 CPython 解释器的 3.10 版本。
- linux_x86_64: 平台标签,表示该包是为 Linux 操作系统和 x86_64 架构(即 64 位 Intel/AMD 处理器)构建的。
- .whl: 文件扩展名,表示这是一个 Python Wheel 文件。Wheel 是 Python 的一种二进制分发格式,用于快速安装包。
- flash_attn: 包的名称,表示这个 Wheel 文件是
如何安装
可以使用 pip
安装这个 Wheel 文件:
pip install flash_attn-2.7.2.post1+cu12torch2.3cxx11abiTRUE-cp310-cp310-linux_x86_64.whl --no-build-isolation
常规安装步骤(方法二)
-
安装依赖:
- CUDA 工具包或 ROCm 工具包
- PyTorch 1.12 及以上版本
packaging
和ninja
Python 包
pip install packaging ninja
-
安装 FlashAttention:
# 后面--no-build-isolation参数是为了pip 会直接在当前环境中构建包,使用当前环境中已安装的依赖项。 # 如果当前环境缺少构建所需的依赖项,构建过程可能会失败。 pip install flash-attn --no-build-isolation
或从源码编译:
# 下载源码后,进行编译 cd flash-attention python setup.py install
-
运行测试:
export PYTHONPATH=$PWD pytest -q -s test_flash_attn.py
-
补充说明:
4.1 上面运行时,建议设置参数MAX_JOBS,限制最大进程数,不然系统容易崩。本人在docker下安装,直接干重启了,所以建议如下方式运行:
MAX_JOBS=4 pip install flash-attn --no-build-isolation
4.2 如果运行时会出现警告且推理速度依旧很慢,需要继续从源码安装rotary和layer_norm,cd到源码的那两个文件夹,执行 python setup.py install进行安装,如果命令报错弃用,可能要用easy_install命令。
接口使用
import flash_attn_interface
flash_attn_interface.flash_attn_func()
硬件支持
NVIDIA CUDA 支持
- 支持 GPU:Ampere、Ada 或 Hopper 架构 GPU(如 A100、RTX 3090、RTX 4090、H100)。
- 数据类型:FP16 和 BF16。
- 头维度:支持所有头维度,最大至 256。
AMD ROCm 支持
- 支持 GPU:MI200 或 MI300 系列 GPU。
- 数据类型:FP16 和 BF16。
- 后端:支持 Composable Kernel (CK) 和 Triton 后端。
性能优化
Triton 后端
Triton 后端的 FlashAttention-2 实现仍在开发中,目前支持以下特性:
- 前向和反向传播:支持因果掩码、变长序列、任意 Q 和 KV 序列长度、任意头大小。
- 多查询和分组查询注意力:目前仅支持前向传播,反向传播支持正在开发中。
性能改进
- 并行编译:使用
ninja
工具进行并行编译,显著减少编译时间。 - 内存管理:通过设置
MAX_JOBS
环境变量,限制并行编译任务数量,避免内存耗尽。
结论
FlashAttention 系列通过优化计算和内存使用,显著提升了注意力机制的效率。无论是研究人员还是工程师,都可以通过本文提供的安装和使用指南,快速上手并应用于实际项目中。随着 FlashAttention-3 的推出,针对 Hopper GPU 的优化将进一步推动大规模深度学习模型的发展。
参考链接
- FlashAttention 源码