全参数微调方法

什么是全参数微调?

全参数微调(Full Fine-tuning)是指对预训练模型的所有参数进行再训练的微调方法。与PEFT参数高效微调不同,这种方法会更新模型的每一个权重参数。

核心特点

优势

  • 适应性最强:能够充分适应目标任务的特点
  • 效果最佳:通常能获得最好的任务性能
  • 灵活性高:可以进行深度的模型改造
  • 理论简单:直接的端到端训练方式

劣势

  • 显存需求大:需要A100、V100等高端GPU
  • 训练耗时长:所有参数都需要更新
  • 部署复杂:需要保存完整的模型权重
  • 过拟合风险:在小数据集上容易过拟合

适用场景

推荐使用全参数微调的情况

  1. 充足的计算资源:有高端GPU集群支持
  2. 大规模数据集:有足够的标注数据防止过拟合
  3. 极致性能要求:对模型效果有最高要求
  4. 领域差异巨大:目标领域与预训练数据差异很大

不推荐使用的情况

  1. 资源受限:只有消费级GPU或云端资源有限
  2. 小数据集:标注数据不足1000条
  3. 快速实验:需要快速验证想法的场景
  4. 多任务部署:需要同时支持多个任务

技术实现要点

学习率设置

# 通常使用较小的学习率
learning_rate = 1e-5  # 比预训练时小1-2个数量级
 
# 可以使用不同层的差异化学习率
optimizer = AdamW([
    {'params': model.embeddings.parameters(), 'lr': 1e-6},
    {'params': model.encoder.parameters(), 'lr': 5e-6},
    {'params': model.classifier.parameters(), 'lr': 1e-5}
])

梯度累积

# 当batch size受限时使用梯度累积
accumulation_steps = 4
for i, batch in enumerate(dataloader):
    loss = model(batch)
    loss = loss / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

权重衰减

# 防止过拟合的正则化
optimizer = AdamW(
    model.parameters(),
    lr=1e-5,
    weight_decay=0.01  # L2正则化
)

训练策略

渐进式解冻(Progressive Unfreezing)

# 第一阶段:只训练分类头
for param in model.base_model.parameters():
    param.requires_grad = False
 
# 第二阶段:解冻最后几层
for param in model.base_model.encoder.layer[-2:].parameters():
    param.requires_grad = True
 
# 第三阶段:全参数微调
for param in model.parameters():
    param.requires_grad = True

差异化学习率(Discriminative Learning Rates)

不同层使用不同的学习率:

  • 底层:较小学习率(保持预训练特征)
  • 中层:中等学习率
  • 顶层:较大学习率(快速适应新任务)

显存优化技巧

混合精度训练

from torch.cuda.amp import autocast, GradScaler
 
scaler = GradScaler()
 
with autocast():
    outputs = model(inputs)
    loss = criterion(outputs, labels)
 
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

梯度检查点

# 牺牲计算时间换取显存
model.gradient_checkpointing_enable()

DeepSpeed ZeRO

# 使用DeepSpeed进行显存优化
from deepspeed import zero
# 配置ZeRO Stage 2或Stage 3

与其他方法的对比

方法参数更新量显存需求训练时间效果部署复杂度
全参数微调100%很高最佳
PEFT参数高效微调1-5%良好
Prompt工程0%最低一般最低

最佳实践建议

数据准备

  1. 数据质量:确保标注数据的准确性
  2. 数据平衡:避免类别不平衡问题
  3. 数据增强:通过回译、同义词替换等方式扩充数据

训练监控

  1. 早停机制:监控验证集性能,防止过拟合
  2. 学习率调度:使用余弦退火或线性衰减
  3. 梯度裁剪:防止梯度爆炸

模型保存

# 保存完整模型
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    'loss': loss,
}, 'full_finetuned_model.pth')

相关概念