揭秘大模型的魔法:實現帶有可訓練權重的多頭自注意力機制

大家好,我是寫代碼的中年人!
自注意力(Self-Attention)是大模型里最常讓人“眼花”的魔術道具:看起來只是一堆矩陣乘法和 softmax,可是組合起來就能學到“句子里誰重要、誰次要”的規則,甚至能學到某些頭只盯標點、某些頭專盯主謂關系。
今天我想把這塊魔術板拆開來給你看個究竟:如何把單頭注意力改成多頭注意力,讓每個頭能學會自己的注意力分布。
01、回顧單頭自注意力機制
假設你在開會,桌上有一堆文件,你想找跟“項目進度”相關的內容。
你心里有個問題(Query):“項目進度在哪兒?
”每份文件上有個標簽(Key),寫著它的主題,比如“預算”“進度”“人員”。
你會先挑出標簽里跟“進度”相關的文件(匹配),然后重點看這些文件的內容(Value),最后把這些內容總結成你的理解。
自注意力就像是給每個詞都做了一次這樣的“信息篩選和總結”,讓每個詞都能根據上下文更好地表達自己。
02、理解多頭自注意力機制
繼續用開會的場景:
桌上還是那堆文件(代表句子里的詞),但現在你不是一個人干活,而是找了3個助手(假設3頭注意力)。每個助手都有自己的“專長”,他們會從不同的角度問問題、匹配標簽和提取內容。
每個頭獨立工作(多視角篩選):
頭1(進度專家):他的問題(Query)是“進度怎么樣?”他只關注標簽里跟“進度”“時間表”相關的文件,忽略其他。挑出匹配的文件后,他總結出一份“進度報告”。
頭2(預算專家):他的問題是“預算超支了嗎?”他匹配標簽里的“預算”“開銷”,然后從那些文件的內容里提煉“預算分析”。
頭3(風險專家):問題是“有什么隱患?”他找“風險”“問題”相關的標簽,輸出一份“風險評估”。
每個頭都像單頭注意力一樣:生成自己的問題、鑰匙和內容,計算匹配度,加權總結。但他們用的“眼鏡”不同(在機器里,這通過不同的線性變換實現),所以捕捉的信息側重點不一樣。
把多頭結果合起來(綜合決策):
一旦每個頭都給出自己的總結,你就把這些報告拼在一起(或簡單平均一下),形成一份完整的“項目概覽”。現在,你的理解不只是“進度”,而是進度+預算+風險的全方位視圖。萬一某個頭漏了什么,其他頭能補上,確保沒死角。
03、用代碼實現多頭自注意力機制
# ONE
我們使用水滸傳的內容進行演示,使用前三回各 100 字的文本,并按“字”切分成模型可用的格式。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# ====== 準備水滸傳真實語料 ======
raw_texts = [
"話說大宋仁宗天子在位,嘉祐三年三月三日五更三點,天子駕坐紫宸殿,受百官朝賀。但見:祥雲迷鳳閣,瑞氣罩龍樓。含煙御柳拂旌旗,帶露宮花迎劍戟。天香影裏,玉簪珠履聚丹墀。仙樂聲中,繡襖錦衣扶御駕。珍珠廉卷,黃金殿上現金輿。鳳尾扇開,白玉階前停寶輦。隱隱凈鞭三下響,層層文武兩班齊。",
"那高俅在臨淮州,因得了赦宥罪犯,思鄉要回東京。這柳世權卻和東京城里金梁橋下開生藥鋪的董將士是親戚,寫了一封書札,收拾些人事盤纏,赍發高俅回東京,投奔董將士家過活。",
"話說當時史進道:「卻怎生是好?」朱武等三個頭領跪下答道:「哥哥,你是乾淨的人,休為我等連累了大郎。可把索來綁縛我三個,出去請賞,免得負累了你不好看?!?
]
# ====== 按字切分 ======
def char_tokenize(text):
return [ch for ch in text if ch.strip()] # 去掉空格、換行
sentences = [char_tokenize(t) for t in raw_texts]
# 構建詞表
vocab = {}
for sent in sentences:
for ch in sent:
if ch not in vocab:
vocab[ch] = len(vocab)
# ====== 轉成索引形式并做 padding ======
max_len = max(len(s) for s in sentences)
PAD_TOKEN = "<PAD>"
vocab[PAD_TOKEN] = len(vocab)
input_ids = []
for sent in sentences:
ids = [vocab[ch] for ch in sent]
# padding
ids += [vocab[PAD_TOKEN]] * (max_len - len(ids))
input_ids.append(ids)
input_ids = torch.tensor(input_ids) # (batch_size, seq_len)
# ====== 多頭自注意力模塊 ======
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim 必須能整除 num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = dropout
self.last_attn_weights = None # 保存最后一次注意力權重 (batch, heads, seq, seq)
def forward(self, x):
B, T, C = x.size()
Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
self.last_attn_weights = attn_weights.detach() # (B, heads, T, T)
out = torch.matmul(attn_weights, V)
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.out_proj(out)
return out
# ====== 模型訓練 ======
embed_dim = 32
num_heads = 4
vocab_size = len(vocab)
embedding = nn.Embedding(vocab_size, embed_dim)
model = MultiHeadSelfAttention(embed_dim, num_heads)
criterion = nn.MSELoss()
optimizer = optim.Adam(list(model.parameters()) + list(embedding.parameters()), lr=1e-3)
epochs = 200
for epoch in range(epochs):
model.train()
x = embedding(input_ids)
target = x.clone()
out = model(x)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1:3d}, Loss: {loss.item():.6f}")
# ====== 可視化注意力熱圖 ======
for idx, sent in enumerate(sentences):
attn = model.last_attn_weights[idx] # (heads, seq, seq)
sent_len = len(sent)
for head in range(num_heads):
plt.figure(figsize=(8, 6))
plt.imshow(attn[head, :sent_len, :sent_len].numpy(), cmap='viridis')
plt.title(f"第{idx+1}句 第{head+1}頭 注意力矩陣")
plt.xticks(ticks=np.arange(sent_len), labels=sent, rotatinotallow=90)
plt.yticks(ticks=np.arange(sent_len), labels=sent)
plt.xlabel("Key (字)")
plt.ylabel("Query (字)")
plt.colorbar(label="Attention Strength")
for i in range(sent_len):
for j in range(sent_len):
plt.text(j, i, f"{attn[head, i, j]:.2f}", ha="center", va="center", color="white", fnotallow=6)
plt.tight_layout()
plt.savefig(f"attention_sentence{idx+1}_head{head+1}.png")
plt.close()
print("注意力熱圖已保存。")











這些多頭自注意力(Multi-Head Self-Attention)的熱圖,其實是一個“誰在關注誰”的可視化工具,用來直觀展示模型在處理文本時的注意力分布。
熱圖上的顏色:
橫軸(Key):表示句子中被關注的字,
縱軸(Query):表示當前在思考的字,
顏色深淺:表示注意力強度,越亮的地方代表這個 Query 在計算時更關注這個 Key。
例如,如果“宋”字在看“天”字時顏色很亮,說明模型覺得“天”這個字對理解“宋”有重要信息。因為是古文,有時模型會捕捉到常見的修辭搭配,比如“天子”“鳳閣”,這時候相鄰的字之間注意力會很高。
為什么會有多張圖:
每一行熱圖對應一句文本(水滸前三回的一個片段)
每句話會畫多個頭的熱圖:
多頭機制的設計就是讓不同的頭學習到不同的關注模式
舉個例子:
Head 1 可能更多關注相鄰的字(局部模式)
Head 2 可能更關注句首或特定關鍵詞(全局模式)
Head 3 可能專注某個語法結構
Head 4 可能專注韻律、排比等古文特性
多頭機制就像多雙眼睛,從不同角度觀察同一句話。
舉個大家都能理解的例子:
學生(Query):舉手發言
老師(Attention):環顧四周,看看應該關注哪個學生(Key)
不同的老師(Head)關注點不同:一個老師喜歡看前排學生(局部依賴)一個老師總是看坐在角落的安靜同學(遠距離依賴)還有老師會特別注意那些名字里有“天”“龍”這些關鍵字的學生
(關鍵觸發詞)顏色越亮,表示老師對這個學生說的話越感興趣。
結束語
回到開頭我們的問題:多頭自注意力到底在看什么?通過水滸傳這樣真實、結構獨特的古文片段,我們不僅看到了模型如何在字與字之間建立聯系,還直觀感受了不同“注意力頭”各自的關注模式。有人關注近鄰字,有人專注關鍵字,有人把目光投向整句的節奏與意境。
這就像課堂上不同的老師一樣——他們的視角不同,但共同構成了對整篇文章的完整理解。這種可視化,不只是為了“看個熱鬧”,而是把模型內部的決策過程攤開給人看,讓深度學習的“黑箱”多了一點可解釋性。
至此,我們用水滸的詩意古文,讓多頭自注意力的數學公式“活”了起來。接下來,我們將整合所有已學過的文章,去實現一個生成模型。


































