Jared AI Hub
Published on

Supervised Fine-tuning Deep Dive: Building Your First Instruction-Following Model

Authors
  • avatar
    Name
    Jared Chung
    Twitter

Supervised Fine-tuning (SFT) is the foundation of most LLM customization efforts. It's where you teach a pre-trained model to follow specific instructions, adopt particular response styles, or excel at domain-specific tasks. This post will walk you through implementing SFT from scratch, with real code and practical insights.

Understanding Supervised Fine-tuning

What Makes SFT Different

Unlike pre-training (which learns from raw text) or reinforcement learning (which learns from rewards), SFT learns from explicit input-output pairs:

# Pre-training learns from sequences like:
"The capital of France is Paris. It is known for..."

# SFT learns from structured examples like:
{
    "instruction": "What is the capital of France?",
    "output": "The capital of France is Paris, a city renowned for its culture, art, and history."
}

This structure teaches the model to understand the relationship between what humans ask and how they expect responses to be formatted.

The SFT Learning Process

During SFT, the model learns to maximize the probability of generating the correct output tokens given the input:

# Simplified loss calculation
loss = -log(P(output_tokens | input_tokens))

# In practice, this becomes:
for token in output_sequence:
    loss += -log(P(token | previous_context))

Data Preparation for SFT

Dataset Selection and Quality

The quality of your training data directly determines your model's capabilities. Here's how to evaluate and prepare datasets:

# data_quality_checker.py
import json
import re
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np

class DataQualityAnalyzer:
    def __init__(self, dataset_path):
        with open(dataset_path, 'r') as f:
            self.data = json.load(f)
    
    def analyze_lengths(self):
        """Analyze instruction and response lengths"""
        instruction_lengths = [len(item['instruction'].split()) for item in self.data]
        output_lengths = [len(item['output'].split()) for item in self.data]
        
        stats = {
            'instruction_stats': {
                'mean': np.mean(instruction_lengths),
                'median': np.median(instruction_lengths),
                'std': np.std(instruction_lengths)
            },
            'output_stats': {
                'mean': np.mean(output_lengths),
                'median': np.median(output_lengths),
                'std': np.std(output_lengths)
            }
        }
        
        return stats
    
    def check_format_consistency(self):
        """Check for formatting issues"""
        issues = []
        required_fields = ['instruction', 'output']
        
        for i, item in enumerate(self.data):
            # Check required fields
            missing_fields = [field for field in required_fields if field not in item]
            if missing_fields:
                issues.append(f"Row {i}: Missing fields {missing_fields}")
            
            # Check for empty content
            if item.get('instruction', '').strip() == '':
                issues.append(f"Row {i}: Empty instruction")
            if item.get('output', '').strip() == '':
                issues.append(f"Row {i}: Empty output")
        
        return issues
    
    def analyze_task_distribution(self):
        """Analyze the distribution of different task types"""
        # Simple heuristic-based task classification
        task_patterns = {
            'question_answering': r'\?|what|who|when|where|why|how',
            'explanation': r'explain|describe|define',
            'generation': r'write|create|generate|compose',
            'analysis': r'analyze|compare|evaluate|assess'
        }
        
        task_counts = Counter()
        
        for item in self.data:
            instruction = item['instruction'].lower()
            for task_type, pattern in task_patterns.items():
                if re.search(pattern, instruction):
                    task_counts[task_type] += 1
                    break
            else:
                task_counts['other'] += 1
        
        return task_counts

# Usage
analyzer = DataQualityAnalyzer('your_dataset.json')
length_stats = analyzer.analyze_lengths()
format_issues = analyzer.check_format_consistency()
task_distribution = analyzer.analyze_task_distribution()

Data Preprocessing Pipeline

Create a robust preprocessing pipeline:

# data_preprocessor.py
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
import re

class SFTDataProcessor:
    def __init__(self, tokenizer_name, max_length=2048):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_length = max_length
        
        # Add special tokens if needed
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def format_instruction(self, instruction, input_text="", output=""):
        """Format data using Alpaca-style template"""
        if input_text:
            prompt = f"Below is an instruction that describes a task, paired with an input. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
        else:
            prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
        
        return prompt, output
    
    def tokenize_function(self, examples):
        """Tokenize the formatted examples"""
        formatted_examples = []
        
        for i in range(len(examples['instruction'])):
            prompt, response = self.format_instruction(
                examples['instruction'][i],
                examples.get('input', [''] * len(examples['instruction']))[i],
                examples['output'][i]
            )
            
            # Create full text for training
            full_text = prompt + response + self.tokenizer.eos_token
            formatted_examples.append(full_text)
        
        # Tokenize
        tokenized = self.tokenizer(
            formatted_examples,
            truncation=True,
            padding=False,
            max_length=self.max_length,
            return_tensors=None
        )
        
        # Create labels (same as input_ids for causal LM)
        tokenized['labels'] = tokenized['input_ids'].copy()
        
        return tokenized
    
    def process_dataset(self, raw_data):
        """Process raw data into training format"""
        # Convert to HuggingFace Dataset
        dataset = Dataset.from_list(raw_data)
        
        # Apply tokenization
        tokenized_dataset = dataset.map(
            self.tokenize_function,
            batched=True,
            remove_columns=dataset.column_names
        )
        
        return tokenized_dataset

# Usage example
processor = SFTDataProcessor("microsoft/DialoGPT-medium")
processed_dataset = processor.process_dataset(your_data)

Training Implementation

Complete Training Script

Here's a production-ready training script with all the essential components:

# train_sft.py
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_from_disk
import wandb
import json
import os
from dataclasses import dataclass
from typing import Optional

@dataclass
class SFTConfig:
    # Model settings
    model_name: str = "microsoft/DialoGPT-medium"
    max_length: int = 2048
    
    # Training settings
    output_dir: str = "./sft_results"
    num_train_epochs: int = 3
    per_device_train_batch_size: int = 4
    per_device_eval_batch_size: int = 4
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-5
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    
    # Optimization settings
    fp16: bool = True
    gradient_checkpointing: bool = True
    dataloader_num_workers: int = 4
    
    # Logging and evaluation
    logging_steps: int = 50
    eval_steps: int = 500
    save_steps: int = 500
    eval_accumulation_steps: int = 1
    
    # Data settings
    data_path: str = "./processed_data"
    
    # Experiment tracking
    wandb_project: Optional[str] = "sft-experiments"
    run_name: Optional[str] = None

class SFTTrainer:
    def __init__(self, config: SFTConfig):
        self.config = config
        self.setup_model_and_tokenizer()
        self.setup_data()
        
    def setup_model_and_tokenizer(self):
        """Initialize model and tokenizer"""
        print(f"Loading model: {self.config.model_name}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.model_name,
            use_fast=True
        )
        
        # Ensure we have a pad token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.model_name,
            torch_dtype=torch.float16 if self.config.fp16 else torch.float32,
            device_map="auto"
        )
        
        # Resize token embeddings if needed
        self.model.resize_token_embeddings(len(self.tokenizer))
        
    def setup_data(self):
        """Load and prepare datasets"""
        print(f"Loading data from: {self.config.data_path}")
        
        self.dataset = load_from_disk(self.config.data_path)
        
        print(f"Train samples: {len(self.dataset['train'])}")
        print(f"Eval samples: {len(self.dataset['validation'])}")
        
    def create_training_arguments(self):
        """Create training arguments"""
        return TrainingArguments(
            output_dir=self.config.output_dir,
            num_train_epochs=self.config.num_train_epochs,
            per_device_train_batch_size=self.config.per_device_train_batch_size,
            per_device_eval_batch_size=self.config.per_device_eval_batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
            warmup_ratio=self.config.warmup_ratio,
            
            # Optimization
            fp16=self.config.fp16,
            gradient_checkpointing=self.config.gradient_checkpointing,
            dataloader_num_workers=self.config.dataloader_num_workers,
            
            # Logging and saving
            logging_steps=self.config.logging_steps,
            eval_steps=self.config.eval_steps,
            save_steps=self.config.save_steps,
            evaluation_strategy="steps",
            save_strategy="steps",
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            
            # Experiment tracking
            report_to="wandb" if self.config.wandb_project else None,
            run_name=self.config.run_name,
            
            # Misc
            remove_unused_columns=False,
            dataloader_pin_memory=True,
        )
    
    def train(self):
        """Main training loop"""
        # Setup wandb
        if self.config.wandb_project:
            wandb.init(
                project=self.config.wandb_project,
                name=self.config.run_name,
                config=self.config.__dict__
            )
        
        # Create trainer
        training_args = self.create_training_arguments()
        
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,  # We're doing causal LM, not masked LM
            pad_to_multiple_of=8 if self.config.fp16 else None
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.dataset['train'],
            eval_dataset=self.dataset['validation'],
            tokenizer=self.tokenizer,
            data_collator=data_collator,
        )
        
        # Train the model
        print("Starting training...")
        trainer.train()
        
        # Save final model
        print("Saving final model...")
        trainer.save_model()
        self.tokenizer.save_pretrained(self.config.output_dir)
        
        # Save config
        with open(os.path.join(self.config.output_dir, "training_config.json"), "w") as f:
            json.dump(self.config.__dict__, f, indent=2)
        
        print(f"Training completed! Model saved to {self.config.output_dir}")

# Training script execution
if __name__ == "__main__":
    config = SFTConfig(
        model_name="microsoft/DialoGPT-medium",
        data_path="./processed_data",
        output_dir="./sft_model",
        num_train_epochs=3,
        per_device_train_batch_size=2,  # Adjust based on your GPU
        gradient_accumulation_steps=8,   # Effective batch size = 2 * 8 = 16
        learning_rate=2e-5,
        wandb_project="llm-sft-tutorial",
        run_name="medium-model-tutorial"
    )
    
    trainer = SFTTrainer(config)
    trainer.train()

Hyperparameter Optimization

Learning Rate and Schedule

Learning rate is crucial for SFT success:

# Learning rate finder
from transformers import get_scheduler

def find_optimal_lr(model, train_dataloader, start_lr=1e-8, end_lr=10, num_steps=100):
    """Simple learning rate finder"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=start_lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer, gamma=(end_lr/start_lr)**(1/num_steps)
    )
    
    losses = []
    lrs = []
    
    model.train()
    for step, batch in enumerate(train_dataloader):
        if step >= num_steps:
            break
            
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])
        
        # Backward pass
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
    
    return lrs, losses

# Recommended learning rates by model size
recommended_lr = {
    "small (< 1B params)": 5e-5,
    "medium (1-7B params)": 2e-5,
    "large (7-30B params)": 1e-5,
    "very_large (> 30B params)": 5e-6
}

Batch Size and Gradient Accumulation

Balance memory usage with training stability:

def calculate_effective_batch_size(per_device_batch_size, gradient_accumulation_steps, num_gpus=1):
    """Calculate the effective batch size"""
    return per_device_batch_size * gradient_accumulation_steps * num_gpus

# Example configurations for different GPU setups
gpu_configs = {
    "RTX 3070 (8GB)": {
        "per_device_train_batch_size": 1,
        "gradient_accumulation_steps": 16,
        "effective_batch_size": 16
    },
    "RTX 4090 (24GB)": {
        "per_device_train_batch_size": 4,
        "gradient_accumulation_steps": 8,
        "effective_batch_size": 32
    },
    "A100 (80GB)": {
        "per_device_train_batch_size": 8,
        "gradient_accumulation_steps": 4,
        "effective_batch_size": 32
    }
}

Monitoring and Evaluation

Real-time Monitoring

Track key metrics during training:

# custom_callbacks.py
from transformers import TrainerCallback
import wandb
import numpy as np

class DetailedLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        """Enhanced logging with additional metrics"""
        if logs:
            # Calculate additional metrics
            if 'train_loss' in logs:
                logs['perplexity'] = np.exp(logs['train_loss'])
            
            # Log gradient norms
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
            logs['grad_norm'] = total_norm
            
            # Log to wandb if available
            if wandb.run:
                wandb.log(logs, step=state.global_step)

Evaluation During Training

Implement meaningful evaluation metrics:

# evaluation.py
from transformers import pipeline
import torch
import numpy as np

class SFTEvaluator:
    def __init__(self, model_path, tokenizer_path):
        self.generator = pipeline(
            "text-generation",
            model=model_path,
            tokenizer=tokenizer_path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
    
    def evaluate_responses(self, test_instructions, max_length=512):
        """Generate responses for test instructions"""
        results = []
        
        for instruction in test_instructions:
            prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
            
            response = self.generator(
                prompt,
                max_length=max_length,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.generator.tokenizer.eos_token_id
            )[0]['generated_text']
            
            # Extract just the response part
            response_start = response.find("### Response:\n") + len("### Response:\n")
            generated_response = response[response_start:].strip()
            
            results.append({
                'instruction': instruction,
                'response': generated_response
            })
        
        return results
    
    def calculate_response_quality_metrics(self, results):
        """Calculate basic quality metrics"""
        metrics = {
            'avg_response_length': np.mean([len(r['response'].split()) for r in results]),
            'responses_with_repetition': sum(1 for r in results if self.has_repetition(r['response'])),
            'empty_responses': sum(1 for r in results if len(r['response'].strip()) == 0)
        }
        
        return metrics
    
    def has_repetition(self, text, threshold=0.3):
        """Simple repetition detection"""
        words = text.split()
        if len(words) < 10:
            return False
        
        # Check for repeated phrases
        for i in range(len(words) - 4):
            phrase = ' '.join(words[i:i+3])
            if text.count(phrase) > 1:
                return True
        
        return False

# Usage during training
test_instructions = [
    "Explain the concept of machine learning in simple terms.",
    "Write a short story about a robot learning to paint.",
    "What are the benefits of renewable energy?"
]

evaluator = SFTEvaluator("./sft_model", "./sft_model")
eval_results = evaluator.evaluate_responses(test_instructions)
quality_metrics = evaluator.calculate_response_quality_metrics(eval_results)

Common Issues and Solutions

Overfitting

Signs and solutions for overfitting:

def detect_overfitting(train_losses, eval_losses, patience=5):
    """Detect overfitting patterns"""
    if len(eval_losses) < patience + 1:
        return False
    
    # Check if eval loss has been increasing for 'patience' steps
    recent_eval = eval_losses[-patience:]
    return all(recent_eval[i] >= recent_eval[i-1] for i in range(1, len(recent_eval)))

# Solutions for overfitting:
overfitting_solutions = {
    "reduce_learning_rate": "Lower learning rate by 2-5x",
    "increase_weight_decay": "Use weight_decay=0.01-0.1",
    "early_stopping": "Stop when eval loss stops improving",
    "more_data": "Add more diverse training examples",
    "shorter_training": "Reduce number of epochs"
}

Training Instability

Monitor for training instabilities:

def check_training_stability(losses, window=50):
    """Check for training instabilities"""
    if len(losses) < window:
        return True
    
    recent_losses = losses[-window:]
    
    # Check for exploding gradients
    if any(loss > 10 * np.median(recent_losses) for loss in recent_losses[-10:]):
        return False
    
    # Check for loss oscillations
    oscillations = sum(1 for i in range(1, len(recent_losses)) 
                      if (recent_losses[i] - recent_losses[i-1]) * 
                         (recent_losses[i-1] - recent_losses[max(0, i-2)]) < 0)
    
    if oscillations > window * 0.7:  # Too many direction changes
        return False
    
    return True

Best Practices Summary

Training Strategy

  1. Start small: Begin with a subset of your data to validate the pipeline
  2. Monitor closely: Watch training and validation losses carefully
  3. Save checkpoints: Regular saves prevent loss of progress
  4. Use validation sets: Always hold out data for evaluation

Data Quality

  1. Clean formatting: Consistent templates and formatting
  2. Diverse examples: Cover all expected use cases
  3. Quality over quantity: Better to have fewer high-quality examples
  4. Regular audits: Periodically review your training data

Hyperparameter Selection

  1. Learning rate: Start with 2e-5 for most models
  2. Batch size: Balance memory constraints with stability
  3. Epochs: Usually 1-5 epochs is sufficient for SFT
  4. Warmup: Use 10% of total steps for warmup

The next post in this series will explore Parameter-Efficient Fine-tuning with LoRA and QLoRA, showing how to achieve similar results with dramatically reduced memory requirements and training time.

With the foundation of supervised fine-tuning under your belt, you're ready to tackle more advanced techniques that make LLM customization accessible even with limited computational resources.