Python Crop Disease Detection — Deep Dive

System architecture

A production crop disease detection system has three stages: a training pipeline (offline, GPU-intensive), an inference service (online, latency-sensitive), and a feedback loop (continuous improvement from user corrections).

Training: Curated Dataset → Augmentation → Fine-tune CNN → Validate → Export ONNX
Inference: Photo Upload → Preprocess → Leaf Segmentation → Classification → Result + Confidence
Feedback: User Correction → Review Queue → Retrain Trigger

Dataset preparation

Quality training data is the bottleneck. Raw disease image datasets require substantial cleaning:

import torch
from torchvision import transforms
from torch.utils.data import DataLoader, WeightedRandomSampler
from pathlib import Path
import numpy as np
from collections import Counter

def create_balanced_loader(
    dataset_root: Path,
    img_size: int = 224,
    batch_size: int = 32
) -> DataLoader:
    """Create a class-balanced dataloader for imbalanced disease datasets."""
    
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(30),
        transforms.ColorJitter(
            brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1
        ),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])
    
    from torchvision.datasets import ImageFolder
    dataset = ImageFolder(str(dataset_root), transform=train_transform)
    
    # Compute class weights for balanced sampling
    class_counts = Counter(dataset.targets)
    weights = [1.0 / class_counts[t] for t in dataset.targets]
    sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
    
    return DataLoader(
        dataset, batch_size=batch_size, sampler=sampler,
        num_workers=4, pin_memory=True
    )

Class imbalance is severe in disease datasets — healthy leaves might outnumber any single disease 10:1. Weighted sampling prevents the model from defaulting to “healthy” for everything.

Transfer learning with EfficientNet

EfficientNet provides the best accuracy-per-FLOP ratio for mobile deployment:

import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

def build_disease_classifier(num_classes: int, freeze_backbone: bool = True):
    """Build an EfficientNet-B0 classifier for crop diseases."""
    model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    
    if freeze_backbone:
        for param in model.features.parameters():
            param.requires_grad = False
        # Unfreeze last 2 blocks for fine-tuning
        for param in model.features[-2:].parameters():
            param.requires_grad = True
    
    # Replace classifier head
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.3),
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(256, num_classes)
    )
    
    return model

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
    
    return total_loss / total, correct / total

A two-phase training strategy works well: first freeze the backbone and train only the classifier head for 5 epochs, then unfreeze the last few backbone blocks and train end-to-end with a lower learning rate (1e-4 vs 1e-2).

Leaf segmentation as preprocessing

Segmenting the leaf from the background before classification significantly improves field performance:

import cv2
import numpy as np

def segment_leaf(image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Segment leaf from background using color-based thresholding + GrabCut."""
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    
    # Green-ish mask (covers most healthy and diseased leaf tissue)
    lower_green = np.array([15, 20, 20])
    upper_green = np.array([95, 255, 255])
    mask_green = cv2.inRange(hsv, lower_green, upper_green)
    
    # Brown/yellow mask (covers necrotic tissue)
    lower_brown = np.array([5, 30, 30])
    upper_brown = np.array([25, 255, 255])
    mask_brown = cv2.inRange(hsv, lower_brown, upper_brown)
    
    mask = cv2.bitwise_or(mask_green, mask_brown)
    
    # Clean up with morphological operations
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    # Refine with GrabCut
    bgd_model = np.zeros((1, 65), np.float64)
    fgd_model = np.zeros((1, 65), np.float64)
    gc_mask = np.where(mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype(np.uint8)
    
    cv2.grabCut(image, gc_mask, None, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
    final_mask = np.where(
        (gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0
    ).astype(np.uint8)
    
    segmented = cv2.bitwise_and(image, image, mask=final_mask)
    return segmented, final_mask

For higher-quality segmentation, U-Net trained on leaf boundary annotations outperforms color thresholding, especially with complex backgrounds.

Attention visualization with Grad-CAM

Grad-CAM shows which regions of the image influenced the model’s decision — critical for building trust with agronomists:

import torch
import torch.nn.functional as F

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.gradients = None
        self.activations = None
        
        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_full_backward_hook(self._save_gradient)
    
    def _save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate(self, input_tensor: torch.Tensor, class_idx: int) -> np.ndarray:
        self.model.eval()
        output = self.model(input_tensor)
        
        self.model.zero_grad()
        target = output[0, class_idx]
        target.backward()
        
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, input_tensor.shape[2:], mode="bilinear")
        cam = cam.squeeze().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam

When the heatmap highlights the lesion area, agronomists trust the prediction. When it highlights the background or leaf veins, it flags a model that learned spurious correlations.

Mobile deployment with ONNX

Converting to ONNX and quantizing enables inference on smartphones:

import torch.onnx

def export_to_onnx(model, num_classes: int, output_path: str):
    """Export trained model to ONNX format for mobile deployment."""
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224)
    
    torch.onnx.export(
        model, dummy_input, output_path,
        input_names=["image"],
        output_names=["prediction"],
        dynamic_axes={"image": {0: "batch"}, "prediction": {0: "batch"}},
        opset_version=13
    )

def quantize_onnx(input_path: str, output_path: str):
    """Apply dynamic quantization for smaller model size."""
    from onnxruntime.quantization import quantize_dynamic, QuantType
    
    quantize_dynamic(
        input_path, output_path,
        weight_type=QuantType.QUInt8
    )

EfficientNet-B0 quantized to INT8 produces a ~15MB model that runs inference in under 200ms on mid-range Android phones. For iOS, Core ML conversion via coremltools provides native Metal GPU acceleration.

Multi-disease and severity estimation

Production systems need to handle co-occurring diseases and severity grading:

def multi_label_prediction(
    model, image_tensor: torch.Tensor, threshold: float = 0.5
) -> list[dict]:
    """Predict multiple diseases with severity scores."""
    model.eval()
    with torch.no_grad():
        logits = model(image_tensor)
        probs = torch.sigmoid(logits).squeeze()
    
    # For multi-label: each output neuron is independent
    results = []
    disease_names = [
        "early_blight", "late_blight", "leaf_mold",
        "septoria_spot", "spider_mites", "bacterial_spot",
        "mosaic_virus", "yellow_curl", "healthy"
    ]
    
    for i, (name, prob) in enumerate(zip(disease_names, probs)):
        if prob > threshold:
            results.append({
                "disease": name,
                "confidence": round(prob.item(), 3),
                "severity": classify_severity(prob.item())
            })
    
    return results if results else [{"disease": "healthy", "confidence": 0.95}]

def classify_severity(confidence: float) -> str:
    if confidence > 0.85:
        return "severe"
    elif confidence > 0.6:
        return "moderate"
    return "mild"

Tradeoffs and limitations

Lab vs. field accuracy: Expect 15-25% accuracy drops from benchmark to field deployment. Budget for field-specific data collection and continuous model updates.

Generalization across regions: A model trained on European wheat diseases may miss Asian wheat rust variants. Regional fine-tuning with local disease photos is necessary.

Speed vs. accuracy: EfficientNet-B0 (5.3M parameters) runs on phones; EfficientNet-B7 (66M parameters) is more accurate but requires server-side inference. The choice depends on connectivity — offline-first apps need smaller models.

Feedback loops: The most impactful long-term investment is building a correction mechanism where agronomists flag wrong predictions, creating a growing labeled dataset for retraining.

One thing to remember: Production crop disease detection requires not just model training but a complete pipeline — leaf segmentation, balanced training, Grad-CAM verification, mobile export, and continuous improvement from field feedback — to bridge the gap between benchmark scores and real-farm utility.

pythonagriculturecomputer-visionmachine-learningdeep-learning

See Also