多任务联合微调

概述

多任务联合微调是指在一个模型中同时训练多个相关任务,通过任务间的知识共享来提升整体性能。这种方法特别适用于任务间有相关性的场景,如同时进行生成式任务微调分类任务微调

多任务学习的优势

知识共享

  • 共享表示:底层特征在多个任务间共享
  • 正则化效果:多任务训练起到隐式正则化作用
  • 数据效率:充分利用所有任务的数据

性能提升

  • 泛化能力:多任务学习提升模型泛化性
  • 鲁棒性:减少对单一任务的过拟合
  • 资源效率:一个模型服务多个任务

多任务架构设计

共享编码器 + 多个任务头

import torch.nn as nn
 
class MultiTaskModel(nn.Module):
    def __init__(self, base_model, task_configs):
        super().__init__()
        self.base_model = base_model
        self.task_heads = nn.ModuleDict()
        
        for task_name, config in task_configs.items():
            if config['type'] == 'classification':
                self.task_heads[task_name] = nn.Linear(
                    base_model.config.hidden_size, 
                    config['num_classes']
                )
            elif config['type'] == 'generation':
                self.task_heads[task_name] = nn.Linear(
                    base_model.config.hidden_size,
                    base_model.config.vocab_size
                )
    
    def forward(self, input_ids, attention_mask, task_name):
        # 共享编码器
        outputs = self.base_model(input_ids, attention_mask)
        hidden_states = outputs.last_hidden_state
        
        # 任务特定的头
        if task_name in self.task_heads:
            task_output = self.task_heads[task_name](hidden_states)
            return task_output
        else:
            raise ValueError(f"Unknown task: {task_name}")

任务特定层设计

class TaskSpecificLayers(nn.Module):
    def __init__(self, hidden_size, task_configs):
        super().__init__()
        self.task_layers = nn.ModuleDict()
        
        for task_name, config in task_configs.items():
            # 每个任务有自己的特定层
            self.task_layers[task_name] = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_size, config['output_size'])
            )
    
    def forward(self, shared_features, task_name):
        return self.task_layers[task_name](shared_features)

损失函数设计

加权多任务损失

def multi_task_loss(outputs, targets, task_weights):
    total_loss = 0
    task_losses = {}
    
    for task_name, output in outputs.items():
        if task_name == 'classification':
            task_loss = F.cross_entropy(output, targets[task_name])
        elif task_name == 'generation':
            task_loss = F.cross_entropy(
                output.view(-1, output.size(-1)), 
                targets[task_name].view(-1)
            )
        elif task_name == 'regression':
            task_loss = F.mse_loss(output, targets[task_name])
        
        task_losses[task_name] = task_loss
        total_loss += task_weights[task_name] * task_loss
    
    return total_loss, task_losses

动态权重调整

class DynamicWeightAveraging:
    def __init__(self, num_tasks, temperature=2.0):
        self.num_tasks = num_tasks
        self.temperature = temperature
        self.task_losses_history = []
    
    def update_weights(self, task_losses):
        self.task_losses_history.append(task_losses)
        
        if len(self.task_losses_history) < 2:
            # 初始权重相等
            return {task: 1.0 for task in task_losses.keys()}
        
        # 计算损失变化率
        prev_losses = self.task_losses_history[-2]
        loss_ratios = {}
        
        for task in task_losses.keys():
            ratio = task_losses[task] / prev_losses[task]
            loss_ratios[task] = ratio
        
        # 使用softmax计算权重
        weights = {}
        ratio_values = list(loss_ratios.values())
        softmax_weights = F.softmax(torch.tensor(ratio_values) / self.temperature, dim=0)
        
        for i, task in enumerate(loss_ratios.keys()):
            weights[task] = softmax_weights[i].item()
        
        return weights

数据处理策略

混合批次采样

class MultiTaskDataLoader:
    def __init__(self, task_dataloaders, sampling_strategy='round_robin'):
        self.task_dataloaders = task_dataloaders
        self.sampling_strategy = sampling_strategy
        self.task_iterators = {
            task: iter(dataloader) 
            for task, dataloader in task_dataloaders.items()
        }
    
    def __iter__(self):
        if self.sampling_strategy == 'round_robin':
            return self._round_robin_sampling()
        elif self.sampling_strategy == 'proportional':
            return self._proportional_sampling()
    
    def _round_robin_sampling(self):
        task_names = list(self.task_dataloaders.keys())
        task_idx = 0
        
        while True:
            task_name = task_names[task_idx]
            try:
                batch = next(self.task_iterators[task_name])
                batch['task_name'] = task_name
                yield batch
                task_idx = (task_idx + 1) % len(task_names)
            except StopIteration:
                # 重新初始化迭代器
                self.task_iterators[task_name] = iter(self.task_dataloaders[task_name])
                break

任务数据格式统一

def unify_data_format(batch, task_name):
    """统一不同任务的数据格式"""
    unified_batch = {
        'input_ids': batch['input_ids'],
        'attention_mask': batch['attention_mask'],
        'task_name': task_name
    }
    
    if task_name == 'classification':
        unified_batch['labels'] = batch['labels']
    elif task_name == 'generation':
        unified_batch['labels'] = batch['input_ids']  # 生成任务的标签就是输入
    elif task_name == 'ner':
        unified_batch['labels'] = batch['ner_labels']
    
    return unified_batch

实战案例:智能客服系统

任务定义

# 定义多个相关任务
task_configs = {
    'intent_classification': {
        'type': 'classification',
        'num_classes': 10,  # 意图类别数
        'weight': 1.0
    },
    'sentiment_analysis': {
        'type': 'classification', 
        'num_classes': 3,   # 正面、负面、中性
        'weight': 0.8
    },
    'response_generation': {
        'type': 'generation',
        'vocab_size': 21128,
        'weight': 1.2
    },
    'entity_extraction': {
        'type': 'sequence_labeling',
        'num_labels': 9,    # BIO标注
        'weight': 0.9
    }
}

训练流程

def train_multi_task_model(model, multi_task_dataloader, task_configs):
    optimizer = AdamW(model.parameters(), lr=2e-5)
    weight_averager = DynamicWeightAveraging(len(task_configs))
    
    for epoch in range(num_epochs):
        epoch_losses = {task: 0 for task in task_configs.keys()}
        
        for batch in multi_task_dataloader:
            task_name = batch['task_name']
            
            # 前向传播
            outputs = model(
                batch['input_ids'], 
                batch['attention_mask'], 
                task_name
            )
            
            # 计算损失
            task_loss = compute_task_loss(outputs, batch['labels'], task_name)
            epoch_losses[task_name] += task_loss.item()
            
            # 反向传播
            task_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        # 更新任务权重
        avg_losses = {task: loss/len(multi_task_dataloader) 
                     for task, loss in epoch_losses.items()}
        task_weights = weight_averager.update_weights(avg_losses)
        
        print(f"Epoch {epoch}, Losses: {avg_losses}, Weights: {task_weights}")

多模态微调

图文联合模型

class MultiModalModel(nn.Module):
    def __init__(self, text_encoder, image_encoder, fusion_dim):
        super().__init__()
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        
        # 模态融合层
        self.text_projection = nn.Linear(text_encoder.config.hidden_size, fusion_dim)
        self.image_projection = nn.Linear(image_encoder.config.hidden_size, fusion_dim)
        
        # 多任务头
        self.classification_head = nn.Linear(fusion_dim, num_classes)
        self.generation_head = nn.Linear(fusion_dim, vocab_size)
    
    def forward(self, text_inputs, image_inputs, task_type):
        # 编码文本和图像
        text_features = self.text_encoder(**text_inputs).pooler_output
        image_features = self.image_encoder(image_inputs).pooler_output
        
        # 投影到共同空间
        text_proj = self.text_projection(text_features)
        image_proj = self.image_projection(image_features)
        
        # 模态融合(这里使用简单的拼接)
        fused_features = torch.cat([text_proj, image_proj], dim=-1)
        
        # 任务特定输出
        if task_type == 'classification':
            return self.classification_head(fused_features)
        elif task_type == 'generation':
            return self.generation_head(fused_features)

多模态数据处理

def process_multimodal_data(text, image_path, tokenizer, image_processor):
    # 处理文本
    text_inputs = tokenizer(
        text, 
        return_tensors='pt', 
        padding=True, 
        truncation=True
    )
    
    # 处理图像
    image = Image.open(image_path)
    image_inputs = image_processor(image, return_tensors='pt')
    
    return {
        'text_inputs': text_inputs,
        'image_inputs': image_inputs
    }

评估策略

任务特定评估

def evaluate_multi_task_model(model, eval_dataloaders, task_configs):
    model.eval()
    task_metrics = {}
    
    for task_name, eval_dataloader in eval_dataloaders.items():
        task_predictions = []
        task_labels = []
        
        for batch in eval_dataloader:
            with torch.no_grad():
                outputs = model(
                    batch['input_ids'], 
                    batch['attention_mask'], 
                    task_name
                )
                
                if task_configs[task_name]['type'] == 'classification':
                    predictions = torch.argmax(outputs, dim=-1)
                    task_predictions.extend(predictions.cpu().tolist())
                    task_labels.extend(batch['labels'].cpu().tolist())
        
        # 计算任务特定指标
        if task_configs[task_name]['type'] == 'classification':
            accuracy = accuracy_score(task_labels, task_predictions)
            f1 = f1_score(task_labels, task_predictions, average='weighted')
            task_metrics[task_name] = {'accuracy': accuracy, 'f1': f1}
    
    return task_metrics

优化技巧

梯度冲突处理

def resolve_gradient_conflicts(model, task_losses, task_weights):
    """处理多任务间的梯度冲突"""
    # 计算每个任务的梯度
    task_gradients = {}
    
    for task_name, loss in task_losses.items():
        model.zero_grad()
        loss.backward(retain_graph=True)
        
        # 收集梯度
        task_grad = []
        for param in model.parameters():
            if param.grad is not None:
                task_grad.append(param.grad.clone().flatten())
        
        task_gradients[task_name] = torch.cat(task_grad)
    
    # 使用PCGrad或其他方法解决冲突
    # 这里简化为加权平均
    final_gradient = torch.zeros_like(task_gradients[list(task_gradients.keys())[0]])
    
    for task_name, grad in task_gradients.items():
        final_gradient += task_weights[task_name] * grad
    
    # 应用最终梯度
    param_idx = 0
    for param in model.parameters():
        if param.grad is not None:
            param_size = param.numel()
            param.grad = final_gradient[param_idx:param_idx+param_size].view(param.shape)
            param_idx += param_size

相关概念