这篇工作比较缝合,它引用了 3 个 idea:
- Flash Attention 和 pass 抽象
- TeAAL 和它使用的 Einsum 规范
- Extended Einsum
TeAAL 提出用 Einsum 的形式去描述算子,并指导加速器的设计。但是 Einsum 只能描述仅包含加法和乘法的算子,对于 Flash-Attension 这种包含指数运算(softmax)和迭代运算的算子无法描述,于是作者借用了 Extended Einsum 的形式描述了 Flash Attension 算子并实现了 FuseMax 加速器,同时论证了 pass 抽象的合理性。
而实际上,TeAAL 提出的用 Einsum 指导加速器设计的思路并没有因为使用了 Extended Einsum 而被拓展;FuseMax 所展现的性能优势,大部分来自于 Flash-Attension 算法本身,而不是其硬件实现。
一、Background
1.1 Flash Attention, Pass
Flash-Attention 是一种 Attention 算子的实现,相比于传统的实现,它可以降低内存带宽的需求,并且使用更少的片上内存,更适合当前加速器存在的 memory bound。为了达到这个目的,我们需要:
- 尽可能少的从内存中读取数据 -> 算法设计的 pass 数要少
- 尽可能少使用片上内存 -> tile 后 reduce