The Annotated Transformer注釋加量版,讀懂代碼就真的懂了Transformer 原創(chuàng)
本文是在The Annotated Transformer這篇文章基礎(chǔ)上的二次加工。
1.給代碼加了更詳細(xì)的注釋。
2.輸出詳細(xì)日志跟蹤數(shù)據(jù)。
原文地址:https://nlp.seas.harvard.edu/annotated-transformer/
或者后臺(tái)回復(fù)taf獲取pdf下載鏈接。
The Andnotated Transformer
Attention is All You Need

- v2022: Austin Huang, Suraj Subramanian, Jonathan Sum, Khalid Almubarak, and Stella Biderman.
- Original: Sasha Rush
閱讀方法
由于原文內(nèi)容過長(zhǎng),我沒有把原文拷貝過來,閱讀本文時(shí),請(qǐng)打開原文鏈接或者我添加注釋的notebook。
1、給代碼加了更詳細(xì)的注釋。
原文基于pytorch從0開始復(fù)現(xiàn)了transformer模型,我在原文代碼基礎(chǔ)上追加了更詳細(xì)的注釋,代碼可以在下面鏈接找到。https://github.com/AIDajiangtang/annotated-transformer/blob/master/AnnotatedTransformer_comment.ipynb
另外,我還在模型結(jié)構(gòu)上加了注釋,我將代碼中重要的類名或者函數(shù)名標(biāo)注在Transforner結(jié)構(gòu)的圖片上,閱讀代碼時(shí)請(qǐng)結(jié)合圖片上的名稱,這樣有助于快速理解代碼。

2、輸出日志跟蹤數(shù)據(jù)。
原文提供了一個(gè)訓(xùn)練德譯英模型的代碼,我在此基礎(chǔ)上加了一些日志,打印數(shù)據(jù)的維度來輔助對(duì)Transformer的理解。

我將按照?qǐng)D片上標(biāo)注數(shù)字順序來跟蹤數(shù)據(jù)。
原始論文中,Transformer是一種Encoder-Decoder架構(gòu),左邊是Encoder,用于提取源語(yǔ)言的表征,右邊是Decoder,根據(jù)表征結(jié)合目標(biāo)語(yǔ)言語(yǔ)法生成目標(biāo)語(yǔ)言。
先從Encoder這邊開始。
0、Inputs:
假設(shè)batch size為2,所以每個(gè)batch包含兩個(gè)樣本,每個(gè)樣本由(德語(yǔ),英語(yǔ))文本對(duì)組成。
[
('Eine gro?e Gruppe Jugendlicher in einem kleinen Unterhaltungsbereich.', 'A large group of young adults are crammed into an area for entertainment.'),
('Zwei Arbeiter stellen Laternen auf.', 'Two workers working on putting up lanterns.')
](batch size的意義:模型每次都是基于batch size個(gè)樣本的損失來更新參數(shù),batch size需要根據(jù)內(nèi)存,顯存大小確定)
對(duì)于Encoder而言,它只需要源語(yǔ)言,也就是德語(yǔ)。
'Eine gro?e Gruppe Jugendlicher in einem kleinen Unterhaltungsbereich.'
'Zwei Arbeiter stellen Laternen auf'
1、Embedding:
1.1.先將文本轉(zhuǎn)換成tokens,并添加起始和結(jié)束符token。
(load_tokenizers函數(shù),
tokenize函數(shù),
build_vocabulary函數(shù)}
["<s>", "</s>", "<blank>", "<unk>"]起始符token id:0,結(jié)束符token id:1,padding token id:2
'Eine gro?e Gruppe Jugendlicher in einem kleinen Unterhaltungsbereich.'的tokens如下
torch.Size([11])
tensor([ 0, 14, 176, 38, 683, 7, 6, 116, 7147, 4, 1],
device='cuda:0')(通過結(jié)果看是基于詞的tokenization方法)
1.2.因?yàn)槲谋鹃L(zhǎng)度不一致,通過padding的方式將序列長(zhǎng)度統(tǒng)一為72。
{collate_batch函數(shù)}
(padding不是必須的,只是出于方便和效率考慮,72是個(gè)經(jīng)驗(yàn)值,通過對(duì)訓(xùn)練數(shù)據(jù)的統(tǒng)計(jì)得出)
'Eine gro?e Gruppe Jugendlicher in einem kleinen Unterhaltungsbereich.'padding后的tokens如下
torch.Size([72])
[tensor([ 0, 14, 176, 38, 683, 7, 6, 116, 7147, 4, 1, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
device='cuda:0')]一個(gè)batch下有兩個(gè)樣本,對(duì)另一個(gè)樣本的德語(yǔ)進(jìn)行同樣的轉(zhuǎn)換最終得到編碼器輸入:X,維度[2, 72]。
在訓(xùn)練過程中,無論是計(jì)算注意力還是交叉注意力,每個(gè)樣本是相互獨(dú)立的,所以可以將一個(gè)batch下所有數(shù)據(jù)組織成矩陣的形式輸入到模型進(jìn)行并行計(jì)算。
1.3.最后將上一步的tokens通過一個(gè)Embedding線性層轉(zhuǎn)換成詞嵌入,設(shè)置d_model=512,所以詞嵌入維度為512。
{Embeddings類}
Embedding層輸入就是前面的X;維度是torch.Size([2, 72])。
Embedding層的輸出維度是torch.Size([2, 72,512]),也就是每個(gè)token id都被轉(zhuǎn)換成512維的向量。
tensor([[[-0.6267, -0.0099, 0.3444, ..., 0.5949, -0.4107, -0.6037],
[ 0.4183, -0.1788, -0.3128, ..., 0.5363, -0.5519, 0.4621],
[ 0.4645, -0.2748, -0.4109, ..., -0.6270, 0.4595, -0.4259],
...,
[-0.1489, 0.6431, -0.0301, ..., -0.0163, 0.4261, 0.3066],
[-0.1489, 0.6431, -0.0301, ..., -0.0163, 0.4261, 0.3066],
[-0.1489, 0.6431, -0.0301, ..., -0.0163, 0.4261, 0.3066]],
[[-0.6267, -0.0099, 0.3444, ..., 0.5949, -0.4107, -0.6037],
[-0.2121, 0.4323, -0.0869, ..., 0.1337, -0.2679, -0.4689],
[ 0.0751, -0.1048, -0.1263, ..., -0.5541, -0.4463, 0.5209],
...,
[-0.1489, 0.6431, -0.0301, ..., -0.0163, 0.4261, 0.3066],
[-0.1489, 0.6431, -0.0301, ..., -0.0163, 0.4261, 0.3066],
[-0.1489, 0.6431, -0.0301, ..., -0.0163, 0.4261, 0.3066]]],
device='cuda:0', grad_fn=<MulBackward0>)(Embedding過程相當(dāng)于用512個(gè)屬性值表示單詞的語(yǔ)義信息,經(jīng)過每個(gè)EncoderLayner時(shí)屬性值會(huì)被修改,使其充分吸收上下文信息,屬性越多,能表示的語(yǔ)音信息越豐富,但計(jì)算量和參數(shù)也會(huì)增加)
2、PositionalEncoding
{PositionalEncoding類}
在計(jì)算注意力分?jǐn)?shù)時(shí),如果調(diào)整單詞的位置,注意力的輸出結(jié)果不變,也就是自注意力這種計(jì)算方式?jīng)]有考慮單詞的位置信息。
所以需要通過一個(gè)額外的位置編碼,位置編碼與詞嵌入維度相同,也是512維向量,最后與詞嵌入相加。
前面Embedding層輸出維度torch.Size([2, 72, 512]),將其與位置編碼相加,輸出也是torch.Size([2, 72, 512])。
(位置編碼可以通過訓(xùn)練方法得到,也可以采用固定計(jì)算方式,本例采用固定計(jì)算方式)
所有樣本共用同一個(gè)位置編碼,本例序列長(zhǎng)度為72,可以提前計(jì)算好位置編碼備用。

pos表示位置,第一個(gè)詞位置是0,第二個(gè)詞位置是1....本例中就是0-71。
對(duì)于512維向量,偶數(shù)位置和奇數(shù)位置的值分別用上面兩個(gè)公式計(jì)算。
tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,
0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 8.2186e-01, ..., 1.0000e+00,
1.0366e-04, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 9.3641e-01, ..., 1.0000e+00,
2.0733e-04, 1.0000e+00],
...,
[-8.9793e-01, 4.4014e-01, 3.6763e-01, ..., 9.9997e-01,
7.0490e-03, 9.9998e-01],
[-1.1478e-01, 9.9339e-01, -5.5487e-01, ..., 9.9997e-01,
7.1527e-03, 9.9997e-01],
[ 7.7389e-01, 6.3332e-01, -9.9984e-01, ..., 9.9997e-01,
7.2564e-03, 9.9997e-01]]], device='cuda:0')可視化出來就是下面效果。

(上圖每一行都是一個(gè)位置編碼向量,一共生成50個(gè)位置編碼,每個(gè)位置編碼是128維向量,而本例需要生成72個(gè),每個(gè)512維)
3.MultiHeadedAttention
{MultiHeadedAttention類,
attention函數(shù)}

MultiHeadedAttention類的輸入是query, key, value,維度都是torch.Size([2, 72, 512]),其實(shí)他們的內(nèi)容也是一樣的,就是上一步輸出的Embedding+位置編碼。
然后query, key, value分別經(jīng)過一個(gè)獨(dú)立的線性層,線性層的維度[512, 512],兩個(gè)樣本的[72, 512]分別與[512, 512]矩陣乘法,所以線性層的輸出維度仍是[2, 72, 512],最后經(jīng)過reshape和轉(zhuǎn)置將[2, 72, 512]轉(zhuǎn)換成torch.Size([2, 8, 72, 64]),8代表有8個(gè)頭,其實(shí)就是將512轉(zhuǎn)換成了8*64來實(shí)現(xiàn)多頭注意力機(jī)制。
(雖然是8個(gè)頭,但與一個(gè)頭的情況相比,參數(shù)并沒有增加)
接下來計(jì)算單個(gè)頭的注意力,Attention函數(shù)的輸入query, key, value的維度都是torch.Size([2, 8, 72, 64]),注意力分?jǐn)?shù)矩陣維度torch.Size([2, 8, 72, 72]),輸出torch.Size([2, 8, 72, 64])。

最后將多個(gè)頭的輸出拼接在一起,也就是通過reshape和轉(zhuǎn)置將torch.Size([2, 8, 72, 72])轉(zhuǎn)換成[2, 72, 512],最后經(jīng)過一個(gè)[512, 512]的線性層輸出[2, 72, 512]。
4、SublayerConnection
{SublayerConnection類}
將多頭注意力的輸出經(jīng)過層歸一化和輸入進(jìn)行殘差鏈接,不改變維度,輸入輸出都是[2, 72, 512]。
5、PositionwiseFeedForward
{PositionwiseFeedForward類}

這其實(shí)是一個(gè)MLP層,輸入維度512,隱藏層維度2048,輸出層維度512,也就是2*72個(gè)tokens并行與[512, 2048]矩陣乘升維至[2, 72, 2048],然后再與矩陣[2048,512]乘恢復(fù)到原來維度[2, 72, 512]。最后再經(jīng)過層歸一化和殘差鏈接。
6、EncoderLayer
{EncoderLayer類}
將3,4,5重復(fù)6次,這里需要注意下,這6個(gè)EncoderLayer只是結(jié)構(gòu)一致,但參數(shù)是獨(dú)立的,原始的Embedding經(jīng)過6個(gè)EncoderLayer后維度是不變的,仍然是[2, 72, 512],只不過內(nèi)容被改變了。
7、LaynerNorm
{LayerNorm類}
為了計(jì)算穩(wěn)定,整個(gè)Encoder的輸出會(huì)再次經(jīng)過層歸一化處理,然后輸入到Decoder層作為key和value,維度仍然是[2, 72, 512]。
Encoder把key和value傳遞給Decoder,它的使命就算完成了。剩下的就是根據(jù)那邊的損失等著更新參數(shù)了。
讓我們來到Decoder這邊。
0、Inputs:
[
('Eine gro?e Gruppe Jugendlicher in einem kleinen Unterhaltungsbereich.', 'A large group of young adults are crammed into an area for entertainment.'),
('Zwei Arbeiter stellen Laternen auf.', 'Two workers working on putting up lanterns.')
]對(duì)于Decoder,除了Encoder的key和value,還要有query,這個(gè)query就是目標(biāo)語(yǔ)言,也就是英語(yǔ)。
'A large group of young adults are crammed into an area for entertainment.'
'Two workers working on putting up lanterns.'
1、Embedding
Decoder和Encoder的Embedding幾乎一致,也是先轉(zhuǎn)換成tokens。
'A large group of young adults are crammed into an area for entertainment.'->tokens
torch.Size([16])
tensor([ 0, 6, 62, 39, 13, 25, 348, 17, 5318, 71, 28, 179,
55, 4285, 5, 1], device='cuda:0')然后進(jìn)行padding。
'A large group of young adults are crammed into an area for entertainment.'->padding tokens
torch.Size([72])
[tensor([ 0, 6, 62, 39, 13, 25, 348, 17, 5318, 71, 28, 179,
55, 4285, 5, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
device='cuda:0')]對(duì)另一個(gè)樣本進(jìn)行同樣的操作得到編碼器的輸入Y,維度[2, 72]。
最后將其轉(zhuǎn)換成Embedding,維度是torch.Size([2, 72, 512])。
但有一點(diǎn)需要注意。
Decoder在訓(xùn)練時(shí)輸入的是整個(gè)batch的英語(yǔ)文本,也就是torch.Size([2, 72, 512])。
但在訓(xùn)練過程中預(yù)測(cè)當(dāng)前token的輸出時(shí),為了讓其只能看到當(dāng)前以及之前位置的輸入,避免看到后面的內(nèi)容,需要采用遮罩的方式,也就是要構(gòu)造一個(gè)mask。
torch.Size([2, 72, 72])
tensor([[[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False]],
[[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False]]], device='cuda:0')
2、PositionalEncoding
與Encoder一樣,輸入輸出都是[2, 72, 512]
3、MultiHeadedAttention
Decoder中的DecoderLayner有兩個(gè)MultiHeadedAttention,第一個(gè)是Mask MultiHeadedAttention,與Encoder中的計(jì)算一致,只不過使用了上一步計(jì)算的Mask。
另一個(gè)MultiHeadedAttention中的key和value來自Encoder,我們稱之為交叉注意力,與自注意力要區(qū)分開,query來自前一層的輸出,維度都是[2, 72, 512]。
4,5,9,7和Encoder都是一樣的。
同樣輸入Embedding經(jīng)過6個(gè)DecoderLayner后維度不變[2, 72, 512]。
4、Generator
{Generator類}
這其實(shí)是一個(gè)沒有隱藏層的MLP,輸入維度512,輸出維度vocab,2*72個(gè)token的Embedding與矩陣[512,vocab]相乘,輸出[2, 72, vocab],vocab為詞表的單詞個(gè)數(shù),本例中英語(yǔ)單詞個(gè)數(shù)為6291。經(jīng)過softmax后輸出一個(gè)概率分布,最大概率對(duì)應(yīng)的位置的詞就是模型預(yù)測(cè)的下一個(gè)詞。
這樣就得到了Decoder的最終輸出,輸出可以是[2, 72],里面是英語(yǔ)詞表下的id。也可以是[2, 72, vocab]直接輸出概率分布,輸出形式不同,損失函數(shù)也是不同的。
對(duì)于其中一個(gè)樣本,訓(xùn)練過程中Decoder的輸入是:
'A large group of young adults are crammed into an area for entertainment.'
torch.Size([72])
[tensor([ 0, 6, 62, 39, 13, 25, 348, 17, 5318, 71, 28, 179,
55, 4285, 5, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
device='cuda:0')]如果想更新參數(shù)就必須計(jì)算損失,計(jì)算損失就必須有標(biāo)簽,那標(biāo)簽是什么?
對(duì)于Decoder,輸入也是輸出,標(biāo)簽就是將輸入向左移動(dòng)了一位:
[tensor([ 6, 62, 39, 13, 25, 348, 17, 5318, 71, 28, 179,
55, 4285, 5, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,2],
device='cuda:0')]也就是起始符0對(duì)應(yīng)的標(biāo)簽是A:6,輸入A對(duì)應(yīng)的標(biāo)簽是large:62,Decoder輸出維度[2, 72],標(biāo)簽維度也是[2, 72],最后通過均方誤差計(jì)算損失,或者輸出概率分布,通過KL損失函數(shù)計(jì)算損失來更新Decoder和Encoder的參數(shù)。
再?gòu)?qiáng)調(diào)一下,整個(gè)batch下所有數(shù)據(jù)是一起輸入到模型的,也就是通過將數(shù)據(jù)組織成矩陣實(shí)現(xiàn)了整個(gè)batch的數(shù)據(jù)并行計(jì)算。
訓(xùn)練完成后,就可以用它進(jìn)行德譯英翻譯了。
假設(shè)輸入這么一句德語(yǔ)。
'Eine gro?e Gruppe Jugendlicher in einem kleinen Unterhaltungsbereich.'
德語(yǔ)先經(jīng)過Encoder進(jìn)行并行編碼,輸出[1, 72, 512]作為Decoder的value和key。
在推理過程中就Deocder就不能并行計(jì)算了,只能自回歸的方式每次前向計(jì)算只產(chǎn)生一個(gè)token。
剛開始只有一個(gè)起始符token 0輸入到Deocder。
[tensor([ 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
device='cuda:0')]decoder輸出6,將6加到0后面再次輸入到decoder。
[tensor([ 0, 6, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
device='cuda:0')]decoder輸出62,以此類推,直到輸出終止符token 1。
人的大腦在學(xué)習(xí)復(fù)雜事物時(shí),往往習(xí)慣使用一種整體到細(xì)節(jié),抽象到具體的漸進(jìn)的方式。
雖然我在作者的源代碼添加了更多的注釋和維度信息,但它仍然是細(xì)節(jié),為了更好地理解大模型的工作原理,我建議先閱讀我之前的圖解和動(dòng)畫Transformer系列,以次獲得對(duì)Transformer有一個(gè)高層次的認(rèn)知。
另外,如果你如果弄明白了Encoder-Decoder架構(gòu),那么就能輕松搞懂GPT和BERT了,因?yàn)樗鼈円粋€(gè)只用了Encoder,另一個(gè)只用了Decoder。
本文轉(zhuǎn)載自公眾號(hào)人工智能大講堂

















