LoRA Fine-Tuning in Python — Deep Dive
LoRA fine-tuning at production scale requires understanding rank selection theory, training dynamics, quantized variants, and deployment strategies. This guide covers the full lifecycle from dataset preparation through training to serving merged and unmerged LoRA models.
Mathematical foundation
Given a pretrained weight matrix W₀ ∈ ℝ^(d×k), LoRA constrains the weight update ΔW to a low-rank decomposition:
ΔW = B·A where B ∈ ℝ^(d×r), A ∈ ℝ^(r×k)
The forward pass becomes: h = W₀·x + (α/r)·B·A·x
Matrix A is initialized with Gaussian random values; B is initialized to zero. This means the LoRA output starts at zero and grows during training, preventing the sudden disruption that random initialization would cause.
The scaling factor α/r normalizes the LoRA contribution. When you double the rank, the per-dimension contribution halves, keeping the overall magnitude stable. This is why the convention of setting α = 2r works well — it provides consistent effective learning rates across different rank choices.
Rank selection strategy
Rank determines the expressiveness of the adaptation:
# Rank analysis utility
def estimate_rank_requirements(dataset_size: int, task_complexity: str) -> int:
"""Heuristic for initial rank selection."""
base_ranks = {
"style_transfer": 4, # Visual style changes
"subject_learning": 8, # DreamBooth-style subject
"domain_adaptation": 16, # New domain knowledge
"behavior_change": 32, # Significant behavioral shift
"instruction_tuning": 64, # Complex instruction following
}
rank = base_ranks.get(task_complexity, 16)
# Reduce rank for small datasets to prevent overfitting
if dataset_size < 100:
rank = min(rank, 8)
elif dataset_size < 1000:
rank = min(rank, 32)
return rank
Empirically, ranks above 64 rarely improve results and often cause overfitting. The sweet spot for most image generation tasks is 8–16; for language models, 16–32.
Training Stable Diffusion LoRA with diffusers
Dataset preparation
from datasets import Dataset
from PIL import Image
import os
def prepare_dataset(image_dir: str, prompt: str):
"""Create a dataset from a folder of images."""
images = []
for f in sorted(os.listdir(image_dir)):
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
img = Image.open(os.path.join(image_dir, f)).convert("RGB")
images.append({"image": img, "text": prompt})
return Dataset.from_list(images)
Training script
import torch
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from peft import LoraConfig
from transformers import CLIPTextModel, CLIPTokenizer
from torch.utils.data import DataLoader
def train_sd_lora(
model_id: str,
dataset,
output_dir: str,
rank: int = 8,
epochs: int = 100,
lr: float = 1e-4,
batch_size: int = 1,
):
# Load components
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
# Freeze base model
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# Add LoRA to U-Net
lora_config = LoraConfig(
r=rank,
lora_alpha=rank * 2,
target_modules=["to_q", "to_v", "to_k", "to_out.0"],
lora_dropout=0.0,
)
unet.add_adapter(lora_config)
# Only train LoRA parameters
trainable_params = [p for p in unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=1e-2)
unet.to("cuda", dtype=torch.float32)
vae.to("cuda", dtype=torch.float16)
text_encoder.to("cuda", dtype=torch.float16)
for epoch in range(epochs):
for batch in DataLoader(dataset, batch_size=batch_size, shuffle=True):
# Encode images to latent space
with torch.no_grad():
latents = vae.encode(
batch["pixel_values"].to("cuda", dtype=torch.float16)
).latent_dist.sample() * vae.config.scaling_factor
# Sample noise and timesteps
noise = torch.randn_like(latents)
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps,
(latents.shape[0],), device="cuda"
)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get text embeddings
with torch.no_grad():
tokens = tokenizer(batch["text"], padding=True, return_tensors="pt")
encoder_hidden = text_encoder(tokens.input_ids.to("cuda"))[0]
# Predict noise
model_pred = unet(
noisy_latents.float(), timesteps, encoder_hidden.float()
).sample
loss = torch.nn.functional.mse_loss(model_pred, noise.float())
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Save LoRA weights only
unet.save_attn_procs(output_dir)
QLoRA: quantized base + LoRA training
QLoRA combines 4-bit quantization of the base model with LoRA training, enabling fine-tuning of 70B parameter models on a single 48GB GPU:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
quantization_config=bnb_config,
device_map="auto",
)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=64,
lora_alpha=128,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
QLoRA uses NF4 (Normal Float 4-bit) quantization, which distributes quantization levels according to the normal distribution of neural network weights, preserving more information than uniform 4-bit quantization.
DreamBooth + LoRA
DreamBooth learns a specific subject from 3–10 images. Combined with LoRA, it becomes practical on consumer hardware:
# Using the diffusers training script
# accelerate launch train_dreambooth_lora.py \
# --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
# --instance_data_dir="./my_dog_photos" \
# --instance_prompt="a photo of sks dog" \
# --output_dir="./dreambooth_lora" \
# --resolution=1024 \
# --train_batch_size=1 \
# --gradient_accumulation_steps=4 \
# --learning_rate=1e-4 \
# --lr_scheduler="constant" \
# --max_train_steps=500 \
# --rank=8
# Key DreamBooth-specific settings:
# - "sks" is a rare token used as the subject identifier
# - Prior preservation loss prevents catastrophic forgetting
# - 500 steps is often sufficient for a single subject
Merging strategies
Permanent merge
Merge LoRA weights into the base model for zero-overhead inference:
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained("base_model_id")
peft_model = PeftModel.from_pretrained(base_model, "lora_adapter_path")
merged = peft_model.merge_and_unload()
merged.save_pretrained("merged_model")
Weighted merge of multiple LoRAs
from peft import PeftModel, set_peft_model_state_dict
import copy
def merge_loras(base_model, lora_paths: list, weights: list):
"""Merge multiple LoRAs with weighted combination."""
merged_state = None
for path, weight in zip(lora_paths, weights):
model = PeftModel.from_pretrained(copy.deepcopy(base_model), path)
state = model.state_dict()
if merged_state is None:
merged_state = {
k: v * weight for k, v in state.items()
if "lora" in k
}
else:
for k, v in state.items():
if "lora" in k:
merged_state[k] += v * weight
final_model = PeftModel.from_pretrained(base_model, lora_paths[0])
set_peft_model_state_dict(final_model, merged_state)
return final_model.merge_and_unload()
Training diagnostics
Learning rate and loss monitoring
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("runs/lora_training")
def log_training_step(step, loss, lr, grad_norm):
writer.add_scalar("train/loss", loss, step)
writer.add_scalar("train/lr", lr, step)
writer.add_scalar("train/grad_norm", grad_norm, step)
# LoRA-specific: track adapter weight norms
for name, param in model.named_parameters():
if "lora" in name and param.requires_grad:
writer.add_scalar(f"weights/{name}", param.norm().item(), step)
Overfitting detection
LoRA is prone to overfitting on small datasets. Signs include:
- Training loss drops below 0.01 rapidly
- Generated outputs look like exact copies of training images
- Validation loss diverges from training loss
Mitigations: reduce rank, increase dropout, use prior preservation loss (DreamBooth), or augment training data.
Deployment patterns
Adapter hot-swapping in production
class LoRAServer:
def __init__(self, base_model_id: str, adapter_dir: str):
self.base_model = AutoModelForCausalLM.from_pretrained(
base_model_id, torch_dtype=torch.float16, device_map="auto"
)
self.adapters = {}
self._load_adapters(adapter_dir)
def _load_adapters(self, adapter_dir: str):
for name in os.listdir(adapter_dir):
path = os.path.join(adapter_dir, name)
if os.path.isdir(path):
self.adapters[name] = path
def generate(self, prompt: str, adapter: str = None, weight: float = 1.0):
model = self.base_model
if adapter and adapter in self.adapters:
model = PeftModel.from_pretrained(
model, self.adapters[adapter]
)
# Scale adapter influence
for name, module in model.named_modules():
if hasattr(module, "scaling"):
module.scaling["default"] = weight
return model.generate(
self.tokenize(prompt),
max_new_tokens=256,
)
One thing to remember: Effective LoRA training combines proper rank sizing (typically 8–32 for images, 16–64 for text), careful learning rate selection (1e-4 to 5e-5), and overfitting awareness — and the real production value comes from hot-swappable adapters that let one base model serve dozens of specialized use cases.
See Also
- Diffusion Models Stable Diffusion and DALL-E don't 'draw' your images — they unspoil a scrambled mess until a picture emerges. Here's the surprisingly simple idea behind it.
- Python Controlnet Image Control Find out how ControlNet lets you boss around an AI artist by giving it sketches, poses, and outlines to follow.
- Python Gan Training Patterns Learn how two neural networks compete like an art forger and a detective to create incredibly realistic fake images.
- Python Image Generation Pipelines Discover how Python chains together multiple steps to turn your ideas into polished AI-generated images, like a factory assembly line for pictures.
- Python Image Inpainting Learn how Python can magically fill in missing parts of a photo, like erasing something and having the picture fix itself.