KV Cache 的时候,重新审视了一遍 collate_fn 里的变量——tgt_out 是标准答案,模型输出是 logits——这些每天都在用的东西,我停下来重新走了一遍它们在维度上的完整链路。

理清了数据怎么从 token id 变成 d_model、又从 d_model 变回 vocab_size 之后,我的问题落在了终点:这个 (batch, seq_len, vocab_size) 的三维张量,到底是怎么变成一个 loss 数字的?

然后每一层答案都把我推向下一个“为什么”:为什么要拍平?为什么不取 max?为什么必须可导?为什么用线性层?为什么需要激活函数?直到最后发现——整条技术栈的设计逻辑,全部收束到同一个点上。

模型输出的形状

用户输入的文本被 tokenizer 切分、转换成 id 序列 (batch, seq) ——一般是一组输入被同时处理,所以有 batch 维。

这些 id 通过嵌入层映射成高维向量,用户输入的形状变成 (batch, seq, d_model)d_model 是模型内部的工作维度,衔言渡意里是 384。从这一步起到输出之前,模型全程都在 d_model 空间里工作:encoder 的每一层、decoder 的每一层,进出都是 (batch, seq, d_model)

Decoder 最后一层出来,还是 (batch, seq, d_model) 。然后过一个线性投影层 nn.Linear(d_model, vocab_size) ,把 384 维映射到 48000 维——每个位置得到一个词表大小的 logits 向量。这才是模型最终输出的形状:(batch, seq, vocab_size)

两头各一次维度转换:嵌入把 token id 投影进 d_model,输出投影把 d_model 映射回 vocab_size。模型本体在中间,只认识 d_model。

训练时

Encoder 处理源语言,用的是双向注意力,每个位置看得到所有其他位置。Decoder 处理目标语言,因果掩码限制“在每个位置只能关注当前和过往位置,未来不可见”——一个 (seq, seq) 的三角矩阵。

实际运算是并行的。seq 个位置同时计算,掩码让位置 3 只能看到 0,1,2,3,位置 7 只能看到 0~7。每个位置各自做了一次“如果序列到此为止,下一个 token 应该是什么”的预测。

seq 个预测,seq 份 loss,一次反向传播搞定。

训练时不用模型自己的预测结果作为下一步的输入,用的是目标序列——如果用模型自己的预测,一步错步步错,错误会级联。这就是 teacher forcing:把正确答案喂给模型,只让它学“在正确上下文下该预测什么”。

推理时

没有目标序列可以喂了。模型自回归生成:从一个起点出发,预测下一个 token,拼上去,再预测,直到预测出结束符。

每一步只需要最后一个位置的输出 [:, -1, :] ,形状 (batch, vocab_size) ——这就是下一个 token 的 logits。

因果掩码在事实上不再限制可见性——没有未来位置了,都是已知的。在 KV Cache 下更极端:每步只送入最新一个 token 作为 Q,K/V 是全部历史缓存,因果掩码就是一个行向量,“上三角”的形状都没了。


而不论训练还是推理,无论有没有 KV Cache ,模型的输出在逻辑上都是 (batch, seq, vocab_size) 。但这个三维张量,是怎么变成一个数字、又怎么引导模型调整权重的?

Loss:从三维张量到一个数字,再到什么?

衔言渡意的 loss 计算是 CrossEntropy,它要求的输入格式是 (N, C) ,N 个样本,C 个类别,这在事实上就是“一组独立的分类问题”——对于每个样本,从 C 个可能类别中选择一个结果。

对于模型的训练输出格式而言, (batch, seq_len, vocab_size) ,正是“​batch \times seq\_len 个样本”,每个样本有 vocab_size 种可能,每个可能各有大小——logits;每个位置的目标,就是“ vocab_size 个类别中的一个”,其值为正确 token 的 id——N 个样本,C 个类别。

这里是把整个 batch 拍平,每个序列也拆开,变成了针对“这次输入中的所有单个 token”的分类问题。token 和 token 之间没有关联,只是凑巧放在了一起,批量算了而已。

\begin{aligned} input &= (batch \times seq\_len, vocab\_size)\\ target &= (batch \times seq\_len)\\ loss &= F.cross\_entropy(input, target) \end{aligned}

我一开始以为,cross_entropy 内部做的事情是通过 argmax 从 vocab_size 个值中取出最大的那个位置——这就是“模型的选择”。然后把那个位置的概率拿出来,和 1 比较,得到一个“以满分为基准的分数”,取平均——loss。

但这个思路有一个根本问题:一旦 argmax 取出了一个索引,这个数字就和词表断开了。模型猜了第 15021 个词,错了——然后呢?这个“错了”只是一个判定,它不携带任何关于正确答案的信息。模型不知道正确答案是第 8437 个词,不知道自己其实给第 8437 个词分配了 0.03 的概率还是 0.30 的概率,不知道差多少,不知道往哪调。一个离散的对错判定,就是一条死路。

事实上, cross_entropy 做的事情和我的猜测完全不同。它对 logits 做 softmax 得到概率分布,然后直接拿目标 token 的 id 作为索引,取出模型为正确答案分配的概率,再求负对数——每个位置得到一个数字,取均值,就是 loss。不做选择,不做判定,不丢弃任何信息。模型给正确答案的概率越高,loss 越小。

由此,每一次的 loss 都不是“对”或“错”的二元判断,而是一个连续的、精确的数字——模型知道自己这次“多好”或“多差”。

但“知道好坏程度”不等于“知道怎么调”。loss 告诉你“当前 2.3 分”,不告诉你 51.7M 个参数里,第 4 层第 1827 个权重该往大调还是往小调、调多少。

一个数字,几千万个旋钮——缺的是一种方法,能把这个数字翻译成每个旋钮各自的调整指令。

可导意味着什么

loss 算出来了,但它只是一个分数,一个模型“做得好坏与否”的标准,它只是没有破坏概率分布的连续关系,并没有“想要 loss 降低,需要把某个权重调大或者调小,调整多少”的信息。

loss 的得出是模型那么多个权重的共同作用结果,于是这里可以简化一下思考——loss 随权重的变化而变化,但它不会随机变化,固定的权重总是在固定的目标下得出固定的 loss,这不就是... loss = f(权重) 吗?

f 是权重与 loss 的对应关系。如果能对 f 求导,自然就得出了要想让 loss 变小,权重应该如何调整——导数同时包含了权重需要知道的方向和大小,学习链条连起来了。

这同时解释了为什么我之前思考中的 argmax 不能用来评估模型好坏。不只是因为它得出的分数与词表概率分布无关,还因为...它不可导。如果没有办法对它求导,自然就无法得出需要对权重做什么调整,才能让模型学习,让 loss 下降。

loss 如何影响每一个权重?

刚才把Transformer的 f 简化了一下,变成了一个整体。

现在把那个简化思考拆开。Transformer的 f 是一个很复杂的东西,线性变换,激活函数,残差连接...以及,loss 计算方式。

它过于复杂,没有办法一步到位把“第 4 层 第 1827 个权重的导数”算出来,但它的每个组成部分都很简单,线性变换是一次函数,激活函数是个没有权重的非线性变换,残差连接是个加法,loss 是个负对数,很容易对每个部件求导。

那么可以从 loss 开始求它的导数,再拿着导数把倒数第二步的导数求出来然后相乘,倒数第三步的...一直到第一层的第一个权重,把它的导数也求出来,再和一路乘过来的导数相乘,f 的求导就完成了,每个权重对 loss 的影响关系也就知道了,反向调整就好。

这个从 loss 到第一个权重的求导链条,就是反向传播。

但反向传播的计算是有代价的——模型每走过一个步骤都被记录了下来,包括那一步时的输入输出值,全部都得存着。

为了在反向传播时求导,针对每个权重,前向传播时,每一层、每个变换算出来的激活值都要留着,再加上优化器,加上计算图结构,显存就这么被吃掉了。

连乘的代价

反向传播的计算不只会产生内存层面的代价,这么多的乘法连在一起,很容易让导数出现数量级上的问题。

模型的每一个步骤都会产生导数,如果持续出现比较小的导数比如 0.1,十个步骤之后,导数就变成 1e-10 了。这个缩小效应累积起来,导数就会持续变小,模型权重每次更新的幅度也会越来越小,直到没有明显变化,学习停滞。

但如果持续出现比较大的导数,导数同样会持续变化,越来越大,直到模型单次学习时权重的跳跃幅度超过权重数值本身,学习失控。

所以要控制导数的值,不能让它太小,也不能让它太大。要控制,得先明确它现在是大是小。

一个权重的导数表示了 loss 相对于它的变化率,而一层的导数事实上也不是一个一个权重算再放到一起,是直接把权重张量整体拿来算的——GPU 也正好擅长干这个嘛。

标量对标量求导的结果是导数,标量对张量求导,就是梯度——矩阵微积分。

但梯度本身不是标量,没法直接说它是大还是小,没法对它进行监控,也就没法调整它。所以一般对梯度求 L2 范数,把它压成一个数字:

\|g\|_2 = \sqrt{\sum_i g_i^2}

这就是梯度范数。梯度范数接近 0 导致的学习停滞是梯度消失,梯度范数超过数值上限导致的权重剧烈波动,就是梯度爆炸。

LayerNorm 解决梯度问题的方式

要解决梯度突变问题,还是要看梯度本身的具体计算过程,通过对计算系数的影响,控制梯度范数的量级。

以线性层为例,它是标准的一次函数,​y = Wx + b。反向传播时,loss 对 W 的梯度里会直接乘上激活值 x——激活值直接参与了梯度的计算。

所以有了 LayerNorm,一个 LayerNorm 层的计算是这样的:

\begin{aligned} \mu &= \frac{1}{n}\sum_{i=1}^{n} x_i\\ \sigma^2 &= \frac{1}{n}\sum_{i=1}^{n}(x_i - \mu)^2\\ LayerNorm(x_i) &= \gamma \cdot \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \end{aligned}

​\epsilon 是一个极小的数,一般是 1e-5,防止出现除 0 问题。​\gamma​\beta 是可学习参数。

整体过程就是,每个激活值减去平均数,除以标准差,把它们控制在以0为界,上下分布散布程度为 1 且相对关系不变的状态。然后乘以​\gamma 再加上 ​\beta——做一个线性变换,让模型自已决定具体的激活值分布状态。

LayerNorm通过对激活值的控制,间接控制了梯度的状态。

梯度裁剪

LayerNorm 控制的只是激活值这一个影响因子。反向传播时梯度的另一条路径依赖权重矩阵本身——权重在训练中不断变化,LayerNorm 管不到这里。再加上 loss 曲面本身可能存在陡峭区域,单步产生的梯度可能突然飙升。所以 LayerNorm 之外还需要一道保险——梯度裁剪。

梯度范数就派上用场了——不止是看大小,还可以反向控制梯度本身。

\|cg\|_2 = \sqrt{\sum_i (cg_i)^2} = \sqrt{c^2 \sum_i g_i^2} = |c| \cdot \|g\|_2

给梯度乘上一个数 c 再算梯度范数,你会发现梯度范数也乘了一个 c——绝对齐次性。

于是直接设定一个梯度阈值,发现范数超过阈值时,计算范数和阈值的比例,把梯度乘上那个比例,梯度范数自然被调整到阈值大小。

梯度 = 梯度 \times (阈值 / 范数)

为什么线性函数是层的基本单元?

之前说,线性层的本质是个一次函数:​output = weight \times input + b。但是,函数千千万,可导的也不在少数,为什么是一次函数?为什么是 ​y = ax + b 这个这么简单的东西?它是否影响了模型的上限?

非线性函数的问题

用非线性函数作为一层的基本单元在逻辑上当然是可行的,我们看看用二次函数 ​y = ax^2 + bx + c 会如何,公式也不难:

output = weight_1 \times input^2 + weight_2 \times input + bias

对吧?很简单,很优美,可导,引入了更多复杂度,模型的上限提升了!

然后呢?走一遍流程,看看实际情况是怎样的。

模型的层与层之间的连接不是加也不是乘,是代入。上一层的计算结果直接作为参数传入下一层,公式是这样的:

\begin{aligned} layer_1(x) &= a_1 x^2 + b_1 x + c_1 \quad \text{(二次)}\\ layer_2(layer_1(x)) &= a_2 (layer_1(x))^2 + b_2 \cdot layer_1(x) + c_2 \quad \text{(四次)}\\ layer_3(layer_2(layer_1(x))) &= \text{(八次)} \end{aligned}

​layer_3 时,就是八次多项式了。

衔言渡意有八层,如果它是二次非线性模型,到第八层时,输出相对于输入就是 256 次多项式。

什么样的激活值能扛住 256 次的幂运算?什么样的 LayerNorm 能把 256 次增长的激活值拉回来?

线性函数的问题

线性函数没有指数爆炸问题,它是这样的:

\begin{aligned} layer_1(x) &= a_1 x + b_1 \\ layer_2(layer_1(x)) &= a_2(a_1 x + b_1) + b_2 = (a_2 a_1) x + (a_2 b_1 + b_2) \end{aligned}

无论是多少层的叠加,它都是线性函数——但问题也在这里,一万个乃至更多的线性函数叠加在一起,它都等价于一个线性函数。它太简单了,只能学线性关系,但现实世界不是线性的。

激活函数的作用

所以需要一个东西夹在线性层与层之间,把简单的线性叠加打破,把那个等价拆开,让层与层真的是两个相对独立的存在,让它们不能被代数合并,“层”才有了意义,也才有了我的特征提炼法这种工程总结。

这个东西就是激活函数,它做的事情很简单,所以它的实现也都很简单。以ReLU为例,小于 0 的值变成 0,大于 0 的不动。就这么一个操作,插在两个线性层之间,代数折叠就不成立了。

线性层负责学习,激活函数负责把线性层的学习结果扭一下,打破线性关系。

那么...如果给非线性层也配一个“扭”,把它的指数爆炸扭回来...会如何?

以二次层为例。我不知道数学上是否存在完美的“反二次”运算,但从逻辑上来看,如果真的在一个二次层后面加一个反二次,那...二次学习 + 反二次扭转 = 两步抵消 = 恒等映射 = 真正意义上什么也没做。线性层至少还能缩放平移呢!


把学习和非线性绑在一起,堆叠时会失控。分开之后,线性层怎么堆都是稳定的线性,表达能力不足时加入更多层即可;激活函数是固定的、精心选过的“温和的非线性”,把每个新的线性层变成真正的层,整体上可控——目前工程上的最优解。

一切收束到微分

从模型输出到 loss 计算,从 argmax 到 softmax,从线性层到激活函数,这一路追问下来,我把整个深度学习...深度解剖了一遍。

回头看这篇文章里的所有“为什么”:

为什么 loss 不用 argmax?为什么层数深了会出梯度问题?为什么线性层是基本单元?为什么不用二次层?为什么要有激活函数?

答案其实是同一句话:因为这条从参数到 loss 的计算链路,必须处处可导。

可导不是一个技术细节,是整个训练机制的前提。没有它,你面对的是 51.7M 个参数——每个都有自己的值,每个都可能需要调整,但你不知道往哪调。唯一的办法是随机试。这辈子试不完。

可导给了方向,微分是提取方向的数学工具。链式法则是让这个方向能穿过几十层网络传回来的传递机制,整个深度学习的技术栈——架构、激活函数、初始化、归一化——全部是在“这条链路必须可导”这个前提下运转的。

它是地基。不是其中一块砖。