context
- 1. 剪枝方案图释
- 2. 正交矩阵Q
1. 剪枝方案图释
图中的阴影是表示丢弃掉这部分数据。通过引入正交矩阵 Q Q Q使 Q ⊤ Q = Q Q ⊤ = I \mathrm{Q}^\top\mathrm{Q}=\mathrm{Q}\mathrm{Q}^\top=\mathrm{I} Q⊤Q=QQ⊤=I,来大量缩减 X X X的列数和 W W W的行数。
由于 Q Q Q是正交矩阵,有 ∥ Q x ∥ = x ⊤ Q ⊤ Q x = x ⊤ x = ∥ x ∥ \|\mathbf{Q}x\|=\sqrt{x^\top\mathbf{Q}^\top\mathbf{Q}x}=\sqrt{x^\top x}=\|x\| ∥Qx∥=x⊤Q⊤Qx=x⊤x=∥x∥,所以 Q Q Q与 x x x相乘不会影响 x x x的范数。
在一般情况下,假设 X ℓ \mathbf{X}_{\ell} Xℓ是transformer中一个块的输出,在经过RMSNorm(对每一行 x ← X ∣ ∣ X ∣ ∣ x\leftarrow \frac{\mathbf{X}}{\left|\left|\mathbf{X}\right|\right|} x←∣∣X∣∣X处理),然后 R M S N o r m ( X ℓ ) \mathrm{RMSNorm}(\mathbf{X}_{\ell}) RMSNorm(Xℓ)作为下一块的输入。若引入矩阵 Q Q Q,则有 R M S N o r m ( X ℓ ) = R M S N o r m ( X ℓ Q ) Q ⊤ \mathrm{RMSNorm}(\mathbf{X}_\ell)=\mathrm{RMSNorm}(\mathbf{X}_\ell\mathbf{Q})\mathbf{Q}^\top RMSNorm(Xℓ)=RMSNorm(XℓQ)Q⊤,所以实际上引入 Q Q Q不改变transformer的结构。对于transformer中的每一attention或FFN层都有线性层,同时由于transformer中有残差连接(图中的 + ◯ \textcircled{+} +◯操作),这里把矩阵 Q Q Q引入每一块的线性层,所以需要把矩阵 Q Q Q引入到所有之前的层(一直到编码阶段)和所有之后的层(一直到LM头)。
令 W i n ℓ \mathbf{W}_{in}^\ell Winℓ和 W o u t ℓ \mathbf{W}_{out}^\ell Woutℓ为transformer的第 ℓ \ell ℓ块的线性层的权重矩阵, b i n ℓ \mathbf{b}_{in}^\ell binℓ和 b o u t ℓ \mathbf{b}_{out}^\ell boutℓ为相对应的偏置, W e m b d \mathbf{W}_{embd} Wembd和 W h e a d \mathbf{W}_{head} Whead为编码和头矩阵, Q Q Q为 D D D维矩阵,则可以用以下矩阵来模型不变性变换
W ~ e m b d = W e m b d Q , (1) b ~ o u t ℓ = Q ⊤ b o u t ℓ , (4) W ~ i n ℓ = Q ⊤ W i n ℓ , (2) W ~ h e a d = Q ⊤ W h e a d . (5) W ~ o u t ℓ = W o u t ℓ Q , (3) \begin{aligned}\tilde{\mathbf{W}}_{embd}&=\mathbf{W}_{embd}\mathbf{Q} ,&&\text{(1)}&&\tilde{b}_{out}^{\ell}=\mathbf{Q}^{\top}b_{out}^{\ell} ,&&\text{(4)}\\\tilde{\mathbf{W}}_{in}^{\ell}&=\mathbf{Q}^{\top}\mathbf{W}_{in}^{\ell},&&\text{(2)}&&\tilde{\mathbf{W}}_{head}=\mathbf{Q}^{\top}\mathbf{W}_{head} .&&\text{(5)}\\\tilde{\mathbf{W}}_{out}^{\ell}&=\mathbf{W}_{out}^{\ell}\mathbf{Q} ,&&\text{(3)}\end{aligned} W~embdW~inℓW~outℓ=WembdQ,=Q⊤Winℓ,=WoutℓQ,(1)(2)(3)b~outℓ=Q⊤boutℓ,W~head=Q⊤Whead.(4)(5)偏置矩阵保持不变 b ~ i n ℓ = b i n ℓ , b ~ h e a d = b h e a d \tilde{b}_{in}^{\ell}=b_{in}^{\ell},\tilde{b}_{head}=b_{head} b~inℓ=binℓ,b~head=bhead
文章主题思想如图Fig. 1.2
图中,(a)中的 W Q W_Q WQ、 W K W_K WK和 W V W_V WV是注意力中的QKV操作, W V W_V WV表示注意力机制的输出矩阵, M = I − 1 D 1 1 ⊤ \mathbf{M}=\mathbf{I}-\frac{1}{D}\mathbf{1}\mathbf{1}^{\top} M=I−D111⊤是用来使矩阵 X X X中的每一个元素拉回到0上下,与下一步的 x ← X ∣ ∣ X ∣ ∣ x\leftarrow \frac{\mathbf{X}}{\left|\left|\mathbf{X}\right|\right|} x←∣∣X∣∣X共同完成归一化处理, W 1 W_1 W1和 W 2 W_2 W2是MLP操作。(b)与(c)中的 ( α ) (\alpha) (α)就是diag( α \alpha α),矩阵 ( α ′ ) (\alpha^{'}) (α′)来自前一块。向量 α \alpha α和偏置 β \beta β在每个LayerNorm实例上独立学习。diag( α \alpha α)是一个矩阵操作,表示将一个向量 ( α ) (\alpha) (α)作为对角线元素创建一个对角矩阵。
最后移除一些不重要的行和列。
2. 正交矩阵Q
使用主成分分析(PCA)来求解 Q ℓ Q_{\ell} Qℓ(transformer中第 ℓ \ell ℓ块),在训练集中抽取一些数据作为校准数据,喂给模型用来从前到后逐层提取正交矩阵。对于校准数据集中的 i i i条数据,使模型中第 ℓ \ell ℓ层输出为 X ℓ , i X_{\ell,i} Xℓ,i,则有
C ℓ = ∑ i X ℓ , i ⊤ X ℓ , i \mathrm{C}_{\ell}=\sum_{i}\mathrm{X}_{\ell,i}^{\top}\mathrm{X}_{\ell,i} Cℓ=i∑Xℓ,i⊤Xℓ,i则 Q ℓ Q_{\ell} Qℓ是 C ℓ \mathrm{C}_{\ell} Cℓ的降序排列特征值的特征矩阵。