Convolutional neural networks (CNNs) are the standard for computer vision. This tutorial implements the same CIFAR-10 image classification CNN in both frameworks to understand their different approaches.
60,000 color images of size 32ร32 in 10 categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck. A classic benchmark for validating CNN architectures.
Definition: A neural network architecture specialized for images. It uses convolutional layers to extract spatial features hierarchically: simple edges โ textures โ objects โ abstract concepts.
Purpose: Process images efficiently by exploiting their local 2D spatial structure, unlike classical networks that treat pixels independently.
Why here: CNNs are state-of-the-art for computer vision. A traditional MLP on pixels ignores 2D spatial structure โ a 3x3 filter detecting an edge must learn independently at each position, which is inefficient and doesn't generalize well.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
print(f"TensorFlow {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")
# โโ Loading and preprocessing data โโ
(X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
# Normalize: [0, 255] โ [0, 1]
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
# One-hot encoding of labels
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
print(f"Train: {X_train.shape} | Test: {X_test.shape}")
# โโ Data Augmentation โโ
data_augmentation = keras.Sequential([
layers.RandomFlip('horizontal'),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
layers.RandomTranslation(0.1, 0.1),
])
# โโ CNN Architecture โโ
def build_model():
inputs = keras.Input(shape=(32, 32, 3))
# Augmentation only during training
x = data_augmentation(inputs)
# Block 1 โ extract basic features (edges, textures, corners)
x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Dropout(0.25)(x)
# Block 2 โ more complex features
x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Dropout(0.25)(x)
# Block 3 โ abstract features
x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.GlobalAveragePooling2D()(x) # More efficient than Flatten
# Classification layer
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10, activation='softmax')(x)
return keras.Model(inputs, outputs)
model = build_model()
model.summary()
# โโ Compilation โโ
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
Definition: A convolution applies a filter (small 3ร3 or 5ร5 kernel) across an image by sliding it position by position. At each position, the sum of element-wise products is calculated (discrete convolution). A feature map is the collection of all these sums โ it represents the "detection" of a pattern everywhere in the image.
Purpose: Extract local patterns (horizontal/vertical edges, corners, textures) from raw pixels. The 32 filters in parallel create 32 feature maps, each detecting a different pattern.
Why here: Convolution exploits the 2D spatial structure of images. Instead of one neuron per pixel (millions of parameters), we use a few shared filters (hundreds of parameters) that detect patterns everywhere โ this is efficient and generalizes well.
Definition: Technique that normalizes activations at each layer to mean zero and unit variance, independently for each channel/feature map. During training, normalization uses the current batch statistics; during inference, exponentially smoothed statistics (running mean/variance) are used.
Purpose: Accelerate training, allow higher learning rates, and improve generalization.
Why here: Without BatchNormalization, training deep CNNs (>10 layers) is unstable โ activations become either too large or too small. With BatchNormalization, training often converges 2-3x faster and is much more stable.
Definition: Operation that reduces spatial dimension by downsampling by a factor (usually 2). MaxPooling(2,2) divides the image into 2ร2 windows and takes the maximum of each window.
Purpose: Reduce spatial complexity and increase the receptive field of subsequent layers.
Why here: Pooling reduces computational cost and memory, and forces the network to learn features invariant to small translations.
Definition: Regularization technique that randomly "turns off" a fraction of neurons during training (e.g., Dropout(0.5) turns off 50% of neurons at each forward pass). During prediction/inference, all neurons are active but their weights are rescaled by (1 - dropout_rate).
Purpose: Prevent co-adaptation of neurons โ no neuron can rely on its neighbors always being active. This forces the network to learn redundant and robust features.
Why here: Dropout is a simple but very effective regularizer that reduces overfitting without increasing model complexity. It's particularly useful for fully connected (Dense) layers.
# โโ Callbacks โโ
callbacks = [
# Stop if no improvement after 10 epochs
keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
# Reduce learning rate if plateau
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-6),
# Save the best model
keras.callbacks.ModelCheckpoint('best_model.keras', save_best_only=True, monitor='val_accuracy'),
]
# โโ Training โโ
# batch_size=64: process 64 images at once (GPU parallelization)
# validation_split=0.2: reserve 20% for validation during training
history = model.fit(
X_train, y_train,
epochs=50,
batch_size=64,
validation_split=0.2,
callbacks=callbacks,
verbose=1
)
# โโ Evaluation โโ
test_loss, test_acc = model.evaluate(X_test, y_test)
print(f"\nTest accuracy: {test_acc*100:.2f}%")
Definition: Callback function executed automatically at regular intervals during training (end of epoch, after N batches, etc.). Callbacks allow you to monitor, modify, or stop training without modifying the main loop.
Purpose: Add control and flexibility to training automatically.
Why here: Callbacks like EarlyStopping and ModelCheckpoint prevent overfitting and save the best model automatically โ without them, you'd need to monitor manually and decide when to stop.
Definition: An epoch is one complete pass over the entire training dataset. A batch is a small subset of data processed in one optimization iteration (forward + backward). With 50,000 training images and batch_size=64, there are ~781 iterations (batches) per epoch.
Purpose: Divide data to enable stochastic optimization (SGD) and GPU parallelization.
Why here: Understanding epochs and batches is crucial for interpreting training. Each batch produces a gradient, optimizer.step() updates weights. After 781 batches, one epoch is complete. Increasing batch_size reduces gradient noise but uses more memory; decreasing increases noise but allows lower learning rates.
Definition: Hyperparameter that controls the size of weight updates at each iteration: weights = weights - learning_rate ร gradient. Too high oscillates or diverges; too low converges very slowly.
Purpose: Balance convergence speed and stability.
Why here: Learning rate is the most important hyperparameter. Here, we use 0.001 (1e-3) with the Adam optimizer which dynamically adapts the learning rate per feature. ReduceLROnPlateau reduces the learning rate if loss stagnates โ a common strategy in practice.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Device: {device}")
# โโ Transforms and augmentation โโ
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
)
])
test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
# โโ Datasets and DataLoaders โโ
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)
Definition: PyTorch class that iterates over a dataset in batches, with automatic batching, shuffling, and parallel data loading (multi-workers) that load data on CPU in parallel.
Purpose: Efficiently manage data loading, augmentation, and preparation during training without bottlenecks.
Why here: shuffle=True in training randomly shuffles data each epoch โ crucial for stochastic optimization (SGD). The num_workers parallelize data loading and augmentation on CPU while GPU processes batches โ double speedup.
# โโ PyTorch CNN Architecture โโ
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.features = nn.Sequential(
# Block 1
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(0.25),
# Block 2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(0.25),
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # Global Average Pooling
nn.Flatten(),
nn.Linear(64, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.features(x)
return self.classifier(x)
model = CNN().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
# โโ PyTorch Training Loop โโ
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
def train_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss, correct = 0, 0
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (outputs.argmax(1) == labels).sum().item()
return total_loss / len(loader), correct / len(loader.dataset)
for epoch in range(50):
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
scheduler.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/50 โ Loss: {train_loss:.4f} | Acc: {train_acc*100:.2f}%")
torch.save(model.state_dict(), 'cnn_cifar10.pth')