微调数据准备
概述
数据准备是微调成功的关键环节,包括数据收集、格式化、质量控制和Prompt工程等步骤。高质量的数据能显著提升微调效果,而数据格式的统一性则确保训练过程的稳定性。
数据格式标准化
JSONL格式(推荐)
{"instruction": "请解释什么是机器学习", "input": "", "output": "机器学习是一种人工智能技术..."}
{"instruction": "翻译以下英文", "input": "Hello world", "output": "你好世界"}
{"instruction": "分析这段文本的情感", "input": "今天天气真好", "output": "正面情感"}
对话格式
{
"conversations": [
{"role": "user", "content": "你好,我想了解Python"},
{"role": "assistant", "content": "你好!Python是一种编程语言..."},
{"role": "user", "content": "它有什么特点?"},
{"role": "assistant", "content": "Python具有以下特点:1. 语法简洁..."}
]
}
分类任务格式
{"text": "我想查询订单状态", "label": "订单查询", "label_id": 0}
{"text": "这个产品质量有问题", "label": "投诉建议", "label_id": 1}
Prompt工程最佳实践
指令设计原则
1. 明确性
# 好的指令
"请用三句话总结以下文章的主要内容:"
# 不好的指令
"总结一下"
2. 一致性
# 保持格式一致
template = """任务:{task_type}
输入:{input_text}
要求:{requirements}
输出:"""
3. 完整性
# 包含必要的上下文信息
prompt = """你是一个专业的客服助手。请根据用户的问题提供准确、友好的回答。
用户问题:{user_question}
回答:"""
元提示(Meta Prompt)设计
角色定义
role_prompts = {
"customer_service": "你是一个专业的客服代表,需要耐心、友好地解决用户问题。",
"technical_expert": "你是一个技术专家,请提供准确、详细的技术解答。",
"creative_writer": "你是一个创意写手,请发挥想象力创作有趣的内容。"
}
输出格式控制
format_prompt = """请按照以下格式回答:
1. 问题分析:[分析用户问题的核心]
2. 解决方案:[提供具体的解决步骤]
3. 注意事项:[相关的注意事项或建议]
"""
风格指导
style_prompt = """回答风格要求:
- 语言简洁明了
- 避免使用专业术语
- 提供具体的例子
- 保持友好的语调
"""
数据质量控制
数据清洗流程
1. 格式验证
def validate_data_format(data_item):
"""验证数据格式是否正确"""
required_fields = ['instruction', 'output']
for field in required_fields:
if field not in data_item:
return False, f"Missing field: {field}"
if not isinstance(data_item['instruction'], str):
return False, "Instruction must be string"
if len(data_item['output'].strip()) == 0:
return False, "Output cannot be empty"
return True, "Valid"
# 使用示例
for item in dataset:
is_valid, message = validate_data_format(item)
if not is_valid:
print(f"Invalid data: {message}")
2. 内容质量检查
def check_content_quality(instruction, output):
"""检查内容质量"""
issues = []
# 检查长度
if len(instruction) < 10:
issues.append("Instruction too short")
if len(output) < 20:
issues.append("Output too short")
# 检查重复
if instruction.lower() in output.lower():
issues.append("Output contains instruction")
# 检查语言一致性
if detect_language(instruction) != detect_language(output):
issues.append("Language mismatch")
return issues
def detect_language(text):
"""简单的语言检测"""
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
english_chars = len([c for c in text if c.isalpha() and ord(c) < 128])
if chinese_chars > english_chars:
return 'zh'
else:
return 'en'
3. 去重处理
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
def remove_duplicates(dataset, similarity_threshold=0.9):
"""基于相似度的去重"""
instructions = [item['instruction'] for item in dataset]
# 计算TF-IDF向量
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(instructions)
# 计算相似度矩阵
similarity_matrix = cosine_similarity(tfidf_matrix)
# 标记重复项
to_remove = set()
for i in range(len(dataset)):
for j in range(i+1, len(dataset)):
if similarity_matrix[i][j] > similarity_threshold:
to_remove.add(j)
# 移除重复项
filtered_dataset = [item for i, item in enumerate(dataset) if i not in to_remove]
print(f"Removed {len(to_remove)} duplicates")
return filtered_dataset
数据增强技术
1. 同义词替换
import random
def synonym_replacement(text, n=1):
"""同义词替换数据增强"""
synonyms = {
"好": ["不错", "优秀", "棒"],
"问题": ["疑问", "困难", "难题"],
"解决": ["处理", "解答", "搞定"]
}
words = text.split()
new_words = words.copy()
for _ in range(n):
random_word_idx = random.randint(0, len(words)-1)
random_word = words[random_word_idx]
if random_word in synonyms:
synonym = random.choice(synonyms[random_word])
new_words[random_word_idx] = synonym
return ' '.join(new_words)
2. 回译增强
from transformers import pipeline
def back_translation_augment(text, intermediate_lang='en'):
"""回译数据增强"""
# 中文 -> 英文
zh_to_en = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
en_text = zh_to_en(text)[0]['translation_text']
# 英文 -> 中文
en_to_zh = pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh")
back_translated = en_to_zh(en_text)[0]['translation_text']
return back_translated
# 使用示例
original_text = "今天天气很好"
augmented_text = back_translation_augment(original_text)
print(f"原文: {original_text}")
print(f"回译: {augmented_text}")
3. 模板变换
def template_variation(instruction, output):
"""通过模板变换生成新样本"""
templates = [
"请{action}:{content}",
"帮我{action}:{content}",
"能否{action}:{content}",
"如何{action}:{content}"
]
# 提取动作和内容
if "请" in instruction:
action = "解释" # 简化处理
content = instruction.replace("请解释", "").strip()
variations = []
for template in templates:
new_instruction = template.format(action=action, content=content)
variations.append({
"instruction": new_instruction,
"output": output
})
return variations
return []
少样本数据生成
使用GPT生成训练数据
import openai
def generate_training_data(domain, num_samples=100):
"""使用GPT生成训练数据"""
prompt_template = """请生成{num}个关于{domain}的问答对,格式如下:
问题:[用户可能问的问题]
回答:[专业、准确的回答]
要求:
1. 问题要多样化,覆盖不同场景
2. 回答要专业、准确、有用
3. 语言要自然、流畅
示例:
问题:什么是机器学习?
回答:机器学习是人工智能的一个分支,通过算法让计算机从数据中学习模式和规律...
请开始生成:"""
prompt = prompt_template.format(num=num_samples, domain=domain)
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
max_tokens=2000,
temperature=0.7
)
generated_text = response.choices[0].message.content
# 解析生成的问答对
qa_pairs = parse_generated_qa(generated_text)
return qa_pairs
def parse_generated_qa(text):
"""解析GPT生成的问答对"""
qa_pairs = []
lines = text.split('\n')
current_question = None
current_answer = None
for line in lines:
line = line.strip()
if line.startswith('问题:'):
current_question = line.replace('问题:', '').strip()
elif line.startswith('回答:'):
current_answer = line.replace('回答:', '').strip()
if current_question and current_answer:
qa_pairs.append({
'instruction': current_question,
'output': current_answer
})
current_question = None
current_answer = None
return qa_pairs
数据集划分策略
训练/验证/测试集划分
from sklearn.model_selection import train_test_split
import random
def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
"""数据集划分"""
assert train_ratio + val_ratio + test_ratio == 1.0
# 随机打乱
random.shuffle(dataset)
total_size = len(dataset)
train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)
train_data = dataset[:train_size]
val_data = dataset[train_size:train_size + val_size]
test_data = dataset[train_size + val_size:]
return train_data, val_data, test_data
# 分层采样(适用于分类任务)
def stratified_split(dataset, label_key='label'):
"""分层采样划分"""
from collections import defaultdict
# 按标签分组
label_groups = defaultdict(list)
for item in dataset:
label_groups[item[label_key]].append(item)
train_data, val_data, test_data = [], [], []
# 对每个标签进行划分
for label, items in label_groups.items():
train_items, val_items, test_items = split_dataset(items)
train_data.extend(train_items)
val_data.extend(val_items)
test_data.extend(test_items)
# 重新打乱
random.shuffle(train_data)
random.shuffle(val_data)
random.shuffle(test_data)
return train_data, val_data, test_data
数据预处理工具
文本清洗
import re
def clean_text(text):
"""文本清洗"""
# 移除多余空格
text = re.sub(r'\s+', ' ', text)
# 移除特殊字符(保留中文、英文、数字、基本标点)
text = re.sub(r'[^\u4e00-\u9fff\w\s.,!?;:()[\]{}"\'-]', '', text)
# 统一标点符号
text = text.replace(',', ',').replace('。', '.').replace('?', '?').replace('!', '!')
return text.strip()
def normalize_format(dataset):
"""格式标准化"""
normalized_data = []
for item in dataset:
normalized_item = {}
# 清洗文本
if 'instruction' in item:
normalized_item['instruction'] = clean_text(item['instruction'])
if 'input' in item:
normalized_item['input'] = clean_text(item['input'])
if 'output' in item:
normalized_item['output'] = clean_text(item['output'])
# 保留其他字段
for key, value in item.items():
if key not in ['instruction', 'input', 'output']:
normalized_item[key] = value
normalized_data.append(normalized_item)
return normalized_data
数据质量评估
统计分析
def analyze_dataset(dataset):
"""数据集统计分析"""
stats = {
'total_samples': len(dataset),
'avg_instruction_length': 0,
'avg_output_length': 0,
'instruction_length_distribution': [],
'output_length_distribution': []
}
instruction_lengths = []
output_lengths = []
for item in dataset:
if 'instruction' in item:
length = len(item['instruction'])
instruction_lengths.append(length)
if 'output' in item:
length = len(item['output'])
output_lengths.append(length)
if instruction_lengths:
stats['avg_instruction_length'] = sum(instruction_lengths) / len(instruction_lengths)
stats['instruction_length_distribution'] = {
'min': min(instruction_lengths),
'max': max(instruction_lengths),
'median': sorted(instruction_lengths)[len(instruction_lengths)//2]
}
if output_lengths:
stats['avg_output_length'] = sum(output_lengths) / len(output_lengths)
stats['output_length_distribution'] = {
'min': min(output_lengths),
'max': max(output_lengths),
'median': sorted(output_lengths)[len(output_lengths)//2]
}
return stats
# 使用示例
dataset_stats = analyze_dataset(training_data)
print(f"数据集大小: {dataset_stats['total_samples']}")
print(f"平均指令长度: {dataset_stats['avg_instruction_length']:.1f}")
print(f"平均输出长度: {dataset_stats['avg_output_length']:.1f}")