投机解码
投机解码(Speculative Decoding)
自回归生成的串行依赖是 Decode 慢的根本原因——必须等上一个 token 出来才能算下一个。投机解码是唯一试图突破这个限制的方法。
核心思路:先用小模型猜接下来几个 token,再让大模型一次性并行验证。
流程
系统里有两个模型:小模型(Draft Model,快但质量差)和大模型(Target Model,质量好但慢)。
每一轮:
- 小模型串行猜 N 个 token(比如4个):“关”→“于”→“秋”→“天”
- 大模型一次性并行验证:把原始上下文 + 4个候选 token 一起喂进去,像做 Prefill 一样并行处理
- 从左到右检查:前3个被接受,第4个被拒绝,用大模型的概率分布重采样
这一轮大模型只做了一次前向传播,但有效产出了 3 个 token(普通 Decode 同样一次前向传播只产出 1 个)。
验证规则:rejection sampling + 残差重采样
接受率 = min(1, q(x) / p(x)),其中 q 是 target model 概率,p 是 draft model 概率。
被拒绝后,从残差分布 max(q - p, 0)(归一化后)重新采样。
warning: 反直觉 最终输出分布和纯用大模型生成完全一致——数学上严格等价,不是近似。直觉上小模型质量差,但验证规则恰好抵消了它的偏差。
数学验证(两词词表 {A, B} 的例子)
Target model: q(A)=0.3, q(B)=0.7
Draft model: p(A)=0.6, p(B)=0.4 ← 小模型高估了 A
计算 P[最终采到 A]:
- 路径一:draft 猜 A,被接受 → 0.6 × min(1, 0.3/0.6) = 0.6 × 0.5 = 0.3
- 路径二:draft 猜 B,被拒绝后残差采到 A → 残差 = max([0.3-0.6, 0.7-0.4], 0) = [0, 0.3],归一化后 A=0
- P[采到A] = 0.3 + 0 = 0.3 = q(A) ✓
计算 P[最终采到 B]:
- 路径一:draft 猜 A,被拒绝,残差采到 B → 0.6 × 0.5 × 1 = 0.3
- 路径二:draft 猜 B,被接受 → 0.4 × min(1, 0.7/0.4) = 0.4 × 1 = 0.4
- P[采到B] = 0.3 + 0.4 = 0.7 = q(B) ✓
question: 待手推 上面两个等式的成立是巧合还是必然?尝试用一般形式证明:对任意 p、q 和词表,这个 rejection sampling 方案都能还原 q 的分布。提示:分 p(x) ≤ q(x) 和 p(x) > q(x) 两种情况分别讨论。
关键规律:
- draft 低估某 token(p < q)→ 接受率 = 1,无条件接受,且残差中该 token 权重为 0
- draft 高估某 token(p > q)→ 接受率 < 1,被拒绝后残差中该 token 权重也为 0
两种情况都确保高估的 token 不会被过量采样。
效果与扩展
加速比取决于 acceptance rate(draft 被接受的比例):
- acceptance rate 高(draft 和 target 分布接近):接近 N 倍加速
- acceptance rate 低:退化到接近 1x,但不会更慢
- 实际场景:稳定提速 2-3 倍
这是无损加速——与 量化 的有损加速有本质区别。
提升 draft model 质量
- 独立 draft model:用 target model 蒸馏训练,让 p 尽量接近 q
- Medusa:在 target model 顶部加并行 head,每个 head 预测不同偏移位置的 token
- EAGLE:draft model 直接读取 target model 最后一层的 hidden state,而不只看 token,猜测质量更高
与约束解码的交互
约束解码(constrained decoding)在 decode 的每一步 mask 掉不合法的 token。如果 draft model 不感知 grammar 状态,它猜出的 token 可能被 grammar mask 拒掉——即使 target model 的验证本来会接受这个 token。这会降低 acceptance rate,削弱加速效果。解决方案是让 draft model 也共享 grammar 状态,只在合法 token 范围内猜测。
相关来源
- llm-inference-tutorial:第7章
- cs336-lecture10-inference:Speculative Sampling 一节