在衔言渡意的训练收尾阶段,我给自己的最后一项技术任务是给推理加上 KV Cache。

我对这个东西的第一印象来自行业讨论——到处都在说 KV Cache 管理、PagedAttention、prefix sharing。脑子里自动补全了一整套分布式系统的画面:用户登录鉴权,会话状态持久化,多轮对话中历史推理结果的缓存和隐私隔离,怎么保证不重复计算跨对话的历史……

然后我去看了它到底是什么。

就是一个生成循环里的局部变量。调一次 generate(),cache 从空开始,token 一个个生成时逐步填充,生成结束,cache 扔掉。下次调 generate(),又是全新的空 cache。没有任何跨推理的持久化。性质上更像循环里的临时变量,不像权重那种跟着模型走的东西。

我脑补的那些——用户鉴权、会话管理、隐私隔离——确实存在,但那是 serving 层的工程问题,和 cache 本身没有关系。行业的一个经典操作:把一个简单概念包在复杂的工程上下文里讲,听众就以为概念本身很复杂。

它到底省了什么

自回归生成时,每一步只新增一个 token,但 self-attention 需要对所有已生成的 token 计算 Key 和 Value。没有 cache 的情况下,每生成一个 token,前面所有 token 都要重新过一遍线性投影算 K 和 V——这是纯粹的重复计算。

Cache 做的事就是:把每一层算过的 K 和 V 存起来,下一步只对新 token 算 K、V,然后拼接到旧的后面。Query 永远只有当前这一个 token,但它要和完整历史的 K 做点积、从完整历史的 V 里加权取值。省掉的是每一层 K、V 投影的重复计算,不是 attention 计算本身。

在 encoder-decoder 架构下还有一个额外的缓存机会:cross-attention 里 encoder 输出的 K 和 V。Encoder 只在最开始跑一次,输出不变,每个 decoder 层的 cross-attention 都在对同样的 encoder output 做投影——存下来就是一次性的事。

所以完整的 cache 结构是每层一个四元组 (self_k, self_v, cross_k, cross_v),整体是一个长度等于 decoder 层数的 list:

# 8 层 decoder,每层 4 个张量
kv_caches: list[tuple[Tensor, Tensor, Tensor, Tensor] | None] = [None] * 8

初始全是 None。随着生成推进,self-attention 的 cache 每步增长,cross-attention 的 cache 第一步算完后就不再变化。

拆开 MultiheadAttention

理解了 cache 的原理之后,第一反应是觉得实现很简单——一个 list,存张量,每步 cat 一下,完事。

然后发现:PyTorch 的 nn.TransformerDecoderLayer 把 attention 的 K、V 投影包在 nn.MultiheadAttention 内部,外面够不到中间产物。你没法在“投影完成”和“进入 attention 计算”之间截住 K、V 做缓存。标准组件的 forward 签名里根本没有 past_key_values 参数。

所以必须手写 decoder layer。但需要拆的只有 MHA——FFN、LayerNorm、residual connection 全部原样搬,没有任何变化。一个 TransformerDecoderLayer 内部做的事,展开就是:

  1. LayerNorm → self-attention → dropout → residual add
  2. LayerNorm → cross-attention → dropout → residual add
  3. LayerNorm → FFN → dropout → residual add

我用的是 norm_first=True(pre-norm),所以是先 norm 再 attention 再加回去。这个结构手写也就几十行,真正的工作量在 MHA 本身。

in_proj_weight:一个大矩阵藏三份投影

PyTorch 默认的 nn.MultiheadAttention 把 Q、K、V 三个投影矩阵拼成一个 in_proj_weight,shape 是 (3 * d_model, d_model)。这样做是因为一次大矩阵乘法在 GPU 上比三次小的更快。

使用时需要拆开。self-attention 中 Q、K、V 来自同一个输入,三等分 chunk:

q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)

Cross-attention 中 Q 来自 decoder,K、V 来自 encoder output,需要分别投影。把权重 split 成 Q 的部分和 KV 的部分:

w_q, w_kv = self.in_proj_weight.split([self.d_model, self.d_model * 2])
b_q, b_kv = self.in_proj_bias.split([self.d_model, self.d_model * 2])
q = F.linear(query, w_q, b_q)
k, v = F.linear(memory, w_kv, b_kv).chunk(2, dim=-1)

我保持了 in_proj_weight 这个合并结构,是为了兼容已有的训练权重——直接 load_state_dict 就能用,不需要做任何转换。如果以后从头训练一个新模型,一开始就打算支持 cache,那直接定义成三个独立的 nn.Linear 会更清晰——语义一目了然,加载完不用拆分直接算。

从投影到注意力:数据流全貌

输入 X 进入 MHA 之后完整的数据流:

投影阶段——X 和 in_proj_weight 做线性变换,产出 Q、K、V。投影的输出维度不一定等于 d_model——它可以是任何能被 n_head 均分的维度,只要最后 out_proj 的维度对应回来即可。我的模型里投影维度就是 d_model,所以每个的 shape 是 (batch, seq_len, d_model)

多头拆分——把投影输出拆成 n_head × d_head

# (batch, seq, d_model) → (batch, seq, n_head, d_head) → (batch, n_head, seq, d_head)
q = q.view(batch_size, q_len, self.n_head, d_head).transpose(1, 2)

KV Cache 的截取点就在这里——投影和 reshape 完成之后、进入 attention 计算之前。这一刻的 K 和 V 就是要缓存的东西。Self-attention 每步 cat 上新的,cross-attention 第一次算完就不再变。

Attention 计算——Q、K、V 送入 F.scaled_dot_product_attention,PyTorch 内部自动选择最优实现(flash attention 等)。输出 shape 是 (batch, n_head, q_len, d_head)

输出合并——转置回来,拼成 (batch, seq, d_model),过 out_proj 线性层。

attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)

最后返回的是一个元组:(attention 输出, (k, v))。上层收到 cache 后直接替换旧的——不是 append,不是增量拼接,因为 self-attention 里返回的 K、V 已经是 cat 过历史的完整版本。

训练和推理共用一套代码

Cache 在整个调用链上——decode、decoder、layer、MHA——都是 Optional。训练时传 None,所有 cache 相关的分支都不走,行为和替换前的 nn.TransformerDecoderLayer 完全一致。推理时传 cache 对象,走缓存逻辑。

# MHA 内部的核心分支
if cache_k is not None and cache_v is not None:
    k = torch.cat([cache_k, k], dim=2)
    v = torch.cat([cache_v, v], dim=2)

这也提供了一个天然的验证方式:cache 传 None 跑训练,如果 loss 和替换前一致,就说明拆 MHA 没有引入任何偏差。实际验证结果——加载 epoch 47 的存档继续训练,val loss 2.395 附近正常波动,复刻正确。

踩坑记录

Mask 形状:报错说 bias,其实是 seq 没对上

Cache 模式下,Q 的长度是 1(只有当前 token),K 的长度是 past_len + 1(完整历史)。Mask 的 seq 维度必须和 K 的 seq 维度对齐——这一点我漏了,用单步输入的长度算了 mask,忘了加上历史长度。

关键在于报错信息的误导性。scaled_dot_product_attention 把 mask 当作 additive bias 加到点积分数上。当 mask 的形状不匹配时,报的不是“shape mismatch”,而是:

(*bias): last dimension must be contiguous

一开始我以为是投影的偏置出了问题——复刻测试都过了,偏置怎么还能出问题?为了追查这个“偏置问题”,我把 MHA 内部的完整数据流从头到尾整理了一遍(倒也不算白费),最后才意识到:报错里的 bias 不是 nn.Linear 的偏置,而是 attention score 上的 additive bias——也就是我的 mask。

第一步时 mask 是 (1, 1, 1, 1),这是一个完美的广播张量,怎么都不会出错。第二步时 K 的长度变成了 2,而 mask 还是 (1, 1, 1, 1)——长度没有累加——但形状依然能广播,真正触发报错的是连续性检查。一个形状问题伪装成了连续性问题。

修复很简单——从 cache 里取 self_k 的长度作为 past_len,mask 的 seq 维度用 past_len + current_len

first_layer = kv_caches[0]
past_len = first_layer[0].shape[2] if first_layer is not None else 0
query_mask = torch.ones((batch_size, past_len + 1), device=device)

位置编码:不是从 0 开始

这个错误更隐蔽。我的 pos_encoding 是 learned positional embedding,推理时每步的位置应该是 past_len + current_step。但我改 cache 逻辑时完全没动位置编码那部分——它还是按输入长度从 0 开始算。

结果就是每个新 token 都以为自己在位置 0。模型生成出了明显的重复模式——因为对它来说,每一步都是“新的开始”。

修复是从 cache 长度推算当前位置:

past_len = first_layer[0].shape[2] if first_layer is not None else 0
decoder_in_pos = torch.arange(past_len, past_len + decoder_in_len, device=decoder_in.device)
decoder_in_pe = self.pos_encoding(decoder_in_pos)

两个坑有个共同点:都是在 cache 模式下,某个本来按“完整序列”设计的逻辑忘了适配“只有当前 token + 历史在 cache 里”的新语境。全量推理时一切正常,切到 cache 模式才暴露。

实际收益:规模决定一切

最后是性能数据。测试输入是一条较长的英文句子(翻译成中文,生成 28 个 token)。

无 cache——逐 token 耗时从 12ms 涨到 17ms。符合预期:每步重算完整序列的 K、V 投影,序列越长计算量越大,复杂度 O(n²)。

有 cache——逐 token 耗时恒定在 ~19ms。也符合预期:每步只投影一个 token 的 K、V,计算量不随序列长度增长,复杂度 O(n)。

结论是:在我的模型规模下,cache 反而更慢。

384 维度、8 层 decoder、序列长度 30 左右——全量重算的成本本来就很低。而手写的 CachedDecoderLayer 拆开了 MHA,多了 cat 操作,多了 Python 层面的 cache 元组管理,每步的固定开销比 PyTorch 高度优化的融合算子更大。省下的计算量抵不过额外的开销。

但 cache 的价值不在绝对速度,而在增长模式。无 cache 时后半段明显变慢,有 cache 时始终恒定。序列越长、模型越大,这个差距会越来越明显。在 1024 维度、序列长度 512 的模型上,无 cache 的后半段会肉眼可见地慢下去——那时候 cache 的恒定耗时才会真正碾压。

翻译结果完全一致。这是最重要的验证——数学上等价,实现正确。


这是衔言渡意的最后一项技术任务。KV Cache 本身不复杂,拆 MHA 也不复杂——复杂的是错误信息把你指向错误的方向,以及那些"全量模式下不存在、cache 模式下才暴露"的隐蔽适配点。

以为它是个分布式系统时,我高估了。以为它只是一个 list 加 __getitem__ 时,我低估了。拆开 PyTorch 的标准组件亲手写一遍之后,才真正知道边界在哪。