深度解析原论文中的 GRPO:带 clip
操作的完整公式与示例代码
在论文 “DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models” 中,Group Relative Policy Optimization (GRPO) 被提出来强化语言模型的数学推理能力。它在 PPO 的基础上进行修改,一方面去掉了价值函数(value function),另一方面利用同一道题上一次性采样多条回答(相同 prompt)来做相对奖励(Relative Reward)。
在之前的简化示例中 从公式到代码:DeepSeek大模型GRPO算法中的 compute_loss如何实现(基于TRL源代码),我们直接写了:
loss ≈ − [ exp ( log p θ − log p θ o l d ) A ^ − β K L ] . \text{loss} \;\approx\; -\Bigl[\exp(\log p_\theta - \log p_{\theta_\mathrm{old}})\,\hat{A} \;-\;\beta \,\mathrm{KL}\Bigr]. loss≈−[exp(logpθ−logpθold)A^−βKL].
但在原论文或作者实现中,还有更接近 PPO 的 clip
操作,用来稳定训练,减少更新过量的风险。下面就是一个更完整的 GRPO 目标函数示例(与 PPO 十分相似):
J G R P O ( θ ) = E q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) [ 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ min ( r t ( θ ) A ^ i , t , c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ i , t ) − β K L [ π θ ∥ π r e f ] ] , \begin{aligned} J_{\mathrm{GRPO}}(\theta) =\;\mathbb{E}_{\,q\sim P(Q),\,\{o_i\}_{i=1}^G\sim \pi_{\theta_{\mathrm{old}}}(O|q)} \biggl[ \frac{1}{G}\sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min \Bigl( r_{t}(\theta)\,\hat{A}_{i,t},\; \mathrm{clip}\bigl( r_{t}(\theta),\,1-\epsilon,\,1+\epsilon )\,\hat{A}_{i,t} \Bigr) \;-\;\beta\,\mathrm{KL}\bigl[\pi_{\theta}\|\pi_{\mathrm{ref}}\bigr] \biggr], \end{aligned} JGRPO(θ)=Eq∼P(Q),{oi}i=1G∼πθold(O∣q)[G1i=1∑G∣oi∣1t=1∑∣oi∣min(rt(θ)A^i,t,clip(rt(θ),1−ϵ,1+ϵ)A^i,t)−βKL[πθ∥πref]],
其中:
- ( q q q) 表示一道题( p r o m p t prompt prompt),一次性采样 ( G G G) 条回答 ({ o i o_i oi}) 组成一个 group。
- ( ∣ o i ∣ |o_i| ∣oi∣) 表示回答 ( o i o_i oi) 的 token 数。
- ( r t ( θ ) = π θ ( o i , t ) π θ o l d ( o i , t ) r_{t}(\theta) = \frac{\pi_\theta(o_{i,t})}{\pi_{\theta_{\mathrm{old}}}(o_{i,t})} rt(θ)=πθold(oi,t)πθ(oi,t)) 是新旧策略对该 token 的概率比(ratio)。
- ( A ^ i , t \hat{A}_{i,t} A^i,t) 是相对优势(relative advantage),基于同一道题多条输出之间的打分差异来估计;
- ( c l i p ( … , 1 − ϵ , 1 + ϵ ) \mathrm{clip}(\dots, 1-\epsilon, 1+\epsilon) clip(…,1−ϵ,1+ϵ)) 与 PPO 类似,把 ( r t r_t rt) 限制在 ( [ 1 − ϵ , 1 + ϵ ] [1-\epsilon,\,1+\epsilon] [1−ϵ,1+ϵ]) 区间,避免更新过度;
- ( β K L [ π θ ∥ π r e f ] \beta\,\mathrm{KL}[\pi_{\theta}\|\pi_{\mathrm{ref}}] βKL[πθ∥πref]) 是对参考策略的 KL 正则,用来抑制策略过离谱地偏离初始模型。
下面我们用一段简化的 PyTorch 伪代码,来演示如何实现这个带 clip
操作的 GRPO loss。
PyTorch 伪代码示例
import torch
import torch.nn.functional as Fdef compute_grpo_loss(current_model,old_model,ref_model,input_ids,attention_mask,advantages,beta,epsilon,
):"""Args:current_model: 当前策略模型 pi_thetaold_model: 旧策略(或快照) pi_{theta_old},仅推断用,不更新ref_model: 参考模型 pi_ref,用来算 KL 的惩罚input_ids, attention_mask: 对应一批完整序列 (prompt + generated)advantages: A_{i,t},由分组得分计算得到的相对优势beta: KL 正则系数epsilon: Clip 范围 (1-epsilon, 1+epsilon)"""# 1) 计算当前模型在序列上每个Token的对数概率 log p_thetaoutputs_curr = current_model(input_ids=input_ids, attention_mask=attention_mask)logps_curr = F.log_softmax(outputs_curr.logits, dim=-1) # (B, L, V)# 2) 计算旧模型 pi_{theta_old} 在序列上的对数概率 log p_{theta_old}with torch.no_grad():outputs_old = old_model(input_ids=input_ids, attention_mask=attention_mask)logps_old = F.log_softmax(outputs_old.logits, dim=-1) # (B, L, V)# 3) 计算参考模型 pi_ref 的对数概率,用来做KLwith torch.no_grad():outputs_ref = ref_model(input_ids=input_ids, attention_mask=attention_mask)logps_ref = F.log_softmax(outputs_ref.logits, dim=-1)# 注意:input_ids 形状 [B, L],要 gather 出每个token位置实际的 logp# gather 出来后 shape = [B, L]curr_token_logp = logps_curr.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)old_token_logp = logps_old.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)ref_token_logp = logps_ref.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)# 4) ratio: r_{t}(\theta) = exp( log p_theta - log p_{theta_old} )ratio = torch.exp(curr_token_logp - old_token_logp)# 5) clip_ratioclipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)# 6) 计算 KL:常见做法 = exp( ref - curr ) - (ref-curr) - 1 (或别的近似)# 也可根据 log p_ref, log p_curr 全分布做更精确的KL,这里只示例token级approxkl_diff = ref_token_logp - curr_token_logpper_token_kl = torch.exp(kl_diff) - kl_diff - 1# 7) 构造PPO-like的目标: min( ratio*adv, clip_ratio*adv ) - beta * KL# 这里为了简单,假设 advantages, ratio, kl 都是 [B, L] 维度,后面再mask掉paddingadv_loss_1 = ratio * advantagesadv_loss_2 = clipped_ratio * advantagesadv_loss = torch.min(adv_loss_1, adv_loss_2)# PPO/GRPO里的损失是 -(上面的期望 - beta * KL)# 由于要最小化loss,而上面J里是一个最大化目标 => loss取负per_token_loss = - (adv_loss - beta * per_token_kl)# 8) 处理padding (比如 attention_mask 或 completion_mask, 这里简写)# 并对batch做聚合# 注意: 下面这个是个示例:可能只对最后若干tokens做mean# 也可能分组对 prompt/生成进行分开算mask = attention_mask # shape = [B, L], 1/0valid_token_count = mask.sum(dim=1) + 1e-10loss_per_seq = (per_token_loss * mask).sum(dim=1) / valid_token_countloss = loss_per_seq.mean()return loss
我们来逐步对应论文中的公式:
- 从旧策略(
old_model
)中拿到 ( log p θ o l d \log p_{\theta_{\mathrm{old}}} logpθold),再结合当前策略的 ( log p θ \log p_\theta logpθ),得到ratio = exp( logps_curr - logps_old )
; - 对 ratio 进行 clip:
clipped_ratio = clamp(ratio, 1-epsilon, 1+epsilon)
,以防更新过量; - 与相对优势(
advantages
)相乘,做min(ratio*adv, clipped_ratio*adv)
; - 参考策略(
ref_model
)的 log-prob 拿来做 KL 惩罚,这里只是一个简化的 token 级近似; - 最终的损失里,将负号(要最大化就取负)与 KL 惩罚项 ( β ∗ K L \beta * \mathrm{KL} β∗KL) 结合,并对 batch 里的 token 做平均。
以上流程和传统 PPO 的实现非常像,只是advantages 的来源在 GRPO 场景中是通过同一问题多条输出的分组对比算出来的——这部分可以在采样和打分阶段完成,然后在这里作为参数传入。
公式与代码的对照
-
公式里的 ( min ( r t ( θ ) A ^ , c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ ) \min(\,r_t(\theta) \hat{A},\, \mathrm{clip}(r_t(\theta),1-\epsilon,1+\epsilon)\,\hat{A}) min(rt(θ)A^,clip(rt(θ),1−ϵ,1+ϵ)A^)) 在代码中体现为:
adv_loss_1 = ratio * advantages adv_loss_2 = clipped_ratio * advantages adv_loss = torch.min(adv_loss_1, adv_loss_2)
-
KL 惩罚
beta * KL(...)
由per_token_kl = torch.exp(kl_diff) - kl_diff - 1 ... per_token_loss = - (adv_loss - beta * per_token_kl)
这样结合起来。这是一种近似在 token 级别上与参考模型做对比,也有别的写法。
-
最后
loss
会累加到batch做平均:这与论文里所说的“取期望”(( E q , { o i } \mathbb{E}_{q,\{o_i\}} Eq,{oi}))是对应的。
小结
在原文更“完整”的 GRPO 设计里,除了去除价值函数、利用分组相对奖励、加 KL 惩罚外,还保留了类似 PPO 的 clip
操作 来稳定训练。在实现时:
-
与 PPO 相似:
- 先算出
ratio
(新旧策略概率比),再用clip(ratio,1-\epsilon,1+\epsilon)
做“截断”; - 再与优势
advantages
相乘,取min(...)
来形成核心损失; - 最后加上
-\beta * KL
。
- 先算出
-
相对优势:
- 来自同一道题一次性采样的多条回答,通过打分(reward model 或自动判定)并在组内做均值/方差归一化,进而得到各条回答/各个 token 的“相对好坏”。
-
无价值函数:
- 与传统 PPO 里需要额外训练一个 value function 不同,GRPO 的“优势”直接通过分组比较得到,减少了对于价值模型(critic)的需求,更轻量,但也需要足够多的采样(group size)来保证估计质量。
这样就完成了 GRPO 算法在原论文中(带 clip 操作)那种更接近 PPO 的目标函数的 PyTorch 伪代码实现。希望这篇文章能帮助你看懂“从公式到代码”的具体过程,也理解 GRPO 在 PPO 框架上是做了哪些改变、又继承了哪些思路。
后记
2025年2月22日12点33分于上海,在GPT o1大模型辅助下完成。