從零實(shí)現(xiàn)一個(gè)17M參數(shù)的GPT預(yù)訓(xùn)練模型

大家好,我是寫代碼的中年人!
今天我們使用開(kāi)源的的中文數(shù)據(jù)進(jìn)行模型的預(yù)訓(xùn)練,下面跟著我的步驟,從零實(shí)現(xiàn)你的預(yù)訓(xùn)練模型。
本文所有代碼和數(shù)據(jù)資源位置:
https://github.com/ColinAIAPP/MoiraiLM
01、預(yù)訓(xùn)練模型的概念
預(yù)訓(xùn)練模型(Pretrained Model)就是一個(gè)已經(jīng)在海量數(shù)據(jù)上訓(xùn)練過(guò)的模型,它學(xué)會(huì)了語(yǔ)言的基本規(guī)律、結(jié)構(gòu)和語(yǔ)義,然后可以拿來(lái)做各種下游任務(wù),比如寫作、翻譯、問(wèn)答、分類、生成代碼等。
那“預(yù)訓(xùn)練”到底在學(xué)什么?以語(yǔ)言模型(LLM)為例:預(yù)訓(xùn)練階段的任務(wù)通常是預(yù)測(cè)下一個(gè)詞(token)。
接下來(lái)我們就一步一步實(shí)現(xiàn)一個(gè)17M參數(shù)的預(yù)訓(xùn)練模型。
02、數(shù)據(jù)準(zhǔn)備
構(gòu)建語(yǔ)言模型的第一要義是高質(zhì)量的數(shù)據(jù)源。對(duì)于中文任務(wù),選擇維基百科開(kāi)源中文數(shù)據(jù)集是一個(gè)理想起點(diǎn)。這個(gè)數(shù)據(jù)集包含數(shù)百萬(wàn)條中文百科條目,涵蓋歷史、文化、科技等領(lǐng)域,總量約數(shù)GB的純文本數(shù)據(jù)。它開(kāi)源且免費(fèi),可通過(guò)維基百科的官方轉(zhuǎn)儲(chǔ)頁(yè)面下載最新版本的XML格式文件。
要解壓處理這個(gè)文件我們要使用wikiextractor工具進(jìn)行數(shù)據(jù)解壓。
安裝解壓命令:
pip install wikiextractor解壓命令:
python -m wikiextractor.WikiExtractor -b 1G -o extracted_wiki_zh zhwiki-20250920-pages-articles-multistream.xml.bz2 --json
zhwiki-20250920-pages-articles-multistream.xml.bz2:為文件名INFO: Preprocessing 'zhwiki-20250920-pages-articles-multistream.xml.bz2' to collect template definitions: this may take some time.
INFO: Preprocessed 100000 pages
INFO: Preprocessed 200000 pages
INFO: Preprocessed 300000 pages
INFO: Preprocessed 400000 pages
INFO: Preprocessed 500000 pages
INFO: Preprocessed 600000 pages
INFO: Preprocessed 700000 pages
INFO: Preprocessed 800000 pages
INFO: Preprocessed 900000 pages
INFO: Preprocessed 1000000 pages
INFO: Preprocessed 1100000 pages
INFO: Preprocessed 1200000 pages
INFO: Preprocessed 1300000 pages
INFO: Preprocessed 1400000 pages
INFO: Preprocessed 1500000 pages
INFO: Preprocessed 1600000 pages
INFO: Preprocessed 1700000 pages
INFO: Preprocessed 1800000 pages
INFO: Preprocessed 1900000 pages
INFO: Preprocessed 2000000 pages
INFO: Preprocessed 2100000 pages
INFO: Preprocessed 2200000 pages
INFO: Preprocessed 2300000 pages
INFO: Preprocessed 2400000 pages
INFO: Preprocessed 2500000 pages
INFO: Preprocessed 2600000 pages
INFO: Preprocessed 2700000 pages
INFO: Preprocessed 2800000 pages
INFO: Preprocessed 2900000 pages
INFO: Preprocessed 3000000 pages
INFO: Preprocessed 3100000 pages
INFO: Preprocessed 3200000 pages
INFO: Preprocessed 3300000 pages
INFO: Preprocessed 3400000 pages
INFO: Preprocessed 3500000 pages
INFO: Preprocessed 3600000 pages
INFO: Preprocessed 3700000 pages
INFO: Preprocessed 3800000 pages
INFO: Preprocessed 3900000 pages
INFO: Preprocessed 4000000 pages
INFO: Preprocessed 4100000 pages
INFO: Preprocessed 4200000 pages
INFO: Preprocessed 4300000 pages
INFO: Preprocessed 4400000 pages
INFO: Preprocessed 4500000 pages
INFO: Preprocessed 4600000 pages
INFO: Preprocessed 4700000 pages
INFO: Loaded 1036734 templates in 704.2s
INFO: Starting page extraction from zhwiki-20250920-pages-articles-multistream.xml.bz2.
INFO: Using 127 extract processes.
INFO: Extracted 100000 articles (1209.6 art/s)
INFO: Extracted 200000 articles (1947.8 art/s)
INFO: Extracted 300000 articles (2325.1 art/s)
INFO: Extracted 400000 articles (3471.3 art/s)
INFO: Extracted 500000 articles (2551.1 art/s)
INFO: Extracted 600000 articles (2239.4 art/s)
INFO: Extracted 700000 articles (2299.3 art/s)
INFO: Extracted 800000 articles (1525.2 art/s)
INFO: Extracted 900000 articles (3256.1 art/s)
INFO: Extracted 1000000 articles (3485.9 art/s)
INFO: Extracted 1100000 articles (3495.0 art/s)
INFO: Extracted 1200000 articles (3330.4 art/s)
INFO: Extracted 1300000 articles (3555.6 art/s)
INFO: Extracted 1400000 articles (3456.3 art/s)
INFO: Extracted 1500000 articles (2476.1 art/s)
INFO: Extracted 1600000 articles (2268.6 art/s)
INFO: Extracted 1700000 articles (2473.5 art/s)
INFO: Extracted 1800000 articles (2305.9 art/s)
INFO: Extracted 1900000 articles (2263.9 art/s)
INFO: Extracted 2000000 articles (2136.4 art/s)
INFO: Extracted 2100000 articles (2363.0 art/s)
INFO: Extracted 2200000 articles (2601.9 art/s)
INFO: Extracted 2300000 articles (3709.0 art/s)
INFO: Extracted 2400000 articles (2723.9 art/s)
INFO: Extracted 2500000 articles (2487.1 art/s)
INFO: Extracted 2600000 articles (2621.3 art/s)
INFO: Extracted 2700000 articles (2525.4 art/s)
INFO: Extracted 2800000 articles (2666.4 art/s)
INFO: Finished 127-process extraction of 2893023 articles in 1156.5s (2501.5 art/s)03、清洗數(shù)據(jù)
我們解壓后的數(shù)據(jù)如下圖,下面我們要把數(shù)據(jù)清洗出來(lái)。
注:
我們本步驟生成的文件為 data/cleaned_wiki_full.txt
import os
import json
import logging
import argparse
import re
from tqdm import tqdm
# 配置日志記錄
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
# python scripts/clean_wiki_text.py data/extracted_wiki_zh data/cleaned_wiki_full.txt --min_line_length 20 --min_article_length 300
def clean_text(text: str) -> str:
"""
對(duì)文本進(jìn)行深度清洗。
移除維基百科特有的格式標(biāo)記、參考文獻(xiàn)、HTML標(biāo)簽、日期和數(shù)字等。
"""
# 移除維基鏈接 [[link|display]] 或 [[link]]
text = re.sub(r'\[\[([^\]|]+\|)?([^\]]+)\]\]', r'\2', text)
# 移除參考文獻(xiàn)標(biāo)記 [1], [2], [ref], 等
text = re.sub(r'\[\d+\]|\[ref\]|\[/ref\]|\[citation needed\]', '', text)
# 移除HTML標(biāo)簽
text = re.sub(r'<[^>]+>', '', text)
# 移除日期格式 (yyyy-mm-dd, yyyy/mm/dd, mm/dd/yyyy 等)
text = re.sub(r'\d{1,4}[-/]\d{1,2}[-/]\d{1,4}', '', text)
# 移除年份 (1000-2999)
text = re.sub(r'\b[12]\d{3}\b', '', text)
# 移除純數(shù)字(包括小數(shù))
text = re.sub(r'\b\d+\.?\d*\b', '', text)
# 移除重復(fù)的空白字符(但保留單個(gè)空格)
text = re.sub(r' +', ' ', text)
# 移除行首尾空白
text = text.strip()
return text
def process_extracted_wiki(extracted_dir: str,
output_file: str,
min_line_length: int = 20,
min_article_length: int = 200):
"""
讀取WikiExtractor輸出的JSON文件,提取、清洗文本并保存到單個(gè)文件中。
參數(shù):
extracted_dir: WikiExtractor輸出的目錄路徑
output_file: 最終合并的純文本文件路徑
min_line_length: 單行文本最小長(zhǎng)度,用于過(guò)濾噪音(默認(rèn): 20)
min_article_length: 文章最小長(zhǎng)度,用于過(guò)濾短文章(默認(rèn): 200)
"""
if not os.path.isdir(extracted_dir):
logging.error(f"輸入的目錄不存在: {extracted_dir}")
return
total_articles = 0
skipped_articles = 0
# 第一次遍歷:獲取所有需要處理的文件列表
file_list = []
for root, dirs, files in os.walk(extracted_dir):
for file_name in files:
# 僅處理 WikiExtractor 生成的以 'wiki_' 開(kāi)頭的文件
if file_name.startswith('wiki_'):
file_list.append(os.path.join(root, file_name))
total_files = len(file_list)
logging.info(f"找到 {total_files} 個(gè)文件等待處理。")
if total_files == 0:
logging.warning(f"目錄 {extracted_dir} 中未找到任何 'wiki_' 文件。請(qǐng)檢查路徑。")
return
# 第二次遍歷:處理文件并寫入輸出
with open(output_file, 'w', encoding='utf-8') as f_out:
# 使用 tqdm 包裝文件列表,顯示處理進(jìn)度
for file_path in tqdm(file_list, desc="?? 正在提取維基文本"):
try:
with open(file_path, 'r', encoding='utf-8') as f_in:
for line_num, line in enumerate(f_in, 1):
try:
article = json.loads(line)
text_content = article.get('text', '').strip()
# --- 文本清洗和過(guò)濾 ---
# 1. 過(guò)濾掉過(guò)短的文章,它們通常是噪音或重定向頁(yè)
if len(text_content) < min_article_length:
skipped_articles += 1
continue
# 2. 按行處理文本,過(guò)濾短行和額外的空白
# 保留行結(jié)構(gòu),而不是將所有行連接成一個(gè)長(zhǎng)句子
cleaned_lines = []
for text_line in text_content.split('\n'):
text_line = clean_text(text_line)
# 只保留足夠長(zhǎng)的行
if len(text_line) >= min_line_length:
cleaned_lines.append(text_line)
# 使用換行符連接各行,保留段落結(jié)構(gòu)
final_text = '\n'.join(cleaned_lines)
# 最終檢查:確保清洗后的文本仍然足夠長(zhǎng)
if final_text and len(final_text) >= min_article_length:
# 文章之間用兩個(gè)換行符分隔
f_out.write(final_text + '\n\n')
total_articles += 1
else:
skipped_articles += 1
except json.JSONDecodeError:
logging.warning(f"無(wú)法解析 JSON,文件: {file_path},行號(hào): {line_num}")
except Exception as e:
logging.error(f"處理文件 {file_path} 第 {line_num} 行時(shí)出錯(cuò): {e}")
except Exception as e:
logging.error(f"打開(kāi)文件 {file_path} 時(shí)出錯(cuò): {e}")
logging.info(f" 所有維基百科文本已成功提取并清洗。")
logging.info(f" 總文章數(shù): {total_articles}")
logging.info(f" 跳過(guò)文章數(shù): {skipped_articles}")
logging.info(f" 文件已保存到: {output_file}")
def main():
parser = argparse.ArgumentParser(
descriptinotallow="從 WikiExtractor 輸出的 JSON 文件中提取并清洗純文本。",
formatter_class=argparse.RawTextHelpFormatter
)
# 位置參數(shù) 1: 輸入目錄
parser.add_argument(
"extracted_directory",
type=str,
help="WikiExtractor 輸出的目錄路徑 (e.g., extracted_wiki_zh)"
)
# 位置參數(shù) 2: 輸出文件
parser.add_argument(
"output_filename",
type=str,
help="最終合并的純文本文件路徑 (e.g., cleaned_wiki.txt)"
)
# 可選參數(shù): 最小行長(zhǎng)
parser.add_argument(
"--min_line_length",
type=int,
default=20,
help="文章中單行文本必須達(dá)到的最小長(zhǎng)度,用于過(guò)濾噪音。默認(rèn)值: 20"
)
# 可選參數(shù): 最小文章長(zhǎng)度
parser.add_argument(
"--min_article_length",
type=int,
default=200,
help="文章最小長(zhǎng)度,用于過(guò)濾短文章和重定向頁(yè)。默認(rèn)值: 200"
)
args = parser.parse_args()
process_extracted_wiki(
args.extracted_directory,
args.output_filename,
args.min_line_length,
args.min_article_length
)
if __name__ == "__main__":
main()2025-10-01 11:10:58,772 - INFO - 找到 5 個(gè)文件等待處理。
正在提取維基文本: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00, 6.78s/it]
2025-10-01 11:11:32,681 - INFO - 所有維基百科文本已成功提取。總文章數(shù): 628093。文件已保存到 data/cleaned_wiki_full.txt04、訓(xùn)練分詞器
我們使用SentencePiece訓(xùn)練分詞器,本次我們訓(xùn)練的分詞庫(kù)大小為16k,你也可以訓(xùn)練32k的分詞庫(kù)。相關(guān)代碼及過(guò)程如下:
注:
我們本步驟生成的文件為
workdir/spm_wiki_16k.model
workdir/spm_wiki_16k.vocab
import sys
import sentencepiece as spm
import argparse
import os
from tqdm import tqdm
# python scripts/train_tokenizer.py data/cleaned_wiki_full.txt workdir/spm_wiki 32000
def get_corpus_size(input_file: str) -> int:
"""計(jì)算語(yǔ)料的總行數(shù)和文件大小"""
try:
file_size_bytes = os.path.getsize(input_file)
file_size_mb = file_size_bytes / (1024 * 1024)
print(f"語(yǔ)料文件大小: {file_size_mb:.2f} MB")
# 計(jì)算行數(shù)和總字符數(shù)
line_count = 0
total_chars = 0
with open(input_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="統(tǒng)計(jì)語(yǔ)料信息"):
line_count += 1
total_chars += len(line)
print(f"語(yǔ)料總行數(shù) (文章數(shù)): {line_count}")
print(f"總字符數(shù): {total_chars:,}")
print(f"平均每行字符數(shù): {total_chars / line_count:.1f}")
return file_size_bytes
except Exception as e:
print(f"警告:無(wú)法計(jì)算文件大小或行數(shù):{e}")
return 0
def train_spm_model(input_file: str,
model_prefix: str,
vocab_size: int,
model_type: str = 'bpe',
character_coverage: float = 0.9995):
"""
訓(xùn)練一個(gè)SentencePiece分詞器模型。
參數(shù):
input_file: 訓(xùn)練語(yǔ)料文件路徑
model_prefix: 輸出模型文件的前綴
vocab_size: 詞匯表大小
model_type: 分詞算法類型 ('bpe', 'unigram', 'char', 'word')
character_coverage: 字符覆蓋率 (0-1,通常 0.995-0.9995)
"""
if not os.path.exists(input_file):
print(f"錯(cuò)誤:輸入語(yǔ)料文件未找到:{input_file}")
sys.exit(1)
# 確保輸出目錄存在
output_dir = os.path.dirname(model_prefix)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
print(f"已創(chuàng)建輸出目錄: {output_dir}")
# 打印語(yǔ)料規(guī)模信息
print("\n=== 語(yǔ)料分析 ===")
get_corpus_size(input_file)
# 構(gòu)建訓(xùn)練參數(shù)
# 對(duì)于 1.5GB 語(yǔ)料,建議啟用 train_extremely_large_corpus=True 加速
train_params = {
'input': input_file,
'model_prefix': model_prefix,
'vocab_size': vocab_size,
'model_type': model_type,
'character_coverage': character_coverage,
'num_threads': 32, # 增加到32(最大化CPU利用)
'bos_id': 0,
'eos_id': 1,
'unk_id': 2,
'pad_id': -1,
'normalization_rule_name': 'identity',
'input_sentence_size': 2000000, # 5000000, # 增加到500萬(wàn)句子采樣
'train_extremely_large_corpus': True, # 必須啟用
'shuffle_input_sentence': True,
'seed_sentencepiece_size': 2000000, # 添加種子句子大小
'hard_vocab_limit': False, # 允許超過(guò)目標(biāo)詞匯量以獲得更好質(zhì)量
}
print("\n=== SentencePiece 訓(xùn)練參數(shù) ===")
for key, value in train_params.items():
print(f" {key}: {value}")
print("=" * 35)
print("\n正在訓(xùn)練 SentencePiece 模型...")
print(" (請(qǐng)稍候,進(jìn)度由 SentencePiece 輸出)\n")
try:
# 執(zhí)行訓(xùn)練
spm.SentencePieceTrainer.train(**train_params)
print("\n分詞器模型訓(xùn)練完成!")
print(f" 模型文件: {model_prefix}.model")
print(f" 詞匯表文件: {model_prefix}.vocab")
# 驗(yàn)證模型是否成功創(chuàng)建
if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):
model_size_kb = os.path.getsize(f"{model_prefix}.model") / 1024
print(f"\n模型文件大小: {model_size_kb:.2f} KB")
# 加載模型進(jìn)行快速測(cè)試
print("\n進(jìn)行快速測(cè)試...")
sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model")
test_text = "這是一個(gè)分詞測(cè)試句子。"
tokens = sp.encode(test_text, out_type=str)
ids = sp.encode(test_text, out_type=int)
print(f" 測(cè)試文本: {test_text}")
print(f" 分詞結(jié)果: {tokens}")
print(f" Token IDs: {ids}")
else:
print("\n警告:模型文件生成失敗,請(qǐng)檢查輸入數(shù)據(jù)或參數(shù)")
except Exception as e:
print(f"\n訓(xùn)練過(guò)程出錯(cuò): {e}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
descriptinotallow="使用 SentencePiece 訓(xùn)練分詞器模型。",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"input_file",
type=str,
help="訓(xùn)練語(yǔ)料的路徑 (e.g., data/cleaned_wiki_full.txt)"
)
parser.add_argument(
"model_prefix",
type=str,
help="訓(xùn)練模型文件的輸出前綴 (e.g., workdir/spm_wiki)"
)
parser.add_argument(
"vocab_size",
type=int,
help="詞匯表大小 (e.g., 32000)"
)
parser.add_argument(
"--model_type",
type=str,
default='bpe',
choices=['bpe', 'unigram', 'char', 'word'],
help="分詞算法類型 (默認(rèn): bpe)"
)
parser.add_argument(
"--character_coverage",
type=float,
default=0.9995,
help="字符覆蓋率,范圍 [0-1]。對(duì)于小詞表(8K),建議用0.99或更小"
)
args = parser.parse_args()
print("\n" + "="*50)
print("SentencePiece 分詞器訓(xùn)練程序")
print("="*50)
print(f"輸入語(yǔ)料: {args.input_file}")
print(f"輸出模型前綴: {args.model_prefix}")
print(f"詞匯表大小: {args.vocab_size}")
print(f"分詞算法: {args.model_type}")
print(f"字符覆蓋率: {args.character_coverage}")
print("="*50 + "\n")
train_spm_model(
args.input_file,
args.model_prefix,
args.vocab_size,
args.model_type,
args.character_coverage
)
if __name__ == "__main__":
main()開(kāi)始訓(xùn)練SentencePiece分詞器...
輸入語(yǔ)料: data/cleaned_wiki_full.txt
輸出模型前綴: workdir/spm_wiki_16k
詞匯表大小: 16000
語(yǔ)料文件大小: 1697.54 MB
Counting lines: 1256186it [00:05, 230354.42it/s]
語(yǔ)料總行數(shù) (文章數(shù)): 1256186
--- SentencePiece 訓(xùn)練參數(shù) ---
--input=data/cleaned_wiki_full.txt
--model_prefix=workdir/spm_wiki_16k
--vocab_size=16000
--model_type=bpe
--character_coverage=0.9995
--num_threads=16
--bos_id=0
--eos_id=1
--unk_id=2
--pad_id=-1
------------------------------
? 正在啟動(dòng)訓(xùn)練... 請(qǐng)注意觀察 SentencePiece 自身的進(jìn)度輸出。
sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=data/cleaned_wiki_full.txt --model_prefix=workdir/spm_colinai_16000 --vocab_size=16000 --model_type=bpe --character_coverage=0.9995 --num_threads=16 --bos_id=0 --eos_id=1 --unk_id=2 --pad_id=-1
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with :
trainer_spec {
input: data/cleaned_wiki_full.txt
input_format:
model_prefix: workdir/spm_colinai_16000
model_type: BPE
vocab_size: 16000
self_test_sample_size: 0
character_coverage: 0.9995
input_sentence_size: 0
shuffle_input_sentence: 1
seed_sentencepiece_size: 1000000
shrinking_factor: 0.75
max_sentence_length: 4192
num_threads: 16
num_sub_iterations: 2
max_sentencepiece_length: 16
split_by_unicode_script: 1
split_by_number: 1
split_by_whitespace: 1
split_digits: 0
pretokenization_delimiter:
treat_whitespace_as_suffix: 0
allow_whitespace_only_pieces: 0
required_chars:
byte_fallback: 0
vocabulary_output_piece_score: 1
train_extremely_large_corpus: 0
seed_sentencepieces_file:
hard_vocab_limit: 1
use_all_vocab: 0
unk_id: 2
bos_id: 0
eos_id: 1
pad_id: -1
unk_piece: <unk>
bos_piece: <s>
eos_piece: </s>
pad_piece: <pad>
unk_surface: ?
enable_differential_privacy: 0
differential_privacy_noise_level: 0
differential_privacy_clipping_threshold: 0
}
normalizer_spec {
name: nmt_nfkc
add_dummy_prefix: 1
remove_extra_whitespaces: 1
escape_whitespaces: 1
normalization_rule_tsv:
}
denormalizer_spec {}
trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(186) LOG(INFO) Loading corpus: data/cleaned_wiki_full.txt
trainer_interface.cc(382) LOG(WARNING) Found too long line (18615 > 4192).
trainer_interface.cc(384) LOG(WARNING) Too long lines are skipped in the training.
trainer_interface.cc(385) LOG(WARNING) The maximum length can be changed with --max_sentence_length=<size> flag.
trainer_interface.cc(411) LOG(INFO) Loaded all 528882 sentences
trainer_interface.cc(418) LOG(INFO) Skipped 99211 too long sentences.
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(432) LOG(INFO) Normalizing sentences...
trainer_interface.cc(541) LOG(INFO) all chars count=281809036
trainer_interface.cc(552) LOG(INFO) Done: 99.95% characters are covered.
trainer_interface.cc(562) LOG(INFO) Alphabet size=8686
trainer_interface.cc(563) LOG(INFO) Final character coverage=0.9995
trainer_interface.cc(594) LOG(INFO) Done! preprocessed 528882 sentences.
trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 528882
trainer_interface.cc(611) LOG(INFO) Done! 3885388
.....05、原始文本轉(zhuǎn)為Token ID 序列
在訓(xùn)練大型語(yǔ)言模型的準(zhǔn)備階段,將海量文本語(yǔ)料轉(zhuǎn)化為模型可處理的數(shù)字格式至關(guān)重要。本次將原始文本語(yǔ)料編碼為整數(shù) Token ID 序列。為了克服單次加載大文件的內(nèi)存限制,腳本采用了分塊讀取機(jī)制,支持以自定義大小逐塊處理語(yǔ)料。所有 Token ID 最終被匯總并轉(zhuǎn)化為高效率的 torch.int32 PyTorch 張量,直接存儲(chǔ)為 .pt 文件。這不僅優(yōu)化了數(shù)據(jù)格式,方便后續(xù) PyTorch DataLoader 快速讀取,同時(shí)也提供了關(guān)鍵的統(tǒng)計(jì)信息和完整性驗(yàn)證,是構(gòu)建 LLM 數(shù)據(jù)集的穩(wěn)定且高性能的預(yù)處理方案。
import sys
import torch
import sentencepiece as spm
import argparse
from tqdm import tqdm
import os
import numpy as np
# python scripts/preprocess_data.py workdir/spm_wiki.model data/cleaned_wiki_full.txt workdir/wiki_tokens.pt
def preprocess(sp_model_path: str,
corpus_path: str,
output_path: str,
chunk_size_mb: int = 50):
"""
分塊讀取語(yǔ)料,編碼為 Token ID,并保存為 PyTorch 文件。
參數(shù):
sp_model_path: SentencePiece 模型文件路徑
corpus_path: 輸入語(yǔ)料文件路徑
output_path: 輸出 .pt 文件路徑
chunk_size_mb: 每次處理的文本大小(MB),默認(rèn) 50MB
"""
# 驗(yàn)證文件存在
if not os.path.exists(sp_model_path):
print(f"錯(cuò)誤:分詞器模型文件未找到: {sp_model_path}")
sys.exit(1)
if not os.path.exists(corpus_path):
print(f"錯(cuò)誤:語(yǔ)料文件未找到: {corpus_path}")
sys.exit(1)
# 加載分詞器
try:
sp = spm.SentencePieceProcessor(model_file=sp_model_path)
vocab_size = sp.get_piece_size()
print(f" 分詞器加載成功")
print(f" 詞匯表大小: {vocab_size}")
print(f" 特殊 Token: BOS={sp.bos_id()}, EOS={sp.eos_id()}, UNK={sp.unk_id()}, PAD={sp.pad_id()}")
except Exception as e:
print(f"加載分詞器失敗: {e}")
sys.exit(1)
# 確保輸出目錄存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
print(f"\n 開(kāi)始處理語(yǔ)料...")
print(f" 輸入文件: {corpus_path}")
print(f" 輸出文件: {output_path}")
print(f" 塊大小: {chunk_size_mb} MB\n")
# 計(jì)算總大小用于進(jìn)度條
total_bytes = os.path.getsize(corpus_path)
chunk_size_bytes = chunk_size_mb * 1024 * 1024
token_ids = []
tokens_processed = 0
chunks_processed = 0
try:
with open(corpus_path, 'r', encoding='utf-8') as f:
with tqdm(total=total_bytes, unit='B', unit_scale=True, desc="? 編碼語(yǔ)料") as pbar:
while True:
chunk = f.read(chunk_size_bytes)
if not chunk:
break
# 直接編碼(cleaned_wiki_full.txt 已經(jīng)過(guò)清洗)
ids = sp.encode(chunk, out_type=int)
token_ids.extend(ids)
# 更新進(jìn)度條
bytes_read = len(chunk.encode('utf-8'))
pbar.update(bytes_read)
tokens_processed += len(ids)
chunks_processed += 1
# 定期顯示進(jìn)度信息
if chunks_processed % 10 == 0:
pbar.set_postfix({
'chunks': chunks_processed,
'tokens': f'{tokens_processed:,}'
})
print(f"\n 編碼完成")
print(f" 處理塊數(shù): {chunks_processed}")
print(f" 總 Token 數(shù): {tokens_processed:,}")
# 轉(zhuǎn)換為 PyTorch 張量
print(f"\n轉(zhuǎn)換為張量并保存...")
final_tensor = torch.tensor(token_ids, dtype=torch.int32)
print(f" 張量形狀: {final_tensor.shape}")
print(f" 張量大小: {final_tensor.numel():,}")
print(f" 數(shù)據(jù)類型: {final_tensor.dtype}")
print(f" 占用內(nèi)存: {final_tensor.numel() * 4 / (1024**3):.2f} GB")
# 驗(yàn)證 Token ID 范圍
min_id = final_tensor.min().item()
max_id = final_tensor.max().item()
print(f" Token ID 范圍: [{min_id}, {max_id}]")
if max_id >= vocab_size or min_id < 0:
print(f" 警告: 檢測(cè)到越界 Token ID!")
print(f" 詞匯表大小: {vocab_size}")
print(f" 最大 ID: {max_id}")
# 保存張量
torch.save(final_tensor, output_path)
file_size_mb = os.path.getsize(output_path) / (1024 ** 2)
print(f"\nToken ID 已保存到 {output_path}")
print(f" 文件大小: {file_size_mb:.2f} MB")
# 驗(yàn)證保存的文件
print(f"\n驗(yàn)證保存的文件...")
loaded_tensor = torch.load(output_path)
print(f" 加載成功,形狀: {loaded_tensor.shape}")
print(f" 是否相同: {torch.equal(final_tensor, loaded_tensor)}")
print(f"\n? 預(yù)處理完成!")
except Exception as e:
print(f"\n處理過(guò)程中出錯(cuò): {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
descriptinotallow="將清洗后的文本語(yǔ)料轉(zhuǎn)換為 Token ID 二進(jìn)制文件。",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"model_path",
type=str,
help="SentencePiece 模型文件路徑 (e.g., workdir/spm_wiki.model)"
)
parser.add_argument(
"corpus_path",
type=str,
help="輸入語(yǔ)料文件路徑 (e.g., data/cleaned_wiki_full.txt)"
)
parser.add_argument(
"output_path",
type=str,
help="輸出 Token ID 文件路徑 (e.g., workdir/wiki_tokens.pt)"
)
parser.add_argument(
"--chunk_size",
type=int,
default=50,
help="每次處理的文本大小(MB),默認(rèn) 50MB。更大的塊更快,但占用更多內(nèi)存。"
)
args = parser.parse_args()
print("\n" + "="*60)
print("數(shù)據(jù)預(yù)處理程序 - 文本到 Token ID")
print("="*60)
print(f"SentencePiece 模型: {args.model_path}")
print(f"輸入語(yǔ)料: {args.corpus_path}")
print(f"輸出文件: {args.output_path}")
print(f"塊大小: {args.chunk_size} MB")
print("="*60 + "\n")
preprocess(
args.model_path,
args.corpus_path,
args.output_path,
args.chunk_size
)
if __name__ == "__main__":
main()06、進(jìn)行模型預(yù)訓(xùn)練
"""
GPT 高性能訓(xùn)練腳本
"""
from __future__ import annotations
import sys
import os
import math
import json
from datetime import datetime
from typing import Optional
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
from tqdm import tqdm
# ==================== 配置參數(shù) ====================
class Config:
BLOCK_SIZE = 512 #256
BATCH_SIZE = 32 #64
GRAD_ACCUM_STEPS = 4 #1
MODEL_DIM = 384 #256
N_LAYERS = 5 #2
NUM_HEADS = 6 #4
HEAD_DIM = MODEL_DIM // NUM_HEADS
FFN_DIM = MODEL_DIM * 4
VOCAB_SIZE = None
EPOCHS = 1
MAX_STEPS = 10000 # 此處根據(jù)自己的硬件和時(shí)間定義步數(shù)
WARMUP_STEPS = 500
LR = 1e-4
MIN_LR = 1e-5
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
DROPOUT = 0.1
CHECKPOINT_EVERY = 5000
LOG_EVERY = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_DIR = "./checkpoints"
LATEST_CHECKPOINT = "latest_checkpoint.pth"
NUM_WORKERS = 8
SEED = 42
# 啟用 bfloat16 (推薦用于現(xiàn)代 GPU)
DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
CFG = Config()
if CFG.DEVICE == 'cuda':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
# 檢查是否使用了 bfloat16
if CFG.DTYPE == torch.bfloat16:
print("使用 bfloat16 混合精度 (推薦)")
else:
print("使用 float16 混合精度")
# ==================== 工具函數(shù) ====================
def print_gpu_memory():
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024**3)
reserved = torch.cuda.memory_reserved() / (1024**3)
print(f"GPU顯存: {allocated:.2f}GB / {reserved:.2f}GB")
def set_seed(seed: int):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(CFG.SEED)
# ==================== 數(shù)據(jù)集 ====================
class TextDataset(Dataset):
def __init__(self, token_ids: torch.Tensor, block_size: int):
self.ids = token_ids.long()
self.block_size = block_size
def __len__(self):
return max(0, self.ids.size(0) - self.block_size)
def __getitem__(self, idx):
x = self.ids[idx: idx + self.block_size]
y = self.ids[idx + 1: idx + 1 + self.block_size]
return x, y
# ==================== RoPE 位置編碼 ====================
class RotaryPositionalEmbedding(nn.Module):
"""RoPE 實(shí)現(xiàn)"""
def __init__(self, head_dim: int, max_seq_len: int = 2048):
super().__init__()
self.head_dim = head_dim
assert head_dim % 2 == 0, "head_dim must be even"
# 基頻:theta_i = 10000^(-2i/d)
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
self._seq_len_cached = max_seq_len
self._cos_cached = None
self._sin_cached = None
self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)
def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
if seq_len == self._seq_len_cached and self._cos_cached is not None:
return
# m: (seq_len,), theta_i: (head_dim//2,)
m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", m, self.inv_freq) # (seq_len, head_dim//2)
# 構(gòu)造完整的旋轉(zhuǎn)矩陣(每個(gè)復(fù)數(shù)對(duì)重復(fù))
emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, head_dim)
cos = emb.cos()[None, None, :, :] # (1, 1, seq_len, head_dim)
sin = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim)
self._cos_cached = cos
self._sin_cached = sin
self._seq_len_cached = seq_len
def forward(self, seq_len: int, device: Optional[torch.device] = None):
if device is None:
device = self.inv_freq.device
self._update_cos_sin_cache(seq_len, device=device)
return self._cos_cached.to(device), self._sin_cached.to(device)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""應(yīng)用RoPE旋轉(zhuǎn)"""
# x: (B, H, T, D), cos/sin: (1, 1, T, D)
# 使用(x, y) -> (x*cos-y*sin, x*sin+y*cos)
return (x * cos) + (_rotate_half(x) * sin)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""將向量旋轉(zhuǎn)90度"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# ==================== Flash Attention ====================
class FlashAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.attn_dropout = nn.Dropout(attn_dropout)
self.rope = RotaryPositionalEmbedding(self.head_dim)
def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.shape
assert T <= self.rope.max_seq_len, f"Seq len {T} exceeds max {self.rope.max_seq_len}"
qkv = self.qkv(x)
qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q = q.permute(0, 2, 1, 3) # (B, H, T, D)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# 應(yīng)用RoPE
cos, sin = self.rope(T, device=x.device)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# 注意力計(jì)算
# 注意:這里如果使用 torch.nn.functional.scaled_dot_product_attention 配合 torch.compile 會(huì)更快
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if causal_mask is not None:
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
return self.out_proj(out)
# ==================== 前饋網(wǎng)絡(luò) ====================
class GLU(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim * 2)
def forward(self, x):
x, gates = self.linear(x).chunk(2, dim=-1)
return x * torch.nn.functional.silu(gates)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
GLU(dim, hidden_dim),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# ==================== Transformer Block ====================
class TransformerBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
self.ln2 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, ffn_dim, dropout)
def forward(self, x, causal_mask=None):
x = x + self.attn(self.ln1(x), causal_mask)
x = x + self.ff(self.ln2(x))
return x
# ==================== GPT 模型(已移除 pos_emb) ====================
class GPTModel(nn.Module):
def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
ffn_dim: int = CFG.FFN_DIM, dropout: float = CFG.DROPOUT,
tie_weights: bool = True):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, dim)
# self.pos_emb = nn.Embedding(block_size, dim) # 移除:與 RoPE 沖突
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
])
self.ln_final = nn.LayerNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.token_emb.weight
self.block_size = block_size
self.apply(self._init_weights)
n_params = sum(p.numel() for p in self.parameters())
print(f"模型參數(shù): {n_params/1e6:.2f}M")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, idx):
B, T = idx.shape
assert T <= self.block_size, f"Seq len {T} exceeds block_size {self.block_size}"
token_emb = self.token_emb(idx)
x = self.dropout(token_emb) # token embedding
causal_mask = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))[None, None, :, :]
for block in self.blocks:
x = block(x, causal_mask)
x = self.ln_final(x)
logits = self.lm_head(x)
return logits
# ==================== 檢查點(diǎn)管理 ====================
def save_checkpoint(model, optimizer, scaler, lr_scheduler, step: int, loss: float, config_dict: dict):
os.makedirs(CFG.CHECKPOINT_DIR, exist_ok=True)
checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
state = {
'step': step,
'loss': loss,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config_dict,
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
}
if scaler is not None and hasattr(scaler, "state_dict"):
state['scaler_state_dict'] = scaler.state_dict()
if lr_scheduler is not None:
state['lr_scheduler_state_dict'] = {
'current_step': lr_scheduler.current_step,
'warmup_steps': lr_scheduler.warmup_steps,
'total_steps': lr_scheduler.total_steps,
'base_lr': lr_scheduler.base_lr,
'min_lr': lr_scheduler.min_lr,
}
torch.save(state, checkpoint_path)
try:
with open(os.path.join(CFG.CHECKPOINT_DIR, "config.json"), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2)
except Exception:
pass
print(f" 檢查點(diǎn)已保存: {checkpoint_path} (step {step}, loss {loss:.4f})")
def load_checkpoint(checkpoint_path: str, model, optimizer, scaler, lr_scheduler):
if not os.path.exists(checkpoint_path):
return None
checkpoint = torch.load(checkpoint_path, map_locatinotallow=CFG.DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if checkpoint.get('scaler_state_dict') is not None and scaler is not None:
try:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
except Exception as e:
print(f"無(wú)法恢復(fù)scaler: {e}")
if checkpoint.get('lr_scheduler_state_dict') is not None and lr_scheduler is not None:
try:
sched_state = checkpoint['lr_scheduler_state_dict']
lr_scheduler.current_step = sched_state['current_step']
lr_scheduler.warmup_steps = sched_state['warmup_steps']
lr_scheduler.total_steps = sched_state['total_steps']
lr_scheduler.base_lr = sched_state['base_lr']
lr_scheduler.min_lr = sched_state['min_lr']
except Exception as e:
print(f"無(wú)法恢復(fù)lr_scheduler: {e}")
torch.set_rng_state(checkpoint['torch_rng_state'])
if torch.cuda.is_available() and checkpoint.get('cuda_rng_state') is not None:
torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
print(f"檢查點(diǎn)已加載: {checkpoint_path}")
print(f" Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")
return checkpoint['step']
# ==================== 學(xué)習(xí)率調(diào)度器 ====================
class WarmupCosineScheduler:
def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float):
self.optimizer = optimizer
self.warmup_steps = max(0, int(warmup_steps))
self.total_steps = max(1, int(total_steps))
self.base_lr = base_lr
self.min_lr = min_lr
self.current_step = 0
def get_lr(self, step: int = None) -> float:
"""計(jì)算給定step的學(xué)習(xí)率(不修改optimizer)"""
if step is None:
step = self.current_step
if step < self.warmup_steps and self.warmup_steps > 0:
return self.base_lr * (step / float(self.warmup_steps))
else:
denom = max(1, (self.total_steps - self.warmup_steps))
progress = (step - self.warmup_steps) / denom
progress = min(1.0, max(0.0, progress))
return self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))
def step(self):
"""執(zhí)行一次步長(zhǎng)更新"""
lr = self.get_lr(self.current_step)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.current_step += 1
return lr
# ==================== 訓(xùn)練循環(huán) ====================
def train(model: nn.Module, train_loader: DataLoader, epochs: int = CFG.EPOCHS, resume: bool = False):
# 檢測(cè)fused優(yōu)化器支持
fused = False
try:
fused = torch.cuda.is_available() and ("fused" in torch.optim.AdamW.__init__.__code__.co_varnames)
except Exception:
fused = False
optimizer = torch.optim.AdamW(
model.parameters(),
lr=CFG.LR,
betas=(0.9, 0.95),
weight_decay=CFG.WEIGHT_DECAY,
fused=fused
)
# 使用配置中的 DTYPE
scaler = torch.cuda.amp.GradScaler(enabled=(CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16))
loss_fn = nn.CrossEntropyLoss()
total_steps = CFG.MAX_STEPS if CFG.MAX_STEPS else len(train_loader) * epochs
lr_scheduler = WarmupCosineScheduler(optimizer, CFG.WARMUP_STEPS, total_steps, CFG.LR, CFG.MIN_LR)
model.train()
start_step = 0
best_loss = float('inf')
checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
if resume and os.path.exists(checkpoint_path):
loaded_step = load_checkpoint(checkpoint_path, model, optimizer, scaler, lr_scheduler)
if loaded_step is not None:
start_step = loaded_step
global_step = start_step
grad_accum_counter = 0
accumulated_loss = 0.0
print("\n" + "="*60)
print("開(kāi)始訓(xùn)練...")
print("="*60)
print_gpu_memory()
print()
# 自動(dòng)選擇是否需要 scaler.scale()
use_scaler = (CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16)
for epoch in range(epochs):
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", initial=global_step % len(train_loader) if epoch == 0 else 0)
num_batches = 0
last_lr = None
for batch_idx, (xb, yb) in enumerate(pbar):
# 跳過(guò)已訓(xùn)練的批次 (如果從中間恢復(fù))
if global_step > start_step and batch_idx < (start_step % len(train_loader)):
continue
xb = xb.to(CFG.DEVICE, non_blocking=True)
yb = yb.to(CFG.DEVICE, non_blocking=True)
with torch.cuda.amp.autocast(enabled=(CFG.DEVICE == "cuda"), dtype=CFG.DTYPE):
logits = model(xb)
loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
loss_item = loss.item()
loss = loss / CFG.GRAD_ACCUM_STEPS
if use_scaler:
scaler.scale(loss).backward()
else:
loss.backward()
grad_accum_counter += 1
accumulated_loss += loss_item
num_batches += 1
# 這里的 global_step 計(jì)數(shù)是基于數(shù)據(jù)批次的,而不是優(yōu)化器步數(shù),用于日志和檢查點(diǎn)
# 真正的優(yōu)化器步數(shù)會(huì)在下面更新
# 梯度累積:達(dá)到閾值時(shí)執(zhí)行優(yōu)化步驟
if grad_accum_counter >= CFG.GRAD_ACCUM_STEPS:
# 優(yōu)化器步進(jìn) (這是真正的 global_step 增長(zhǎng)點(diǎn))
lr_scheduler.step() # 先更新 LR
global_step += 1 # 只有進(jìn)行了一次優(yōu)化器步進(jìn),才算一個(gè) global_step
if use_scaler:
scaler.unscale_(optimizer)
# 梯度裁剪 (在 unscale 后或非 AMP 模式下)
torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
if use_scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
grad_accum_counter = 0
last_lr = lr_scheduler.get_lr(global_step) # 獲取當(dāng)前步的LR
# 日志輸出
if global_step % CFG.LOG_EVERY == 0 or (global_step == 1):
# accumulated_loss 是累積的原始損失, num_batches 是累積的批次數(shù)
avg_loss = accumulated_loss / num_batches if num_batches > 0 else 0.0
pbar.set_postfix({
'step': global_step,
'loss': f'{avg_loss:.4f}',
'lr': f'{last_lr:.2e}' if last_lr is not None else 'N/A'
})
# 重置累積值以便計(jì)算下一個(gè) LOG_EVERY 間隔的平均損失
accumulated_loss = 0.0
num_batches = 0
# 保存檢查點(diǎn)
if global_step > start_step and global_step % CFG.CHECKPOINT_EVERY == 0:
# 使用上一個(gè)日志點(diǎn)計(jì)算的 avg_loss
current_avg_loss = accumulated_loss / num_batches if num_batches > 0 else loss_item
config_dict = {
'vocab_size': CFG.VOCAB_SIZE,
'block_size': CFG.BLOCK_SIZE,
'model_dim': CFG.MODEL_DIM,
'n_layers': CFG.N_LAYERS,
'num_heads': CFG.NUM_HEADS,
'created_at': datetime.now().isoformat()
}
save_checkpoint(model, optimizer, scaler, lr_scheduler, global_step, current_avg_loss, config_dict)
torch.cuda.empty_cache()
if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
break
# 處理 epoch 結(jié)束時(shí)剩余的梯度 (如果 grad_accum_counter > 0)
if grad_accum_counter > 0:
if use_scaler:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
if use_scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
global_step += 1
grad_accum_counter = 0
# 此時(shí) pbar.total_loss 已累積
if num_batches > 0:
final_avg_loss = accumulated_loss / num_batches
else:
final_avg_loss = float('inf')
if final_avg_loss < best_loss:
best_loss = final_avg_loss
best_path = os.path.join(CFG.CHECKPOINT_DIR, "best_model.pth")
torch.save(model.state_dict(), best_path)
print(f"最佳模型已保存 (loss: {best_loss:.4f})")
print(f"\n[Epoch {epoch+1}] Avg Loss: {final_avg_loss:.4f}")
if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
break
print("\n訓(xùn)練完成!")
# ==================== 主函數(shù) ====================
def main():
if len(sys.argv) < 4:
print("用法: python train_20251012_v1.py workdir/spm_wiki_16k.model workdir/wiki_tokens_16k.pt models/gpt_wiki.pth [--resume]")
sys.exit(1)
sp_model_path, token_file_path, out_path = sys.argv[1:4]
resume = "--resume" in sys.argv
if not os.path.exists(token_file_path):
print(f" Token文件不存在: {token_file_path}")
sys.exit(1)
# 檢查 CFG.DTYPE 是否為 bfloat16 但環(huán)境不支持
if CFG.DTYPE == torch.bfloat16 and not torch.cuda.is_bf16_supported():
print("警告: bfloat16 不受當(dāng)前 CUDA 設(shè)備支持,自動(dòng)回退到 float16。")
CFG.DTYPE = torch.float16
sp = spm.SentencePieceProcessor(model_file=sp_model_path)
CFG.VOCAB_SIZE = sp.get_piece_size()
print("="*60)
print("GPT 語(yǔ)言模型訓(xùn)練")
print("="*60)
print(f"分詞器: {sp_model_path}")
print(f"Token文件: {token_file_path}")
print(f"輸出模型: {out_path}")
print(f"設(shè)備: {CFG.DEVICE}")
print(f"\n模型配置:")
print(f" - VOCAB_SIZE: {CFG.VOCAB_SIZE}")
print(f" - BLOCK_SIZE: {CFG.BLOCK_SIZE}")
print(f" - MODEL_DIM: {CFG.MODEL_DIM}")
print(f" - N_LAYERS: {CFG.N_LAYERS}")
print(f" - NUM_HEADS: {CFG.NUM_HEADS}")
print(f"\n訓(xùn)練配置:")
print(f" - BATCH_SIZE: {CFG.BATCH_SIZE}")
print(f" - GRAD_ACCUM_STEPS: {CFG.GRAD_ACCUM_STEPS}")
print(f" - 有效BATCH_SIZE: {CFG.BATCH_SIZE * CFG.GRAD_ACCUM_STEPS}")
print(f" - LR: {CFG.LR}, WARMUP_STEPS: {CFG.WARMUP_STEPS}")
print("="*60)
print(f"\n加載Token文件: {token_file_path}")
ids = torch.load(token_file_path)
print(f"已加載 {ids.numel():,} tokens ({ids.numel() * ids.element_size() / (1024**3):.2f} GB)")
dataset = TextDataset(ids, CFG.BLOCK_SIZE)
del ids
torch.cuda.empty_cache()
# 改進(jìn):?jiǎn)⒂?shuffle=True 進(jìn)行預(yù)訓(xùn)練
num_workers = CFG.NUM_WORKERS
try:
train_loader = DataLoader(
dataset,
batch_size=CFG.BATCH_SIZE,
shuffle=True, # 啟用 Shuffle
pin_memory=(CFG.DEVICE == "cuda"),
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False
)
except Exception as e:
print(f"DataLoader錯(cuò)誤: {e}, 改用num_workers=0")
train_loader = DataLoader(
dataset,
batch_size=CFG.BATCH_SIZE,
shuffle=True,
pin_memory=(CFG.DEVICE == "cuda"),
num_workers=0
)
model = GPTModel(
CFG.VOCAB_SIZE,
CFG.BLOCK_SIZE,
dim=CFG.MODEL_DIM,
num_layers=CFG.N_LAYERS,
num_heads=CFG.NUM_HEADS,
ffn_dim=CFG.FFN_DIM,
dropout=CFG.DROPOUT
).to(CFG.DEVICE)
# 嘗試編譯(容錯(cuò))
try:
model = torch.compile(model, mode='reduce-overhead')
print("已啟用 torch.compile() 加速")
except Exception as e:
print(f"跳過(guò) torch.compile(): {e}")
train(model, train_loader, epochs=CFG.EPOCHS, resume=resume)
torch.save(model.state_dict(), out_path)
print(f"\n最終模型已保存到 {out_path}")
print_gpu_memory()
if __name__ == "__main__":
main()07、進(jìn)行模型推理測(cè)試
import torch
from torch import nn
import sentencepiece as spm
from typing import Optional
# ==================== 配置參數(shù) (必須與訓(xùn)練時(shí)一致) ====================
# 使用與訓(xùn)練腳本中完全相同的配置
class Config:
BLOCK_SIZE = 512
# 模型尺寸參數(shù) (必須與訓(xùn)練時(shí)一致)
MODEL_DIM = 384
N_LAYERS = 5
NUM_HEADS = 6
HEAD_DIM = MODEL_DIM // NUM_HEADS
FFN_DIM = MODEL_DIM * 4
VOCAB_SIZE = None
# 推理設(shè)置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 推理通常使用 float32 獲得最佳兼容性和精度
DTYPE = torch.float32
CFG = Config()
# ==================== RoPE 位置編碼 (與訓(xùn)練腳本保持一致) ====================
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, head_dim: int, max_seq_len: int = 2048):
super().__init__()
self.head_dim = head_dim
assert head_dim % 2 == 0, "head_dim must be even"
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
self._seq_len_cached = max_seq_len
self._cos_cached = None
self._sin_cached = None
self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)
def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
if seq_len == self._seq_len_cached and self._cos_cached is not None:
return
m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", m, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]
self._cos_cached = cos
self._sin_cached = sin
self._seq_len_cached = seq_len
def forward(self, seq_len: int, device: Optional[torch.device] = None):
if device is None:
device = self.inv_freq.device
self._update_cos_sin_cache(seq_len, device=device)
return self._cos_cached.to(device), self._sin_cached.to(device)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return (x * cos) + (_rotate_half(x) * sin)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# ==================== Attention, FFN, Block, Model (與訓(xùn)練腳本保持一致) ====================
class FlashAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
# 推理時(shí)通常不使用 Dropout,但模型結(jié)構(gòu)需要保持一致
self.attn_dropout = nn.Dropout(attn_dropout)
self.rope = RotaryPositionalEmbedding(self.head_dim)
def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.shape
qkv = self.qkv(x)
qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q = q.permute(0, 2, 1, 3) # (B, H, T, D)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
cos, sin = self.rope(T, device=x.device)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 注意:在推理時(shí),通常使用 KV-Cache,這里簡(jiǎn)化為完整計(jì)算
if T > 1: # 僅在序列長(zhǎng)度大于 1 時(shí)應(yīng)用 mask
causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))[None, None, :, :]
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
# 推理時(shí)禁用 dropout
# attn = self.attn_dropout(attn)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
# 必須保持與訓(xùn)練腳本中完全相同的 nn.Sequential 結(jié)構(gòu)
self.net = nn.Sequential(
GLU(dim, hidden_dim),
nn.Dropout(dropout), # net.1: Dropout (必須保留,占位)
nn.Linear(hidden_dim, dim), # net.2: Linear (與訓(xùn)練時(shí)一致)
nn.Dropout(dropout), # net.3: Dropout (必須保留,占位)
)
def forward(self, x):
# 在推理時(shí), model.eval() 會(huì)自動(dòng)禁用所有 nn.Dropout 層,但結(jié)構(gòu)不變
return self.net(x)
# 確保 GLU 的定義如下(與訓(xùn)練時(shí)一致):
class GLU(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
# GLU 內(nèi)部只有一個(gè) nn.Linear
self.linear = nn.Linear(in_dim, out_dim * 2)
def forward(self, x):
x, gates = self.linear(x).chunk(2, dim=-1)
return x * torch.nn.functional.silu(gates)
class TransformerBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
self.ln2 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, ffn_dim, dropout)
def forward(self, x, causal_mask=None):
x = x + self.attn(self.ln1(x), causal_mask)
x = x + self.ff(self.ln2(x))
return x
class GPTModel(nn.Module):
def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
ffn_dim: int = CFG.FFN_DIM, dropout: float = 0.0, # 推理時(shí) dropout=0
tie_weights: bool = True):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, dim)
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
])
self.ln_final = nn.LayerNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.token_emb.weight
self.block_size = block_size
def forward(self, idx):
B, T = idx.shape
token_emb = self.token_emb(idx)
x = token_emb # 推理時(shí)不使用 dropout
causal_mask = None # Attention 模塊內(nèi)部處理 Causal Mask
for block in self.blocks:
x = block(x, causal_mask)
x = self.ln_final(x)
logits = self.lm_head(x)
return logits
# ==================== 推理和生成函數(shù) ====================
@torch.no_grad()
def generate_text(model: GPTModel, sp: spm.SentencePieceProcessor,
prompt: str, max_new_tokens: int, temperature: float = 0.8,
top_k: int = 50):
model.eval()
device = CFG.DEVICE
# 1. 編碼輸入
input_ids = sp.encode_as_ids(prompt)
if not input_ids:
return "無(wú)法編碼輸入。"
# 將輸入轉(zhuǎn)換為模型期望的格式 (B, T)
x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
# 2. 循環(huán)生成
for _ in range(max_new_tokens):
# 裁剪輸入以適應(yīng)模型的 BLOCK_SIZE
# 在實(shí)際部署中,這里應(yīng)該使用 KV Cache,但此處簡(jiǎn)化為完整前向傳播
idx_cond = x if x.size(1) <= CFG.BLOCK_SIZE else x[:, -CFG.BLOCK_SIZE:]
# 獲取 logits
logits = model(idx_cond)
# 只取最后一個(gè)時(shí)間步的 logits
logits = logits[:, -1, :]
# 應(yīng)用溫度縮放
logits = logits / temperature
# 3. Top-K 采樣
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# 計(jì)算概率并采樣
probs = torch.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# 4. 停止條件
# 檢查是否生成了 EOS token (假設(shè) </s> 是 ID 3, 請(qǐng)根據(jù)您的分詞器調(diào)整)
# 默認(rèn)使用 SentencePiece 的 <eos> ID
if idx_next.item() == sp.eos_id():
break
# 將新生成的 token 添加到序列中
x = torch.cat((x, idx_next), dim=1)
# 檢查是否達(dá)到最大序列長(zhǎng)度 (防止溢出)
if x.size(1) >= CFG.BLOCK_SIZE + max_new_tokens:
break
# 5. 解碼輸出
output_ids = x[0].tolist()
# 查找輸入 prompt 的長(zhǎng)度,只解碼新生成的 token
start_index = len(input_ids)
return sp.decode_ids(output_ids[start_index:])
# ==================== 主執(zhí)行函數(shù) ====================
def main_infer(sp_model_path: str, model_weights_path: str):
print("="*50)
print(f"GPT 模型推理模式")
print(f"設(shè)備: {CFG.DEVICE}, DTYPE: {CFG.DTYPE}")
print("="*50)
# 1. 加載分詞器
try:
sp = spm.SentencePieceProcessor(model_file=sp_model_path)
CFG.VOCAB_SIZE = sp.get_piece_size()
print(f"加載分詞器成功,VOCAB_SIZE: {CFG.VOCAB_SIZE}")
except Exception as e:
print(f"無(wú)法加載分詞器模型 {sp_model_path}: {e}")
return
# 2. 實(shí)例化模型
model = GPTModel(
vocab_size=CFG.VOCAB_SIZE,
block_size=CFG.BLOCK_SIZE,
dim=CFG.MODEL_DIM,
num_layers=CFG.N_LAYERS,
num_heads=CFG.NUM_HEADS,
ffn_dim=CFG.FFN_DIM,
dropout=0.0 # 推理時(shí)設(shè)置 dropout 為 0
).to(CFG.DEVICE).to(CFG.DTYPE)
# 3. 加載權(quán)重
try:
# 檢查是否是 torch.compile 后的狀態(tài)字典
weights = torch.load(model_weights_path, map_locatinotallow=CFG.DEVICE)
# 如果權(quán)重是 DDP 或 torch.compile 包裝后的,需要解包
if any(k.startswith('_orig_mod.') for k in weights.keys()):
weights = {k.replace('_orig_mod.', ''): v for k, v in weights.items()}
model.load_state_dict(weights, strict=True)
print(f"成功加載模型權(quán)重: {model_weights_path}")
except Exception as e:
print(f"無(wú)法加載或匹配模型權(quán)重: {e}")
# 如果加載失敗,打印預(yù)期鍵和實(shí)際鍵,方便調(diào)試
# print("\n--- 預(yù)期模型鍵 (部分) ---")
# print(list(model.state_dict().keys())[:5])
# print("\n--- 載入權(quán)重鍵 (部分) ---")
# print(list(weights.keys())[:5])
return
# 4. 進(jìn)入交互循環(huán)
print("\n--- 進(jìn)入交互模式 ---")
print(f"輸入 'exit' 或 'quit' 退出。")
print(f"輸入 'config' 查看當(dāng)前生成參數(shù)。")
print("----------------------")
max_tokens = 100
temperature = 0.8
top_k = 50
while True:
try:
prompt = input(">>> 輸入提示詞: ")
if prompt.lower() in ['exit', 'quit']:
break
if prompt.lower() == 'config':
print(f" Max Tokens: {max_tokens}, Temp: {temperature}, Top K: {top_k}")
new_max = input(" 設(shè)置 Max Tokens (回車跳過(guò)): ")
new_temp = input(" 設(shè)置 Temperature (回車跳過(guò)): ")
new_k = input(" 設(shè)置 Top K (回車跳過(guò)): ")
if new_max: max_tokens = int(new_max)
if new_temp: temperature = float(new_temp)
if new_k: top_k = int(new_k)
continue
if not prompt.strip():
continue
print("生成中...")
# 執(zhí)行生成
output = generate_text(model, sp, prompt, max_tokens, temperature, top_k)
print(f"--- 模型回復(fù) ---\n{output.strip()}")
print("----------------")
except KeyboardInterrupt:
print("\n退出生成...")
break
except Exception as e:
print(f"發(fā)生錯(cuò)誤: {e}")
if __name__ == "__main__":
import sys
if len(sys.argv) != 3:
print("用法: python infer.py <spm模型路徑> <模型權(quán)重文件路徑>")
# 示例用法 (請(qǐng)根據(jù)您的實(shí)際文件路徑修改):
# python infer.py tokenizer.model final_model.pth
sys.exit(1)
sp_model_path = sys.argv[1]
model_weights_path = sys.argv[2]
main_infer(sp_model_path, model_weights_path)我們看到模型大概可以預(yù)測(cè)我們輸入的下一個(gè)詞,因我們訓(xùn)練的參數(shù)和步數(shù)很低,模型輸出的亂七八糟!
本次總結(jié)
本次我們做了數(shù)據(jù)準(zhǔn)備、數(shù)據(jù)清洗、分詞器訓(xùn)練、模型訓(xùn)練、推理等,請(qǐng)根據(jù)步驟進(jìn)行執(zhí)行代碼,你便可以得到一個(gè)17M參數(shù)的小模型。后面我們?cè)偌哟髤?shù)進(jìn)行訓(xùn)練,再進(jìn)行監(jiān)督微調(diào)。

























