解码加速:Flash-Decoding 让长文本推理更快

大型语言模型(LLM)如 ChatGPT 和 Llama 近年来备受关注,但它们的运行成本依然高昂。即使生成单个回复可能只需要 0.01 美元(在 AWS 上使用 8xA100 实例运行几秒钟),但当扩展到数十亿用户时,成本会迅速增加,因为这些用户每天可能与 LLM 进行多次交互。一些用例的成本更高,例如代码自动补全,因为每次输入新字符时都会运行。随着 LLM 应用的增加,即使生成时间略微提高效率,也会产生巨大的影响。

LLM 推理(或“解码”)是一个迭代过程:一次生成一个词元。生成包含 N 个词元的完整句子需要对模型进行 N 次前向传递。幸运的是,可以缓存先前计算的词元:这意味着单个生成步骤不依赖于上下文长度,除了一个操作:注意力机制。该操作的计算量随着上下文长度的增加而迅速增长。

LLM 的一些重要新兴用例利用了长上下文。有了更长的上下文,LLM 可以推理更长的文档,例如对文档进行摘要或回答有关文档的问题,它们可以跟踪更长的对话,甚至在编写代码之前处理整个代码库。例如,大多数 LLM 在 2022 年的上下文长度最多为 2k(GPT-3),但现在我们拥有上下文长度扩展到 32k(Llama-2-32k)甚至 100k(CodeLlama)的开源 LLM。在这种情况下,注意力机制在推理过程中占用了大量时间。

当扩展批次大小维度时,即使上下文相对较短,注意力机制也会成为瓶颈。这是因为要读取的内存量随着批次大小维度而扩展,而它仅取决于模型大小。

我们提出了一种名为 Flash-Decoding 的技术,它可以显著加快推理过程中的注意力机制,对于非常长的序列,可以使生成速度提高 8 倍。主要思想是尽可能快地并行加载键和值,然后分别重新缩放和组合结果以保持正确的注意力输出。

解码的多头注意力机制

在解码过程中,每个新生成的词元都需要关注所有先前的词元,以计算:

softmax(queries @ keys.transpose) @ values

此操作已在训练情况下使用 FlashAttention(最近的 v1 和 v2 版本)进行了优化,其中瓶颈是读取和写入中间结果(例如 Q @ K^T)的内存带宽。然而,这些优化不直接适用于推理情况,因为瓶颈不同。对于训练,FlashAttention 在批次大小和查询长度维度上并行化。在推理过程中,查询长度通常为 1:这意味着如果批次大小小于 GPU 上的流式多处理器数量(A100 为 108),则操作将仅使用 GPU 的一小部分!当使用长上下文时尤其如此,因为它需要更小的批次大小才能适应 GPU 内存。如果批次大小为 1,FlashAttention 将使用不到 GPU 的 1%!

FlashAttention 仅在查询块和批次大小上并行化,无法在解码过程中占用整个 GPU。

注意力机制也可以使用矩阵乘法原语来完成,而无需使用 FlashAttention。在这种情况下,操作会完全占用 GPU,但会启动许多内核来写入和读取中间结果,这并非最佳选择。

解码的更快注意力机制:Flash-Decoding

我们新的方法 Flash-Decoding 基于 FlashAttention,并添加了一个新的并行化维度:键/值序列长度。它结合了上述两种方法的优点。与 FlashAttention 一样,它存储到全局内存的额外数据很少,但即使批次大小很小,只要上下文长度足够长,它也能充分利用 GPU。

Flash-Decoding 也在键和值上并行化,但需要一个小的最终归约步骤。

Flash-Decoding 分三个步骤进行:

  1. 首先,我们将键/值分成更小的块。
  2. 我们使用 FlashAttention 并行计算查询与每个块的注意力。我们还为每行和每个块写入一个额外的标量:注意力值的 log-sum-exp。
  3. 最后,我们使用 log-sum-exp 来缩放每个块的贡献,通过对所有块进行归约来计算实际输出。

所有这些都是可能的,因为注意力/softmax 可以迭代计算。在 Flash-Decoding 中,它在两个级别上使用:在块内(类似于 FlashAttention),以及在块之间进行最终归约。

实际上,步骤 (1) 不涉及任何 GPU 操作,因为键/值块是完整键/值张量的视图。然后我们有两个独立的内核分别执行 (2) 和 (3)。

CodeLlama 34B 的基准测试

为了验证这种方法,我们对 CodeLLaMa-34b 的解码吞吐量进行了基准测试。该模型与 Llama 2 具有相同的架构,更一般而言,结果应该可以推广到许多 LLM。我们测量了不同序列长度(从 512 到 64k)下的解码速度(tok/s),并比较了几种计算注意力机制的方法:

  • Pytorch:使用纯 PyTorch 原语(不使用 FlashAttention)运行注意力机制。
  • FlashAttention v2。
  • FasterTransformer:使用 FasterTransformer 注意力内核。
  • Flash-Decoding。
  • 以及一个上限,计算为读取整个模型以及 KV 缓存所需的时间。

Flash-Decoding 在解码速度方面为非常长的序列带来了高达 8 倍的加速,并且比其他方法的扩展性更好。

所有方法在小提示情况下表现相似,但随着序列长度从 512 增加到 64k,扩展性都很差,除了 Flash-Decoding。在这种情况下(批次大小为 1),使用 Flash-Decoding,扩展序列长度对生成速度几乎没有影响。

组件级微基准测试

我们还在 A100 上对不同序列长度和批次大小的缩放多头注意力机制进行了微基准测试,输入为 f16。我们将批次大小设置为 1,并使用 16 个维度为 128 的查询头,用于 2 个键/值头(分组查询注意力),这与在 4 个 GPU 上运行 CodeLLaMa-34b 时使用的维度相匹配。

设置算法运行时间(us)
B=256, seqlen=256PyTorch Eager3058.6
B=256, seqlen=256Flash-Attention v2.0.9390.5
B=256, seqlen=256Flash-Decoding63.4
B=128, seqlen=512PyTorch Eager3151.4
B=128, seqlen=512Flash-Attention v2.0.9366.3
B=128, seqlen=512Flash-Decoding67.7
B=64, seqlen=1024PyTorch Eager3160.4
B=64, seqlen=1024Flash-Attention v2.0.9364.8
B=64, seqlen=1024Flash-Decoding77.7
B=32, seqlen=2048PyTorch Eager3158.3
B=32, seqlen=2048Flash-Attention v2.0.9352
B=32, seqlen=2048Flash-Decoding58.5
B=16, seqlen=4096PyTorch Eager3157
B=16, seqlen=4096Flash-Attention v2.0.9401.7
B=16, seqlen=4096Flash-Decoding57
B=8, seqlen=8192PyTorch Eager3173.1
B=8, seqlen=8192Flash-Attention v2.0.9529.2
B=8, seqlen=8192Flash-Decoding56.4
B=4, seqlen=16384PyTorch Eager3223
B=4, seqlen=16384Flash-Attention v2.0.9582.7
B=4, seqlen=16384Flash-Decoding58.2
B=2, seqlen=32768PyTorch Eager3224.1
B=2, seqlen=32768Flash-Attention v2.0.91156.1
B=2, seqlen=32768Flash-Decoding60.3
B=1, seqlen=65536PyTorch Eager1335.6
B=1, seqlen=65536Flash-Attention v2.0.92300.6
B=1, seqlen=65536Flash-Decoding64.4
B=1, seqlen=131072PyTorch Eager2664
B=1, seqlen=131072Flash-Attention v2.0.94592.2
B=1, seqlen=131072Flash-Decoding106.6

多头注意力机制的微基准测试,运行时间为 us。Flash-Decoding 在序列长度扩展到 64k 时,运行时间几乎保持不变。

先前测量的端到端高达 8 倍的加速是可能的,因为注意力机制本身比 FlashAttention 快 50 倍。在序列长度达到 32k 之前,注意力时间大致保持不变,因为 Flash-Decoding 设法充分利用了 GPU。

如何使用 Flash-Decoding

Flash-decoding 可在以下位置使用:

  • FlashAttention 包,从 2.2 版本开始。
  • xFormers,从 0.0.22 版本开始,通过 xformers.ops.memory_efficient_attention。调度器将根据问题大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个高效的 triton 内核,该内核实现了 Flash-Decoding 算法。

LLaMa v2 / CodeLLaMa 的完整解码示例可在 FlashAttention 存储库 这里 和 xFormers 存储库 这里 找到。我们还提供了一个 LLaMa v1/v2 模型的最小高效解码代码示例,旨在快速、易于阅读、具有教育意义和可修改性。

致谢

感谢 Erich Elsen、Ashish Vaswani 和 Michaël Benesty 建议了将 KV 缓存加载拆分的这个想法。我们要感谢 Jeremy Reizenstein、Patrick Labatut 和 Andrew Tulloch 的宝贵讨论,以及 Quentin Carbonneaux 为 xFormers 提供高效的解码示例。我们还要感谢 Geeta Chauhan 和 Gregory Chanan 在写作方面提供的帮助,以及更广泛地为将本文发表在 PyTorch 博客上做出的贡献。

参考文献:

发表评论