揭秘大模型的魔法:RoPE(旋轉位置編碼)是怎么讓 AI 記住“前后左右”的?

大家好,我是寫代碼的中年人!
本章我們介紹RoPE的概念,為后續(xù)的模型實現(xiàn)打好基礎知識。
在大模型中,怎么知道一個詞在句子里是第幾個?這就像是:“你看一本書的時候,怎么知道這一句話是在第一頁還是最后一頁?”
嗯……要是沒有“位置感”,模型看到的只是一個個孤立的詞向量,就像你看一本書時只看到拆散的字母。
早期 Transformer 用的是正弦位置編碼(Sinusoidal Positional Encoding),相當于在每個單詞的向量里混入一點“坐標信號”,讓模型知道它在句子里的位置。
但這玩意有個小問題:它是直接加在詞向量上的,像在紙條上貼標簽模型能用,但對長文本、超長依賴不太友好,而且它的“旋律”是固定的,靈活性有限,后來人們發(fā)明了 RoPE(Rotary Position Embedding),中文名“旋轉位置編碼”,它干的事情很酷:不是直接給你貼標簽,而是給你旋轉一個角度。
01、什么是RoPE
比如我們在組織一場廣場舞,舞者(詞向量)站在一個圓圈上,每兩人一組。音樂響起時,每組舞者按自己的節(jié)奏旋轉:
第一組慢轉(低頻,0°、10°、20°……),第二組稍快(30°、60°、90°……),以此類推。
模型通過比較兩組舞者的相對旋轉角度(相位差),就能知道他們離得多遠,角度差越大,距離越遠。旋轉是連續(xù)的,即使新舞者加入(序列變長),他們也能按同樣規(guī)則旋轉,模型無需重新學習,就能判斷新舞者的位置。
這就是RoPE的厲害之處:通過旋轉的相位編碼位置,高效捕捉相對距離,還能適應超長序列!
02、RoPE數(shù)學公式
在自注意力里,Q(查詢向量)和 K(鍵向量)要做點積:

RoPE 的思路是:在做這個點積前,先把 Q、K 各自旋轉一下:

這里的 R(θ) 是一個二維旋轉矩陣(其實是對向量的每一對分量旋轉):

而 θ 跟位置 p 有關,比如:

α是不同維度的頻率參數(shù),就像正弦編碼那樣安排。
結果是什么?
相對位置關系自然保留:Q 和 K 的旋轉差值就是它們的相對位置
外推能力強:沒見過的長序列也能用,因為旋轉是周期性的
03、用代碼實現(xiàn)RoPE
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import jieba
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文字體
plt.rcParams['axes.unicode_minus'] = False
# ===== 準備《水滸傳》樣本文本 =====
text_samples = [
"""張?zhí)鞄熎盱烈撸樘菊`走妖魔。話說大宋天子仁宗皇帝在位年間,
京師瘟疫流行,百姓多有染病。天子召張?zhí)鞄熑雽m祈禳,命洪太尉押送香火,
不料誤開封印,放出妖魔。""",
"""王教頭私走延安府,九紋龍大鬧史家村。史進自幼好武,學成十八般武藝,
因打死惡霸,被官府緝拿。王進教頭見勢不妙,離開東京前往延安府,
途經(jīng)史家村。""",
"""史大郎夜走華陰縣,魯提轄拳打鎮(zhèn)關西。史進與魯達結義,路遇鎮(zhèn)關西鄭屠,
見其欺壓婦女,魯達憤然出手,三拳打死鄭屠,遂落草為寇。"""
]
# ===== 中文分詞 =====
def tokenize_texts(text_list):
tokenized = []
for t in text_list:
words = list(jieba.cut(t))
words = [w.strip() for w in words if w.strip()]
tokenized.append(words)
return tokenized
sentences = tokenize_texts(text_samples)
# ===== 構建詞表 =====
vocab = {}
for sent in sentences:
for w in sent:
if w not in vocab:
vocab[w] = len(vocab)
vocab["<PAD>"] = len(vocab)
vocab_size = len(vocab)
embed_dim = 32
seq_len = max(len(s) for s in sentences)
# 將句子轉為索引,并pad
def encode_sentences(sentences, vocab, seq_len):
data = []
for s in sentences:
idxs = [vocab[w] for w in s]
if len(idxs) < seq_len:
idxs += [vocab["<PAD>"]] * (seq_len - len(idxs))
data.append(idxs)
return torch.tensor(data)
input_ids = encode_sentences(sentences, vocab, seq_len)
# ===== RoPE實現(xiàn) =====
def apply_rope(x):
"""
支持輸入維度:
- (B, T, D) 或
- (B, T, H, D)
返回相同形狀,且對最后一維做 RoPE(要求 D 為偶數(shù))
"""
orig_shape = x.shape
if len(orig_shape) == 3:
# (B, T, D) -> 轉為 (B, T, 1, D) 方便統(tǒng)一處理
x = x.unsqueeze(2)
squeezed = True
else:
squeezed = False
# 形狀為 (B, T, H, D)
# 現(xiàn)在 x.shape = (B, T, H, D)
bsz, seqlen, nheads, head_dim = x.shape
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
device = x.device
dtype = x.dtype
half = head_dim // 2
# theta: (half,)
theta = 10000 ** (-torch.arange(0, half, device=device, dtype=dtype) / half) # (half,)
# seq positions: (seqlen,)
seq_idx = torch.arange(seqlen, device=device, dtype=dtype) # (seqlen,)
# freqs: (seqlen, half)
freqs = torch.einsum('n,d->nd', seq_idx, theta)
cos = freqs.cos().view(1, seqlen, 1, half) # (1, T, 1, half)
sin = freqs.sin().view(1, seqlen, 1, half) # (1, T, 1, half)
x1 = x[..., :half] # (B, T, H, half)
x2 = x[..., half:] # (B, T, H, half)
x_rotated = torch.cat([x1 * cos - x2 * sin,
x1 * sin + x2 * cos], dim=-1) # (B, T, H, D)
if squeezed:
x_rotated = x_rotated.squeeze(2) # back to (B, T, D)
return x_rotated
# ===== 多頭注意力 with RoPE =====
class MultiHeadSelfAttentionRoPE(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout = dropout
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.last_attn_weights = None
def forward(self, x):
B, T, C = x.size()
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim)
v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim)
# 應用 RoPE
q = apply_rope(q)
k = apply_rope(k)
# 注意力計算
attn_scores = torch.einsum('bthd,bshd->bhts', q, k) / (self.head_dim ** 0.5)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
self.last_attn_weights = attn_weights.detach()
out = torch.einsum('bhts,bshd->bthd', attn_weights, v)
out = out.reshape(B, T, C)
return self.out_proj(out)
# ===== 模型訓練 =====
embedding = nn.Embedding(vocab_size, embed_dim)
model = MultiHeadSelfAttentionRoPE(embed_dim, num_heads=4, dropout=0.1)
criterion = nn.MSELoss()
optimizer = optim.Adam(list(model.parameters()) + list(embedding.parameters()), lr=1e-3)
epochs = 200
for epoch in range(epochs):
model.train()
x = embedding(input_ids)
target = x.clone()
out = model(x)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.6f}")
# ===== 注意力熱圖可視化 =====
def plot_attention(attn, sentence_tokens, filename):
heads = attn.shape[0]
fig, axes = plt.subplots(1, heads, figsize=(4*heads, 4))
if heads == 1:
axes = [axes]
for h in range(heads):
ax = axes[h]
attn_head = attn[h].numpy()
im = ax.imshow(attn_head, cmap='viridis')
ax.set_xticks(np.arange(len(sentence_tokens)))
ax.set_yticks(np.arange(len(sentence_tokens)))
ax.set_xticklabels(sentence_tokens, rotatinotallow=90)
ax.set_yticklabels(sentence_tokens)
ax.set_title(f"Head {h+1}")
fig.colorbar(im, ax=ax)
plt.tight_layout()
plt.savefig(filename)
plt.close()
model.eval()
with torch.no_grad():
x = embedding(input_ids)
_ = model(x)
attn_weights = model.last_attn_weights # (batch, heads, seq, seq)
for i, tokens in enumerate(sentences):
attn = attn_weights[i]
plot_attention(attn.cpu(), tokens, f"rope_attention_sentence{i+1}.png")
print("RoPE多頭注意力熱圖已生成,文件名為 rope_attention_sentenceX.png")


結束語
經(jīng)過這次實戰(zhàn),我們不僅從《水滸傳》的古文中“偷”來了一點文學氣息,還把它喂進了現(xiàn)代的多頭自注意力網(wǎng)絡里,加上 RoPE 旋轉位置編碼,讓模型在捕捉長距離依賴關系時不再“迷路”。
通過可視化注意力熱圖,我們能直觀看到詞與詞之間的微妙聯(lián)系,就像在顯微鏡下觀察一場無聲的對話。這一切的意義,不僅僅是跑通了一段代碼,更是把理論、實現(xiàn)與效果驗證串成了一條完整的鏈條。
接下來,我們繼續(xù)探索更多高級技巧!技術的世界沒有終點,只有下一段旅程。

































