Meta發表的將系統2模型蒸餾至系統1模型

一、結論寫在前面
論文標題:Distilling System 2 into System 1
論文鏈接:??https://arxiv.org/pdf/2407.06023v2??
LLMs在推理過程中可以額外消耗計算資源來生成中間思維,這有助于產生更好的最終響應。自思維鏈以來,已經提出了許多此類系統2技術,例如重述與響應(Rephrase and Respond )、系統2注意力(System 2 Attention)和分支-解決-合并(Branch-Solve-Merge)。
論文研究了自監督方法(self-supervised),將系統2技術的高質量輸出“編譯”(蒸餾,distill)回LLM生成中,而不需要中間推理token序列,因為這種推理已經被蒸餾到系統1中。
論文進行了跨4種不同System 2 LLM方法和5種不同任務的實驗。論文發現,論文的方法能夠在多種環境下將System 2推理蒸餾為System 1,有時甚至能超越System 2教師模型的效果。此外,這些預測現在以極低的計算成本生成。例如,論文在處理偏見觀點或無關信息的任務(System 2注意力)、澄清和改進某些推理任務的響應(重述與回應)以及對LLM進行細粒度評估(分支-解決-合并)方面看到了成功的蒸餾。
然而,論文也表明并非所有任務都能蒸餾到System 1,特別是需要鏈式思維的復雜數學推理任務。這一點在人類中也得到了體現,有些任務沒有刻意的System 2推理是無法執行的。
二、論文的簡單介紹
2.1 論文的背景
人類 System 1 System 1推理被描述為能夠識別模式、快速做出判斷以及理解簡單或熟悉的符號。例如,它用于識別常見的交通標志、識別人臉或關聯基本符號與特定情緒或想法。
人類 System 2 對于復雜的問題解決或例如抽象符號(如代數方程或邏輯陳述)的操作,System 2推理被認為是必要的。在心理學中,自動性概念描述了行為變得如此熟練以至于可以在幾乎沒有意識思考的情況下執行,例如駕駛熟悉的路線。一般來說,人類被認為使用程序記憶將特定任務整合到記憶中,通過實踐學習,以便之后無需意識就能執行。無意識能力概念被歸類為學習的后期階段。最初,一個人認識到自己的無能,并有意學習一項技能,直到獲得有意識的能力。最終目標是在無需意識思考的情況下使用它,這時它被稱為,用通俗的話說,“第二天性”。
模型 System 1 論文將直接輸出響應而不產生中間輸出的神經網絡稱為系統1模型。盡管如此,這類網絡在其層中仍可計算中間的潛在表征,然后輸出響應。由于這些狀態以向量形式表示,它們通常編碼分布式知識而非離散決策,并且難以直接處理復雜的符號推理任務,這與人類系統1推理存在的問題類似。盡管如此,許多任務可以直接通過這種方式成功解決,無需中間生成(Radford et al., 2019)。
模型 System 2 同一個無法執行復雜多步驟計算的語言模型,在要求其通過少樣本提示或監督訓練生成中間步驟到“草稿板”上時,能夠完成這些任務。鏈式思維推理已被證明可以通過零樣本提示、監督訓練或少量樣本方法從大型語言模型中引發。大型語言模型的預訓練使得這種推理能夠融入模型中,因為訓練語料庫中包含了人類編寫的離散符號(文本)的推理步驟。這類系統2模型方法輸出離散的token,有利于進行連續正確的邏輯推理步驟——但顯然,如果推理生成錯誤,則存在缺點。錯誤的離散決策難以恢復,與可能更容易建模分布的潛在向量推理不同。
生成中間思考過程允許模型進行推理和規劃,以成功完成任務或響應指令。論文將這種深思熟慮的思考稱為系統2推理,這一概念源自Sloman(1996)和Kahneman(2011)對人類的描述,后來也被應用于人工智能模型。在系統2推理中,消耗大量認知資源來處理復雜問題和重要決策。因此,在標準的大型語言模型(LLMs)中,論文將系統1定義為直接應用Transformer來根據輸入生成響應,而不生成中間token。論文將系統2定義為任何生成中間token的方法,包括執行搜索或多次提示,然后最終生成響應的方法。
目前已提出了一系列這樣的系統2技術,其中包括思維鏈(Chain-of-Thought)、思維樹(Tree-of-Thoughts)、思維圖(Graph-of-Thoughts)、分支-解決-合并(Branch-Solve-Merge)、系統2注意力(System 2 Attention)、重述和回應(Rephrase and Respond)等等。許多這些方法通過顯式推理被證明能產生更準確的結果,但通常會以更高的推理成本和響應延遲為代價。由于后者的原因,許多這些方法并未在生產系統中使用,生產系統主要使用系統1生成。

圖1:系統2蒸餾概覽。通過在未token數據上運行系統2方法(如分支-求解-合并(BSM))收集過濾后的訓練樣本,這些方法利用額外計算產生更高質量的輸出。然后將這些目標蒸餾到標準(系統1)語言模型中
對于人類而言,心理學中將技能從有意識(系統2)轉移到自動(系統1)的過程被稱為自動性,并利用程序性記憶。例如,首次駕車上班時,人們可能會耗費大量意識努力進行規劃和決策以到達目的地。經過多次重復這條路線后,駕駛過程便“編譯”為潛意識(Charlton and Starkey, 2013)。同樣,像打網球這樣的運動可以變得“習以為?!?。
論文探索了一種類似的技術應用于AI模型。論文的方法以無監督方式進行這種編譯,論文稱之為系統2蒸餾,給定一組未token樣本。對于每個樣本,論文應用給定的系統2方法,然后以無監督方式衡量預測質量。例如,對于具有唯一答案的任務,論文采用自一致性(self-consistency),多次采樣。對于系統2足夠一致的樣本,論文假設此結果應被蒸餾,并將其添加到蒸餾池中。隨后,論文微調系統1以匹配系統2方法在收集的樣本池上的預測,但不生成中間步驟。圖1展示了將系統2蒸餾為系統1的整體過程。
?2.2 將系統2蒸餾至系統1
2.2.1 設置:系統1與系統2模型?
給定輸入 論文x論文,本工作考慮單一模型的情景,即大型語言模型(LLM),該模型具備兩種響應模式:
(i) 系統1:直接生成輸出 論文y論文。這是通過前向傳播底層自回歸神經網絡(Transformer)的各層以生成輸出token來實現的。
(ii) 系統2:論文將系統2模型定義為利用底層Transformer在生成最終響應token之前生成任意類型的中間輸出token 論文z論文 的方法。這可能包括多次調用(提示)。
更正式地,論文將一個System 2模型S視為一個函數,該函數接受一個LLM 和輸入x,并可能多次調用LLM以使用特定算法生成中間token,然后返回一個輸出論文y:

System 2方法可能涉及多個提示、分支、迭代和搜索,同時利用LLM生成中間結果以進行進一步處理。相比之下,一個System 1模型僅考慮原始輸入x,并直接調用LLM生成輸出y:

有許多現有的System 2模型實例。思維鏈提示僅需要單個LLM提示,但仍輸出中間生成內容,然后給出最終響應,通常用于數學和其他推理任務)。
諸如System 2 Attention和Rephrase and Respond(等方法需要兩次調用LLM,在前者中,第一次調用用于關注上下文并消除偏見,而在后者中用于擴展問題。第二次調用則用于根據中間生成內容最終回答問題。某些方法更為復雜,例如Branch-Solve-Merge(,它通過LLM生成計劃,該計劃分支成多個LLM調用,直到最終階段合并結果。
論文將對上述四種方法進行實驗,但還有許多其他System 2方法,例如Tree-of-Thoughts、Graph-of-Thoughts等。
2.2.2 方法:系統2蒸餾
許多系統2方法本質上在推理時由于多次提示調用和生成中間token而顯著較慢。系統2蒸餾的目標是將所有推理從S_II蒸餾回S_I,以便語言模型的直接輸出p_θ( x)得到改進。論文假設模型可以訪問未token的輸入t,從中它可以學習,類似于人類如何在無監督的情況下學習程序記憶。對于基于語言的任務,通??梢栽L問遵循指令的提示(輸入),因為它們可以由人類收集,例如發布的1M Wild-Chat交互,其中提供了輸入但正確標簽未知。因此,這是一個現實的設置。
所提出方法的第一步是使用系統2模型在未token的輸入t上生成響應:

這些響應可以直接用作微調系統1模型的系統2蒸餾目標。然而,它們受到噪聲的影響:其中一些響應可能是高質量的,而其他可能是低質量或不正確的。對于涉及短響應且通常具有唯一正確(但未知)答案的短形式QA和推理任務,論文因此考慮一個無監督的篩選步驟,以嘗試提高訓練數據質量。論文考慮兩種變體,兩者都依賴于一致性標準:
?輸出自一致性:論文總共采樣S_II(x^ i ; p_θ) N次,并接受多數投票的響應;如果沒有多數勝出者,論文丟棄該示例。
?輸入擾動下的自一致性:論文以輸出不應改變的方式擾動輸入w,例如改變提示中多項選擇項的順序,并為每個擾動計算S_I;如果輸出不一致,論文丟棄該示例。
隨后,論文得到合成數據集(X_S_II , Y_S_II),其中 論文X_S_II是X的過濾子集,目標為Y_S_II)。最后一步是使用這個蒸餾的訓練集對具有參數pθ的大型語言模型(LLM)進行有監督的微調。論文通常從當前狀態pθ初始化模型,并繼續使用新數據集進行訓練。
微調后,論文獲得一個 LLM p_θ,這是一個系統1模型,預計其輸出和性能提升與評估的系統2模型相似。
?2.3 實驗
2.3.1 訓練與評估設置?
論文使用 Llama-2-70B-chat作為所有實驗的基礎模型。論文需要一個足夠強大的基礎模型,使其能作為系統2模型表現出色,同時具有可微調的開源權重,因此選擇了此模型。論文考慮了幾種系統2方法,包括重述與回應(RaR)、系統2注意力(S2A)、分支-解決-合并(BSM)和思維鏈(CoT),重點關注每種方法已展示出強大性能的任務。對于系統1,論文使用指令調優的基礎模型進行零樣本推理,作為標準基線。論文報告每個任務的特定指標,以及“#Tokens”指標,該指標衡量評估集中每個輸入生成的平均token數量。對于系統2方法,這包括中間token生成和最終輸出token生成。
2.3.2 重述與回應蒸餾(Rephrase and Respond Distillation)
重述與回應(RaR)是一種系統2方法,首先提示語言模型對原始問題進行進一步闡述的重述,然后基于重述的問題生成回應,旨在提供更優質的輸出。作者介紹了兩種方法,1步RaR和2步RaR,后者涉及兩個單獨的提示,而不是像前者那樣的組合提示,具體提示見附錄A.1。他們發現2步RaR在幾個對基線LLM具有挑戰性的推理任務上顯著提高了性能。論文考慮了原文中表現良好的兩個任務:最后一個字母連接任務和硬幣翻轉推理。然后評估是否可能蒸餾這種系統2方法。
蒸餾數據集 論文為RaR構建了系統2蒸餾數據集,利用輸出的自一致性。對于每個輸入,論文對最后一個字母任務進行八次采樣迭代,并對硬幣翻轉任務的每個階段進行八次采樣迭代。然后,論文通過多數表決來確定最終輸出。
2.3.2.1 最后一個字母拼接任務(Last letter Concatenation Task)
此任務側重于符號推理,要求模型拼接給定單詞的最后一個字母。例如,指令:“取Edgar Bob中單詞的最后一個字母并拼接它們。”正如Deng等人(2023a)所示,此任務從RaR方法的應用中獲益顯著。論文通過隨機選擇1200個獨特的英語單詞來編譯數據集。利用這些單詞,論文分別為訓練、驗證和測試構建了200個樣本。
結果 總體結果見表1。基準系統1模型(Llama-2-70B-chat)達到30.0%的準確率,被1步和2步RaR的系統2方法(分別為39.5%和44.5%)超越。通過論文的無監督技術將2步RaR方法蒸餾回系統1 Llama-2-70B-chat模型,論文實現了驚人的98.0%準確率。與零樣本聊天模型相比,該模型能有效學習如何解決此任務。重述并回應的蒸餾有效繼承了系統2和系統1的優勢。它在保持系統2的準確性優勢的同時,推理成本與系統1相當(見生成token數量)。
分析與消融實驗 為了評估論文利用輸出自一致性的無監督篩選步驟的有效性和必要性,論文通過創建一個不應用自一致性過濾器的蒸餾數據集進行了消融研究。當論文在這個未經過濾的數據集上使用相同的設置對System 2模型進行了蒸餾,其精確匹配準確率達到了87.5%(過濾版本為98%)。這一比較突顯了一致性過濾的關鍵作用。盡管如此,在兩種情況下,構建訓練數據確實比零樣本性能有所提升。論文還嘗試使用相同的過濾技術對System 1預測進行蒸餾,結果準確率較低,為69.5%。

表1:重述并回應的系統2蒸餾:硬幣翻轉和最后一個字母拼接任務。論文報告精確匹配(EM)測試準確率和生成(中間和輸出)token數量
2.3.2.2 硬幣翻轉推理任務?
這一符號推理任務在研究中經常被測試,包括在Wei等人(2022)和Deng等人(2023a)的研究中。它涉及從已知初始位置開始,經過一系列自然語言描述的翻轉后,確定硬幣的最終面(正面或反面),例如“一枚硬幣正面朝上。Roxas沒有翻轉硬幣。Schneiderman沒有翻轉硬幣。硬幣還是正面朝上嗎?”Deng等人(2023a)表明,即使是強大的語言模型也無法成功完成這一任務,而應用RaR方法則能提高它們的性能。該任務有20k個訓練示例(無標簽,用于無監督學習),3.33k個驗證示例和1.33k個測試示例。
結果 總體結果見表1。Llama-2-70B-chat(零樣本)在該任務上的成功率為56.1%,而1-Step和2-Step RaR的成功率分別為58.59%和77.2%。因此,論文僅在2-Step方法中看到了顯著的改進。通過論文的無監督技術將2-Step RaR蒸餾回System 1 Llama-2-70B-chat,成功率為75.69%。因此,論文發現論文的蒸餾System 2模型提供了與System 2(2 Step RaR)相當的性能,但無需執行LLM程序。

表2:System 2注意力蒸餾:TriviaQA任務,報告有偏和無偏評估集的準確率
分析與消融實驗 Deng等(2023a)的RaR方法包含了提示工程技巧,例如在原始查詢后附加"Flip意味著反轉?;卮鹗腔蚍駟栴}"等短語,這已被證明可以提高模型性能。遵循他們的方法,論文使用不同的提示評估了模型性能,見表6。當使用"Flip意味著反轉"和"Flip意味著反轉。回答是或否問題"等提示測試Llama-2-70B-chat模型(系統1)時,論文觀察到性能顯著提升,從56.11%提高到66.84%。這突顯了提示選擇在優化系統1模型性能中的關鍵作用。然而,這種對提示工程的依賴也代表了一個局限性,需要額外的人力投入。
論文還嘗試對系統1模型進行蒸餾,但得到了較差的性能。在這種情況下,論文同樣觀察到不同提示下性能的波動。相比之下,蒸餾后的系統2模型在各種提示下表現出一致的性能,對提示變化的敏感度較低。這種一致性表明,對于蒸餾后的系統2模型,可能不需要進行大量的提示工程。
2.3.3 系統 2 注意力蒸餾
Weston 和 Sukhbaatar 在 2023 年提出了系統 2 注意力(S2A),這是一種有助于減少模型推理缺陷的方法,如依賴輸入中的偏見信息或關注無關上下文。S2A 是一種兩階段推理方法,第一階段重寫輸入,使其不包含如偏見或無關上下文等不期望的信息,第二階段關注重寫后的較短上下文(與 Rak 擴展上下文相反),參見圖 6。在本研究中,論文驗證了將 S2A 蒸餾到系統 1 的可行性。特別地,論文關注了 SycophancyEval 問答任務(Sharma 等人,2023),該任務的輸入中包含已知會損害大語言模型(LLM)性能的偏見信息。論文使用了來自 SycophancyEval 的 6668 個示例作為未token訓練數據,以及 400 個示例用于評估,后者被分為偏見輸入(350 個)和無偏見輸入(50 個)。
蒸餾數據 論文使用通用自一致性(USC)(Chen et al., 2023)來篩選高質量的目標。具體而言,論文采樣20個生成結果,然后利用Llama-70B-chat模型配合USC提示(如圖12所示)來組合一個自一致性(多數)的最終答案,該答案作為蒸餾目標。
結果 結果如表2所示,報告了3個隨機種子的平均準確率?;€(系統1)LLM在偏見部分的準確率較低,正如預期,因為其容易受到偏見輸入的影響。S2A顯著提升了偏見輸入的性能。系統2蒸餾顯示出與系統2方法相似的強勁性能。然而,與基線和S2A模型相比,平均使用的token數量有顯著減少。這是因為偏見輸入往往使基線LLM生成更多的輸出token,而S2A還需要生成中間token。圖11展示了一個代表性示例。最后,論文通過報告不使用USC的結果(最后一行),顯示后者提供的結果較差,從而表明使用USC進行蒸餾對整體結果的重要性。這突出了在微調過程中使用的蒸餾數據質量的重要性。
2.3.4 分支-解決-合并蒸餾
分支-解決-合并(BSM)(Saha et al., 2023)由三個模塊組成:分支、解決和合并。這些模塊協同工作,將任務分解為多個并行子任務,每個子任務由特定提示引導。BSM在LLM作為評判者的情境中已被證明有效,如圖14所示。該方法首先提示語言模型列出針對特定用戶查詢定制的評估指標(分支)。隨后,LLM被查詢以基于每個指標獨立并行地評估響應(解決)。最后,來自每個分支的分數被平均以得出一個全面的評估決策(合并)。值得注意的是,這種方法的推理成本是傳統(系統1)LLM評估方法的5-6倍,使其實用性大打折扣。論文評估了蒸餾BSM的可行性,旨在保留其優勢的同時降低計算成本。

表3 系統 2 分支-解決-合并 (BSM) 的蒸餾:Open Assistant (OASST2) 和 MT-bench 對 LLM 作為判斷者的評估。系統 2 BSM 的蒸餾優于 BSM 本身,甚至優于 GPT4 作為判斷者,盡管使用的是 Llama-2-70B-chat。蒸餾后的 BSM 具有更高的人類一致性(一致性),更少的位置偏差,并且不一致樣本的百分比為 9.1%
蒸餾數據 遵循 Yuan 等人 (2024) 和 Li 等人 (2023b) 的方法,論文使用了 Open Assistant Dataset v2 (OASST2) (Kopf 等人, 2024) 的第一輪和僅限英語的數據。論文使用 OASST2 訓練集中的查詢及其兩個候選響應作為輸入(總共 19,672 個樣本)。論文通過輸入擾動下的自一致性來確保蒸餾數據的質量。具體來說,由于需要判斷兩個響應,論文對每個樣本進行兩次 BSM 評估——一次按原始順序,一次按交換順序。無論順序如何,獲勝的響應應保持一致。論文過濾掉在響應順序交換時未能產生一致獲勝者的樣本。
評估 論文在兩個流行的基準上評估論文的模型,即 OASST2 驗證集和 MT-bench (Zheng 等人, 2024)。OASST2 驗證集包含 273 個樣本,僅限于第一輪和英語語言。對響應對的評估在原始順序和交換順序下進行。由于論文的蒸餾模型是在 OASST2 訓練集上訓練的,OASST2 驗證集作為分布內評估集,而 MT-bench 則更具分布外特性。MT-bench 是一個流行的基準,評估 LLM 作為有用 AI 助手對話時對其他 LLM 響應的判斷。它包含來自 8 個不同領域的指令,例如寫作、推理、數學、編碼等。
遵循 Zheng 等人 (2024) 的方法,論文評估了模型投票與人類專家投票之間的一致性。LLM 作為判斷者的一個已知局限是位置偏差,即語言模型 (LLM) 傾向于偏好某些位置而非其他位置。這種偏差在改變評估提示中響應的位置時,常常導致模型做出不同的決策。為了量化這一點,論文不僅測量一致性,還計算不一致樣本的百分比以評估位置偏差。
OASST2評估結果 表3提供了在OASST2數據集上的結果。與基線(系統1)大型語言模型相比,思維鏈(CoT)方法通過提高一致性和降低不一致率來改善性能(參見附錄中的提示)。雖然BSM表現優于CoT,但這是以增加推理時間(#To-kens)為代價的。值得注意的是,論文蒸餾的系統2 BSM模型僅需生成四個token,仍然優于CoT和BSM。此外,論文基于Llama-2-70B-chat的蒸餾模型超過了GPT-4-0125-preview,實現了更高的人類一致性和更大的連貫性。
MT-Bench評估結果 表3也提供了在MT-bench上的結果,該測試作為分布外測試。結果與OASST2評估的結果相呼應。思維鏈(CoT)和BSM都提高了模型性能,但代價是顯著增加的推理成本。論文的蒸餾BSM模型不僅實現了更高的人類一致性和更低的不一致率,而且需要的計算資源更少。盡管論文的模型在一致性上略遜于最先進的GPT-4-0125-preview模型,但它僅基于Llama-2-70B-chat在OASST2上的未標注數據進行訓練。盡管如此,它在連貫性上更優,且在輸出token方面推理成本低廉。

圖2:MT-bench上LM評判與人類偏好之間的一致性,按評估類別劃分

表3:GSM8k測試集準確率。多數投票中的投票數k表示為收集預測答案的投票而采樣的候選數量。在這種情況下,系統2的CoT蒸餾效果不佳
按類別分析 在此,論文進一步按類別分析MT-Bench結果中的一致性。圖2展示了按類別的一致性。論文觀察到,與基礎模型(Llama-2-70B-Chat)相比,CoT在所有類別上提高了一致性。BSM優于CoT,而論文的蒸餾BSM甚至優于BSM。盡管蒸餾BSM在所有類別上相較于基線取得了優越的性能,但在推理、編碼和提取方面仍落后于GPT-4-0125-preview。然而,在寫作、數學和STEM方面,它超過了GPT-4-0125-preview。
2.3.5 思維鏈蒸餾
思維鏈(CoT)已被證明是提高LLM推理能力的有效方法,例如解決研究生數學問題。LLM生成中間token,這些token是推理(思維)的步驟(鏈),然后產生最終答案。論文考慮了該方法的兩個變體:(i)少樣本CoT,即從訓練集中提供多個[問題,CoT,答案]示例作為上下文,隨后是問題;(ii)零樣本,即在提示中除了問題外還添加了“一步一步”思考的明確指令,詳見附錄圖10。
蒸餾數據 論文使用CoT為GSM8k訓練集中的問題(論文認為這些是無標簽的,由Cobbe等人,2021年提出)生成答案,采用K=10的多數投票方法。由此產生的蒸餾訓練集包含7461個[問題, 答案]對,即不包含任何中間推理步驟。為了分析目的計算的自監督目標準確率為56.81%。
評估 論文在GSM8k測試集上使用不同K值的多數投票方法計算并報告評估準確率。與之前的實驗類似,論文報告每種方法預測的平均token數。請注意,論文在進行多數投票時計算所有生成token的平均值,以觀察K值的增加如何影響推理成本。論文考慮了幾個基線:系統1和系統2(CoT)方法在零樣本或8樣本輸入上下文中進行評估。需要注意的是,系統2在8樣本情況下意味著在少量樣本輸入中提供了CoT,而系統1則意味著少量樣本示例包含問題和答案,但沒有CoT。
結果 評估結果如表3所示。首先,正如預期,使用CoT方法帶來了改進:將其作為少樣本上下文的一部分或作為提示模板中的指令的一部分時,這種方法有所幫助。這些改進伴隨著推理成本的增加:與System 1方法相比,使用CoT方法預測的序列長度顯著增加。其次,論文的System 2蒸餾方法在各種解碼超參數下表現不佳。GSM8k任務(數學問題)所需的推理類型與論文在此工作中考慮的其他任務截然不同。這突顯了System 2蒸餾的非平凡性:所提出的蒸餾算法在許多情況下有效,但并非總是如此。這為未來的研究留下了空間,以闡明在何種具體情況下應用蒸餾,以及何時不應應用,或許可以采用類似于人類的方法。
本文轉載自 ??AI帝國??,作者: 無影寺

















