作者:Xuan Luo, Weizhi Wang, Xifeng Yan
Department of Computer Science, UC Santa Barbara
xuan_luo@ucsb.edu, weizhiwang@ucsb.edu, xyan@cs.ucsb.edu
1. 引言与动机
1.1 背景
- LLM 的成功与挑战:
- 大型语言模型 (LLMs) 在翻译、代码生成、推理等任务上取得巨大成功。
- 核心问题: 当前LLM在生成每个token时,通常需要通过所有Transformer层进行完整的前向传播。
- 计算资源浪费:
- 这种统一的计算分配 (Uniform Allocation) 与直觉相悖:简单的任务/token(如重复词、常见短语)理应需要更少的计算资源,而复杂的任务/token(如推理、生成新信息)需要更多。
- 导致计算效率低下, 过拟合等。
1.2 研究问题与贡献
- 现有方法的局限:
- 已有的层跳过 (Layer-skipping) 或早退 (Early-Exit) 方法虽然能减少计算量,但大多忽略了一个根本问题:
- “不同 Token 的生成,其计算需求是如何变化的?” (How do computational demands vary across the generation of different tokens?)
- 已有的层跳过 (Layer-skipping) 或早退 (Early-Exit) 方法虽然能减少计算量,但大多忽略了一个根本问题:
- 本文动机:
- 深入探究Token生成过程中的计算需求异质性。
- 提出一种能在预训练LLM上实现自适应层跳过的方法,且不修改原始模型参数。
- 主要贡献:
- 提出 FlexiDepth: 一个动态调整Transformer层数的即插即用 (plug-in) 方法。
- 在 Llama-3-8B 上实现显著层跳过(跳过8/32层)同时保持100%基准性能。
- 揭示了LLM计算需求与Token类型显著相关(如重复Token vs. 计算密集型Token)。
- 开源了 FlexiDepth 模型 和 FlexiPatterns 数据集 (记录层分配模式)。
2. 相关工作
- 层跳过/效率提升方法分类:
- 基于统计信息跳过层: 利用层输入输出差异等信息判断并跳过不重要层 (如 ShortGPT [26])。
- 早退 (Early-Exit): 在中间层设置判断点,若置信度足够高则直接输出,跳过后续所有层 (如 [37, 18, 34])。
- 从头训练动态深度模型: 在训练时就加入路由机制,动态决定每层是否执行 (如 MoD [31], SkipLayer [41], Duo-LLM [2])。缺点:需要大量计算资源重新训练。
- Encoder中的条件计算: 如 PoWER-BERT [11], CoDA [21], COLT5 [1] 等,在Encoder中根据token重要性/复杂度分配不同计算路径。缺点:非因果性,不直接适用于Decoder-only模型。
- 预训练模型中的跳过: MindSkip [12] 可以在预训练模型上跳过,但主要探索跳过Attention,且本文作者认为其性能或方式有别。
- FlexiDepth 的定位:
- 专注于Decoder-only的预训练LLM。
- 逐层 (Layer-wise) 动态决策,而非早退。
- 通过轻量级插件实现,冻结原始模型参数。
- 不仅提升效率,更旨在理解和利用计算需求的变化规律。
3. FlexiDepth
3.1 整体架构
-
核心思想: 在预训练LLM的每个(或部分,如下文所述,通常是后半部分)Transformer Decoder层,增加决策和适配机制,动态决定每个Token是完整处理还是跳过该层核心计算。
-
FlexiDepth Block (图 2):
- 输入: Hidden State (X)。
- 两个并行路径:
- 完整处理路径 (Full-processing Path, 图2 左):
- Token 通过标准的 Attention 和 FFN 模块。
- 输出 = g * Original_Layer(X) (g 为路由得分)。
- 跳过路径 (Skipping Path, 图2 右):
- Token 绕过 Attention 和 FFN 模块。
- 通过一个轻量级的 Adapter 进行处理。
- 输出 = (1-g) * Adapter(Norm(X))。
- 完整处理路径 (Full-processing Path, 图2 左):
- 核心组件 (可训练):
- Router: 决定Token走哪条路径 (计算得分 g)。
- Adapter: 处理走跳过路径的Token,解决表征不匹配问题。
- 输出: 两条路径的输出加权合并。
-
关键特性: 原始LLM的Attention和FFN参数保持冻结。只训练Router和Adapter。
3.2 Router 设计
-
目标: 为每个输入Token x_i 计算一个门控分数 g_i ∈ (0, 1),表示其通过完整路径的倾向。
-
输入: 经过 RMSNorm 标准化的 Hidden State z = Norm(X)。
-
Router 结构 (Eq 2):
- 为什么不用简单的线性层? (消融实验会证明) 简单的线性层不足以捕捉路由决策所需的复杂模式,尤其是在冻结主干模型时。Bottleneck结构在参数高效的同时提供了足够的表达能力。
-
输出: Gating Score G = σ(Router(z)) (Eq 1),其中 σ 是 Sigmoid 函数。
-
路由决策: 使用预定义阈值 τ。若 g_i > τ,走完整路径;若 g_i <= τ,走跳过路径。
3.3 Attention Skipping 与 KV Cache
-
问题: 如果完全跳过Attention层,那么该Token对应的Key (K) 和 Value (V) 就不会被计算。对于自回归模型,后续的Token将无法Attention到这个被跳过的Token,导致上下文信息丢失,严重影响生成质量 (如图3 中间的 ‘No KV Cache’ 所示)。
-
FlexiDepth 的解决方案 (图3 右侧 ‘KV Cache’):
- 对于决定跳过Attention模块的Token (即 g_i <= τ):
- 仍然计算其对应的 Key (K) 和 Value (V) 并存入KV Cache。
- 跳过 Query (Q) 的计算以及后续的点积注意力计算 (Scaled Dot-Product Attention)。
- 对于决定跳过Attention模块的Token (即 g_i <= τ):
-
好处:
- 保留了完整的上下文信息,确保后续Token可以Attention到所有历史Token。
- 依然节省了Query计算和主要的Attention矩阵计算开销。
- 这是维护自回归生成完整性的关键设计。
3.4 FFN Skipping 与 Adapter
- 问题: FFN层包含非线性变换,直接跳过FFN会导致:
- 表征不匹配 (Representation Mismatch): 经过FFN处理的Token和直接跳过的Token处于不同的表示空间。
- 性能显著下降: (消融实验会证明) 简单跳过FFN效果很差。
- FlexiDepth 的解决方案 (图2 右侧):
- 引入一个轻量级 Adapter。
- 结构: 与原始FFN类似 (MLP结构),但中间维度显著减小 (例如,论文中提到减少16倍)。
- 功能: 对跳过FFN的Token进行变换,使其表示与经过完整FFN处理的Token对齐 (align)。
- 好处:
- 在计算开销很小的情况下,有效弥合了跳过FFN带来的表征差异。
- 是保证性能的另一个关键组件。
3.5 损失函数
-
目标: 平衡 生成质量 和 计算效率 (层跳过率)。
-
总损失 (Total Loss, Eq 4):
- L_lm: 标准的下一个Token预测损失 (Language Modeling Loss)。
- L_skip: 层跳过损失 (Layer-skipping Loss)。
- α: 平衡系数,控制层跳过损失的权重。
-
层跳过损失 (L_skip, Eq 3): L_skip = (1/T) * Σ_t (Σ_l g_tl)2 (原文公式似乎有误,应该是类似惩罚“使用层数”的平方和,更可能是 (1/T) * Σ_t Σ_l (g_tl)2 或者类似含义,需确认。但核心思想是惩罚使用的层数。)
- 惩罚每个Token使用的门控分数 (g) 的总和的平方 (或者各层g的平方和)。
- 为什么用平方? 对使用更多层的Token施加更大的惩罚,鼓励模型跳过层;同时避免模型陷入全跳或全不跳的极端。有助于稳定训练。
-
训练细节 (Section 3.1):
- 只在模型的后半部分层 (如 Llama-3-8B 的后16层) 应用FlexiDepth。原因:先前研究表明跳过早期层对性能影响更大。
- Router的Bottleneck维度 (dr = d/16),Adapter的中间层维度缩小16倍。
- 使用 Tulu-v2 数据集训练,AdamW优化器
4. 实验设置
- 基础模型: Llama-3-8B-Instruct (32层)。
- 评估基准 (Benchmarks):
- 单Token生成: MMLU, HellaSwag, Winogrande (考察知识、常识、推理)。
- 多Token生成: GSM8K (数学推理), HumanEval (代码生成), CoQA (对话式问答)。区分这两类很重要,因为性能差异在多Token任务上更明显。
- 评估指标 (Metrics): Accuracy (acc), Normalized Accuracy (acc_norm), Exact Match (EM), Pass@1, F1 score (根据不同任务选择)。
- 对比基线 (Baselines):
- Vanilla (原始 Llama-3-8B-Instruct)。
- LayerSkip [9] (早退最后k层 + 推测解码)。
- ShortGPT [26] (基于输入输出差异剪枝k层)。
- LaCo [39] (层合并,减少k层)。
- MindSkip [12] (探索Attention/FFN/Layer跳过,论文采用其Layer Skipping设置)。
- 公平比较: 所有基线方法都应用于 Llama-3-8B,并配置为跳过相同数量 (k=4 或 k=8) 的层进行比较 (通过调整FlexiDepth的α实现近似跳过层数)。
5. 主要结果与分析
5.1 基准性能比较
- 核心发现: FlexiDepth 在跳过层数(k=4, k=8)的情况下,显著优于所有基线方法,尤其是在多Token生成任务 (GSM8K, HumanEval) 上。
- Skip 8 Layers:
- 基线方法在 GSM8K 和 HumanEval 上性能几乎崩溃 (接近0)。
- FlexiDepth 保持了接近100% (100.7%) 的平均性能。
- 性能甚至略有提升?
- 在某些任务上,FlexiDepth 性能甚至略微超过了原始模型 (Retain % > 100%)。
- 假设: 作者推测这可能源于自适应跳过带来的隐式正则化 (implicit regularization) 效果,跳过了不信息或噪声参数。与完全微调的模型对比 (allenai/llama-3-tulu-2-8b),FlexiDepth在GSM8K/HumanEval上表现更好,说明提升不完全来自训练数据。
- 结论: FlexiDepth 可以在大幅减少计算(跳过8层)的同时,几乎无损甚至略微提升模型在各种任务上的性能,尤其擅长处理需要复杂推理的长序列生成任务。
5.2 跨模型尺寸表现
- 实验: 在不同尺寸的指令微调模型上应用FlexiDepth (Llama-2-13B, Llama-3-8B, Qwen-2.5-3B)。
- 发现:
- 模型越大,跳过的层数越多。
- Llama-2-13B: 平均跳过约 6-7 层。
- Llama-3-8B: 平均跳过约 6 层 (这里跳过层数比Table 1的8层少,可能是α取值不同)。
- Qwen-2.5-3B: 平均只跳过 1-2 层。
- 模型越大,跳过的层数越多。
- 解释:
- 这表明更大的模型固有地拥有更高的冗余度 (redundancy)。
- 因此,自适应层跳过方法在更大规模的LLM上具有更大的潜力。
5.3 层分配模式
- 主要发现:
-
任务依赖性
- Summarization (总结): 平均使用更多层 (e.g., 28.65层)。需要深入理解和抽象。
- Extractive QA (抽取式问答) / Copying (复制): 平均使用较少层 (e.g., 复制 21.95层)。依赖检索和直接输出。
- Continuation (续写): 使用最多层 (e.g., 30.27层)。需要创造性和上下文连贯性。
-
Token 类型依赖性
- 重复/简单复制: 如重复数字列表、公式左侧的数字,使用较少层。
- 计算/推理/高不确定性: 如数学运算的结果、总结或续写中的新信息,需要更多层。
-
- 结论: LLM的计算需求确实不是均匀的,而是与任务复杂度和当前Token的功能(是复制、计算还是生成新信息)密切相关。FlexiDepth的自适应机制能够捕捉并利用这种模式。
6. 消融实验
- 目的: 验证FlexiDepth中各个设计选择的必要性。基于Llama-3-8B进行。
- 实验设置:
- Linear Router: 将 MLP Router 替换为简单的线性层 + Sigmoid。
- No KV Cache: 跳过Attention时,不计算和存储 K, V。
- No Adapter: 跳过FFN时,移除Adapter。
- 结果:
- Linear Router: 性能显著下降 (Retain 68.7%),尤其在 GSM8K (0.657 -> 0.131)。说明复杂路由机制是必要的。
- No KV Cache: 性能大幅下降 (Retain 84.3%)。证明为跳过Token保留KV Cache对于维护上下文至关重要。
- No Adapter: 性能灾难性下降 (Retain 28.1%)。凸显Adapter在对齐跳过FFN的Token表征方面的关键作用。
- 结论: FlexiDepth 中的 Router、KV Cache 保留策略、以及 FFN Adapter 都是不可或缺的设计,共同保证了模型在层跳过时的性能。
7. 局限性与未来工作
- 主要局限性 (Limitation):
- 理论FLOPs减少 vs. 实际吞吐量提升: 当前实现未能在现有GPU硬件上带来显著的推理速度提升。
- 原因:
- 控制流开销 (Control-flow overhead): 同一个batch内的样本可能走不同的计算路径 (一些Token跳过,一些不跳过),需要复杂的管理。
- 不规则内存访问 (Irregular memory access): 不同的执行路径导致访存模式不规则,降低GPU并行效率。
- 未来工作 (Future Work):
- 硬件感知优化: 需要研究专门的优化技术来克服上述瓶颈,例如:
- Token Grouping [30]: 将计算需求相似的Token分组处理。
- Expert Sharding / Load Balancing [30, 15]: 在多GPU或专用硬件上更有效地分配计算负载。
- 深入研究正则化效应: 探索自适应跳过是否真的能作为一种有效的正则化手段。
- 将FlexiDepth应用于更广泛的模型和任务。
- 硬件感知优化: 需要研究专门的优化技术来克服上述瓶颈,例如:
8. 结论
- 核心贡献: 提出 FlexiDepth,一种在预训练LLM上实现动态自适应层跳过的方法,无需修改原始模型参数。
- 关键成果:
- 在保持SOTA性能(甚至略有超越)的同时,实现了显著的层跳过(如Llama-3-8B跳过8/32层)。
- 显著优于现有兼容预训练模型的层跳过方法,尤其在复杂生成任务上。
- 重要洞见:
- 首次系统地揭示并量化了LLM中Token生成的计算需求异质性,发现其与任务类型和Token功能强相关。
- 验证了更大模型具有更高冗余度,为自适应方法提供了更大空间。
- 价值: 提供了一种有效的方法来提升LLM效率(潜力巨大,待硬件优化),并为理解LLM内部计算动态提供了新的视角和工具 (FlexiPatterns数据集)。
9. 代码
https://huggingface.co/xuan-luo/FlexiDepth-Llama-3-8B-Instruct/blob/main/modeling_ddllama.py