突破极限:高效扩展 Transformer 模型推理

大型语言模型(LLM)在自然语言处理领域取得了巨大成功,其参数规模不断攀升,突破了千亿甚至万亿级别。然而,这些模型的推理效率却面临着严峻挑战,尤其是当需要处理长序列文本并满足严格的延迟要求时。本文将深入探讨 Transformer 模型推理的效率问题,并介绍一系列工程优化策略,旨在突破模型规模和推理效率之间的瓶颈。

推理成本的权衡

随着模型规模的增长,推理成本也随之增加。我们主要关注三个关键指标:延迟、吞吐量和模型 FLOPS 利用率 (MFU)。延迟是指完成一次推理所需的时间,可以细分为处理输入文本的时间(称为“预填充”)和生成输出文本的时间(称为“解码”)。解码延迟也可以按“每步”计算,即除以每个序列中的令牌数。吞吐量是指每秒处理或生成的令牌数。MFU 则是实际吞吐量与理论峰值吞吐量的比率,反映了硬件资源的利用效率。

大型模型通常无法完全容纳在一块加速器芯片的内存中,需要进行模型划分,将模型参数和激活张量分布在多个芯片上。这种划分虽然可以降低每个芯片的内存和计算压力,但也引入了芯片间通信的开销。

内存成本:模型参数和 KV 缓存(每个层中的注意力键和值张量)需要存储在芯片上的高带宽内存 (HBM) 中。每次前向传播(预填充或解码步骤)都需要将这些张量从 HBM 加载到计算核心,这会消耗一定的时间,称为“内存时间”。在小批量和短序列情况下,加载权重的时间占主导地位。而在大批量和长序列情况下,加载 KV 缓存的时间则会占主导地位。

计算成本:一个拥有 N 个参数的解码器模型,每个令牌需要进行 2N 次矩阵乘法运算。如果所有芯片都以峰值 FLOPS 运行,这些矩阵乘法需要一定的时间,称为“计算时间”。注意力机制的矩阵乘法通常占用的 FLOPS 较少,但在长序列情况下,KV 缓存的内存占用和带宽需求会显著增加。

优化策略:模型划分

为了高效地进行推理,我们需要对大型模型进行合理的划分。本文将介绍几种模型划分策略,并分析其在不同模型规模、序列长度和应用需求下的性能表现。

1. 前馈层划分

1D 权重固定布局:这是最简单的划分策略,将每个 E × F 权重矩阵沿 E 或 F 轴进行划分(或分片),每个权重分片在相应的芯片上与激活分片进行乘法运算,并将结果通过全聚合和/或降维散射操作进行聚合。这种策略在芯片数量较少时,内存延迟和计算延迟会随着芯片数量的增加而线性下降。然而,通信延迟基本保持不变,因为每次矩阵乘法都需要将整个激活矩阵进行聚合。当芯片数量增加时,通信成为瓶颈。

2D 权重固定布局:当芯片数量较多时,可以将每个 E × F 权重矩阵沿 E 和 F 轴进行划分,使得每个分片近似为正方形。这种策略称为 2D 权重固定布局。虽然计算成本与 1D 权重固定布局相同,但通信效率更高。通过交替地沿 E 和 F 轴进行激活聚合,可以确保每个芯片始终拥有所需的激活分片,而无需完全复制激活张量。通信时间随着芯片数量的增加而减小,因此即使在通信成为瓶颈的情况下,也可以通过增加芯片数量来降低延迟。

权重聚合布局:在权重固定布局中,每个芯片存储一个权重矩阵的分片,并负责将其与相应的激活分片进行乘法运算。每个芯片的矩阵乘法结果需要进行聚合,才能作为后续操作的输入。然而,当批量大小(和序列长度)增加时,输出激活的大小可能远大于权重的大小。在这种情况下,将激活固定在每个芯片上,并将权重在芯片之间进行传输会更经济。对于非常大的批量大小,最好将激活完全固定在连续的矩阵乘法之间,这需要将权重完全传输到所有芯片之间。我们称这种方法为 XYZ 权重聚合。对于中等批量大小,使用“混合”方法是有益的,即权重和激活都沿不同的轴进行部分传输。我们将这些方法称为 X 权重聚合和 XY 权重聚合。

2. 注意力层划分

多头注意力:多头注意力可以与前馈层类似地进行划分,将 nheads 视为 dff。然而,多头注意力在存储和加载 KV 缓存方面会产生大量的内存容量和带宽成本,这在大批量或长序列情况下可能会成为主要的性能瓶颈。

多查询注意力:多查询注意力是一种替代方案,它仍然为查询张量输出 nheads,但仅为键和值张量输出一个头,该头在所有 nheads 查询头之间共享。这将 KV 缓存张量的大小减少了 nheads 倍,从而减少了加载它们所需的内存时间。但它也去掉了原本用于并行化的一个轴,因此 KV 缓存和相关的计算需要进行不同的划分。

优化策略:为了最大程度地减少加载 KV 缓存所需的内存时间,我们将 Q、K 和 V 矩阵沿批次 B 轴进行划分。这种策略可以将每个芯片加载 KV 缓存的内存成本降低 nheads 倍,从而减少内存时间。虽然这种策略会增加额外的通信成本,但与沿头轴划分相比,它可以显著减少 KV 缓存的内存占用,从而提升推理效率。

3. 并行注意力/前馈层

PaLM 模型采用了一种并行化的 Transformer 块结构,将注意力层和前馈层并行计算,并将其结果相加得到输出。这种结构有两个主要优势:

  • 减少延迟:每个层只有一个层归一化操作,而不是两个,这在小批量情况下可以降低延迟。
  • 提高 FLOPS 利用率:可以将前馈层的输入矩阵与注意力层的查询投影矩阵 WQ 进行融合,将键/值投影矩阵 WK 和 WV 融合到注意力层中,并将前馈层的输出矩阵与注意力层的输出投影矩阵 WO 进行融合。这种融合可以提高 FLOPS 利用率,因为它可以更有效地运行加速器上的大型矩阵乘法。更重要的是,它还消除了每个 Transformer 层中用于 dff/nheads 并行化所需的两个全聚合操作之一,将沿该轴的通信时间减少了一半。

低级优化

除了模型划分策略,我们还采用了一系列低级优化技术,进一步提升推理效率:

  • 循环集体 einsum:将通信与计算并行执行,可以部分或完全隐藏大部分降维散射和全聚合操作的通信时间。
  • 集体 einsum 优化:通过开发集体 einsum 的多种变体,可以针对不同的场景进行优化,例如延迟与吞吐量、不同的环形轴数量以及与不同输入/输出集体进行融合。
  • 内存布局优化:优化张量在内存中的布局,可以减少矩阵乘法过程中的填充和复制操作。
  • 快速 top-k/top-p 实现:加速解码采样过程。
  • 快速 log-base-2 实现:加速 Softmax 和 Swish 函数的计算。
  • 增量序列处理:支持在预填充阶段对序列进行增量处理。

量化

为了降低权重存储的内存成本,我们使用 AQT 库将 16 位权重转换为 8 位整数。这可以节省权重加载所需的内存时间,尤其是在小批量情况下,并减少权重聚合布局中的通信量。

案例研究:PaLM 模型

为了验证上述优化策略的有效性,我们对 PaLM 模型家族进行了实验,包括 8B、62B 和 540B 参数模型,并使用 bfloat16 或 int8 格式的权重。

1. 前馈层划分

实验结果表明,2D 权重固定布局在解码阶段的性能优于 1D 权重固定布局,因为其在芯片数量增加时的扩展性更好。在预填充阶段,随着批量大小的增加,最优的划分布局从 2D 权重固定布局转变为权重聚合布局。权重聚合布局在小批量情况下效率较低,但在高批量情况下效率最高,可以实现高达 76% 的 MFU。

2. 注意力层划分

实验结果表明,将多查询注意力沿批次 B 轴进行划分可以显著提高推理效率,与沿头轴划分相比,它可以支持 32 倍以上的上下文长度。

3. 并行注意力/前馈层

实验结果表明,并行化 Transformer 块结构可以有效地降低延迟,尤其是在解码阶段。

4. 端到端结果

我们通过调整批量大小、芯片数量和划分策略,获得了 PaLM 模型家族在不同模型规模、序列长度和延迟要求下的 Pareto 前沿。结果表明,在高批量情况下,推理成本与模型参数数量成正比。通过降低批量大小,可以提高延迟,但会增加每个令牌的成本。

实验结果还表明,int8 权重量化可以有效地降低延迟。在低延迟目标情况下,int8 权重量化可以将成本降低一半以上。在高批量情况下,int8 和 bfloat16 权重的成本差异较小。

与 FasterTransformer 的比较

我们还将我们的实现与 FasterTransformer 基准进行了比较,结果表明,我们的实现可以实现更高的 MFU 和更低的延迟,尤其是在大批量情况下。这主要归功于我们采用的 2D 权重固定布局和 TPU v4 的高带宽互连网络。

总结

本文提出了一系列工程优化策略,可以有效地提高 Transformer 模型推理的效率,尤其是当需要处理长序列文本并满足严格的延迟要求时。这些策略可以应用于各种硬件平台,包括 GPU 和 TPU。

未来,我们希望通过进一步探索稀疏性技术、自适应计算技术等方法,进一步降低 Transformer 模型的 FLOPS 数量和通信量,从而实现更高的推理效率。

参考文献

2211.05102v1.pdf (https://arxiv.org/pdf/2211.05102)

发表评论