重磅!Thinking Machines開山之作:大模型輸出隨機的根本原因被揪出,并開源終結方案
要理解AI,先要理解它何以不確定,由OpenAI前CTO Mira Murati創辦的Thinking Machines 開山之作來了,剛剛,Thinking Machines Lab 宣布正式上線技術研究博客:連接主義。開篇就是萬字技術雄文《擊敗LLM推理中的非確定性(Defeating Nondeterminism in LLM Inference)》。

為什么叫“連接主義”?這其實是個致敬梗。來自上世紀80年代,當時研究神經網絡的那群先驅者們,就是用這個詞來命名探索人造大腦的領域的,Thinking Machines非常良心,沒有學CloseAI,堅持開源,以后大家可以在聯結主義博客里看到各種各樣的話題,從硬核的底層計算,到好玩的提示詞技巧,應有盡有,另外北大校友,OpenAI 前安全副總裁,Thinking Machines 聯合創始人LiLian Weng 爆料,公司第一個旗艦產品命名為“連接機器”。

LLM有個共同的毛病就是同一個問題,再次詢問時可能給出不同的回答,沒有可復現性,這好像不是個問題,大家都習以為常了。
Thinking Machines Lab 的這篇文章對LLM不確定性進行了研究,發現幾乎所有 LLM 推理端點之所以非確定性,主要原因在于負載(并因此導致批次大小)的非確定性變化!
這個洞見瞬間顛覆了大家習以為常的視角。原來,困擾我們的隨機性,并非源于計算核心的瑕疵,而是源于系統在應對動態負載時架構上的妥協。你的每一次請求,其結果都在被其他成千上萬個并發請求無形地塑造著,純粹是數學上的問題。
文章中有一句話點出:“從推理服務器的角度,它是確定的;但從用戶的角度,它是非確定的。” 這就像生活里那些看似公平的規則,當環境變量一變,體驗到的卻是另一回事。AI并沒有撒謊,只是我們忽略了系統背后復雜的運行邏輯。
更重要的是,文章沒有停留在揭示問題,而是給出了系統性的解法:讓核心算子實現批次不變性,從RMSNorm到矩陣乘法,再到注意力機制,逐步重塑內核的確定性。最終,Thinking Machines不僅在實驗中實現了千次生成結果的完全一致,還為未來的在線策略強化學習打開了大門。
blog:https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
可復現性是科學進步的基石。然而,要從大型語言模型中獲得可復現的結果卻異常困難。
例如,人們可能會觀察到,多次向 ChatGPT 提出相同的問題會得到不同的結果。這本身并不奇怪,因為從語言模型獲取結果涉及采樣過程,該過程將模型的輸出轉換為概率分布,并據此概率性地選擇一個詞元(token)。
但更令人驚訝的可能是,即使將溫度(temperature)降至 01(從而使采樣在理論上是確定性的),LLM API 在實踐中仍然不是確定性的。即便使用像 vLLM 或 SGLang 這樣的開源推理庫在自己的硬件上運行推理,采樣過程仍然不是確定性的。
但為什么 LLM 推理引擎不是確定性的呢?一個常見的假設是,浮點數的非結合性(non-associativity)與并發執行共同導致了非確定性,具體結果取決于哪個并發核心先完成計算。本文將此稱為 LLM 推理非確定性的“并發+浮點數”假說。例如,一篇近期的 arXiv 預印本文章寫道:
GPU 中的浮點數運算表現出非結合性,即 (a + b) + c ≠ a + (b + c),這是由有限精度和舍入誤差造成的。這一特性直接影響 Transformer 架構中注意力分數和 logits 的計算,其中跨多線程的并行操作會因執行順序的不同而產生不同的結果。
盡管這個假說不完全錯誤,但它并未揭示全貌。例如,即使在 GPU 上,對相同數據重復運行相同的矩陣乘法,也總會得到逐位元(bitwise)相同的結果。這里確實使用了浮點數,GPU 也確實存在大量并發。那為什么在這個測試中看不到非確定性呢?
# python
A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
ref = torch.mm(A, B)
for _ in range(1000):
assert (torch.mm(A, B) - ref).abs().max().item() == 0要理解 LLM 推理非確定性的真正原因,必須深入探究
不幸的是,即便是定義 LLM 推理的“確定性”也很困難。也許令人困惑的是,以下陳述全部同時成立:
1.GPU 上的某些核函數(kernels)是非確定性的
2.然而,語言模型前向傳播(forward pass)中使用的所有核函數都是確定性的
3.此外,LLM 推理服務器(如 vLLM)的前向傳播也可以聲稱是確定性的
4.盡管如此,從任何使用該推理服務器的用戶的角度來看,結果都是非確定性的
在本文中,將解釋為什么“并發+浮點數”假說未能抓住要點,揭示 LLM 推理非確定性背后的真正元兇,并說明如何擊敗非確定性,在 LLM 推理中獲得真正可復現的結果。
原罪:浮點數的非結合性
在討論非確定性之前,有必要先解釋一下為什么會出現數值差異。畢竟,通常認為機器學習模型是遵循交換律或結合律等結構規則的數學函數。難道不應該有一個“數學上正確”的結果,并且機器學習庫就應該提供這個結果嗎?
罪魁禍首是浮點數的非結合性。也就是說,對于浮點數而言:
(a + b) + c ≠ a + (b + c)
# python
(0.1 + 1e20) - 1e20
>>> 0
0.1 + (1e20 - 1e20)
>>> 0.1具有諷刺意味的是,打破結合律正是浮點數有用的原因。
浮點數之所以有用,是因為它們允許一種動態的精度水平。為了便于解釋,這里將使用基數 10(而不是二進制),其中浮點數的格式為 尾數 * 10^指數。同時,假設尾數有 3 位數字,指數有 1 位數字。
例如,對于數值 3450,可以精確表示為 3.45 * 103。也可以表示像 0.486 這樣小得多的值,即 4.86 * 10?1。通過這種方式,浮點數可以同時表示非常小和非常大的值。在科學領域,可以說浮點數允許我們保持恒定數量的“有效數字”。
如果將兩個具有相同指數的浮點數相加,其過程類似于整數加法。例如,123 (1.23 * 102) + 456 (4.56 * 102) 的結果是 579 (5.79 * 102)。
但當兩個指數不同的浮點數相加時,比如 1230 和 23.4,會發生什么呢?在這種情況下,精確結果是 1253.4。然而,一次只能保持 3 位數的精度。因此,浮點數加法會丟棄最后兩位數字,得到 1.25 * 103(即 1250)。

圖1:需要 3 位精度來表示 1230,3 位精度來表示 23.4。然而,將這兩個數相加得到的結果需要 5 位精度來表示(1253.4)。浮點數格式必須丟棄末尾的 34。在某種意義上,這相當于在相加前將原始的 23.4 四舍五入到 20.0。
然而,到了這一步,信息已經被破壞了。請注意,每當將兩個具有不同尺度(即不同指數)的浮點數相加時,都可能發生這種情況。而將不同指數的浮點數相加是常有的事。事實上,如果能保證永遠不需要不同的指數,那直接使用整數就可以了。
換句話說,每次以不同順序對浮點數進行求和,都可能得到一個完全不同的結果。舉個極端的例子,根據求和順序的不同,對下面這個數組求和可能會產生 102 種不同的結果。
# python
import random
vals = [1e-10, 1e-5, 1e-2, 1]
vals = vals + [-v for v in vals]
results = []
random.seed(42)
for _ in range(10000):
random.shuffle(vals)
results.append(sum(vals))
results = sorted(set(results))
print(f"There are {len(results)} unique results: {results}")
# 輸出:
# There are 102 unique results: [-8.326672684688674e-17, -7.45931094670027e-17, ...]盡管這是產生非相同輸出的根本原因,但它并沒有直接回答非確定性來自何處。它沒能幫助我們理解為什么浮點值會以不同順序相加,這種情況何時發生,以及如何避免。
答案在于核函數(kernels)的實現方式。
為什么核函數不總按相同順序進行加法運算?
如上所述,關于核函數以不同順序進行加法運算的一個常見解釋是“并發+浮點數”假說。該假說認為,如果并發線程的完成順序是非確定性的,并且累加順序依賴于線程完成的順序(例如使用原子加法),那么累加順序也將是非確定性的。
令人困惑的是,盡管這確實會導致非確定性核函數,但并發(以及原子加法)最終與 LLM 推理的非確定性完全無關!為了解釋真正的罪魁禍首是什么,首先來了解一下為什么現代 GPU 核函數很少需要原子加法。
何時需要原子加法?
通常,GPU 會在許多核心(即 SMs)上并發地啟動一個程序。由于這些核心之間沒有內在的同步機制,如果核心之間需要通信,就會帶來挑戰。例如,如果所有核心都必須對同一個元素進行累加,可以使用“原子加法”(atomic add,有時也稱為“fetch-and-add”)。原子加法是非確定性的——結果累加的順序完全取決于哪個核心先完成。
具體來說,想象一下用 100 個核心對一個 100 元素的向量進行規約(reduction,例如 torch.sum())。雖然可以并行加載所有 100 個元素,但最終必須將它們規約成一個單一元素。實現這一目標的一種方法是使用某種“原子加法”原語,硬件保證所有加法都會被處理,但不保證處理順序。

圖2,原子加法確保每個核心的貢獻都會反映在最終的總和中。然而,它不保證貢獻被添加的順序。這個順序完全取決于哪個核心先完成,這是一個非確定性的屬性。因此,多次執行同一個并行程序可能會導致非確定性的輸出。
這通常就是人們所說的非確定性——用完全相同的輸入兩次執行同一個核函數,卻得到了不同的輸出。這被稱為逐次運行非確定性(run-to-run nondeterminism),即兩次運行完全相同的 Python 腳本(具有完全相同的依賴項),但得到的結果不同。
盡管并發的原子加法確實會使核函數非確定性,但絕大多數核函數并不需要原子加法。事實上,在 LLM 的典型前向傳播中,通常一個原子加法都沒有。
這可能令人驚訝,因為并行化規約可以從原子加法中受益。原子加法最終之所以非必需,主要有兩個原因:
1.通常在“批次”維度上有足夠的并行性,因此不需要在規約維度上進行并行化。例如,假設不是規約一個 100 維的向量,而是在并行地規約 500 個向量。在這種情況下,可以在每個核心中規約一整個向量,讓每個核心處理不同的向量。
2.隨著時間的推移,大多數神經網絡庫都采用了各種策略來實現確定性,同時不犧牲性能。例如,可以執行“分裂式”(或樹形)規約,將 100 個元素的規約分解為五個 20 元素的規約(從而實現五路并行)。然后,為了合并剩余的五個元素,可以執行一個單獨的“清理”規約(這個規約不是并行的,但因為它處理的元素足夠少所以成本很低),或者利用一個信號量(semaphore)(它能確保每個并發的線程塊以確定性的順序進行累加)。
由于這兩個因素,對于絕大多數神經網絡操作來說,避免原子加法帶來的性能損失可以忽略不計。
仍有一些常見的操作在避免原子加法時會有顯著的性能損失。例如,PyTorch 中的 scatter_add (a[b] += c)。然而,在 LLM 中唯一常用到的是 FlashAttention 的反向傳播。
然而,LLM 的前向傳播不涉及任何需要原子加法的操作。因此,LLM 的前向傳播實際上是逐次運行確定性的。

圖3,從推理服務器的角度來看,它是確定性的。給定完全相同的用戶請求,它將總是提供相同的確定性輸出。
維基百科寫道:確定性算法是指,對于一個特定的輸入,它將總是產生相同的輸出。”在這種情況下,給定完全相同的輸入(即推理服務器正在處理的完全相同的請求),前向傳播總是會產生完全相同的輸出。
然而,前向傳播本身是確定性的,并不足以保證包含它的系統是確定性的。例如,如果請求的輸出依賴于并行的用戶請求(例如 batch-norm),那會怎樣?由于每個單獨的請求無法知道并行的其他請求會是什么,從它們的角度來看,整個 LLM 推理也是非確定性的!
事實證明,請求的輸出確實依賴于并行的用戶請求。這并不是因為在批次之間泄露了信息——而是因為前向傳播缺乏 “批次不變性” (batch invariance),導致請求的輸出依賴于前向傳播的批次大小。
批次不變性與確定性
為了解釋批次不變性,讓我們簡化系統,只關注矩陣乘法。可以假設所有矩陣乘法的實現都是逐次運行確定性的。然而,它們并不是批次不變的。換句話說,當批次大小改變時,批次中的每個元素都可能得到不同的結果。
從數學角度來看,這是一個相當不尋常的屬性。矩陣乘法在批次中的每個元素上應該是獨立的——批次中的其他元素以及批次的大小都不應影響特定元素的計算結果。
然而,根據經驗可以觀察到,這并非事實。
# python
import torch
torch.set_default_device('cuda')
B = 2048
D = 4096
a = torch.linspace(-1000, 1000, B*D).reshape(B, D)
b = torch.linspace(-1000, 1000, D*D).reshape(D, D)
# 通過取批次的第一個元素進行矩陣向量乘法
out1 = torch.mm(a[:1], b)
# 進行矩陣矩陣乘法,然后取批次的第一個元素
out2 = torch.mm(a, b)[:1]
print((out1 - out2).abs().max()) # tensor(1669.2500, device='cuda:0')注意,這是逐次運行確定性的。如果多次運行此腳本,它將確定性地返回相同的結果。
然而,當一個非批次不變的核函數被用作更大推理系統的一部分時,該系統可能變得非確定性。當向一個推理端點發出查詢時,服務器的負載量從用戶的角度來看實際上是“非確定性”的。負載決定了核函數運行時的批次大小,從而改變了每個獨立請求的最終結果!

圖4,盡管推理服務器本身可以聲稱是確定性的,但對于單個用戶而言,情況則不同。從單個用戶的角度來看,其他并發用戶不是系統的輸入,而是系統的非確定性屬性。這使得 LLM 推理從每個用戶的角度來看是非確定性的。
如果將核函數不具有不變性的某個屬性(即批次大小)與該屬性的非確定性(即服務器負載)組合在一起,就會得到一個非確定性的系統。
換句話說,幾乎所有 LLM 推理端點之所以非確定性,主要原因在于負載(并因此導致批次大小)的非確定性變化! 這種非確定性并非 GPU 所獨有——由 CPU 或 TPU 提供的 LLM 推理端點也會有這個非確定性來源。
因此,如果想在推理服務器中避免非確定性,就必須在核函數中實現批次不變性。為了理解如何實現這一點,首先來看一下為什么核函數一開始就不具備批次不變性。
如何使核函數具備批次不變性?
為了使 Transformer 實現批次不變,必須使其每個核函數都具備批次不變性。幸運的是,可以假設每個逐點(pointwise)操作都是批次不變的。因此,只需要關注涉及規約的 3 個操作——RMSNorm、矩陣乘法和注意力。
方便的是,它們也按難度遞增的順序排列。每一個都需要一些額外的考慮才能在保持合理性能的同時實現批次不變性。先從 RMSNorm 開始。
批次不變的 RMSNorm

圖5:數據并行 RMSNorm,理想情況下,我們希望在并行化策略中避免核心之間的通信。實現這一點的一種方法是為每個批次元素分配一個核心,從而保證每個規約完全在單個核心內完成。這就是所謂的“數據并行”策略,因為我們只是沿著一個不需要通信的維度進行并行化。在這個例子中,有四行和四個核心,正好飽和了我們的核心。
RMSNorm 可以實現為:
# python
# x: [batch_size, hidden_dim]
# weight: [hidden_dim]
def rms_norm(x, weight):
return x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * weight批次不變性的要求是,每個元素的規約順序必須固定,與核函數的批次大小無關。注意,這并不意味著必須總是使用相同的規約策略。例如,如果改變規約的元素數量,即使規約策略改變,仍然可以是批次不變的。
因此,只有當批次大小影響到規約策略時,才會破壞批次不變性。
讓我們看一下 RMSNorm 的標準并行策略。通常,并行算法受益于最小化核心間的通信。因此,可以從一個策略開始,即為每個批次元素分配一個核心,如上圖所示。
增加批次大小并不會影響規約策略;如果 200 的批次大小為核函數提供了足夠的并行度,那么 2000 的批次大小也絕對能提供足夠的并行度。

圖6:更大批次的數據并行 RMSNorm,將數據并行策略擴展到更大的批次是相當直接的——與其讓每個核心處理一行,不如讓每個核心順序處理多行。這保留了批次不變性,因為每個批次元素的規約策略保持不變。
另一方面,減小批次大小可能會帶來挑戰。因為為每個批次元素分配一個核心,減小批次大小最終會導致核心數量多于批次元素,從而使一些核心處于空閑狀態。
遇到這種情況時,一個優秀的核函數工程師會選擇前一節提到的解決方案之一(原子加法或分裂式規約),以保持良好的并行性和性能。不幸的是,這會改變規約策略,從而使該核函數不具備批次不變性。

圖7:分裂式規約 RMSNorm,如果批次大小很小,數據并行策略可能不再有足夠的并行度來飽和核心。在這種情況下,將一個規約“分裂”到多個核心上可能更有效,從而充分利用 GPU。然而,這會失去批次不變性,因為不再以相同的順序對每個元素進行規約。
最簡單的解決方案是完全忽略這些情況。這并非完全不合理——小的批次大小意味著核函數可能執行得很快,因此性能下降可能不是災難性的。
如果非要優化這種情況,一種方法是始終使用一種即使在非常小的批次大小下也具有足夠并行性的規約策略。這樣的規約策略對于較大的批次大小會導致過多的并行性,但能讓我們在整個尺寸范圍內實現可觀(但非峰值)的性能。
批次不變的矩陣乘法

圖8:數據并行矩陣乘法,與 RMSNorm 類似,矩陣乘法的標準并行策略是“數據并行”策略,將整個規約保持在一個核心內。
可以將矩陣乘法看作是一個逐點操作后跟一個規約。然后,如果通過將輸出分塊為多個瓦片(tile)來并行化矩陣乘法,就得到了一個類似的“數據并行”核函數策略,它將每個規約保持在一個核心內。
也與 RMSNorm 類似,批次維度(M 和 N)可能變得太小,迫使我們在規約維度(K)上進行分裂。盡管有兩個“批次”維度,矩陣乘法也需要每個核心有更多的工作量才能有效利用張量核心(tensorcores)。例如,如果有一個 [1024, K] x [K, 1024] 的矩陣乘法和一個標準的 [128, 128] 2D 瓦片大小,數據并行策略只能將這個矩陣乘法分解到 64 個核心中,不足以飽和 GPU。
在矩陣乘法中沿規約維度進行分裂被稱為 Split-K 矩陣乘法。就像 RMSNorm 一樣,使用這種策略會破壞批次不變性1?。

圖9:Split-K 矩陣乘法,如果批次維度相當小,可能沒有足夠的并行性,需要進行 split-k 矩陣乘法。在這個例子中,我們將每個規約分裂到兩個核心上,它們會分別累加,然后在最后合并它們的結果。然而,將每個規約分裂到兩個核心上,使我們仍然可以利用八個核心。
矩陣乘法還有一個額外的復雜性——張量核心指令。對于規約,可以一次只操作一行,而高效的矩陣乘法核函數必須一次操作一整個“瓦片”。
每個張量核心指令(例如 wgmma.mma_async.sync.aligned.m64n128k16)內部可能有不同的規約順序。使用不同張量核心指令的一個原因可能是批次大小非常小。例如,如果使用一個操作 256 長度瓦片的張量核心 PTX 指令,但批次大小只有 32,那么幾乎浪費了所有的計算!在批次大小為 1 時,最快的核函數通常根本不使用張量核心。

圖10:填充的張量核心指令,如果批次大小太小,可能會出現我們甚至無法在輸出中容納一個 2D 瓦片的情況。在這種情況下,最有效的方法是切換到更小的張量核心指令或完全放棄張量核心!然而,這兩種選擇都會使核函數不具備批次不變性。
因此,確保矩陣乘法批次不變性的最簡單方法是編譯一個核函數配置,并將其用于所有形狀。雖然會損失一些性能,但這在 LLM 推理中通常不是災難性的。特別是,split-k 在 M 和 N 都很小時最需要,而幸運的是,在我們的案例中,N(即模型維度)通常非常大!

圖11:性能對比圖,盡管獲得了批次不變性,與 cuBLAS 相比,性能僅損失約 20%。注意,這也不是一個優化的 Triton 核函數(例如,沒有 TMA)。然而,性能圖中的一些模式說明了批次不變性要求在何處損失了性能。首先,請注意在非常小的批次大小時,由于指令過大和并行性不足,性能損失顯著。其次,隨著批次大小增加,存在一個“鋸齒狀”模式,這是由量化效應(瓦片和波次)引起的,通常通過改變瓦片大小來緩解。
批次不變的注意力
在為矩陣乘法獲得批次不變性之后,注意力引入了兩個額外的復雜問題——恰如其名,因為它包含兩個矩陣乘法。
1.與 RMSNorm 和矩陣乘法僅在特征維度上進行規約不同,現在需要在特征維度和序列維度上進行規約
2.由于上述原因,注意力必須處理各種影響序列處理方式的推理優化(如分塊預填充、前綴緩存等)
因此,要在 LLM 推理中實現確定性,數值計算必須對一次處理多少個請求以及每個請求在推理引擎中如何被切分保持不變。
首先,來看一下注意力的標準并行策略,該策略首次在 FlashAttention2 中引入。與 RMSNorm 和 Matmul 類似,默認策略是數據并行策略。由于規約是在鍵/值(key/value)張量上進行的,數據并行策略只能在查詢(query)張量上進行并行化。

圖12:FlashAttention2 策略,沿著 Q 進行并行化,同時沿著 K/V 進行規約。這意味著整個規約可以保持在單個核心內,使其成為另一種數據并行策略。
例如,根據推理引擎的選擇,一個序列可能會被分成幾個部分進行處理(如在分塊預填充中),或者可能一次性全部處理(如果預填充沒有被分割)。為了實現“批次不變性”,一個給定詞元的規約順序必須不依賴于其序列中同時處理的其他詞元的數量。如果在 KV 緩存中的 K/V 值與當前正在處理的詞元的 K/V 值分開進行規約(如 vLLM 的 Triton 注意力核函數中那樣),這是無法實現的。例如,在處理序列中的第 1000 個查詢詞元時,無論 KV 緩存中有 0 個詞元(預填充)還是 999 個詞元(解碼),規約順序都必須相同。

圖13:帶 KV 緩存的 FlashAttention,為什么顯式地分開處理 KV 緩存和當前的 KV 值會破壞批次不變性,原因有些微妙,與“邊界條件”有關。具體來說,假設塊大小為 32,但 KV 緩存中當前有 80 個元素。然后,我們計算另外 48 個未緩存的元素。在這種情況下,需要三個塊(兩個完整塊和一個掩碼塊)來計算“P cache”,以及另外兩個塊(一個完整塊和一個掩碼塊)來計算“P”。因此,總共需要五個塊來計算規約,而我們總共只有四個塊的元素(即 128)需要計算,這肯定會改變我們的規約順序。
要解決這個問題,可以在注意力核函數本身之前更新 KV 緩存和頁表,確保無論處理多少詞元,鍵和值始終以一致的方式布局。
有了這個額外的細節(以及前一節提到的所有內容,如一致的瓦片大小),就能夠實現一個批次不變的注意力!
然而,這里有一個重要問題。與矩陣乘法不同,在 LLM 推理中看到的注意力形狀通常確實需要一個分裂式規約核函數,通常稱為 Split-KV 或 FlashDecoding。這是因為如果不沿規約維度進行并行化,就只能沿批次維度、頭維度和“查詢長度”維度進行并行化。在注意力的解碼階段,查詢長度非常小,因此除非有非常大的批次大小,否則通常無法飽和 GPU。
不幸的是,像 RMSNorm 和 Matmuls 那樣忽略這種情況并不容易。例如,如果有一個非常長的 KV 緩存,注意力核函數可能需要很長時間,盡管只處理一個請求。

圖14:固定數量的 Split-KV 策略(即 FlashDecode),如果查詢長度變得非常小(如解碼期間),核函數中的并行度可能會非常低。在這些情況下,需要再次沿規約維度進行分裂——這次是 KV 維度。如何沿 KV 維度分裂的典型策略是計算需要多少并行度,然后將 KV 維度平均劃分。例如,如果 KV 長度為 1000,需要 4 個分裂,每個核心將處理 250 個元素。
這也不幸地破壞了批次不變性,因為精確的規約策略取決于在任何給定請求中正在處理的序列查詢詞元的數量。
此外,常用于注意力的分裂式規約策略也給批次不變性帶來了挑戰。例如,FlashInfer 的平衡調度算法會選擇能夠飽和所有 GPU 核心的最大分裂尺寸,從而使規約策略不具備“批次不變性”。然而,與 RMSNorm/Matmuls 不同,僅僅選擇一個固定的分裂數量而不考慮批次大小是不夠的。
相反,為了實現批次不變性,必須采用“固定分裂尺寸”策略。換句話說,不是固定分裂的數量,而是固定每個分裂的尺寸,最終得到可變數量的分裂。通過這種方式,可以保證無論正在處理多少詞元,總是執行相同的規約順序。

圖15:固定尺寸的 Split-KV 策略,此策略與前一個策略的唯一區別在于,分裂現在是“固定尺寸”的。例如,如果 KV 長度為 1000,與其將其分成四個長度為 250 的均勻分裂,不如將其分成三個固定尺寸長度為 256 的分裂和一個長度為 232 的分裂。這使我們能夠保留批次不變性,因為規約策略不再依賴于一次處理多少查詢詞元!
實現
通過利用 vLLM 的 FlexAttention 后端以及 torch.Library,提供了一個在 vLLM 之上進行確定性推理的演示。通過 torch.Library,能夠以非侵入性的方式替換掉大部分相關的 PyTorch 操作符。可以在 thinking-machines-lab/batch-invariant-ops 找到“批次不變”核函數庫,以及在“確定性”模式下運行的 vLLM 示例
地址:https://github.com/thinking-machines-lab/batch_invariant_ops
實驗
完成結果的非確定性有多嚴重?
使用 Qwen/Qwen3-235B-A22B-Instruct-2507 模型,在溫度為 0 的情況下,以提示“Tell me about Richard Feynman”采樣 1000 次補全,每次生成 1000 個詞元。令人驚訝的是,產生了 80 個獨特的補全結果,其中最常見的出現了 78 次。
觀察補全結果的差異之處,可以看到它們在前 102 個詞元上實際上是相同的!分歧的首次出現是在第 103 個詞元處。所有補全都生成了序列“Feynman was born on May 11, 1918, in”。然而,其中 992 個補全接著生成了“Queens, New York”,而 8 個補全生成了“New York City”。
另一方面,當啟用批次不變的核函數時,所有 1000 個補全結果都是相同的。這正是從采樣器中數學上期望得到的結果,但沒有批次不變的核函數是無法實現確定性結果的。
性能
這里沒有在優化批次不變核函數的性能上投入大量精力。然而,還是進行了一些實驗來驗證其性能仍然可用。
將設置一個 API 服務器,用一個 GPU 運行 Qwen-3-8B,并請求 1000 個序列,輸出長度在 90 到 110 之間。
配置 | 時間 (秒) |
vLLM 默認 | 26 |
未優化的確定性 vLLM | 55 |
+ 改進的注意力核函數 | 42 |
大部分的性能下降來自于 vLLM 中 FlexAttention 集成尚未經過高度優化。盡管如此,可以看到性能并非災難性的。
真正的同策略強化學習
正如研究人員所指出的,訓練和推理之間的不同數值計算,不知不覺地將同策略(on-policy)強化學習變成了異策略(off-policy)強化學習。
當然,如果甚至無法在兩次相同的推理請求之間獲得逐位元相同的結果,那么在訓練和推理之間獲得逐位元相同的結果是不可能的。確定性推理使我們也能修改訓練棧,以獲得采樣和訓練之間逐位元相同的結果,從而實現真正的同策略強化學習。
在一個 RLVR 設置中,在 Bigmath 上進行了實驗,RL 策略由 Qwen 2.5-VL instruct 8B 初始化,最大 rollout 長度為 4096。
如果在沒有異策略校正(即重要性權重)的情況下進行訓練,獎勵會在訓練中途崩潰,而添加異策略校正項則允許訓練順利進行。但是,如果在采樣器和訓練器之間實現了逐位元相同的結果,就完全處于同策略(即 0 KL 散度)狀態,并且也能順利訓練。
還可以繪制采樣器和訓練器之間 logprobs 的 KL 散度圖,其中 3 次運行的行為有顯著不同。當使用重要性權重運行時,它保持在 0.001 左右,偶爾出現尖峰。然而,不使用重要性權重運行最終會導致 KL 散度出現尖峰,大約在獎勵崩潰的同一時間。當然,當運行“真正的同策略 RL”時,KL 散度保持在 0,表明訓練策略和采樣策略之間沒有分歧。

圖16,注意,沒有重要性權重的運行在第 318 步附近有一個顯著的損失尖峰,并且這伴隨著 logprobs 的 KL 散度出現相應的尖峰。與此同時,無論是使用異策略校正還是運行“真正的同策略”都允許 RL 繼續順利進行。顯示“真正的同策略”的藍線不是一個 bug——它就是一條平坦的 0 線。
結論
現代軟件系統包含許多抽象層。在機器學習中,當遇到非確定性和微小的數值差異時,很容易將它們掩蓋過去。畢竟,系統已經是概率性的了,多一點非確定性又有什么關系呢?在失敗的單元測試中提高 atol/rtol 的容忍度又有什么問題呢?訓練器和采樣器之間 logprobs 的差異可能不是一個真正的 bug,對吧?
本文反對這種失敗主義的態度。通過一些努力,可以理解非確定性的根本原因,甚至解決它們!希望這篇博客文章能夠為社區提供一個堅實的理解,關于如何解決推理系統中的非確定性問題,并激勵其他人去完全理解他們的系統。
source:https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/




























