在RTX 4090被限制的時(shí)代下,讓大模型使用RLHF更高效的方法來(lái)了

- 論文鏈接:https://arxiv.org/abs/2310.10505
- 作者:李子牛,許天,張雨舜,俞揚(yáng),孫若愚,羅智泉
- 機(jī)構(gòu):香港中文大學(xué)(深圳),深圳市大數(shù)據(jù)研究院,南京大學(xué),南棲仙策
- 開源代碼:https://github.com/liziniu/ReMax
如未額外說(shuō)明,所有圖片來(lái)自于論文。
背景
今年,以 ChatGPT 為首的大語(yǔ)言模型(Large Language Models, LLMs) 在各個(gè)方面大放光彩,由此引發(fā)了學(xué)術(shù)界和商業(yè)界對(duì) GPU 等計(jì)算資源的需求劇增。

左圖來(lái)自 DALL?E3,右圖來(lái)自 DALL?E3
比如監(jiān)督訓(xùn)練地調(diào)優(yōu) (supervised fine-tuning, SFT) 一個(gè) Llama2-7B 的模型,需要消耗 80GB 以上的內(nèi)存。而這往往不夠,為了和人類對(duì)齊(alignment),大語(yǔ)言模型還要經(jīng)過(guò) RLHF (reinforcement learning from human feedback) 的訓(xùn)練。RLHF 的 GPU 消耗往往是 SFT 的 2 倍以上,訓(xùn)練時(shí)間更能達(dá)到 6 倍以上。
近日,美國(guó)政府宣布限制英偉達(dá) GPU 產(chǎn)品 H100, H800等進(jìn)入中國(guó)市場(chǎng)。這項(xiàng)條款無(wú)疑為中國(guó)發(fā)展大語(yǔ)言模型(LLMs) 和人工智能增添了很多阻力。減小 RLHF 的訓(xùn)練成本(GPU 消耗和訓(xùn)練時(shí)間)對(duì) LLMs 的發(fā)展非常重要。
動(dòng)機(jī)
RLHF 包含三個(gè)階段:
1. 監(jiān)督式地調(diào)優(yōu)(Supervised Fine-Tuning, SFT)。
2. 從對(duì)比數(shù)據(jù)中學(xué)習(xí)獎(jiǎng)勵(lì)模型(reward model)。
3. 利用強(qiáng)化學(xué)習(xí)(RL)算法來(lái)最大化獎(jiǎng)勵(lì)。

圖片來(lái)源自 InstructGPT 論文
我們發(fā)現(xiàn) RLHF 的主要計(jì)算開銷來(lái)源于第三階段(獎(jiǎng)勵(lì)最大化)。這一點(diǎn)可以從 DeepSpeed-Chat 的報(bào)告里看到,第三階段的訓(xùn)練時(shí)間是前兩個(gè)階段時(shí)間總和的 4 倍以上。而且,根據(jù)我們的經(jīng)驗(yàn),第三階段的 GPU 消耗是前兩階段的 2 倍以上。

圖片來(lái)自 DeepSpeed-Chat 技術(shù)報(bào)告
目前 RLHF 第 3 階段的主要計(jì)算瓶頸是什么?
我們發(fā)現(xiàn)該階段的計(jì)算瓶頸主要來(lái)源用來(lái)目前使用的 RL 算法:PPO 算法。PPO 算法是用來(lái)解決普適 RL 問(wèn)題的最流行的算法之一,有非常多成功的案例。我們?cè)谶@里省略 PPO 的技術(shù)細(xì)節(jié),著重介紹 PPO 的一個(gè)關(guān)鍵組件:價(jià)值模型 (The value model)。價(jià)值模型是一個(gè)需要被訓(xùn)練的神經(jīng)網(wǎng)絡(luò),能夠有效地估計(jì)給定策略的預(yù)期長(zhǎng)期回報(bào)。盡管價(jià)值模型為 PPO 帶來(lái)了良好的性能,但它在 RLHF 任務(wù)中也引入了沉重的計(jì)算開銷。例如,為了更好地與人類偏好對(duì)齊,PPO 中的價(jià)值模型通常與 LLM 大小相似,這使存儲(chǔ)需求翻了一番。此外,價(jià)值模型的訓(xùn)練需要存儲(chǔ)其梯度、激活和優(yōu)化器狀態(tài),這進(jìn)一步增加了近 4 倍的 GPU 存儲(chǔ)需求。總結(jié)來(lái)說(shuō),PPO 和它的價(jià)值模型(以及其訓(xùn)練相關(guān)部分)已成為 RLHF 獎(jiǎng)勵(lì)最大化階段的主要計(jì)算障礙。

相比 PPO,ReMax 是輕量級(jí)算法
思路
是否有可能找到比 PPO 更適配 RLHF 的算法?
我們得出的答案是肯定的。這是因?yàn)?PPO 和價(jià)值模型是為通用 RL 問(wèn)題設(shè)計(jì)的,而不是針對(duì)像 RLHF 這樣的特定問(wèn)題(RLHF 只是 RL 問(wèn)題中的一個(gè)子類)。有趣的是,我們發(fā)現(xiàn) RLHF 具有三個(gè)在 PPO 中未使用的重要結(jié)構(gòu):
1. 快速模擬(fast simulation): 軌跡(即 LLM 中的整個(gè)響應(yīng))可以在很短的時(shí)間內(nèi)迅速執(zhí)行(小于 1s),幾乎沒有時(shí)間開銷。
2. 確定性轉(zhuǎn)移(deterministic transitions):上下文確定性依賴于過(guò)去的標(biāo)記和當(dāng)前生成的標(biāo)記。
3. 軌跡級(jí)獎(jiǎng)勵(lì)(trajectory-level rewards):獎(jiǎng)勵(lì)模型只在響應(yīng)完成時(shí)提供一個(gè)獎(jiǎng)賞值。
通過(guò)這三個(gè)觀察,我們不難發(fā)現(xiàn) value model 在 RLHF 的問(wèn)題中是 “冗余” 的。這是因?yàn)?value model 設(shè)計(jì)的初衷是為了隨機(jī)環(huán)境下的樣本效率和慢仿真環(huán)境的計(jì)算效率。然而這在 RLHF 中是不需要的。

ReMax 是針對(duì) RLHF 設(shè)計(jì)的算法,PPO 則是為通用 RL 設(shè)計(jì)的算法
方法
ReMax
ReMax 算法基于一個(gè)古老的策略梯度算法 REINFORCE,REINFORCE 使用的策略梯度估計(jì)器如下圖所示:

REINFORCE 梯度估計(jì)器
REINFORCE可以在計(jì)算層面利用好RLHF任務(wù)的三個(gè)性質(zhì),因?yàn)镽EINFORCE直接利用一個(gè)響應(yīng)的獎(jiǎng)勵(lì)來(lái)進(jìn)行優(yōu)化,不需要像一般的RL算法一樣需要知道中間步驟的獎(jiǎng)勵(lì)和值函數(shù)。然而,由于策略的隨機(jī)性, REINFORCE梯度估計(jì)器存在高方差問(wèn)題(在Richard Sutton的RL書里有指出),這一問(wèn)題會(huì)影響模型訓(xùn)練的有效性,因此REINFORCE在RLHF任務(wù)中的效果較差,見下面兩張圖片。

REINFORCE 的計(jì)算代價(jià)小,但性能差

REINFORCE 的(隨機(jī))梯度值遠(yuǎn)遠(yuǎn)大于 ReMax
為解決這一問(wèn)題,ReMax 使用貪婪生成的回答(greedy response)的獎(jiǎng)勵(lì)作為基準(zhǔn)值(baseline value)來(lái)構(gòu)建梯度估計(jì)器,具體公式如下:

ReMax 梯度估計(jì)器
注意到,貪婪回復(fù)的獎(jiǎng)勵(lì)
可以看作為期望獎(jiǎng)勵(lì)
的好的近似。在理想情形下(
),對(duì)于隨機(jī)變量
,
,因此我們能夠期望估計(jì)器
具有更小的方差。
下圖展示了 ReMax 的算法流程,紅色方框中的是核心算法改變。

ReMax 算法流程
理論保證
我們證明了 ReMax 使用的梯度估計(jì)器仍然是真實(shí)策略梯度的一個(gè)無(wú)偏估計(jì)器。
詳細(xì)理論介紹見論文。
算法優(yōu)點(diǎn)
- ReMax 的核心部分可以用 6 行代碼來(lái)實(shí)現(xiàn)。相比之下,PPO 要額外引入重要性采樣(importance sampling),廣義優(yōu)勢(shì)估計(jì)(generalized advantage estimation,GAE),價(jià)值模型學(xué)習(xí)等額外模塊。
- ReMax 的超參數(shù)很少。相比之下,PPO 有額外的超參數(shù),例如重要性采樣剪切閾值(importance sampling clipping ratio)、GAE 系數(shù)、價(jià)值模型學(xué)習(xí)率,離策略訓(xùn)練輪次(off-policy training epoch)等,這些超參數(shù)都需要花大量時(shí)間去調(diào)優(yōu)。
- ReMax 能理論上節(jié)省約 50% 內(nèi)存。相比于 PPO,ReMax 成功移除了所有和價(jià)值模型相關(guān)的部件,大大減小了內(nèi)存開銷。通過(guò)計(jì)算,我們發(fā)現(xiàn)相比于 PPO,ReMax 能節(jié)省約 50% 內(nèi)存。
效果
有效性
- ReMax 可以像 PPO 一樣有效地最大化獎(jiǎng)勵(lì)

在 OPT-1.3B 上,ReMax 可以有效地最大化獎(jiǎng)勵(lì)

在 OPT-1.3B 上,ReMax 的訓(xùn)練非常穩(wěn)定
- 在 GPT-4 評(píng)估下(LIMA Test Questions),ReMax 得到的策略比 SFT 和 PPO 會(huì)更好

GPT4 打分顯示 ReMax 得到的模型會(huì)更好
高效性
- ReMax 能節(jié)省近 50% 的 GPU 內(nèi)存。ReMax 移除掉了價(jià)值模型和它的訓(xùn)練部分(梯度,優(yōu)化器,激活值),從而極大節(jié)省了 GPU 內(nèi)存需求。考慮 Llama2-7B,PPO 無(wú)法在 8xA100-40GB 的機(jī)器上跑起來(lái),但是 ReMax 可以。

在 Llama2-7B 上,ReMax 可以節(jié)省近 50% 的 GPU 內(nèi)存
- ReMax 能加快 2 倍的訓(xùn)練速度。在每一輪中,ReMax 調(diào)用 2 次生成(generation),1 次反向傳播(backpropagation);而 PPO 使用 1 次生成,2 次反向傳播。對(duì)于大模型而言,生成會(huì)比反向傳播的時(shí)間小,從而 ReMax 可以實(shí)現(xiàn)理論上接近 2 倍的訓(xùn)練加速。
通用性
除了 RLHF 任務(wù),作為一個(gè) RL 算法,ReMax 對(duì)于經(jīng)典的 NLP 任務(wù)也適用。本文考慮了在 GPT-2 上進(jìn)行一個(gè)電影評(píng)論續(xù)寫的任務(wù),這里獎(jiǎng)勵(lì)模型不是從對(duì)比數(shù)據(jù)學(xué)習(xí)的。實(shí)驗(yàn)觀測(cè)到,ReMax 可以實(shí)現(xiàn) 2.2 倍的訓(xùn)練加速和 60% 的 GPU 內(nèi)存節(jié)省。

在經(jīng)典的 NLP 任務(wù)(文本續(xù)寫)上,ReMax 相比 PPO 實(shí)現(xiàn)了 2.2 倍加速
總結(jié)
最后,我們從實(shí)驗(yàn)中簡(jiǎn)要總結(jié)了 ReMax 相對(duì)于 PPO 的主要優(yōu)勢(shì)。
- 更簡(jiǎn)單的實(shí)現(xiàn): ReMax 的核心部分 6 行代碼即可實(shí)現(xiàn)。這與 PPO 中的眾多復(fù)雜的代碼構(gòu)建塊形成鮮明對(duì)比。
- 更少的內(nèi)存開銷:由于移除了價(jià)值模型及其全部訓(xùn)練組件,相比 PPO,ReMax 節(jié)省了大約 50% 的 GPU 內(nèi)存。
- 更少的超參數(shù): ReMax 成功移除了所有和價(jià)值模型訓(xùn)練相關(guān)的超參數(shù),其中包括:GAE 系數(shù)、價(jià)值模型學(xué)習(xí)率、重要性采樣時(shí)期、小批量(mini-batch)大小。這些超參數(shù)往往對(duì)問(wèn)題敏感且難以調(diào)整。我們相信 ReMax 對(duì) RLHF 研究者更加友好。
- 更快的訓(xùn)練速度:在 GPT2(137M)的實(shí)驗(yàn)中,我們觀察到 ReMax 在真實(shí)運(yùn)行時(shí)間方面相比于 PPO 有 2.2 倍的加速。加速來(lái)自 ReMax 每次迭代中較少的計(jì)算開銷。通過(guò)我們的計(jì)算,該加速優(yōu)勢(shì)在更大的模型上也能維持(假設(shè)在足夠大的內(nèi)存下 PPO 可以被成功部署)。
- 優(yōu)異的性能:如前所示,ReMax在中等規(guī)模實(shí)驗(yàn)中與PPO實(shí)現(xiàn)了相當(dāng)?shù)男阅埽⑶矣袝r(shí)甚至超越它(可能是由于 ReMax 更容易找到合適的超參數(shù))。我們推測(cè)這種良好的性能可以拓展到更大規(guī)模的模型中。





























