Mamba:线性时间序列建模,超越Transformer的效率与性能

近年来,深度学习领域涌现出许多令人惊叹的应用,而这些应用背后的核心力量往往是基于Transformer架构及其核心注意力机制的“基础模型”。为了解决Transformer在处理长序列时计算效率低下的问题,许多次二次时间复杂度的架构被开发出来,例如线性注意力、门控卷积、循环模型和结构化状态空间模型(SSMs)。然而,这些模型在语言等重要模态上的表现却不如注意力机制出色。


友情链接:ACEJoy


 

本文将深入探讨这些模型的局限性,并提出一种全新的选择性状态空间模型,它在多个方面超越了先前工作,在保持线性时间复杂度的同时,实现了与Transformer相当的建模能力。

Transformer的局限性:效率与效果的权衡

Transformer凭借其强大的自注意力机制,能够在上下文窗口内密集地传递信息,从而对复杂数据进行建模。然而,这种机制也带来了两个关键问题:

  1. 有限的上下文窗口: Transformer无法处理超过窗口大小的任何信息。
  2. 二次复杂度: 随着窗口大小的增加,Transformer的计算量呈平方增长。

为了克服这些问题,研究人员一直在探索更有效的注意力机制变体,但往往以牺牲注意力机制的有效性为代价。至今,还没有一种变体能够在多个领域中展现出大规模的有效性。

结构化状态空间模型(SSMs):线性时间复杂度的希望

近年来,结构化状态空间序列模型(SSMs)作为一种很有潜力的序列建模架构,逐渐崭露头角。这些模型可以被看作是循环神经网络(RNNs)和卷积神经网络(CNNs)的结合,并受到了经典状态空间模型(Kalman 1960)的启发。SSMs可以高效地以循环或卷积的形式进行计算,其时间复杂度和空间复杂度都与序列长度呈线性或近线性关系。此外,它们还拥有建模某些数据模态中长程依赖关系的机制,并在Long Range Arena(Tay, Dehghani, Abnar, et al. 2021)等基准测试中取得了领先地位。

然而,SSMs在建模离散和信息密集型数据(如文本)方面表现不佳。

选择性状态空间模型:突破SSMs的局限

本文提出的选择性状态空间模型,通过以下几个方面的改进,克服了先前工作的局限性,实现了与Transformer相当的建模能力,并保持了线性时间复杂度:

选择机制:基于内容的推理

先前模型的一个关键局限性在于它们无法以输入依赖的方式高效地选择数据(例如,关注或忽略特定输入)。受选择性复制归纳头等重要合成任务的启发,本文设计了一种简单的选择机制,通过将SSM参数设置为输入的函数,从而使模型能够根据当前标记有选择地传播或遗忘信息。

硬件感知算法:高效的并行计算

这种简单的改变为模型的计算带来了技术挑战,因为先前所有的SSM模型都必须是时间和输入不变的,才能保证计算效率。本文通过一种硬件感知算法克服了这一挑战,该算法以扫描而不是卷积的方式递归地计算模型,但不会将扩展后的状态具体化,从而避免了GPU内存层次结构不同级别之间的IO访问。这种实现方法在理论上比以前的方法更快(时间复杂度为线性,而所有基于卷积的SSMs的时间复杂度为伪线性),并且在现代硬件上也更快(在A100 GPU上快3倍)。

Mamba架构:简洁而强大的模型设计

本文将先前SSM架构(Dao, Fu, Saab, et al. 2023)的设计与Transformer的MLP块结合,形成一个简单的同质架构设计(Mamba),该架构包含选择性状态空间。

Mamba作为一种通用的序列模型主干,具有以下特点:

  • 高性能: 选择机制在语言和基因组等密集模态上带来了强大的性能。
  • 快速训练和推理: 训练过程中的计算量和内存使用量与序列长度呈线性关系,而推理过程中的自回归展开仅需要每个步骤恒定的时间,因为不需要缓存先前元素。
  • 长上下文: 性能和效率的结合,使模型能够在长达百万个标记的真实数据上取得性能提升。

Mamba的实验验证:超越Transformer的性能

本文通过一系列实验验证了Mamba作为通用序列基础模型主干的潜力,包括在预训练质量和特定领域任务性能方面的评估。

合成任务:选择性复制和归纳头

选择性复制任务中,Mamba展现出强大的内容感知推理能力,能够有效地记住相关标记,并忽略无关标记。在归纳头任务中,Mamba能够完美地解决该任务,并将其解决方案扩展到百万个标记的序列长度,而其他方法只能扩展到训练序列长度的两倍。

语言建模:与Transformer相当的性能

在语言建模方面,Mamba是第一个真正实现Transformer级性能的线性时间序列模型,无论是在预训练困惑度还是在零样本评估方面。在高达10亿个参数的规模下,Mamba的性能超越了各种基线,包括基于LLaMa(Touvron et al. 2023)的现代Transformer训练方案。Mamba语言模型的生成吞吐量是同等规模Transformer的5倍,Mamba-3B的质量与Transformer的两倍规模(例如,Pythia-7B)相当,甚至在常识推理方面超过了Pythia-7B。

DNA建模:高效的基因组分析

在基因组分析方面,Mamba在DNA序列预训练和微调方面都超越了先前最先进的模型,例如SaShiMi、Hyen和Transformer。在预训练质量和下游指标(例如,将具有挑战性的语音生成数据集上的FID降低一半以上)方面,Mamba都取得了优异的成绩。在两种情况下,Mamba的性能都随着上下文长度的增加而提升,直至百万个标记的序列长度。

音频建模:高质量的语音生成

在音频建模方面,Mamba在自回归音频建模方面超越了先前最先进的模型(SaShiMi),并能够处理长达分钟的上下文或百万个标记的序列(控制计算量)。在SC09语音生成数据集上,Mamba模型在保真度指标方面显著优于现有模型,包括WaveNet、SampleRNN、WaveGAN、DiffWave和SaShiMi。

Mamba的效率:超越Transformer的计算速度

Mamba在训练和推理方面都展现出优异的效率。在训练方面,Mamba的效率扫描比PyTorch中的标准扫描实现快20-40倍。在推理方面,Mamba的推理吞吐量是同等规模Transformer的4-5倍,因为Mamba不需要KV缓存,因此可以采用更高的批次大小。

Mamba的未来展望:更强大的基础模型

本文提出的选择性状态空间模型为构建不同领域的基础模型开辟了新的可能性,尤其是在需要长上下文的领域,例如基因组学、音频和视频。Mamba有望成为一种通用的序列模型主干,为未来的深度学习研究带来新的突破。

参考文献

  • Brown, Tom, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, et al. “Language Models Are Few-Shot Learners.” arXiv preprint arXiv:2005.14165, 2020.
  • Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. “Attention Is All You Need.” In Advances in Neural Information Processing Systems, 5998–6008, 2017.
  • Gu, Albert, Karan Goel, and Christopher Ré. “Efficiently Modeling Long Sequences with Structured State Spaces.” In International Conference on Learning Representations, 2022.
  • Dao, Tri, Daniel Y Fu, Andrew M Saab, et al. “H3: A Unified Architecture for Long Sequence Modeling.” In International Conference on Learning Representations, 2023.
  • Poli, Luca, Thomas Scialom, Guillaume Lample, et al. “Hyena: A Hybrid Convolutional Architecture for Long Sequence Modeling.” In International Conference on Learning Representations, 2023.

Q: 这篇论文试图解决什么问题?

A: 这篇论文试图解决的问题是深度学习中基础模型(Foundation Models, FMs)在处理长序列数据时的计算效率问题。基础模型通常基于Transformer架构,尽管它们在多种领域(如语言、图像、语音、音频、时间序列和基因组学)的应用中表现出色,但它们的计算复杂度随着序列长度的增加而呈二次方增长,这限制了它们在处理长序列数据时的应用。为了克服这一挑战,论文提出了一种新的序列建模方法,称为选择性状态空间模型(Selective State Space Models, SSMs),并将其集成到一个简化的神经网络架构中,称为Mamba。Mamba旨在实现线性时间复杂度的序列建模,同时保持或超越Transformer模型在多种模态(如语言、音频和基因组学)上的性能。

Q: 有哪些相关研究?

A: 这篇论文提到了多个与选择性状态空间模型(Selective State Space Models, SSMs)和Mamba架构相关的研究。以下是一些关键的相关研究:

  1. Transformer架构:这是由Vaswani等人在2017年提出的,它的核心是自注意力机制,能够处理长距离依赖关系,但计算复杂度随序列长度的增加而增加。
  2. 结构化状态空间序列模型(Structured State Space Models, SSMs):这些模型是RNN和CNN的结合,能够高效地处理序列数据,具有线性或近线性的序列长度缩放特性。Gu, Goel, and Ré (2022) 和 Gu, Johnson, Goel, et al. (2021) 在这方面的工作为Mamba架构提供了理论基础。
  3. 线性注意力(Linear Attention):Katharopoulos等人在2020年提出了一种近似自注意力的方法,通过递归来实现,这为后续的SSM研究提供了灵感。
  4. H3架构:Dao, Fu, Saab等人在2023年提出了这种架构,它结合了线性注意力和SSM,通过门控连接来增强模型的性能。
  5. Hyena模型:Poli等人在2023年提出了这种模型,它使用全局卷积来近似SSM,并且在某些连续信号数据领域(如音频和视觉)取得了成功。
  6. RetNet:Sun等人在2023年提出了这种模型,它通过引入额外的门控来优化SSM的并行计算路径。
  7. RWKV(Reinventing RNNs for the Transformer Era):Peng等人在2023年提出了这种模型,它基于线性注意力近似,旨在为语言建模设计高效的RNN。
  8. S5模型:Smith, Warrington, and Linderman在2023年提出了这种模型,它通过简化SSM架构来提高计算效率。

这些研究为Mamba架构的发展提供了理论基础和实践经验,特别是在处理长序列数据和提高计算效率方面。

Q: 论文如何解决这个问题?

A: 论文通过以下几个关键创新来解决长序列数据处理中的计算效率问题:

  1. 选择性状态空间模型(Selective State Space Models, SSMs):论文提出了一种新的SSM,它通过允许模型参数(如状态转移矩阵A、输入门B和输出门C)依赖于输入数据,从而实现对信息的选择性传播或遗忘。这种机制使得模型能够根据当前输入动态地调整其状态,有效地过滤掉不相关的信息。
  2. 硬件感知的并行算法:尽管选择性SSMs的设计破坏了模型的线性时间不变性(LTI),但论文设计了一种硬件感知的并行算法,通过在GPU内存层次结构中更高效地扩展状态来计算模型。这种算法通过在快速SRAM中执行离散化和递归操作,然后仅将最终输出写回HBM,从而减少了内存访问次数,提高了计算效率。
  3. 简化的神经网络架构(Mamba):论文提出了一个简化的端到端神经网络架构,该架构不包含注意力或多层感知器(MLP)块,而是将SSMs直接集成到网络中。这种架构通过合并H3架构中的线性注意力和MLP块,形成了一个简单且同质的架构设计。
  4. 跨模态性能验证:论文在多种模态(如语言、音频和基因组学)上验证了Mamba架构的有效性。在语言建模任务中,Mamba-3B模型在预训练和下游评估中的表现超过了相同大小的Transformer模型,并与两倍大小的Transformer模型相匹配。

通过这些创新,论文展示了Mamba架构在处理长序列数据时的高效性和优越性能,特别是在保持或提高模型性能的同时,实现了线性时间复杂度的序列建模。

Q: 论文做了哪些实验?

A: 论文进行了一系列的实验来验证Mamba架构的有效性,这些实验涵盖了不同的数据模态和任务。以下是主要的实验内容:

  1. 合成任务
    • 选择性复制(Selective Copying):测试模型在处理随机间隔的输入序列时的记忆能力,这要求模型能够记住相关输入并忽略不相关的“噪声”输入。
    • 归纳头(Induction Heads):评估模型在关联记忆任务中的表现,这要求模型能够在给定上下文的情况下正确地预测输出。
  2. 语言建模
    • 预训练质量(Scaling Laws):在Pile数据集上,比较Mamba与其他架构(如Transformer、Transformer++、H3等)在不同模型大小和序列长度下的困惑度(Perplexity)。
    • 零样本下游评估(Zero-shot Evaluations):在多个下游任务上评估Mamba模型的性能,包括LAMBADA、HellaSwag、PIQA、ARC-challenge等。
  3. DNA序列建模
    • 模型规模(Scaling: Model Size):在HG38(人类基因组)数据集上,研究不同模型背骨在不同规模下的性能。
    • 上下文长度(Scaling: Context Length):固定模型规模,增加序列长度,观察模型性能随序列长度的增长情况。
    • 合成物种分类(Synthetic Species Classification):在预训练模型的基础上,进行物种DNA序列的分类任务,以评估模型在长序列上的泛化能力。
  4. 音频建模和生成
    • 长上下文自回归预训练(Long-Context Autoregressive Pretraining):在YouTubeMix数据集上,评估Mamba在不同序列长度下的预训练质量。
    • 自回归语音生成(Autoregressive Speech Generation):在SC09数据集上,比较Mamba模型与现有最先进模型在自动语音生成任务上的性能。
  5. 速度和内存基准测试(Speed and Memory Benchmarks)
    • SSM扫描操作(Scan Operation):在A100 GPU上,比较Mamba的SSM扫描操作与标准卷积和注意力操作的速度。
    • 端到端推理吞吐量(End-to-end Inference Throughput):测量Mamba模型在不同批处理大小下的推理速度。
    • 内存消耗(Memory Consumption):报告Mamba模型在训练过程中的内存使用情况,并与优化的Transformer模型进行比较。

这些实验旨在全面评估Mamba架构在不同场景下的性能,包括其在处理长序列数据时的效率和效果。

Q: 有什么可以进一步探索的点?

A: 尽管论文提出了Mamba架构并在多个领域展示了其有效性,但仍有一些潜在的研究方向可以进一步探索:

  1. 模型规模的扩展:论文中的实验主要集中在较小的模型规模上。研究Mamba在更大模型规模(例如数十亿参数)上的性能,以及它是否能够保持或提高其在长序列任务上的优势。
  2. 多模态学习:探索Mamba在处理多模态数据(如结合文本、图像和音频)时的性能,以及如何有效地整合不同模态的信息。
  3. 跨领域迁移学习:研究Mamba在不同领域(如医疗、金融、法律等)的迁移学习能力,以及如何调整模型以适应特定领域的知识。
  4. 模型解释性:提高Mamba模型的可解释性,以便更好地理解其在处理复杂任务时的决策过程。
  5. 计算效率的进一步优化:尽管Mamba在计算效率上有所改进,但仍然可以探索新的算法和硬件优化技术,以进一步提高其在大规模部署中的实用性。
  6. 长期记忆和上下文建模:研究Mamba在处理需要长期记忆和复杂上下文理解的任务(如故事生成、对话系统等)中的性能。
  7. 模型鲁棒性:评估Mamba在面对噪声、异常值和对抗性攻击时的鲁棒性,并探索提高模型鲁棒性的策略。
  8. 环境交互和强化学习:将Mamba应用于需要与环境交互的任务,如强化学习中的决策制定,以及探索其在这些任务中的性能和适用性。
  9. 模型压缩和部署:研究如何有效地压缩Mamba模型以便于在资源受限的设备上部署,同时保持其性能。
  10. 社会影响和伦理考量:随着模型规模的扩大,研究Mamba在实际应用中可能产生的社会影响,以及如何确保其符合伦理和公平性标准。

这些研究方向不仅有助于推动Mamba架构的发展,还能够为深度学习和人工智能领域的其他研究提供新的视角和方法。

Q: 总结一下论文的主要内容

A: 这篇论文介绍了一种名为Mamba的新型线性时间序列建模架构,旨在提高深度学习中基础模型(Foundation Models, FMs)在处理长序列数据时的计算效率。Mamba的核心是选择性状态空间模型(Selective State Space Models, SSMs),这些模型通过引入选择机制来改进传统的结构化状态空间模型(Structured State Space Models, SSMs),使其能够根据输入内容动态地选择性地处理信息。这种选择性允许模型在保持线性时间复杂度的同时,有效地处理长序列数据。

论文的主要贡献包括:

  1. 选择性SSMs:提出了一种新的SSM,通过让模型参数依赖于输入,实现了对信息的选择性传播和遗忘,从而提高了模型在处理离散和信息密集型数据(如文本)时的性能。
  2. 硬件感知的并行算法:设计了一种新的并行计算方法,通过在GPU内存层次结构中更高效地处理状态扩展,解决了选择性SSMs的计算效率问题。
  3. 简化的神经网络架构:提出了Mamba架构,它将SSMs与Transformer的MLP块结合,形成了一个简单且同质的网络设计,无需额外的注意力或MLP块。
  4. 跨模态性能验证:在多种数据模态(如语言、音频和基因组学)上验证了Mamba架构的有效性,展示了其在预训练质量和下游任务性能上的优势。
  5. 效率基准测试:展示了Mamba在训练和推理速度上的优势,特别是在处理长序列数据时,其性能优于现有的Transformer模型。

论文的实验结果表明,Mamba在处理长序列数据时不仅提高了计算效率,而且在多个领域内实现了与Transformer模型相媲美或更优的性能。这表明Mamba有潜力成为通用序列模型的有力候选者,特别是在需要处理长序列数据的新兴领域,如基因组学、音频和视频。

发表评论