引言
Transformer 架构自 2017 年论文 “Attention Is All You Need”
提出以来,彻底改变了自然语言处理乃至整个深度学习领域的格局。从 BERT 到
GPT,从 Vision Transformer 到 Stable Diffusion,Transformer
已经成为几乎所有 SOTA 模型的基础架构。本文将深入解析 Transformer
的每一个核心组件——Self-Attention、Multi-Head Attention、Positional
Encoding、Layer Normalization,并梳理 GPT、BERT、T5
等重要变体的架构差异。
graph TB
subgraph "Encoder (N layers)"
A[Input Embedding<br/>+ Position Encoding] --> B[Multi-Head<br/>Self-Attention]
B --> C[Add & Norm]
C --> D[Feed-Forward<br/>Network]
D --> E[Add & Norm]
end
subgraph "Decoder (N layers)"
F[Output Embedding<br/>+ Position Encoding] --> G[Masked Multi-Head<br/>Self-Attention]
G --> H[Add & Norm]
H --> I[Multi-Head<br/>Cross-Attention]
E --> I
I --> J[Add & Norm]
J --> K[Feed-Forward<br/>Network]
K --> L[Add & Norm]
end
L --> M[Linear + Softmax]
M --> N[Output Probabilities]
style B fill:#e74c3c,color:#fff
style I fill:#3498db,color:#fff
style D fill:#2ecc71,color:#fff
Self-Attention 机制
Self-Attention 是 Transformer
的核心——它允许序列中的每个位置关注所有其他位置,捕获长距离依赖。
公式推导
1 2 3 4 5 6 7 Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V 其中: Q = X * W_Q (Query: 我在找什么?) K = X * W_K (Key: 我有什么?) V = X * W_V (Value: 我的信息是什么?) d_k = Key 的维度 (缩放因子,防止点积过大导致 softmax 饱和)
graph LR
X[输入 X] --> Q["Q = X·W_Q"]
X --> K["K = X·W_K"]
X --> V["V = X·W_V"]
Q --> MM["QK^T"]
K --> MM
MM --> S["÷ √d_k"]
S --> SM["Softmax"]
SM --> MV["× V"]
V --> MV
MV --> O[输出]
style Q fill:#e74c3c,color:#fff
style K fill:#3498db,color:#fff
style V fill:#2ecc71,color:#fff
实现代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathclass SelfAttention (nn.Module): def __init__ (self, d_model: int , d_k: int ): super ().__init__() self.d_k = d_k self.W_Q = nn.Linear(d_model, d_k, bias=False ) self.W_K = nn.Linear(d_model, d_k, bias=False ) self.W_V = nn.Linear(d_model, d_k, bias=False ) def forward (self, x, mask=None ): """ x: (batch_size, seq_len, d_model) mask: (batch_size, seq_len, seq_len) or None """ Q = self.W_Q(x) K = self.W_K(x) V = self.W_V(x) scores = torch.matmul(Q, K.transpose(-2 , -1 )) / math.sqrt(self.d_k) if mask is not None : scores = scores.masked_fill(mask == 0 , float ('-inf' )) attention_weights = F.softmax(scores, dim=-1 ) output = torch.matmul(attention_weights, V) return output, attention_weights
注意力矩阵可视化
1 2 3 4 5 6 7 8 9 10 11 Input: "我 爱 自然 语言 处理" Attention Matrix (softmax后): 我 爱 自然 语言 处理 我 [0.15, 0.10, 0.25, 0.30, 0.20] 爱 [0.12, 0.08, 0.30, 0.28, 0.22] 自然 [0.10, 0.15, 0.20, 0.35, 0.20] 语言 [0.08, 0.12, 0.35, 0.25, 0.20] 处理 [0.10, 0.10, 0.30, 0.25, 0.25] → "语言"在关注"自然"时权重最高(0.35),体现了"自然语言"的语义关联
Multi-Head Attention
多头注意力让模型同时从不同的表示子空间学习注意力模式:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 class MultiHeadAttention (nn.Module): def __init__ (self, d_model: int , n_heads: int ): super ().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.W_Q = nn.Linear(d_model, d_model, bias=False ) self.W_K = nn.Linear(d_model, d_model, bias=False ) self.W_V = nn.Linear(d_model, d_model, bias=False ) self.W_O = nn.Linear(d_model, d_model, bias=False ) def forward (self, x, mask=None ): batch_size, seq_len, _ = x.shape Q = self.W_Q(x) K = self.W_K(x) V = self.W_V(x) Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1 , 2 ) K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1 , 2 ) V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1 , 2 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) / math.sqrt(self.d_k) if mask is not None : scores = scores.masked_fill(mask.unsqueeze(1 ) == 0 , float ('-inf' )) attn_weights = F.softmax(scores, dim=-1 ) context = torch.matmul(attn_weights, V) context = context.transpose(1 , 2 ).contiguous().view( batch_size, seq_len, self.d_model ) output = self.W_O(context) return output
graph TB
A[Input X] --> B1["Head 1<br/>语法关系"]
A --> B2["Head 2<br/>语义关系"]
A --> B3["Head 3<br/>位置关系"]
A --> B4["Head h<br/>其他模式"]
B1 --> C[Concat]
B2 --> C
B3 --> C
B4 --> C
C --> D["W_O 线性投影"]
D --> E[Output]
style B1 fill:#e74c3c,color:#fff
style B2 fill:#3498db,color:#fff
style B3 fill:#2ecc71,color:#fff
style B4 fill:#f39c12,color:#000
Positional Encoding
Transformer
没有循环结构,需要额外的位置信息。原始论文使用正弦/余弦位置编码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class PositionalEncoding (nn.Module): def __init__ (self, d_model: int , max_seq_len: int = 5000 ): super ().__init__() pe = torch.zeros(max_seq_len, d_model) position = torch.arange(0 , max_seq_len, dtype=torch.float ).unsqueeze(1 ) div_term = torch.exp( torch.arange(0 , d_model, 2 ).float () * (-math.log(10000.0 ) / d_model) ) pe[:, 0 ::2 ] = torch.sin(position * div_term) pe[:, 1 ::2 ] = torch.cos(position * div_term) pe = pe.unsqueeze(0 ) self.register_buffer('pe' , pe) def forward (self, x ): """x: (batch_size, seq_len, d_model)""" return x + self.pe[:, :x.size(1 ), :]
位置编码变体
graph TB
A[位置编码方法] --> B[绝对位置编码]
A --> C[相对位置编码]
A --> D[旋转位置编码 RoPE]
B --> B1[正弦/余弦<br/>原始 Transformer]
B --> B2[可学习位置编码<br/>BERT/GPT]
C --> C1[相对位置偏置<br/>T5 / ALiBi]
D --> D1[Llama / Qwen<br/>支持长度外推]
style B1 fill:#3498db,color:#fff
style D1 fill:#2ecc71,color:#fff
RoPE (Rotary Position
Embedding)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class RotaryPositionEmbedding (nn.Module): """RoPE: used in Llama, Qwen, Mistral.""" def __init__ (self, dim: int , max_seq_len: int = 8192 , base: int = 10000 ): super ().__init__() inv_freq = 1.0 / (base ** (torch.arange(0 , dim, 2 ).float () / dim)) self.register_buffer('inv_freq' , inv_freq) t = torch.arange(max_seq_len, dtype=torch.float ) freqs = torch.einsum('i,j->ij' , t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1 ) self.register_buffer('cos' , emb.cos()) self.register_buffer('sin' , emb.sin()) def forward (self, q, k, seq_len ): cos = self.cos[:seq_len] sin = self.sin[:seq_len] q_rotated = (q * cos) + (self._rotate_half(q) * sin) k_rotated = (k * cos) + (self._rotate_half(k) * sin) return q_rotated, k_rotated @staticmethod def _rotate_half (x ): x1, x2 = x.chunk(2 , dim=-1 ) return torch.cat([-x2, x1], dim=-1 )
Feed-Forward Network
每个 Transformer 层包含一个两层的前馈网络:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class FeedForward (nn.Module): """Original: FFN(x) = max(0, xW1 + b1)W2 + b2""" def __init__ (self, d_model: int , d_ff: int , dropout: float = 0.1 ): super ().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward (self, x ): return self.linear2(self.dropout(F.relu(self.linear1(x))))class SwiGLUFeedForward (nn.Module): """SwiGLU variant: used in Llama, Mistral.""" def __init__ (self, d_model: int , d_ff: int ): super ().__init__() self.gate_proj = nn.Linear(d_model, d_ff, bias=False ) self.up_proj = nn.Linear(d_model, d_ff, bias=False ) self.down_proj = nn.Linear(d_ff, d_model, bias=False ) def forward (self, x ): gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) return self.down_proj(gate * up)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class TransformerEncoderLayer (nn.Module): def __init__ (self, d_model: int , n_heads: int , d_ff: int , dropout: float = 0.1 ): super ().__init__() self.self_attn = MultiHeadAttention(d_model, n_heads) self.ffn = FeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward (self, x, mask=None ): attn_output = self.self_attn(self.norm1(x), mask) x = x + self.dropout(attn_output) ffn_output = self.ffn(self.norm2(x)) x = x + self.dropout(ffn_output) return x
训练范式:预训练与微调
graph TB
A[预训练 Pre-training] --> B{训练目标}
B --> C[Masked Language Model<br/>BERT: 预测被遮盖的词]
B --> D[Causal Language Model<br/>GPT: 预测下一个词]
B --> E[Seq2Seq<br/>T5: 文本到文本]
F[微调 Fine-tuning] --> G[全量微调]
F --> H[LoRA / QLoRA]
F --> I[Prompt Tuning]
A --> F
style C fill:#3498db,color:#fff
style D fill:#e74c3c,color:#fff
style E fill:#2ecc71,color:#fff
Masked Language Model (BERT)
Causal Language Model (GPT)
架构变体对比
graph LR
subgraph "Encoder-Only (BERT)"
A1[Input] --> B1["Bidirectional<br/>Self-Attention"]
B1 --> C1["[CLS] Token<br/>分类/相似度"]
end
subgraph "Decoder-Only (GPT)"
A2[Input] --> B2["Causal (Masked)<br/>Self-Attention"]
B2 --> C2["Next Token<br/>Prediction"]
end
subgraph "Encoder-Decoder (T5)"
A3[Input] --> B3["Bidirectional<br/>Encoder"]
B3 --> D3["Cross-Attention<br/>Decoder"]
D3 --> C3["Seq2Seq<br/>Output"]
end
style B1 fill:#3498db,color:#fff
style B2 fill:#e74c3c,color:#fff
style B3 fill:#2ecc71,color:#fff
style D3 fill:#f39c12,color:#000
Encoder-Only
BERT, RoBERTa
MLM + NSP
分类、NER、相似度
Decoder-Only
GPT, Llama, Qwen
CLM (Next Token)
文本生成、对话
Encoder-Decoder
T5, BART
Span Corruption
翻译、摘要、问答
现代优化技术
KV-Cache
在自回归生成中,每生成一个新 token,之前 token 的 Key 和 Value
无需重新计算:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class CachedMultiHeadAttention (nn.Module): def forward (self, x, kv_cache=None ): Q = self.W_Q(x) K = self.W_K(x) V = self.W_V(x) if kv_cache is not None : K = torch.cat([kv_cache['K' ], K], dim=-2 ) V = torch.cat([kv_cache['V' ], V], dim=-2 ) new_cache = {'K' : K, 'V' : V} scores = torch.matmul(Q, K.transpose(-2 , -1 )) / math.sqrt(self.d_k) attn = F.softmax(scores, dim=-1 ) output = torch.matmul(attn, V) return output, new_cache
Flash Attention
Flash Attention 通过分块计算和 IO
感知的算法设计,在不牺牲精度的情况下将 Attention 计算速度提升 2-4
倍:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 from flash_attn import flash_attn_func output = flash_attn_func( q, k, v, dropout_p=0.0 , causal=True , softmax_scale=1.0 / math.sqrt(d_k), )
GQA (Grouped Query Attention)
graph TB
subgraph "MHA (Multi-Head)"
A1[Q1 K1 V1] --- A2[Q2 K2 V2] --- A3[Q3 K3 V3] --- A4[Q4 K4 V4]
end
subgraph "MQA (Multi-Query)"
B1[Q1] --- B5[K V<br/>shared]
B2[Q2] --- B5
B3[Q3] --- B5
B4[Q4] --- B5
end
subgraph "GQA (Grouped-Query)"
C1[Q1 Q2] --- C5[K1 V1<br/>Group 1]
C3[Q3 Q4] --- C6[K2 V2<br/>Group 2]
end
style A1 fill:#e74c3c,color:#fff
style B5 fill:#3498db,color:#fff
style C5 fill:#2ecc71,color:#fff
style C6 fill:#2ecc71,color:#fff
GQA 在 MHA(每个 head 独立 KV)和 MQA(所有 head 共享
KV)之间取平衡:
Llama 2 70B 使用 8 个 KV head,64 个 Query head
大幅减少 KV-Cache 内存,同时保持接近 MHA 的效果
总结
Transformer 的成功源于其灵活且强大的注意力机制——Self-Attention
实现了全局信息交互,Multi-Head Attention
学习多种注意力模式,位置编码注入序列位置信息。理解这些核心组件是深入 LLM
领域的基础。
现代 Transformer 的关键优化方向:
RoPE 替代绝对位置编码,支持长序列外推
SwiGLU 替代 ReLU FFN,提升模型容量
GQA 减少推理时的 KV-Cache 内存
Flash Attention 加速训练和推理
Pre-LN (在 Attention 前做 LayerNorm) 比 Post-LN
训练更稳定