Prefix Tuning
什么是Prefix Tuning?
Prefix Tuning是一种参数高效的微调方法,通过在输入序列前添加可学习的虚拟token(prefix)来引导模型行为,而不修改预训练模型的任何参数。
核心思想
虚拟Prefix概念
原始输入:[x1, x2, x3, ..., xn]
Prefix Tuning:[P1, P2, ..., Pk, x1, x2, x3, ..., xn]
其中P1到Pk是可学习的虚拟token,k通常远小于n。
与Prompt的区别
- Prompt:使用自然语言描述任务
- Prefix:使用可学习的连续向量表示
技术架构
基础实现
import torch
import torch.nn as nn
class PrefixTuning(nn.Module):
def __init__(self,
prefix_length=10,
hidden_size=768,
num_layers=12,
num_heads=12):
super().__init__()
self.prefix_length = prefix_length
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
# 每层的prefix参数
self.prefix_embeddings = nn.Parameter(
torch.randn(num_layers, 2, prefix_length, hidden_size)
)
# 2表示key和value
# MLP重参数化(可选)
self.prefix_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, num_layers * 2 * hidden_size)
)
def get_prefix_states(self, batch_size):
"""获取prefix的key和value状态"""
# 方法1:直接使用参数
prefix_states = self.prefix_embeddings.unsqueeze(0).expand(
batch_size, -1, -1, -1, -1
)
# 方法2:通过MLP生成(更稳定)
# prefix_input = torch.randn(batch_size, self.prefix_length, self.hidden_size)
# prefix_output = self.prefix_mlp(prefix_input)
# prefix_states = prefix_output.view(
# batch_size, self.num_layers, 2, self.prefix_length, self.hidden_size
# )
return prefix_states
def forward(self, input_ids, attention_mask=None):
batch_size = input_ids.size(0)
# 获取prefix状态
prefix_states = self.get_prefix_states(batch_size)
# 扩展attention mask
if attention_mask is not None:
prefix_attention = torch.ones(
batch_size, self.prefix_length,
device=attention_mask.device,
dtype=attention_mask.dtype
)
attention_mask = torch.cat([prefix_attention, attention_mask], dim=1)
return prefix_states, attention_mask
集成到Transformer
class TransformerWithPrefix(nn.Module):
def __init__(self, base_model, prefix_config):
super().__init__()
self.base_model = base_model
self.prefix_tuning = PrefixTuning(**prefix_config)
# 冻结基础模型
for param in self.base_model.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask=None, **kwargs):
# 获取prefix状态
prefix_states, extended_attention_mask = self.prefix_tuning(
input_ids, attention_mask
)
# 修改模型的注意力计算
return self.forward_with_prefix(
input_ids,
extended_attention_mask,
prefix_states,
**kwargs
)
def forward_with_prefix(self, input_ids, attention_mask, prefix_states, **kwargs):
# 获取输入嵌入
inputs_embeds = self.base_model.embeddings(input_ids)
# 逐层前向传播
hidden_states = inputs_embeds
for layer_idx, layer in enumerate(self.base_model.encoder.layer):
# 获取当前层的prefix key和value
layer_prefix_states = prefix_states[:, layer_idx] # [batch, 2, prefix_len, hidden]
prefix_key = layer_prefix_states[:, 0] # [batch, prefix_len, hidden]
prefix_value = layer_prefix_states[:, 1] # [batch, prefix_len, hidden]
# 修改注意力计算以包含prefix
hidden_states = self.layer_forward_with_prefix(
layer, hidden_states, attention_mask, prefix_key, prefix_value
)
return hidden_states
实现变体
P-Tuning v1
class PTuning(nn.Module):
def __init__(self, vocab_size, hidden_size, prefix_length=10):
super().__init__()
self.prefix_length = prefix_length
# 可学习的prompt token嵌入
self.prompt_embeddings = nn.Parameter(
torch.randn(prefix_length, hidden_size)
)
# LSTM编码器(可选)
self.lstm = nn.LSTM(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=2,
bidirectional=True,
batch_first=True
)
# 投影层
self.mlp = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size)
)
def forward(self, batch_size):
# 扩展到batch
prompt_embeds = self.prompt_embeddings.unsqueeze(0).expand(
batch_size, -1, -1
)
# LSTM编码
lstm_output, _ = self.lstm(prompt_embeds)
# MLP投影
prompt_embeds = self.mlp(lstm_output)
return prompt_embeds
P-Tuning v2
class PTuningV2(nn.Module):
def __init__(self,
prefix_length=10,
hidden_size=768,
num_layers=12,
prefix_projection=True):
super().__init__()
self.prefix_length = prefix_length
self.num_layers = num_layers
self.prefix_projection = prefix_projection
if prefix_projection:
# 使用MLP投影
self.embedding_size = 512
self.trans = nn.Sequential(
nn.Linear(self.embedding_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, num_layers * 2 * hidden_size)
)
self.prefix_tokens = nn.Parameter(
torch.randn(prefix_length, self.embedding_size)
)
else:
# 直接优化
self.prefix_tokens = nn.Parameter(
torch.randn(num_layers, 2, prefix_length, hidden_size)
)
def forward(self, batch_size):
if self.prefix_projection:
# 通过MLP生成prefix
prefix_tokens = self.trans(self.prefix_tokens) # [prefix_length, num_layers * 2 * hidden_size]
prefix_tokens = prefix_tokens.view(
self.prefix_length, self.num_layers, 2, -1
).permute(1, 2, 0, 3) # [num_layers, 2, prefix_length, hidden_size]
else:
prefix_tokens = self.prefix_tokens
# 扩展到batch维度
prefix_tokens = prefix_tokens.unsqueeze(0).expand(
batch_size, -1, -1, -1, -1
)
return prefix_tokens
训练策略
基础训练流程
from transformers import Trainer, TrainingArguments
def train_prefix_tuning(model, train_dataset, eval_dataset):
# 只训练prefix参数
for name, param in model.named_parameters():
if "prefix" not in name:
param.requires_grad = False
else:
param.requires_grad = True
# 训练配置
training_args = TrainingArguments(
output_dir="./prefix_output",
num_train_epochs=10, # Prefix Tuning通常需要更多轮数
per_device_train_batch_size=8,
learning_rate=5e-4, # 相对较大的学习率
warmup_ratio=0.1,
weight_decay=0.01,
logging_steps=50,
save_strategy="epoch",
evaluation_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
return model
多任务训练
class MultiTaskPrefixModel(nn.Module):
def __init__(self, base_model, task_configs):
super().__init__()
self.base_model = base_model
self.task_prefixes = nn.ModuleDict()
# 为每个任务创建独立的prefix
for task_name, config in task_configs.items():
self.task_prefixes[task_name] = PrefixTuning(**config)
def forward(self, input_ids, task_name, attention_mask=None):
# 使用任务特定的prefix
prefix_states, extended_attention_mask = self.task_prefixes[task_name](
input_ids, attention_mask
)
return self.forward_with_prefix(
input_ids, extended_attention_mask, prefix_states
)
优化技巧
初始化策略
def initialize_prefix_embeddings(prefix_embeddings, vocab_embeddings, init_text):
"""使用真实词汇初始化prefix"""
# 将初始化文本转换为token
init_tokens = tokenizer(init_text, return_tensors="pt")["input_ids"][0]
# 使用对应的词嵌入初始化
with torch.no_grad():
for i, token_id in enumerate(init_tokens[:prefix_embeddings.size(0)]):
prefix_embeddings[i] = vocab_embeddings[token_id]
# 使用示例
init_text = "The following is a helpful assistant response:"
initialize_prefix_embeddings(
model.prefix_tuning.prefix_embeddings[0, 0], # 第一层的key
model.base_model.embeddings.word_embeddings.weight,
init_text
)
长度自适应
class AdaptivePrefixTuning(nn.Module):
def __init__(self, max_prefix_length=20, hidden_size=768):
super().__init__()
self.max_prefix_length = max_prefix_length
# 可学习的prefix池
self.prefix_pool = nn.Parameter(
torch.randn(max_prefix_length, hidden_size)
)
# 长度控制器
self.length_controller = nn.Linear(hidden_size, 1)
def forward(self, input_ids):
batch_size = input_ids.size(0)
# 根据输入动态确定prefix长度
input_repr = input_ids.float().mean(dim=1, keepdim=True) # 简化表示
length_logit = self.length_controller(input_repr)
prefix_length = int(torch.sigmoid(length_logit) * self.max_prefix_length)
prefix_length = max(1, prefix_length) # 至少1个token
# 选择对应长度的prefix
selected_prefix = self.prefix_pool[:prefix_length]
return selected_prefix.unsqueeze(0).expand(batch_size, -1, -1)
应用场景
文本生成
def generation_with_prefix(model, input_text, prefix_text="Generate a story:"):
# 设置生成prefix
model.set_prefix(prefix_text)
# 生成文本
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=200,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
分类任务
class PrefixClassificationModel(nn.Module):
def __init__(self, base_model, num_classes, prefix_length=10):
super().__init__()
self.base_model = base_model
self.prefix_tuning = PrefixTuning(prefix_length=prefix_length)
self.classifier = nn.Linear(base_model.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask=None, labels=None):
# 应用prefix
prefix_states, extended_attention_mask = self.prefix_tuning(
input_ids, attention_mask
)
# 获取表示
outputs = self.base_model(
input_ids=input_ids,
attention_mask=extended_attention_mask,
prefix_states=prefix_states
)
# 分类
pooled_output = outputs.last_hidden_state[:, 0] # [CLS] token
logits = self.classifier(pooled_output)
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
性能分析
参数效率对比
def analyze_parameter_efficiency():
"""分析不同方法的参数效率"""
model_size = 110_000_000 # 110M参数的模型
methods = {
"Full Fine-tuning": {
"trainable_params": model_size,
"percentage": 100.0
},
"LoRA (r=16)": {
"trainable_params": 294_912, # 约0.27%
"percentage": 0.27
},
"Adapter": {
"trainable_params": 896_000, # 约0.81%
"percentage": 0.81
},
"Prefix Tuning (len=10)": {
"trainable_params": 61_440, # 约0.056%
"percentage": 0.056
}
}
return methods
任务性能对比
def benchmark_prefix_tuning():
"""Prefix Tuning性能基准"""
results = {
"短文本分类": {
"Full Fine-tuning": 0.92,
"Prefix Tuning": 0.89,
"相对性能": "96.7%"
},
"文本生成": {
"Full Fine-tuning": 0.85,
"Prefix Tuning": 0.82,
"相对性能": "96.5%"
},
"长文本理解": {
"Full Fine-tuning": 0.88,
"Prefix Tuning": 0.83,
"相对性能": "94.3%"
}
}
return results
优势与局限
优势
- 极少参数:通常只需0.01%-0.1%的参数
- 训练稳定:相比其他PEFT方法更稳定
- 任务无关:不需要修改模型架构
- 快速切换:可以快速切换不同任务的prefix
局限性
- 长文本性能下降:在长序列上效果不如其他方法
- 任务复杂度限制:对复杂任务的适应能力有限
- 序列长度敏感:prefix长度需要仔细调优
- 推理开销:增加了序列长度
最佳实践
Prefix长度选择
def choose_prefix_length(task_type, sequence_length):
"""根据任务类型和序列长度选择prefix长度"""
if task_type == "classification":
if sequence_length < 128:
return 5
elif sequence_length < 512:
return 10
else:
return 20
elif task_type == "generation":
if sequence_length < 256:
return 10
elif sequence_length < 1024:
return 20
else:
return 50
else: # 其他任务
return min(20, sequence_length // 10)
训练技巧
def prefix_training_tips():
"""Prefix Tuning训练技巧"""
return {
"学习率": "5e-4到1e-3,比其他方法大",
"训练轮数": "10-20轮,需要更多训练",
"初始化": "使用相关文本初始化prefix",
"长度调优": "从小到大逐步调整prefix长度",
"正则化": "适当的dropout防止过拟合",
"验证策略": "密切监控验证集性能"
}
相关概念
- PEFT参数高效微调 - Prefix Tuning所属的技术类别
- LoRA微调 - 另一种参数高效方法
- Adapter微调 - 另一种参数高效方法
- Prompt工程 - 相关的提示技术
- 微调策略选择指南 - 方法选择指导