基础 · 15

Transformer 一步步拆解

把 GPT 级别的 Transformer 逐层拆开:位置编码、残差、LayerNorm、FFN——你会发现除了注意力,剩下的部分比想象的简单。

18 min read

一个 block 到底长什么样

上一篇我们拆清了注意力机制——multi-head attention 让每个 token 能"看到"其他 token 并融合信息。但一个 Transformer 不只有注意力。

GPT、Llama、Claude 用的 decoder-only Transformer 由完全相同的 block 堆叠而成(GPT-3 是 96 层)。每个 block 的结构出奇地简单,只有两个子层:

  1. Multi-Head Attention —— 跨 token 交互(上一篇讲的)
  2. Feed-Forward Network (FFN) —— 逐 token 独立计算

每个子层外面都包着残差连接 + Layer Norm。就这些。没有池化、没有卷积、没有递归——整个架构的复杂度全藏在注意力和叠层数里。

点击「下一步」看数据怎么流过一个 block:

输入
Input Embedding[seq, d]Layer NormMulti-Head AttentionAdd (residual)Layer NormLinear + GELULinear (project)Add (residual)Output → next block
Pre-Norm Transformer block。注意 FFN 先膨胀到 4d 再压回 d,以及残差连接(虚线)跳过整个子层直接加回。

注意两条虚线——它们是残差连接(skip connection)。输入直接跳过子层加到输出上:y=SubLayer(x)+x\mathbf{y} = \text{SubLayer}(\mathbf{x}) + \mathbf{x}。这是让 96 层网络能训起来的关键,下面会详细讲。

Token 嵌入与反嵌入

在数据进入 block 之前,有一个容易被忽略的关键步骤:把离散的 token 变成连续向量

模型维护一个巨大的嵌入矩阵 ERV×dE \in \mathbb{R}^{V \times d}——VV 是词表大小(GPT-2 = 50257,Llama 3 = 128k),dd 是模型维度(768 到 8192 不等)。每一行就是一个 token 的「身份向量」。

输入 token 的 ID 就是行号——嵌入操作本质上只是一次查表

xi=E[token_idi]\mathbf{x}_i = E[\text{token\_id}_i]

没有矩阵乘法,只有一次索引。这个简单操作把 "hello" 这样的符号变成了一个 768 维的浮点向量——后面所有的注意力、FFN 都在这个向量空间里运算。

反嵌入(un-embedding) 是整件事的逆过程。最后一层 block 输出一个 [seq,d][seq, d] 的矩阵,我们拿最后一个位置的向量 h\mathbf{h}EE^\top 做矩阵乘:

logits=hERV\text{logits} = \mathbf{h} \cdot E^\top \in \mathbb{R}^V

得到词表上每个 token 的分数(logit),再过 softmax 变成概率分布——这就是「下一个 token」的预测。直觉上,un-embedding 在问:"输出向量和哪个 token 的嵌入最相似?"

GPT-2 用了 weight tying——embedding 和 un-embedding 共享同一个 EE。这不只是省参数:它强制模型的输入空间和输出空间是同一个语义空间,让"相似的 token 有相似的 embedding"这个归纳偏置更强。

位置编码:让模型知道顺序

注意力机制有一个根本问题:它是置换不变的

Attention(Q,K,V)\text{Attention}(Q, K, V) 的计算只涉及点积和加权求和——如果你把输入 token 的顺序打乱,attention 矩阵的值会重新排列,但每个 token 得到的结果不变(只是行列换了位置)。换句话说,纯 attention 根本不知道 "the cat sat" 和 "sat the cat" 有区别

解决方案:在输入 embedding 上一个位置信号。Vaswani 等人的原始方案是正弦/余弦编码:

PE(pos,2i)=sin ⁣(pos100002i/d),PE(pos,2i+1)=cos ⁣(pos100002i/d)\text{PE}(pos, 2i) = \sin\!\left(\frac{pos}{10000^{2i/d}}\right), \quad \text{PE}(pos, 2i+1) = \cos\!\left(\frac{pos}{10000^{2i/d}}\right)

每个维度是一个不同频率的正弦波——低维变化快,高维变化慢。这样每个位置都有一个唯一的"指纹",并且相对位置可以通过线性变换恢复(sin(a+b)\sin(a+b) 可以表示为 sina\sin acosa\cos a 的线性组合)。

PE 矩阵 (16 位置 × 32 维度)
+sin/cos−sin/cos
维度 →位置 →
正弦位置编码的热力图。低维变化快(高频),高维变化慢(低频)。每一行是一个位置的「指纹」——模型靠这个知道 token 在第几个位置。

鼠标悬停看具体值。注意:前几列(低维)震荡很快,越往右(高维)波长越长。这组频率的设计让模型能同时感知"这两个 token 紧挨着"和"这个 token 在句子开头还是结尾"。

现代 LLM 大多已经换成了可学习的位置编码RoPE(旋转位置编码),但核心思想不变:给每个位置一个独特的、可被线性操作区分的信号。

残差连接:梯度的高速公路

神经网络篇 我们讲过:反向传播是链式乘法。如果每一层的局部导数都小于 1,乘 96 次之后梯度趋近零——前面的层完全学不动。

残差连接提供了一条绕过每个子层的直通路

y=f(x)+x\mathbf{y} = f(\mathbf{x}) + \mathbf{x}

x\mathbf{x} 求梯度:

yx=fx+I\frac{\partial \mathbf{y}}{\partial \mathbf{x}} = \frac{\partial f}{\partial \mathbf{x}} + I

不管 ff 的梯度多小,加上单位矩阵 II 之后,梯度至少为 1。这就像在每个收费站旁边修了条免费高速——梯度可以一路畅通地流回第一层。

这也是为什么 Transformer 用的是 Pre-Norm(先 LayerNorm 再进子层)而不是原论文的 Post-Norm(先子层再 LayerNorm):Pre-Norm 让残差支路完全干净——x\mathbf{x} 不经过任何变换直接加回来,梯度流更稳定。

Layer Norm:稳定训练的关键

神经网络篇 讲了 Batch Norm——对 batch 维度做标准化。但 Transformer 处理的是变长序列,batch size 往往很小,BN 效果差。

Layer Normalization 换了一个方向:对单个样本的所有特征维度做标准化:

LN(x)=γxμσ2+ϵ+β\text{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

其中 μ\muσ2\sigma^2这个 token 自己所有维度的均值和方差(不涉及 batch 里其他样本),γ\gammaβ\beta 是可学习参数。

Layer Norm 的直觉:每个 token 的向量表示可能因为残差叠加而越来越大。LN 把它拉回单位球面附近(均值 0、方差 1),然后用 γ\gammaβ\beta 恢复表达力。这让后续的注意力点积和 FFN 的输入范围始终稳定——96 层叠下来不会爆炸也不会消失。

RMSNorm 是进一步简化:只除以均方根,不减均值。Llama 系列用的就是 RMSNorm,计算更快,效果相当。

FFN:每个 token 的独立计算

每个 block 的第二个子层是前馈网络(FFN)。它对每个 token 独立地做同样的变换——token 之间不交互(交互已经在 attention 里做完了)。

标准 FFN 就是两层全连接 + 激活:

FFN(x)=W2GELU(W1x+b1)+b2\text{FFN}(\mathbf{x}) = W_2 \cdot \text{GELU}(W_1 \mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2

W1W_1 把维度从 dd 扩大到 4d4d("膨胀"),激活后 W2W_2 再压回 dd。这个膨胀-压缩的沙漏结构是所有 Transformer 的标配。

为什么要扩大再缩小?因为 attention 做的是信息路由(决定看谁),而 FFN 做的是信息处理(对融合后的内容做非线性变换)。扩大到 4d4d 给了网络足够的"计算空间"来做复杂的特征组合。

SwiGLU——现代 LLM(Llama、PaLM、Gemini)的实际选择:

SwiGLU(x)=(Swish(xW1)xV)W2\text{SwiGLU}(\mathbf{x}) = (\text{Swish}(\mathbf{x} W_1) \odot \mathbf{x} V) W_2

用一路信号门控另一路,实验上比纯 GELU 效果更好。代价是多了一个矩阵 VV,通常把膨胀比从 4× 降到 83×\tfrac{8}{3}\times 来保持参数量不变。

FFN 的参数量占整个 Transformer 的约 2/3。在 GPT-3 级别的模型里,FFN 有数十亿参数。

FFN 即 key-value 记忆库

有一个越来越受认可的直觉:FFN 就是一个巨大的 key-value 存储器。把公式展开看:

FFN(x)=j=14dGELU(w1(j)x)匹配程度(soft gate)w2(j)存储的内容\text{FFN}(\mathbf{x}) = \sum_{j=1}^{4d} \underbrace{\text{GELU}(\mathbf{w}_1^{(j)} \cdot \mathbf{x})}_{\text{匹配程度(soft gate)}} \cdot \underbrace{\mathbf{w}_2^{(j)}}_{\text{存储的内容}}

W1W_1 的每一行 w1(j)\mathbf{w}_1^{(j)} 是一个模式检测器(key):它和输入做点积,激活后得到一个"匹配分数"。W2W_2 的对应列 w2(j)\mathbf{w}_2^{(j)} 是这个模式关联的知识(value)。GELU 是一个 soft gate——只有匹配程度高的 key 才会把对应 value "释放"出来。

具体例子:假设 W1W_1 的某一行学会了检测"这个 token 处于 '法国的首都是___' 的填空位"这个 pattern。当输入匹配时,W2W_2 的对应列就会输出一个指向 "Paris" embedding 附近的向量。

这个视角解释了两件事:为什么 FFN 这么大(需要存储海量知识),以及为什么 MoE 是 FFN 的自然升级——不是让每个 token 查全部记忆,而是先路由到相关的 "专家"(记忆子集),大幅降低计算量。

Decoder-Only 架构

原始 Transformer 论文(2017)设计了 encoder-decoder 结构——encoder 看完整输入,decoder 自回归生成输出。但 GPT 系列证明了一件事:

只要 decoder 够大,encoder 完全可以去掉。

Decoder-only 架构只保留了一种 block,和标准 block 唯一的区别是 attention 加了causal mask

Maskij={0if jiif j>i\text{Mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}

加到 attention score 上之后,softmax 会把 j>ij > i 的位置概率置零——每个 token 只能看自己和之前的 token。这让模型能做自回归生成:每次预测下一个 token,然后把它拼到序列末尾继续生成。

整个 GPT / Llama / Claude 的推理过程就是:

  1. 输入 prompt 过 N 个 block(prefill 阶段,并行计算)
  2. 最后一个 token 的输出过一个线性层 + softmax → 得到下一个 token 的概率分布
  3. 采样一个 token,拼到序列末尾
  4. 新 token 再过 N 个 block(decode 阶段,只算一个 token,用 KV cache 避免重复)
  5. 重复 3-4 直到输出结束符

整个模型没有任何"理解"和"生成"的切换——它始终在做同一件事:给定前文,预测下一个 token

Prefill vs Decode:两种完全不同的计算模式

虽然数学上两个阶段做的事一模一样,但它们的计算特征截然不同:

Prefill(首次灌入)Decode(逐 token 生成)
处理对象整个 prompt,一次性并行每步只算 1 个新 token
瓶颈GPU 算力(compute-bound)显存带宽(memory-bound)
KV Cache从零计算,写入缓存从缓存读取,只追加一行
耗时占比快(高并行度)慢(串行生成)

Prefill 阶段像做一道大题——所有 token 同时算 attention,GPU 的算力被充分利用。Decode 阶段像逐字默写——每一步都要从显存把整个 KV Cache 搬到计算单元,但只产出一个 token,GPU 大部分时间在等数据。

这就是为什么 "长 prompt + 短回答" 比 "短 prompt + 长回答" 快得多——前者在 prefill 里并行消化了大部分计算,后者被 decode 的串行瓶颈卡住。

这个想法在前沿里