## Overview This skill teaches you to implement custom autograd functions in PyTorch using `torch.autograd.Function`. You'll learn when custom functions are necessary, how to implement forward and backward passes correctly, mandatory numerical verification with gradcheck, and common pitfalls that break gradient computation. ## When This Skill Applies Use `torch.autograd.Function` when you encounter these situations: ### Symptoms That Require Custom Functions 1. **Custom operations not in PyTorch**: Implementing novel mathematical operations (special activations, custom loss components, domain-specific transformations) 2. **Wrapping external code**: Interfacing with C++/CUDA kernels, third-party libraries, or compiled extensions that PyTorch doesn't know about 3. **Custom gradient behavior**: Need non-standard gradient computation (gradient clipping in backward, sparsification, quantization, gradient routing) 4. **Memory optimization**: Implementing gradient checkpointing, fused operations, or selective materialization to reduce memory footprint 5. **Interfacing with non-differentiable code**: Wrapping operations that aren't naturally differentiable but have known gradient behavior ### When NOT to Use Custom Functions Don't use `torch.autograd.Function` when: - ❌ **Operation composable from existing ops**: If you can write it with standard PyTorch operations, autograd handles gradients automatically - ❌ **Simple function wrapping**: Just wrapping `torch.nn.functional` operations gains nothing - ❌ **Standard gradient computation**: No custom behavior needed - use regular PyTorch - ❌ **Avoiding learning curve**: Custom functions aren't "advanced only" - use when appropriate **Example**: ```python # DON'T: Unnecessary custom function class MyAdd(Function): # ❌ Pointless wrapper @staticmethod def forward(ctx, a, b): return a + b @staticmethod def backward(ctx, grad): return grad, grad # DO: Use PyTorch's autograd output = a + b # ✅ Autograd handles this correctly # DON'T: Reimplement existing operations class MyReLU(Function): # ❌ Use torch.nn.functional.relu ... # DO: Use built-in operations output = torch.relu(input) # ✅ Efficient and correct ``` ## Core Pattern: Complete Function Implementation ### Basic Template Every custom autograd function follows this pattern: ```python import torch from torch.autograd import Function class MyCustomFunction(Function): """ Custom autograd function template. Implements forward pass (computation) and backward pass (gradient). Context object (ctx) saves data between forward and backward. """ @staticmethod def forward(ctx, input, weight, bias=None): """ Forward pass: compute output from inputs. Args: ctx: Context object for saving tensors/data input: First input tensor weight: Second input tensor bias: Optional bias tensor Returns: output: Result tensor """ # Save tensors needed for backward pass ctx.save_for_backward(input, weight, bias) # Save non-tensor data as ctx attributes # ctx.some_value = some_non_tensor_data # Compute forward pass output = input.mm(weight.t()) if bias is not None: output += bias.unsqueeze(0).expand_as(output) return output @staticmethod def backward(ctx, grad_output): """ Backward pass: compute gradients using chain rule. Args: ctx: Context object with saved data grad_output: Gradient of loss w.r.t. output (dL/dY) Returns: Tuple of gradients for each input to forward(): (grad_input, grad_weight, grad_bias) Must return one gradient per forward() argument (None if not needed) """ # Retrieve saved tensors input, weight, bias = ctx.saved_tensors # Initialize gradients to None grad_input = grad_weight = grad_bias = None # Compute gradients only if needed (efficiency) if ctx.needs_input_grad[0]: # Check if input needs gradient grad_input = grad_output.mm(weight) # dL/dX = dL/dY @ W if ctx.needs_input_grad[1]: # Check if weight needs gradient grad_weight = grad_output.t().mm(input) # dL/dW = dL/dY.T @ X if bias is not None and ctx.needs_input_grad[2]: # Check if bias needs gradient grad_bias = grad_output.sum(0) # Sum over batch dimension # Must return gradient for each forward() input (including ctx) # First return is always None (ctx doesn't need gradient) # But since ctx is implicit, return one per actual argument return grad_input, grad_weight, grad_bias # Use the custom function my_func = MyCustomFunction.apply # Get callable output = my_func(input_tensor, weight_tensor, bias_tensor) loss = output.sum() loss.backward() # Calls MyCustomFunction.backward() ``` ### Critical Rules **Rule 1: Return gradient for EACH forward input** ```python # forward signature: forward(ctx, a, b, c, d=None) # backward must return: (grad_a, grad_b, grad_c, grad_d) # ✅ CORRECT: 4 inputs → 4 gradient returns def backward(ctx, grad_output): return grad_a, grad_b, grad_c, grad_d # ❌ WRONG: Missing gradients def backward(ctx, grad_output): return grad_a, grad_b # Only 2 - will crash! # ✅ CORRECT: Use None for unused gradients def backward(ctx, grad_output): return grad_a, None, grad_c, None # b and d don't need gradients ``` **Rule 2: Gradient shape must match input shape exactly** ```python # If input.shape = (32, 128) # Then grad_input.shape MUST BE (32, 128) # ✅ CORRECT: Shapes match assert grad_input.shape == input.shape # ❌ WRONG: Shape mismatch causes runtime error grad_input = some_computation() # Shape (32, 64) - WRONG! ``` **Rule 3: Check needs_input_grad before computing** ```python # Efficiency optimization: skip gradient computation if not needed # ✅ CORRECT: Check before computing if ctx.needs_input_grad[0]: grad_input = expensive_gradient_computation(...) else: grad_input = None # ❌ WASTEFUL: Always compute (slow) grad_input = expensive_gradient_computation(...) # Even if not needed ``` ## Context Object (ctx) Rules The context object `ctx` is how you pass data from forward to backward. Use it correctly or break everything. ### Rule 1: Use save_for_backward() for Tensors ONLY ```python # ✅ CORRECT: Save tensors with save_for_backward() @staticmethod def forward(ctx, input, weight): ctx.save_for_backward(input, weight) # Both are tensors return input @ weight # ❌ WRONG: Saving tensors as attributes @staticmethod def forward(ctx, input, weight): ctx.input = input # Breaks memory tracking ctx.weight = weight # Breaks memory tracking return input @ weight # Why it matters: save_for_backward() properly tracks tensor versions # and memory. Attribute assignment doesn't, leading to bugs/crashes. ``` ### Rule 2: Save Non-Tensor Data as Attributes ```python # ✅ CORRECT: Non-tensor data as attributes @staticmethod def forward(ctx, input, kernel_size, stride): ctx.save_for_backward(input) # Tensor ctx.kernel_size = kernel_size # Integer - use attribute ctx.stride = stride # Integer - use attribute return some_operation(input, kernel_size, stride) # ❌ WRONG: Trying to save non-tensors with save_for_backward() @staticmethod def forward(ctx, input, kernel_size, stride): ctx.save_for_backward(input, kernel_size, stride) # TypeError! # kernel_size and stride are ints, not tensors ``` ### Rule 3: Access saved_tensors Only in Backward ```python # ✅ CORRECT: Access in backward @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors # Available here grad_input = grad_output @ weight.t() return grad_input, None # ❌ WRONG: Access in forward @staticmethod def forward(ctx, input, weight): ctx.save_for_backward(input, weight) output = input @ weight saved = ctx.saved_tensors # AttributeError! Not available in forward return output ``` ### Rule 4: Never Modify Saved Tensors ```python # ❌ WRONG: Modifying saved tensor @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors input = input * 2 # Creates new tensor - OK input *= 2 # IN-PLACE modification - BREAKS AUTOGRAD! grad_input = compute_gradient(input, grad_output) return grad_input # ✅ CORRECT: Don't modify, or clone first @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors input_scaled = input * 2 # New tensor - safe grad_input = compute_gradient(input_scaled, grad_output) return grad_input ``` ### Complete ctx Example ```python class CompleteCtxExample(Function): @staticmethod def forward(ctx, input, weight, bias, stride=1, training=True): # Save tensors with save_for_backward() ctx.save_for_backward(input, weight, bias) # Save non-tensor data as attributes ctx.stride = stride # int ctx.training = training # bool # Compute forward output = some_computation(input, weight, bias, stride, training) return output @staticmethod def backward(ctx, grad_output): # Retrieve saved tensors input, weight, bias = ctx.saved_tensors # Retrieve non-tensor data stride = ctx.stride training = ctx.training # Compute gradients grad_input = None grad_weight = None grad_bias = None if ctx.needs_input_grad[0]: grad_input = compute_input_gradient(grad_output, weight, stride) if ctx.needs_input_grad[1]: grad_weight = compute_weight_gradient(grad_output, input, stride) if ctx.needs_input_grad[2]: grad_bias = compute_bias_gradient(grad_output) # Return gradients (None for stride and training - they're not tensors) return grad_input, grad_weight, grad_bias, None, None ``` ## Gradient Computation Patterns Understanding common gradient patterns helps implement backward() correctly. ### Pattern 1: Element-wise Operations Forward: `y = f(x)` (element-wise function) Backward: `grad_x = grad_y * f'(x)` (element-wise multiply) ```python # Example: Custom ReLU class CustomReLU(Function): @staticmethod def forward(ctx, input): # Save input to compute derivative in backward ctx.save_for_backward(input) # ReLU: max(0, x) output = input.clamp(min=0) return output @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors # Derivative: 1 if x > 0, else 0 grad_input = grad_output.clone() grad_input[input < 0] = 0 return grad_input # Example: Custom Sigmoid class CustomSigmoid(Function): @staticmethod def forward(ctx, input): output = torch.sigmoid(input) # Save output (more efficient than recomputing) ctx.save_for_backward(output) return output @staticmethod def backward(ctx, grad_output): output, = ctx.saved_tensors # Derivative: sigmoid(x) * (1 - sigmoid(x)) grad_input = grad_output * output * (1 - output) return grad_input # Example: Custom Tanh class CustomTanh(Function): @staticmethod def forward(ctx, input): output = torch.tanh(input) ctx.save_for_backward(output) return output @staticmethod def backward(ctx, grad_output): output, = ctx.saved_tensors # Derivative: 1 - tanh^2(x) grad_input = grad_output * (1 - output * output) return grad_input ``` ### Pattern 2: Matrix Operations (Chain Rule) Forward: `Y = X @ W` (matrix multiply) Backward: Apply chain rule with matrix transpose ```python class CustomLinear(Function): @staticmethod def forward(ctx, input, weight): # Save both tensors for backward ctx.save_for_backward(input, weight) # Forward: Y = X @ W output = input.mm(weight.t()) # (batch, in) @ (out, in).T = (batch, out) return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input = grad_weight = None # Gradient w.r.t. input: dL/dX = dL/dY @ W # Shapes: (batch, out) @ (out, in) = (batch, in) ✓ if ctx.needs_input_grad[0]: grad_input = grad_output.mm(weight) # Gradient w.r.t. weight: dL/dW = dL/dY.T @ X # Then transpose to match weight shape # Shapes: (out, batch) @ (batch, in) = (out, in) ✓ if ctx.needs_input_grad[1]: grad_weight = grad_output.t().mm(input) return grad_input, grad_weight # More complex: Matrix multiply with both inputs requiring gradients class CustomMatmul(Function): @staticmethod def forward(ctx, a, b): ctx.save_for_backward(a, b) # Forward: C = A @ B return torch.matmul(a, b) @staticmethod def backward(ctx, grad_output): a, b = ctx.saved_tensors grad_a = grad_b = None # Gradient w.r.t. A: dL/dA = dL/dC @ B.T if ctx.needs_input_grad[0]: grad_a = torch.matmul(grad_output, b.transpose(-2, -1)) # Gradient w.r.t. B: dL/dB = A.T @ dL/dC if ctx.needs_input_grad[1]: grad_b = torch.matmul(a.transpose(-2, -1), grad_output) return grad_a, grad_b ``` ### Pattern 3: Broadcasting Operations Forward: Operation with broadcasting (e.g., adding bias) Backward: Sum over broadcasted dimensions ```python class CustomBiasAdd(Function): @staticmethod def forward(ctx, input, bias): # input: (batch, channels, height, width) # bias: (channels,) ctx.save_for_backward(input, bias) # Broadcasting adds bias to each channel output = input + bias.view(1, -1, 1, 1) return output @staticmethod def backward(ctx, grad_output): input, bias = ctx.saved_tensors grad_input = grad_bias = None # Gradient w.r.t. input: just pass through if ctx.needs_input_grad[0]: grad_input = grad_output # Gradient w.r.t. bias: sum over broadcasted dimensions # grad_output: (batch, channels, height, width) # grad_bias should be: (channels,) # Sum over batch (0), height (2), width (3) if ctx.needs_input_grad[1]: grad_bias = grad_output.sum(dim=(0, 2, 3)) return grad_input, grad_bias # General broadcasting pattern class CustomBroadcastOp(Function): @staticmethod def forward(ctx, input, param): # Save shapes to determine broadcast dimensions ctx.input_shape = input.shape ctx.param_shape = param.shape ctx.save_for_backward(input, param) # Some operation with broadcasting output = input * param # param broadcasts to input shape return output @staticmethod def backward(ctx, grad_output): input, param = ctx.saved_tensors grad_input = grad_param = None if ctx.needs_input_grad[0]: grad_input = grad_output * param if ctx.needs_input_grad[1]: # Sum grad_output over dimensions that were broadcasted grad_param = grad_output * input # Find which dimensions were broadcasted # Sum over those dimensions to match param shape ndim_diff = len(ctx.input_shape) - len(ctx.param_shape) for i in range(ndim_diff): grad_param = grad_param.sum(0) # Sum leading dimensions for i, (input_dim, param_dim) in enumerate( zip(ctx.input_shape[ndim_diff:], ctx.param_shape) ): if param_dim == 1 and input_dim > 1: grad_param = grad_param.sum(i, keepdim=True) return grad_input, grad_param ``` ### Pattern 4: Reduction Operations Forward: Reduce dimensions (sum, mean, max, etc.) Backward: Expand gradient back to original shape ```python class CustomSum(Function): @staticmethod def forward(ctx, input, dim, keepdim=False): ctx.input_shape = input.shape ctx.dim = dim ctx.keepdim = keepdim # Sum along dimension output = input.sum(dim=dim, keepdim=keepdim) return output @staticmethod def backward(ctx, grad_output): # Gradient of sum: distribute grad_output to all elements grad_input = grad_output # Expand back to original shape if not ctx.keepdim: # Add back the reduced dimension grad_input = grad_input.unsqueeze(ctx.dim) # Expand to original shape (broadcasts the gradient) grad_input = grad_input.expand(ctx.input_shape) return grad_input, None, None class CustomMean(Function): @staticmethod def forward(ctx, input, dim, keepdim=False): ctx.input_shape = input.shape ctx.dim = dim ctx.keepdim = keepdim output = input.mean(dim=dim, keepdim=keepdim) return output @staticmethod def backward(ctx, grad_output): # Gradient of mean: distribute evenly to all elements grad_input = grad_output if not ctx.keepdim: grad_input = grad_input.unsqueeze(ctx.dim) grad_input = grad_input.expand(ctx.input_shape) # Divide by number of elements that were averaged n = ctx.input_shape[ctx.dim] grad_input = grad_input / n return grad_input, None, None class CustomMax(Function): @staticmethod def forward(ctx, input, dim, keepdim=False): # Save both max values and indices output, indices = input.max(dim=dim, keepdim=keepdim) ctx.save_for_backward(indices) ctx.input_shape = input.shape ctx.dim = dim ctx.keepdim = keepdim return output @staticmethod def backward(ctx, grad_output): indices, = ctx.saved_tensors # Only the maximum element gets gradient grad_input = torch.zeros(ctx.input_shape, device=grad_output.device) if not ctx.keepdim: grad_output = grad_output.unsqueeze(ctx.dim) indices = indices.unsqueeze(ctx.dim) # Scatter gradient to max indices grad_input.scatter_(ctx.dim, indices, grad_output) return grad_input, None, None ``` ### Pattern 5: Convolution-like Operations Complex operations that involve multiple dimensions and strides. ```python class CustomConv1d(Function): @staticmethod def forward(ctx, input, weight, bias=None, stride=1, padding=0): # Use PyTorch's conv for forward output = torch.nn.functional.conv1d( input, weight, bias, stride, padding ) # Save what's needed for backward ctx.save_for_backward(input, weight, bias) ctx.stride = stride ctx.padding = padding return output @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors stride = ctx.stride padding = ctx.padding grad_input = grad_weight = grad_bias = None # Gradient w.r.t. input: convolve grad_output with weight if ctx.needs_input_grad[0]: grad_input = torch.nn.grad.conv1d_input( input.shape, weight, grad_output, stride, padding ) # Gradient w.r.t. weight: convolve input with grad_output if ctx.needs_input_grad[1]: grad_weight = torch.nn.grad.conv1d_weight( input, weight.shape, grad_output, stride, padding ) # Gradient w.r.t. bias: sum grad_output over batch and spatial dims if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum((0, 2)) return grad_input, grad_weight, grad_bias, None, None ``` ## Numerical Gradient Verification (MANDATORY) **NEVER skip this step.** `gradcheck` verifies your gradient computation is correct by comparing analytical gradients (your backward()) against numerical gradients (finite differences). ### Why gradcheck is Mandatory ```python # You implement backward() and it "looks right" class MyFunction(Function): @staticmethod def backward(ctx, grad_output): # Gradient formula looks correct mathematically grad = some_computation() return grad # But: # ❌ Transpose is wrong # ❌ Shape doesn't match # ❌ Sign is flipped # ❌ Missing a term # ❌ Broadcasting is incorrect # These bugs are invisible without gradcheck! # Your model trains but produces wrong results. # Debugging takes days without knowing gradients are wrong. ``` ### Basic gradcheck Usage ```python import torch from torch.autograd import gradcheck def test_my_function(): """Test custom function with gradcheck.""" # Create test inputs with requires_grad=True # Use double precision for numerical stability input = torch.randn(20, 20, dtype=torch.double, requires_grad=True) weight = torch.randn(30, 20, dtype=torch.double, requires_grad=True) # Run gradcheck test = gradcheck( MyCustomFunction.apply, # Your function (input, weight), # Tuple of inputs eps=1e-6, # Finite difference step size atol=1e-4, # Absolute tolerance rtol=1e-3, # Relative tolerance raise_exception=True # Raise error on failure (recommended) ) if test: print("✅ Gradient check PASSED!") else: print("❌ Gradient check FAILED!") raise AssertionError("Gradient check failed") # Run before using your function test_my_function() ``` ### Complete gradcheck Pattern ```python import torch from torch.autograd import gradcheck, gradgradcheck class MyCustomFunction(Function): @staticmethod def forward(ctx, input, weight, bias): ctx.save_for_backward(input, weight, bias) output = input.mm(weight.t()) if bias is not None: output += bias return output @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: grad_input = grad_output.mm(weight) if ctx.needs_input_grad[1]: grad_weight = grad_output.t().mm(input) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) return grad_input, grad_weight, grad_bias def test_my_custom_function(): """Comprehensive gradient testing.""" # Test 1: Basic gradcheck print("Test 1: Basic gradient check...") input = torch.randn(10, 5, dtype=torch.double, requires_grad=True) weight = torch.randn(3, 5, dtype=torch.double, requires_grad=True) bias = torch.randn(3, dtype=torch.double, requires_grad=True) assert gradcheck( MyCustomFunction.apply, (input, weight, bias), eps=1e-6, atol=1e-4, raise_exception=True ), "Basic gradcheck failed" print("✅ Basic gradcheck passed") # Test 2: Without bias (optional parameter) print("Test 2: Gradient check without bias...") assert gradcheck( MyCustomFunction.apply, (input, weight, None), eps=1e-6, atol=1e-4, raise_exception=True ), "Gradcheck without bias failed" print("✅ Gradcheck without bias passed") # Test 3: Different input shapes print("Test 3: Different input shapes...") input_large = torch.randn(50, 20, dtype=torch.double, requires_grad=True) weight_large = torch.randn(10, 20, dtype=torch.double, requires_grad=True) bias_large = torch.randn(10, dtype=torch.double, requires_grad=True) assert gradcheck( MyCustomFunction.apply, (input_large, weight_large, bias_large), eps=1e-6, atol=1e-4, raise_exception=True ), "Gradcheck with large inputs failed" print("✅ Gradcheck with different shapes passed") # Test 4: Second-order gradients (if needed) print("Test 4: Second-order gradient check...") try: assert gradgradcheck( MyCustomFunction.apply, (input, weight, bias), eps=1e-6, atol=1e-4, raise_exception=True ), "Second-order gradcheck failed" print("✅ Second-order gradcheck passed") except NotImplementedError: print("⚠️ Second-order gradients not implemented (OK if not needed)") print("\n🎉 All gradient checks passed!") # ALWAYS run this before using your function in training test_my_custom_function() ``` ### gradcheck Parameters ```python gradcheck( func, # Your Function.apply inputs, # Tuple of input tensors eps=1e-6, # Finite difference step: f(x+eps) - f(x-eps) atol=1e-5, # Absolute tolerance for comparison rtol=1e-3, # Relative tolerance for comparison raise_exception=True, # Raise on failure (recommended for testing) check_sparse_nnz=False, # Check sparse tensor non-zeros nondet_tol=0.0, # Tolerance for non-deterministic operations check_undefined_grad=True, # Check that undefined grads are None check_grad_dtypes=True, # Check gradient dtypes match ) # Key insights: # - Use double precision (dtype=torch.double) for numerical stability # - eps=1e-6 is good default; smaller for more precision, larger for stability # - atol/rtol balance: looser tolerances for complex operations # - raise_exception=True catches bugs immediately in testing ``` ### Debugging gradcheck Failures ```python def debug_gradcheck(): """Step-by-step debugging when gradcheck fails.""" # Step 1: Check forward pass works print("Step 1: Verify forward pass...") input = torch.randn(5, 3, dtype=torch.double, requires_grad=True) weight = torch.randn(4, 3, dtype=torch.double, requires_grad=True) output = MyCustomFunction.apply(input, weight) print(f"Output shape: {output.shape}") print(f"Output contains NaN: {torch.isnan(output).any()}") print(f"Output contains Inf: {torch.isinf(output).any()}") assert output.shape == (5, 4), "Forward shape wrong" assert not torch.isnan(output).any(), "Forward produces NaN" # Step 2: Check backward runs without error print("\nStep 2: Verify backward runs...") loss = output.sum() loss.backward() print(f"Input grad shape: {input.grad.shape}") print(f"Weight grad shape: {weight.grad.shape}") assert input.grad.shape == input.shape, "Input gradient shape mismatch" assert weight.grad.shape == weight.shape, "Weight gradient shape mismatch" # Step 3: Check gradient magnitudes print("\nStep 3: Check gradient magnitudes...") print(f"Input grad: mean={input.grad.mean():.6f}, std={input.grad.std():.6f}") print(f"Weight grad: mean={weight.grad.mean():.6f}, std={weight.grad.std():.6f}") # Should be reasonable numbers (not 1e10 or 1e-20) # Step 4: Manual numerical gradient check for one element print("\nStep 4: Manual gradient check for one element...") input_test = torch.randn(3, 2, dtype=torch.double) weight_test = torch.randn(2, 2, dtype=torch.double, requires_grad=True) eps = 1e-6 # Analytical gradient output = MyCustomFunction.apply(input_test, weight_test) loss = output.sum() loss.backward() analytical_grad = weight_test.grad[0, 0].item() # Numerical gradient (finite difference) weight_test_plus = weight_test.clone().detach() weight_test_plus[0, 0] += eps output_plus = MyCustomFunction.apply(input_test, weight_test_plus) loss_plus = output_plus.sum() weight_test_minus = weight_test.clone().detach() weight_test_minus[0, 0] -= eps output_minus = MyCustomFunction.apply(input_test, weight_test_minus) loss_minus = output_minus.sum() numerical_grad = (loss_plus - loss_minus) / (2 * eps) numerical_grad = numerical_grad.item() print(f"Analytical gradient: {analytical_grad:.10f}") print(f"Numerical gradient: {numerical_grad:.10f}") print(f"Difference: {abs(analytical_grad - numerical_grad):.10e}") if abs(analytical_grad - numerical_grad) > 1e-4: print("❌ Large difference - gradient implementation likely wrong") else: print("✅ Small difference - gradient likely correct") # Step 5: Run gradcheck with verbose output print("\nStep 5: Run gradcheck...") input_check = torch.randn(3, 2, dtype=torch.double, requires_grad=True) weight_check = torch.randn(2, 2, dtype=torch.double, requires_grad=True) try: result = gradcheck( MyCustomFunction.apply, (input_check, weight_check), eps=1e-6, atol=1e-4, raise_exception=True ) print("✅ gradcheck passed!") except RuntimeError as e: print(f"❌ gradcheck failed with error:\n{e}") # Error message shows which gradient failed and by how much debug_gradcheck() ``` ### Common gradcheck Failure Reasons ```python # Failure 1: Wrong gradient formula class WrongGradient(Function): @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors # ❌ WRONG: No transpose grad_input = grad_output @ weight # Should be weight.t() return grad_input, None # gradcheck fails: analytical ≠ numerical # Failure 2: Shape mismatch class WrongShape(Function): @staticmethod def backward(ctx, grad_output): # ❌ WRONG: Returns wrong shape return grad_output.sum(), None # Should be grad_output.shape == input.shape # gradcheck fails: shape error # Failure 3: In-place operation class InplaceOperation(Function): @staticmethod def backward(ctx, grad_output): grad_output[grad_output < 0] = 0 # ❌ IN-PLACE return grad_output # gradcheck fails: modified by inplace operation # Failure 4: Not using saved tensors correctly class WrongSaved(Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input.clone()) # Saved clone return input * 2 @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors # Using saved tensor is OK, but if logic depends on # input's original properties and clone loses them, fails return grad_output * 2 # May pass or fail depending on what was lost in clone # Failure 5: Forgot to return gradients for all inputs class MissingGradient(Function): @staticmethod def forward(ctx, input, weight, bias): # ... return output @staticmethod def backward(ctx, grad_output): # ❌ WRONG: Only returns 2 gradients for 3 inputs return grad_input, grad_weight # gradcheck fails: tuple length mismatch ``` ## Common Pitfalls ### Pitfall 1: In-Place Operations In-place operations modify tensors that other operations depend on, breaking autograd. ```python # ❌ WRONG: In-place operation in forward class InplaceForward(Function): @staticmethod def forward(ctx, input): input[input < 0] = 0 # IN-PLACE - breaks autograd! return input # ❌ WRONG: In-place operation in backward class InplaceBackward(Function): @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_output[input < 0] = 0 # IN-PLACE - breaks autograd! return grad_output # ✅ CORRECT: Create new tensor class CorrectInplace(Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) output = input.clone() # Create new tensor output[output < 0] = 0 # Modify copy, not original return output @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() # Create new tensor grad_input[input < 0] = 0 # Modify copy return grad_input # Even better: Use non-in-place operations class BestInplace(Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return input.clamp(min=0) # Non-in-place @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output * (input > 0).float() # Non-in-place return grad_input ``` ### Pitfall 2: Gradient Shape Mismatch Gradient must match input shape exactly. ```python # ❌ WRONG: Gradient shape doesn't match input class WrongShape(Function): @staticmethod def forward(ctx, input): # input: (32, 128) ctx.save_for_backward(input) return input.sum() # scalar @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors # grad_output is scalar, but need (32, 128) return grad_output # ❌ WRONG: scalar ≠ (32, 128) # ✅ CORRECT: Expand to match input shape class CorrectShape(Function): @staticmethod def forward(ctx, input): ctx.input_shape = input.shape return input.sum() @staticmethod def backward(ctx, grad_output): # Expand scalar to input shape grad_input = grad_output.expand(ctx.input_shape) return grad_input # Always verify shapes def verify_shapes(ctx, grad_output, *grad_inputs): """Helper to verify gradient shapes.""" for i, (grad, tensor) in enumerate(zip(grad_inputs, ctx.saved_tensors)): if grad is not None: assert grad.shape == tensor.shape, \ f"Gradient {i} shape {grad.shape} != input shape {tensor.shape}" ``` ### Pitfall 3: Not Checking needs_input_grad Computing gradients when not needed wastes computation. ```python # ❌ WASTEFUL: Always compute all gradients class AlwaysCompute(Function): @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors # Compute all gradients even if not needed grad_input = expensive_computation_1(grad_output, weight) grad_weight = expensive_computation_2(grad_output, input) grad_bias = expensive_computation_3(grad_output) return grad_input, grad_weight, grad_bias # ✅ EFFICIENT: Check needs_input_grad first class EfficientCompute(Function): @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None # Only compute if needed if ctx.needs_input_grad[0]: grad_input = expensive_computation_1(grad_output, weight) if ctx.needs_input_grad[1]: grad_weight = expensive_computation_2(grad_output, input) if ctx.needs_input_grad[2]: grad_bias = expensive_computation_3(grad_output) return grad_input, grad_weight, grad_bias # Example where it matters def example_needs_input_grad(): """Demonstrate needs_input_grad optimization.""" input = torch.randn(100, 100, requires_grad=True) weight = torch.randn(100, 100, requires_grad=False) # No gradient needed # Without check: computes grad_weight unnecessarily # With check: skips grad_weight computation (faster) output = MyFunction.apply(input, weight) loss = output.sum() loss.backward() # weight.grad is None because requires_grad=False assert weight.grad is None ``` ### Pitfall 4: Using .data Instead of .detach() `.data` bypasses autograd tracking incorrectly. ```python # ❌ WRONG: Using .data class UsingData(Function): @staticmethod def forward(ctx, input): # .data returns tensor without autograd tracking # But doesn't properly detach from computation graph ctx.save_for_backward(input.data) # ❌ WRONG return input * 2 @staticmethod def backward(ctx, grad_output): input_data, = ctx.saved_tensors # May produce incorrect gradients return grad_output * 2 # ✅ CORRECT: Use .detach() or save normally class UsingDetach(Function): @staticmethod def forward(ctx, input): # If you need to save without tracking gradient ctx.save_for_backward(input.detach()) # ✅ Properly detaches # Or just save normally if gradient tracking is OK ctx.save_for_backward(input) # ✅ Most common return input * 2 @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors return grad_output * 2 # When to detach in custom functions class WhenToDetach(Function): @staticmethod def forward(ctx, input, target): # Save input normally (gradient needed) # Detach target (gradient not needed, just reference) ctx.save_for_backward(input, target.detach()) loss = (input - target).pow(2).mean() return loss @staticmethod def backward(ctx, grad_output): input, target = ctx.saved_tensors # Compute gradient w.r.t. input only grad_input = 2 * (input - target) * grad_output return grad_input, None ``` ### Pitfall 5: Modifying grad_output Never modify grad_output in-place; it's used by other operations. ```python # ❌ WRONG: Modifying grad_output in-place class ModifyGradOutput(Function): @staticmethod def backward(ctx, grad_output): # ❌ WRONG: In-place modification grad_output *= 2 return grad_output # ✅ CORRECT: Create new tensor class DontModifyGradOutput(Function): @staticmethod def backward(ctx, grad_output): # ✅ Create new tensor grad_input = grad_output * 2 return grad_input # Why it matters def why_grad_output_matters(): """Demonstrate why modifying grad_output breaks autograd.""" # Consider: z = f(g(x)) # Backward: dz/dx = dz/dg * dg/dx # # If f.backward() modifies grad_output (dz/dg), # then g.backward() receives wrong gradient! x = torch.randn(5, requires_grad=True) # g(x) y = x * 2 # f(g(x)) - uses custom function that modifies grad_output z = BadFunction.apply(y) z.backward() # x.grad is now WRONG because BadFunction modified grad_output # that was passed to y's backward ``` ### Pitfall 6: Forgetting to Return None for Non-Tensor Arguments Must return gradient for every forward() argument. ```python # ❌ WRONG: Not enough return values class NotEnoughReturns(Function): @staticmethod def forward(ctx, input, kernel_size, stride): # 3 arguments (excluding ctx) ctx.save_for_backward(input) ctx.kernel_size = kernel_size ctx.stride = stride return some_operation(input, kernel_size, stride) @staticmethod def backward(ctx, grad_output): # ❌ WRONG: Only returns 1 value for 3 inputs return grad_output # Crashes: expected 3 values # ✅ CORRECT: Return gradient (or None) for each input class EnoughReturns(Function): @staticmethod def forward(ctx, input, kernel_size, stride): ctx.save_for_backward(input) ctx.kernel_size = kernel_size ctx.stride = stride return some_operation(input, kernel_size, stride) @staticmethod def backward(ctx, grad_output): # ✅ Return 3 values: grad for input, None for kernel_size, None for stride return grad_output, None, None # Rule of thumb def count_returns(): """ backward() must return one value per forward() argument (excluding ctx). forward(ctx, a, b, c, d=None) → backward must return (grad_a, grad_b, grad_c, grad_d) Use None for: - Non-tensor arguments (ints, strings, etc.) - Optional arguments that were None - Tensors that don't need gradients """ pass ``` ### Pitfall 7: Incorrect Broadcasting in Gradient Gradient must account for broadcasting that occurred in forward. ```python # ❌ WRONG: Doesn't handle broadcasting correctly class WrongBroadcast(Function): @staticmethod def forward(ctx, input, weight): # input: (32, 64, 10, 10) # weight: (64, 1, 1) - broadcasts to (32, 64, 10, 10) ctx.save_for_backward(input, weight) output = input * weight # Broadcasting happens return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input = grad_output * weight # ✅ This is fine # ❌ WRONG: grad_weight shape is (32, 64, 10, 10), should be (64, 1, 1) grad_weight = grad_output * input return grad_input, grad_weight # Shape mismatch! # ✅ CORRECT: Sum over broadcasted dimensions class CorrectBroadcast(Function): @staticmethod def forward(ctx, input, weight): ctx.save_for_backward(input, weight) ctx.input_shape = input.shape ctx.weight_shape = weight.shape output = input * weight return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input = grad_output * weight # ✅ Sum over dimensions that were broadcasted grad_weight = grad_output * input # weight shape: (64, 1, 1), grad_weight current: (32, 64, 10, 10) # Sum over batch (0), height (2), width (3) grad_weight = grad_weight.sum(dim=(0, 2, 3), keepdim=True) # Now grad_weight shape: (1, 64, 1, 1) # Squeeze batch dimension: (64, 1, 1) ✅ grad_weight = grad_weight.squeeze(0) return grad_input, grad_weight # General broadcasting gradient pattern def sum_to_shape(tensor, shape): """Sum tensor to match target shape (handles broadcasting).""" # Find dimensions that were added while tensor.dim() > len(shape): tensor = tensor.sum(0) # Find dimensions that were size 1 and got broadcasted for i, (t_dim, s_dim) in enumerate(zip(tensor.shape, shape)): if s_dim == 1 and t_dim > 1: tensor = tensor.sum(i, keepdim=True) return tensor class GeneralBroadcast(Function): @staticmethod def backward(ctx, grad_output): input, param = ctx.saved_tensors grad_input = grad_output * param grad_param = grad_output * input # Sum to match param's original shape grad_param = sum_to_shape(grad_param, param.shape) return grad_input, grad_param ``` ## Memory Efficiency Patterns Custom functions enable memory optimizations not possible with standard autograd. ### Pattern 1: Gradient Checkpointing Trade computation for memory by recomputing forward in backward. ```python class CheckpointedFunction(Function): """ Gradient checkpointing: Don't save activations, recompute in backward. Memory: O(1) instead of O(n) for n layers Time: 2x forward pass (once in forward, once in backward) """ @staticmethod def forward(ctx, input, *args): # Save inputs and parameters, NOT intermediate activations ctx.save_for_backward(input, *args) # Compute forward pass (activations not saved) # For complex operations, this may compute many intermediate values output = expensive_computation(input, *args) # Intermediate activations are garbage collected # Saves memory! return output @staticmethod def backward(ctx, grad_output): # Retrieve inputs input, *args = ctx.saved_tensors # Recompute forward pass to get intermediate values # This time, track gradients with torch.enable_grad(): # Detach input, then set requires_grad # (Required for computing gradients) input = input.detach().requires_grad_(True) # Recompute forward output = expensive_computation(input, *args) # Now compute gradients using autograd grad_input, = torch.autograd.grad( outputs=output, inputs=input, grad_outputs=grad_output, retain_graph=False ) # Return gradient for input (and None for args if they're parameters) return (grad_input,) + (None,) * len(args) # Example: Checkpointed Sequential Layers class CheckpointedSequential(Function): @staticmethod def forward(ctx, input, *layers): """ Forward through multiple layers without saving intermediate activations. Normal: Saves n-1 activations for n layers Checkpointed: Saves only input and parameters """ ctx.layers = layers ctx.save_for_backward(input) # Forward through all layers output = input for layer in layers: output = layer(output) return output @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors layers = ctx.layers # Recompute forward to get intermediate activations with torch.enable_grad(): input = input.detach().requires_grad_(True) # Forward pass again, this time tracking activations activations = [input] output = input for layer in layers: output = layer(output) activations.append(output) # Backward through layers grad = grad_output for i in reversed(range(len(layers))): # Compute gradient through layer i grad, = torch.autograd.grad( outputs=activations[i+1], inputs=activations[i], grad_outputs=grad, retain_graph=True ) grad_input = grad # No gradients for layers (they're not inputs to forward) return (grad_input,) + (None,) * len(layers) # PyTorch provides torch.utils.checkpoint.checkpoint for this # But understanding the pattern helps for custom cases ``` ### Pattern 2: Selective Saving Only save what's needed for backward; recompute or omit the rest. ```python class SelectiveSaving(Function): """ Save only essential tensors; recompute others in backward. """ @staticmethod def forward(ctx, input, weight, bias): # Compute intermediate values weighted = input @ weight activated = torch.relu(weighted) output = activated + bias # DON'T save everything: # ctx.save_for_backward(input, weight, bias, weighted, activated) # ❌ Saves 5 tensors # ✅ SAVE ONLY WHAT'S NEEDED: # Can recompute 'weighted' from input and weight # Can recompute 'activated' from weighted ctx.save_for_backward(input, weight, bias, activated) # ✅ Saves 4 tensors (or even fewer) # Or even more selective: # ctx.save_for_backward(input, weight, bias) # ctx.save_for_backward(activated > 0) # Save mask, not full tensor return output @staticmethod def backward(ctx, grad_output): input, weight, bias, activated = ctx.saved_tensors grad_input = grad_weight = grad_bias = None # Gradient through bias addition if ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) # Gradient through ReLU (need activation mask) grad_activated = grad_output.clone() grad_activated[activated <= 0] = 0 # Gradient through matmul if ctx.needs_input_grad[0]: grad_input = grad_activated @ weight.t() if ctx.needs_input_grad[1]: grad_weight = input.t() @ grad_activated return grad_input, grad_weight, grad_bias # Even more selective: Save boolean mask instead of full tensor class UltraSelective(Function): @staticmethod def forward(ctx, input, weight): output = torch.relu(input @ weight) # Instead of saving full 'output' tensor: # ctx.save_for_backward(input, weight, output) # Large memory # ✅ Save only boolean mask (1 bit per element vs 32 bits) ctx.save_for_backward(input, weight) ctx.relu_mask = (output > 0) # Boolean tensor (much smaller) return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors relu_mask = ctx.relu_mask # Use mask to apply ReLU gradient grad_weighted = grad_output * relu_mask.float() grad_input = grad_weighted @ weight.t() if ctx.needs_input_grad[0] else None grad_weight = input.t() @ grad_weighted if ctx.needs_input_grad[1] else None return grad_input, grad_weight ``` ### Pattern 3: Detaching Tensors That Don't Need Gradients Detach tensors that are used in forward but don't need gradients. ```python class DetachPattern(Function): """ Detach tensors that don't need gradient computation. """ @staticmethod def forward(ctx, input, target, weight): """ Compute loss between input and target. Target doesn't need gradients (it's labels). """ # Save input and weight (need gradients) ctx.save_for_backward(input, weight) # Detach target (doesn't need gradients) # This breaks the autograd connection, saving memory ctx.target = target.detach() # Compute weighted loss loss = ((input - target) ** 2 * weight).mean() return loss @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors target = ctx.target # Already detached # Compute gradients diff = input - target grad_input = None grad_weight = None if ctx.needs_input_grad[0]: grad_input = 2 * diff * weight * grad_output # No gradient for target (ctx.needs_input_grad[1] is False) if ctx.needs_input_grad[2]: grad_weight = (diff ** 2) * grad_output return grad_input, None, grad_weight # When to detach def detach_guidelines(): """ Detach tensors when: 1. They're labels/targets (no gradient needed) 2. They're constants (no gradient needed) 3. They're from non-differentiable sources 4. You explicitly don't want gradients to flow through them Don't detach when: 1. Gradient is needed for that tensor 2. Gradient will flow through that path """ pass ``` ## Advanced Patterns ### Pattern 1: Double Backward (Second-Order Derivatives) Support computing gradients of gradients. ```python class DoubleBackwardFunction(Function): """ Function that supports double backward for second-order derivatives. Example: Hessian computation, meta-learning, some regularization terms. """ @staticmethod def forward(ctx, input): ctx.save_for_backward(input) # Forward: y = x^2 return input ** 2 @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors # First derivative: dy/dx = 2x grad_input = 2 * input * grad_output # For double backward, need to return tensor that supports backward # Not just detached scalar return grad_input # For explicit double backward support (optional, autograd often handles it) @staticmethod def jvp(ctx, *grad_inputs): """ Jacobian-vector product for forward-mode AD. Needed for some second-order derivative computations. """ # Usually not needed; autograd handles it pass # Test double backward def test_double_backward(): """Test that double backward works.""" x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) # First forward y = DoubleBackwardFunction.apply(x) # First backward: dy/dx grad_x, = torch.autograd.grad(y.sum(), x, create_graph=True) # grad_x = 2x = [2, 4, 6] # Second backward: d(dy/dx)/dx = d(2x)/dx = 2 grad_grad_x, = torch.autograd.grad(grad_x.sum(), x) # grad_grad_x = [2, 2, 2] print(f"First derivative: {grad_x}") # [2, 4, 6] print(f"Second derivative: {grad_grad_x}") # [2, 2, 2] assert torch.allclose(grad_x, 2 * x) assert torch.allclose(grad_grad_x, torch.ones_like(x) * 2) print("✅ Double backward works!") test_double_backward() # Example: Function where double backward matters class HessianVectorProduct(Function): """ Efficiently compute Hessian-vector product: H @ v where H = ∇²f(x) is the Hessian matrix. Used in: second-order optimization, meta-learning. """ @staticmethod def forward(ctx, input, vector): ctx.save_for_backward(input, vector) # Placeholder forward (actual computation in backward) return input @staticmethod def backward(ctx, grad_output): input, vector = ctx.saved_tensors # Compute gradient grad_input = grad_output # This gradient can be backpropagated through again # for Hessian computation return grad_input, None # Using double backward for Hessian def compute_hessian(f, x): """ Compute Hessian of scalar function f at point x. Uses double backward: H[i,j] = ∂²f/∂x_i∂x_j """ # First derivative grad_x, = torch.autograd.grad(f(x), x, create_graph=True) # Second derivative (Hessian) hessian = [] for i in range(x.shape[0]): grad_grad_x, = torch.autograd.grad( grad_x[i], x, retain_graph=True ) hessian.append(grad_grad_x) return torch.stack(hessian) ``` ### Pattern 2: Custom Backward Hooks Modify gradients during backward pass without changing the function. ```python def gradient_clipping_hook(grad): """ Hook to clip gradients to [-1, 1]. Applied to tensor, not Function. """ return torch.clamp(grad, -1, 1) def gradient_noise_hook(grad): """Add noise to gradients (for regularization).""" noise = torch.randn_like(grad) * 0.01 return grad + noise def gradient_logging_hook(grad): """Log gradient statistics.""" print(f"Gradient: mean={grad.mean():.6f}, std={grad.std():.6f}, max={grad.abs().max():.6f}") return grad # Return unchanged # Using hooks def use_hooks(): """Example of using gradient hooks.""" input = torch.randn(10, 10, requires_grad=True) weight = torch.randn(10, 10, requires_grad=True) # Register hooks input.register_hook(gradient_clipping_hook) weight.register_hook(gradient_logging_hook) # Forward and backward output = MyFunction.apply(input, weight) loss = output.sum() loss.backward() # Hooks are applied during backward # input.grad is clipped to [-1, 1] # weight.grad statistics are logged # Hooks in custom functions class FunctionWithHook(Function): @staticmethod def forward(ctx, input, clip_grad=False): ctx.clip_grad = clip_grad ctx.save_for_backward(input) return input * 2 @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output * 2 # Apply custom modification based on context if ctx.clip_grad: grad_input = torch.clamp(grad_input, -1, 1) return grad_input, None # Removing hooks def manage_hooks(): """Add and remove hooks.""" tensor = torch.randn(5, requires_grad=True) # Add hook (returns handle) hook_handle = tensor.register_hook(gradient_clipping_hook) # Use tensor loss = (tensor ** 2).sum() loss.backward() # Hook is applied # Remove hook hook_handle.remove() # Hook no longer applied in subsequent backwards tensor.grad.zero_() loss = (tensor ** 2).sum() loss.backward() # Hook NOT applied ``` ### Pattern 3: Custom Gradient for Part of Computation Stop gradients or customize them for specific operations. ```python class StopGradient(Function): """ Stop gradient flow (like tf.stop_gradient or tensor.detach()). Forward: pass through Backward: return None (no gradient) """ @staticmethod def forward(ctx, input): return input @staticmethod def backward(ctx, grad_output): # Don't pass gradient through return None # Usage: Stop gradient flow def stop_gradient_example(): x = torch.randn(5, requires_grad=True) # y doesn't get gradients from z y = StopGradient.apply(x) z = y ** 2 z.sum().backward() assert x.grad is None # No gradient flowed to x class StraightThroughEstimator(Function): """ Straight-through estimator for non-differentiable operations. Forward: non-differentiable operation (e.g., binarization) Backward: pretend it was identity (pass gradient through) Used for: quantization, binarization, discrete operations. """ @staticmethod def forward(ctx, input): # Non-differentiable: binarize to {-1, 1} return torch.sign(input) @staticmethod def backward(ctx, grad_output): # Pretend forward was identity: gradient passes through unchanged return grad_output # Usage: Train networks with binary weights def straight_through_example(): """ Train a network with binary weights using straight-through estimator. """ weight = torch.randn(10, 10, requires_grad=True) input = torch.randn(32, 10) # Binarize weight for forward pass binary_weight = StraightThroughEstimator.apply(weight) # binary_weight ∈ {-1, 1} # Use binary weight output = input @ binary_weight loss = output.sum() loss.backward() # weight.grad exists (even though sign() isn't differentiable) # Gradient passed through as if sign() was identity assert weight.grad is not None class CustomGradientScale(Function): """ Scale gradient by a factor without changing forward. Forward: pass through Backward: scale gradient by alpha Used for: gradient reversal layers (adversarial training), controlling gradient flow in different branches. """ @staticmethod def forward(ctx, input, alpha): ctx.alpha = alpha return input @staticmethod def backward(ctx, grad_output): # Scale gradient return grad_output * ctx.alpha, None # Usage: Gradient reversal layer def gradient_reversal_layer(input, alpha=-1.0): """ Reverses gradients (multiplies by -1). Used in domain adaptation for adversarial training. """ return CustomGradientScale.apply(input, alpha) ``` ## Complete Real-World Examples ### Example 1: Custom Swish Activation with Learnable Beta ```python import torch from torch.autograd import Function, gradcheck import torch.nn as nn class SwishFunction(Function): """ Custom Swish activation: f(x) = x * sigmoid(β * x) where β is a learnable parameter. Forward: y = x * σ(βx) Backward: dy/dx = σ(βx) + x * σ(βx) * (1 - σ(βx)) * β dy/dβ = x² * σ(βx) * (1 - σ(βx)) """ @staticmethod def forward(ctx, input, beta): """ Args: input: Input tensor beta: Learnable scaling parameter (scalar tensor) Returns: output: Swish activation output """ ctx.save_for_backward(input, beta) # Compute sigmoid(beta * input) sigmoid_beta_input = torch.sigmoid(beta * input) # Save for backward (more efficient than recomputing) ctx.save_for_backward(input, beta, sigmoid_beta_input) # f(x) = x * sigmoid(beta * x) output = input * sigmoid_beta_input return output @staticmethod def backward(ctx, grad_output): """ Compute gradients using chain rule. """ input, beta, sigmoid_beta_input = ctx.saved_tensors grad_input = grad_beta = None # Gradient w.r.t. input if ctx.needs_input_grad[0]: # d/dx[x * σ(βx)] = σ(βx) + x * σ'(βx) * β # where σ'(z) = σ(z) * (1 - σ(z)) sigmoid_derivative = sigmoid_beta_input * (1 - sigmoid_beta_input) grad_input = grad_output * ( sigmoid_beta_input + input * sigmoid_derivative * beta ) # Gradient w.r.t. beta if ctx.needs_input_grad[1]: # d/dβ[x * σ(βx)] = x * σ'(βx) * x = x² * σ(βx) * (1 - σ(βx)) sigmoid_derivative = sigmoid_beta_input * (1 - sigmoid_beta_input) grad_beta = grad_output * (input ** 2) * sigmoid_derivative # Sum over all elements (beta is scalar) grad_beta = grad_beta.sum() return grad_input, grad_beta class Swish(nn.Module): """ Swish activation module with learnable beta parameter. """ def __init__(self, beta=1.0): super().__init__() self.beta = nn.Parameter(torch.tensor(beta)) def forward(self, input): return SwishFunction.apply(input, self.beta) # Test with gradcheck def test_swish(): """Verify Swish implementation with gradcheck.""" print("Testing Swish activation...") # Test 1: Basic gradcheck input = torch.randn(10, 5, dtype=torch.double, requires_grad=True) beta = torch.tensor(1.0, dtype=torch.double, requires_grad=True) assert gradcheck( SwishFunction.apply, (input, beta), eps=1e-6, atol=1e-4, raise_exception=True ), "Swish gradcheck failed" print("✅ Basic gradcheck passed") # Test 2: Different beta values for beta_val in [0.5, 1.0, 2.0]: beta = torch.tensor(beta_val, dtype=torch.double, requires_grad=True) assert gradcheck( SwishFunction.apply, (input, beta), eps=1e-6, atol=1e-4, raise_exception=True ), f"Swish gradcheck failed for beta={beta_val}" print("✅ Multiple beta values passed") # Test 3: Use in module module = Swish(beta=1.0) input_single = torch.randn(32, 128, requires_grad=True) output = module(input_single) loss = output.sum() loss.backward() assert input_single.grad is not None assert module.beta.grad is not None print("✅ Module usage works") print("\n🎉 Swish activation fully tested!") test_swish() # Usage example def use_swish_in_model(): """Use Swish in a neural network.""" model = nn.Sequential( nn.Linear(784, 256), Swish(beta=1.0), # Learnable beta nn.Linear(256, 128), Swish(beta=1.0), nn.Linear(128, 10) ) # Train model... # Beta parameters will be learned along with weights ``` ### Example 2: Numerically Stable LogSumExp ```python class LogSumExp(Function): """ Numerically stable log-sum-exp operation. Forward: log(sum(exp(x_i))) Uses max trick: log(sum(exp(x_i))) = max(x) + log(sum(exp(x_i - max(x)))) Backward: softmax(x_i) """ @staticmethod def forward(ctx, input, dim): """ Args: input: Input tensor dim: Dimension to sum over Returns: logsumexp: log(sum(exp(input))) along dim """ # Max trick for numerical stability max_val, _ = input.max(dim=dim, keepdim=True) input_shifted = input - max_val # Compute log-sum-exp sumexp = torch.exp(input_shifted).sum(dim=dim, keepdim=True) logsumexp = torch.log(sumexp) + max_val # Save softmax for backward softmax = torch.exp(input_shifted) / sumexp ctx.save_for_backward(softmax) ctx.dim = dim return logsumexp.squeeze(dim) @staticmethod def backward(ctx, grad_output): """ Gradient of log-sum-exp is softmax. """ softmax, = ctx.saved_tensors # Expand grad_output to match softmax shape grad_output_expanded = grad_output.unsqueeze(ctx.dim) # Gradient: softmax * grad_output grad_input = softmax * grad_output_expanded return grad_input, None # Test def test_logsumexp(): """Test LogSumExp implementation.""" print("Testing LogSumExp...") input = torch.randn(10, 5, dtype=torch.double, requires_grad=True) dim = 1 # gradcheck assert gradcheck( lambda x: LogSumExp.apply(x, dim), (input,), eps=1e-6, atol=1e-4, raise_exception=True ), "LogSumExp gradcheck failed" print("✅ LogSumExp gradcheck passed") # Compare with PyTorch's implementation custom_result = LogSumExp.apply(input, dim) torch_result = torch.logsumexp(input, dim=dim) assert torch.allclose(custom_result, torch_result, atol=1e-5), \ "LogSumExp doesn't match PyTorch" print("✅ LogSumExp matches PyTorch implementation") print("\n🎉 LogSumExp fully tested!") test_logsumexp() ``` ### Example 3: Fused Linear + ReLU (Memory Efficient) ```python class FusedLinearReLU(Function): """ Fused linear + ReLU operation. Saves memory by not materializing intermediate activations. Forward: ReLU(X @ W + b) Memory: Only saves mask (boolean) and input/weights, not intermediate """ @staticmethod def forward(ctx, input, weight, bias): """ Args: input: (batch, in_features) weight: (out_features, in_features) bias: (out_features,) Returns: output: (batch, out_features) """ # Compute linear linear_output = input.mm(weight.t()) if bias is not None: linear_output += bias # Apply ReLU and save mask (not full tensor!) output = torch.relu(linear_output) relu_mask = (linear_output > 0) # Boolean (1 bit per element) # Save only input, weight, bias, and mask # NOT saving linear_output (saves memory) ctx.save_for_backward(input, weight, bias) ctx.relu_mask = relu_mask return output @staticmethod def backward(ctx, grad_output): """ Backward through ReLU and linear. """ input, weight, bias = ctx.saved_tensors relu_mask = ctx.relu_mask grad_input = grad_weight = grad_bias = None # Gradient through ReLU (use mask) grad_linear = grad_output * relu_mask.float() # Gradient through linear if ctx.needs_input_grad[0]: grad_input = grad_linear.mm(weight) if ctx.needs_input_grad[1]: grad_weight = grad_linear.t().mm(input) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_linear.sum(0) return grad_input, grad_weight, grad_bias # Test def test_fused_linear_relu(): """Test fused operation.""" print("Testing FusedLinearReLU...") input = torch.randn(20, 10, dtype=torch.double, requires_grad=True) weight = torch.randn(15, 10, dtype=torch.double, requires_grad=True) bias = torch.randn(15, dtype=torch.double, requires_grad=True) # gradcheck assert gradcheck( FusedLinearReLU.apply, (input, weight, bias), eps=1e-6, atol=1e-4, raise_exception=True ), "FusedLinearReLU gradcheck failed" print("✅ FusedLinearReLU gradcheck passed") # Compare with separate operations fused_output = FusedLinearReLU.apply(input, weight, bias) separate_output = torch.relu(input.mm(weight.t()) + bias) assert torch.allclose(fused_output, separate_output, atol=1e-5), \ "Fused output doesn't match separate operations" print("✅ Fused matches separate operations") print("\n🎉 FusedLinearReLU fully tested!") test_fused_linear_relu() ``` ## Debugging Custom Functions ### Systematic Debugging Workflow When your custom function has bugs, follow this workflow: ```python def debug_custom_function(): """ Step-by-step debugging of custom autograd functions. """ print("=== DEBUGGING CUSTOM AUTOGRAD FUNCTION ===\n") # Step 1: Verify forward pass works print("Step 1: Testing forward pass...") try: input = torch.randn(5, 3, dtype=torch.double) weight = torch.randn(4, 3, dtype=torch.double) output = MyCustomFunction.apply(input, weight) print(f"✅ Forward pass works") print(f" Input shape: {input.shape}") print(f" Weight shape: {weight.shape}") print(f" Output shape: {output.shape}") print(f" Output contains NaN: {torch.isnan(output).any()}") print(f" Output contains Inf: {torch.isinf(output).any()}") # Check expected shape expected_shape = (5, 4) # Based on your operation assert output.shape == expected_shape, \ f"Wrong output shape: {output.shape} != {expected_shape}" print(f" Output shape correct: {expected_shape}") except Exception as e: print(f"❌ Forward pass failed: {e}") return print() # Step 2: Verify backward runs without error print("Step 2: Testing backward pass...") try: input = torch.randn(5, 3, dtype=torch.double, requires_grad=True) weight = torch.randn(4, 3, dtype=torch.double, requires_grad=True) output = MyCustomFunction.apply(input, weight) loss = output.sum() loss.backward() print(f"✅ Backward pass runs") print(f" Input grad shape: {input.grad.shape}") print(f" Weight grad shape: {weight.grad.shape}") # Check gradient shapes assert input.grad.shape == input.shape, \ f"Input gradient shape mismatch: {input.grad.shape} != {input.shape}" assert weight.grad.shape == weight.shape, \ f"Weight gradient shape mismatch: {weight.grad.shape} != {weight.shape}" print(f" Gradient shapes correct") # Check for NaN/Inf assert not torch.isnan(input.grad).any(), "Input gradient contains NaN" assert not torch.isnan(weight.grad).any(), "Weight gradient contains NaN" assert not torch.isinf(input.grad).any(), "Input gradient contains Inf" assert not torch.isinf(weight.grad).any(), "Weight gradient contains Inf" print(f" Gradients are finite") except Exception as e: print(f"❌ Backward pass failed: {e}") return print() # Step 3: Check gradient magnitudes print("Step 3: Checking gradient magnitudes...") print(f" Input grad - mean: {input.grad.mean():.6f}, std: {input.grad.std():.6f}, max: {input.grad.abs().max():.6f}") print(f" Weight grad - mean: {weight.grad.mean():.6f}, std: {weight.grad.std():.6f}, max: {weight.grad.abs().max():.6f}") # Reasonable gradient magnitudes (problem-dependent) if input.grad.abs().max() > 1e6: print(f" ⚠️ Input gradient very large (may indicate bug)") if input.grad.abs().max() < 1e-6: print(f" ⚠️ Input gradient very small (may indicate vanishing gradient)") print() # Step 4: Manual numerical gradient check for one element print("Step 4: Manual numerical gradient check...") input_test = torch.randn(3, 2, dtype=torch.double, requires_grad=True) weight_test = torch.randn(2, 2, dtype=torch.double, requires_grad=True) # Analytical gradient output = MyCustomFunction.apply(input_test, weight_test) loss = output.sum() loss.backward() analytical_grad = weight_test.grad[0, 0].item() # Numerical gradient eps = 1e-6 weight_plus = weight_test.clone().detach() weight_plus[0, 0] += eps output_plus = MyCustomFunction.apply(input_test, weight_plus) loss_plus = output_plus.sum() weight_minus = weight_test.clone().detach() weight_minus[0, 0] -= eps output_minus = MyCustomFunction.apply(input_test, weight_minus) loss_minus = output_minus.sum() numerical_grad = ((loss_plus - loss_minus) / (2 * eps)).item() diff = abs(analytical_grad - numerical_grad) print(f" Analytical gradient: {analytical_grad:.10f}") print(f" Numerical gradient: {numerical_grad:.10f}") print(f" Absolute difference: {diff:.10e}") if diff < 1e-4: print(f" ✅ Small difference - gradient likely correct") else: print(f" ❌ Large difference - gradient likely WRONG") print() # Step 5: Full gradcheck print("Step 5: Running full gradcheck...") try: input_check = torch.randn(10, 5, dtype=torch.double, requires_grad=True) weight_check = torch.randn(8, 5, dtype=torch.double, requires_grad=True) result = gradcheck( MyCustomFunction.apply, (input_check, weight_check), eps=1e-6, atol=1e-4, raise_exception=True ) print(f" ✅ gradcheck PASSED!") print(f"\n🎉 All checks passed! Function is correct.") except RuntimeError as e: print(f" ❌ gradcheck FAILED") print(f" Error: {e}") print(f"\n Debug hints:") print(f" - Check gradient computation formulas") print(f" - Verify all transposes are correct") print(f" - Ensure shapes match everywhere") print(f" - Check for in-place operations") print(f" - Verify saved tensors are correct") # Run debugging debug_custom_function() ``` ### Common Error Messages and Solutions ```python """ ERROR 1: "one of the variables needed for gradient computation has been modified by an inplace operation" CAUSE: In-place operation in forward or backward SOLUTION: Replace in-place ops with non-in-place versions - Replace: tensor[mask] = value - With: tensor = tensor.clone(); tensor[mask] = value - Or use: tensor * mask instead of masking ERROR 2: "grad can be implicitly created only for scalar outputs" CAUSE: Calling .backward() on non-scalar tensor without grad_output SOLUTION: Either sum to scalar or provide grad_output - loss = output.sum(); loss.backward() - Or: output.backward(torch.ones_like(output)) ERROR 3: "Expected to get X gradient(s) for backward, but got Y" CAUSE: backward() returns wrong number of gradients SOLUTION: Return one gradient per forward() argument (excluding ctx) - forward(ctx, a, b, c) → backward must return (grad_a, grad_b, grad_c) - Use None for arguments that don't need gradients ERROR 4: "Sizes of tensors must match except in dimension X" CAUSE: Shape mismatch in gradient computation SOLUTION: Ensure grad_input.shape == input.shape - Print shapes to debug: print(f"grad shape: {grad.shape}, input shape: {input.shape}") - Handle broadcasting by summing over broadcasted dimensions ERROR 5: "RuntimeError: Function returned an invalid gradient at index X - got ... but expected shape ..." CAUSE: Gradient shape doesn't match input shape SOLUTION: Verify gradient shape matches input exactly - assert grad_input.shape == input.shape - Use .view(), .reshape(), .expand() to fix shape ERROR 6: "gradcheck failed" CAUSE: Analytical gradient ≠ numerical gradient SOLUTION: Debug gradient computation - Check math formulas (derivatives) - Verify transposes in matrix operations - Test manually with small tensors - Run debug_custom_function() above ERROR 7: "AttributeError: 'Context' object has no attribute 'saved_tensors'" CAUSE: Accessing saved_tensors in forward (only available in backward) SOLUTION: Only access saved_tensors in backward() - forward: ctx.save_for_backward(...) - backward: ctx.saved_tensors ERROR 8: "TypeError: save_for_backward() takes 1 positional argument but X were given" CAUSE: Passing non-tensors to save_for_backward() SOLUTION: Only save tensors with save_for_backward(), use attributes for others - ctx.save_for_backward(tensor1, tensor2) # Tensors only - ctx.some_value = non_tensor_data # Non-tensors as attributes """ ``` ## When NOT to Use Custom Functions Recognize when custom functions are unnecessary: ```python # DON'T: Wrapping simple PyTorch operations class UnnecessaryAdd(Function): # ❌ Pointless @staticmethod def forward(ctx, a, b): return a + b @staticmethod def backward(ctx, grad): return grad, grad # DO: Use PyTorch directly output = a + b # ✅ Autograd handles this # DON'T: Reimplementing existing activations class UnnecessaryReLU(Function): # ❌ Use torch.relu @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return input.clamp(min=0) @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors return grad_output * (input > 0).float() # DO: Use built-in output = torch.relu(input) # ✅ Optimized C++ implementation # DON'T: Operations that compose from standard ops class UnnecessaryGELU(Function): # ❌ Can compose from existing ops @staticmethod def forward(ctx, input): # GELU = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) # This is composed entirely from standard ops ctx.save_for_backward(input) return 0.5 * input * (1 + torch.tanh(...)) @staticmethod def backward(ctx, grad_output): # Complex gradient computation ... # DO: Let autograd handle composition def gelu(input): # ✅ Autograd computes gradients automatically return 0.5 * input * (1 + torch.tanh( math.sqrt(2 / math.pi) * (input + 0.044715 * input ** 3) )) # Or even better: use built-in output = torch.nn.functional.gelu(input) # ✅ Most efficient # WHEN TO USE: True custom operations class ClippedGradientLinear(Function): # ✅ Custom gradient behavior """ Linear operation with clipped gradients. Can't compose from standard ops - requires custom backward. """ @staticmethod def forward(ctx, input, weight): ctx.save_for_backward(input, weight) return input @ weight.t() @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors # Custom: clip gradients (can't do with composition) grad_input = torch.clamp(grad_output @ weight, -1, 1) grad_weight = grad_output.t() @ input return grad_input, grad_weight ``` ## gradcheck Time Investment **Reality check on "no time for gradcheck"**: When under deadline pressure, this rationalization is powerful but wrong: ```python # Time cost analysis gradcheck_time = "< 1 second" # Typical operation debugging_time = "hours to days" # Finding gradient bugs without gradcheck # ROI calculation time_investment = 1 # second time_saved = 3600 * 24 # potentially days roi = time_saved / time_investment # 86,400x return! ``` **Under deadline pressure, correct workflow**: 1. Implement Function correctly (5-10 minutes) 2. Run gradcheck (1 second) 3. Deploy with confidence **NOT**: 1. Skip verification (saves 1 second) 2. Deploy immediately 3. Spend days debugging gradient bugs in production (costs hours/days) **Time spent on gradcheck**: <1 second (negligible) **Time saved from catching bugs early**: hours to days (enormous) **The math is clear**: Always run gradcheck, even under extreme time pressure. ## Multiple Custom Functions Workflow **When implementing multiple functions**: ✅ **CORRECT: One-at-a-time** 1. Implement Function 1 2. Test Function 1 with gradcheck 3. Verify Function 1 passes 4. Move to Function 2 5. Repeat for each function ❌ **WRONG: Batch approach** 1. Implement all functions 2. Test all together 3. Debug mess of overlapping bugs 4. Waste hours figuring out which function is broken 5. Fix bugs one at a time anyway (should have started here) **Why one-at-a-time wins**: - Bugs caught immediately (when you wrote the code, easy to debug) - Know each function works before building on it - Same total time, but bugs isolated (no "which function broke?" confusion) - Build confidence incrementally - Can use earlier functions while implementing later ones **Example**: ```python # Implementing 5 activations # ✅ CORRECT SwishBeta() → test → ✅ → Mish() → test → ✅ → GELU() → test → ✅ ... # Each bug found immediately after writing that function # ❌ WRONG SwishBeta() + Mish() + GELU() + ELU() + SELU() → test all # Bug found in one, but which? Have to debug all to find it # 5x the debugging complexity ``` ## Handling Approximate Gradients **Special case**: External libraries with approximate gradients (finite differences, Monte Carlo estimates). When wrapping external code with approximate gradients: ### Workflow for Approximate Gradients ```python import torch from torch.autograd import Function class ApproximateGradientWrapper(Function): """ Wrap external library with approximate gradients. Key insights: 1. Standard gradcheck WILL fail (approximate ≠ analytical) 2. But can still verify wrapper implementation 3. Must quantify gradient quality 4. Must assess if quality is acceptable """ @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return external_library_forward(input) @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors # Library provides approximate gradient grad_input = external_library_gradient(input, grad_output) return grad_input # Can't run standard gradcheck, but CAN verify: def test_approximate_gradient_wrapper(): """Verification workflow for approximate gradients.""" # 1. Verify wrapper mechanics (forward/backward run) input = torch.randn(5, requires_grad=True) output = ApproximateGradientWrapper.apply(input) output.sum().backward() assert input.grad is not None print("✅ Wrapper mechanics work") # 2. Quantify gradient quality (compare with numerical) input_test = torch.randn(3, requires_grad=True) output = ApproximateGradientWrapper.apply(input_test) output.sum().backward() library_gradient = input_test.grad.clone() # Compute our own numerical gradient eps = 1e-6 numerical_gradient = compute_numerical_gradient(input_test, eps) # Measure error error = (library_gradient - numerical_gradient).abs().max() print(f"Gradient error: {error:.6e}") # 3. Assess acceptability if error < 1e-3: print("✅ Approximate gradient high quality (good enough)") elif error < 1e-2: print("⚠️ Approximate gradient has noticeable error (may affect optimization)") else: print("❌ Approximate gradient poor quality (likely to cause issues)") raise ValueError("Gradient quality unacceptable") # 4. Document in code print("\n📝 Documentation:") print(f" - Gradients are approximate (error: {error:.6e})") print(" - Standard gradcheck will fail (expected)") print(" - Wrapper implementation verified correct") print(" - Gradient quality measured and acceptable") return error # Run verification gradient_error = test_approximate_gradient_wrapper() ``` ### Key Points for Approximate Gradients 1. **Standard gradcheck will fail** - This is expected (approximate ≠ analytical) 2. **Test wrapper implementation** - Verify mechanics (forward/backward run) 3. **Quantify gradient quality** - Measure error vs numerical gradient 4. **Assess acceptability** - Is error tolerable for your use case? 5. **Document limitations** - Record gradient error magnitude in code/docs **Don't skip verification** - Adapt it to approximate case. **"Good enough" requires evidence** - Measure error, don't assume. ## Rationalization Resistance | Rationalization | Reality | Counter-Response | |----------------|---------|------------------| | "Autograd will figure it out" | Only for standard ops; custom ops need Function | Use Function for non-standard operations. Autograd needs explicit backward implementation. | | "Gradient looks mathematically correct" | Implementation bugs invisible without testing | Always run gradcheck. Math correctness ≠ implementation correctness. | | "gradcheck is slow, skip for speed" | Catches bugs early; debugging later costs more | gradcheck is fast (<1s). Finding gradient bugs without it takes hours/days. | | "No time for gradcheck, deadline NOW" | gradcheck takes <1s, debugging takes hours | 1 second now saves hours of production debugging. Always run gradcheck. | | "Too complex for me" | Pattern is standardized; template works | Follow template. Thousands of successful implementations exist. | | "In-place is more efficient" | Breaks autograd graph; causes crashes | Never use in-place in custom functions. Memory savings negligible, bugs catastrophic. | | "Shape will probably work out" | Must match exactly; no flexibility | Gradient shape must equal input shape exactly. Verify with assertions. | | "ctx details don't matter much" | Incorrect usage breaks everything | ctx.save_for_backward() is mandatory for tensors. Attributes break memory tracking. | | "My manual test is good enough" | Misses edge cases gradcheck catches | Manual tests catch obvious bugs. gradcheck catches subtle numerical errors. | | "Batch test all functions together" | Overlapping bugs hard to debug | Test one at a time. Bugs isolated immediately, same total time. | | "Don't need needs_input_grad check" | Wastes computation; slower training | Always check needs_input_grad. Free optimization, no downside. | | "Approximate gradients, can't verify" | Can verify wrapper and measure quality | Adapt verification. Test wrapper mechanics, quantify error, assess acceptability. | | "Second-order derivatives too advanced" | Same pattern as first-order | Not advanced. Same template, test with gradgradcheck. Accessible to all. | | "Can skip double backward support" | Breaks higher-order derivatives if needed | If you might need Hessian/meta-learning, support double backward from start. | | "Detach doesn't matter here" | Controls gradient flow; critical | Understand when to detach. Impacts what gets gradients. | | "I'll verify gradients during training" | Training metrics hide gradient bugs | Verify before training. Gradient bugs cause subtle issues (slow convergence, wrong behavior). | | "Test in production, faster iteration" | Production debugging catastrophic | Test before deployment. Production gradient bugs cause model failure. | ## Red Flags Checklist **Stop and verify if you see:** 1. ⚠️ **Not using torch.autograd.Function** for custom operations - Writing normal functions for truly custom ops (external code, novel math) 2. ⚠️ **No gradcheck before using** the function - Skipping numerical verification - "Testing during training" instead of proper gradcheck 3. ⚠️ **In-place operations** in forward or backward - `tensor[mask] = value`, `tensor += other`, `tensor.mul_(other)` - Modifying `grad_output` in backward 4. ⚠️ **Wrong number of return values** from backward - Not returning gradient for each forward() input - Missing None for non-tensor arguments 5. ⚠️ **Not using ctx.save_for_backward()** for tensors - Saving tensors as `ctx.tensor = tensor` instead - Saving non-tensors with save_for_backward() 6. ⚠️ **Gradient shape doesn't match** input shape - Not verifying `grad_input.shape == input.shape` - Forgetting to expand/sum for broadcasting 7. ⚠️ **Missing needs_input_grad checks** - Always computing all gradients even when not needed - Not checking `ctx.needs_input_grad[i]` before expensive computation 8. ⚠️ **Using .data instead of .detach()** - Accessing `.data` attribute - Not understanding difference between .data and .detach() 9. ⚠️ **Accessing ctx.saved_tensors in forward** - Trying to use saved tensors before backward - Not understanding ctx lifecycle 10. ⚠️ **Modifying saved tensors** in backward - In-place operations on tensors from ctx.saved_tensors - Breaking gradient graph 11. ⚠️ **Ignoring gradcheck failures** - "Gradient close enough" - Not investigating why gradcheck failed 12. ⚠️ **Using custom Function for standard operations** - Reimplementing built-in operations unnecessarily - Not checking if PyTorch already provides it 13. ⚠️ **Batching implementation without incremental testing** - Implementing multiple functions before testing any - "Will test them all together" approach 14. ⚠️ **Skipping gradcheck under time pressure** - "Deadline is tight, verify later" - "No time for gradcheck" 15. ⚠️ **Assuming approximate gradients can't be verified** - "Library provides gradients, can't test" - Not measuring gradient quality 16. ⚠️ **Avoiding second-order derivatives due to perceived complexity** - "Too advanced for me" - Not attempting gradgradcheck 17. ⚠️ **Deploying to production without verification** - "Test with real data" - Skipping numerical verification ## Summary Checklist Before deploying a custom autograd function: - [ ] Verified custom Function is actually needed (can't compose from standard ops) - [ ] Implemented forward() correctly (saves needed tensors/data) - [ ] Used ctx.save_for_backward() for ALL tensors - [ ] Saved non-tensor data as ctx attributes - [ ] Implemented backward() with correct gradient formulas - [ ] Verified gradient shape matches input shape exactly - [ ] Returned gradient for EVERY forward() input (None if not needed) - [ ] Added needs_input_grad checks for expensive computations - [ ] Avoided ALL in-place operations - [ ] **Ran gradcheck and verified it PASSES** - [ ] Tested with different input shapes - [ ] Tested with optional parameters (if any) - [ ] Verified no NaN/Inf in outputs or gradients - [ ] Checked gradient magnitudes are reasonable - [ ] Tested double backward if needed - [ ] Added documentation explaining when to use this function - [ ] Considered memory efficiency (detach, selective saving) ## Final Notes **Custom autograd functions are powerful but require precision**: 1. Use them when truly needed (custom ops, external code, memory optimization, custom gradients) 2. Follow the template pattern (don't reinvent) 3. **Always run gradcheck** (non-negotiable) 4. Understand ctx rules (save_for_backward for tensors, attributes for non-tensors) 5. Verify shapes (gradient must match input) 6. Avoid in-place operations (they break autograd) 7. Check needs_input_grad (optimization) 8. Test thoroughly before using in training **The gradient implementation is as important as the forward computation.** Bugs in backward() are silent and catastrophic - they cause models to learn wrong things. gradcheck is your safety net; never skip it. When in doubt: implement, run gradcheck, debug until it passes, then use with confidence.