原文title: Mamba与状态空间模型的可视化指南 url:
https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state
Transformer架构在大型语言模型(LLMs)的成功中发挥了重要作用,几乎所有正在使用的LLM都采用了这一架构。为了进一步提升LLM的性能,研究人员开发了新的架构,甚至可能超越Transformer。其中一种方法是Mamba,一种状态空间模型,确实令人兴奋!
Mamba在论文Mamba: Linear-Time Sequence Modeling with Selective State Spaces中被提出,代码库:https://github.com/state-spaces/mamba
下面将有超过50个定制视觉图,帮助你直观理解Mamba和状态空间模型!
第一部分:Transformer的问题
为了说明为什么Mamba是一个如此有趣的架构,让我们先简要回顾一下Transformer,并探讨它的一个缺点。
Transformer将任何文本输入视为由token组成的序列。
Transformer的一个主要优势是,无论它接收到什么输入,它都可以回顾序列中的任何先前token来推导其表示。
Transformer的核心组件
Transformer由两个结构组成,一组用于表示文本的编码器块和一组用于生成文本的解码器块。这些结构可以用于多种任务,包括翻译。
我们可以通过仅使用解码器来创建生成模型。这种基于Transformer的模型,生成预训练Transformer(GPT),使用解码器块来完成某些输入文本。
让我们看看它是如何工作的!
训练优势
单个解码器块由两个主要组件组成,掩码自注意力机制后接一个前馈神经网络。
自注意力机制是这些模型表现如此出色的主要原因。它允许模型快速训练,同时对整个序列进行无压缩的观察。
那么它是如何工作的呢?
它创建了一个矩阵,将每个token与之前的所有token进行比较。矩阵中的权重由token对之间的相关性决定。
在训练期间,这个矩阵是一次性创建的。我们不需要先计算“My”和“name”之间的注意力,然后再计算“name”和“is”之间的注意力。
这使得并行化成为可能,从而极大地加快了训练速度!
推理缺陷!
然而,这里存在一个缺陷。当我们生成下一个token时,即使我们已经生成了一些token,我们仍然需要重新计算整个序列的注意力。
生成长度为L的序列的token大约需要L²次计算,如果序列长度增加,这可能会非常昂贵。
这种需要重新计算整个序列的情况是Transformer架构的一个主要瓶颈。
让我们看看“经典”技术——循环神经网络(RNN)如何解决推理速度慢的问题。
RNN是解决方案吗?
循环神经网络(RNN)是一种基于序列的网络。它在序列中的每个时间步接收两个输入,即时间步t的输入和前一个时间步t-1的隐藏状态,以生成下一个隐藏状态并预测输出。
RNN通过其循环机制,可以将前一步的信息传递到下一步。我们可以“展开”这个可视化,使其更加明确。
在生成输出时,RNN只需要考虑前一个隐藏状态和当前输入。它避免了重新计算所有先前的隐藏状态,而这正是Transformer会做的。
换句话说,RNN可以快速进行推理,因为它的计算量与序列长度呈线性关系!理论上,它甚至可以有无限上下文长度。
为了说明这一点,让我们将RNN应用于我们之前使用的输入文本。
[
每个隐藏状态是所有先前隐藏状态的聚合,通常是一个压缩的视图。
然而,这里有一个问题……
注意到在生成名字“Maarten”时,最后一个隐藏状态不再包含关于单词“Hello”的信息。RNN往往会随着时间的推移忘记信息,因为它们只考虑一个前一个状态。
尽管RNN在训练和推理中可能很快,但它们缺乏Transformer模型提供的准确性。
相反,我们来看状态空间模型(SSM),它能够有效地使用RNN(有时还使用卷积)。
第二部分:状态空间模型(SSM)
状态空间模型(SSM)与Transformer和RNN一样,处理信息序列,如文本或信号。在本节中,我们将介绍SSM的基础知识及其与文本数据的关系。
什么是状态空间?
状态空间包含完全描述一个系统所需的最少数量的变量。它是一种通过定义系统可能的状态来数学表示问题的方式。
让我们简化一下。想象我们正在迷宫中导航。“状态空间”是所有可能位置(状态)的地图。每个点代表迷宫中的一个独特位置,并包含特定细节,比如你距离出口有多远。
“状态空间表示”是这个地图的简化描述。它显示了你当前的位置(当前状态)、你可以去的下一个位置(可能的未来状态)以及引导你到下一个状态的变化(向右或向左移动)。
尽管状态空间模型使用方程和矩阵来跟踪这种行为,但它只是跟踪你在哪里、你可以去哪里以及你如何到达那里的一种方式。
描述状态的变量,在我们的例子中是X和Y坐标,以及到出口的距离,可以被表示为“状态向量”。
听起来熟悉吗?这是因为语言模型中的嵌入或向量也经常用于描述输入序列的“状态”。例如,你当前位置的向量(状态向量)可能看起来像这样:
在神经网络的背景下,系统的“状态”通常是其隐藏状态,而在大型语言模型的背景下,生成新token时,这是最重要的方面之一。
什么是状态空间模型?
SSM是用于描述这些状态表示并根据某些输入预测其下一个状态的模型。
传统上,在时间t,SSM:
-
将输入序列x(t)映射到潜在状态表示h(t)。
-
并推导出预测的输出序列y(t)。
然而,SSM不是使用离散序列(如移动一次),而是将连续序列作为输入并预测输出序列。
SSM假设动态系统(如在3D空间中移动的物体)可以通过其状态在时间t通过两个方程进行预测。
通过求解这些方程,我们假设可以揭示基于观察数据(输入序列和先前状态)预测系统状态的统计原理。
其目标是找到这个状态表示h(t),以便我们可以从输入到输出序列。
这两个方程是状态空间模型的核心。
这两个方程将在本指南中多次引用。为了使其更直观,它们被颜色编码,以便你可以快速引用它们。
状态方程描述了状态如何通过矩阵A变化,以及输入如何通过矩阵B影响状态。
正如我们之前看到的,h(t)指的是我们在任何给定时间t的潜在状态表示,而***x(t)***指的是某个输入。
输出方程描述了状态如何通过矩阵C转换为输出,以及输入如何通过矩阵D影响输出。
注意:矩阵A、B、C和D通常也被称为参数,因为它们是可学习的。
可视化这两个方程,我们得到以下架构:
让我们逐步了解这些矩阵如何影响学习过程。
假设我们有一些输入信号x(t),这个信号首先乘以矩阵B,它描述了输入如何影响系统。
更新后的状态(类似于神经网络的隐藏状态)是一个潜在空间,包含了环境的核心“知识”。我们将状态乘以矩阵A,它描述了所有内部状态如何连接,因为它们代表了系统的基本动态。
正如你可能已经注意到的,矩阵A在创建状态表示之前应用,并在状态表示更新后更新。
然后,我们使用矩阵C来描述状态如何转换为输出。
最后,我们可以使用矩阵D来提供从输入到输出的直接信号。这也通常被称为跳跃连接。
由于矩阵D类似于跳跃连接,SSM通常被视为以下形式,不包含跳跃连接。
我们还可以更详细地了解每个步骤:
回到我们简化的视角,我们现在可以专注于矩阵A、B和C,作为SSM的核心。
可以表示为:
我们可以更新原始方程(并添加一些漂亮的颜色)来标记每个矩阵的作用,就像我们之前所做的那样。
这两个方程共同目标是基于观察数据预测系统的状态。由于输入是连续的,SSM的主要表示是连续时间表示。
从连续信号到离散信号
如果你有一个连续信号,找到状态表示***h(t)***在解析上具有挑战性。此外,由于我们通常有一个离散输入(如文本序列),我们希望将模型离散化。
为此,我们使用零阶保持技术。它的工作原理如下。首先,每当我们接收到一个离散信号时,我们保持其值,直到接收到一个新的离散信号。这个过程创建了一个SSM可以使用的连续信号:
我们保持值的时间由一个可学习的参数表示,称为步长 ∆。它表示输入的分辨率。
现在我们有了一个连续信号用于输入,我们可以生成一个连续输出,并根据输入的时间步长对值进行采样。
这些采样值是我们的离散化输出!
数学上,我们可以应用零阶保持如下:
它们共同允许我们从连续SSM转变为离散SSM,表示为从函数到函数,x(t) → y(t),变为序列到序列,xₖ → yₖ:
在这里,矩阵A和B现在表示模型的离散化参数。
我们使用k而不是t来表示离散时间步长,以便更清楚地区分连续SSM和离散SSM。
注意:在训练期间,我们仍然保存矩阵A的连续形式,而不是离散版本。在训练期间,连续表示被离散化。
现在我们有了离散表示的公式,让我们探索如何实际计算模型。
循环表示
我们的离散化SSM允许我们将问题表示为特定的时间步长,而不是连续信号。正如我们之前看到的,循环方法(如RNN)在这里非常有用。
如果我们考虑离散时间步长而不是连续信号,我们可以将问题重新表示为时间步长:
在每个时间步长,我们计算当前输入(Bxₖ)如何影响前一个状态(Ahₖ₋₁),然后计算预测的输出(Chₖ)。
这种表示可能已经有点熟悉了!我们可以像之前处理RNN一样处理它。
我们可以将其展开(或展开)如下:
注意到我们可以使用这种离散化版本,利用RNN的基本方法。
卷积表示
我们可以使用的另一种表示是卷积。记住在经典的图像识别任务中,我们应用滤波器(核)来推导聚合特征:
由于我们处理的是文本而不是图像,我们需要一维的视角:
使用来自不同领域的技术,形成了一个有趣的管道:
我们用来表示这个“过滤器”的核是从SSM公式中推导出来的:
让我们探索一下这个核在实践中是如何工作的。与卷积一样,我们可以使用我们的SSM核来处理每一组token并计算输出:
这也说明了填充可能对输出的影响。我改变了填充的顺序以改善可视化,但我们通常在句子的末尾应用填充。
在下一步中,核被移动一次以执行计算的下一个步骤:
在最后一步中,我们可以看到核的完整效果:
将SSM表示为卷积的一个主要好处是它可以像卷积神经网络(CNN)一样并行训练。然而,由于核大小固定,它们的推理速度不如RNN快,并且无法处理无界序列。
三种表示
这三种表示——连续、循环和卷积——各有不同的优势和劣势:
有趣的是,我们现在有了高效的推理(通过循环SSM)和可并行化的训练(通过卷积SSM)。
有了这些表示,我们可以使用一个巧妙的技巧,即根据任务选择表示。在训练期间,我们使用可以并行化的卷积表示,而在推理期间,我们使用高效的循环表示:
这种模型被称为[线性状态空间层(LSSL)
这些表示共享一个重要属性,即线性时不变性(LTI)。LTI表示SSM的参数A、B和C在所有时间步长中都是固定的。这意味着矩阵A、B和C在SSM生成的每个token中都是相同的。
换句话说,无论你给SSM什么序列,A、B和C的值都保持不变。我们有一个静态表示,它不是内容感知的。
在我们探索Mamba如何解决这个问题之前,让我们先看看最后一块拼图,矩阵A。
矩阵A的重要性
可以说,SSM公式中最重要的方面之一是矩阵A。正如我们之前看到的循环表示,它捕获了关于前一个状态的信息,以构建新状态。
本质上,矩阵A生成了隐藏状态:
创建矩阵A的方式可能决定了是只记住几个先前的token,还是捕获我们迄今为止看到的每一个token。特别是在循环表示的背景下,因为它只回顾 前一个状态。
那么,我们如何创建矩阵A,以保留较大的记忆(上下文大小)?
我们使用“Hungry Hungry Hippo”(饥渴的河马)!或者HiPPO(高阶多项式投影算子)。
HiPPO试图将迄今为止看到的所有输入信号压缩成一个系数向量。它使用矩阵A来构建一个状态表示,该表示能够很好地捕捉最近的标记,并使较旧的标记逐渐衰减。其公式可以表示为:
假设我们有一个方阵矩阵A,这将给我们:
使用HiPPO构建矩阵A被证明比将其初始化为随机矩阵要好得多。因此,它能够更准确地重构较新的信号(最近的标记),而不是较旧的信号(初始标记)。
HiPPO矩阵背后的思想是它产生一个隐藏状态,能够记住其历史。从数学上讲,它是通过跟踪勒让德多项式的系数来实现的,这使得它能够近似之前的所有历史。
HiPPO随后被应用于我们之前看到的递归和卷积表示,以处理长距离依赖关系。其结果是结构化序列状态空间(S4),一类能够高效处理长序列的状态空间模型。
它由三部分组成:
- 状态空间模型
- HiPPO用于处理长距离依赖
- 离散化用于创建递归和卷积表示
根据你选择的表示(递归与卷积),这类状态空间模型具有多种优势。它还可以通过构建HiPPO矩阵高效地处理长文本序列并存储记忆。
第三部分:Mamba - 选择性状态空间模型
我们终于涵盖了理解Mamba独特之处所需的所有基础知识。状态空间模型可以用于建模文本序列,但仍存在一些我们希望避免的缺点。
在本节中,我们将探讨Mamba的两个主要贡献:
- 选择性扫描算法,允许模型过滤(不)相关信息。
- 硬件感知算法,允许通过并行扫描、内核融合和重新计算高效地存储(中间)结果。
它们共同创建了选择性状态空间模型或S6模型,可以像自注意力一样用于创建Mamba块。
在探讨这两个主要贡献之前,让我们先了解为什么它们是必要的。
它试图解决什么问题?
状态空间模型,甚至是S4(结构化状态空间模型),在语言建模和生成中某些关键任务上表现不佳,即专注于或忽略特定输入的能力。
我们可以通过两个合成任务来说明这一点,即选择性复制和归纳头。
在选择性复制任务中,SSM的目标是复制输入的一部分并按顺序输出:
然而,(递归/卷积)SSM在这一任务上表现不佳,因为它具有线性时不变性。正如我们之前所见,矩阵A、B和C对SSM生成的每个标记都是相同的。
因此,SSM无法进行基于内容的推理,因为它由于固定的A、B和C矩阵而将每个标记视为同等重要。这是一个问题,因为我们希望SSM能够推理输入(提示)。
SSM表现不佳的第二个任务是归纳头,其目标是重现输入中发现的模式:
在上面的例子中,我们实际上是在进行一次性提示,试图“教”模型在每个“Q:”之后提供一个“A:”响应。然而,由于SSM具有时不变性,它无法选择要从其历史中回忆哪些之前的标记。
让我们通过关注矩阵B来说明这一点。无论输入x是什么,矩阵B始终保持完全相同,并且因此独立于x:
同样,A和C也无论输入是什么都保持固定。这展示了我们迄今为止看到的SSM的静态特性。
相比之下,这些任务对于Transformer来说相对容易,因为它们会根据输入序列动态地改变注意力。它们可以选择性地“查看”或“关注”序列的不同部分。
SSM在这些任务上的表现不佳揭示了时不变SSM的潜在问题,A、B和C矩阵的静态特性导致了基于内容的感知问题。
选择性保留信息
SSM的递归表示创建了一个较小的状态,这是相当高效的,因为它压缩了整个历史。然而,与不压缩历史(通过注意力矩阵)的Transformer模型相比,它的能力要弱得多。
Mamba的目标是兼得两者的优点:拥有一个强大的小状态,就像Transformer的状态一样强大。
如上所述,它通过有选择性地将数据压缩到状态中来实现这一点。当你有一个输入句子时,其中往往有一些信息(如停用词)并没有太多意义。
为了有选择性地压缩信息,我们需要让参数依赖于输入。为此,我们首先探索在训练期间SSM中输入和输出的维度:
在结构化状态空间模型(S4)中,矩阵A、B和C独立于输入,因为它们的维度N和D是静态的,不会改变。
相反,Mamba通过引入输入的序列长度和批量大小,使矩阵B和C以及步长 ∆依赖于输入:
这意味着对于每个输入标记,我们现在都有不同的B和C矩阵,从而解决了内容感知问题!
注意:矩阵A保持不变,因为我们希望状态本身保持静态,但通过B和C对它的影响是动态的。
它们共同有选择性地决定什么信息要保留在隐藏状态中,什么信息要忽略,因为它们现在依赖于输入。
较小的步长 ∆会导致忽略特定单词,转而更多地使用之前的上下文,而较大的步长 ∆则更注重输入单词本身,而不是上下文:
扫描操作
由于这些矩阵现在是动态的,因此无法使用卷积表示来计算,因为卷积假设了一个固定的核。我们只能使用递归表示,并失去卷积提供的并行化能力。
为了实现并行化,我们来看看如何使用递归计算输出:
每个状态是前一个状态(乘以A)与当前输入(乘以B)的和。这被称为扫描操作,可以通过循环轻松计算。
然而,由于每个状态的计算依赖于前一个状态,似乎无法实现并行化。但Mamba通过并行扫描算法实现了这一点。
它假设操作的顺序并不重要(通过结合结合律)。因此,我们可以将序列分成若干部分,并逐步将它们组合起来:
动态矩阵B和C,以及并行扫描算法共同构成了选择性扫描算法,以实现递归表示的动态和快速特性。
硬件感知算法
现代GPU的一个缺点是其小而高效的SRAM与大但效率稍低的DRAM之间的有限传输(IO)速度。频繁地在SRAM和DRAM之间复制信息会成为瓶颈。
Mamba(类似于Flash Attention)试图限制从DRAM到SRAM(反之亦然)的传输次数。它通过内核融合实现这一点,允许模型在完成计算之前避免写入中间结果,并持续进行计算,直到完成。
我们可以通过可视化Mamba的基础架构来观察DRAM和SRAM的具体分配情况:
在这里,以下内容被融合到一个内核中:
- 步长 ∆的离散化步骤
- 选择性扫描算法
- 与C的乘法
硬件感知算法的最后一个部分是重新计算。
中间状态是反向传播计算梯度所必需的,但它们没有被保存。相反,作者在反向传播期间重新计算这些中间状态。
尽管这看起来可能效率不高,但与从相对较慢的DRAM中读取所有中间状态相比,它的成本要低得多。
现在,我们已经涵盖了其架构的所有组成部分,该架构如下图所示:
选择性状态空间模型。来源:Gu, Albert, and Tri Dao. “Mamba: Linear-time sequence modeling with selective state spaces.” arXiv preprint arXiv:2312.00752 (2023)
这种架构通常被称为选择性状态空间模型或S6模型,因为它本质上是通过选择性扫描算法计算的S4模型。
Mamba块
我们迄今为止探讨的选择性状态空间模型可以实现为一个块,就像我们可以在解码器块中表示自注意力一样。
与解码器类似,我们可以堆叠多个Mamba块,并将一个Mamba块的输出用作下一个Mamba块的输入:
它从一个线性投影开始,扩展输入嵌入。然后,在应用选择性状态空间模型之前,使用卷积来防止独立的标记计算。
选择性状态空间模型具有以下特性:
- 通过离散化创建的递归状态空间模型
- 在矩阵A上使用HiPPO初始化以捕捉长距离依赖
- 选择性扫描算法用于有选择性地压缩信息
- 硬件感知算法以加速计算
当我们进一步探讨代码实现时,可以更详细地了解这种架构,并探索一个端到端的例子:
注意一些变化,比如加入了归一化层和softmax,用于选择输出标记。
当我们把所有这些放在一起时,我们得到了快速的推理和训练,甚至实现了无界上下文。使用这种架构,作者发现它的表现与相同大小的Transformer模型相当,甚至在某些情况下超过了Transformer模型。