AI · #transformer#attention#deep-learning

Transformer架构深度解析

2024.09.18 7 min 2.8k
// 目录 · contents

引言

Transformer 架构自 2017 年论文 “Attention Is All You Need” 提出以来,彻底改变了自然语言处理乃至整个深度学习领域的格局。从 BERT 到 GPT,从 Vision Transformer 到 Stable Diffusion,Transformer 已经成为几乎所有 SOTA 模型的基础架构。本文将深入解析 Transformer 的每一个核心组件——Self-Attention、Multi-Head Attention、Positional Encoding、Layer Normalization,并梳理 GPT、BERT、T5 等重要变体的架构差异。

Transformer 整体架构

graph TB
    subgraph "Encoder (N layers)"
        A[Input Embedding<br/>+ Position Encoding] --> B[Multi-Head<br/>Self-Attention]
        B --> C[Add & Norm]
        C --> D[Feed-Forward<br/>Network]
        D --> E[Add & Norm]
    end

    subgraph "Decoder (N layers)"
        F[Output Embedding<br/>+ Position Encoding] --> G[Masked Multi-Head<br/>Self-Attention]
        G --> H[Add & Norm]
        H --> I[Multi-Head<br/>Cross-Attention]
        E --> I
        I --> J[Add & Norm]
        J --> K[Feed-Forward<br/>Network]
        K --> L[Add & Norm]
    end

    L --> M[Linear + Softmax]
    M --> N[Output Probabilities]

    style B fill:#e74c3c,color:#fff
    style I fill:#3498db,color:#fff
    style D fill:#2ecc71,color:#fff

Self-Attention 机制

Self-Attention 是 Transformer 的核心——它允许序列中的每个位置关注所有其他位置,捕获长距离依赖。

公式推导

1
2
3
4
5
6
7
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

其中:
Q = X * W_Q (Query: 我在找什么?)
K = X * W_K (Key: 我有什么?)
V = X * W_V (Value: 我的信息是什么?)
d_k = Key 的维度 (缩放因子,防止点积过大导致 softmax 饱和)
graph LR
    X[输入 X] --> Q["Q = X·W_Q"]
    X --> K["K = X·W_K"]
    X --> V["V = X·W_V"]

    Q --> MM["QK^T"]
    K --> MM

    MM --> S["÷ √d_k"]
    S --> SM["Softmax"]
    SM --> MV["× V"]
    V --> MV

    MV --> O[输出]

    style Q fill:#e74c3c,color:#fff
    style K fill:#3498db,color:#fff
    style V fill:#2ecc71,color:#fff

实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
def __init__(self, d_model: int, d_k: int):
super().__init__()
self.d_k = d_k
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)

def forward(self, x, mask=None):
"""
x: (batch_size, seq_len, d_model)
mask: (batch_size, seq_len, seq_len) or None
"""
Q = self.W_Q(x) # (batch, seq_len, d_k)
K = self.W_K(x) # (batch, seq_len, d_k)
V = self.W_V(x) # (batch, seq_len, d_k)

# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: (batch, seq_len, seq_len)

if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

attention_weights = F.softmax(scores, dim=-1)
# attention_weights: (batch, seq_len, seq_len)

output = torch.matmul(attention_weights, V)
# output: (batch, seq_len, d_k)

return output, attention_weights

注意力矩阵可视化

1
2
3
4
5
6
7
8
9
10
11
Input: "我 爱 自然 语言 处理"

Attention Matrix (softmax后):
我 爱 自然 语言 处理
我 [0.15, 0.10, 0.25, 0.30, 0.20]
爱 [0.12, 0.08, 0.30, 0.28, 0.22]
自然 [0.10, 0.15, 0.20, 0.35, 0.20]
语言 [0.08, 0.12, 0.35, 0.25, 0.20]
处理 [0.10, 0.10, 0.30, 0.25, 0.25]

→ "语言"在关注"自然"时权重最高(0.35),体现了"自然语言"的语义关联

Multi-Head Attention

多头注意力让模型同时从不同的表示子空间学习注意力模式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0

self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads

self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)

def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape

# Linear projections
Q = self.W_Q(x) # (batch, seq_len, d_model)
K = self.W_K(x)
V = self.W_V(x)

# Split into heads: (batch, n_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

# Scaled dot-product attention for each head
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: (batch, n_heads, seq_len, seq_len)

if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))

attn_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# context: (batch, n_heads, seq_len, d_k)

# Concatenate heads
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)

# Final linear projection
output = self.W_O(context)
return output
graph TB
    A[Input X] --> B1["Head 1<br/>语法关系"]
    A --> B2["Head 2<br/>语义关系"]
    A --> B3["Head 3<br/>位置关系"]
    A --> B4["Head h<br/>其他模式"]

    B1 --> C[Concat]
    B2 --> C
    B3 --> C
    B4 --> C

    C --> D["W_O 线性投影"]
    D --> E[Output]

    style B1 fill:#e74c3c,color:#fff
    style B2 fill:#3498db,color:#fff
    style B3 fill:#2ecc71,color:#fff
    style B4 fill:#f39c12,color:#000

Positional Encoding

Transformer 没有循环结构,需要额外的位置信息。原始论文使用正弦/余弦位置编码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_seq_len: int = 5000):
super().__init__()

pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)

pe[:, 0::2] = torch.sin(position * div_term) # Even dimensions
pe[:, 1::2] = torch.cos(position * div_term) # Odd dimensions

pe = pe.unsqueeze(0) # (1, max_seq_len, d_model)
self.register_buffer('pe', pe)

def forward(self, x):
"""x: (batch_size, seq_len, d_model)"""
return x + self.pe[:, :x.size(1), :]

位置编码变体

graph TB
    A[位置编码方法] --> B[绝对位置编码]
    A --> C[相对位置编码]
    A --> D[旋转位置编码 RoPE]

    B --> B1[正弦/余弦<br/>原始 Transformer]
    B --> B2[可学习位置编码<br/>BERT/GPT]

    C --> C1[相对位置偏置<br/>T5 / ALiBi]

    D --> D1[Llama / Qwen<br/>支持长度外推]

    style B1 fill:#3498db,color:#fff
    style D1 fill:#2ecc71,color:#fff

RoPE (Rotary Position Embedding)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class RotaryPositionEmbedding(nn.Module):
"""RoPE: used in Llama, Qwen, Mistral."""

def __init__(self, dim: int, max_seq_len: int = 8192, base: int = 10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

t = torch.arange(max_seq_len, dtype=torch.float)
freqs = torch.einsum('i,j->ij', t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos', emb.cos())
self.register_buffer('sin', emb.sin())

def forward(self, q, k, seq_len):
cos = self.cos[:seq_len]
sin = self.sin[:seq_len]

q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated

@staticmethod
def _rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)

Feed-Forward Network

每个 Transformer 层包含一个两层的前馈网络:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class FeedForward(nn.Module):
"""Original: FFN(x) = max(0, xW1 + b1)W2 + b2"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))

class SwiGLUFeedForward(nn.Module):
"""SwiGLU variant: used in Llama, Mistral."""
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)

def forward(self, x):
gate = F.silu(self.gate_proj(x)) # SiLU activation
up = self.up_proj(x)
return self.down_proj(gate * up)

完整 Transformer 层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
# Self-attention with residual connection
attn_output = self.self_attn(self.norm1(x), mask) # Pre-LN
x = x + self.dropout(attn_output)

# Feed-forward with residual connection
ffn_output = self.ffn(self.norm2(x))
x = x + self.dropout(ffn_output)

return x

训练范式:预训练与微调

graph TB
    A[预训练 Pre-training] --> B{训练目标}
    B --> C[Masked Language Model<br/>BERT: 预测被遮盖的词]
    B --> D[Causal Language Model<br/>GPT: 预测下一个词]
    B --> E[Seq2Seq<br/>T5: 文本到文本]

    F[微调 Fine-tuning] --> G[全量微调]
    F --> H[LoRA / QLoRA]
    F --> I[Prompt Tuning]

    A --> F

    style C fill:#3498db,color:#fff
    style D fill:#e74c3c,color:#fff
    style E fill:#2ecc71,color:#fff

Masked Language Model (BERT)

1
2
3
4
5
6
7
8
# BERT pre-training objective
# Input: "我 [MASK] 自然 语言 [MASK]"
# Target: "我 爱 自然 语言 处理"

# 15% of tokens are selected for masking:
# - 80% replaced with [MASK]
# - 10% replaced with random token
# - 10% kept unchanged

Causal Language Model (GPT)

1
2
3
4
5
6
7
8
9
10
# GPT pre-training objective
# Input: "我 爱 自然 语言"
# Target: "爱 自然 语言 处理"
# (Predict next token, with causal mask preventing looking ahead)

# Causal mask example for seq_len=4:
# [[1, 0, 0, 0],
# [1, 1, 0, 0],
# [1, 1, 1, 0],
# [1, 1, 1, 1]]

架构变体对比

graph LR
    subgraph "Encoder-Only (BERT)"
        A1[Input] --> B1["Bidirectional<br/>Self-Attention"]
        B1 --> C1["[CLS] Token<br/>分类/相似度"]
    end

    subgraph "Decoder-Only (GPT)"
        A2[Input] --> B2["Causal (Masked)<br/>Self-Attention"]
        B2 --> C2["Next Token<br/>Prediction"]
    end

    subgraph "Encoder-Decoder (T5)"
        A3[Input] --> B3["Bidirectional<br/>Encoder"]
        B3 --> D3["Cross-Attention<br/>Decoder"]
        D3 --> C3["Seq2Seq<br/>Output"]
    end

    style B1 fill:#3498db,color:#fff
    style B2 fill:#e74c3c,color:#fff
    style B3 fill:#2ecc71,color:#fff
    style D3 fill:#f39c12,color:#000
架构 代表模型 训练目标 擅长任务
Encoder-Only BERT, RoBERTa MLM + NSP 分类、NER、相似度
Decoder-Only GPT, Llama, Qwen CLM (Next Token) 文本生成、对话
Encoder-Decoder T5, BART Span Corruption 翻译、摘要、问答

现代优化技术

KV-Cache

在自回归生成中,每生成一个新 token,之前 token 的 Key 和 Value 无需重新计算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class CachedMultiHeadAttention(nn.Module):
def forward(self, x, kv_cache=None):
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)

if kv_cache is not None:
# Append new KV to cache
K = torch.cat([kv_cache['K'], K], dim=-2)
V = torch.cat([kv_cache['V'], V], dim=-2)

new_cache = {'K': K, 'V': V}

# Only compute attention for new tokens
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
output = torch.matmul(attn, V)

return output, new_cache

Flash Attention

Flash Attention 通过分块计算和 IO 感知的算法设计,在不牺牲精度的情况下将 Attention 计算速度提升 2-4 倍:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Using Flash Attention 2
from flash_attn import flash_attn_func

# Significantly faster and more memory efficient
output = flash_attn_func(
q, k, v,
dropout_p=0.0,
causal=True, # For decoder
softmax_scale=1.0 / math.sqrt(d_k),
)

# Key optimizations:
# 1. Tiling: process attention in blocks that fit in SRAM
# 2. Recomputation: trade compute for memory in backward pass
# 3. Kernel fusion: fuse softmax and matmul into one kernel

GQA (Grouped Query Attention)

graph TB
    subgraph "MHA (Multi-Head)"
        A1[Q1 K1 V1] --- A2[Q2 K2 V2] --- A3[Q3 K3 V3] --- A4[Q4 K4 V4]
    end

    subgraph "MQA (Multi-Query)"
        B1[Q1] --- B5[K V<br/>shared]
        B2[Q2] --- B5
        B3[Q3] --- B5
        B4[Q4] --- B5
    end

    subgraph "GQA (Grouped-Query)"
        C1[Q1 Q2] --- C5[K1 V1<br/>Group 1]
        C3[Q3 Q4] --- C6[K2 V2<br/>Group 2]
    end

    style A1 fill:#e74c3c,color:#fff
    style B5 fill:#3498db,color:#fff
    style C5 fill:#2ecc71,color:#fff
    style C6 fill:#2ecc71,color:#fff

GQA 在 MHA(每个 head 独立 KV)和 MQA(所有 head 共享 KV)之间取平衡:

  • Llama 2 70B 使用 8 个 KV head,64 个 Query head
  • 大幅减少 KV-Cache 内存,同时保持接近 MHA 的效果

总结

Transformer 的成功源于其灵活且强大的注意力机制——Self-Attention 实现了全局信息交互,Multi-Head Attention 学习多种注意力模式,位置编码注入序列位置信息。理解这些核心组件是深入 LLM 领域的基础。

现代 Transformer 的关键优化方向:

  1. RoPE 替代绝对位置编码,支持长序列外推
  2. SwiGLU 替代 ReLU FFN,提升模型容量
  3. GQA 减少推理时的 KV-Cache 内存
  4. Flash Attention 加速训练和推理
  5. Pre-LN (在 Attention 前做 LayerNorm) 比 Post-LN 训练更稳定
作者 · authorzt
发布 · date2024-09-18
篇幅 · length2.8k 字 · 7 min
许可 · licenseCC BY-SA 4.0
$ echo "comments" · 评论