PPO微调

什么是PPO?

PPO(Proximal Policy Optimization)是一种策略梯度强化学习算法,在大语言模型微调中用于优化策略模型,使其生成更符合奖励函数的输出。PPO是RLHF微调第三阶段的核心算法。

核心原理

策略梯度方法

PPO属于策略梯度方法,直接优化策略函数π(a|s),目标是最大化期望奖励:

J(θ) = E[R(τ)] = E[Σ r(s_t, a_t)]

重要性采样

PPO使用重要性采样来重用旧策略收集的数据:

L(θ) = E[r_t(θ) * A_t]
其中 r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)

裁剪目标

为防止策略更新过大,PPO引入裁剪机制:

L^CLIP(θ) = E[min(r_t(θ)A_t, clip(r_t(θ), 1-ε, 1+ε)A_t)]

技术实现

PPO算法核心

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
 
class PPOAgent:
    def __init__(self, policy_model, value_model, clip_ratio=0.2, 
                 lr_policy=1e-6, lr_value=1e-5, gamma=0.99, lam=0.95):
        """
        PPO智能体
        
        Args:
            policy_model: 策略模型(语言模型)
            value_model: 价值模型(可选,用于估计状态价值)
            clip_ratio: 裁剪比率ε
            lr_policy: 策略学习率
            lr_value: 价值学习率
            gamma: 折扣因子
            lam: GAE参数
        """
        self.policy_model = policy_model
        self.value_model = value_model
        self.clip_ratio = clip_ratio
        self.gamma = gamma
        self.lam = lam
        
        # 优化器
        self.policy_optimizer = torch.optim.AdamW(
            policy_model.parameters(), lr=lr_policy
        )
        
        if value_model:
            self.value_optimizer = torch.optim.AdamW(
                value_model.parameters(), lr=lr_value
            )
    
    def get_action_and_value(self, state, action=None):
        """获取动作和价值"""
        
        # 策略模型前向传播
        outputs = self.policy_model(input_ids=state)
        logits = outputs.logits
        
        # 创建动作分布
        dist = Categorical(logits=logits)
        
        if action is None:
            # 采样动作
            action = dist.sample()
        
        # 计算log概率
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        
        # 价值估计
        if self.value_model:
            value = self.value_model(state)
        else:
            value = torch.zeros_like(log_prob)
        
        return action, log_prob, entropy, value
    
    def compute_gae(self, rewards, values, dones):
        """计算广义优势估计(GAE)"""
        
        advantages = []
        gae = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
            gae = delta + self.gamma * self.lam * (1 - dones[t]) * gae
            advantages.insert(0, gae)
        
        return torch.tensor(advantages, dtype=torch.float32)
    
    def update(self, states, actions, old_log_probs, rewards, dones, 
               num_epochs=4, batch_size=64):
        """PPO更新"""
        
        # 计算优势
        with torch.no_grad():
            _, _, _, values = self.get_action_and_value(states)
            advantages = self.compute_gae(rewards, values, dones)
            returns = advantages + values
            
            # 标准化优势
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # 多轮更新
        for epoch in range(num_epochs):
            # 随机打乱数据
            indices = torch.randperm(len(states))
            
            for start in range(0, len(states), batch_size):
                end = start + batch_size
                batch_indices = indices[start:end]
                
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]
                
                # 计算当前策略的log概率和价值
                _, new_log_probs, entropy, new_values = self.get_action_and_value(
                    batch_states, batch_actions
                )
                
                # 计算比率
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                
                # 策略损失(裁剪目标)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(
                    ratio, 1 - self.clip_ratio, 1 + self.clip_ratio
                ) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # 价值损失
                if self.value_model:
                    value_loss = F.mse_loss(new_values, batch_returns)
                else:
                    value_loss = torch.tensor(0.0)
                
                # 熵损失(鼓励探索)
                entropy_loss = -entropy.mean()
                
                # 总损失
                total_loss = policy_loss + 0.5 * value_loss + 0.01 * entropy_loss
                
                # 更新策略
                self.policy_optimizer.zero_grad()
                if self.value_model:
                    self.value_optimizer.zero_grad()
                
                total_loss.backward()
                
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), 0.5)
                if self.value_model:
                    torch.nn.utils.clip_grad_norm_(self.value_model.parameters(), 0.5)
                
                self.policy_optimizer.step()
                if self.value_model:
                    self.value_optimizer.step()
        
        return {
            "policy_loss": policy_loss.item(),
            "value_loss": value_loss.item(),
            "entropy_loss": entropy_loss.item(),
            "total_loss": total_loss.item()
        }

语言模型PPO适配

class LanguageModelPPO:
    def __init__(self, model, tokenizer, reward_model, ref_model=None):
        """语言模型PPO训练器"""
        
        self.model = model
        self.tokenizer = tokenizer
        self.reward_model = reward_model
        self.ref_model = ref_model or model
        
        # PPO参数
        self.clip_ratio = 0.2
        self.kl_coef = 0.1
        self.entropy_coef = 0.01
        
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)
        
        # 冻结参考模型和奖励模型
        for param in self.ref_model.parameters():
            param.requires_grad = False
        for param in self.reward_model.parameters():
            param.requires_grad = False
    
    def generate_response(self, prompt, max_length=256, temperature=0.7):
        """生成回答并记录log概率"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_length = inputs["input_ids"].size(1)
        
        # 生成序列
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                do_sample=True,
                temperature=temperature,
                pad_token_id=self.tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )
        
        # 提取生成的部分
        generated_ids = outputs.sequences[0][input_length:]
        response = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # 计算log概率
        logits = torch.stack(outputs.scores, dim=1)  # [1, seq_len, vocab_size]
        log_probs = F.log_softmax(logits, dim=-1)
        
        # 收集每个token的log概率
        token_log_probs = torch.gather(
            log_probs, -1, generated_ids.unsqueeze(0).unsqueeze(-1)
        ).squeeze(-1)
        
        return response, token_log_probs, generated_ids
    
    def compute_reward(self, prompt, response):
        """计算奖励分数"""
        
        full_text = f"{prompt}\n{response}"
        inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True)
        
        with torch.no_grad():
            reward = self.reward_model(**inputs)
        
        return reward.item()
    
    def compute_kl_divergence(self, prompt, response):
        """计算与参考模型的KL散度"""
        
        full_text = f"{prompt}\n{response}"
        inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True)
        
        # 当前模型的logits
        current_outputs = self.model(**inputs)
        current_logits = current_outputs.logits
        
        # 参考模型的logits
        with torch.no_grad():
            ref_outputs = self.ref_model(**inputs)
            ref_logits = ref_outputs.logits
        
        # 计算KL散度
        current_probs = F.softmax(current_logits, dim=-1)
        ref_probs = F.softmax(ref_logits, dim=-1)
        
        kl_div = F.kl_div(
            F.log_softmax(current_logits, dim=-1),
            ref_probs,
            reduction='batchmean'
        )
        
        return kl_div
    
    def ppo_step(self, prompts, responses, old_log_probs, rewards):
        """执行PPO更新步骤"""
        
        self.model.train()
        total_loss = 0
        
        for prompt, response, old_log_prob, reward in zip(
            prompts, responses, old_log_probs, rewards
        ):
            # 重新计算当前策略的log概率
            full_text = f"{prompt}\n{response}"
            inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True)
            
            outputs = self.model(**inputs)
            logits = outputs.logits
            
            # 计算response部分的log概率
            response_tokens = self.tokenizer(response, return_tensors="pt")["input_ids"]
            log_probs = F.log_softmax(logits, dim=-1)
            
            # 只计算生成部分的概率
            prompt_length = len(self.tokenizer(prompt)["input_ids"])
            response_log_probs = log_probs[0, prompt_length-1:-1]  # 去掉最后一个token
            
            current_log_prob = torch.gather(
                response_log_probs, -1, response_tokens[0, 1:].unsqueeze(-1)
            ).sum()
            
            # 计算比率
            ratio = torch.exp(current_log_prob - old_log_prob)
            
            # 优势(这里简化为奖励)
            advantage = reward
            
            # PPO裁剪损失
            surr1 = ratio * advantage
            surr2 = torch.clamp(
                ratio, 1 - self.clip_ratio, 1 + self.clip_ratio
            ) * advantage
            
            policy_loss = -torch.min(surr1, surr2)
            
            # KL散度惩罚
            kl_penalty = self.compute_kl_divergence(prompt, response)
            
            # 熵奖励(鼓励多样性)
            entropy = -(F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)).sum()
            
            # 总损失
            loss = policy_loss + self.kl_coef * kl_penalty - self.entropy_coef * entropy
            total_loss += loss
        
        # 反向传播
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        
        return total_loss.item()
    
    def train_step(self, prompts):
        """完整的PPO训练步骤"""
        
        # 1. 生成回答
        responses = []
        old_log_probs = []
        
        for prompt in prompts:
            response, log_probs, _ = self.generate_response(prompt)
            responses.append(response)
            old_log_probs.append(log_probs.sum().item())
        
        # 2. 计算奖励
        rewards = []
        for prompt, response in zip(prompts, responses):
            reward = self.compute_reward(prompt, response)
            rewards.append(reward)
        
        # 3. PPO更新
        loss = self.ppo_step(prompts, responses, old_log_probs, rewards)
        
        return {
            "loss": loss,
            "mean_reward": np.mean(rewards),
            "responses": responses
        }

高级技术

自适应KL惩罚

class AdaptiveKLPPO(LanguageModelPPO):
    def __init__(self, *args, target_kl=0.01, **kwargs):
        super().__init__(*args, **kwargs)
        self.target_kl = target_kl
        self.kl_coef = 0.1
        self.adaptive_kl = True
    
    def update_kl_coef(self, current_kl):
        """自适应调整KL系数"""
        
        if self.adaptive_kl:
            if current_kl < self.target_kl / 1.5:
                # KL太小,减少惩罚
                self.kl_coef *= 0.5
            elif current_kl > self.target_kl * 1.5:
                # KL太大,增加惩罚
                self.kl_coef *= 2.0
            
            # 限制范围
            self.kl_coef = np.clip(self.kl_coef, 0.001, 1.0)
    
    def train_step(self, prompts):
        """带自适应KL的训练步骤"""
        
        result = super().train_step(prompts)
        
        # 计算平均KL散度
        kl_divs = []
        for prompt, response in zip(prompts, result["responses"]):
            kl_div = self.compute_kl_divergence(prompt, response)
            kl_divs.append(kl_div.item())
        
        mean_kl = np.mean(kl_divs)
        self.update_kl_coef(mean_kl)
        
        result["mean_kl"] = mean_kl
        result["kl_coef"] = self.kl_coef
        
        return result

经验回放

class PPOWithExperienceReplay:
    def __init__(self, model, buffer_size=10000):
        self.model = model
        self.buffer_size = buffer_size
        self.experience_buffer = []
    
    def store_experience(self, prompt, response, reward, log_prob):
        """存储经验"""
        
        experience = {
            "prompt": prompt,
            "response": response,
            "reward": reward,
            "log_prob": log_prob
        }
        
        self.experience_buffer.append(experience)
        
        # 保持缓冲区大小
        if len(self.experience_buffer) > self.buffer_size:
            self.experience_buffer.pop(0)
    
    def sample_batch(self, batch_size=32):
        """采样批次数据"""
        
        if len(self.experience_buffer) < batch_size:
            return self.experience_buffer
        
        indices = np.random.choice(
            len(self.experience_buffer), 
            batch_size, 
            replace=False
        )
        
        return [self.experience_buffer[i] for i in indices]
    
    def replay_update(self, batch_size=32, num_updates=4):
        """经验回放更新"""
        
        for _ in range(num_updates):
            batch = self.sample_batch(batch_size)
            
            if len(batch) == 0:
                continue
            
            prompts = [exp["prompt"] for exp in batch]
            responses = [exp["response"] for exp in batch]
            rewards = [exp["reward"] for exp in batch]
            old_log_probs = [exp["log_prob"] for exp in batch]
            
            # 执行PPO更新
            loss = self.ppo_step(prompts, responses, old_log_probs, rewards)

评估指标

PPO特定指标

def evaluate_ppo_training(model, eval_prompts, reward_model):
    """评估PPO训练效果"""
    
    metrics = {
        "mean_reward": 0,
        "reward_std": 0,
        "policy_entropy": 0,
        "kl_divergence": 0,
        "response_length": 0
    }
    
    rewards = []
    entropies = []
    kl_divs = []
    lengths = []
    
    for prompt in eval_prompts:
        # 生成回答
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(
            **inputs,
            max_length=256,
            do_sample=True,
            temperature=0.7,
            return_dict_in_generate=True,
            output_scores=True
        )
        
        response = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
        
        # 计算奖励
        full_text = f"{prompt}\n{response}"
        reward_inputs = tokenizer(full_text, return_tensors="pt")
        reward = reward_model(**reward_inputs).item()
        rewards.append(reward)
        
        # 计算熵
        logits = torch.stack(outputs.scores, dim=1)
        probs = F.softmax(logits, dim=-1)
        entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).mean()
        entropies.append(entropy.item())
        
        # 记录长度
        lengths.append(len(response.split()))
    
    metrics["mean_reward"] = np.mean(rewards)
    metrics["reward_std"] = np.std(rewards)
    metrics["policy_entropy"] = np.mean(entropies)
    metrics["response_length"] = np.mean(lengths)
    
    return metrics

训练稳定性监控

class PPOStabilityMonitor:
    def __init__(self, window_size=100):
        self.window_size = window_size
        self.metrics_history = {
            "rewards": [],
            "policy_loss": [],
            "kl_divergence": [],
            "clip_fraction": []
        }
    
    def update(self, metrics):
        """更新监控指标"""
        for key, value in metrics.items():
            if key in self.metrics_history:
                self.metrics_history[key].append(value)
                
                # 保持窗口大小
                if len(self.metrics_history[key]) > self.window_size:
                    self.metrics_history[key].pop(0)
    
    def check_stability(self):
        """检查训练稳定性"""
        
        issues = []
        
        # 检查奖励趋势
        if len(self.metrics_history["rewards"]) > 10:
            recent_rewards = self.metrics_history["rewards"][-10:]
            if np.std(recent_rewards) > np.mean(recent_rewards) * 0.5:
                issues.append("奖励波动过大")
        
        # 检查KL散度
        if len(self.metrics_history["kl_divergence"]) > 5:
            recent_kl = self.metrics_history["kl_divergence"][-5:]
            if np.mean(recent_kl) > 0.1:
                issues.append("KL散度过大,可能偏离参考模型")
        
        # 检查裁剪比例
        if len(self.metrics_history["clip_fraction"]) > 5:
            recent_clip = self.metrics_history["clip_fraction"][-5:]
            if np.mean(recent_clip) > 0.8:
                issues.append("裁剪比例过高,学习率可能过大")
        
        return issues

优势与局限

优势

  1. 稳定训练:裁剪机制防止策略更新过大
  2. 样本效率:重要性采样提高数据利用率
  3. 简单实现:相比其他RL算法实现较简单
  4. 广泛适用:适用于各种强化学习任务
  5. 理论保证:有理论收敛保证

局限性

  1. 超参数敏感:需要仔细调优裁剪比率等参数
  2. 局部最优:可能陷入局部最优解
  3. 奖励设计:依赖良好的奖励函数设计
  4. 计算开销:需要多次前向传播计算
  5. 分布偏移:长期训练可能偏离初始分布

最佳实践

超参数设置

def ppo_hyperparameter_guide():
    """PPO超参数设置指南"""
    
    return {
        "clip_ratio": {
            "范围": "0.1 - 0.3",
            "推荐": "0.2",
            "说明": "裁剪比率,控制策略更新幅度"
        },
        
        "learning_rate": {
            "范围": "1e-7 - 1e-5",
            "推荐": "1e-6",
            "说明": "学习率,语言模型通常需要很小的值"
        },
        
        "kl_coef": {
            "范围": "0.01 - 0.5",
            "推荐": "0.1",
            "说明": "KL散度系数,控制与参考模型的偏离"
        },
        
        "entropy_coef": {
            "范围": "0.001 - 0.1",
            "推荐": "0.01",
            "说明": "熵系数,鼓励探索"
        },
        
        "num_epochs": {
            "范围": "3 - 10",
            "推荐": "4",
            "说明": "每批数据的更新轮数"
        }
    }

训练技巧

def ppo_training_tips():
    """PPO训练技巧"""
    
    return {
        "梯度裁剪": "使用梯度裁剪防止梯度爆炸",
        "学习率调度": "使用余弦退火或线性衰减",
        "批次大小": "根据显存调整,通常较小",
        "经验回放": "可选择性使用经验回放提高效率",
        "早停策略": "监控KL散度,及时停止训练",
        "检查点保存": "定期保存模型检查点",
        "指标监控": "密切监控奖励、KL散度等指标"
    }

相关概念