Overview
This notebook provides a comprehensive, production-ready guide to fine-tuning Large Language Models (LLMs) using Supervised Fine-Tuning (SFT) followed by Direct Preference Optimization (DPO) using the Axolotl framework.
Why SFT → DPO?
- SFT teaches the model the task format and basic capabilities
- DPO refines the model to prefer better responses over worse ones
- This two-stage approach is more stable than direct DPO on base models
- DPO is simpler than RLHF (no reward model, no PPO complexity)
Technical Architecture
Base Model (e.g., Llama-3.1-8B)
↓
SFT Training (Instruction Following)
↓
SFT Model Checkpoint
↓
DPO Training (Preference Alignment)
↓
Final Aligned Model
Part 1: Environment Setup
First, we’ll install Axolotl and its dependencies. Axolotl is a powerful framework that handles:
- Data preprocessing and formatting
- Model loading with quantization support
- Training with various optimizations (LoRA, Flash Attention, etc.)
- Evaluation and checkpointing
# Install Axolotl and dependencies
# Note: This uses the latest Axolotl with Flash Attention 2 support
!pip install -q -U axolotl[flash-attn,deepspeed] accelerate transformers datasets bitsandbytes peft
# Import required libraries
import torch
import yaml
import os
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
Output:
PyTorch version: 2.5.1
CUDA available: False
GPU: None
Part 2: Understanding the Data Formats
SFT Data Format
For SFT, we need instruction-response pairs. Modern formats:
- ChatML format: The current standard used by most models (GPT-4, Claude, Llama-3+)
- Alpaca format: Legacy single-turn instruction-input-output (mostly deprecated)
- ShareGPT format: Older multi-turn format (being phased out in favor of ChatML)
ChatML is the modern standard with better support and cleaner structure:
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."}
]
}
DPO Data Format
DPO requires triplets:
- Prompt: The input question/instruction
- Chosen: The preferred response
- Rejected: The less preferred response
Let’s load datasets from HuggingFace:
# SFT Dataset: High-quality instruction-following data
# Using OpenOrca dataset - 1M GPT-4 and GPT-3.5 instruction-response pairs
sft_dataset = load_dataset("Open-Orca/OpenOrca", split="train[:5000]") # Using 5k for demo
print("\n=== SFT Dataset Example ===")
print(sft_dataset[0])
print(f"\nSFT Dataset size: {len(sft_dataset):,}")
# DPO Dataset: Preference pairs
# Using Anthropic HH-RLHF dataset with human preferences
dpo_dataset = load_dataset("Anthropic/hh-rlhf", split="train[:5000]")
print("\n=== DPO Dataset Example ===")
print("Keys:", dpo_dataset.column_names)
print("\nChosen:", dpo_dataset[0]['chosen'][:200])
print("\nRejected:", dpo_dataset[0]['rejected'][:200])
print(f"\nDPO Dataset size: {len(dpo_dataset):,}")
Output:
=== SFT Dataset Example ===
{'id': 'niv.242684', 'system_prompt': '', 'question': "You will be given a definition of a task first...", 'response': '[\n ["AFC Ajax (amateurs)", "has ground", "Sportpark De Toekomst"],\n ["Ajax Youth Academy", "plays at", "Sportpark De Toekomst"]\n]'}
SFT Dataset size: 5,000
=== DPO Dataset Example ===
Keys: ['chosen', 'rejected']
Chosen: H: What are some cuss words in english?
A: Here's an incomplete list...
Rejected: H: What are some cuss words in english?
A: Here's an incomplete list...
DPO Dataset size: 5,000
Part 3: Data Preprocessing
Why Preprocessing Matters
- Format Standardization: Different datasets have different structures
- Tokenization: Convert text to tokens the model understands
- Special Tokens: Add chat templates, EOS tokens, etc.
- Length Filtering: Remove too-long sequences to save memory
Axolotl handles most of this, but we need to prepare our data correctly:
# Convert OpenOrca to ChatML format (what Axolotl expects)
def convert_to_chatml_format(example):
"""
Convert OpenOrca format to ChatML conversation format.
ChatML format is the modern standard with a simple structure:
{
"messages": [
{"role": "system", "content": "..."},
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
Benefits of ChatML:
- Native support in most modern LLMs (GPT-4, Claude, Llama-3+)
- Better handling of multi-turn conversations
- Standardized across the industry
"""
messages = []
# Add system message if present
if example.get('system_prompt') and len(example['system_prompt'].strip()) > 0:
messages.append({
'role': 'system',
'content': example['system_prompt']
})
# Add user question
messages.append({
'role': 'user',
'content': example['question']
})
# Add assistant response
messages.append({
'role': 'assistant',
'content': example['response']
})
return {'messages': messages}
# Apply conversion
sft_dataset_formatted = sft_dataset.map(
convert_to_chatml_format,
remove_columns=sft_dataset.column_names
)
print("\n=== Formatted SFT Example (ChatML) ===")
import json
print(json.dumps(sft_dataset_formatted[0], indent=2))
Output:
{
"messages": [
{
"content": "You will be given a definition of a task first...",
"role": "user"
},
{
"content": "[\n [\"AFC Ajax (amateurs)\", \"has ground\", \"Sportpark De Toekomst\"],\n [\"Ajax Youth Academy\", \"plays at\", \"Sportpark De Toekomst\"]\n]",
"role": "assistant"
}
]
}
# Process DPO data - ensure it has required fields: prompt, chosen, rejected
def process_dpo_example(example):
"""
Extract prompt, chosen, and rejected from Anthropic HH-RLHF format.
The dataset uses special tokens:
- \n\nHuman: marks user messages
- \n\nAssistant: marks assistant messages
"""
# Split on assistant marker to separate prompt from response
chosen_parts = example['chosen'].split('\n\nAssistant:')
rejected_parts = example['rejected'].split('\n\nAssistant:')
return {
'prompt': chosen_parts[0], # Everything before first assistant response
'chosen': chosen_parts[1].strip() if len(chosen_parts) > 1 else '',
'rejected': rejected_parts[1].strip() if len(rejected_parts) > 1 else ''
}
dpo_dataset_formatted = dpo_dataset.map(
process_dpo_example,
remove_columns=dpo_dataset.column_names
)
# Filter out any empty examples
dpo_dataset_formatted = dpo_dataset_formatted.filter(
lambda x: len(x['chosen']) > 0 and len(x['rejected']) > 0
)
print("\n=== Formatted DPO Example ===")
print(f"Prompt: {dpo_dataset_formatted[0]['prompt'][:150]}...")
print(f"\nChosen: {dpo_dataset_formatted[0]['chosen'][:150]}...")
print(f"\nRejected: {dpo_dataset_formatted[0]['rejected'][:150]}...")
# Save datasets to disk for Axolotl to load
os.makedirs('data', exist_ok=True)
sft_dataset_formatted.to_json('data/sft_data.jsonl')
dpo_dataset_formatted.to_json('data/dpo_data.jsonl')
print("✓ Datasets saved to data/ directory")
Part 4: Supervised Fine-Tuning (SFT) Configuration
Understanding the SFT Config
Key components:
- Base Model: We’ll use
meta-llama/Llama-3.2-1B(small for demo, use 8B+ for production) - LoRA: Low-Rank Adaptation - efficient fine-tuning method
- Only trains small adapter weights (~0.1-1% of model params)
- Rank (r): Higher = more capacity but slower (8-64 typical)
- Alpha: Scaling factor, usually 2x rank
- Training Params:
- Learning rate: 2e-5 is good for LoRA (higher than full fine-tuning)
- Batch size: Trade-off between memory and training speed
- Gradient accumulation: Simulate larger batches with limited memory
- Memory Optimization:
- Flash Attention 2: 2-4x faster attention
- Gradient checkpointing: Trade compute for memory
- BF16: Better than FP16 for training stability
# SFT Configuration for Axolotl
sft_config = {
# Base model configuration
'base_model': 'meta-llama/Llama-3.2-1B', # Use 'meta-llama/Llama-3.1-8B' for production
'model_type': 'LlamaForCausalLM',
'tokenizer_type': 'LlamaTokenizer',
# Trust remote code (needed for some models)
'trust_remote_code': True,
# Data configuration - using modern ChatML format
'datasets': [
{
'path': 'data/sft_data.jsonl',
'type': 'chat_template', # Use native chat template
'chat_template': 'chatml',
'field_messages': 'messages', # Field containing the message list
'message_field_role': 'role', # Field for role (system/user/assistant)
'message_field_content': 'content', # Field for message content
}
],
# Data processing
'dataset_prepared_path': 'last_run_prepared',
'val_set_size': 0.05, # 5% for validation
'sequence_len': 2048, # Max sequence length
# LoRA configuration
'adapter': 'lora',
'lora_r': 32, # Rank - controls adapter capacity
'lora_alpha': 64, # Scaling factor (typically 2x rank)
'lora_dropout': 0.05, # Regularization
'lora_target_modules': [ # Which layers to adapt
'q_proj',
'k_proj',
'v_proj',
'o_proj',
'gate_proj',
'up_proj',
'down_proj',
],
# Training hyperparameters
'num_epochs': 1, # Single epoch often sufficient with good data
'micro_batch_size': 2, # Per-device batch size
'gradient_accumulation_steps': 4, # Effective batch = 2 * 4 = 8
'learning_rate': 2e-5,
'lr_scheduler': 'cosine', # Cosine annealing
'warmup_steps': 100, # LR warmup for stability
# Optimizer
'optimizer': 'adamw_torch',
'weight_decay': 0.01,
'adam_beta2': 0.95, # Good for LLMs
'adam_epsilon': 1e-8,
# Memory optimization
'flash_attention': True, # Use Flash Attention 2
'gradient_checkpointing': True, # Save memory at cost of compute
'bf16': True, # BFloat16 training
'fp16': False, # Don't use FP16 if using BF16
'tf32': True, # TensorFloat32 for A100 GPUs
# Logging and evaluation
'logging_steps': 10,
'eval_steps': 50,
'save_steps': 100,
'save_total_limit': 3, # Keep only 3 checkpoints
'output_dir': './sft_output',
# Special tokens
'special_tokens': {
'pad_token': '<|pad|>',
'eos_token': '<|end_of_text|>',
},
}
# Save config
with open('sft_config.yml', 'w') as f:
yaml.dump(sft_config, f, default_flow_style=False)
print("✓ SFT config saved to sft_config.yml")
print(f"\nKey settings:")
print(f" Model: {sft_config['base_model']}")
print(f" Data format: ChatML (modern standard)")
print(f" LoRA rank: {sft_config['lora_r']}")
print(f" Effective batch size: {sft_config['micro_batch_size'] * sft_config['gradient_accumulation_steps']}")
print(f" Learning rate: {sft_config['learning_rate']}")
Output:
✓ SFT config saved to sft_config.yml
Key settings:
Model: meta-llama/Llama-3.2-1B
Data format: ChatML (modern standard)
LoRA rank: 32
Effective batch size: 8
Learning rate: 2e-05
Part 5: Running SFT Training
Training Process
Axolotl will:
- Load and tokenize the dataset
- Initialize model with LoRA adapters
- Train with gradient accumulation and checkpointing
- Save best checkpoints
- Merge LoRA weights back into base model (optional)
Memory Requirements
For Llama-3.2-1B with LoRA:
- ~4-6 GB VRAM
For Llama-3.1-8B with LoRA:
- ~16-20 GB VRAM (single A100/A6000)
- Can use 4-bit quantization for ~10 GB VRAM
# Run SFT training with Axolotl
# Note: In a real scenario, run this via command line for better logging:
# accelerate launch -m axolotl.cli.train sft_config.yml
!accelerate launch -m axolotl.cli.train sft_config.yml
Understanding SFT Training Dynamics
What to monitor:
- Training Loss: Should decrease smoothly
- If it plateaus early: increase learning rate or model capacity
- If it’s unstable: decrease learning rate or increase warmup
- Validation Loss: Should track training loss
- If it diverges: you’re overfitting, reduce epochs or add regularization
- Gradient Norm: Should be stable (0.5-2.0 range)
- If exploding: decrease LR or enable gradient clipping
- Learning Rate Schedule: Cosine decay from peak to 0
- Warmup prevents initial instability
- Decay helps convergence
Expected results after SFT:
- Model understands instruction format
- Can generate coherent, on-topic responses
- May not always prefer “better” responses (that’s what DPO fixes)
# Load the SFT model for testing
from transformers import pipeline
# Merge LoRA weights into base model (optional, for easier deployment)
!python -m axolotl.cli.merge_lora sft_config.yml --lora_model_dir="./sft_output"
# Load merged model
sft_model_path = "./sft_output/merged"
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)
model = AutoModelForCausalLM.from_pretrained(
sft_model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Create text generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=128
)
# Test the SFT model
test_prompt = "What is the capital of France?"
response = pipe(test_prompt)[0]['generated_text']
print("=== SFT Model Test ===")
print(f"Prompt: {test_prompt}")
print(f"Response: {response}")
Part 6: Direct Preference Optimization (DPO) Configuration
Understanding DPO
DPO is an elegant alternative to RLHF that directly optimizes the model to prefer chosen responses over rejected ones.
Key insight: Instead of training a reward model and using RL (PPO), DPO formulates preference learning as a classification problem.
The DPO loss function:
\[L_{DPO} = -\log \sigma\left(\beta \left[\log \pi_\theta(y_w|x) - \log \pi_\theta(y_l|x) - \log \pi_{ref}(y_w|x) + \log \pi_{ref}(y_l|x)\right]\right)\]Where:
- $\pi_\theta$: Current model being trained
- $\pi_{ref}$: Reference model (frozen SFT checkpoint)
- $y_w$: Chosen (winner) response
- $y_l$: Rejected (loser) response
- $\beta$: Temperature parameter (controls strength of optimization)
Why this works:
- Increases probability of chosen responses
- Decreases probability of rejected responses
- KL penalty (via reference model) prevents mode collapse
- No separate reward model needed!
DPO Configuration Differences from SFT
- Base model: Use SFT checkpoint (not base model)
- Reference model: Keep frozen copy of SFT model
- Beta: Controls optimization strength (0.1-0.5 typical)
- Learning rate: Usually lower than SFT (1e-6 to 1e-5)
- Dataset: Requires (prompt, chosen, rejected) triplets
# DPO Configuration
dpo_config = {
# Base model is now our SFT checkpoint
'base_model': './sft_output/merged',
'model_type': 'LlamaForCausalLM',
'tokenizer_type': 'LlamaTokenizer',
'trust_remote_code': True,
# DPO-specific: reference model (frozen SFT model)
'dpo_reference_model': './sft_output/merged',
# Data configuration - DPO format
'datasets': [
{
'path': 'data/dpo_data.jsonl',
'type': 'dpo', # DPO dataset type
}
],
'val_set_size': 0.05,
'sequence_len': 2048,
# LoRA configuration (same as SFT)
'adapter': 'lora',
'lora_r': 32,
'lora_alpha': 64,
'lora_dropout': 0.05,
'lora_target_modules': [
'q_proj', 'k_proj', 'v_proj', 'o_proj',
'gate_proj', 'up_proj', 'down_proj',
],
# DPO-specific hyperparameters
'dpo_beta': 0.1, # KL penalty coefficient
# Lower beta = more aggressive optimization, higher = more conservative
# Start with 0.1, increase if model degrades
# Training hyperparameters (more conservative than SFT)
'num_epochs': 1,
'micro_batch_size': 1, # DPO needs 2x memory (stores chosen + rejected)
'gradient_accumulation_steps': 8, # Compensate for smaller batch
'learning_rate': 5e-6, # Lower than SFT to avoid forgetting
'lr_scheduler': 'cosine',
'warmup_steps': 50,
# Optimizer
'optimizer': 'adamw_torch',
'weight_decay': 0.01,
'adam_beta2': 0.95,
'adam_epsilon': 1e-8,
# Memory optimization (same as SFT)
'flash_attention': True,
'gradient_checkpointing': True,
'bf16': True,
'fp16': False,
'tf32': True,
# Logging and evaluation
'logging_steps': 10,
'eval_steps': 50,
'save_steps': 100,
'save_total_limit': 3,
'output_dir': './dpo_output',
# DPO evaluation metric
'eval_metric': 'dpo_accuracy', # % of times model prefers chosen over rejected
}
# Save config
with open('dpo_config.yml', 'w') as f:
yaml.dump(dpo_config, f, default_flow_style=False)
print("✓ DPO config saved to dpo_config.yml")
print(f"\nKey settings:")
print(f" Base (SFT) model: {dpo_config['base_model']}")
print(f" DPO beta: {dpo_config['dpo_beta']}")
print(f" Learning rate: {dpo_config['learning_rate']} (lower than SFT)")
print(f" Effective batch size: {dpo_config['micro_batch_size'] * dpo_config['gradient_accumulation_steps']}")
Output:
✓ DPO config saved to dpo_config.yml
Key settings:
Base (SFT) model: ./sft_output/merged
DPO beta: 0.1
Learning rate: 5e-06 (lower than SFT)
Effective batch size: 8
Part 7: Running DPO Training
What Happens During DPO Training
- Forward pass on chosen: Compute log probabilities for preferred response
- Forward pass on rejected: Compute log probabilities for non-preferred response
- Reference model: Compute log probs from frozen SFT model (for KL penalty)
- DPO loss: Optimize to increase margin between chosen and rejected
- Backward pass: Update only the policy model (reference stays frozen)
Expected Memory Usage
DPO requires ~2x memory of SFT because:
- Must process both chosen and rejected sequences
- Must load reference model (can share weights with policy to save memory)
For Llama-3.2-1B:
- ~8-10 GB VRAM
For Llama-3.1-8B:
- ~24-32 GB VRAM (use 4-bit quant for 16GB)
# Run DPO training
!accelerate launch -m axolotl.cli.train dpo_config.yml
Understanding DPO Training Metrics
Key metrics to monitor:
- DPO Loss: Should decrease steadily
- Starts around 0.6-0.8 (random chance)
- Good final value: 0.2-0.4
- If it goes to 0: might be overfitting
- Reward Margin: Difference in implied rewards between chosen and rejected
- Should increase during training
- Measures separation in model’s preferences
- Accuracy: How often model assigns higher probability to chosen response
- Starts at ~50% (random)
- Good final value: 70-85%
- If 100%: likely overfitting
- KL Divergence: Drift from reference model
- Should be small (< 10)
- If too high: model forgetting SFT capabilities
- Increase beta if KL is too high
Common issues:
- Reward hacking: Model exploits preference function, degrades quality
- Solution: Lower learning rate, increase beta, add more diverse data
- Catastrophic forgetting: Model loses SFT capabilities
- Solution: Lower learning rate, stronger KL penalty (higher beta)
- Slow convergence: Training takes too long
- Solution: Increase learning rate, clean data for stronger signal
# Merge DPO LoRA weights
!python -m axolotl.cli.merge_lora dpo_config.yml --lora_model_dir="./dpo_output"
print("✓ DPO model merged and ready")
Summary and Key Takeaways
What We Accomplished
- Set up Axolotl framework for efficient LLM training
- Prepared datasets for both SFT and DPO
- Fine-tuned a base model with SFT (instruction following)
- Further aligned with DPO (preference learning)
Best Practices
Data Quality:
- SFT: Use diverse, high-quality instruction-response pairs
- DPO: Ensure clear preference signal (chosen should be notably better)
- Filter out low-quality, contradictory, or harmful examples
Training:
- Start with smaller learning rates and scale up if needed
- Monitor validation metrics - don’t just optimize training loss
- Use gradient checkpointing and Flash Attention for large models
- DPO should use lower LR than SFT to avoid catastrophic forgetting
Memory Optimization:
- LoRA for efficient fine-tuning (~10% of full fine-tuning memory)
- Flash Attention 2 for 2-4x speedup and memory savings
- Gradient checkpointing trades 30% compute for 50% memory savings
- BF16 for stability, better than FP16 for training
Model Evaluation:
- Use held-out validation set (5-10% of data)
- Monitor multiple metrics: loss, perplexity, task-specific metrics
- For DPO: track accuracy, reward margin, and KL divergence
- Test with qualitative examples throughout training
Production Considerations
Scaling to Larger Models:
- Llama-3.1-8B: 16-24 GB VRAM with LoRA
- Llama-3.1-70B: Use multi-GPU with DeepSpeed ZeRO-3 or 4-bit quantization
- For very large models: Consider QLoRA (4-bit quantized base + LoRA)
Deployment:
- Merge LoRA weights into base model for simpler serving
- Use vLLM or TensorRT-LLM for optimized inference
- Quantize to INT8/INT4 for production if latency/cost critical
Data Pipeline:
- Invest in high-quality data curation
- Use GPT-4/Claude for synthetic data generation if needed
- Regularly audit for bias, toxicity, and factual errors
- Version control your datasets
Final Notes
This notebook provided a complete, production-ready pipeline for SFT + DPO fine-tuning. The techniques here scale from 1B to 70B+ parameter models with appropriate hardware.
Key Resources:
- Axolotl Documentation: https://github.com/OpenAccess-AI-Collective/axolotl
- DPO Paper: https://arxiv.org/abs/2305.18290
- LoRA Paper: https://arxiv.org/abs/2106.09685
- Flash Attention: https://arxiv.org/abs/2205.14135
Next Steps:
- Experiment with different hyperparameters (learning rate, LoRA rank, DPO beta)
- Try other alignment methods (PPO, REINFORCE, Constitutional AI)
- Implement evaluation harnesses (MT-Bench, AlpacaEval)
- Deploy with inference optimization (vLLM, quantization)
Comments