Prefix Tuning

什么是Prefix Tuning?

Prefix Tuning是一种参数高效的微调方法,通过在输入序列前添加可学习的虚拟token(prefix)来引导模型行为,而不修改预训练模型的任何参数。

核心思想

虚拟Prefix概念

原始输入:[x1, x2, x3, ..., xn]
Prefix Tuning:[P1, P2, ..., Pk, x1, x2, x3, ..., xn]

其中P1到Pk是可学习的虚拟token,k通常远小于n。

与Prompt的区别

  • Prompt:使用自然语言描述任务
  • Prefix:使用可学习的连续向量表示

技术架构

基础实现

import torch
import torch.nn as nn
 
class PrefixTuning(nn.Module):
    def __init__(self, 
                 prefix_length=10,
                 hidden_size=768,
                 num_layers=12,
                 num_heads=12):
        super().__init__()
        
        self.prefix_length = prefix_length
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        
        # 每层的prefix参数
        self.prefix_embeddings = nn.Parameter(
            torch.randn(num_layers, 2, prefix_length, hidden_size)
        )
        # 2表示key和value
        
        # MLP重参数化(可选)
        self.prefix_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, num_layers * 2 * hidden_size)
        )
    
    def get_prefix_states(self, batch_size):
        """获取prefix的key和value状态"""
        
        # 方法1:直接使用参数
        prefix_states = self.prefix_embeddings.unsqueeze(0).expand(
            batch_size, -1, -1, -1, -1
        )
        
        # 方法2:通过MLP生成(更稳定)
        # prefix_input = torch.randn(batch_size, self.prefix_length, self.hidden_size)
        # prefix_output = self.prefix_mlp(prefix_input)
        # prefix_states = prefix_output.view(
        #     batch_size, self.num_layers, 2, self.prefix_length, self.hidden_size
        # )
        
        return prefix_states
    
    def forward(self, input_ids, attention_mask=None):
        batch_size = input_ids.size(0)
        
        # 获取prefix状态
        prefix_states = self.get_prefix_states(batch_size)
        
        # 扩展attention mask
        if attention_mask is not None:
            prefix_attention = torch.ones(
                batch_size, self.prefix_length,
                device=attention_mask.device,
                dtype=attention_mask.dtype
            )
            attention_mask = torch.cat([prefix_attention, attention_mask], dim=1)
        
        return prefix_states, attention_mask

集成到Transformer

class TransformerWithPrefix(nn.Module):
    def __init__(self, base_model, prefix_config):
        super().__init__()
        self.base_model = base_model
        self.prefix_tuning = PrefixTuning(**prefix_config)
        
        # 冻结基础模型
        for param in self.base_model.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, attention_mask=None, **kwargs):
        # 获取prefix状态
        prefix_states, extended_attention_mask = self.prefix_tuning(
            input_ids, attention_mask
        )
        
        # 修改模型的注意力计算
        return self.forward_with_prefix(
            input_ids, 
            extended_attention_mask, 
            prefix_states,
            **kwargs
        )
    
    def forward_with_prefix(self, input_ids, attention_mask, prefix_states, **kwargs):
        # 获取输入嵌入
        inputs_embeds = self.base_model.embeddings(input_ids)
        
        # 逐层前向传播
        hidden_states = inputs_embeds
        
        for layer_idx, layer in enumerate(self.base_model.encoder.layer):
            # 获取当前层的prefix key和value
            layer_prefix_states = prefix_states[:, layer_idx]  # [batch, 2, prefix_len, hidden]
            prefix_key = layer_prefix_states[:, 0]  # [batch, prefix_len, hidden]
            prefix_value = layer_prefix_states[:, 1]  # [batch, prefix_len, hidden]
            
            # 修改注意力计算以包含prefix
            hidden_states = self.layer_forward_with_prefix(
                layer, hidden_states, attention_mask, prefix_key, prefix_value
            )
        
        return hidden_states

实现变体

P-Tuning v1

class PTuning(nn.Module):
    def __init__(self, vocab_size, hidden_size, prefix_length=10):
        super().__init__()
        self.prefix_length = prefix_length
        
        # 可学习的prompt token嵌入
        self.prompt_embeddings = nn.Parameter(
            torch.randn(prefix_length, hidden_size)
        )
        
        # LSTM编码器(可选)
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )
        
        # 投影层
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )
    
    def forward(self, batch_size):
        # 扩展到batch
        prompt_embeds = self.prompt_embeddings.unsqueeze(0).expand(
            batch_size, -1, -1
        )
        
        # LSTM编码
        lstm_output, _ = self.lstm(prompt_embeds)
        
        # MLP投影
        prompt_embeds = self.mlp(lstm_output)
        
        return prompt_embeds

P-Tuning v2

class PTuningV2(nn.Module):
    def __init__(self, 
                 prefix_length=10,
                 hidden_size=768,
                 num_layers=12,
                 prefix_projection=True):
        super().__init__()
        
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        self.prefix_projection = prefix_projection
        
        if prefix_projection:
            # 使用MLP投影
            self.embedding_size = 512
            self.trans = nn.Sequential(
                nn.Linear(self.embedding_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, num_layers * 2 * hidden_size)
            )
            self.prefix_tokens = nn.Parameter(
                torch.randn(prefix_length, self.embedding_size)
            )
        else:
            # 直接优化
            self.prefix_tokens = nn.Parameter(
                torch.randn(num_layers, 2, prefix_length, hidden_size)
            )
    
    def forward(self, batch_size):
        if self.prefix_projection:
            # 通过MLP生成prefix
            prefix_tokens = self.trans(self.prefix_tokens)  # [prefix_length, num_layers * 2 * hidden_size]
            prefix_tokens = prefix_tokens.view(
                self.prefix_length, self.num_layers, 2, -1
            ).permute(1, 2, 0, 3)  # [num_layers, 2, prefix_length, hidden_size]
        else:
            prefix_tokens = self.prefix_tokens
        
        # 扩展到batch维度
        prefix_tokens = prefix_tokens.unsqueeze(0).expand(
            batch_size, -1, -1, -1, -1
        )
        
        return prefix_tokens

训练策略

基础训练流程

from transformers import Trainer, TrainingArguments
 
def train_prefix_tuning(model, train_dataset, eval_dataset):
    # 只训练prefix参数
    for name, param in model.named_parameters():
        if "prefix" not in name:
            param.requires_grad = False
        else:
            param.requires_grad = True
    
    # 训练配置
    training_args = TrainingArguments(
        output_dir="./prefix_output",
        num_train_epochs=10,  # Prefix Tuning通常需要更多轮数
        per_device_train_batch_size=8,
        learning_rate=5e-4,   # 相对较大的学习率
        warmup_ratio=0.1,
        weight_decay=0.01,
        logging_steps=50,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    
    trainer.train()
    return model

多任务训练

class MultiTaskPrefixModel(nn.Module):
    def __init__(self, base_model, task_configs):
        super().__init__()
        self.base_model = base_model
        self.task_prefixes = nn.ModuleDict()
        
        # 为每个任务创建独立的prefix
        for task_name, config in task_configs.items():
            self.task_prefixes[task_name] = PrefixTuning(**config)
    
    def forward(self, input_ids, task_name, attention_mask=None):
        # 使用任务特定的prefix
        prefix_states, extended_attention_mask = self.task_prefixes[task_name](
            input_ids, attention_mask
        )
        
        return self.forward_with_prefix(
            input_ids, extended_attention_mask, prefix_states
        )

优化技巧

初始化策略

def initialize_prefix_embeddings(prefix_embeddings, vocab_embeddings, init_text):
    """使用真实词汇初始化prefix"""
    
    # 将初始化文本转换为token
    init_tokens = tokenizer(init_text, return_tensors="pt")["input_ids"][0]
    
    # 使用对应的词嵌入初始化
    with torch.no_grad():
        for i, token_id in enumerate(init_tokens[:prefix_embeddings.size(0)]):
            prefix_embeddings[i] = vocab_embeddings[token_id]
 
# 使用示例
init_text = "The following is a helpful assistant response:"
initialize_prefix_embeddings(
    model.prefix_tuning.prefix_embeddings[0, 0],  # 第一层的key
    model.base_model.embeddings.word_embeddings.weight,
    init_text
)

长度自适应

class AdaptivePrefixTuning(nn.Module):
    def __init__(self, max_prefix_length=20, hidden_size=768):
        super().__init__()
        self.max_prefix_length = max_prefix_length
        
        # 可学习的prefix池
        self.prefix_pool = nn.Parameter(
            torch.randn(max_prefix_length, hidden_size)
        )
        
        # 长度控制器
        self.length_controller = nn.Linear(hidden_size, 1)
        
    def forward(self, input_ids):
        batch_size = input_ids.size(0)
        
        # 根据输入动态确定prefix长度
        input_repr = input_ids.float().mean(dim=1, keepdim=True)  # 简化表示
        length_logit = self.length_controller(input_repr)
        prefix_length = int(torch.sigmoid(length_logit) * self.max_prefix_length)
        prefix_length = max(1, prefix_length)  # 至少1个token
        
        # 选择对应长度的prefix
        selected_prefix = self.prefix_pool[:prefix_length]
        
        return selected_prefix.unsqueeze(0).expand(batch_size, -1, -1)

应用场景

文本生成

def generation_with_prefix(model, input_text, prefix_text="Generate a story:"):
    # 设置生成prefix
    model.set_prefix(prefix_text)
    
    # 生成文本
    inputs = tokenizer(input_text, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=200,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

分类任务

class PrefixClassificationModel(nn.Module):
    def __init__(self, base_model, num_classes, prefix_length=10):
        super().__init__()
        self.base_model = base_model
        self.prefix_tuning = PrefixTuning(prefix_length=prefix_length)
        self.classifier = nn.Linear(base_model.config.hidden_size, num_classes)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # 应用prefix
        prefix_states, extended_attention_mask = self.prefix_tuning(
            input_ids, attention_mask
        )
        
        # 获取表示
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=extended_attention_mask,
            prefix_states=prefix_states
        )
        
        # 分类
        pooled_output = outputs.last_hidden_state[:, 0]  # [CLS] token
        logits = self.classifier(pooled_output)
        
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}

性能分析

参数效率对比

def analyze_parameter_efficiency():
    """分析不同方法的参数效率"""
    
    model_size = 110_000_000  # 110M参数的模型
    
    methods = {
        "Full Fine-tuning": {
            "trainable_params": model_size,
            "percentage": 100.0
        },
        "LoRA (r=16)": {
            "trainable_params": 294_912,  # 约0.27%
            "percentage": 0.27
        },
        "Adapter": {
            "trainable_params": 896_000,  # 约0.81%
            "percentage": 0.81
        },
        "Prefix Tuning (len=10)": {
            "trainable_params": 61_440,   # 约0.056%
            "percentage": 0.056
        }
    }
    
    return methods

任务性能对比

def benchmark_prefix_tuning():
    """Prefix Tuning性能基准"""
    
    results = {
        "短文本分类": {
            "Full Fine-tuning": 0.92,
            "Prefix Tuning": 0.89,
            "相对性能": "96.7%"
        },
        "文本生成": {
            "Full Fine-tuning": 0.85,
            "Prefix Tuning": 0.82,
            "相对性能": "96.5%"
        },
        "长文本理解": {
            "Full Fine-tuning": 0.88,
            "Prefix Tuning": 0.83,
            "相对性能": "94.3%"
        }
    }
    
    return results

优势与局限

优势

  1. 极少参数:通常只需0.01%-0.1%的参数
  2. 训练稳定:相比其他PEFT方法更稳定
  3. 任务无关:不需要修改模型架构
  4. 快速切换:可以快速切换不同任务的prefix

局限性

  1. 长文本性能下降:在长序列上效果不如其他方法
  2. 任务复杂度限制:对复杂任务的适应能力有限
  3. 序列长度敏感:prefix长度需要仔细调优
  4. 推理开销:增加了序列长度

最佳实践

Prefix长度选择

def choose_prefix_length(task_type, sequence_length):
    """根据任务类型和序列长度选择prefix长度"""
    
    if task_type == "classification":
        if sequence_length < 128:
            return 5
        elif sequence_length < 512:
            return 10
        else:
            return 20
    
    elif task_type == "generation":
        if sequence_length < 256:
            return 10
        elif sequence_length < 1024:
            return 20
        else:
            return 50
    
    else:  # 其他任务
        return min(20, sequence_length // 10)

训练技巧

def prefix_training_tips():
    """Prefix Tuning训练技巧"""
    
    return {
        "学习率": "5e-4到1e-3,比其他方法大",
        "训练轮数": "10-20轮,需要更多训练",
        "初始化": "使用相关文本初始化prefix",
        "长度调优": "从小到大逐步调整prefix长度",
        "正则化": "适当的dropout防止过拟合",
        "验证策略": "密切监控验证集性能"
    }

相关概念