26 KiB
Architecture Design Principles
Context
You're designing a neural network architecture or debugging why your network isn't learning. Common mistakes:
- Ignoring inductive biases: Using MLP for images (should use CNN)
- Over-engineering: Using Transformer for 100 samples (should use linear regression)
- No skip connections: 50-layer plain network fails (should use ResNet)
- Wrong depth-width balance: 100 layers × 8 channels bottlenecks capacity
- Ignoring constraints: 1.5B parameter model doesn't fit 24GB GPU
This skill provides principled architecture design: match structure to problem, respect constraints, avoid over-engineering.
Core Principle: Inductive Biases
Inductive bias = assumptions baked into architecture about problem structure
Key insight: The right inductive bias makes learning dramatically easier. Wrong bias makes learning impossible.
What are Inductive Biases?
# Example: Image classification
# MLP (no inductive bias):
# - Treats each pixel independently
# - No concept of "spatial locality" or "translation"
# - Must learn from scratch that nearby pixels are related
# - Learns "cat at position (10,10)" and "cat at (50,50)" separately
# Parameters: 150M, Accuracy: 75%
# CNN (strong inductive bias):
# - Assumes spatial locality (nearby pixels related)
# - Assumes translation invariance (cat is cat anywhere)
# - Shares filters across spatial positions
# - Hierarchical feature learning (edges → textures → objects)
# Parameters: 11M, Accuracy: 95%
# CNN's inductive bias: 14× fewer parameters, 20% better accuracy!
Principle: Match your architecture's inductive biases to your problem's structure.
Architecture Families and Their Inductive Biases
1. Fully Connected (MLP)
Inductive bias: None (general-purpose)
Structure:
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
When to use:
- ✅ Tabular data (independent features)
- ✅ Small datasets (< 10,000 samples)
- ✅ Baseline / proof of concept
When NOT to use:
- ❌ Images (use CNN)
- ❌ Sequences (use RNN/Transformer)
- ❌ Graphs (use GNN)
Strengths:
- Simple and interpretable
- Fast training
- Works for any input type (flattened)
Weaknesses:
- No structural assumptions (must learn everything from data)
- Parameter explosion (input_size × hidden_size can be huge)
- Doesn't leverage problem structure
2. Convolutional Neural Networks (CNN)
Inductive bias: Spatial locality + Translation invariance
Structure:
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(128 * 7 * 7, 1000)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 112×112
x = self.pool(F.relu(self.conv2(x))) # 56×56
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
Inductive biases:
- Local connectivity: Neurons see only nearby pixels (spatial locality)
- Translation invariance: Same filter slides across image (parameter sharing)
- Hierarchical features: Stack layers to build complex features from simple ones
When to use:
- ✅ Images (classification, detection, segmentation)
- ✅ Spatial data (maps, medical scans)
- ✅ Any grid-structured data
When NOT to use:
- ❌ Sequences with long-range dependencies (use Transformer)
- ❌ Graphs (irregular structure, use GNN)
- ❌ Tabular data (no spatial structure)
Strengths:
- Parameter efficient (filter sharing)
- Translation invariant (cat anywhere = cat)
- Hierarchical feature learning
Weaknesses:
- Fixed receptive field (limited by kernel size)
- Not suitable for variable-length inputs
- Requires grid structure
3. Recurrent Neural Networks (RNN/LSTM)
Inductive bias: Temporal dependencies
Structure:
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# x: (batch, seq_len, input_size)
lstm_out, (h_n, c_n) = self.lstm(x)
# Use last hidden state
output = self.fc(h_n[-1])
return output
Inductive bias: Sequential processing (earlier elements influence later elements)
When to use:
- ✅ Time series (stock prices, sensor data)
- ✅ Short sequences (< 100 timesteps)
- ✅ Online processing (process one timestep at a time)
When NOT to use:
- ❌ Long sequences (> 1000 timesteps, use Transformer)
- ❌ Non-sequential data (images, tabular)
- ❌ When parallel processing needed (use Transformer)
Strengths:
- Natural for sequential data
- Constant memory (doesn't grow with sequence length)
- Online processing capability
Weaknesses:
- Slow (sequential, can't parallelize)
- Vanishing gradients (long sequences)
- Struggles with long-range dependencies
4. Transformers
Inductive bias: Minimal (self-attention is general-purpose)
Structure:
class SimpleTransformer(nn.Module):
def __init__(self, d_model, num_heads, num_layers):
super().__init__()
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model, num_heads),
num_layers
)
self.fc = nn.Linear(d_model, num_classes)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = self.encoder(x)
# Global average pooling
x = x.mean(dim=1)
return self.fc(x)
Inductive bias: Minimal (learns relationships from data via attention)
When to use:
- ✅ Long sequences (> 100 tokens)
- ✅ Language (text, code)
- ✅ Large datasets (> 100k samples)
- ✅ When relationships are complex and data-dependent
When NOT to use:
- ❌ Small datasets (< 10k samples, use RNN or MLP)
- ❌ Strong structural priors available (images → CNN)
- ❌ Very long sequences (> 16k tokens, use sparse attention)
- ❌ Low-latency requirements (RNN faster)
Strengths:
- Parallel processing (fast training)
- Long-range dependencies (attention)
- State-of-the-art for language
Weaknesses:
- Quadratic complexity O(n²) with sequence length
- Requires large datasets (weak inductive bias)
- High memory usage
5. Graph Neural Networks (GNN)
Inductive bias: Message passing over graph structure
Structure:
class SimpleGNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
# x: node features (num_nodes, input_dim)
# edge_index: graph structure (2, num_edges)
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
Inductive bias: Nodes influenced by neighbors (message passing)
When to use:
- ✅ Graph data (social networks, molecules, knowledge graphs)
- ✅ Irregular connectivity (different # of neighbors per node)
- ✅ Relational reasoning
When NOT to use:
- ❌ Grid data (images → CNN)
- ❌ Sequences (text → Transformer)
- ❌ If graph structure doesn't help (test MLP baseline first!)
Strengths:
- Handles irregular structure
- Permutation invariant
- Natural for relational data
Weaknesses:
- Requires meaningful graph structure
- Over-smoothing (too many layers)
- Scalability challenges (large graphs)
Decision Tree: Architecture Selection
START
|
├─ Is data grid-structured (images)?
│ ├─ YES → Use CNN
│ │ └─ ResNet (general), EfficientNet (mobile), ViT (very large datasets)
│ └─ NO → Continue
│
├─ Is data sequential (text, time series)?
│ ├─ YES → Check sequence length
│ │ ├─ < 100 timesteps → LSTM/GRU
│ │ ├─ 100-4000 tokens → Transformer
│ │ └─ > 4000 tokens → Sparse Transformer (Longformer)
│ └─ NO → Continue
│
├─ Is data graph-structured (molecules, social networks)?
│ ├─ YES → Check if structure helps
│ │ ├─ Test MLP baseline first
│ │ └─ If structure helps → GNN (GCN, GraphSAGE, GAT)
│ └─ NO → Continue
│
└─ Is data tabular (independent features)?
└─ YES → Start simple
├─ < 1000 samples → Linear / Ridge regression
├─ 1000-100k samples → Small MLP (2-3 layers)
└─ > 100k samples → Larger MLP or Gradient Boosting (XGBoost)
Principle: Start Simple, Add Complexity Only When Needed
Occam's Razor: Simplest model that solves the problem is best.
Progression:
# Step 1: Linear baseline (ALWAYS start here!)
model = nn.Linear(input_size, num_classes)
# Train and evaluate
# Step 2: IF linear insufficient, add small MLP
if linear_accuracy < target:
model = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
# Step 3: IF small MLP insufficient, add depth/width
if mlp_accuracy < target:
model = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
# Step 4: IF simple models fail, use specialized architecture
if simple_models_fail:
# Images → CNN
# Sequences → RNN/Transformer
# Graphs → GNN
# NEVER skip to Step 4 without testing Step 1-3!
Why Start Simple?
- Faster iteration: Linear model trains in seconds, Transformer in hours
- Baseline: Know if complexity helps (compare complex vs simple)
- Occam's Razor: Simple model generalizes better (less overfitting)
- Debugging: Easy to verify simple model works correctly
Example: House Price Prediction
# Dataset: 1000 samples, 20 features
# WRONG: Start with Transformer
model = HugeTransformer(20, 512, 6, 1) # 10M parameters
# Result: Overfits (10M params / 1000 samples = 10,000:1 ratio!)
# RIGHT: Start simple
# Step 1: Linear
model = nn.Linear(20, 1) # 21 parameters
# Trains in 1 second, achieves R² = 0.85 (good!)
# Conclusion: Linear sufficient, stop here. No need for Transformer!
Rule: Add complexity only when simple models demonstrably fail.
Principle: Deep Networks Need Skip Connections
Problem: Plain networks > 10 layers suffer from vanishing gradients and degradation.
Vanishing Gradients:
# Gradient flow in plain 50-layer network:
gradient_layer_1 = gradient_output × (∂L50/∂L49) × (∂L49/∂L48) × ... × (∂L2/∂L1)
# Each term < 1 (due to activations):
# If each ≈ 0.9, then: 0.9^50 = 0.0000051 (vanishes!)
# Result: Early layers don't learn (gradients too small)
Degradation:
# Empirical observation (ResNet paper):
20-layer plain network: 85% accuracy
56-layer plain network: 78% accuracy # WORSE with more layers!
# This is NOT overfitting (training accuracy also drops)
# This is optimization difficulty
Solution: Skip Connections (Residual Networks)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
identity = x # Save input
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = out + identity # Skip connection!
out = F.relu(out)
return out
Why skip connections work:
# Gradient flow with skip connections:
∂loss/∂x = ∂loss/∂out × (1 + ∂F/∂x)
# ↑
# Always flows! ("+1" term)
# Even if ∂F/∂x ≈ 0, gradient flows through identity path
Results:
# Without skip connections:
20-layer plain: 85% accuracy
50-layer plain: 78% accuracy (worse!)
# With skip connections (ResNet):
20-layer ResNet: 87% accuracy
50-layer ResNet: 92% accuracy (better!)
152-layer ResNet: 95% accuracy (even better!)
Rule: For networks > 10 layers, ALWAYS use skip connections.
Skip Connection Variants:
1. Residual (ResNet):
out = x + F(x) # Add input to output
2. Dense (DenseNet):
out = torch.cat([x, F(x)], dim=1) # Concatenate input and output
3. Highway:
gate = sigmoid(W_gate @ x)
out = gate * F(x) + (1 - gate) * x # Learned gating
Most common: Residual (simple, effective)
Principle: Balance Depth and Width
Depth = # of layers Width = # of channels/neurons per layer
Capacity Formula:
# Approximate capacity (for CNNs):
capacity ≈ depth × width²
# Why width²?
# Each layer: input_channels × output_channels × kernel_size²
# Doubling width → 4× parameters per layer
Trade-offs:
Too deep, too narrow:
# 100 layers × 8 channels
# Problems:
# - Information bottleneck (8 channels can't represent complex features)
# - Harder to optimize (more layers)
# - Slow inference (100 layers sequential)
# Example:
model = VeryDeepNarrow(num_layers=100, channels=8)
# Result: 60% accuracy (bottleneck!)
Too shallow, too wide:
# 2 layers × 1024 channels
# Problems:
# - Under-utilizes depth (no hierarchical features)
# - Memory explosion (1024 × 1024 = 1M parameters per layer!)
# Example:
model = VeryWideShallow(num_layers=2, channels=1024)
# Result: 70% accuracy (doesn't leverage depth)
Balanced:
# 18 layers, gradually increasing width: 64 → 128 → 256 → 512
# Benefits:
# - Hierarchical features (depth)
# - Sufficient capacity (width)
# - Good optimization (not too deep)
# Example (ResNet-18):
model = ResNet18()
# Layers: 18, Channels: 64-512 (average ~200)
# Result: 95% accuracy (optimal balance!)
Standard Patterns:
# CNNs: Gradually increase channels as spatial dims decrease
# Input: 224×224×3
# Layer 1: 224×224×64 (same spatial size, more channels)
# Layer 2: 112×112×128 (half spatial, double channels)
# Layer 3: 56×56×256 (half spatial, double channels)
# Layer 4: 28×28×512 (half spatial, double channels)
# Why? Compensate for spatial information loss with channel information
Rule: Balance depth and width. Standard pattern: 12-50 layers, 64-512 channels.
Principle: Match Capacity to Data Size
Capacity = # of learnable parameters
Parameter Budget:
# Rule of thumb: parameters should be 0.01-0.1× dataset size
# Example 1: MNIST (60,000 images)
# Budget: 600 - 6,000 parameters
# Simple CNN: 60,000 parameters (10×) → Works, but might overfit
# LeNet: 60,000 parameters → Classic, works well
# Example 2: ImageNet (1.2M images)
# Budget: 12,000 - 120,000 parameters
# ResNet-50: 25M parameters (200×) → Works (aggressive augmentation helps)
# Example 3: Tabular (100 samples, 20 features)
# Budget: 1 - 10 parameters
# Linear: 21 parameters → Perfect fit!
# MLP: 1,000 parameters → Overfits horribly
Overfitting Detection:
# Training accuracy >> Validation accuracy (gap > 5%)
train_acc = 99%, val_acc = 70% # 29% gap → OVERFITTING!
# Solutions:
# 1. Reduce model capacity (fewer layers/channels)
# 2. Add regularization (dropout, weight decay)
# 3. Collect more data
# 4. Data augmentation
# Order: Try (1) first (simplest), then (2), then (3)/(4)
Underfitting Detection:
# Training accuracy < target (model too simple)
train_acc = 60%, val_acc = 58% # Both low → UNDERFITTING!
# Solutions:
# 1. Increase model capacity (more layers/channels)
# 2. Train longer
# 3. Reduce regularization
# Order: Try (2) first (cheapest), then (1), then (3)
Rule: Match parameters to data size. Start small, increase capacity only if underfitting.
Principle: Design for Compute Constraints
Constraints:
- Memory: Model + gradients + optimizer states < GPU VRAM
- Latency: Inference time < requirement (e.g., < 100ms for real-time)
- Throughput: Samples/second > requirement
Memory Budget:
# Memory calculation (training):
# 1. Model parameters (FP32): params × 4 bytes
# 2. Gradients: params × 4 bytes
# 3. Optimizer states (Adam): params × 8 bytes (2× weights)
# 4. Activations: batch_size × feature_maps × spatial_size × 4 bytes
# Example: ResNet-50
params = 25M
memory_params = 25M × 4 = 100 MB
memory_gradients = 100 MB
memory_optimizer = 200 MB
memory_activations = batch_size × 64 × 7×7 × 4 ≈ batch_size × 12 KB
# Total (batch=32): 100 + 100 + 200 + 0.4 = 400 MB
# Fits easily on 4GB GPU!
# Example: GPT-3 (175B parameters)
memory_params = 175B × 4 = 700 GB
memory_total = 700 + 700 + 1400 = 2800 GB = 2.8 TB!
# Requires 35×A100 (80GB each)
Rule: Calculate memory before training. Don't design models that don't fit.
Latency Budget:
# Inference latency = # operations / throughput
# Example: Mobile app (< 100ms latency requirement)
# ResNet-50:
# Operations: 4B FLOPs
# Mobile CPU: 10 GFLOPS
# Latency: 4B / 10G = 0.4 seconds (FAILS!)
# MobileNetV2:
# Operations: 300M FLOPs
# Mobile CPU: 10 GFLOPS
# Latency: 300M / 10G = 0.03 seconds = 30ms (PASSES!)
# Solution: Use efficient architectures (MobileNet, EfficientNet) for mobile
Rule: Measure latency. Use efficient architectures if latency-constrained.
Common Architectural Patterns
1. Bottleneck (ResNet)
Structure:
# Standard: 3×3 conv (256 channels) → 3×3 conv (256 channels)
# Parameters: 256 × 256 × 3 × 3 = 590K
# Bottleneck: 1×1 (256→64) → 3×3 (64→64) → 1×1 (64→256)
# Parameters: 256×64 + 64×64×3×3 + 64×256 = 16K + 37K + 16K = 69K
# Reduction: 590K → 69K (8.5× fewer!)
Purpose: Reduce parameters while maintaining capacity
When to use: Deep networks (> 50 layers) where parameters are a concern
2. Inverted Bottleneck (MobileNetV2)
Structure:
# Bottleneck (ResNet): Wide → Narrow → Wide (256 → 64 → 256)
# Inverted: Narrow → Wide → Narrow (64 → 256 → 64)
# Why? Efficient for mobile (depthwise separable convolutions)
Purpose: Maximize efficiency (FLOPs per parameter)
When to use: Mobile/edge deployment
3. Multi-scale Features (Inception)
Structure:
# Parallel branches with different kernel sizes:
# Branch 1: 1×1 conv
# Branch 2: 3×3 conv
# Branch 3: 5×5 conv
# Branch 4: 3×3 max pool
# Concatenate all branches
# Captures features at multiple scales simultaneously
Purpose: Capture multi-scale patterns
When to use: When features exist at multiple scales (object detection)
4. Attention (Transformers, SE-Net)
Structure:
# Squeeze-and-Excitation (SE) block:
# 1. Global average pooling (spatial → channel descriptor)
# 2. FC layer (bottleneck)
# 3. FC layer (restore channels)
# 4. Sigmoid (attention weights)
# 5. Multiply input channels by attention weights
# Result: Emphasize important channels, suppress irrelevant
Purpose: Learn importance of features (channels or positions)
When to use: When not all features equally important
Debugging Architectures
Problem 1: Network doesn't learn (loss stays constant)
Diagnosis:
# Check gradient flow
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_mean={param.grad.mean():.6f}, grad_std={param.grad.std():.6f}")
# Vanishing: grad_mean ≈ 0, grad_std ≈ 0 → Add skip connections
# Exploding: grad_mean > 1, grad_std > 1 → Gradient clipping or lower LR
Solutions:
- Add skip connections (ResNet)
- Check initialization (Xavier or He initialization)
- Lower learning rate
- Check data preprocessing (normalized inputs?)
Problem 2: Overfitting (train >> val)
Diagnosis:
train_acc = 99%, val_acc = 70% # 29% gap → Overfitting
# Check parameter/data ratio:
num_params = sum(p.numel() for p in model.parameters())
data_size = len(train_dataset)
ratio = num_params / data_size
# If ratio > 1: Model has more parameters than data points!
Solutions (in order):
- Reduce capacity (fewer layers/channels)
- Add dropout / weight decay
- Data augmentation
- Collect more data
Problem 3: Underfitting (train and val both low)
Diagnosis:
train_acc = 65%, val_acc = 63% # Both low → Underfitting
# Model too simple for task complexity
Solutions (in order):
- Train longer (more epochs)
- Increase capacity (more layers/channels)
- Reduce regularization (lower dropout/weight decay)
- Check learning rate (too low?)
Problem 4: Slow training
Diagnosis:
# Profile forward/backward pass
import time
start = time.time()
loss = criterion(model(inputs), targets)
forward_time = time.time() - start
start = time.time()
loss.backward()
backward_time = time.time() - start
# If backward_time >> forward_time: Gradient computation bottleneck
Solutions:
- Use mixed precision (FP16)
- Reduce batch size (if memory-bound)
- Use gradient accumulation (simulate large batch)
- Simplify architecture (fewer layers)
Design Checklist
Before finalizing an architecture:
☐ Match inductive bias to problem
- Images → CNN
- Sequences → RNN/Transformer
- Graphs → GNN
- Tabular → MLP
☐ Start simple, add complexity only when needed
- Test linear baseline first
- Add complexity incrementally
- Compare performance at each step
☐ Use skip connections for deep networks (> 10 layers)
- ResNet for CNNs
- Pre-norm for Transformers
- Gradient flow is critical
☐ Balance depth and width
- Not too deep and narrow (bottleneck)
- Not too shallow and wide (under-utilizes depth)
- Standard: 12-50 layers, 64-512 channels
☐ Match capacity to data size
- Parameters ≈ 0.01-0.1× dataset size
- Monitor train/val gap (overfitting indicator)
☐ Respect compute constraints
- Memory: Model + gradients + optimizer + activations < VRAM
- Latency: Inference time < requirement
- Use efficient architectures if constrained (MobileNet, EfficientNet)
☐ Verify gradient flow
- Check gradients in early layers (should be non-zero)
- Use skip connections if vanishing
☐ Benchmark against baselines
- Compare to simple model (linear, small MLP)
- Ensure complexity adds value (% improvement > 5%)
Anti-Patterns
Anti-pattern 1: "Architecture X is state-of-the-art, so I'll use it"
Wrong:
# Transformer is SOTA for NLP, so use for tabular data (100 samples)
model = HugeTransformer(...) # 10M parameters
# Result: Overfits horribly (100 samples / 10M params = 0.00001 ratio!)
Right:
# Match architecture to problem AND data size
# Tabular + small data → Linear or small MLP
model = nn.Linear(20, 1) # 21 parameters (appropriate!)
Anti-pattern 2: "More layers = better"
Wrong:
# 100-layer plain network (no skip connections)
for i in range(100):
layers.append(nn.Conv2d(64, 64, 3, padding=1))
# Result: Doesn't train (vanishing gradients)
Right:
# 50-layer ResNet (with skip connections)
# Each block: out = x + F(x) # Skip connection
# Result: Trains well, high accuracy
Anti-pattern 3: "Deeper + narrower = efficient"
Wrong:
# 100 layers × 8 channels = information bottleneck
model = VeryDeepNarrow(100, 8)
# Result: 60% accuracy (8 channels insufficient)
Right:
# 18 layers, 64-512 channels (balanced)
model = ResNet18() # Balanced depth and width
# Result: 95% accuracy
Anti-pattern 4: "Ignore constraints, optimize later"
Wrong:
# Design 1.5B parameter model for 24GB GPU
model = HugeModel(1.5e9)
# Result: OOM (out of memory), can't train
Right:
# Calculate memory first:
# 1.5B params × 4 bytes = 6GB (weights)
# + 6GB (gradients) + 12GB (Adam) + 8GB (activations) = 32GB
# > 24GB → Doesn't fit!
# Design for hardware:
model = ReasonableSizeModel(200e6) # 200M parameters (fits!)
Anti-pattern 5: "Hyperparameters will fix architectural problems"
Wrong:
# Architecture: MLP for images (wrong inductive bias)
# Response: "Just tune learning rate!"
for lr in [0.1, 0.01, 0.001, 0.0001]:
train(model, lr=lr)
# Result: All fail (architecture is wrong!)
Right:
# Fix architecture first (use CNN for images)
model = ResNet18() # Correct inductive bias
# Then tune hyperparameters
Summary
Core principles:
-
Inductive bias: Match architecture to problem structure (CNN for images, RNN/Transformer for sequences, GNN for graphs)
-
Occam's Razor: Start simple (linear, small MLP), add complexity only when needed
-
Skip connections: Use for networks > 10 layers (ResNet, DenseNet)
-
Depth-width balance: Not too deep+narrow (bottleneck) or too shallow+wide (under-utilizes depth)
-
Capacity: Match parameters to data size (0.01-0.1× dataset size)
-
Constraints: Design for available memory, latency, throughput
Decision framework:
- Images → CNN (ResNet, EfficientNet)
- Short sequences → LSTM
- Long sequences → Transformer
- Graphs → GNN (test if structure helps first!)
- Tabular → Linear or small MLP
Key insight: Architecture design is about matching structural assumptions to problem structure, not about using the "best" or "most complex" model. Simple models often win.
When in doubt: Start with the simplest model that could plausibly work. Add complexity only when you have evidence it helps.