我第一次接触 Transformer 时,记住的是架构图,但始终没真正弄明白数据是怎么在 Attention 里流动的。最近重新读了 llm.c 和 LLMs-from-scratch,才意识到,理解 Attention 最有效的方式不是反复看图,而是盯着输入张量如何一步步变成输出张量。下面我不讲历史,只讲计算。(如果你希望先看例子,请跳到后面举例说明部分)
我们先明确一下 Attention 层的输入和输出分别是什么
- 输入:多个 token 组成的张量,输入为 X∈RT×dmodel
- T:序列长度,也就是 token 个数
- dmodel:每个 token 的表示维度
- 输出:对序列中的每个 token,计算它应该从自己以及其他相关 token 聚合多少信息,并形成新的表示
对单个 token 而言,会得到三个向量:Query、Key、Value;当把整个序列放在一起看,就对应三个矩阵 Q,K,V。
- Query:我该关注谁
- Key:我提供什么特征供别人匹配
- Value:如果别人关注我,我实际提供什么信息
Q=XWQ,K=XWK,V=XWV
其中:
- WQ∈Rdmodel×dk
- WK∈Rdmodel×dk
- WV∈Rdmodel×dv
注意力分数到底怎么算,对第 i 个 token 和第 j 个 token,它们的匹配分数是:
sij=qi⋅kj
把所有 token 两两计算,就得到分数矩阵:
S=QKT
当 dk 较大时,点积的数值方差会增大,softmax 更容易变得过于尖锐,因此通常要除以 dk 进行缩放。
S=dkQKT
如果把 Q 的第 i 行记作 qi,把 K 的第 j 行记作 kj,那么矩阵 QKT 的第 i,j 个元素就是 qi⋅kj。也就是说,分数矩阵的第 i 行表示:第 i 个 token 对整个序列中所有 token 的关注程度。
对每一行做 softmax:
A=softmax(dkQKT+M)
这里 Aij 表示:第 i 个 token 应该从第 j 个 token 读取多少信息。
最后用这些权重对 V 做加权求和:
O=AV
其中 O∈RT×dv。
因果掩码自注意力
在语言模型生成场景中,第 i 个 token 不应该看到未来 token 的信息,因此在分数矩阵 S 上加一个上三角 mask。被遮住的位置设为一个极小值,如 −∞,再做 softmax。
A=softmax(dkQKT+M)
其中 M 是 mask 矩阵,未来位置为 −∞,其他位置为 0。
这样 softmax 之后,未来 token 对应的权重就会变成 0,因此每个位置只能关注自己及之前的位置。
单头 attention 只在一个表示子空间里做信息聚合。多头 attention 的做法是,用多组不同的投影矩阵,把输入映射到多个不同的子空间中,分别计算 attention,再把各头结果拼接起来。
对第 h 个头:
Q(h)=XWQ(h),K(h)=XWK(h),V(h)=XWV(h)
每个头独立计算:
headh=softmax(dkQ(h)K(h)T)V(h)
最后:
MultiHead(X)=Concat(head1,…,headH)WO
其中 WO 把拼接后的结果映射回 dmodel 维。
我将选取 因果掩码自注意力 来说明整个计算过程。
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
[0.55, 0.87, 0.66],
[0.57, 0.85, 0.64],
[0.22, 0.58, 0.33],
[0.77, 0.25, 0.10],
[0.05, 0.80, 0.55]]
)
这里一共有 6 个 token,每个 token 是一个 3 维向量,因此:
- 序列长度 T=6
- 输入维度 din=3
为了把注意力集中在 attention 本身的计算流程上,这里先固定一组最简单的投影矩阵:
WQ=WK=WV=I3
也就是说,Query、Key、Value 都直接等于输入本身:
Q=X,K=X,V=X
因此:
W_Q = torch.eye(3)
W_K = torch.eye(3)
W_V = torch.eye(3)
Q = inputs @ W_Q
K = inputs @ W_K
V = inputs @ W_V
所以:
Q=K=V=0.430.550.570.220.770.050.150.870.850.580.250.800.890.660.640.330.100.55
因为这里 dk=3,所以缩放后的注意力分数矩阵为:
S=3QKT
对应代码:
scores = Q @ K.T / (K.shape[-1] ** 0.5)
数值上,约等于:
S≈0.5770.5510.5440.2740.2640.3640.5510.8630.8520.4870.4080.6270.5440.8520.8410.4790.4130.6120.2740.4870.4790.2850.2010.3790.2640.4080.4130.2010.3840.1690.3640.6270.6120.3790.1690.546
这个矩阵的第 i 行表示:第 i 个 token 对所有 token 的原始匹配分数。
比如第 2 个 token(journey)对第 1 个 token(Your)的分数是:
s21=3x(2)⋅x(1)=30.55⋅0.43+0.87⋅0.15+0.66⋅0.89≈0.551
而它对自己的分数是:
s22=3x(2)⋅x(2)≈0.863
这说明第 2 个 token 当前更“匹配”自己,而不是第 1 个 token。
由于是语言模型中的 causal self-attention,第 i 个 token 不能看到未来位置,所以要加上一个上三角 mask:
M=000000−∞00000−∞−∞0000−∞−∞−∞000−∞−∞−∞−∞00−∞−∞−∞−∞−∞0
对应代码:
T = scores.shape[0]
mask = torch.triu(torch.ones(T, T), diagonal=1)
scores = scores.masked_fill(mask.bool(), float("-inf"))
加上 mask 后,分数矩阵变成:
Smasked≈0.5770.5510.5440.2740.2640.364−∞0.8630.8520.4870.4080.627−∞−∞0.8410.4790.4130.612−∞−∞−∞0.2850.2010.379−∞−∞−∞−∞0.3840.169−∞−∞−∞−∞−∞0.546
A=softmax(Smasked)
对应代码:
attn_weights = torch.softmax(scores, dim=-1)
得到的注意力权重大约是:
A≈1.0000.4230.2700.2230.1860.15100.5770.3670.2760.2150.197000.3630.2740.2160.1940000.2260.1740.15300000.2100.124000000.181
这个矩阵就是真正的“读信息比例”。
例如第 2 行:
[0.423, 0.577, 0, 0, 0, 0]
表示第 2 个 token 在更新自己时:
- 从第 1 个 token 读取约 42.3% 的信息
- 从自己读取约 57.7% 的信息
- 完全不能读第 3 到第 6 个 token,因为它们属于未来位置
最后计算:
O=AV
对应代码:
context_vec = attn_weights @ V
结果约为:
O≈0.4300.4990.5250.4540.5210.4220.1500.5660.6680.6380.5510.6230.8900.7570.7150.6310.5240.551
其中第 2 个 token 的输出向量可以展开写成:
o2=0.423x(1)+0.577x(2)
代入数值:
o2=0.4230.430.150.89+0.5770.550.870.66≈0.4990.5660.757
这就说明:attention 的输出,本质上是“当前位置对过去若干 token 的 Value 做加权平均”。
- 第一,第 1 个 token 只能看到自己,所以它的输出和输入完全一样:
o1=x(1)
[0.151, 0.197, 0.194, 0.153, 0.124, 0.181]
- 说明它会综合前面所有 token,再形成自己的新表示。
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
[0.55, 0.87, 0.66],
[0.57, 0.85, 0.64],
[0.22, 0.58, 0.33],
[0.77, 0.25, 0.10],
[0.05, 0.80, 0.55]]
)
W_Q = torch.eye(3)
W_K = torch.eye(3)
W_V = torch.eye(3)
Q = inputs @ W_Q
K = inputs @ W_K
V = inputs @ W_V
scores = Q @ K.T / (K.shape[-1] ** 0.5)
T = scores.shape[0]
mask = torch.triu(torch.ones(T, T), diagonal=1)
scores = scores.masked_fill(mask.bool(), float("-inf"))
attn_weights = torch.softmax(scores, dim=-1)
context_vec = attn_weights @ V
print("Q =\n", Q)
print("K =\n", K)
print("V =\n", V)
print("attention scores =\n", scores)
print("attention weights =\n", attn_weights)
print("context vectors =\n", context_vec)