# PyTorch nn.Module Design Patterns ## Overview **Core Principle:** nn.Module is not just a container for forward passes. It's PyTorch's contract for model serialization, device management, parameter enumeration, and inspection. Follow conventions or face subtle bugs during scaling, deployment, and debugging. Poor module design manifests as: state dict corruption, DDP failures, hook memory leaks, initialization fragility, and un-inspectable architectures. These bugs are silent until production. Design modules correctly from the start using PyTorch's established patterns. ## When to Use **Use this skill when:** - Implementing custom nn.Module subclasses - Adding forward/backward hooks for feature extraction or debugging - Designing modular architectures with swappable components - Implementing custom weight initialization strategies - Building reusable model components (blocks, layers, heads) - Encountering state dict issues, DDP failures, or hook problems **Don't use when:** - Simple model composition (stack existing modules) - Training loop issues (use training-optimization) - Memory debugging unrelated to modules (use tensor-operations-and-memory) **Symptoms triggering this skill:** - "State dict keys don't match after loading" - "DDP not syncing gradients properly" - "Hooks causing memory leaks" - "Can't move model to device" - "Model serialization breaks after changes" - "Need to extract intermediate features" - "Want to make architecture more modular" ## Expert Module Design Patterns ### Pattern 1: Always Use nn.Module, Never None **Problem:** Conditional module assignment using `None` breaks PyTorch's module contract. ```python # ❌ WRONG: Conditional None assignment class ResNetBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1) self.bn1 = nn.BatchNorm2d(out_channels) # PROBLEM: Using None for conditional skip connection if stride != 1 or in_channels != out_channels: self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, 0) else: self.skip = None # ❌ Breaks module enumeration! def forward(self, x): out = self.bn1(self.conv1(x)) # Conditional check needed if self.skip is not None: x = self.skip(x) return F.relu(out + x) ``` **Why this breaks:** - `model.parameters()` and `model.named_modules()` skip None attributes - `.to(device)` doesn't move None, causes device mismatch bugs - `state_dict()` saving/loading becomes inconsistent - DDP/model parallel don't handle None modules correctly - Can't inspect architecture: `for name, module in model.named_modules()` **✅ CORRECT: Use nn.Identity() for no-op** ```python class ResNetBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1) self.bn1 = nn.BatchNorm2d(out_channels) # ✅ ALWAYS assign an nn.Module subclass if stride != 1 or in_channels != out_channels: self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, 0) else: self.skip = nn.Identity() # ✅ No-op module def forward(self, x): out = self.bn1(self.conv1(x)) # No conditional needed! x = self.skip(x) # Identity passes through unchanged return F.relu(out + x) ``` **Why this works:** - `nn.Identity()` passes input unchanged (no-op) - Consistent module hierarchy across all code paths - Device movement works: `.to(device)` works on Identity too - State dict consistent: Identity has no parameters but is tracked - DDP handles Identity correctly - Architecture inspection works: `model.skip` always exists **Rule:** Never assign `None` to `self.*` for modules. Use `nn.Identity()` for no-ops. ### Pattern 2: Functional vs Module Operations - When to Use Each **Core Question:** When should you use `F.relu(x)` vs `self.relu = nn.ReLU()`? **Decision Framework:** | Use Functional (`F.*`) When | Use Module (`nn.*`) When | |------------------------------|--------------------------| | Simple, stateless operations | Need to hook the operation | | Performance critical paths | Need to inspect/modify later | | Operations in complex control flow | Want clear module hierarchy | | One-off computations | Operation has learnable parameters | | Loss functions | Activation functions you might swap | **Example: When functional is fine** ```python class SimpleBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, 1, 1) self.bn = nn.BatchNorm2d(channels) # No need to store ReLU as module for simple blocks def forward(self, x): x = self.conv(x) x = self.bn(x) return F.relu(x) # ✅ Fine for simple cases ``` **Example: When module storage matters** ```python class FeatureExtractorBlock(nn.Module): def __init__(self, channels, activation='relu'): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, 1, 1) self.bn = nn.BatchNorm2d(channels) # ✅ Store as module for flexibility and inspection if activation == 'relu': self.activation = nn.ReLU() elif activation == 'gelu': self.activation = nn.GELU() else: self.activation = nn.Identity() def forward(self, x): x = self.conv(x) x = self.bn(x) return self.activation(x) # ✅ Can hook, swap, inspect ``` **Why storing as module matters:** 1. **Hooks**: Can only hook module operations, not functional ```python # ✅ Can register hook model.layer1.activation.register_forward_hook(hook_fn) # ❌ Can't hook F.relu() calls ``` 2. **Inspection**: Module hierarchy shows architecture ```python for name, module in model.named_modules(): print(f"{name}: {module}") # With nn.ReLU: "layer1.activation: ReLU()" # With F.relu: activation not shown ``` 3. **Modification**: Can swap modules after creation ```python # ✅ Can replace activation model.layer1.activation = nn.GELU() # ❌ Can't modify F.relu() usage without code changes ``` 4. **Quantization**: Quantization tools trace module operations ```python # ✅ Quantization sees nn.ReLU quantized = torch.quantization.quantize_dynamic(model) # ❌ F.relu() not traced by quantization ``` **Pattern to follow:** - Simple internal blocks: Functional is fine - Top-level operations you might modify: Use modules - When building reusable components: Use modules - When unsure: Use modules (negligible overhead) ### Pattern 3: Modular Design with Substitutable Components **Problem:** Hardcoding architecture choices makes variants difficult. ```python # ❌ WRONG: Hardcoded architecture class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() # Hardcoded: ReLU, BatchNorm, specific conv config self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1) self.bn2 = nn.BatchNorm2d(out_channels) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) return x ``` **Problem:** To use LayerNorm or GELU, you must copy-paste and create new class. **✅ CORRECT: Modular design with substitutable components** ```python class EncoderBlock(nn.Module): def __init__( self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, # ✅ Substitutable activation=nn.ReLU, # ✅ Substitutable bias=True ): super().__init__() # Use provided norm and activation self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=bias) self.norm1 = norm_layer(out_channels) self.act1 = activation() self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias) self.norm2 = norm_layer(out_channels) self.act2 = activation() def forward(self, x): x = self.act1(self.norm1(self.conv1(x))) x = self.act2(self.norm2(self.conv2(x))) return x # Usage examples: # Standard: BatchNorm + ReLU block1 = EncoderBlock(64, 128) # LayerNorm + GELU (for vision transformers) block2 = EncoderBlock(64, 128, norm_layer=nn.LayerNorm, activation=nn.GELU) # No normalization block3 = EncoderBlock(64, 128, norm_layer=nn.Identity, activation=nn.ReLU) ``` **Advanced: Flexible normalization for different dimensions** ```python class EncoderBlock(nn.Module): def __init__( self, in_channels, out_channels, norm_layer=None, # If None, use default BatchNorm2d activation=None # If None, use default ReLU ): super().__init__() # Set defaults if norm_layer is None: norm_layer = nn.BatchNorm2d if activation is None: activation = nn.ReLU self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) # Handle both class and partial/lambda self.norm1 = norm_layer(out_channels) if callable(norm_layer) else norm_layer self.act1 = activation() if callable(activation) else activation self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1) self.norm2 = norm_layer(out_channels) if callable(norm_layer) else norm_layer self.act2 = activation() if callable(activation) else activation def forward(self, x): x = self.act1(self.norm1(self.conv1(x))) x = self.act2(self.norm2(self.conv2(x))) return x ``` **Benefits:** - One class supports many architectural variants - Easy to experiment: swap LayerNorm, GELU, etc. - Code reuse without duplication - Matches PyTorch's own design (e.g., ResNet's `norm_layer` parameter) **Pattern:** Accept layer constructors as arguments, not hardcoded classes. ### Pattern 4: Proper State Management and `__init__` Structure **Core principle:** `__init__` defines the module's structure, `forward` defines computation. ```python class WellStructuredModule(nn.Module): """ Template for well-structured PyTorch modules. """ def __init__(self, config): # 1. ALWAYS call super().__init__() first super().__init__() # 2. Store configuration (for reproducibility/serialization) self.config = config # 3. Initialize all submodules (parameters registered automatically) self._build_layers() # 4. Initialize weights AFTER building layers self.reset_parameters() def _build_layers(self): """ Separate method for building layers (cleaner __init__). """ self.encoder = nn.Sequential( nn.Linear(self.config.input_dim, self.config.hidden_dim), nn.ReLU(), nn.Linear(self.config.hidden_dim, self.config.hidden_dim) ) self.decoder = nn.Linear(self.config.hidden_dim, self.config.output_dim) # ✅ Use nn.Identity() for conditional modules if self.config.use_skip: self.skip = nn.Linear(self.config.input_dim, self.config.output_dim) else: self.skip = nn.Identity() def reset_parameters(self): """ Custom initialization following PyTorch convention. This method can be called to re-initialize the module: - After creation - When loading partial checkpoints - For training experiments """ for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: # ✅ Check before accessing nn.init.zeros_(module.bias) def forward(self, x): """ Forward pass - pure computation, no module construction. """ # ❌ NEVER create modules here! # ❌ NEVER assign self.* here! encoded = self.encoder(x) decoded = self.decoder(encoded) skip = self.skip(x) return decoded + skip ``` **Critical rules:** 1. **Never create modules in `forward()`** ```python # ❌ WRONG def forward(self, x): self.temp_layer = nn.Linear(10, 10) # ❌ Created during forward! return self.temp_layer(x) ``` **Why:** Parameters not registered, DDP breaks, state dict inconsistent. 2. **Never use `self.*` for intermediate results** ```python # ❌ WRONG def forward(self, x): self.intermediate = self.encoder(x) # ❌ Storing as attribute! return self.decoder(self.intermediate) ``` **Why:** Retains computation graph, memory leak, not thread-safe. 3. **All modules defined in `__init__`** ```python # ✅ CORRECT def __init__(self): super().__init__() self.encoder = nn.Linear(10, 10) # ✅ Defined in __init__ def forward(self, x): intermediate = self.encoder(x) # ✅ Local variable return self.decoder(intermediate) ``` ## Hook Management Best Practices ### Pattern 5: Forward Hooks for Feature Extraction **Problem:** Naive hook usage causes memory leaks and handle management issues. ```python # ❌ WRONG: Multiple problems import torch import torch.nn as nn features = {} # ❌ Global state def get_features(name): def hook(module, input, output): features[name] = output # ❌ Retains computation graph! return hook model = nn.Sequential(...) # ❌ No handle stored, can't remove model[2].register_forward_hook(get_features('layer2')) with torch.no_grad(): output = model(x) # features now contains tensors with gradients (even in no_grad context!) ``` **Why this breaks:** 1. **Hooks run outside `torch.no_grad()` context**: Hook is called by autograd machinery, not your code 2. **Global state**: Not thread-safe, can't have multiple concurrent extractions 3. **No cleanup**: Hooks persist forever, can't remove 4. **Memory leak**: Retained outputs keep computation graph alive **✅ CORRECT: Encapsulated hook handler with proper cleanup** ```python class FeatureExtractor: """ Proper feature extraction using forward hooks. Example: extractor = FeatureExtractor(model, layers=['layer2', 'layer3']) with extractor: output = model(input) features = extractor.features # Dict of detached tensors """ def __init__(self, model, layers): self.model = model self.layers = layers self.features = {} self.handles = [] # ✅ Store handles for cleanup def _make_hook(self, name): def hook(module, input, output): # ✅ CRITICAL: Detach and optionally clone self.features[name] = output.detach() # For inputs that might be modified in-place, use .clone(): # self.features[name] = output.detach().clone() return hook def __enter__(self): """Register hooks when entering context.""" self.features.clear() for name, module in self.model.named_modules(): if name in self.layers: handle = module.register_forward_hook(self._make_hook(name)) self.handles.append(handle) # ✅ Store handle return self def __exit__(self, *args): """Clean up hooks when exiting context.""" # ✅ CRITICAL: Remove all hooks for handle in self.handles: handle.remove() self.handles.clear() # Usage model = resnet50() extractor = FeatureExtractor(model, layers=['layer2', 'layer3', 'layer4']) with extractor: output = model(input_tensor) # Features extracted and hooks cleaned up pyramid_features = [ extractor.features['layer2'], extractor.features['layer3'], extractor.features['layer4'] ] ``` **Key points:** - ✅ Encapsulated in class (no global state) - ✅ `output.detach()` breaks gradient tracking (prevents memory leak) - ✅ Hook handles stored and removed (no persistent hooks) - ✅ Context manager ensures cleanup even if error occurs - ✅ Thread-safe (each extractor has own state) ### Pattern 6: When to Detach vs Clone in Hooks **Question:** Should hooks detach, clone, or both? **Decision framework:** ```python def hook(module, input, output): # Decision tree: # 1. Just reading output, no modifications? self.features[name] = output.detach() # ✅ Sufficient # 2. Output might be modified in-place later? self.features[name] = output.detach().clone() # ✅ Safer # 3. Need gradients for analysis (rare)? self.features[name] = output # ⚠️ Dangerous, ensure short lifetime ``` **Example: When clone matters** ```python # Scenario: In-place operations after hook class Model(nn.Module): def forward(self, x): x = self.layer1(x) # Hook here x = self.layer2(x) x += 10 # ❌ In-place modification! return x # ❌ WRONG: Detach without clone def hook(module, input, output): features['layer1'] = output.detach() # Still shares memory! # After forward pass: # features['layer1'] has been modified by x += 10! # ✅ CORRECT: Clone to get independent copy def hook(module, input, output): features['layer1'] = output.detach().clone() # Independent copy ``` **Rule of thumb:** - **Detach only**: Reading features for analysis, no in-place ops - **Detach + clone**: Features might be modified, or unsure - **Neither**: Only if you need gradients (rare, risky) ### Pattern 7: Backward Hooks for Gradient Inspection **Use case:** Debugging gradient flow, detecting vanishing/exploding gradients. ```python class GradientInspector: """ Inspect gradients during backward pass. Example: inspector = GradientInspector(model, layers=['layer1', 'layer2']) with inspector: output = model(input) loss.backward() # Check gradient statistics for name, stats in inspector.grad_stats.items(): print(f"{name}: mean={stats['mean']:.4f}, std={stats['std']:.4f}") """ def __init__(self, model, layers): self.model = model self.layers = layers self.grad_stats = {} self.handles = [] def _make_hook(self, name): def hook(module, grad_input, grad_output): # grad_output: gradients w.r.t. outputs (from upstream) # grad_input: gradients w.r.t. inputs (to downstream) # Check grad_output (most common) if grad_output[0] is not None: grad = grad_output[0].detach() self.grad_stats[name] = { 'mean': grad.abs().mean().item(), 'std': grad.std().item(), 'max': grad.abs().max().item(), 'min': grad.abs().min().item(), } return hook def __enter__(self): self.grad_stats.clear() for name, module in self.model.named_modules(): if name in self.layers: handle = module.register_full_backward_hook(self._make_hook(name)) self.handles.append(handle) return self def __exit__(self, *args): for handle in self.handles: handle.remove() self.handles.clear() # Usage for gradient debugging model = MyModel() inspector = GradientInspector(model, layers=['encoder.layer1', 'decoder.layer1']) with inspector: output = model(input) loss = criterion(output, target) loss.backward() # Check for vanishing/exploding gradients for name, stats in inspector.grad_stats.items(): if stats['mean'] < 1e-7: print(f"⚠️ Vanishing gradient in {name}") if stats['mean'] > 100: print(f"⚠️ Exploding gradient in {name}") ``` **Critical differences from forward hooks:** - **Backward hooks run during `.backward()`**: Not during forward pass - **Receive gradient tensors**: Not activations - **Used for gradient analysis**: Not feature extraction ### Pattern 8: Hook Handle Management Patterns **Never do this:** ```python # ❌ WRONG: No handle stored model.layer.register_forward_hook(hook_fn) # Hook persists forever, can't remove! ``` **Three patterns for handle management:** **Pattern A: Context manager (recommended for temporary hooks)** ```python class HookManager: def __init__(self, module, hook_fn): self.module = module self.hook_fn = hook_fn self.handle = None def __enter__(self): self.handle = self.module.register_forward_hook(self.hook_fn) return self def __exit__(self, *args): if self.handle: self.handle.remove() # Usage with HookManager(model.layer1, my_hook): output = model(input) # Hook automatically removed ``` **Pattern B: Explicit cleanup (for long-lived hooks)** ```python class Model(nn.Module): def __init__(self): super().__init__() self.layer = nn.Linear(10, 10) self.hook_handle = self.layer.register_forward_hook(self._debug_hook) def _debug_hook(self, module, input, output): print(f"Output shape: {output.shape}") def remove_hooks(self): """Explicit cleanup method.""" if self.hook_handle: self.hook_handle.remove() self.hook_handle = None # Usage model = Model() # ... use model ... model.remove_hooks() # Clean up before saving or finishing ``` **Pattern C: List of handles (multiple hooks)** ```python class Model(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)]) self.hook_handles = [] def register_debug_hooks(self): """Register hooks on all layers.""" for i, layer in enumerate(self.layers): handle = layer.register_forward_hook( lambda m, inp, out, idx=i: print(f"Layer {idx}: {out.shape}") ) self.hook_handles.append(handle) def remove_all_hooks(self): """Remove all registered hooks.""" for handle in self.hook_handles: handle.remove() self.hook_handles.clear() ``` **Critical rule:** Every `register_*_hook()` call MUST have corresponding `handle.remove()`. ## Weight Initialization Patterns ### Pattern 9: The `reset_parameters()` Convention **PyTorch convention:** Custom initialization goes in `reset_parameters()`, called from `__init__`. ```python # ❌ WRONG: Initialization in __init__ after submodule creation class CustomModule(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.linear1 = nn.Linear(in_dim, out_dim) self.linear2 = nn.Linear(out_dim, out_dim) # ❌ Initializing here is fragile nn.init.xavier_uniform_(self.linear1.weight) nn.init.xavier_uniform_(self.linear2.weight) # What if linear has bias=False? This crashes: nn.init.zeros_(self.linear1.bias) # ❌ AttributeError if bias=False ``` **Problems:** 1. Happens AFTER nn.Linear's own `reset_parameters()` (already initialized) 2. Can't re-initialize later: `model.reset_parameters()` won't work 3. Fragile: assumes bias exists 4. Violates PyTorch convention **✅ CORRECT: Define `reset_parameters()` method** ```python class CustomModule(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.linear1 = nn.Linear(in_dim, out_dim) self.linear2 = nn.Linear(out_dim, out_dim) # ✅ Call reset_parameters at end of __init__ self.reset_parameters() def reset_parameters(self): """ Initialize module parameters. Following PyTorch convention, this method: - Can be called to re-initialize the module - Is called automatically at end of __init__ - Allows for custom initialization strategies """ # ✅ Defensive: check if bias exists nn.init.xavier_uniform_(self.linear1.weight) if self.linear1.bias is not None: nn.init.zeros_(self.linear1.bias) nn.init.xavier_uniform_(self.linear2.weight) if self.linear2.bias is not None: nn.init.zeros_(self.linear2.bias) def forward(self, x): return self.linear2(F.relu(self.linear1(x))) # Benefits: # 1. Can re-initialize: model.reset_parameters() # 2. Defensive checks for optional bias # 3. Follows PyTorch convention # 4. Clear separation: __init__ defines structure, reset_parameters initializes ``` ### Pattern 10: Hierarchical Initialization **Pattern:** When modules contain submodules, iterate through hierarchy. ```python class ComplexModel(nn.Module): def __init__(self, config): super().__init__() self.encoder = nn.Sequential( nn.Linear(config.input_dim, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, config.hidden_dim) ) self.attention = nn.MultiheadAttention(config.hidden_dim, config.num_heads) self.decoder = nn.Linear(config.hidden_dim, config.output_dim) self.reset_parameters() def reset_parameters(self): """ Initialize all submodules hierarchically. """ # Method 1: Iterate through all modules for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): # MultiheadAttention has its own reset_parameters() # Option: Call it or customize module._reset_parameters() # Call internal reset # Method 2: Specific initialization for specific layers # Override general initialization for decoder nn.init.xavier_uniform_(self.decoder.weight, gain=0.5) ``` **Two strategies:** 1. **Uniform initialization**: Iterate all modules, apply same rules ```python for module in self.modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.kaiming_normal_(module.weight) ``` 2. **Layered initialization**: Different rules for different components ```python def reset_parameters(self): # Encoder: Xavier for module in self.encoder.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) # Decoder: Xavier with small gain nn.init.xavier_uniform_(self.decoder.weight, gain=0.5) ``` **Defensive checks:** ```python def reset_parameters(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) # ✅ Always check for bias if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.BatchNorm2d): # BatchNorm has weight and bias, but different semantics nn.init.ones_(module.weight) nn.init.zeros_(module.bias) ``` ### Pattern 11: Initialization with Learnable Parameters **Use case:** Custom parameters that need special initialization. ```python class AttentionWithTemperature(nn.Module): def __init__(self, d_model, d_k): super().__init__() self.d_k = d_k self.query = nn.Linear(d_model, d_k) self.key = nn.Linear(d_model, d_k) self.value = nn.Linear(d_model, d_k) self.output = nn.Linear(d_k, d_model) # ✅ Learnable temperature parameter # Initialize to 1/sqrt(d_k), but make it learnable self.temperature = nn.Parameter(torch.ones(1)) self.reset_parameters() def reset_parameters(self): """Initialize all parameters.""" # Standard initialization for linear layers for linear in [self.query, self.key, self.value]: nn.init.xavier_uniform_(linear.weight) if linear.bias is not None: nn.init.zeros_(linear.bias) # Output projection with smaller gain nn.init.xavier_uniform_(self.output.weight, gain=0.5) if self.output.bias is not None: nn.init.zeros_(self.output.bias) # ✅ Custom parameter initialization nn.init.constant_(self.temperature, 1.0 / math.sqrt(self.d_k)) def forward(self, x): q = self.query(x) k = self.key(x) v = self.value(x) scores = torch.matmul(q, k.transpose(-2, -1)) * self.temperature attn = torch.softmax(scores, dim=-1) out = torch.matmul(attn, v) return self.output(out) ``` **Key points:** - Custom parameters defined with `nn.Parameter()` - Initialized in `reset_parameters()` like other parameters - Can use `nn.init.*` functions on parameters ## Common Pitfalls ### Consolidated Pitfalls Table | # | Pitfall | Symptom | Root Cause | Fix | |---|---------|---------|------------|-----| | 1 | Using `self.x = None` for conditional modules | State dict inconsistent, DDP fails, can't move to device | None not an nn.Module | Use `nn.Identity()` | | 2 | Using functional ops when hooks/inspection needed | Can't hook activations, architecture invisible | Functional bypasses module hierarchy | Store as `self.activation = nn.ReLU()` | | 3 | Hooks retaining computation graphs | Memory leak during feature extraction | Hook doesn't detach outputs | Use `output.detach()` in hook | | 4 | No hook handle cleanup | Hooks persist, memory leak, unexpected behavior | Handles not stored/removed | Store handles, call `handle.remove()` | | 5 | Global state in hook closures | Not thread-safe, coupling issues | Mutable global variables | Encapsulate in class | | 6 | Initialization in `__init__` instead of `reset_parameters()` | Can't re-initialize, fragile timing | Violates PyTorch convention | Define `reset_parameters()` | | 7 | Accessing bias without checking existence | Crashes with AttributeError | Assumes bias always exists | Check `if module.bias is not None:` | | 8 | Creating modules in `forward()` | Parameters not registered, DDP breaks | Modules must be in `__init__` | Move to `__init__`, use local vars | | 9 | Storing intermediate results as `self.*` | Memory leak, not thread-safe | Retains computation graph | Use local variables only | | 10 | Not using context managers for hooks | Hooks not cleaned up on error | Missing try/finally | Use `__enter__`/`__exit__` pattern | ### Pitfall 1: Conditional None Assignment ```python # ❌ WRONG class Block(nn.Module): def __init__(self, use_skip): super().__init__() self.layer = nn.Linear(10, 10) self.skip = nn.Linear(10, 10) if use_skip else None # ❌ def forward(self, x): out = self.layer(x) if self.skip is not None: out = out + self.skip(x) return out # ✅ CORRECT class Block(nn.Module): def __init__(self, use_skip): super().__init__() self.layer = nn.Linear(10, 10) self.skip = nn.Linear(10, 10) if use_skip else nn.Identity() # ✅ def forward(self, x): out = self.layer(x) out = out + self.skip(x) # No conditional needed return out ``` **Symptom:** State dict keys mismatch, DDP synchronization failures **Fix:** Always use `nn.Identity()` for no-op modules ### Pitfall 2: Functional Ops Preventing Hooks ```python # ❌ WRONG: Can't hook ReLU class Encoder(nn.Module): def __init__(self, dim): super().__init__() self.linear = nn.Linear(dim, dim) def forward(self, x): x = self.linear(x) return F.relu(x) # ❌ Can't hook this! # Can't do this: # encoder.relu.register_forward_hook(hook) # AttributeError! # ✅ CORRECT: Hookable activation class Encoder(nn.Module): def __init__(self, dim): super().__init__() self.linear = nn.Linear(dim, dim) self.relu = nn.ReLU() # ✅ Stored as module def forward(self, x): x = self.linear(x) return self.relu(x) # ✅ Now can hook: encoder.relu.register_forward_hook(hook) ``` **Symptom:** Can't register hooks on operations **Fix:** Store operations as modules when you need inspection/hooks ### Pitfall 3: Hook Memory Leak ```python # ❌ WRONG: Hook retains graph features = {} def hook(module, input, output): features['layer'] = output # ❌ Retains computation graph! model.layer.register_forward_hook(hook) with torch.no_grad(): output = model(input) # features['layer'] STILL has gradients! # ✅ CORRECT: Detach in hook def hook(module, input, output): features['layer'] = output.detach() # ✅ Breaks graph # Even better: Clone if might be modified def hook(module, input, output): features['layer'] = output.detach().clone() # ✅ Independent copy ``` **Symptom:** Memory grows during feature extraction even with `torch.no_grad()` **Fix:** Always `.detach()` in hooks (and `.clone()` if needed) ### Pitfall 4: Missing Hook Cleanup ```python # ❌ WRONG: No handle management model.layer.register_forward_hook(my_hook) # Hook persists forever, can't remove! # ✅ CORRECT: Store and clean up handle class HookManager: def __init__(self): self.handle = None def register(self, module, hook): self.handle = module.register_forward_hook(hook) def cleanup(self): if self.handle: self.handle.remove() manager = HookManager() manager.register(model.layer, my_hook) # ... use model ... manager.cleanup() # ✅ Remove hook ``` **Symptom:** Hooks persist, unexpected behavior, memory leaks **Fix:** Always store handles and call `.remove()` ### Pitfall 5: Initialization Timing ```python # ❌ WRONG: Init in __init__ (fragile) class MyModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 10) # Already initialized! # This works but is fragile: nn.init.xavier_uniform_(self.linear.weight) # Overwrites default init # ✅ CORRECT: Init in reset_parameters() class MyModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 10) self.reset_parameters() # ✅ Clear separation def reset_parameters(self): nn.init.xavier_uniform_(self.linear.weight) if self.linear.bias is not None: # ✅ Defensive nn.init.zeros_(self.linear.bias) ``` **Symptom:** Can't re-initialize, crashes on bias=False **Fix:** Define `reset_parameters()`, call from `__init__` ## Red Flags - Stop and Reconsider **If you catch yourself doing ANY of these, STOP and follow patterns:** | Red Flag Action | Reality | What to Do Instead | |-----------------|---------|-------------------| | "I'll assign None to this module attribute" | Breaks PyTorch's module contract | Use `nn.Identity()` | | "F.relu() is simpler than nn.ReLU()" | True, but prevents inspection/hooks | Use module if you might need hooks | | "I'll store hook output directly" | Retains computation graph | Always `.detach()` first | | "I don't need to store the hook handle" | Can't remove hook later | Always store handles | | "I'll just initialize in __init__" | Can't re-initialize later | Use `reset_parameters()` | | "Bias always exists, right?" | No! `bias=False` is common | Check `if bias is not None:` | | "I'll save intermediate results as self.*" | Memory leak, not thread-safe | Use local variables only | | "I'll create this module in forward()" | Parameters not registered | All modules in `__init__` | **Critical rule:** Follow PyTorch conventions or face subtle bugs in production. ## Complete Example: Well-Designed ResNet Block ```python import torch import torch.nn as nn import math class ResNetBlock(nn.Module): """ Well-designed ResNet block following all best practices. Features: - Substitutable norm and activation layers - Proper use of nn.Identity() for skip connections - Hook-friendly (all operations are modules) - Correct initialization via reset_parameters() - Defensive bias checking """ def __init__( self, in_channels, out_channels, stride=1, norm_layer=nn.BatchNorm2d, activation=nn.ReLU, bias=False # Usually False with BatchNorm ): super().__init__() # Store config for potential serialization self.in_channels = in_channels self.out_channels = out_channels self.stride = stride # Main path: conv -> norm -> activation -> conv -> norm self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=bias) self.norm1 = norm_layer(out_channels) self.act1 = activation() self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias) self.norm2 = norm_layer(out_channels) # Skip connection (dimension matching) # ✅ CRITICAL: Use nn.Identity(), never None if stride != 1 or in_channels != out_channels: self.skip = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=bias), norm_layer(out_channels) ) else: self.skip = nn.Identity() # Final activation (applied after residual addition) self.act2 = activation() # ✅ Initialize weights following convention self.reset_parameters() def reset_parameters(self): """ Initialize weights using He initialization (good for ReLU). """ # Iterate through all conv layers for module in self.modules(): if isinstance(module, nn.Conv2d): # He initialization for ReLU nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') # ✅ Defensive: check bias exists if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.BatchNorm2d): # BatchNorm standard initialization nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward(self, x): """ Forward pass: residual connection with skip path. Note: All operations are modules, so can be hooked or modified. """ # Main path out = self.conv1(x) out = self.norm1(out) out = self.act1(out) # ✅ Module, not F.relu() out = self.conv2(out) out = self.norm2(out) # Skip connection (always works, no conditional) skip = self.skip(x) # ✅ Identity passes through if no projection needed # Residual addition and final activation out = out + skip out = self.act2(out) # ✅ Module, not F.relu() return out # Usage examples: # Standard ResNet block block1 = ResNetBlock(64, 128, stride=2) # With LayerNorm and GELU (Vision Transformer style) block2 = ResNetBlock(64, 128, norm_layer=nn.GroupNorm, activation=nn.GELU) # Can hook any operation: handle = block1.act1.register_forward_hook(lambda m, i, o: print(f"ReLU output shape: {o.shape}")) # Can re-initialize: block1.reset_parameters() # Can inspect architecture: for name, module in block1.named_modules(): print(f"{name}: {module}") ``` **Why this design is robust:** 1. ✅ No None assignments (uses `nn.Identity()`) 2. ✅ All operations are modules (hookable) 3. ✅ Substitutable components (norm, activation) 4. ✅ Proper initialization (`reset_parameters()`) 5. ✅ Defensive bias checking 6. ✅ Clear module hierarchy 7. ✅ Configuration stored (reproducibility) 8. ✅ No magic numbers or hardcoded choices ## Edge Cases and Advanced Scenarios ### Edge Case 1: Dynamic Module Lists (nn.ModuleList) **Scenario:** Need variable number of layers based on config. ```python # ❌ WRONG: Using Python list for modules class DynamicModel(nn.Module): def __init__(self, num_layers): super().__init__() self.layers = [] # ❌ Python list, parameters not registered! for i in range(num_layers): self.layers.append(nn.Linear(10, 10)) def forward(self, x): for layer in self.layers: x = layer(x) return x # model.parameters() is empty! DDP breaks! # ✅ CORRECT: Use nn.ModuleList class DynamicModel(nn.Module): def __init__(self, num_layers): super().__init__() self.layers = nn.ModuleList([ # ✅ Registers all parameters nn.Linear(10, 10) for _ in range(num_layers) ]) def forward(self, x): for layer in self.layers: x = layer(x) return x ``` **Rule:** Use `nn.ModuleList` for lists of modules, `nn.ModuleDict` for dicts. ### Edge Case 2: Hooks on nn.Sequential **Problem:** Hooking specific layers inside nn.Sequential. ```python model = nn.Sequential( nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 20), nn.ReLU(), nn.Linear(20, 10) ) # ❌ WRONG: Can't access by name easily # model.layer2.register_forward_hook(hook) # AttributeError # ✅ CORRECT: Access by index handle = model[2].register_forward_hook(hook) # Third layer (Linear 20->20) # ✅ BETTER: Use named modules for name, module in model.named_modules(): if isinstance(module, nn.Linear): print(f"Hooking {name}") module.register_forward_hook(hook) ``` **Best practice:** For hookable models, use explicit named attributes instead of Sequential: ```python class HookableModel(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Linear(10, 20) self.act1 = nn.ReLU() self.layer2 = nn.Linear(20, 20) # ✅ Named, easy to hook self.act2 = nn.ReLU() self.layer3 = nn.Linear(20, 10) def forward(self, x): x = self.act1(self.layer1(x)) x = self.act2(self.layer2(x)) return self.layer3(x) # Easy to hook specific layers: model.layer2.register_forward_hook(hook) ``` ### Edge Case 3: Hooks with In-Place Operations **Problem:** In-place operations modify hooked tensors. ```python class ModelWithInPlace(nn.Module): def forward(self, x): x = self.layer1(x) # Hook here x += 10 # ❌ In-place modification! x = self.layer2(x) return x # Hook only using detach(): def hook(module, input, output): features['layer1'] = output.detach() # ❌ Still shares memory! # After forward pass, features['layer1'] has been modified! # ✅ CORRECT: Detach AND clone def hook(module, input, output): features['layer1'] = output.detach().clone() # ✅ Independent copy ``` **Decision tree for hooks:** ``` Is output modified in-place later? ├─ Yes → Use .detach().clone() └─ No → Use .detach() (sufficient) Need gradients for analysis? ├─ Yes → Don't detach (but ensure short lifetime!) └─ No → Detach (prevents memory leak) ``` ### Edge Case 4: Partial State Dict Loading **Scenario:** Loading checkpoint with different architecture. ```python # Original model class ModelV1(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Linear(10, 20) self.decoder = nn.Linear(20, 10) # New model with additional layer class ModelV2(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Linear(10, 20) self.middle = nn.Linear(20, 20) # New layer! self.decoder = nn.Linear(20, 10) self.reset_parameters() def reset_parameters(self): # ✅ Initialize all layers for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) # Load V1 checkpoint into V2 model model_v2 = ModelV2() checkpoint = torch.load('model_v1.pth') # ✅ Use strict=False for partial loading model_v2.load_state_dict(checkpoint, strict=False) # ✅ Re-initialize new layers only model_v2.middle.reset_parameters() # New layer needs init ``` **Pattern:** When loading partial checkpoints: 1. Load with `strict=False` 2. Check which keys are missing/unexpected 3. Re-initialize only new layers (not loaded ones) ### Edge Case 5: Hook Removal During Forward Pass **Problem:** Removing hooks while iterating causes issues. ```python class Model(nn.Module): def __init__(self): super().__init__() self.layer = nn.Linear(10, 10) self.hook_handles = [] def add_temporary_hook(self): def hook(module, input, output): print("Hook called!") # ❌ WRONG: Removing handle inside hook for h in self.hook_handles: h.remove() # Dangerous during iteration! handle = self.layer.register_forward_hook(hook) self.hook_handles.append(handle) # ✅ CORRECT: Flag for removal, remove after forward pass class Model(nn.Module): def __init__(self): super().__init__() self.layer = nn.Linear(10, 10) self.hook_handles = [] self.hooks_to_remove = [] def add_temporary_hook(self): def hook(module, input, output): print("Hook called!") # ✅ Flag for removal self.hooks_to_remove.append(handle) handle = self.layer.register_forward_hook(hook) self.hook_handles.append(handle) def cleanup_hooks(self): """Call after forward pass""" for handle in self.hooks_to_remove: handle.remove() self.hook_handles.remove(handle) self.hooks_to_remove.clear() ``` **Rule:** Never modify hook handles during forward pass. Flag for removal and clean up after. ### Edge Case 6: Custom Modules with Buffers **Pattern:** Buffers are non-parameter tensors that should be saved/moved with model. ```python class RunningStatsModule(nn.Module): def __init__(self, num_features): super().__init__() # ❌ WRONG: Just store as attribute self.running_mean = torch.zeros(num_features) # Not registered! # ✅ CORRECT: Register as buffer self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) # Parameters (learnable) self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) def forward(self, x): # Update running stats (in training mode) if self.training: mean = x.mean(dim=0) var = x.var(dim=0) # ✅ In-place update of buffers self.running_mean.mul_(0.9).add_(mean, alpha=0.1) self.running_var.mul_(0.9).add_(var, alpha=0.1) # Normalize using running stats normalized = (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5) return normalized * self.weight + self.bias # Buffers are moved with model: model = RunningStatsModule(10) model.cuda() # ✅ running_mean and running_var moved to GPU # Buffers are saved in state_dict: torch.save(model.state_dict(), 'model.pth') # ✅ Includes buffers ``` **When to use buffers:** - Running statistics (BatchNorm-style) - Fixed embeddings (not updated by optimizer) - Positional encodings (not learned) - Masks or indices **Rule:** Use `register_buffer()` for tensors that aren't parameters but should be saved/moved. ## Common Rationalizations (Don't Do These) | Excuse | Reality | Correct Approach | |--------|---------|------------------| | "User wants quick solution, I'll use None" | Quick becomes slow when DDP breaks | Always use nn.Identity(), same speed | | "It's just a prototype, proper patterns later" | Prototype becomes production, tech debt compounds | Build correctly from start, no extra time | | "F.relu() is more Pythonic/simpler" | True, but prevents hooks and modification | Use nn.ReLU() if any chance of needing hooks | | "I'll fix initialization in training loop" | Defeats purpose of reset_parameters() | Put in reset_parameters(), 5 extra lines | | "Bias is almost always there" | False! Many models use bias=False | Check if bias is not None, always | | "Hooks are advanced, user won't use them" | Until they need debugging or feature extraction | Design hookable from start, no cost | | "I'll clean up hooks manually later" | Later never comes, memory leaks persist | Context manager takes 10 lines, bulletproof | | "This module is simple, no need for modularity" | Simple modules get extended and reused | Substitutable components from start | | "State dict loading always matches architecture" | False! Checkpoints get reused across versions | Implement reset_parameters() for partial loads | | "In-place ops are fine, I'll remember detach+clone" | Won't remember under pressure | Document decision in code, add comment | **Critical insight:** "Shortcuts for simplicity" become "bugs in production." Proper patterns take seconds more, prevent hours of debugging. ## Decision Frameworks ### Framework 1: Module vs Functional Operations **Question:** Should I use `nn.ReLU()` or `F.relu()`? ``` Will you ever need to: ├─ Register hooks on this operation? → Use nn.ReLU() ├─ Inspect architecture (model.named_modules())? → Use nn.ReLU() ├─ Swap activation (ReLU→GELU)? → Use nn.ReLU() ├─ Use quantization? → Use nn.ReLU() └─ None of above AND performance critical? → F.relu() acceptable ``` **Default:** When in doubt, use module version. Performance difference negligible. ### Framework 2: Hook Detachment Strategy **Question:** In my hook, should I use `detach()`, `detach().clone()`, or neither? ``` Do you need gradients for analysis? ├─ Yes → Don't detach (but ensure short lifetime!) └─ No → Continue... Will the output be modified in-place later? ├─ Yes → Use .detach().clone() ├─ Unsure → Use .detach().clone() (safer) └─ No → Use .detach() (sufficient) ``` **Example decision:** ```python # Scenario: Extract features for visualization (no gradients needed, no in-place) def hook(module, input, output): return output.detach() # ✅ Sufficient # Scenario: Extract features, model has in-place ops (x += y) def hook(module, input, output): return output.detach().clone() # ✅ Necessary # Scenario: Gradient analysis (rare!) def hook(module, input, output): return output # ⚠️ Keep gradients, but ensure short lifetime ``` ### Framework 3: Initialization Strategy Selection **Question:** Which initialization should I use? ``` Activation function? ├─ ReLU family → Kaiming (He) initialization ├─ Tanh/Sigmoid → Xavier (Glorot) initialization ├─ GELU/Swish → Xavier or Kaiming (experiment) └─ None/Linear → Xavier Layer type? ├─ Conv → Usually Kaiming with mode='fan_out' ├─ Linear → Kaiming or Xavier depending on activation ├─ Embedding → Normal(0, 1) or Xavier └─ LSTM/GRU → Xavier for gates Special considerations? ├─ ResNet-style → Last layer of block: small gain (e.g., 0.5) ├─ Transformer → Xavier uniform, specific scale for embeddings ├─ GAN → Careful initialization critical (see paper) └─ Pre-trained → Don't re-initialize! Load checkpoint ``` **Code example:** ```python def reset_parameters(self): for module in self.modules(): if isinstance(module, nn.Conv2d): # ReLU activation → Kaiming nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): # Check what activation follows (from self.config or hardcoded) if self.activation == 'relu': nn.init.kaiming_uniform_(module.weight, nonlinearity='relu') else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) ``` ### Framework 4: When to Use Buffers vs Parameters vs Attributes **Decision tree:** ``` Is it a tensor that needs to be saved with the model? └─ No → Regular attribute (self.x = value) └─ Yes → Continue... Should it be updated by optimizer? └─ Yes → nn.Parameter() └─ No → Continue... Should it move with model (.to(device))? └─ Yes → register_buffer() └─ No → Regular attribute Examples: - Model weights → nn.Parameter() - Running statistics (BatchNorm) → register_buffer() - Configuration dict → Regular attribute - Fixed positional encoding → register_buffer() - Dropout probability → Regular attribute - Learnable temperature → nn.Parameter() ``` ## Pressure Testing Scenarios ### Scenario 1: Time Pressure **User:** "I need this module quickly, just make it work." **Agent thought:** "I'll use None and functional ops, faster to write." **Reality:** Taking 30 seconds more to use nn.Identity() and nn.ReLU() prevents hours of debugging DDP issues. **Correct response:** Apply patterns anyway. They're not slower to write once familiar. ### Scenario 2: "Simple" Module **User:** "This is a simple block, don't overcomplicate it." **Agent thought:** "I'll hardcode ReLU and BatchNorm, it's just a prototype." **Reality:** Prototypes become production. Making activation/norm substitutable takes one extra line. **Correct response:** Design modularly from the start. "Simple" doesn't mean "brittle." ### Scenario 3: Existing Codebase **User:** "The existing code uses None for optional modules." **Agent thought:** "I should match existing style for consistency." **Reality:** Existing code may have bugs. Improving patterns is better than perpetuating anti-patterns. **Correct response:** Use correct patterns. Offer to refactor existing code if user wants. ### Scenario 4: "Just Getting Started" **User:** "I'm just experimenting, I'll clean it up later." **Agent thought:** "Proper patterns can wait until it works." **Reality:** Later never comes. Or worse, you can't iterate quickly because of accumulated tech debt. **Correct response:** Proper patterns don't slow down experimentation. They enable faster iteration. ## Red Flags Checklist Before writing `__init__` or `forward`, check yourself: ### Module Definition Red Flags - [ ] Am I assigning `None` to a module attribute? - **FIX:** Use `nn.Identity()` - [ ] Am I using functional ops (F.relu) without considering hooks? - **ASK:** Will this ever need inspection/modification? - [ ] Am I hardcoding architecture choices (ReLU, BatchNorm)? - **FIX:** Make them substitutable parameters - [ ] Am I creating modules in `forward()`? - **FIX:** All modules in `__init__` ### Hook Usage Red Flags - [ ] Am I storing hook output without detaching? - **FIX:** Use `.detach()` or `.detach().clone()` - [ ] Am I registering hooks without storing handles? - **FIX:** Store handles, clean up in `__exit__` - [ ] Am I using global variables in hook closures? - **FIX:** Encapsulate in a class - [ ] Am I modifying hook handles during forward pass? - **FIX:** Flag for removal, clean up after ### Initialization Red Flags - [ ] Am I initializing weights in `__init__`? - **FIX:** Define `reset_parameters()`, call from `__init__` - [ ] Am I accessing `.bias` without checking if it exists? - **FIX:** Check `if module.bias is not None:` - [ ] Am I using one initialization for all layers? - **ASK:** Should different layers have different strategies? ### State Management Red Flags - [ ] Am I storing intermediate results as `self.*`? - **FIX:** Use local variables only - [ ] Am I using Python list for modules? - **FIX:** Use `nn.ModuleList` - [ ] Do I have tensors that should be buffers but aren't? - **FIX:** Use `register_buffer()` **If ANY red flag is true, STOP and apply the pattern before proceeding.** ## Quick Reference Cards ### Card 1: Module Design Checklist ``` ✓ super().__init__() called first ✓ All modules defined in __init__ (not forward) ✓ No None assignments (use nn.Identity()) ✓ Substitutable components (norm_layer, activation args) ✓ reset_parameters() defined and called ✓ Defensive checks (if bias is not None) ✓ Buffers registered (register_buffer()) ✓ No self.* assignments in forward() ``` ### Card 2: Hook Checklist ``` ✓ Hook detaches output (.detach() or .detach().clone()) ✓ Hook handles stored in list ✓ Context manager for cleanup (__enter__/__exit__) ✓ No global state mutation ✓ Error handling (try/except in hook) ✓ Documented whether hook modifies output ``` ### Card 3: Initialization Checklist ``` ✓ reset_parameters() method defined ✓ Called from __init__ ✓ Iterates through modules or layers ✓ Checks if bias is not None ✓ Uses appropriate init strategy (Kaiming/Xavier) ✓ Documents why this initialization ✓ Can be called to re-initialize ``` ## References **PyTorch Documentation:** - nn.Module: https://pytorch.org/docs/stable/notes/modules.html - Hooks: https://pytorch.org/docs/stable/notes/modules.html#module-hooks - Initialization: https://pytorch.org/docs/stable/nn.init.html **Related Skills:** - tensor-operations-and-memory (memory management) - debugging-techniques (using hooks for debugging) - distributed-training-strategies (DDP-compatible module design) - checkpointing-and-reproducibility (state dict best practices)