全参数微调方法
什么是全参数微调?
全参数微调(Full Fine-tuning)是指对预训练模型的所有参数进行再训练的微调方法。与PEFT参数高效微调不同,这种方法会更新模型的每一个权重参数。
核心特点
优势
- 适应性最强:能够充分适应目标任务的特点
- 效果最佳:通常能获得最好的任务性能
- 灵活性高:可以进行深度的模型改造
- 理论简单:直接的端到端训练方式
劣势
- 显存需求大:需要A100、V100等高端GPU
- 训练耗时长:所有参数都需要更新
- 部署复杂:需要保存完整的模型权重
- 过拟合风险:在小数据集上容易过拟合
适用场景
推荐使用全参数微调的情况
- 充足的计算资源:有高端GPU集群支持
- 大规模数据集:有足够的标注数据防止过拟合
- 极致性能要求:对模型效果有最高要求
- 领域差异巨大:目标领域与预训练数据差异很大
不推荐使用的情况
- 资源受限:只有消费级GPU或云端资源有限
- 小数据集:标注数据不足1000条
- 快速实验:需要快速验证想法的场景
- 多任务部署:需要同时支持多个任务
技术实现要点
学习率设置
# 通常使用较小的学习率
learning_rate = 1e-5 # 比预训练时小1-2个数量级
# 可以使用不同层的差异化学习率
optimizer = AdamW([
{'params': model.embeddings.parameters(), 'lr': 1e-6},
{'params': model.encoder.parameters(), 'lr': 5e-6},
{'params': model.classifier.parameters(), 'lr': 1e-5}
])
梯度累积
# 当batch size受限时使用梯度累积
accumulation_steps = 4
for i, batch in enumerate(dataloader):
loss = model(batch)
loss = loss / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
权重衰减
# 防止过拟合的正则化
optimizer = AdamW(
model.parameters(),
lr=1e-5,
weight_decay=0.01 # L2正则化
)
训练策略
渐进式解冻(Progressive Unfreezing)
# 第一阶段:只训练分类头
for param in model.base_model.parameters():
param.requires_grad = False
# 第二阶段:解冻最后几层
for param in model.base_model.encoder.layer[-2:].parameters():
param.requires_grad = True
# 第三阶段:全参数微调
for param in model.parameters():
param.requires_grad = True
差异化学习率(Discriminative Learning Rates)
不同层使用不同的学习率:
- 底层:较小学习率(保持预训练特征)
- 中层:中等学习率
- 顶层:较大学习率(快速适应新任务)
显存优化技巧
混合精度训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
梯度检查点
# 牺牲计算时间换取显存
model.gradient_checkpointing_enable()
DeepSpeed ZeRO
# 使用DeepSpeed进行显存优化
from deepspeed import zero
# 配置ZeRO Stage 2或Stage 3
与其他方法的对比
方法 | 参数更新量 | 显存需求 | 训练时间 | 效果 | 部署复杂度 |
---|---|---|---|---|---|
全参数微调 | 100% | 很高 | 长 | 最佳 | 高 |
PEFT参数高效微调 | 1-5% | 低 | 短 | 良好 | 低 |
Prompt工程 | 0% | 最低 | 无 | 一般 | 最低 |
最佳实践建议
数据准备
- 数据质量:确保标注数据的准确性
- 数据平衡:避免类别不平衡问题
- 数据增强:通过回译、同义词替换等方式扩充数据
训练监控
- 早停机制:监控验证集性能,防止过拟合
- 学习率调度:使用余弦退火或线性衰减
- 梯度裁剪:防止梯度爆炸
模型保存
# 保存完整模型
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss,
}, 'full_finetuned_model.pth')