- Published on
Supervised Fine-tuning Deep Dive: Building Your First Instruction-Following Model
- Authors
- Name
- Jared Chung
Supervised Fine-tuning (SFT) is the foundation of most LLM customization efforts. It's where you teach a pre-trained model to follow specific instructions, adopt particular response styles, or excel at domain-specific tasks. This post will walk you through implementing SFT from scratch, with real code and practical insights.
Understanding Supervised Fine-tuning
What Makes SFT Different
Unlike pre-training (which learns from raw text) or reinforcement learning (which learns from rewards), SFT learns from explicit input-output pairs:
# Pre-training learns from sequences like:
"The capital of France is Paris. It is known for..."
# SFT learns from structured examples like:
{
"instruction": "What is the capital of France?",
"output": "The capital of France is Paris, a city renowned for its culture, art, and history."
}
This structure teaches the model to understand the relationship between what humans ask and how they expect responses to be formatted.
The SFT Learning Process
During SFT, the model learns to maximize the probability of generating the correct output tokens given the input:
# Simplified loss calculation
loss = -log(P(output_tokens | input_tokens))
# In practice, this becomes:
for token in output_sequence:
loss += -log(P(token | previous_context))
Data Preparation for SFT
Dataset Selection and Quality
The quality of your training data directly determines your model's capabilities. Here's how to evaluate and prepare datasets:
# data_quality_checker.py
import json
import re
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
class DataQualityAnalyzer:
def __init__(self, dataset_path):
with open(dataset_path, 'r') as f:
self.data = json.load(f)
def analyze_lengths(self):
"""Analyze instruction and response lengths"""
instruction_lengths = [len(item['instruction'].split()) for item in self.data]
output_lengths = [len(item['output'].split()) for item in self.data]
stats = {
'instruction_stats': {
'mean': np.mean(instruction_lengths),
'median': np.median(instruction_lengths),
'std': np.std(instruction_lengths)
},
'output_stats': {
'mean': np.mean(output_lengths),
'median': np.median(output_lengths),
'std': np.std(output_lengths)
}
}
return stats
def check_format_consistency(self):
"""Check for formatting issues"""
issues = []
required_fields = ['instruction', 'output']
for i, item in enumerate(self.data):
# Check required fields
missing_fields = [field for field in required_fields if field not in item]
if missing_fields:
issues.append(f"Row {i}: Missing fields {missing_fields}")
# Check for empty content
if item.get('instruction', '').strip() == '':
issues.append(f"Row {i}: Empty instruction")
if item.get('output', '').strip() == '':
issues.append(f"Row {i}: Empty output")
return issues
def analyze_task_distribution(self):
"""Analyze the distribution of different task types"""
# Simple heuristic-based task classification
task_patterns = {
'question_answering': r'\?|what|who|when|where|why|how',
'explanation': r'explain|describe|define',
'generation': r'write|create|generate|compose',
'analysis': r'analyze|compare|evaluate|assess'
}
task_counts = Counter()
for item in self.data:
instruction = item['instruction'].lower()
for task_type, pattern in task_patterns.items():
if re.search(pattern, instruction):
task_counts[task_type] += 1
break
else:
task_counts['other'] += 1
return task_counts
# Usage
analyzer = DataQualityAnalyzer('your_dataset.json')
length_stats = analyzer.analyze_lengths()
format_issues = analyzer.check_format_consistency()
task_distribution = analyzer.analyze_task_distribution()
Data Preprocessing Pipeline
Create a robust preprocessing pipeline:
# data_preprocessor.py
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
import re
class SFTDataProcessor:
def __init__(self, tokenizer_name, max_length=2048):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.max_length = max_length
# Add special tokens if needed
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def format_instruction(self, instruction, input_text="", output=""):
"""Format data using Alpaca-style template"""
if input_text:
prompt = f"Below is an instruction that describes a task, paired with an input. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
else:
prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
return prompt, output
def tokenize_function(self, examples):
"""Tokenize the formatted examples"""
formatted_examples = []
for i in range(len(examples['instruction'])):
prompt, response = self.format_instruction(
examples['instruction'][i],
examples.get('input', [''] * len(examples['instruction']))[i],
examples['output'][i]
)
# Create full text for training
full_text = prompt + response + self.tokenizer.eos_token
formatted_examples.append(full_text)
# Tokenize
tokenized = self.tokenizer(
formatted_examples,
truncation=True,
padding=False,
max_length=self.max_length,
return_tensors=None
)
# Create labels (same as input_ids for causal LM)
tokenized['labels'] = tokenized['input_ids'].copy()
return tokenized
def process_dataset(self, raw_data):
"""Process raw data into training format"""
# Convert to HuggingFace Dataset
dataset = Dataset.from_list(raw_data)
# Apply tokenization
tokenized_dataset = dataset.map(
self.tokenize_function,
batched=True,
remove_columns=dataset.column_names
)
return tokenized_dataset
# Usage example
processor = SFTDataProcessor("microsoft/DialoGPT-medium")
processed_dataset = processor.process_dataset(your_data)
Training Implementation
Complete Training Script
Here's a production-ready training script with all the essential components:
# train_sft.py
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import load_from_disk
import wandb
import json
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class SFTConfig:
# Model settings
model_name: str = "microsoft/DialoGPT-medium"
max_length: int = 2048
# Training settings
output_dir: str = "./sft_results"
num_train_epochs: int = 3
per_device_train_batch_size: int = 4
per_device_eval_batch_size: int = 4
gradient_accumulation_steps: int = 4
learning_rate: float = 2e-5
weight_decay: float = 0.01
warmup_ratio: float = 0.1
# Optimization settings
fp16: bool = True
gradient_checkpointing: bool = True
dataloader_num_workers: int = 4
# Logging and evaluation
logging_steps: int = 50
eval_steps: int = 500
save_steps: int = 500
eval_accumulation_steps: int = 1
# Data settings
data_path: str = "./processed_data"
# Experiment tracking
wandb_project: Optional[str] = "sft-experiments"
run_name: Optional[str] = None
class SFTTrainer:
def __init__(self, config: SFTConfig):
self.config = config
self.setup_model_and_tokenizer()
self.setup_data()
def setup_model_and_tokenizer(self):
"""Initialize model and tokenizer"""
print(f"Loading model: {self.config.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model_name,
use_fast=True
)
# Ensure we have a pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
torch_dtype=torch.float16 if self.config.fp16 else torch.float32,
device_map="auto"
)
# Resize token embeddings if needed
self.model.resize_token_embeddings(len(self.tokenizer))
def setup_data(self):
"""Load and prepare datasets"""
print(f"Loading data from: {self.config.data_path}")
self.dataset = load_from_disk(self.config.data_path)
print(f"Train samples: {len(self.dataset['train'])}")
print(f"Eval samples: {len(self.dataset['validation'])}")
def create_training_arguments(self):
"""Create training arguments"""
return TrainingArguments(
output_dir=self.config.output_dir,
num_train_epochs=self.config.num_train_epochs,
per_device_train_batch_size=self.config.per_device_train_batch_size,
per_device_eval_batch_size=self.config.per_device_eval_batch_size,
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
learning_rate=self.config.learning_rate,
weight_decay=self.config.weight_decay,
warmup_ratio=self.config.warmup_ratio,
# Optimization
fp16=self.config.fp16,
gradient_checkpointing=self.config.gradient_checkpointing,
dataloader_num_workers=self.config.dataloader_num_workers,
# Logging and saving
logging_steps=self.config.logging_steps,
eval_steps=self.config.eval_steps,
save_steps=self.config.save_steps,
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
# Experiment tracking
report_to="wandb" if self.config.wandb_project else None,
run_name=self.config.run_name,
# Misc
remove_unused_columns=False,
dataloader_pin_memory=True,
)
def train(self):
"""Main training loop"""
# Setup wandb
if self.config.wandb_project:
wandb.init(
project=self.config.wandb_project,
name=self.config.run_name,
config=self.config.__dict__
)
# Create trainer
training_args = self.create_training_arguments()
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False, # We're doing causal LM, not masked LM
pad_to_multiple_of=8 if self.config.fp16 else None
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=self.dataset['train'],
eval_dataset=self.dataset['validation'],
tokenizer=self.tokenizer,
data_collator=data_collator,
)
# Train the model
print("Starting training...")
trainer.train()
# Save final model
print("Saving final model...")
trainer.save_model()
self.tokenizer.save_pretrained(self.config.output_dir)
# Save config
with open(os.path.join(self.config.output_dir, "training_config.json"), "w") as f:
json.dump(self.config.__dict__, f, indent=2)
print(f"Training completed! Model saved to {self.config.output_dir}")
# Training script execution
if __name__ == "__main__":
config = SFTConfig(
model_name="microsoft/DialoGPT-medium",
data_path="./processed_data",
output_dir="./sft_model",
num_train_epochs=3,
per_device_train_batch_size=2, # Adjust based on your GPU
gradient_accumulation_steps=8, # Effective batch size = 2 * 8 = 16
learning_rate=2e-5,
wandb_project="llm-sft-tutorial",
run_name="medium-model-tutorial"
)
trainer = SFTTrainer(config)
trainer.train()
Hyperparameter Optimization
Learning Rate and Schedule
Learning rate is crucial for SFT success:
# Learning rate finder
from transformers import get_scheduler
def find_optimal_lr(model, train_dataloader, start_lr=1e-8, end_lr=10, num_steps=100):
"""Simple learning rate finder"""
optimizer = torch.optim.AdamW(model.parameters(), lr=start_lr)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=(end_lr/start_lr)**(1/num_steps)
)
losses = []
lrs = []
model.train()
for step, batch in enumerate(train_dataloader):
if step >= num_steps:
break
# Forward pass
outputs = model(**batch)
loss = outputs.loss
losses.append(loss.item())
lrs.append(optimizer.param_groups[0]['lr'])
# Backward pass
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
return lrs, losses
# Recommended learning rates by model size
recommended_lr = {
"small (< 1B params)": 5e-5,
"medium (1-7B params)": 2e-5,
"large (7-30B params)": 1e-5,
"very_large (> 30B params)": 5e-6
}
Batch Size and Gradient Accumulation
Balance memory usage with training stability:
def calculate_effective_batch_size(per_device_batch_size, gradient_accumulation_steps, num_gpus=1):
"""Calculate the effective batch size"""
return per_device_batch_size * gradient_accumulation_steps * num_gpus
# Example configurations for different GPU setups
gpu_configs = {
"RTX 3070 (8GB)": {
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"effective_batch_size": 16
},
"RTX 4090 (24GB)": {
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 8,
"effective_batch_size": 32
},
"A100 (80GB)": {
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 4,
"effective_batch_size": 32
}
}
Monitoring and Evaluation
Real-time Monitoring
Track key metrics during training:
# custom_callbacks.py
from transformers import TrainerCallback
import wandb
import numpy as np
class DetailedLoggingCallback(TrainerCallback):
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
"""Enhanced logging with additional metrics"""
if logs:
# Calculate additional metrics
if 'train_loss' in logs:
logs['perplexity'] = np.exp(logs['train_loss'])
# Log gradient norms
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
logs['grad_norm'] = total_norm
# Log to wandb if available
if wandb.run:
wandb.log(logs, step=state.global_step)
Evaluation During Training
Implement meaningful evaluation metrics:
# evaluation.py
from transformers import pipeline
import torch
import numpy as np
class SFTEvaluator:
def __init__(self, model_path, tokenizer_path):
self.generator = pipeline(
"text-generation",
model=model_path,
tokenizer=tokenizer_path,
torch_dtype=torch.float16,
device_map="auto"
)
def evaluate_responses(self, test_instructions, max_length=512):
"""Generate responses for test instructions"""
results = []
for instruction in test_instructions:
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
response = self.generator(
prompt,
max_length=max_length,
temperature=0.7,
do_sample=True,
pad_token_id=self.generator.tokenizer.eos_token_id
)[0]['generated_text']
# Extract just the response part
response_start = response.find("### Response:\n") + len("### Response:\n")
generated_response = response[response_start:].strip()
results.append({
'instruction': instruction,
'response': generated_response
})
return results
def calculate_response_quality_metrics(self, results):
"""Calculate basic quality metrics"""
metrics = {
'avg_response_length': np.mean([len(r['response'].split()) for r in results]),
'responses_with_repetition': sum(1 for r in results if self.has_repetition(r['response'])),
'empty_responses': sum(1 for r in results if len(r['response'].strip()) == 0)
}
return metrics
def has_repetition(self, text, threshold=0.3):
"""Simple repetition detection"""
words = text.split()
if len(words) < 10:
return False
# Check for repeated phrases
for i in range(len(words) - 4):
phrase = ' '.join(words[i:i+3])
if text.count(phrase) > 1:
return True
return False
# Usage during training
test_instructions = [
"Explain the concept of machine learning in simple terms.",
"Write a short story about a robot learning to paint.",
"What are the benefits of renewable energy?"
]
evaluator = SFTEvaluator("./sft_model", "./sft_model")
eval_results = evaluator.evaluate_responses(test_instructions)
quality_metrics = evaluator.calculate_response_quality_metrics(eval_results)
Common Issues and Solutions
Overfitting
Signs and solutions for overfitting:
def detect_overfitting(train_losses, eval_losses, patience=5):
"""Detect overfitting patterns"""
if len(eval_losses) < patience + 1:
return False
# Check if eval loss has been increasing for 'patience' steps
recent_eval = eval_losses[-patience:]
return all(recent_eval[i] >= recent_eval[i-1] for i in range(1, len(recent_eval)))
# Solutions for overfitting:
overfitting_solutions = {
"reduce_learning_rate": "Lower learning rate by 2-5x",
"increase_weight_decay": "Use weight_decay=0.01-0.1",
"early_stopping": "Stop when eval loss stops improving",
"more_data": "Add more diverse training examples",
"shorter_training": "Reduce number of epochs"
}
Training Instability
Monitor for training instabilities:
def check_training_stability(losses, window=50):
"""Check for training instabilities"""
if len(losses) < window:
return True
recent_losses = losses[-window:]
# Check for exploding gradients
if any(loss > 10 * np.median(recent_losses) for loss in recent_losses[-10:]):
return False
# Check for loss oscillations
oscillations = sum(1 for i in range(1, len(recent_losses))
if (recent_losses[i] - recent_losses[i-1]) *
(recent_losses[i-1] - recent_losses[max(0, i-2)]) < 0)
if oscillations > window * 0.7: # Too many direction changes
return False
return True
Best Practices Summary
Training Strategy
- Start small: Begin with a subset of your data to validate the pipeline
- Monitor closely: Watch training and validation losses carefully
- Save checkpoints: Regular saves prevent loss of progress
- Use validation sets: Always hold out data for evaluation
Data Quality
- Clean formatting: Consistent templates and formatting
- Diverse examples: Cover all expected use cases
- Quality over quantity: Better to have fewer high-quality examples
- Regular audits: Periodically review your training data
Hyperparameter Selection
- Learning rate: Start with 2e-5 for most models
- Batch size: Balance memory constraints with stability
- Epochs: Usually 1-5 epochs is sufficient for SFT
- Warmup: Use 10% of total steps for warmup
The next post in this series will explore Parameter-Efficient Fine-tuning with LoRA and QLoRA, showing how to achieve similar results with dramatically reduced memory requirements and training time.
With the foundation of supervised fine-tuning under your belt, you're ready to tackle more advanced techniques that make LLM customization accessible even with limited computational resources.