精品欧美一区二区三区在线观看 _久久久久国色av免费观看性色_国产精品久久在线观看_亚洲第一综合网站_91精品又粗又猛又爽_小泽玛利亚一区二区免费_91亚洲精品国偷拍自产在线观看 _久久精品视频在线播放_美女精品久久久_欧美日韩国产成人在线

從零實(shí)現(xiàn)一個(gè)17M參數(shù)的GPT預(yù)訓(xùn)練模型

人工智能
今天我們使用開(kāi)源的的中文數(shù)據(jù)進(jìn)行模型的預(yù)訓(xùn)練,下面跟著我的步驟,從零實(shí)現(xiàn)你的預(yù)訓(xùn)練模型。

大家好,我是寫代碼的中年人!

今天我們使用開(kāi)源的的中文數(shù)據(jù)進(jìn)行模型的預(yù)訓(xùn)練,下面跟著我的步驟,從零實(shí)現(xiàn)你的預(yù)訓(xùn)練模型。

本文所有代碼和數(shù)據(jù)資源位置:

https://github.com/ColinAIAPP/MoiraiLM

01、預(yù)訓(xùn)練模型的概念

預(yù)訓(xùn)練模型(Pretrained Model)就是一個(gè)已經(jīng)在海量數(shù)據(jù)上訓(xùn)練過(guò)的模型,它學(xué)會(huì)了語(yǔ)言的基本規(guī)律、結(jié)構(gòu)和語(yǔ)義,然后可以拿來(lái)做各種下游任務(wù),比如寫作、翻譯、問(wèn)答、分類、生成代碼等。

那“預(yù)訓(xùn)練”到底在學(xué)什么?以語(yǔ)言模型(LLM)為例:預(yù)訓(xùn)練階段的任務(wù)通常是預(yù)測(cè)下一個(gè)詞(token)。

接下來(lái)我們就一步一步實(shí)現(xiàn)一個(gè)17M參數(shù)的預(yù)訓(xùn)練模型。

02、數(shù)據(jù)準(zhǔn)備

構(gòu)建語(yǔ)言模型的第一要義是高質(zhì)量的數(shù)據(jù)源。對(duì)于中文任務(wù),選擇維基百科開(kāi)源中文數(shù)據(jù)集是一個(gè)理想起點(diǎn)。這個(gè)數(shù)據(jù)集包含數(shù)百萬(wàn)條中文百科條目,涵蓋歷史、文化、科技等領(lǐng)域,總量約數(shù)GB的純文本數(shù)據(jù)。它開(kāi)源且免費(fèi),可通過(guò)維基百科的官方轉(zhuǎn)儲(chǔ)頁(yè)面下載最新版本的XML格式文件。

要解壓處理這個(gè)文件我們要使用wikiextractor工具進(jìn)行數(shù)據(jù)解壓。

安裝解壓命令:

pip install wikiextractor

解壓命令:

python -m wikiextractor.WikiExtractor -b 1G -o extracted_wiki_zh zhwiki-20250920-pages-articles-multistream.xml.bz2 --json
zhwiki-20250920-pages-articles-multistream.xml.bz2:為文件名

INFO: Preprocessing 'zhwiki-20250920-pages-articles-multistream.xml.bz2' to collect template definitions: this may take some time.
INFO: Preprocessed 100000 pages
INFO: Preprocessed 200000 pages
INFO: Preprocessed 300000 pages
INFO: Preprocessed 400000 pages
INFO: Preprocessed 500000 pages
INFO: Preprocessed 600000 pages
INFO: Preprocessed 700000 pages
INFO: Preprocessed 800000 pages
INFO: Preprocessed 900000 pages
INFO: Preprocessed 1000000 pages
INFO: Preprocessed 1100000 pages
INFO: Preprocessed 1200000 pages
INFO: Preprocessed 1300000 pages
INFO: Preprocessed 1400000 pages
INFO: Preprocessed 1500000 pages
INFO: Preprocessed 1600000 pages
INFO: Preprocessed 1700000 pages
INFO: Preprocessed 1800000 pages
INFO: Preprocessed 1900000 pages
INFO: Preprocessed 2000000 pages
INFO: Preprocessed 2100000 pages
INFO: Preprocessed 2200000 pages
INFO: Preprocessed 2300000 pages
INFO: Preprocessed 2400000 pages
INFO: Preprocessed 2500000 pages
INFO: Preprocessed 2600000 pages
INFO: Preprocessed 2700000 pages
INFO: Preprocessed 2800000 pages
INFO: Preprocessed 2900000 pages
INFO: Preprocessed 3000000 pages
INFO: Preprocessed 3100000 pages
INFO: Preprocessed 3200000 pages
INFO: Preprocessed 3300000 pages
INFO: Preprocessed 3400000 pages
INFO: Preprocessed 3500000 pages
INFO: Preprocessed 3600000 pages
INFO: Preprocessed 3700000 pages
INFO: Preprocessed 3800000 pages
INFO: Preprocessed 3900000 pages
INFO: Preprocessed 4000000 pages
INFO: Preprocessed 4100000 pages
INFO: Preprocessed 4200000 pages
INFO: Preprocessed 4300000 pages
INFO: Preprocessed 4400000 pages
INFO: Preprocessed 4500000 pages
INFO: Preprocessed 4600000 pages
INFO: Preprocessed 4700000 pages
INFO: Loaded 1036734 templates in 704.2s
INFO: Starting page extraction from zhwiki-20250920-pages-articles-multistream.xml.bz2.
INFO: Using 127 extract processes.
INFO: Extracted 100000 articles (1209.6 art/s)
INFO: Extracted 200000 articles (1947.8 art/s)
INFO: Extracted 300000 articles (2325.1 art/s)
INFO: Extracted 400000 articles (3471.3 art/s)
INFO: Extracted 500000 articles (2551.1 art/s)
INFO: Extracted 600000 articles (2239.4 art/s)
INFO: Extracted 700000 articles (2299.3 art/s)
INFO: Extracted 800000 articles (1525.2 art/s)
INFO: Extracted 900000 articles (3256.1 art/s)
INFO: Extracted 1000000 articles (3485.9 art/s)
INFO: Extracted 1100000 articles (3495.0 art/s)
INFO: Extracted 1200000 articles (3330.4 art/s)
INFO: Extracted 1300000 articles (3555.6 art/s)
INFO: Extracted 1400000 articles (3456.3 art/s)
INFO: Extracted 1500000 articles (2476.1 art/s)
INFO: Extracted 1600000 articles (2268.6 art/s)
INFO: Extracted 1700000 articles (2473.5 art/s)
INFO: Extracted 1800000 articles (2305.9 art/s)
INFO: Extracted 1900000 articles (2263.9 art/s)
INFO: Extracted 2000000 articles (2136.4 art/s)
INFO: Extracted 2100000 articles (2363.0 art/s)
INFO: Extracted 2200000 articles (2601.9 art/s)
INFO: Extracted 2300000 articles (3709.0 art/s)
INFO: Extracted 2400000 articles (2723.9 art/s)
INFO: Extracted 2500000 articles (2487.1 art/s)
INFO: Extracted 2600000 articles (2621.3 art/s)
INFO: Extracted 2700000 articles (2525.4 art/s)
INFO: Extracted 2800000 articles (2666.4 art/s)
INFO: Finished 127-process extraction of 2893023 articles in 1156.5s (2501.5 art/s)

03、清洗數(shù)據(jù)

我們解壓后的數(shù)據(jù)如下圖,下面我們要把數(shù)據(jù)清洗出來(lái)。

注:

我們本步驟生成的文件為 data/cleaned_wiki_full.txt

import os
import json
import logging
import argparse
import re
from tqdm import tqdm


# 配置日志記錄
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')


# python scripts/clean_wiki_text.py data/extracted_wiki_zh data/cleaned_wiki_full.txt --min_line_length 20 --min_article_length 300




def clean_text(text: str) -> str:
    """
    對(duì)文本進(jìn)行深度清洗。
    移除維基百科特有的格式標(biāo)記、參考文獻(xiàn)、HTML標(biāo)簽、日期和數(shù)字等。
    """
    # 移除維基鏈接 [[link|display]] 或 [[link]]
    text = re.sub(r'\[\[([^\]|]+\|)?([^\]]+)\]\]', r'\2', text)


    # 移除參考文獻(xiàn)標(biāo)記 [1], [2], [ref], 等
    text = re.sub(r'\[\d+\]|\[ref\]|\[/ref\]|\[citation needed\]', '', text)


    # 移除HTML標(biāo)簽
    text = re.sub(r'<[^>]+>', '', text)


    # 移除日期格式 (yyyy-mm-dd, yyyy/mm/dd, mm/dd/yyyy 等)
    text = re.sub(r'\d{1,4}[-/]\d{1,2}[-/]\d{1,4}', '', text)


    # 移除年份 (1000-2999)
    text = re.sub(r'\b[12]\d{3}\b', '', text)


    # 移除純數(shù)字(包括小數(shù))
    text = re.sub(r'\b\d+\.?\d*\b', '', text)


    # 移除重復(fù)的空白字符(但保留單個(gè)空格)
    text = re.sub(r' +', ' ', text)


    # 移除行首尾空白
    text = text.strip()


    return text




def process_extracted_wiki(extracted_dir: str, 
                          output_file: str, 
                          min_line_length: int = 20, 
                          min_article_length: int = 200):
    """
    讀取WikiExtractor輸出的JSON文件,提取、清洗文本并保存到單個(gè)文件中。


    參數(shù):
        extracted_dir: WikiExtractor輸出的目錄路徑
        output_file: 最終合并的純文本文件路徑
        min_line_length: 單行文本最小長(zhǎng)度,用于過(guò)濾噪音(默認(rèn): 20)
        min_article_length: 文章最小長(zhǎng)度,用于過(guò)濾短文章(默認(rèn): 200)
    """
    if not os.path.isdir(extracted_dir):
        logging.error(f"輸入的目錄不存在: {extracted_dir}")
        return


    total_articles = 0
    skipped_articles = 0


    # 第一次遍歷:獲取所有需要處理的文件列表
    file_list = []
    for root, dirs, files in os.walk(extracted_dir):
        for file_name in files:
            # 僅處理 WikiExtractor 生成的以 'wiki_' 開(kāi)頭的文件
            if file_name.startswith('wiki_'):
                file_list.append(os.path.join(root, file_name))


    total_files = len(file_list)
    logging.info(f"找到 {total_files} 個(gè)文件等待處理。")


    if total_files == 0:
        logging.warning(f"目錄 {extracted_dir} 中未找到任何 'wiki_' 文件。請(qǐng)檢查路徑。")
        return


    # 第二次遍歷:處理文件并寫入輸出
    with open(output_file, 'w', encoding='utf-8') as f_out:
        # 使用 tqdm 包裝文件列表,顯示處理進(jìn)度
        for file_path in tqdm(file_list, desc="?? 正在提取維基文本"):
            try:
                with open(file_path, 'r', encoding='utf-8') as f_in:
                    for line_num, line in enumerate(f_in, 1):
                        try:
                            article = json.loads(line)
                            text_content = article.get('text', '').strip()


                            # --- 文本清洗和過(guò)濾 ---


                            # 1. 過(guò)濾掉過(guò)短的文章,它們通常是噪音或重定向頁(yè)
                            if len(text_content) < min_article_length:
                                skipped_articles += 1
                                continue


                            # 2. 按行處理文本,過(guò)濾短行和額外的空白
                            # 保留行結(jié)構(gòu),而不是將所有行連接成一個(gè)長(zhǎng)句子
                            cleaned_lines = []
                            for text_line in text_content.split('\n'):
                                text_line = clean_text(text_line)
                                # 只保留足夠長(zhǎng)的行
                                if len(text_line) >= min_line_length:
                                    cleaned_lines.append(text_line)


                            # 使用換行符連接各行,保留段落結(jié)構(gòu)
                            final_text = '\n'.join(cleaned_lines)


                            # 最終檢查:確保清洗后的文本仍然足夠長(zhǎng)
                            if final_text and len(final_text) >= min_article_length:
                                # 文章之間用兩個(gè)換行符分隔
                                f_out.write(final_text + '\n\n')
                                total_articles += 1
                            else:
                                skipped_articles += 1


                        except json.JSONDecodeError:
                            logging.warning(f"無(wú)法解析 JSON,文件: {file_path},行號(hào): {line_num}")
                        except Exception as e:
                            logging.error(f"處理文件 {file_path} 第 {line_num} 行時(shí)出錯(cuò): {e}")


            except Exception as e:
                logging.error(f"打開(kāi)文件 {file_path} 時(shí)出錯(cuò): {e}")


    logging.info(f"  所有維基百科文本已成功提取并清洗。")
    logging.info(f"   總文章數(shù): {total_articles}")
    logging.info(f"   跳過(guò)文章數(shù): {skipped_articles}")
    logging.info(f"   文件已保存到: {output_file}")




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="從 WikiExtractor 輸出的 JSON 文件中提取并清洗純文本。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    # 位置參數(shù) 1: 輸入目錄
    parser.add_argument(
        "extracted_directory",
        type=str,
        help="WikiExtractor 輸出的目錄路徑 (e.g., extracted_wiki_zh)"
    )


    # 位置參數(shù) 2: 輸出文件
    parser.add_argument(
        "output_filename",
        type=str,
        help="最終合并的純文本文件路徑 (e.g., cleaned_wiki.txt)"
    )


    # 可選參數(shù): 最小行長(zhǎng)
    parser.add_argument(
        "--min_line_length",
        type=int,
        default=20,
        help="文章中單行文本必須達(dá)到的最小長(zhǎng)度,用于過(guò)濾噪音。默認(rèn)值: 20"
    )


    # 可選參數(shù): 最小文章長(zhǎng)度
    parser.add_argument(
        "--min_article_length",
        type=int,
        default=200,
        help="文章最小長(zhǎng)度,用于過(guò)濾短文章和重定向頁(yè)。默認(rèn)值: 200"
    )


    args = parser.parse_args()


    process_extracted_wiki(
        args.extracted_directory, 
        args.output_filename, 
        args.min_line_length,
        args.min_article_length
    )




if __name__ == "__main__":
    main()
2025-10-01 11:10:58,772 - INFO - 找到 5 個(gè)文件等待處理。
正在提取維基文本: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.78s/it]
2025-10-01 11:11:32,681 - INFO - 所有維基百科文本已成功提取。總文章數(shù): 628093。文件已保存到 data/cleaned_wiki_full.txt

04、訓(xùn)練分詞器

我們使用SentencePiece訓(xùn)練分詞器,本次我們訓(xùn)練的分詞庫(kù)大小為16k,你也可以訓(xùn)練32k的分詞庫(kù)。相關(guān)代碼及過(guò)程如下:

注:

我們本步驟生成的文件為

workdir/spm_wiki_16k.model

workdir/spm_wiki_16k.vocab

import sys
import sentencepiece as spm
import argparse
import os
from tqdm import tqdm


# python scripts/train_tokenizer.py data/cleaned_wiki_full.txt workdir/spm_wiki 32000




def get_corpus_size(input_file: str) -> int:
    """計(jì)算語(yǔ)料的總行數(shù)和文件大小"""
    try:
        file_size_bytes = os.path.getsize(input_file)
        file_size_mb = file_size_bytes / (1024 * 1024)
        print(f"語(yǔ)料文件大小: {file_size_mb:.2f} MB")


        # 計(jì)算行數(shù)和總字符數(shù)
        line_count = 0
        total_chars = 0
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="統(tǒng)計(jì)語(yǔ)料信息"):
                line_count += 1
                total_chars += len(line)


        print(f"語(yǔ)料總行數(shù) (文章數(shù)): {line_count}")
        print(f"總字符數(shù): {total_chars:,}")
        print(f"平均每行字符數(shù): {total_chars / line_count:.1f}")
        return file_size_bytes
    except Exception as e:
        print(f"警告:無(wú)法計(jì)算文件大小或行數(shù):{e}")
        return 0




def train_spm_model(input_file: str, 
                    model_prefix: str, 
                    vocab_size: int,
                    model_type: str = 'bpe',
                    character_coverage: float = 0.9995):
    """
    訓(xùn)練一個(gè)SentencePiece分詞器模型。


    參數(shù):
        input_file: 訓(xùn)練語(yǔ)料文件路徑
        model_prefix: 輸出模型文件的前綴
        vocab_size: 詞匯表大小
        model_type: 分詞算法類型 ('bpe', 'unigram', 'char', 'word')
        character_coverage: 字符覆蓋率 (0-1,通常 0.995-0.9995)
    """
    if not os.path.exists(input_file):
        print(f"錯(cuò)誤:輸入語(yǔ)料文件未找到:{input_file}")
        sys.exit(1)


    # 確保輸出目錄存在
    output_dir = os.path.dirname(model_prefix)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        print(f"已創(chuàng)建輸出目錄: {output_dir}")


    # 打印語(yǔ)料規(guī)模信息
    print("\n=== 語(yǔ)料分析 ===")
    get_corpus_size(input_file)


    # 構(gòu)建訓(xùn)練參數(shù)
    # 對(duì)于 1.5GB 語(yǔ)料,建議啟用 train_extremely_large_corpus=True 加速
    train_params = {
        'input': input_file,
        'model_prefix': model_prefix,
        'vocab_size': vocab_size,
        'model_type': model_type,
        'character_coverage': character_coverage,
        'num_threads': 32,        # 增加到32(最大化CPU利用)
        'bos_id': 0,
        'eos_id': 1,
        'unk_id': 2,
        'pad_id': -1,
        'normalization_rule_name': 'identity',
        'input_sentence_size': 2000000, # 5000000,         # 增加到500萬(wàn)句子采樣
        'train_extremely_large_corpus': True,   # 必須啟用
        'shuffle_input_sentence': True,
        'seed_sentencepiece_size': 2000000,     # 添加種子句子大小
        'hard_vocab_limit': False,              # 允許超過(guò)目標(biāo)詞匯量以獲得更好質(zhì)量
    }


    print("\n=== SentencePiece 訓(xùn)練參數(shù) ===")
    for key, value in train_params.items():
        print(f"  {key}: {value}")
    print("=" * 35)


    print("\n正在訓(xùn)練 SentencePiece 模型...")
    print("   (請(qǐng)稍候,進(jìn)度由 SentencePiece 輸出)\n")


    try:
        # 執(zhí)行訓(xùn)練
        spm.SentencePieceTrainer.train(**train_params)


        print("\n分詞器模型訓(xùn)練完成!")
        print(f"   模型文件: {model_prefix}.model")
        print(f"   詞匯表文件: {model_prefix}.vocab")


        # 驗(yàn)證模型是否成功創(chuàng)建
        if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):
            model_size_kb = os.path.getsize(f"{model_prefix}.model") / 1024
            print(f"\n模型文件大小: {model_size_kb:.2f} KB")


            # 加載模型進(jìn)行快速測(cè)試
            print("\n進(jìn)行快速測(cè)試...")
            sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model")


            test_text = "這是一個(gè)分詞測(cè)試句子。"
            tokens = sp.encode(test_text, out_type=str)
            ids = sp.encode(test_text, out_type=int)


            print(f" 測(cè)試文本: {test_text}")
            print(f" 分詞結(jié)果: {tokens}")
            print(f" Token IDs: {ids}")
        else:
            print("\n警告:模型文件生成失敗,請(qǐng)檢查輸入數(shù)據(jù)或參數(shù)")


    except Exception as e:
        print(f"\n訓(xùn)練過(guò)程出錯(cuò): {e}")
        sys.exit(1)




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="使用 SentencePiece 訓(xùn)練分詞器模型。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    parser.add_argument(
        "input_file",
        type=str,
        help="訓(xùn)練語(yǔ)料的路徑 (e.g., data/cleaned_wiki_full.txt)"
    )


    parser.add_argument(
        "model_prefix",
        type=str,
        help="訓(xùn)練模型文件的輸出前綴 (e.g., workdir/spm_wiki)"
    )


    parser.add_argument(
        "vocab_size",
        type=int,
        help="詞匯表大小 (e.g., 32000)"
    )


    parser.add_argument(
        "--model_type",
        type=str,
        default='bpe',
        choices=['bpe', 'unigram', 'char', 'word'],
        help="分詞算法類型 (默認(rèn): bpe)"
    )


    parser.add_argument(
        "--character_coverage",
        type=float,
        default=0.9995,
        help="字符覆蓋率,范圍 [0-1]。對(duì)于小詞表(8K),建議用0.99或更小"
    )


    args = parser.parse_args()


    print("\n" + "="*50)
    print("SentencePiece 分詞器訓(xùn)練程序")
    print("="*50)
    print(f"輸入語(yǔ)料: {args.input_file}")
    print(f"輸出模型前綴: {args.model_prefix}")
    print(f"詞匯表大小: {args.vocab_size}")
    print(f"分詞算法: {args.model_type}")
    print(f"字符覆蓋率: {args.character_coverage}")
    print("="*50 + "\n")


    train_spm_model(
        args.input_file, 
        args.model_prefix, 
        args.vocab_size,
        args.model_type,
        args.character_coverage
    )




if __name__ == "__main__":
    main()
開(kāi)始訓(xùn)練SentencePiece分詞器...
   輸入語(yǔ)料: data/cleaned_wiki_full.txt
   輸出模型前綴: workdir/spm_wiki_16k
   詞匯表大小: 16000
   語(yǔ)料文件大小: 1697.54 MB
Counting lines: 1256186it [00:05, 230354.42it/s]
   語(yǔ)料總行數(shù) (文章數(shù)): 1256186


--- SentencePiece 訓(xùn)練參數(shù) ---
--input=data/cleaned_wiki_full.txt
--model_prefix=workdir/spm_wiki_16k
--vocab_size=16000
--model_type=bpe
--character_coverage=0.9995
--num_threads=16
--bos_id=0
--eos_id=1
--unk_id=2
--pad_id=-1 
------------------------------


? 正在啟動(dòng)訓(xùn)練... 請(qǐng)注意觀察 SentencePiece 自身的進(jìn)度輸出。
sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=data/cleaned_wiki_full.txt --model_prefix=workdir/spm_colinai_16000 --vocab_size=16000 --model_type=bpe --character_coverage=0.9995 --num_threads=16 --bos_id=0 --eos_id=1 --unk_id=2 --pad_id=-1 
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: data/cleaned_wiki_full.txt
  input_format: 
  model_prefix: workdir/spm_colinai_16000
  model_type: BPE
  vocab_size: 16000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 2
  bos_id: 0
  eos_id: 1
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ? 
  enable_differential_privacy: 0
  differential_privacy_noise_level: 0
  differential_privacy_clipping_threshold: 0
}
normalizer_spec {
  name: nmt_nfkc
  add_dummy_prefix: 1
  remove_extra_whitespaces: 1
  escape_whitespaces: 1
  normalization_rule_tsv: 
}
denormalizer_spec {}
trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(186) LOG(INFO) Loading corpus: data/cleaned_wiki_full.txt
trainer_interface.cc(382) LOG(WARNING) Found too long line (18615 > 4192).
trainer_interface.cc(384) LOG(WARNING) Too long lines are skipped in the training.
trainer_interface.cc(385) LOG(WARNING) The maximum length can be changed with --max_sentence_length=<size> flag.
trainer_interface.cc(411) LOG(INFO) Loaded all 528882 sentences
trainer_interface.cc(418) LOG(INFO) Skipped 99211 too long sentences.
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(432) LOG(INFO) Normalizing sentences...
trainer_interface.cc(541) LOG(INFO) all chars count=281809036
trainer_interface.cc(552) LOG(INFO) Done: 99.95% characters are covered.
trainer_interface.cc(562) LOG(INFO) Alphabet size=8686
trainer_interface.cc(563) LOG(INFO) Final character coverage=0.9995
trainer_interface.cc(594) LOG(INFO) Done! preprocessed 528882 sentences.
trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 528882
trainer_interface.cc(611) LOG(INFO) Done! 3885388
.....

05、原始文本轉(zhuǎn)為Token ID 序列

在訓(xùn)練大型語(yǔ)言模型的準(zhǔn)備階段,將海量文本語(yǔ)料轉(zhuǎn)化為模型可處理的數(shù)字格式至關(guān)重要。本次將原始文本語(yǔ)料編碼為整數(shù) Token ID 序列。為了克服單次加載大文件的內(nèi)存限制,腳本采用了分塊讀取機(jī)制,支持以自定義大小逐塊處理語(yǔ)料。所有 Token ID 最終被匯總并轉(zhuǎn)化為高效率的 torch.int32 PyTorch 張量,直接存儲(chǔ)為 .pt 文件。這不僅優(yōu)化了數(shù)據(jù)格式,方便后續(xù) PyTorch DataLoader 快速讀取,同時(shí)也提供了關(guān)鍵的統(tǒng)計(jì)信息和完整性驗(yàn)證,是構(gòu)建 LLM 數(shù)據(jù)集的穩(wěn)定且高性能的預(yù)處理方案。

import sys
import torch
import sentencepiece as spm
import argparse
from tqdm import tqdm
import os
import numpy as np


# python scripts/preprocess_data.py workdir/spm_wiki.model data/cleaned_wiki_full.txt workdir/wiki_tokens.pt




def preprocess(sp_model_path: str, 
               corpus_path: str, 
               output_path: str,
               chunk_size_mb: int = 50):
    """
    分塊讀取語(yǔ)料,編碼為 Token ID,并保存為 PyTorch 文件。


    參數(shù):
        sp_model_path: SentencePiece 模型文件路徑
        corpus_path: 輸入語(yǔ)料文件路徑
        output_path: 輸出 .pt 文件路徑
        chunk_size_mb: 每次處理的文本大小(MB),默認(rèn) 50MB
    """
    # 驗(yàn)證文件存在
    if not os.path.exists(sp_model_path):
        print(f"錯(cuò)誤:分詞器模型文件未找到: {sp_model_path}")
        sys.exit(1)


    if not os.path.exists(corpus_path):
        print(f"錯(cuò)誤:語(yǔ)料文件未找到: {corpus_path}")
        sys.exit(1)


    # 加載分詞器
    try:
        sp = spm.SentencePieceProcessor(model_file=sp_model_path)
        vocab_size = sp.get_piece_size()
        print(f"   分詞器加載成功")
        print(f"   詞匯表大小: {vocab_size}")
        print(f"   特殊 Token: BOS={sp.bos_id()}, EOS={sp.eos_id()}, UNK={sp.unk_id()}, PAD={sp.pad_id()}")
    except Exception as e:
        print(f"加載分詞器失敗: {e}")
        sys.exit(1)


    # 確保輸出目錄存在
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)


    print(f"\n 開(kāi)始處理語(yǔ)料...")
    print(f"   輸入文件: {corpus_path}")
    print(f"   輸出文件: {output_path}")
    print(f"   塊大小: {chunk_size_mb} MB\n")


    # 計(jì)算總大小用于進(jìn)度條
    total_bytes = os.path.getsize(corpus_path)
    chunk_size_bytes = chunk_size_mb * 1024 * 1024


    token_ids = []
    tokens_processed = 0
    chunks_processed = 0


    try:
        with open(corpus_path, 'r', encoding='utf-8') as f:
            with tqdm(total=total_bytes, unit='B', unit_scale=True, desc="? 編碼語(yǔ)料") as pbar:


                while True:
                    chunk = f.read(chunk_size_bytes)
                    if not chunk:
                        break


                    # 直接編碼(cleaned_wiki_full.txt 已經(jīng)過(guò)清洗)
                    ids = sp.encode(chunk, out_type=int)
                    token_ids.extend(ids)


                    # 更新進(jìn)度條
                    bytes_read = len(chunk.encode('utf-8'))
                    pbar.update(bytes_read)


                    tokens_processed += len(ids)
                    chunks_processed += 1


                    # 定期顯示進(jìn)度信息
                    if chunks_processed % 10 == 0:
                        pbar.set_postfix({
                            'chunks': chunks_processed,
                            'tokens': f'{tokens_processed:,}'
                        })


        print(f"\n 編碼完成")
        print(f"   處理塊數(shù): {chunks_processed}")
        print(f"   總 Token 數(shù): {tokens_processed:,}")


        # 轉(zhuǎn)換為 PyTorch 張量
        print(f"\n轉(zhuǎn)換為張量并保存...")
        final_tensor = torch.tensor(token_ids, dtype=torch.int32)


        print(f"   張量形狀: {final_tensor.shape}")
        print(f"   張量大小: {final_tensor.numel():,}")
        print(f"   數(shù)據(jù)類型: {final_tensor.dtype}")
        print(f"   占用內(nèi)存: {final_tensor.numel() * 4 / (1024**3):.2f} GB")


        # 驗(yàn)證 Token ID 范圍
        min_id = final_tensor.min().item()
        max_id = final_tensor.max().item()
        print(f"   Token ID 范圍: [{min_id}, {max_id}]")


        if max_id >= vocab_size or min_id < 0:
            print(f"   警告: 檢測(cè)到越界 Token ID!")
            print(f"   詞匯表大小: {vocab_size}")
            print(f"   最大 ID: {max_id}")


        # 保存張量
        torch.save(final_tensor, output_path)
        file_size_mb = os.path.getsize(output_path) / (1024 ** 2)
        print(f"\nToken ID 已保存到 {output_path}")
        print(f"   文件大小: {file_size_mb:.2f} MB")


        # 驗(yàn)證保存的文件
        print(f"\n驗(yàn)證保存的文件...")
        loaded_tensor = torch.load(output_path)
        print(f"   加載成功,形狀: {loaded_tensor.shape}")
        print(f"   是否相同: {torch.equal(final_tensor, loaded_tensor)}")


        print(f"\n? 預(yù)處理完成!")


    except Exception as e:
        print(f"\n處理過(guò)程中出錯(cuò): {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="將清洗后的文本語(yǔ)料轉(zhuǎn)換為 Token ID 二進(jìn)制文件。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    parser.add_argument(
        "model_path",
        type=str,
        help="SentencePiece 模型文件路徑 (e.g., workdir/spm_wiki.model)"
    )


    parser.add_argument(
        "corpus_path",
        type=str,
        help="輸入語(yǔ)料文件路徑 (e.g., data/cleaned_wiki_full.txt)"
    )


    parser.add_argument(
        "output_path",
        type=str,
        help="輸出 Token ID 文件路徑 (e.g., workdir/wiki_tokens.pt)"
    )


    parser.add_argument(
        "--chunk_size",
        type=int,
        default=50,
        help="每次處理的文本大小(MB),默認(rèn) 50MB。更大的塊更快,但占用更多內(nèi)存。"
    )


    args = parser.parse_args()


    print("\n" + "="*60)
    print("數(shù)據(jù)預(yù)處理程序 - 文本到 Token ID")
    print("="*60)
    print(f"SentencePiece 模型: {args.model_path}")
    print(f"輸入語(yǔ)料: {args.corpus_path}")
    print(f"輸出文件: {args.output_path}")
    print(f"塊大小: {args.chunk_size} MB")
    print("="*60 + "\n")


    preprocess(
        args.model_path,
        args.corpus_path,
        args.output_path,
        args.chunk_size
    )




if __name__ == "__main__":
    main()

06、進(jìn)行模型預(yù)訓(xùn)練

"""
GPT 高性能訓(xùn)練腳本
"""


from __future__ import annotations
import sys
import os
import math
import json
from datetime import datetime
from typing import Optional


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
from tqdm import tqdm


# ==================== 配置參數(shù) ====================
class Config:
    BLOCK_SIZE = 512 #256
    BATCH_SIZE = 32 #64
    GRAD_ACCUM_STEPS = 4 #1
    MODEL_DIM = 384 #256
    N_LAYERS = 5 #2
    NUM_HEADS = 6 #4
    HEAD_DIM = MODEL_DIM // NUM_HEADS
    FFN_DIM = MODEL_DIM * 4
    VOCAB_SIZE = None


    EPOCHS = 1
    MAX_STEPS = 10000 # 此處根據(jù)自己的硬件和時(shí)間定義步數(shù)
    WARMUP_STEPS = 500
    LR = 1e-4
    MIN_LR = 1e-5
    WEIGHT_DECAY = 0.01
    GRAD_CLIP = 1.0
    DROPOUT = 0.1


    CHECKPOINT_EVERY = 5000
    LOG_EVERY = 100


    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    CHECKPOINT_DIR = "./checkpoints"
    LATEST_CHECKPOINT = "latest_checkpoint.pth"


    NUM_WORKERS = 8
    SEED = 42


    # 啟用 bfloat16 (推薦用于現(xiàn)代 GPU)
    DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16


CFG = Config()


if CFG.DEVICE == 'cuda':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.cuda.empty_cache()
    # 檢查是否使用了 bfloat16
    if CFG.DTYPE == torch.bfloat16:
        print("使用 bfloat16 混合精度 (推薦)")
    else:
        print("使用 float16 混合精度")


# ==================== 工具函數(shù) ====================
def print_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024**3)
        reserved = torch.cuda.memory_reserved() / (1024**3)
        print(f"GPU顯存: {allocated:.2f}GB / {reserved:.2f}GB")


def set_seed(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(CFG.SEED)


# ==================== 數(shù)據(jù)集 ====================
class TextDataset(Dataset):
    def __init__(self, token_ids: torch.Tensor, block_size: int):
        self.ids = token_ids.long()
        self.block_size = block_size


    def __len__(self):
        return max(0, self.ids.size(0) - self.block_size)


    def __getitem__(self, idx):
        x = self.ids[idx: idx + self.block_size]
        y = self.ids[idx + 1: idx + 1 + self.block_size]
        return x, y


# ==================== RoPE 位置編碼  ====================
class RotaryPositionalEmbedding(nn.Module):
    """RoPE 實(shí)現(xiàn)"""
    def __init__(self, head_dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.head_dim = head_dim
        assert head_dim % 2 == 0, "head_dim must be even"


        # 基頻:theta_i = 10000^(-2i/d)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)


        self.max_seq_len = max_seq_len
        self._seq_len_cached = max_seq_len
        self._cos_cached = None
        self._sin_cached = None
        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)


    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
        if seq_len == self._seq_len_cached and self._cos_cached is not None:
            return


        # m: (seq_len,), theta_i: (head_dim//2,)
        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", m, self.inv_freq)  # (seq_len, head_dim//2)


        # 構(gòu)造完整的旋轉(zhuǎn)矩陣(每個(gè)復(fù)數(shù)對(duì)重復(fù))
        emb = torch.cat([freqs, freqs], dim=-1)  # (seq_len, head_dim)


        cos = emb.cos()[None, None, :, :]  # (1, 1, seq_len, head_dim)
        sin = emb.sin()[None, None, :, :]  # (1, 1, seq_len, head_dim)


        self._cos_cached = cos
        self._sin_cached = sin
        self._seq_len_cached = seq_len


    def forward(self, seq_len: int, device: Optional[torch.device] = None):
        if device is None:
            device = self.inv_freq.device
        self._update_cos_sin_cache(seq_len, device=device)
        return self._cos_cached.to(device), self._sin_cached.to(device)


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """應(yīng)用RoPE旋轉(zhuǎn)"""
    # x: (B, H, T, D), cos/sin: (1, 1, T, D)
    # 使用(x, y) -> (x*cos-y*sin, x*sin+y*cos)
    return (x * cos) + (_rotate_half(x) * sin)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """將向量旋轉(zhuǎn)90度"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# ==================== Flash Attention ====================
class FlashAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5


        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.rope = RotaryPositionalEmbedding(self.head_dim)


    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = x.shape
        assert T <= self.rope.max_seq_len, f"Seq len {T} exceeds max {self.rope.max_seq_len}"


        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)


        # 應(yīng)用RoPE
        cos, sin = self.rope(T, device=x.device)
        q = apply_rotary_emb(q, cos, sin)
        k = apply_rotary_emb(k, cos, sin)


        # 注意力計(jì)算
        # 注意:這里如果使用 torch.nn.functional.scaled_dot_product_attention 配合 torch.compile 會(huì)更快
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if causal_mask is not None:
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
        return self.out_proj(out)


# ==================== 前饋網(wǎng)絡(luò) ====================
class GLU(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim * 2)


    def forward(self, x):
        x, gates = self.linear(x).chunk(2, dim=-1)
        return x * torch.nn.functional.silu(gates)


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            GLU(dim, hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )


    def forward(self, x):
        return self.net(x)


# ==================== Transformer Block ====================
class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, ffn_dim, dropout)


    def forward(self, x, causal_mask=None):
        x = x + self.attn(self.ln1(x), causal_mask)
        x = x + self.ff(self.ln2(x))
        return x


# ==================== GPT 模型(已移除 pos_emb) ====================
class GPTModel(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
                 ffn_dim: int = CFG.FFN_DIM, dropout: float = CFG.DROPOUT,
                 tie_weights: bool = True):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        # self.pos_emb = nn.Embedding(block_size, dim) # 移除:與 RoPE 沖突
        self.dropout = nn.Dropout(dropout)


        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
        ])


        self.ln_final = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)


        if tie_weights:
            self.lm_head.weight = self.token_emb.weight


        self.block_size = block_size
        self.apply(self._init_weights)


        n_params = sum(p.numel() for p in self.parameters())
        print(f"模型參數(shù): {n_params/1e6:.2f}M")


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)


    def forward(self, idx):
        B, T = idx.shape
        assert T <= self.block_size, f"Seq len {T} exceeds block_size {self.block_size}"


        token_emb = self.token_emb(idx)


       
        x = self.dropout(token_emb) # token embedding


        causal_mask = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))[None, None, :, :]
        for block in self.blocks:
            x = block(x, causal_mask)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits


# ==================== 檢查點(diǎn)管理 ====================
def save_checkpoint(model, optimizer, scaler, lr_scheduler, step: int, loss: float, config_dict: dict):
    os.makedirs(CFG.CHECKPOINT_DIR, exist_ok=True)
    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
    state = {
        'step': step,
        'loss': loss,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config_dict,
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
    }


    if scaler is not None and hasattr(scaler, "state_dict"):
        state['scaler_state_dict'] = scaler.state_dict()


    if lr_scheduler is not None:
        state['lr_scheduler_state_dict'] = {
            'current_step': lr_scheduler.current_step,
            'warmup_steps': lr_scheduler.warmup_steps,
            'total_steps': lr_scheduler.total_steps,
            'base_lr': lr_scheduler.base_lr,
            'min_lr': lr_scheduler.min_lr,
        }


    torch.save(state, checkpoint_path)


    try:
        with open(os.path.join(CFG.CHECKPOINT_DIR, "config.json"), "w", encoding="utf-8") as f:
            json.dump(config_dict, f, indent=2)
    except Exception:
        pass


    print(f" 檢查點(diǎn)已保存: {checkpoint_path} (step {step}, loss {loss:.4f})")


def load_checkpoint(checkpoint_path: str, model, optimizer, scaler, lr_scheduler):
    if not os.path.exists(checkpoint_path):
        return None


    checkpoint = torch.load(checkpoint_path, map_locatinotallow=CFG.DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


    if checkpoint.get('scaler_state_dict') is not None and scaler is not None:
        try:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        except Exception as e:
            print(f"無(wú)法恢復(fù)scaler: {e}")


    if checkpoint.get('lr_scheduler_state_dict') is not None and lr_scheduler is not None:
        try:
            sched_state = checkpoint['lr_scheduler_state_dict']
            lr_scheduler.current_step = sched_state['current_step']
            lr_scheduler.warmup_steps = sched_state['warmup_steps']
            lr_scheduler.total_steps = sched_state['total_steps']
            lr_scheduler.base_lr = sched_state['base_lr']
            lr_scheduler.min_lr = sched_state['min_lr']
        except Exception as e:
            print(f"無(wú)法恢復(fù)lr_scheduler: {e}")


    torch.set_rng_state(checkpoint['torch_rng_state'])
    if torch.cuda.is_available() and checkpoint.get('cuda_rng_state') is not None:
        torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])


    print(f"檢查點(diǎn)已加載: {checkpoint_path}")
    print(f"    Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")
    return checkpoint['step']


# ==================== 學(xué)習(xí)率調(diào)度器 ====================
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float):
        self.optimizer = optimizer
        self.warmup_steps = max(0, int(warmup_steps))
        self.total_steps = max(1, int(total_steps))
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.current_step = 0


    def get_lr(self, step: int = None) -> float:
        """計(jì)算給定step的學(xué)習(xí)率(不修改optimizer)"""
        if step is None:
            step = self.current_step


        if step < self.warmup_steps and self.warmup_steps > 0:
            return self.base_lr * (step / float(self.warmup_steps))
        else:
            denom = max(1, (self.total_steps - self.warmup_steps))
            progress = (step - self.warmup_steps) / denom
            progress = min(1.0, max(0.0, progress))
            return self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))


    def step(self):
        """執(zhí)行一次步長(zhǎng)更新"""
        lr = self.get_lr(self.current_step)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.current_step += 1
        return lr


# ==================== 訓(xùn)練循環(huán) ====================
def train(model: nn.Module, train_loader: DataLoader, epochs: int = CFG.EPOCHS, resume: bool = False):
    # 檢測(cè)fused優(yōu)化器支持
    fused = False
    try:
        fused = torch.cuda.is_available() and ("fused" in torch.optim.AdamW.__init__.__code__.co_varnames)
    except Exception:
        fused = False


    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CFG.LR,
        betas=(0.9, 0.95),
        weight_decay=CFG.WEIGHT_DECAY,
        fused=fused
    )


    # 使用配置中的 DTYPE
    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16))
    loss_fn = nn.CrossEntropyLoss()


    total_steps = CFG.MAX_STEPS if CFG.MAX_STEPS else len(train_loader) * epochs
    lr_scheduler = WarmupCosineScheduler(optimizer, CFG.WARMUP_STEPS, total_steps, CFG.LR, CFG.MIN_LR)


    model.train()
    start_step = 0
    best_loss = float('inf')


    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
    if resume and os.path.exists(checkpoint_path):
        loaded_step = load_checkpoint(checkpoint_path, model, optimizer, scaler, lr_scheduler)
        if loaded_step is not None:
            start_step = loaded_step


    global_step = start_step
    grad_accum_counter = 0
    accumulated_loss = 0.0


    print("\n" + "="*60)
    print("開(kāi)始訓(xùn)練...")
    print("="*60)
    print_gpu_memory()
    print()


    # 自動(dòng)選擇是否需要 scaler.scale()
    use_scaler = (CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16)


    for epoch in range(epochs):
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", initial=global_step % len(train_loader) if epoch == 0 else 0)
        num_batches = 0
        last_lr = None


        for batch_idx, (xb, yb) in enumerate(pbar):
            # 跳過(guò)已訓(xùn)練的批次 (如果從中間恢復(fù))
            if global_step > start_step and batch_idx < (start_step % len(train_loader)):
                 continue


            xb = xb.to(CFG.DEVICE, non_blocking=True)
            yb = yb.to(CFG.DEVICE, non_blocking=True)


            with torch.cuda.amp.autocast(enabled=(CFG.DEVICE == "cuda"), dtype=CFG.DTYPE):
                logits = model(xb)
                loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
                loss_item = loss.item()
                loss = loss / CFG.GRAD_ACCUM_STEPS


            if use_scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()


            grad_accum_counter += 1
            accumulated_loss += loss_item
            num_batches += 1
            # 這里的 global_step 計(jì)數(shù)是基于數(shù)據(jù)批次的,而不是優(yōu)化器步數(shù),用于日志和檢查點(diǎn)
            # 真正的優(yōu)化器步數(shù)會(huì)在下面更新


            # 梯度累積:達(dá)到閾值時(shí)執(zhí)行優(yōu)化步驟
            if grad_accum_counter >= CFG.GRAD_ACCUM_STEPS:


                # 優(yōu)化器步進(jìn) (這是真正的 global_step 增長(zhǎng)點(diǎn))
                lr_scheduler.step() # 先更新 LR
                global_step += 1 # 只有進(jìn)行了一次優(yōu)化器步進(jìn),才算一個(gè) global_step


                if use_scaler:
                    scaler.unscale_(optimizer)


                # 梯度裁剪 (在 unscale 后或非 AMP 模式下)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)


                if use_scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()


                optimizer.zero_grad()
                grad_accum_counter = 0
                last_lr = lr_scheduler.get_lr(global_step) # 獲取當(dāng)前步的LR


            # 日志輸出
            if global_step % CFG.LOG_EVERY == 0 or (global_step == 1):
                # accumulated_loss 是累積的原始損失, num_batches 是累積的批次數(shù)
                avg_loss = accumulated_loss / num_batches if num_batches > 0 else 0.0
                pbar.set_postfix({
                    'step': global_step,
                    'loss': f'{avg_loss:.4f}',
                    'lr': f'{last_lr:.2e}' if last_lr is not None else 'N/A'
                })
                # 重置累積值以便計(jì)算下一個(gè) LOG_EVERY 間隔的平均損失
                accumulated_loss = 0.0
                num_batches = 0




            # 保存檢查點(diǎn)
            if global_step > start_step and global_step % CFG.CHECKPOINT_EVERY == 0:
                # 使用上一個(gè)日志點(diǎn)計(jì)算的 avg_loss
                current_avg_loss = accumulated_loss / num_batches if num_batches > 0 else loss_item


                config_dict = {
                    'vocab_size': CFG.VOCAB_SIZE,
                    'block_size': CFG.BLOCK_SIZE,
                    'model_dim': CFG.MODEL_DIM,
                    'n_layers': CFG.N_LAYERS,
                    'num_heads': CFG.NUM_HEADS,
                    'created_at': datetime.now().isoformat()
                }
                save_checkpoint(model, optimizer, scaler, lr_scheduler, global_step, current_avg_loss, config_dict)
                torch.cuda.empty_cache()


            if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
                break


        # 處理 epoch 結(jié)束時(shí)剩余的梯度 (如果 grad_accum_counter > 0)
        if grad_accum_counter > 0:
            if use_scaler:
                scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)


            if use_scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()


            optimizer.zero_grad()
            lr_scheduler.step()
            global_step += 1
            grad_accum_counter = 0




        # 此時(shí) pbar.total_loss 已累積
        if num_batches > 0:
             final_avg_loss = accumulated_loss / num_batches
        else:
             final_avg_loss = float('inf')




        if final_avg_loss < best_loss:
            best_loss = final_avg_loss
            best_path = os.path.join(CFG.CHECKPOINT_DIR, "best_model.pth")
            torch.save(model.state_dict(), best_path)
            print(f"最佳模型已保存 (loss: {best_loss:.4f})")


        print(f"\n[Epoch {epoch+1}] Avg Loss: {final_avg_loss:.4f}")


        if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
            break


    print("\n訓(xùn)練完成!")


# ==================== 主函數(shù) ====================
def main():
    if len(sys.argv) < 4:
        print("用法: python train_20251012_v1.py workdir/spm_wiki_16k.model workdir/wiki_tokens_16k.pt models/gpt_wiki.pth [--resume]")
        sys.exit(1)


    sp_model_path, token_file_path, out_path = sys.argv[1:4]
    resume = "--resume" in sys.argv


    if not os.path.exists(token_file_path):
        print(f" Token文件不存在: {token_file_path}")
        sys.exit(1)


    # 檢查 CFG.DTYPE 是否為 bfloat16 但環(huán)境不支持
    if CFG.DTYPE == torch.bfloat16 and not torch.cuda.is_bf16_supported():
        print("警告: bfloat16 不受當(dāng)前 CUDA 設(shè)備支持,自動(dòng)回退到 float16。")
        CFG.DTYPE = torch.float16


    sp = spm.SentencePieceProcessor(model_file=sp_model_path)
    CFG.VOCAB_SIZE = sp.get_piece_size()


    print("="*60)
    print("GPT 語(yǔ)言模型訓(xùn)練")
    print("="*60)
    print(f"分詞器: {sp_model_path}")
    print(f"Token文件: {token_file_path}")
    print(f"輸出模型: {out_path}")
    print(f"設(shè)備: {CFG.DEVICE}")
    print(f"\n模型配置:")
    print(f"    - VOCAB_SIZE: {CFG.VOCAB_SIZE}")
    print(f"    - BLOCK_SIZE: {CFG.BLOCK_SIZE}")
    print(f"    - MODEL_DIM: {CFG.MODEL_DIM}")
    print(f"    - N_LAYERS: {CFG.N_LAYERS}")
    print(f"    - NUM_HEADS: {CFG.NUM_HEADS}")
    print(f"\n訓(xùn)練配置:")
    print(f"    - BATCH_SIZE: {CFG.BATCH_SIZE}")
    print(f"    - GRAD_ACCUM_STEPS: {CFG.GRAD_ACCUM_STEPS}")
    print(f"    - 有效BATCH_SIZE: {CFG.BATCH_SIZE * CFG.GRAD_ACCUM_STEPS}")
    print(f"    - LR: {CFG.LR}, WARMUP_STEPS: {CFG.WARMUP_STEPS}")
    print("="*60)


    print(f"\n加載Token文件: {token_file_path}")
    ids = torch.load(token_file_path)
    print(f"已加載 {ids.numel():,} tokens ({ids.numel() * ids.element_size() / (1024**3):.2f} GB)")


    dataset = TextDataset(ids, CFG.BLOCK_SIZE)
    del ids
    torch.cuda.empty_cache()


    # 改進(jìn):?jiǎn)⒂?shuffle=True 進(jìn)行預(yù)訓(xùn)練
    num_workers = CFG.NUM_WORKERS
    try:
        train_loader = DataLoader(
            dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=True, # 啟用 Shuffle
            pin_memory=(CFG.DEVICE == "cuda"),
            num_workers=num_workers,
            persistent_workers=True if num_workers > 0 else False
        )
    except Exception as e:
        print(f"DataLoader錯(cuò)誤: {e}, 改用num_workers=0")
        train_loader = DataLoader(
            dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=True,
            pin_memory=(CFG.DEVICE == "cuda"),
            num_workers=0
        )


    model = GPTModel(
        CFG.VOCAB_SIZE,
        CFG.BLOCK_SIZE,
        dim=CFG.MODEL_DIM,
        num_layers=CFG.N_LAYERS,
        num_heads=CFG.NUM_HEADS,
        ffn_dim=CFG.FFN_DIM,
        dropout=CFG.DROPOUT
    ).to(CFG.DEVICE)


    # 嘗試編譯(容錯(cuò))
    try:
        model = torch.compile(model, mode='reduce-overhead')
        print("已啟用 torch.compile() 加速")
    except Exception as e:
        print(f"跳過(guò) torch.compile(): {e}")


    train(model, train_loader, epochs=CFG.EPOCHS, resume=resume)


    torch.save(model.state_dict(), out_path)
    print(f"\n最終模型已保存到 {out_path}")
    print_gpu_memory()


if __name__ == "__main__":
    main()

07、進(jìn)行模型推理測(cè)試

import torch
from torch import nn
import sentencepiece as spm
from typing import Optional


# ==================== 配置參數(shù) (必須與訓(xùn)練時(shí)一致) ====================
# 使用與訓(xùn)練腳本中完全相同的配置
class Config:
    BLOCK_SIZE = 512
    # 模型尺寸參數(shù) (必須與訓(xùn)練時(shí)一致)
    MODEL_DIM = 384
    N_LAYERS = 5
    NUM_HEADS = 6
    HEAD_DIM = MODEL_DIM // NUM_HEADS
    FFN_DIM = MODEL_DIM * 4 
    VOCAB_SIZE = None


    # 推理設(shè)置
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    # 推理通常使用 float32 獲得最佳兼容性和精度
    DTYPE = torch.float32 


CFG = Config()


# ==================== RoPE 位置編碼 (與訓(xùn)練腳本保持一致) ====================
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.head_dim = head_dim
        assert head_dim % 2 == 0, "head_dim must be even"
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_seq_len = max_seq_len
        self._seq_len_cached = max_seq_len
        self._cos_cached = None
        self._sin_cached = None
        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)


    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
        if seq_len == self._seq_len_cached and self._cos_cached is not None:
            return
        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", m, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        cos = emb.cos()[None, None, :, :]
        sin = emb.sin()[None, None, :, :]
        self._cos_cached = cos
        self._sin_cached = sin
        self._seq_len_cached = seq_len


    def forward(self, seq_len: int, device: Optional[torch.device] = None):
        if device is None:
            device = self.inv_freq.device
        self._update_cos_sin_cache(seq_len, device=device)
        return self._cos_cached.to(device), self._sin_cached.to(device)


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    return (x * cos) + (_rotate_half(x) * sin)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# ==================== Attention, FFN, Block, Model (與訓(xùn)練腳本保持一致) ====================
class FlashAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        # 推理時(shí)通常不使用 Dropout,但模型結(jié)構(gòu)需要保持一致
        self.attn_dropout = nn.Dropout(attn_dropout) 
        self.rope = RotaryPositionalEmbedding(self.head_dim)


    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)


        cos, sin = self.rope(T, device=x.device)
        q = apply_rotary_emb(q, cos, sin)
        k = apply_rotary_emb(k, cos, sin)


        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # 注意:在推理時(shí),通常使用 KV-Cache,這里簡(jiǎn)化為完整計(jì)算
        if T > 1: # 僅在序列長(zhǎng)度大于 1 時(shí)應(yīng)用 mask
            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))[None, None, :, :]
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))


        attn = torch.softmax(scores, dim=-1)
        # 推理時(shí)禁用 dropout
        # attn = self.attn_dropout(attn) 
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
        return self.out_proj(out)


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        # 必須保持與訓(xùn)練腳本中完全相同的 nn.Sequential 結(jié)構(gòu)
        self.net = nn.Sequential(
            GLU(dim, hidden_dim),
            nn.Dropout(dropout),   # net.1: Dropout (必須保留,占位)
            nn.Linear(hidden_dim, dim), # net.2: Linear (與訓(xùn)練時(shí)一致)
            nn.Dropout(dropout),   # net.3: Dropout (必須保留,占位)
        )


    def forward(self, x):
        # 在推理時(shí), model.eval() 會(huì)自動(dòng)禁用所有 nn.Dropout 層,但結(jié)構(gòu)不變
        return self.net(x)


# 確保 GLU 的定義如下(與訓(xùn)練時(shí)一致):
class GLU(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        # GLU 內(nèi)部只有一個(gè) nn.Linear
        self.linear = nn.Linear(in_dim, out_dim * 2)


    def forward(self, x):
        x, gates = self.linear(x).chunk(2, dim=-1)
        return x * torch.nn.functional.silu(gates)


class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, ffn_dim, dropout)


    def forward(self, x, causal_mask=None):
        x = x + self.attn(self.ln1(x), causal_mask)
        x = x + self.ff(self.ln2(x))
        return x


class GPTModel(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
                 ffn_dim: int = CFG.FFN_DIM, dropout: float = 0.0, # 推理時(shí) dropout=0
                 tie_weights: bool = True):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.dropout = nn.Dropout(dropout)


        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
        ])


        self.ln_final = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)


        if tie_weights:
            self.lm_head.weight = self.token_emb.weight


        self.block_size = block_size


    def forward(self, idx):
        B, T = idx.shape
        token_emb = self.token_emb(idx)
        x = token_emb # 推理時(shí)不使用 dropout


        causal_mask = None # Attention 模塊內(nèi)部處理 Causal Mask
        for block in self.blocks:
            x = block(x, causal_mask)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits




# ==================== 推理和生成函數(shù) ====================


@torch.no_grad()
def generate_text(model: GPTModel, sp: spm.SentencePieceProcessor, 
                  prompt: str, max_new_tokens: int, temperature: float = 0.8, 
                  top_k: int = 50):


    model.eval()
    device = CFG.DEVICE


    # 1. 編碼輸入
    input_ids = sp.encode_as_ids(prompt)
    if not input_ids:
        return "無(wú)法編碼輸入。"


    # 將輸入轉(zhuǎn)換為模型期望的格式 (B, T)
    x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)


    # 2. 循環(huán)生成
    for _ in range(max_new_tokens):
        # 裁剪輸入以適應(yīng)模型的 BLOCK_SIZE
        # 在實(shí)際部署中,這里應(yīng)該使用 KV Cache,但此處簡(jiǎn)化為完整前向傳播
        idx_cond = x if x.size(1) <= CFG.BLOCK_SIZE else x[:, -CFG.BLOCK_SIZE:]


        # 獲取 logits
        logits = model(idx_cond)


        # 只取最后一個(gè)時(shí)間步的 logits
        logits = logits[:, -1, :] 


        # 應(yīng)用溫度縮放
        logits = logits / temperature


        # 3. Top-K 采樣
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')


        # 計(jì)算概率并采樣
        probs = torch.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)


        # 4. 停止條件
        # 檢查是否生成了 EOS token (假設(shè) </s> 是 ID 3, 請(qǐng)根據(jù)您的分詞器調(diào)整)
        # 默認(rèn)使用 SentencePiece 的 <eos> ID
        if idx_next.item() == sp.eos_id():
            break


        # 將新生成的 token 添加到序列中
        x = torch.cat((x, idx_next), dim=1)


        # 檢查是否達(dá)到最大序列長(zhǎng)度 (防止溢出)
        if x.size(1) >= CFG.BLOCK_SIZE + max_new_tokens:
            break


    # 5. 解碼輸出
    output_ids = x[0].tolist()
    # 查找輸入 prompt 的長(zhǎng)度,只解碼新生成的 token
    start_index = len(input_ids)


    return sp.decode_ids(output_ids[start_index:])




# ==================== 主執(zhí)行函數(shù) ====================


def main_infer(sp_model_path: str, model_weights_path: str):
    print("="*50)
    print(f"GPT 模型推理模式")
    print(f"設(shè)備: {CFG.DEVICE}, DTYPE: {CFG.DTYPE}")
    print("="*50)


    # 1. 加載分詞器
    try:
        sp = spm.SentencePieceProcessor(model_file=sp_model_path)
        CFG.VOCAB_SIZE = sp.get_piece_size()
        print(f"加載分詞器成功,VOCAB_SIZE: {CFG.VOCAB_SIZE}")
    except Exception as e:
        print(f"無(wú)法加載分詞器模型 {sp_model_path}: {e}")
        return


    # 2. 實(shí)例化模型
    model = GPTModel(
        vocab_size=CFG.VOCAB_SIZE,
        block_size=CFG.BLOCK_SIZE,
        dim=CFG.MODEL_DIM,
        num_layers=CFG.N_LAYERS,
        num_heads=CFG.NUM_HEADS,
        ffn_dim=CFG.FFN_DIM,
        dropout=0.0 # 推理時(shí)設(shè)置 dropout 為 0
    ).to(CFG.DEVICE).to(CFG.DTYPE)


    # 3. 加載權(quán)重
    try:
        # 檢查是否是 torch.compile 后的狀態(tài)字典
        weights = torch.load(model_weights_path, map_locatinotallow=CFG.DEVICE)


        # 如果權(quán)重是 DDP 或 torch.compile 包裝后的,需要解包
        if any(k.startswith('_orig_mod.') for k in weights.keys()):
            weights = {k.replace('_orig_mod.', ''): v for k, v in weights.items()}


        model.load_state_dict(weights, strict=True)
        print(f"成功加載模型權(quán)重: {model_weights_path}")
    except Exception as e:
        print(f"無(wú)法加載或匹配模型權(quán)重: {e}")
        # 如果加載失敗,打印預(yù)期鍵和實(shí)際鍵,方便調(diào)試
        # print("\n--- 預(yù)期模型鍵 (部分) ---")
        # print(list(model.state_dict().keys())[:5])
        # print("\n--- 載入權(quán)重鍵 (部分) ---")
        # print(list(weights.keys())[:5])
        return


    # 4. 進(jìn)入交互循環(huán)
    print("\n--- 進(jìn)入交互模式 ---")
    print(f"輸入 'exit' 或 'quit' 退出。")
    print(f"輸入 'config' 查看當(dāng)前生成參數(shù)。")
    print("----------------------")


    max_tokens = 100
    temperature = 0.8
    top_k = 50


    while True:
        try:
            prompt = input(">>> 輸入提示詞: ")


            if prompt.lower() in ['exit', 'quit']:
                break


            if prompt.lower() == 'config':
                print(f"  Max Tokens: {max_tokens}, Temp: {temperature}, Top K: {top_k}")
                new_max = input("  設(shè)置 Max Tokens (回車跳過(guò)): ")
                new_temp = input("  設(shè)置 Temperature (回車跳過(guò)): ")
                new_k = input("  設(shè)置 Top K (回車跳過(guò)): ")


                if new_max: max_tokens = int(new_max)
                if new_temp: temperature = float(new_temp)
                if new_k: top_k = int(new_k)
                continue


            if not prompt.strip():
                continue


            print("生成中...")


            # 執(zhí)行生成
            output = generate_text(model, sp, prompt, max_tokens, temperature, top_k)


            print(f"--- 模型回復(fù) ---\n{output.strip()}")
            print("----------------")


        except KeyboardInterrupt:
            print("\n退出生成...")
            break
        except Exception as e:
            print(f"發(fā)生錯(cuò)誤: {e}")




if __name__ == "__main__":
    import sys


    if len(sys.argv) != 3:
        print("用法: python infer.py <spm模型路徑> <模型權(quán)重文件路徑>")
        # 示例用法 (請(qǐng)根據(jù)您的實(shí)際文件路徑修改):
        # python infer.py tokenizer.model final_model.pth
        sys.exit(1)


    sp_model_path = sys.argv[1]
    model_weights_path = sys.argv[2]


    main_infer(sp_model_path, model_weights_path)

我們看到模型大概可以預(yù)測(cè)我們輸入的下一個(gè)詞,因我們訓(xùn)練的參數(shù)和步數(shù)很低,模型輸出的亂七八糟!

本次總結(jié)

本次我們做了數(shù)據(jù)準(zhǔn)備、數(shù)據(jù)清洗、分詞器訓(xùn)練、模型訓(xùn)練、推理等,請(qǐng)根據(jù)步驟進(jìn)行執(zhí)行代碼,你便可以得到一個(gè)17M參數(shù)的小模型。后面我們?cè)偌哟髤?shù)進(jìn)行訓(xùn)練,再進(jìn)行監(jiān)督微調(diào)。

責(zé)任編輯:龐桂玉 來(lái)源: 寫代碼的中年人
相關(guān)推薦

2025-10-24 10:34:55

2020-09-24 11:46:03

Promise

2021-03-23 15:21:00

人工智能機(jī)器學(xué)習(xí)技術(shù)

2020-03-17 10:45:11

GitHub代碼開(kāi)發(fā)者

2021-08-17 11:08:08

參數(shù)M6模型

2019-04-24 15:06:37

Http服務(wù)器協(xié)議

2024-11-04 00:24:56

2021-01-25 13:45:14

模型人工智能深度學(xué)習(xí)

2024-12-23 12:52:29

2021-06-30 07:19:36

網(wǎng)絡(luò)安全

2021-08-04 05:49:40

數(shù)據(jù)庫(kù)數(shù)時(shí)序數(shù)據(jù)庫(kù)技術(shù)

2021-10-28 09:19:29

模型人工智能Facebook

2014-09-25 09:51:29

Android App個(gè)人博客

2022-11-01 14:50:00

數(shù)據(jù)計(jì)算

2016-09-14 17:48:44

2023-04-06 08:01:30

RustMutex

2019-07-21 19:45:23

GitHub代碼開(kāi)發(fā)者

2017-06-06 10:14:55

KerasTensorFlow深度學(xué)習(xí)

2021-09-26 10:47:12

預(yù)訓(xùn)練模型GPT

2024-05-10 10:01:26

自動(dòng)駕駛模型
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)

精品在线视频免费| 中文字幕人妻无码系列第三区| 午夜性色福利视频| 日韩成人毛片视频| 99re66热这里只有精品4| 国产精品福利电影一区二区三区四区| 91gao视频| 国产三级av片| 自由日本语亚洲人高潮| 日韩av在线网页| 午夜精品免费看| 日本午夜大片a在线观看| 国产精品久久久久影院老司| 国内不卡一区二区三区| 一级欧美一级日韩| 国产农村妇女毛片精品久久莱园子| 影音先锋日韩有码| 国产在线不卡av| 四虎国产精品免费久久| 色哟哟一区二区| 男人天堂网站在线| 91美女视频在线| 成人小视频在线| 国产在线播放不卡| 久久精品视频7| 樱桃成人精品视频在线播放| 久久精品亚洲热| 国产sm调教视频| 亚洲+小说+欧美+激情+另类| 精品卡一卡二卡三卡四在线| 激情黄色小视频| 国产私拍福利精品视频二区| 狠狠躁天天躁日日躁欧美| 粉嫩av一区二区三区天美传媒 | 国产香蕉一区二区三区在线视频| 在线观看免费视频国产| www.久久久久爱免| 欧美日韩综合不卡| 激情网站五月天| 欧美xxxhd| 99国产精品视频免费观看一公开| 国产亚洲欧美日韩日本| 92国产精品久久久久首页 | 久久国产精品99久久人人澡| 国产99在线|中文| 中文字幕精品三级久久久| 亚洲手机视频| 欧美激情在线有限公司| 欧美精品色哟哟| 综合在线视频| 操日韩av在线电影| 性欧美videos| 在线一区免费| 欧美极品少妇xxxxx| 久久久久久久久99| 国产精品v一区二区三区| 欧美精品在线观看91| 欧美成人777| 亚洲中无吗在线| 欧美精品在线免费| 国产福利久久久| 亚洲精品1234| 欧美在线性爱视频| 波多野结衣视频免费观看| 久久一区二区三区四区五区| 国产mv久久久| 97成人在线观看| 国产一区二区三区免费在线观看| 92看片淫黄大片看国产片| 国产成人久久精品77777综合| 国产精品18久久久久久久久久久久 | 一区二区三欧美| 萌白酱视频在线| 一区二区三区四区电影| 欧美激情精品久久久| 狠狠躁狠狠躁视频专区| 精品欧美一区二区久久久久 | 亚洲第一精品夜夜躁人人躁| 少妇被狂c下部羞羞漫画| 色老板在线视频一区二区| 亚洲美女精品成人在线视频| 国产精品理论在线| 欧美黄在线观看| 国语自产精品视频在免费| 成年人av网站| 黑人巨大精品欧美一区| 成人av男人的天堂| 精品av中文字幕在线毛片| 中文字幕一区二区三区在线观看| 国产一二三四区在线观看| 18video性欧美19sex高清| 色老汉av一区二区三区| 国产大片一区二区三区| 欧美三级午夜理伦三级在线观看| 在线视频中文亚洲| 黑人巨大精品一区二区在线| 国际av在线| 天美传媒免费在线观看| 99九九久久| 欧美高清www午色夜在线视频| 精产国品一区二区三区| 亚洲日产av中文字幕| 久久精品久久久久| 天堂网av手机版| 国产一区二区三区蝌蚪| 久久精品女人的天堂av| 国产超级va在线视频| 狠狠色香婷婷久久亚洲精品| 亚洲高清视频免费| 狠狠操综合网| 久久久最新网址| 91片黄在线观看喷潮| 久久婷婷一区二区三区| av 日韩 人妻 黑人 综合 无码| 成人动漫一区| 亚洲成人精品在线| 永久免费看mv网站入口| 新狼窝色av性久久久久久| 亚洲一区二区久久久久久久| 精品福利视频导航大全| 亚洲国产成人精品视频| 午夜免费一级片| 国产va免费精品观看精品视频| 久久久久久久影院| 国产高清视频免费| 中文字幕在线一区二区三区| 狠狠操精品视频| 青青草久久爱| 久久久久久91| 国产黄色免费大片| 亚洲欧洲另类国产综合| 另类小说色综合| 国内精品久久久久久99蜜桃| 97国产精品视频| 亚洲大尺度网站| 亚洲日穴在线视频| 日韩av片专区| 97精品视频| 国产精品永久免费| 99免在线观看免费视频高清| 一本色道久久综合亚洲91| 久久性爱视频网站| 伊人精品成人久久综合软件| 亚洲最大av在线| av网站免费在线观看| 制服丝袜在线91| 国产精品精品软件男同| 久久精品国产网站| 一本久道久久综合| 在线免费成人| 久久综合久久八八| 国产露脸无套对白在线播放| 亚洲视频一二区| 免费人成视频在线播放| 欧美a级在线| 国产成人免费电影| 午夜激情电影在线播放| 亚洲欧美在线播放| 999视频在线| 国产精品久久久久久久久免费桃花| 欧美大尺度做爰床戏| 日韩国产在线| 91中文在线观看| 日本在线观看高清完整版| 亚洲国产精品成人va在线观看| 日本少妇在线观看| 久久蜜桃一区二区| av网站在线不卡| 影音先锋日韩在线| 国产精品久久久对白| gay欧美网站| 色综久久综合桃花网| 国产精品女人久久久| 亚洲精品美国一| 欧类av怡春院| 蜜桃视频一区二区三区在线观看| 五月天色婷婷综合| 91麻豆精品国产91久久久久推荐资源| 97视频在线观看网址| 国产乱子伦三级在线播放| 6080日韩午夜伦伦午夜伦| 麻豆成人在线视频| 91视频在线看| 狠狠操狠狠干视频| 亚洲日本免费| 亚洲在线播放电影| 国产色噜噜噜91在线精品| 国产福利精品视频| 污片视频在线免费观看| 亚洲欧美精品一区| 国产夫妻在线观看| 在线一区二区三区做爰视频网站| 日韩激情小视频| 91丨porny丨首页| 91丝袜超薄交口足| 性欧美长视频| 玖玖精品在线视频| 亚洲深夜福利在线观看| 亚洲va久久久噜噜噜| 亚洲女同志freevdieo| 久久成人av网站| 蜜桃成人在线视频| 精品少妇一区二区三区| 中文字幕在线日亚洲9| 亚洲成人激情自拍| 日本女人性生活视频| 91网站视频在线观看| 国产精品探花在线播放| 视频一区二区国产| 欧美乱大交xxxxx潮喷l头像| 四虎成人精品永久免费av九九| 精品乱码一区| 亚洲国产高清在线观看| 国产精品久久久久久久app| gogo高清在线播放免费| 久久影院模特热| 永久av在线| 国产一区二区三区视频免费| 欧美一区二区公司| 日韩午夜激情av| 91九色蝌蚪91por成人| 91国偷自产一区二区开放时间 | 欧美韩日一区| 欧美成人蜜桃| 日本亚洲不卡| 国产亚洲欧美另类一区二区三区| 久久av影院| 国产精品99久久久久久久久| 色戒汤唯在线观看| 韩日精品中文字幕| 国产偷倩在线播放| 欧美人与性动交| av网站大全在线| 欧美成人激情视频| 午夜小视频在线| 视频直播国产精品| 91精彩视频在线观看| 国产一区二区激情| 国产三级电影在线观看| 亚洲精品日韩在线| 你懂的视频在线播放| 日韩电视剧在线观看免费网站| 欧美特黄一级视频| 亚洲大胆人体在线| 黄色一级a毛片| 精品国产一区二区国模嫣然| 亚洲伦理在线观看| 亚洲白拍色综合图区| 天堂在线观看av| 亚洲激情中文字幕| 五月婷婷六月丁香综合| 亚洲精品mp4| 日本ー区在线视频| 国产一区二区三区欧美| 高清在线观看av| 久久精品最新地址| 五月天激情在线| 国产69精品99久久久久久宅男| 538视频在线| 日韩免费观看av| 国产激情欧美| 成人天堂噜噜噜| 999久久精品| 狠狠色综合网站久久久久久久| 天堂日韩电影| 先锋影音亚洲资源| 综合精品久久| 欧美 日本 亚洲| 日本成人中文字幕在线视频| 亚洲日本黄色片| 高清免费成人av| 亚洲狠狠婷婷综合久久久久图片| 欧美激情一区二区三区蜜桃视频| 永久免费未视频| 亚洲综合一二区| 一级片在线观看免费| 7777精品伊人久久久大香线蕉 | 欧美男人亚洲天堂| 欧美久久久久久久久| 亚洲国产精品一| 亚洲人成电影在线| 国产不卡在线| 日本视频久久久| 精品99re| 欧美一区少妇| 欧美成人69| 99re在线视频免费观看| 国内精品视频一区二区三区八戒| 视频免费在线观看| 国产精品久久网站| 国产午夜福利一区二区| 欧美综合一区二区三区| a天堂中文在线观看| 亚洲欧美另类国产| av免费网站在线观看| 青青草精品毛片| 日本超碰一区二区| 视频一区视频二区视频三区视频四区国产 | 夫妻免费无码v看片| 精品一区二区三区久久久| 国产中文字幕一区二区| 亚洲天堂成人网| 日批视频免费在线观看| 精品国产一区二区三区不卡| 男人的天堂在线视频免费观看 | 欧美一区二区三区艳史| 国产激情精品一区二区三区| 日韩av一区二区三区在线观看 | 国产五月天婷婷| 综合日韩av| 97精品国产露脸对白| 久久久久久久久久久免费精品| 亚洲一区二区图片| 在线观看三级视频| www.66久久| 日本三级韩国三级久久| 亚洲精品成人无码| 精品无人乱码一区二区三区 | 亚洲人成人77777线观看| 黄色成人av网站| 爱豆国产剧免费观看大全剧苏畅| 99久久国产综合精品色伊 | 国内激情久久| mm1313亚洲国产精品无码试看| 不卡电影一区二区三区| 美女的奶胸大爽爽大片| 欧美视频一区在线| 国产鲁鲁视频在线观看免费| 欧美自拍视频在线| 极品束缚调教一区二区网站| www国产无套内射com| 激情亚洲综合在线| 青青操在线播放| 欧美色电影在线| 成人av一区| 国产精品91久久久| 国产一区二区三区天码| aaa毛片在线观看| 99re热视频精品| 久久草视频在线| 日韩精品中文字幕在线播放| а√天堂中文资源在线bt| 国产综合动作在线观看| 91久久视频| 日本一卡二卡在线| 午夜精品久久久久久久蜜桃app | 成人少妇影院yyyy| 国产精品 欧美 日韩| 亚洲国产日韩欧美在线动漫 | 国产精品久久久久久久久| 欧美人与物videos另类xxxxx| 国产精品亚洲αv天堂无码| 久久久久久免费| 青娱乐在线免费视频| 一区二区三区高清国产| 国内精品伊人| 国产日韩欧美大片| 高清日韩电视剧大全免费| 国产精品xxxx喷水欧美| 精品呦交小u女在线| 欧美韩国亚洲| 婷婷视频在线播放| 国产成人在线影院| 国产性xxxx高清| 亚洲视频一区二区三区| 日韩午夜视频在线| 成年丰满熟妇午夜免费视频| 成年人网站91| 三级网站在线播放| 久久精品免费播放| 丁香婷婷成人| 激情五月亚洲色图| 亚洲你懂的在线视频| 天天干天天操av| 国产精品旅馆在线| 亚洲综合中文| 国产精品一区二区人妻喷水| 欧美亚洲综合久久| 免费网站在线观看人| 免费av一区二区三区| 精品一区二区久久| 日本一级淫片色费放| 中文字幕无线精品亚洲乱码一区 | av不卡在线播放| 中文字幕码精品视频网站| 欧美精品在线免费观看| 欧美日韩一区二区三区在线电影| 一区二区xxx| 午夜精品免费在线| 久久综合之合合综合久久| 国产无套精品一区二区| 蜜桃视频在线一区| 日韩 欧美 精品| 久久精品最新地址| 蜜桃a∨噜噜一区二区三区| 能看毛片的网站| 欧美日韩中文精品| 狼人综合视频| 精品嫩模一区二区三区| 国产日韩欧美综合在线|