跳转至

指令微调(SFT)

SFT(Supervised Fine-Tuning,监督微调)是把基座模型变成"听话的助手"的核心步骤。本篇我们使用 HuggingFace 的 TRL 库,它是专门为 SFT/DPO/PPO 设计的高级封装,比手写 Trainer 简洁很多。

1. SFT 的本质

预训练模型只会"接龙"——给一段文字它会续写。SFT 教会模型:

  • 看到问题(role: user)应该回答(role: assistant
  • 遵守 system 设定的规则
  • 按照训练数据中的风格、格式回答

技术上,SFT 就是普通的语言模型训练,但只对回答部分计算 loss——这是关键。

2. Loss Mask 的重要性

考虑这个训练样本:

<user>什么是 Python?</user>
<assistant>Python 是一种高级编程语言。</assistant>

如果像普通 LM 训练那样计算所有 token 的 loss,模型也会学习生成"什么是 Python?"这样的提问,这不是我们要的。

正确做法:将问题部分的 labels 设为 -100,只让回答部分参与 loss 计算。

-100 是 PyTorch CrossEntropyLoss 的默认忽略值。

3. 使用 TRL 的 SFTTrainer

TRL 的 SFTTrainer 自动处理 loss mask、对话模板等繁琐细节。安装:

pip install trl transformers datasets accelerate peft bitsandbytes

4. 完整 SFT 训练脚本

新建 train_sft.py

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
DATA_PATH = "sft_data.jsonl"
OUTPUT_DIR = "./output_sft"

# === 加载模型(QLoRA)===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
)

# === LoRA 配置 ===
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
)

# === 加载数据集(messages 格式)===
# 数据格式示例:
# {"messages": [
#   {"role": "system", "content": "你是一位友善的助手"},
#   {"role": "user", "content": "你好"},
#   {"role": "assistant", "content": "你好!"}
# ]}
dataset = load_dataset("json", data_files=DATA_PATH, split="train")
dataset = dataset.train_test_split(test_size=0.05, seed=42)

# === SFT 训练参数 ===
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    bf16=True,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    report_to="tensorboard",

    # SFT 专属参数
    max_seq_length=2048,                # 最大序列长度
    packing=False,                      # 是否将多个短样本打包到同一序列(提速但会跨样本)
    dataset_text_field=None,            # messages 格式不需要
)

# === 创建 Trainer 并训练 ===
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    tokenizer=tokenizer,
)

trainer.train()
trainer.save_model(OUTPUT_DIR)

启动:

python train_sft.py

SFTTrainer 的强大之处:

  • 自动应用对话模板:识别 messages 字段并调用 apply_chat_template
  • 自动 loss mask:只对 assistant 回复计算 loss
  • 自动数据 collator:动态 padding
  • 自动多轮支持:处理多轮对话样本

5. 三种数据格式支持

SFTTrainer 支持三种输入格式:

Messages 格式(推荐)

{
  "messages": [
    {"role": "system", "content": "..."},
    {"role": "user", "content": "..."},
    {"role": "assistant", "content": "..."}
  ]
}

Prompt-Completion 格式

{
  "prompt": "什么是 Python?",
  "completion": "Python 是一种高级编程语言。"
}

纯文本格式

{"text": "用户:你好\n助手:你好!"}

需要在 SFTConfig 中指定 dataset_text_field="text"

6. Loss Mask 实现细节

如果你不用 TRL,想自己实现 SFT loss mask:

def preprocess_with_mask(example, tokenizer, max_length=2048):
    """SFT 数据预处理,只对 assistant 部分计算 loss"""
    messages = example["messages"]

    input_ids = []
    labels = []

    for i, msg in enumerate(messages):
        # 分别对每条消息编码
        single_msg = [msg]
        is_last_user = (
            i + 1 < len(messages) and messages[i + 1]["role"] == "assistant"
        )

        text = tokenizer.apply_chat_template(
            single_msg,
            tokenize=False,
            add_generation_prompt=False,
        )
        ids = tokenizer.encode(text, add_special_tokens=False)
        input_ids.extend(ids)

        if msg["role"] == "assistant":
            # 对 assistant 部分计算 loss
            labels.extend(ids)
        else:
            # system / user 部分不计算 loss
            labels.extend([-100] * len(ids))

    # 截断
    input_ids = input_ids[:max_length]
    labels = labels[:max_length]

    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": [1] * len(input_ids),
    }

实际项目推荐直接用 TRL,逻辑更稳定。

7. 多轮对话训练

多轮对话样本会让模型学会维持对话上下文。同一条样本中所有的 assistant 轮次都会参与 loss 计算

{
  "messages": [
    {"role": "user", "content": "我叫小明"},
    {"role": "assistant", "content": "你好,小明!"},
    {"role": "user", "content": "你还记得我叫什么?"},
    {"role": "assistant", "content": "你叫小明。"}
  ]
}

这种数据特别有助于模型学会长期记忆和上下文理解

8. SFT 调参经验

学习率

方法 推荐学习率
全量 SFT 1e-5 ~ 5e-5
LoRA / QLoRA SFT 1e-4 ~ 5e-4

Epoch

  • 指令跟随任务:1-3 epoch 通常足够
  • 小数据集(< 1K):可以训 5-10 epoch
  • 大数据集(> 100K):1 epoch 可能就够,避免过拟合

Packing

packing=True 会将多个短样本拼接到同一个序列中,训练速度提升 2-3 倍。但有两个注意事项:

  1. 需要正确处理 attention mask,避免跨样本"串味"
  2. 多轮对话样本不建议打包

9. 评估训练效果

简单的人工测试:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
LORA_DIR = "./output_sft"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL, quantization_config=bnb_config, device_map="auto"
)
model = PeftModel.from_pretrained(model, LORA_DIR)
model.eval()

test_questions = [
    "你好,介绍一下你自己。",
    "用 Python 写一个二分查找。",
    "如何评价《百年孤独》?",
]

for q in test_questions:
    messages = [{"role": "user", "content": q}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    print(f"Q: {q}\nA: {response}\n{'-'*80}")

下一篇会讲更系统的评估方法。

10. 常见坑

坑 1:忘了用 -Instruct 版本

如果你的目标是对话模型,强烈建议从带 -Instruct / -Chat 后缀的版本继续 SFT,不要从 base 版本开始。base 版本对话能力是 0,需要海量数据才能从零训出来。

坑 2:数据格式错误

messages 中必须用 role: system/user/assistant,不能是其他角色名。role 顺序必须是 user → assistant 交替。

坑 3:System Prompt 不一致

如果训练数据里 80% 用同一个 system prompt,模型会"绑定"到这个 prompt。换其他 system 时效果可能下降。

坑 4:训练后模型变蠢

通常是过拟合或灾难性遗忘。解决:

  • 减小 epoch
  • 在数据中混入通用对话样本(10-20%)
  • 减小 LoRA rlora_alpha

总结

  • SFT 的核心是只对 assistant 回答计算 loss
  • TRL 的 SFTTrainer 自动处理对话模板、loss mask、动态 padding
  • 推荐 messages 格式输入,自动应用模型的 chat template
  • LoRA 学习率比全量大 10 倍,1-3 epoch 通常足够
  • -Instruct 版本开始 SFT,从 base 开始要海量数据
  • packing=True 可以大幅提速,但多轮对话不建议打包

评论