89 KiB
Overview
This skill teaches you to implement custom autograd functions in PyTorch using torch.autograd.Function. You'll learn when custom functions are necessary, how to implement forward and backward passes correctly, mandatory numerical verification with gradcheck, and common pitfalls that break gradient computation.
When This Skill Applies
Use torch.autograd.Function when you encounter these situations:
Symptoms That Require Custom Functions
-
Custom operations not in PyTorch: Implementing novel mathematical operations (special activations, custom loss components, domain-specific transformations)
-
Wrapping external code: Interfacing with C++/CUDA kernels, third-party libraries, or compiled extensions that PyTorch doesn't know about
-
Custom gradient behavior: Need non-standard gradient computation (gradient clipping in backward, sparsification, quantization, gradient routing)
-
Memory optimization: Implementing gradient checkpointing, fused operations, or selective materialization to reduce memory footprint
-
Interfacing with non-differentiable code: Wrapping operations that aren't naturally differentiable but have known gradient behavior
When NOT to Use Custom Functions
Don't use torch.autograd.Function when:
- ❌ Operation composable from existing ops: If you can write it with standard PyTorch operations, autograd handles gradients automatically
- ❌ Simple function wrapping: Just wrapping
torch.nn.functionaloperations gains nothing - ❌ Standard gradient computation: No custom behavior needed - use regular PyTorch
- ❌ Avoiding learning curve: Custom functions aren't "advanced only" - use when appropriate
Example:
# DON'T: Unnecessary custom function
class MyAdd(Function): # ❌ Pointless wrapper
@staticmethod
def forward(ctx, a, b):
return a + b
@staticmethod
def backward(ctx, grad):
return grad, grad
# DO: Use PyTorch's autograd
output = a + b # ✅ Autograd handles this correctly
# DON'T: Reimplement existing operations
class MyReLU(Function): # ❌ Use torch.nn.functional.relu
...
# DO: Use built-in operations
output = torch.relu(input) # ✅ Efficient and correct
Core Pattern: Complete Function Implementation
Basic Template
Every custom autograd function follows this pattern:
import torch
from torch.autograd import Function
class MyCustomFunction(Function):
"""
Custom autograd function template.
Implements forward pass (computation) and backward pass (gradient).
Context object (ctx) saves data between forward and backward.
"""
@staticmethod
def forward(ctx, input, weight, bias=None):
"""
Forward pass: compute output from inputs.
Args:
ctx: Context object for saving tensors/data
input: First input tensor
weight: Second input tensor
bias: Optional bias tensor
Returns:
output: Result tensor
"""
# Save tensors needed for backward pass
ctx.save_for_backward(input, weight, bias)
# Save non-tensor data as ctx attributes
# ctx.some_value = some_non_tensor_data
# Compute forward pass
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass: compute gradients using chain rule.
Args:
ctx: Context object with saved data
grad_output: Gradient of loss w.r.t. output (dL/dY)
Returns:
Tuple of gradients for each input to forward():
(grad_input, grad_weight, grad_bias)
Must return one gradient per forward() argument (None if not needed)
"""
# Retrieve saved tensors
input, weight, bias = ctx.saved_tensors
# Initialize gradients to None
grad_input = grad_weight = grad_bias = None
# Compute gradients only if needed (efficiency)
if ctx.needs_input_grad[0]: # Check if input needs gradient
grad_input = grad_output.mm(weight) # dL/dX = dL/dY @ W
if ctx.needs_input_grad[1]: # Check if weight needs gradient
grad_weight = grad_output.t().mm(input) # dL/dW = dL/dY.T @ X
if bias is not None and ctx.needs_input_grad[2]: # Check if bias needs gradient
grad_bias = grad_output.sum(0) # Sum over batch dimension
# Must return gradient for each forward() input (including ctx)
# First return is always None (ctx doesn't need gradient)
# But since ctx is implicit, return one per actual argument
return grad_input, grad_weight, grad_bias
# Use the custom function
my_func = MyCustomFunction.apply # Get callable
output = my_func(input_tensor, weight_tensor, bias_tensor)
loss = output.sum()
loss.backward() # Calls MyCustomFunction.backward()
Critical Rules
Rule 1: Return gradient for EACH forward input
# forward signature: forward(ctx, a, b, c, d=None)
# backward must return: (grad_a, grad_b, grad_c, grad_d)
# ✅ CORRECT: 4 inputs → 4 gradient returns
def backward(ctx, grad_output):
return grad_a, grad_b, grad_c, grad_d
# ❌ WRONG: Missing gradients
def backward(ctx, grad_output):
return grad_a, grad_b # Only 2 - will crash!
# ✅ CORRECT: Use None for unused gradients
def backward(ctx, grad_output):
return grad_a, None, grad_c, None # b and d don't need gradients
Rule 2: Gradient shape must match input shape exactly
# If input.shape = (32, 128)
# Then grad_input.shape MUST BE (32, 128)
# ✅ CORRECT: Shapes match
assert grad_input.shape == input.shape
# ❌ WRONG: Shape mismatch causes runtime error
grad_input = some_computation() # Shape (32, 64) - WRONG!
Rule 3: Check needs_input_grad before computing
# Efficiency optimization: skip gradient computation if not needed
# ✅ CORRECT: Check before computing
if ctx.needs_input_grad[0]:
grad_input = expensive_gradient_computation(...)
else:
grad_input = None
# ❌ WASTEFUL: Always compute (slow)
grad_input = expensive_gradient_computation(...) # Even if not needed
Context Object (ctx) Rules
The context object ctx is how you pass data from forward to backward. Use it correctly or break everything.
Rule 1: Use save_for_backward() for Tensors ONLY
# ✅ CORRECT: Save tensors with save_for_backward()
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight) # Both are tensors
return input @ weight
# ❌ WRONG: Saving tensors as attributes
@staticmethod
def forward(ctx, input, weight):
ctx.input = input # Breaks memory tracking
ctx.weight = weight # Breaks memory tracking
return input @ weight
# Why it matters: save_for_backward() properly tracks tensor versions
# and memory. Attribute assignment doesn't, leading to bugs/crashes.
Rule 2: Save Non-Tensor Data as Attributes
# ✅ CORRECT: Non-tensor data as attributes
@staticmethod
def forward(ctx, input, kernel_size, stride):
ctx.save_for_backward(input) # Tensor
ctx.kernel_size = kernel_size # Integer - use attribute
ctx.stride = stride # Integer - use attribute
return some_operation(input, kernel_size, stride)
# ❌ WRONG: Trying to save non-tensors with save_for_backward()
@staticmethod
def forward(ctx, input, kernel_size, stride):
ctx.save_for_backward(input, kernel_size, stride) # TypeError!
# kernel_size and stride are ints, not tensors
Rule 3: Access saved_tensors Only in Backward
# ✅ CORRECT: Access in backward
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors # Available here
grad_input = grad_output @ weight.t()
return grad_input, None
# ❌ WRONG: Access in forward
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
output = input @ weight
saved = ctx.saved_tensors # AttributeError! Not available in forward
return output
Rule 4: Never Modify Saved Tensors
# ❌ WRONG: Modifying saved tensor
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
input = input * 2 # Creates new tensor - OK
input *= 2 # IN-PLACE modification - BREAKS AUTOGRAD!
grad_input = compute_gradient(input, grad_output)
return grad_input
# ✅ CORRECT: Don't modify, or clone first
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
input_scaled = input * 2 # New tensor - safe
grad_input = compute_gradient(input_scaled, grad_output)
return grad_input
Complete ctx Example
class CompleteCtxExample(Function):
@staticmethod
def forward(ctx, input, weight, bias, stride=1, training=True):
# Save tensors with save_for_backward()
ctx.save_for_backward(input, weight, bias)
# Save non-tensor data as attributes
ctx.stride = stride # int
ctx.training = training # bool
# Compute forward
output = some_computation(input, weight, bias, stride, training)
return output
@staticmethod
def backward(ctx, grad_output):
# Retrieve saved tensors
input, weight, bias = ctx.saved_tensors
# Retrieve non-tensor data
stride = ctx.stride
training = ctx.training
# Compute gradients
grad_input = None
grad_weight = None
grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = compute_input_gradient(grad_output, weight, stride)
if ctx.needs_input_grad[1]:
grad_weight = compute_weight_gradient(grad_output, input, stride)
if ctx.needs_input_grad[2]:
grad_bias = compute_bias_gradient(grad_output)
# Return gradients (None for stride and training - they're not tensors)
return grad_input, grad_weight, grad_bias, None, None
Gradient Computation Patterns
Understanding common gradient patterns helps implement backward() correctly.
Pattern 1: Element-wise Operations
Forward: y = f(x) (element-wise function)
Backward: grad_x = grad_y * f'(x) (element-wise multiply)
# Example: Custom ReLU
class CustomReLU(Function):
@staticmethod
def forward(ctx, input):
# Save input to compute derivative in backward
ctx.save_for_backward(input)
# ReLU: max(0, x)
output = input.clamp(min=0)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# Derivative: 1 if x > 0, else 0
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# Example: Custom Sigmoid
class CustomSigmoid(Function):
@staticmethod
def forward(ctx, input):
output = torch.sigmoid(input)
# Save output (more efficient than recomputing)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
# Derivative: sigmoid(x) * (1 - sigmoid(x))
grad_input = grad_output * output * (1 - output)
return grad_input
# Example: Custom Tanh
class CustomTanh(Function):
@staticmethod
def forward(ctx, input):
output = torch.tanh(input)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
# Derivative: 1 - tanh^2(x)
grad_input = grad_output * (1 - output * output)
return grad_input
Pattern 2: Matrix Operations (Chain Rule)
Forward: Y = X @ W (matrix multiply)
Backward: Apply chain rule with matrix transpose
class CustomLinear(Function):
@staticmethod
def forward(ctx, input, weight):
# Save both tensors for backward
ctx.save_for_backward(input, weight)
# Forward: Y = X @ W
output = input.mm(weight.t()) # (batch, in) @ (out, in).T = (batch, out)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_weight = None
# Gradient w.r.t. input: dL/dX = dL/dY @ W
# Shapes: (batch, out) @ (out, in) = (batch, in) ✓
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
# Gradient w.r.t. weight: dL/dW = dL/dY.T @ X
# Then transpose to match weight shape
# Shapes: (out, batch) @ (batch, in) = (out, in) ✓
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
return grad_input, grad_weight
# More complex: Matrix multiply with both inputs requiring gradients
class CustomMatmul(Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
# Forward: C = A @ B
return torch.matmul(a, b)
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_a = grad_b = None
# Gradient w.r.t. A: dL/dA = dL/dC @ B.T
if ctx.needs_input_grad[0]:
grad_a = torch.matmul(grad_output, b.transpose(-2, -1))
# Gradient w.r.t. B: dL/dB = A.T @ dL/dC
if ctx.needs_input_grad[1]:
grad_b = torch.matmul(a.transpose(-2, -1), grad_output)
return grad_a, grad_b
Pattern 3: Broadcasting Operations
Forward: Operation with broadcasting (e.g., adding bias) Backward: Sum over broadcasted dimensions
class CustomBiasAdd(Function):
@staticmethod
def forward(ctx, input, bias):
# input: (batch, channels, height, width)
# bias: (channels,)
ctx.save_for_backward(input, bias)
# Broadcasting adds bias to each channel
output = input + bias.view(1, -1, 1, 1)
return output
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
grad_input = grad_bias = None
# Gradient w.r.t. input: just pass through
if ctx.needs_input_grad[0]:
grad_input = grad_output
# Gradient w.r.t. bias: sum over broadcasted dimensions
# grad_output: (batch, channels, height, width)
# grad_bias should be: (channels,)
# Sum over batch (0), height (2), width (3)
if ctx.needs_input_grad[1]:
grad_bias = grad_output.sum(dim=(0, 2, 3))
return grad_input, grad_bias
# General broadcasting pattern
class CustomBroadcastOp(Function):
@staticmethod
def forward(ctx, input, param):
# Save shapes to determine broadcast dimensions
ctx.input_shape = input.shape
ctx.param_shape = param.shape
ctx.save_for_backward(input, param)
# Some operation with broadcasting
output = input * param # param broadcasts to input shape
return output
@staticmethod
def backward(ctx, grad_output):
input, param = ctx.saved_tensors
grad_input = grad_param = None
if ctx.needs_input_grad[0]:
grad_input = grad_output * param
if ctx.needs_input_grad[1]:
# Sum grad_output over dimensions that were broadcasted
grad_param = grad_output * input
# Find which dimensions were broadcasted
# Sum over those dimensions to match param shape
ndim_diff = len(ctx.input_shape) - len(ctx.param_shape)
for i in range(ndim_diff):
grad_param = grad_param.sum(0) # Sum leading dimensions
for i, (input_dim, param_dim) in enumerate(
zip(ctx.input_shape[ndim_diff:], ctx.param_shape)
):
if param_dim == 1 and input_dim > 1:
grad_param = grad_param.sum(i, keepdim=True)
return grad_input, grad_param
Pattern 4: Reduction Operations
Forward: Reduce dimensions (sum, mean, max, etc.) Backward: Expand gradient back to original shape
class CustomSum(Function):
@staticmethod
def forward(ctx, input, dim, keepdim=False):
ctx.input_shape = input.shape
ctx.dim = dim
ctx.keepdim = keepdim
# Sum along dimension
output = input.sum(dim=dim, keepdim=keepdim)
return output
@staticmethod
def backward(ctx, grad_output):
# Gradient of sum: distribute grad_output to all elements
grad_input = grad_output
# Expand back to original shape
if not ctx.keepdim:
# Add back the reduced dimension
grad_input = grad_input.unsqueeze(ctx.dim)
# Expand to original shape (broadcasts the gradient)
grad_input = grad_input.expand(ctx.input_shape)
return grad_input, None, None
class CustomMean(Function):
@staticmethod
def forward(ctx, input, dim, keepdim=False):
ctx.input_shape = input.shape
ctx.dim = dim
ctx.keepdim = keepdim
output = input.mean(dim=dim, keepdim=keepdim)
return output
@staticmethod
def backward(ctx, grad_output):
# Gradient of mean: distribute evenly to all elements
grad_input = grad_output
if not ctx.keepdim:
grad_input = grad_input.unsqueeze(ctx.dim)
grad_input = grad_input.expand(ctx.input_shape)
# Divide by number of elements that were averaged
n = ctx.input_shape[ctx.dim]
grad_input = grad_input / n
return grad_input, None, None
class CustomMax(Function):
@staticmethod
def forward(ctx, input, dim, keepdim=False):
# Save both max values and indices
output, indices = input.max(dim=dim, keepdim=keepdim)
ctx.save_for_backward(indices)
ctx.input_shape = input.shape
ctx.dim = dim
ctx.keepdim = keepdim
return output
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
# Only the maximum element gets gradient
grad_input = torch.zeros(ctx.input_shape, device=grad_output.device)
if not ctx.keepdim:
grad_output = grad_output.unsqueeze(ctx.dim)
indices = indices.unsqueeze(ctx.dim)
# Scatter gradient to max indices
grad_input.scatter_(ctx.dim, indices, grad_output)
return grad_input, None, None
Pattern 5: Convolution-like Operations
Complex operations that involve multiple dimensions and strides.
class CustomConv1d(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0):
# Use PyTorch's conv for forward
output = torch.nn.functional.conv1d(
input, weight, bias, stride, padding
)
# Save what's needed for backward
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.padding = padding
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
grad_input = grad_weight = grad_bias = None
# Gradient w.r.t. input: convolve grad_output with weight
if ctx.needs_input_grad[0]:
grad_input = torch.nn.grad.conv1d_input(
input.shape, weight, grad_output, stride, padding
)
# Gradient w.r.t. weight: convolve input with grad_output
if ctx.needs_input_grad[1]:
grad_weight = torch.nn.grad.conv1d_weight(
input, weight.shape, grad_output, stride, padding
)
# Gradient w.r.t. bias: sum grad_output over batch and spatial dims
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0, 2))
return grad_input, grad_weight, grad_bias, None, None
Numerical Gradient Verification (MANDATORY)
NEVER skip this step. gradcheck verifies your gradient computation is correct by comparing analytical gradients (your backward()) against numerical gradients (finite differences).
Why gradcheck is Mandatory
# You implement backward() and it "looks right"
class MyFunction(Function):
@staticmethod
def backward(ctx, grad_output):
# Gradient formula looks correct mathematically
grad = some_computation()
return grad
# But:
# ❌ Transpose is wrong
# ❌ Shape doesn't match
# ❌ Sign is flipped
# ❌ Missing a term
# ❌ Broadcasting is incorrect
# These bugs are invisible without gradcheck!
# Your model trains but produces wrong results.
# Debugging takes days without knowing gradients are wrong.
Basic gradcheck Usage
import torch
from torch.autograd import gradcheck
def test_my_function():
"""Test custom function with gradcheck."""
# Create test inputs with requires_grad=True
# Use double precision for numerical stability
input = torch.randn(20, 20, dtype=torch.double, requires_grad=True)
weight = torch.randn(30, 20, dtype=torch.double, requires_grad=True)
# Run gradcheck
test = gradcheck(
MyCustomFunction.apply, # Your function
(input, weight), # Tuple of inputs
eps=1e-6, # Finite difference step size
atol=1e-4, # Absolute tolerance
rtol=1e-3, # Relative tolerance
raise_exception=True # Raise error on failure (recommended)
)
if test:
print("✅ Gradient check PASSED!")
else:
print("❌ Gradient check FAILED!")
raise AssertionError("Gradient check failed")
# Run before using your function
test_my_function()
Complete gradcheck Pattern
import torch
from torch.autograd import gradcheck, gradgradcheck
class MyCustomFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
def test_my_custom_function():
"""Comprehensive gradient testing."""
# Test 1: Basic gradcheck
print("Test 1: Basic gradient check...")
input = torch.randn(10, 5, dtype=torch.double, requires_grad=True)
weight = torch.randn(3, 5, dtype=torch.double, requires_grad=True)
bias = torch.randn(3, dtype=torch.double, requires_grad=True)
assert gradcheck(
MyCustomFunction.apply,
(input, weight, bias),
eps=1e-6,
atol=1e-4,
raise_exception=True
), "Basic gradcheck failed"
print("✅ Basic gradcheck passed")
# Test 2: Without bias (optional parameter)
print("Test 2: Gradient check without bias...")
assert gradcheck(
MyCustomFunction.apply,
(input, weight, None),
eps=1e-6,
atol=1e-4,
raise_exception=True
), "Gradcheck without bias failed"
print("✅ Gradcheck without bias passed")
# Test 3: Different input shapes
print("Test 3: Different input shapes...")
input_large = torch.randn(50, 20, dtype=torch.double, requires_grad=True)
weight_large = torch.randn(10, 20, dtype=torch.double, requires_grad=True)
bias_large = torch.randn(10, dtype=torch.double, requires_grad=True)
assert gradcheck(
MyCustomFunction.apply,
(input_large, weight_large, bias_large),
eps=1e-6,
atol=1e-4,
raise_exception=True
), "Gradcheck with large inputs failed"
print("✅ Gradcheck with different shapes passed")
# Test 4: Second-order gradients (if needed)
print("Test 4: Second-order gradient check...")
try:
assert gradgradcheck(
MyCustomFunction.apply,
(input, weight, bias),
eps=1e-6,
atol=1e-4,
raise_exception=True
), "Second-order gradcheck failed"
print("✅ Second-order gradcheck passed")
except NotImplementedError:
print("⚠️ Second-order gradients not implemented (OK if not needed)")
print("\n🎉 All gradient checks passed!")
# ALWAYS run this before using your function in training
test_my_custom_function()
gradcheck Parameters
gradcheck(
func, # Your Function.apply
inputs, # Tuple of input tensors
eps=1e-6, # Finite difference step: f(x+eps) - f(x-eps)
atol=1e-5, # Absolute tolerance for comparison
rtol=1e-3, # Relative tolerance for comparison
raise_exception=True, # Raise on failure (recommended for testing)
check_sparse_nnz=False, # Check sparse tensor non-zeros
nondet_tol=0.0, # Tolerance for non-deterministic operations
check_undefined_grad=True, # Check that undefined grads are None
check_grad_dtypes=True, # Check gradient dtypes match
)
# Key insights:
# - Use double precision (dtype=torch.double) for numerical stability
# - eps=1e-6 is good default; smaller for more precision, larger for stability
# - atol/rtol balance: looser tolerances for complex operations
# - raise_exception=True catches bugs immediately in testing
Debugging gradcheck Failures
def debug_gradcheck():
"""Step-by-step debugging when gradcheck fails."""
# Step 1: Check forward pass works
print("Step 1: Verify forward pass...")
input = torch.randn(5, 3, dtype=torch.double, requires_grad=True)
weight = torch.randn(4, 3, dtype=torch.double, requires_grad=True)
output = MyCustomFunction.apply(input, weight)
print(f"Output shape: {output.shape}")
print(f"Output contains NaN: {torch.isnan(output).any()}")
print(f"Output contains Inf: {torch.isinf(output).any()}")
assert output.shape == (5, 4), "Forward shape wrong"
assert not torch.isnan(output).any(), "Forward produces NaN"
# Step 2: Check backward runs without error
print("\nStep 2: Verify backward runs...")
loss = output.sum()
loss.backward()
print(f"Input grad shape: {input.grad.shape}")
print(f"Weight grad shape: {weight.grad.shape}")
assert input.grad.shape == input.shape, "Input gradient shape mismatch"
assert weight.grad.shape == weight.shape, "Weight gradient shape mismatch"
# Step 3: Check gradient magnitudes
print("\nStep 3: Check gradient magnitudes...")
print(f"Input grad: mean={input.grad.mean():.6f}, std={input.grad.std():.6f}")
print(f"Weight grad: mean={weight.grad.mean():.6f}, std={weight.grad.std():.6f}")
# Should be reasonable numbers (not 1e10 or 1e-20)
# Step 4: Manual numerical gradient check for one element
print("\nStep 4: Manual gradient check for one element...")
input_test = torch.randn(3, 2, dtype=torch.double)
weight_test = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
eps = 1e-6
# Analytical gradient
output = MyCustomFunction.apply(input_test, weight_test)
loss = output.sum()
loss.backward()
analytical_grad = weight_test.grad[0, 0].item()
# Numerical gradient (finite difference)
weight_test_plus = weight_test.clone().detach()
weight_test_plus[0, 0] += eps
output_plus = MyCustomFunction.apply(input_test, weight_test_plus)
loss_plus = output_plus.sum()
weight_test_minus = weight_test.clone().detach()
weight_test_minus[0, 0] -= eps
output_minus = MyCustomFunction.apply(input_test, weight_test_minus)
loss_minus = output_minus.sum()
numerical_grad = (loss_plus - loss_minus) / (2 * eps)
numerical_grad = numerical_grad.item()
print(f"Analytical gradient: {analytical_grad:.10f}")
print(f"Numerical gradient: {numerical_grad:.10f}")
print(f"Difference: {abs(analytical_grad - numerical_grad):.10e}")
if abs(analytical_grad - numerical_grad) > 1e-4:
print("❌ Large difference - gradient implementation likely wrong")
else:
print("✅ Small difference - gradient likely correct")
# Step 5: Run gradcheck with verbose output
print("\nStep 5: Run gradcheck...")
input_check = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
weight_check = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
try:
result = gradcheck(
MyCustomFunction.apply,
(input_check, weight_check),
eps=1e-6,
atol=1e-4,
raise_exception=True
)
print("✅ gradcheck passed!")
except RuntimeError as e:
print(f"❌ gradcheck failed with error:\n{e}")
# Error message shows which gradient failed and by how much
debug_gradcheck()
Common gradcheck Failure Reasons
# Failure 1: Wrong gradient formula
class WrongGradient(Function):
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
# ❌ WRONG: No transpose
grad_input = grad_output @ weight # Should be weight.t()
return grad_input, None
# gradcheck fails: analytical ≠ numerical
# Failure 2: Shape mismatch
class WrongShape(Function):
@staticmethod
def backward(ctx, grad_output):
# ❌ WRONG: Returns wrong shape
return grad_output.sum(), None # Should be grad_output.shape == input.shape
# gradcheck fails: shape error
# Failure 3: In-place operation
class InplaceOperation(Function):
@staticmethod
def backward(ctx, grad_output):
grad_output[grad_output < 0] = 0 # ❌ IN-PLACE
return grad_output
# gradcheck fails: modified by inplace operation
# Failure 4: Not using saved tensors correctly
class WrongSaved(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input.clone()) # Saved clone
return input * 2
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# Using saved tensor is OK, but if logic depends on
# input's original properties and clone loses them, fails
return grad_output * 2
# May pass or fail depending on what was lost in clone
# Failure 5: Forgot to return gradients for all inputs
class MissingGradient(Function):
@staticmethod
def forward(ctx, input, weight, bias):
# ...
return output
@staticmethod
def backward(ctx, grad_output):
# ❌ WRONG: Only returns 2 gradients for 3 inputs
return grad_input, grad_weight
# gradcheck fails: tuple length mismatch
Common Pitfalls
Pitfall 1: In-Place Operations
In-place operations modify tensors that other operations depend on, breaking autograd.
# ❌ WRONG: In-place operation in forward
class InplaceForward(Function):
@staticmethod
def forward(ctx, input):
input[input < 0] = 0 # IN-PLACE - breaks autograd!
return input
# ❌ WRONG: In-place operation in backward
class InplaceBackward(Function):
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_output[input < 0] = 0 # IN-PLACE - breaks autograd!
return grad_output
# ✅ CORRECT: Create new tensor
class CorrectInplace(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = input.clone() # Create new tensor
output[output < 0] = 0 # Modify copy, not original
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone() # Create new tensor
grad_input[input < 0] = 0 # Modify copy
return grad_input
# Even better: Use non-in-place operations
class BestInplace(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0) # Non-in-place
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output * (input > 0).float() # Non-in-place
return grad_input
Pitfall 2: Gradient Shape Mismatch
Gradient must match input shape exactly.
# ❌ WRONG: Gradient shape doesn't match input
class WrongShape(Function):
@staticmethod
def forward(ctx, input):
# input: (32, 128)
ctx.save_for_backward(input)
return input.sum() # scalar
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# grad_output is scalar, but need (32, 128)
return grad_output # ❌ WRONG: scalar ≠ (32, 128)
# ✅ CORRECT: Expand to match input shape
class CorrectShape(Function):
@staticmethod
def forward(ctx, input):
ctx.input_shape = input.shape
return input.sum()
@staticmethod
def backward(ctx, grad_output):
# Expand scalar to input shape
grad_input = grad_output.expand(ctx.input_shape)
return grad_input
# Always verify shapes
def verify_shapes(ctx, grad_output, *grad_inputs):
"""Helper to verify gradient shapes."""
for i, (grad, tensor) in enumerate(zip(grad_inputs, ctx.saved_tensors)):
if grad is not None:
assert grad.shape == tensor.shape, \
f"Gradient {i} shape {grad.shape} != input shape {tensor.shape}"
Pitfall 3: Not Checking needs_input_grad
Computing gradients when not needed wastes computation.
# ❌ WASTEFUL: Always compute all gradients
class AlwaysCompute(Function):
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
# Compute all gradients even if not needed
grad_input = expensive_computation_1(grad_output, weight)
grad_weight = expensive_computation_2(grad_output, input)
grad_bias = expensive_computation_3(grad_output)
return grad_input, grad_weight, grad_bias
# ✅ EFFICIENT: Check needs_input_grad first
class EfficientCompute(Function):
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# Only compute if needed
if ctx.needs_input_grad[0]:
grad_input = expensive_computation_1(grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = expensive_computation_2(grad_output, input)
if ctx.needs_input_grad[2]:
grad_bias = expensive_computation_3(grad_output)
return grad_input, grad_weight, grad_bias
# Example where it matters
def example_needs_input_grad():
"""Demonstrate needs_input_grad optimization."""
input = torch.randn(100, 100, requires_grad=True)
weight = torch.randn(100, 100, requires_grad=False) # No gradient needed
# Without check: computes grad_weight unnecessarily
# With check: skips grad_weight computation (faster)
output = MyFunction.apply(input, weight)
loss = output.sum()
loss.backward()
# weight.grad is None because requires_grad=False
assert weight.grad is None
Pitfall 4: Using .data Instead of .detach()
.data bypasses autograd tracking incorrectly.
# ❌ WRONG: Using .data
class UsingData(Function):
@staticmethod
def forward(ctx, input):
# .data returns tensor without autograd tracking
# But doesn't properly detach from computation graph
ctx.save_for_backward(input.data) # ❌ WRONG
return input * 2
@staticmethod
def backward(ctx, grad_output):
input_data, = ctx.saved_tensors
# May produce incorrect gradients
return grad_output * 2
# ✅ CORRECT: Use .detach() or save normally
class UsingDetach(Function):
@staticmethod
def forward(ctx, input):
# If you need to save without tracking gradient
ctx.save_for_backward(input.detach()) # ✅ Properly detaches
# Or just save normally if gradient tracking is OK
ctx.save_for_backward(input) # ✅ Most common
return input * 2
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * 2
# When to detach in custom functions
class WhenToDetach(Function):
@staticmethod
def forward(ctx, input, target):
# Save input normally (gradient needed)
# Detach target (gradient not needed, just reference)
ctx.save_for_backward(input, target.detach())
loss = (input - target).pow(2).mean()
return loss
@staticmethod
def backward(ctx, grad_output):
input, target = ctx.saved_tensors
# Compute gradient w.r.t. input only
grad_input = 2 * (input - target) * grad_output
return grad_input, None
Pitfall 5: Modifying grad_output
Never modify grad_output in-place; it's used by other operations.
# ❌ WRONG: Modifying grad_output in-place
class ModifyGradOutput(Function):
@staticmethod
def backward(ctx, grad_output):
# ❌ WRONG: In-place modification
grad_output *= 2
return grad_output
# ✅ CORRECT: Create new tensor
class DontModifyGradOutput(Function):
@staticmethod
def backward(ctx, grad_output):
# ✅ Create new tensor
grad_input = grad_output * 2
return grad_input
# Why it matters
def why_grad_output_matters():
"""Demonstrate why modifying grad_output breaks autograd."""
# Consider: z = f(g(x))
# Backward: dz/dx = dz/dg * dg/dx
#
# If f.backward() modifies grad_output (dz/dg),
# then g.backward() receives wrong gradient!
x = torch.randn(5, requires_grad=True)
# g(x)
y = x * 2
# f(g(x)) - uses custom function that modifies grad_output
z = BadFunction.apply(y)
z.backward()
# x.grad is now WRONG because BadFunction modified grad_output
# that was passed to y's backward
Pitfall 6: Forgetting to Return None for Non-Tensor Arguments
Must return gradient for every forward() argument.
# ❌ WRONG: Not enough return values
class NotEnoughReturns(Function):
@staticmethod
def forward(ctx, input, kernel_size, stride):
# 3 arguments (excluding ctx)
ctx.save_for_backward(input)
ctx.kernel_size = kernel_size
ctx.stride = stride
return some_operation(input, kernel_size, stride)
@staticmethod
def backward(ctx, grad_output):
# ❌ WRONG: Only returns 1 value for 3 inputs
return grad_output # Crashes: expected 3 values
# ✅ CORRECT: Return gradient (or None) for each input
class EnoughReturns(Function):
@staticmethod
def forward(ctx, input, kernel_size, stride):
ctx.save_for_backward(input)
ctx.kernel_size = kernel_size
ctx.stride = stride
return some_operation(input, kernel_size, stride)
@staticmethod
def backward(ctx, grad_output):
# ✅ Return 3 values: grad for input, None for kernel_size, None for stride
return grad_output, None, None
# Rule of thumb
def count_returns():
"""
backward() must return one value per forward() argument (excluding ctx).
forward(ctx, a, b, c, d=None) → backward must return (grad_a, grad_b, grad_c, grad_d)
Use None for:
- Non-tensor arguments (ints, strings, etc.)
- Optional arguments that were None
- Tensors that don't need gradients
"""
pass
Pitfall 7: Incorrect Broadcasting in Gradient
Gradient must account for broadcasting that occurred in forward.
# ❌ WRONG: Doesn't handle broadcasting correctly
class WrongBroadcast(Function):
@staticmethod
def forward(ctx, input, weight):
# input: (32, 64, 10, 10)
# weight: (64, 1, 1) - broadcasts to (32, 64, 10, 10)
ctx.save_for_backward(input, weight)
output = input * weight # Broadcasting happens
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output * weight # ✅ This is fine
# ❌ WRONG: grad_weight shape is (32, 64, 10, 10), should be (64, 1, 1)
grad_weight = grad_output * input
return grad_input, grad_weight # Shape mismatch!
# ✅ CORRECT: Sum over broadcasted dimensions
class CorrectBroadcast(Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
ctx.input_shape = input.shape
ctx.weight_shape = weight.shape
output = input * weight
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output * weight
# ✅ Sum over dimensions that were broadcasted
grad_weight = grad_output * input
# weight shape: (64, 1, 1), grad_weight current: (32, 64, 10, 10)
# Sum over batch (0), height (2), width (3)
grad_weight = grad_weight.sum(dim=(0, 2, 3), keepdim=True)
# Now grad_weight shape: (1, 64, 1, 1)
# Squeeze batch dimension: (64, 1, 1) ✅
grad_weight = grad_weight.squeeze(0)
return grad_input, grad_weight
# General broadcasting gradient pattern
def sum_to_shape(tensor, shape):
"""Sum tensor to match target shape (handles broadcasting)."""
# Find dimensions that were added
while tensor.dim() > len(shape):
tensor = tensor.sum(0)
# Find dimensions that were size 1 and got broadcasted
for i, (t_dim, s_dim) in enumerate(zip(tensor.shape, shape)):
if s_dim == 1 and t_dim > 1:
tensor = tensor.sum(i, keepdim=True)
return tensor
class GeneralBroadcast(Function):
@staticmethod
def backward(ctx, grad_output):
input, param = ctx.saved_tensors
grad_input = grad_output * param
grad_param = grad_output * input
# Sum to match param's original shape
grad_param = sum_to_shape(grad_param, param.shape)
return grad_input, grad_param
Memory Efficiency Patterns
Custom functions enable memory optimizations not possible with standard autograd.
Pattern 1: Gradient Checkpointing
Trade computation for memory by recomputing forward in backward.
class CheckpointedFunction(Function):
"""
Gradient checkpointing: Don't save activations, recompute in backward.
Memory: O(1) instead of O(n) for n layers
Time: 2x forward pass (once in forward, once in backward)
"""
@staticmethod
def forward(ctx, input, *args):
# Save inputs and parameters, NOT intermediate activations
ctx.save_for_backward(input, *args)
# Compute forward pass (activations not saved)
# For complex operations, this may compute many intermediate values
output = expensive_computation(input, *args)
# Intermediate activations are garbage collected
# Saves memory!
return output
@staticmethod
def backward(ctx, grad_output):
# Retrieve inputs
input, *args = ctx.saved_tensors
# Recompute forward pass to get intermediate values
# This time, track gradients
with torch.enable_grad():
# Detach input, then set requires_grad
# (Required for computing gradients)
input = input.detach().requires_grad_(True)
# Recompute forward
output = expensive_computation(input, *args)
# Now compute gradients using autograd
grad_input, = torch.autograd.grad(
outputs=output,
inputs=input,
grad_outputs=grad_output,
retain_graph=False
)
# Return gradient for input (and None for args if they're parameters)
return (grad_input,) + (None,) * len(args)
# Example: Checkpointed Sequential Layers
class CheckpointedSequential(Function):
@staticmethod
def forward(ctx, input, *layers):
"""
Forward through multiple layers without saving intermediate activations.
Normal: Saves n-1 activations for n layers
Checkpointed: Saves only input and parameters
"""
ctx.layers = layers
ctx.save_for_backward(input)
# Forward through all layers
output = input
for layer in layers:
output = layer(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
layers = ctx.layers
# Recompute forward to get intermediate activations
with torch.enable_grad():
input = input.detach().requires_grad_(True)
# Forward pass again, this time tracking activations
activations = [input]
output = input
for layer in layers:
output = layer(output)
activations.append(output)
# Backward through layers
grad = grad_output
for i in reversed(range(len(layers))):
# Compute gradient through layer i
grad, = torch.autograd.grad(
outputs=activations[i+1],
inputs=activations[i],
grad_outputs=grad,
retain_graph=True
)
grad_input = grad
# No gradients for layers (they're not inputs to forward)
return (grad_input,) + (None,) * len(layers)
# PyTorch provides torch.utils.checkpoint.checkpoint for this
# But understanding the pattern helps for custom cases
Pattern 2: Selective Saving
Only save what's needed for backward; recompute or omit the rest.
class SelectiveSaving(Function):
"""
Save only essential tensors; recompute others in backward.
"""
@staticmethod
def forward(ctx, input, weight, bias):
# Compute intermediate values
weighted = input @ weight
activated = torch.relu(weighted)
output = activated + bias
# DON'T save everything:
# ctx.save_for_backward(input, weight, bias, weighted, activated)
# ❌ Saves 5 tensors
# ✅ SAVE ONLY WHAT'S NEEDED:
# Can recompute 'weighted' from input and weight
# Can recompute 'activated' from weighted
ctx.save_for_backward(input, weight, bias, activated)
# ✅ Saves 4 tensors (or even fewer)
# Or even more selective:
# ctx.save_for_backward(input, weight, bias)
# ctx.save_for_backward(activated > 0) # Save mask, not full tensor
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias, activated = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# Gradient through bias addition
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# Gradient through ReLU (need activation mask)
grad_activated = grad_output.clone()
grad_activated[activated <= 0] = 0
# Gradient through matmul
if ctx.needs_input_grad[0]:
grad_input = grad_activated @ weight.t()
if ctx.needs_input_grad[1]:
grad_weight = input.t() @ grad_activated
return grad_input, grad_weight, grad_bias
# Even more selective: Save boolean mask instead of full tensor
class UltraSelective(Function):
@staticmethod
def forward(ctx, input, weight):
output = torch.relu(input @ weight)
# Instead of saving full 'output' tensor:
# ctx.save_for_backward(input, weight, output) # Large memory
# ✅ Save only boolean mask (1 bit per element vs 32 bits)
ctx.save_for_backward(input, weight)
ctx.relu_mask = (output > 0) # Boolean tensor (much smaller)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
relu_mask = ctx.relu_mask
# Use mask to apply ReLU gradient
grad_weighted = grad_output * relu_mask.float()
grad_input = grad_weighted @ weight.t() if ctx.needs_input_grad[0] else None
grad_weight = input.t() @ grad_weighted if ctx.needs_input_grad[1] else None
return grad_input, grad_weight
Pattern 3: Detaching Tensors That Don't Need Gradients
Detach tensors that are used in forward but don't need gradients.
class DetachPattern(Function):
"""
Detach tensors that don't need gradient computation.
"""
@staticmethod
def forward(ctx, input, target, weight):
"""
Compute loss between input and target.
Target doesn't need gradients (it's labels).
"""
# Save input and weight (need gradients)
ctx.save_for_backward(input, weight)
# Detach target (doesn't need gradients)
# This breaks the autograd connection, saving memory
ctx.target = target.detach()
# Compute weighted loss
loss = ((input - target) ** 2 * weight).mean()
return loss
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
target = ctx.target # Already detached
# Compute gradients
diff = input - target
grad_input = None
grad_weight = None
if ctx.needs_input_grad[0]:
grad_input = 2 * diff * weight * grad_output
# No gradient for target (ctx.needs_input_grad[1] is False)
if ctx.needs_input_grad[2]:
grad_weight = (diff ** 2) * grad_output
return grad_input, None, grad_weight
# When to detach
def detach_guidelines():
"""
Detach tensors when:
1. They're labels/targets (no gradient needed)
2. They're constants (no gradient needed)
3. They're from non-differentiable sources
4. You explicitly don't want gradients to flow through them
Don't detach when:
1. Gradient is needed for that tensor
2. Gradient will flow through that path
"""
pass
Advanced Patterns
Pattern 1: Double Backward (Second-Order Derivatives)
Support computing gradients of gradients.
class DoubleBackwardFunction(Function):
"""
Function that supports double backward for second-order derivatives.
Example: Hessian computation, meta-learning, some regularization terms.
"""
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
# Forward: y = x^2
return input ** 2
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# First derivative: dy/dx = 2x
grad_input = 2 * input * grad_output
# For double backward, need to return tensor that supports backward
# Not just detached scalar
return grad_input
# For explicit double backward support (optional, autograd often handles it)
@staticmethod
def jvp(ctx, *grad_inputs):
"""
Jacobian-vector product for forward-mode AD.
Needed for some second-order derivative computations.
"""
# Usually not needed; autograd handles it
pass
# Test double backward
def test_double_backward():
"""Test that double backward works."""
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# First forward
y = DoubleBackwardFunction.apply(x)
# First backward: dy/dx
grad_x, = torch.autograd.grad(y.sum(), x, create_graph=True)
# grad_x = 2x = [2, 4, 6]
# Second backward: d(dy/dx)/dx = d(2x)/dx = 2
grad_grad_x, = torch.autograd.grad(grad_x.sum(), x)
# grad_grad_x = [2, 2, 2]
print(f"First derivative: {grad_x}") # [2, 4, 6]
print(f"Second derivative: {grad_grad_x}") # [2, 2, 2]
assert torch.allclose(grad_x, 2 * x)
assert torch.allclose(grad_grad_x, torch.ones_like(x) * 2)
print("✅ Double backward works!")
test_double_backward()
# Example: Function where double backward matters
class HessianVectorProduct(Function):
"""
Efficiently compute Hessian-vector product: H @ v
where H = ∇²f(x) is the Hessian matrix.
Used in: second-order optimization, meta-learning.
"""
@staticmethod
def forward(ctx, input, vector):
ctx.save_for_backward(input, vector)
# Placeholder forward (actual computation in backward)
return input
@staticmethod
def backward(ctx, grad_output):
input, vector = ctx.saved_tensors
# Compute gradient
grad_input = grad_output
# This gradient can be backpropagated through again
# for Hessian computation
return grad_input, None
# Using double backward for Hessian
def compute_hessian(f, x):
"""
Compute Hessian of scalar function f at point x.
Uses double backward: H[i,j] = ∂²f/∂x_i∂x_j
"""
# First derivative
grad_x, = torch.autograd.grad(f(x), x, create_graph=True)
# Second derivative (Hessian)
hessian = []
for i in range(x.shape[0]):
grad_grad_x, = torch.autograd.grad(
grad_x[i], x, retain_graph=True
)
hessian.append(grad_grad_x)
return torch.stack(hessian)
Pattern 2: Custom Backward Hooks
Modify gradients during backward pass without changing the function.
def gradient_clipping_hook(grad):
"""
Hook to clip gradients to [-1, 1].
Applied to tensor, not Function.
"""
return torch.clamp(grad, -1, 1)
def gradient_noise_hook(grad):
"""Add noise to gradients (for regularization)."""
noise = torch.randn_like(grad) * 0.01
return grad + noise
def gradient_logging_hook(grad):
"""Log gradient statistics."""
print(f"Gradient: mean={grad.mean():.6f}, std={grad.std():.6f}, max={grad.abs().max():.6f}")
return grad # Return unchanged
# Using hooks
def use_hooks():
"""Example of using gradient hooks."""
input = torch.randn(10, 10, requires_grad=True)
weight = torch.randn(10, 10, requires_grad=True)
# Register hooks
input.register_hook(gradient_clipping_hook)
weight.register_hook(gradient_logging_hook)
# Forward and backward
output = MyFunction.apply(input, weight)
loss = output.sum()
loss.backward()
# Hooks are applied during backward
# input.grad is clipped to [-1, 1]
# weight.grad statistics are logged
# Hooks in custom functions
class FunctionWithHook(Function):
@staticmethod
def forward(ctx, input, clip_grad=False):
ctx.clip_grad = clip_grad
ctx.save_for_backward(input)
return input * 2
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output * 2
# Apply custom modification based on context
if ctx.clip_grad:
grad_input = torch.clamp(grad_input, -1, 1)
return grad_input, None
# Removing hooks
def manage_hooks():
"""Add and remove hooks."""
tensor = torch.randn(5, requires_grad=True)
# Add hook (returns handle)
hook_handle = tensor.register_hook(gradient_clipping_hook)
# Use tensor
loss = (tensor ** 2).sum()
loss.backward()
# Hook is applied
# Remove hook
hook_handle.remove()
# Hook no longer applied in subsequent backwards
tensor.grad.zero_()
loss = (tensor ** 2).sum()
loss.backward()
# Hook NOT applied
Pattern 3: Custom Gradient for Part of Computation
Stop gradients or customize them for specific operations.
class StopGradient(Function):
"""
Stop gradient flow (like tf.stop_gradient or tensor.detach()).
Forward: pass through
Backward: return None (no gradient)
"""
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
# Don't pass gradient through
return None
# Usage: Stop gradient flow
def stop_gradient_example():
x = torch.randn(5, requires_grad=True)
# y doesn't get gradients from z
y = StopGradient.apply(x)
z = y ** 2
z.sum().backward()
assert x.grad is None # No gradient flowed to x
class StraightThroughEstimator(Function):
"""
Straight-through estimator for non-differentiable operations.
Forward: non-differentiable operation (e.g., binarization)
Backward: pretend it was identity (pass gradient through)
Used for: quantization, binarization, discrete operations.
"""
@staticmethod
def forward(ctx, input):
# Non-differentiable: binarize to {-1, 1}
return torch.sign(input)
@staticmethod
def backward(ctx, grad_output):
# Pretend forward was identity: gradient passes through unchanged
return grad_output
# Usage: Train networks with binary weights
def straight_through_example():
"""
Train a network with binary weights using straight-through estimator.
"""
weight = torch.randn(10, 10, requires_grad=True)
input = torch.randn(32, 10)
# Binarize weight for forward pass
binary_weight = StraightThroughEstimator.apply(weight)
# binary_weight ∈ {-1, 1}
# Use binary weight
output = input @ binary_weight
loss = output.sum()
loss.backward()
# weight.grad exists (even though sign() isn't differentiable)
# Gradient passed through as if sign() was identity
assert weight.grad is not None
class CustomGradientScale(Function):
"""
Scale gradient by a factor without changing forward.
Forward: pass through
Backward: scale gradient by alpha
Used for: gradient reversal layers (adversarial training),
controlling gradient flow in different branches.
"""
@staticmethod
def forward(ctx, input, alpha):
ctx.alpha = alpha
return input
@staticmethod
def backward(ctx, grad_output):
# Scale gradient
return grad_output * ctx.alpha, None
# Usage: Gradient reversal layer
def gradient_reversal_layer(input, alpha=-1.0):
"""
Reverses gradients (multiplies by -1).
Used in domain adaptation for adversarial training.
"""
return CustomGradientScale.apply(input, alpha)
Complete Real-World Examples
Example 1: Custom Swish Activation with Learnable Beta
import torch
from torch.autograd import Function, gradcheck
import torch.nn as nn
class SwishFunction(Function):
"""
Custom Swish activation: f(x) = x * sigmoid(β * x)
where β is a learnable parameter.
Forward: y = x * σ(βx)
Backward: dy/dx = σ(βx) + x * σ(βx) * (1 - σ(βx)) * β
dy/dβ = x² * σ(βx) * (1 - σ(βx))
"""
@staticmethod
def forward(ctx, input, beta):
"""
Args:
input: Input tensor
beta: Learnable scaling parameter (scalar tensor)
Returns:
output: Swish activation output
"""
ctx.save_for_backward(input, beta)
# Compute sigmoid(beta * input)
sigmoid_beta_input = torch.sigmoid(beta * input)
# Save for backward (more efficient than recomputing)
ctx.save_for_backward(input, beta, sigmoid_beta_input)
# f(x) = x * sigmoid(beta * x)
output = input * sigmoid_beta_input
return output
@staticmethod
def backward(ctx, grad_output):
"""
Compute gradients using chain rule.
"""
input, beta, sigmoid_beta_input = ctx.saved_tensors
grad_input = grad_beta = None
# Gradient w.r.t. input
if ctx.needs_input_grad[0]:
# d/dx[x * σ(βx)] = σ(βx) + x * σ'(βx) * β
# where σ'(z) = σ(z) * (1 - σ(z))
sigmoid_derivative = sigmoid_beta_input * (1 - sigmoid_beta_input)
grad_input = grad_output * (
sigmoid_beta_input + input * sigmoid_derivative * beta
)
# Gradient w.r.t. beta
if ctx.needs_input_grad[1]:
# d/dβ[x * σ(βx)] = x * σ'(βx) * x = x² * σ(βx) * (1 - σ(βx))
sigmoid_derivative = sigmoid_beta_input * (1 - sigmoid_beta_input)
grad_beta = grad_output * (input ** 2) * sigmoid_derivative
# Sum over all elements (beta is scalar)
grad_beta = grad_beta.sum()
return grad_input, grad_beta
class Swish(nn.Module):
"""
Swish activation module with learnable beta parameter.
"""
def __init__(self, beta=1.0):
super().__init__()
self.beta = nn.Parameter(torch.tensor(beta))
def forward(self, input):
return SwishFunction.apply(input, self.beta)
# Test with gradcheck
def test_swish():
"""Verify Swish implementation with gradcheck."""
print("Testing Swish activation...")
# Test 1: Basic gradcheck
input = torch.randn(10, 5, dtype=torch.double, requires_grad=True)
beta = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
assert gradcheck(
SwishFunction.apply,
(input, beta),
eps=1e-6,
atol=1e-4,
raise_exception=True
), "Swish gradcheck failed"
print("✅ Basic gradcheck passed")
# Test 2: Different beta values
for beta_val in [0.5, 1.0, 2.0]:
beta = torch.tensor(beta_val, dtype=torch.double, requires_grad=True)
assert gradcheck(
SwishFunction.apply,
(input, beta),
eps=1e-6,
atol=1e-4,
raise_exception=True
), f"Swish gradcheck failed for beta={beta_val}"
print("✅ Multiple beta values passed")
# Test 3: Use in module
module = Swish(beta=1.0)
input_single = torch.randn(32, 128, requires_grad=True)
output = module(input_single)
loss = output.sum()
loss.backward()
assert input_single.grad is not None
assert module.beta.grad is not None
print("✅ Module usage works")
print("\n🎉 Swish activation fully tested!")
test_swish()
# Usage example
def use_swish_in_model():
"""Use Swish in a neural network."""
model = nn.Sequential(
nn.Linear(784, 256),
Swish(beta=1.0), # Learnable beta
nn.Linear(256, 128),
Swish(beta=1.0),
nn.Linear(128, 10)
)
# Train model...
# Beta parameters will be learned along with weights
Example 2: Numerically Stable LogSumExp
class LogSumExp(Function):
"""
Numerically stable log-sum-exp operation.
Forward: log(sum(exp(x_i)))
Uses max trick: log(sum(exp(x_i))) = max(x) + log(sum(exp(x_i - max(x))))
Backward: softmax(x_i)
"""
@staticmethod
def forward(ctx, input, dim):
"""
Args:
input: Input tensor
dim: Dimension to sum over
Returns:
logsumexp: log(sum(exp(input))) along dim
"""
# Max trick for numerical stability
max_val, _ = input.max(dim=dim, keepdim=True)
input_shifted = input - max_val
# Compute log-sum-exp
sumexp = torch.exp(input_shifted).sum(dim=dim, keepdim=True)
logsumexp = torch.log(sumexp) + max_val
# Save softmax for backward
softmax = torch.exp(input_shifted) / sumexp
ctx.save_for_backward(softmax)
ctx.dim = dim
return logsumexp.squeeze(dim)
@staticmethod
def backward(ctx, grad_output):
"""
Gradient of log-sum-exp is softmax.
"""
softmax, = ctx.saved_tensors
# Expand grad_output to match softmax shape
grad_output_expanded = grad_output.unsqueeze(ctx.dim)
# Gradient: softmax * grad_output
grad_input = softmax * grad_output_expanded
return grad_input, None
# Test
def test_logsumexp():
"""Test LogSumExp implementation."""
print("Testing LogSumExp...")
input = torch.randn(10, 5, dtype=torch.double, requires_grad=True)
dim = 1
# gradcheck
assert gradcheck(
lambda x: LogSumExp.apply(x, dim),
(input,),
eps=1e-6,
atol=1e-4,
raise_exception=True
), "LogSumExp gradcheck failed"
print("✅ LogSumExp gradcheck passed")
# Compare with PyTorch's implementation
custom_result = LogSumExp.apply(input, dim)
torch_result = torch.logsumexp(input, dim=dim)
assert torch.allclose(custom_result, torch_result, atol=1e-5), \
"LogSumExp doesn't match PyTorch"
print("✅ LogSumExp matches PyTorch implementation")
print("\n🎉 LogSumExp fully tested!")
test_logsumexp()
Example 3: Fused Linear + ReLU (Memory Efficient)
class FusedLinearReLU(Function):
"""
Fused linear + ReLU operation.
Saves memory by not materializing intermediate activations.
Forward: ReLU(X @ W + b)
Memory: Only saves mask (boolean) and input/weights, not intermediate
"""
@staticmethod
def forward(ctx, input, weight, bias):
"""
Args:
input: (batch, in_features)
weight: (out_features, in_features)
bias: (out_features,)
Returns:
output: (batch, out_features)
"""
# Compute linear
linear_output = input.mm(weight.t())
if bias is not None:
linear_output += bias
# Apply ReLU and save mask (not full tensor!)
output = torch.relu(linear_output)
relu_mask = (linear_output > 0) # Boolean (1 bit per element)
# Save only input, weight, bias, and mask
# NOT saving linear_output (saves memory)
ctx.save_for_backward(input, weight, bias)
ctx.relu_mask = relu_mask
return output
@staticmethod
def backward(ctx, grad_output):
"""
Backward through ReLU and linear.
"""
input, weight, bias = ctx.saved_tensors
relu_mask = ctx.relu_mask
grad_input = grad_weight = grad_bias = None
# Gradient through ReLU (use mask)
grad_linear = grad_output * relu_mask.float()
# Gradient through linear
if ctx.needs_input_grad[0]:
grad_input = grad_linear.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_linear.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_linear.sum(0)
return grad_input, grad_weight, grad_bias
# Test
def test_fused_linear_relu():
"""Test fused operation."""
print("Testing FusedLinearReLU...")
input = torch.randn(20, 10, dtype=torch.double, requires_grad=True)
weight = torch.randn(15, 10, dtype=torch.double, requires_grad=True)
bias = torch.randn(15, dtype=torch.double, requires_grad=True)
# gradcheck
assert gradcheck(
FusedLinearReLU.apply,
(input, weight, bias),
eps=1e-6,
atol=1e-4,
raise_exception=True
), "FusedLinearReLU gradcheck failed"
print("✅ FusedLinearReLU gradcheck passed")
# Compare with separate operations
fused_output = FusedLinearReLU.apply(input, weight, bias)
separate_output = torch.relu(input.mm(weight.t()) + bias)
assert torch.allclose(fused_output, separate_output, atol=1e-5), \
"Fused output doesn't match separate operations"
print("✅ Fused matches separate operations")
print("\n🎉 FusedLinearReLU fully tested!")
test_fused_linear_relu()
Debugging Custom Functions
Systematic Debugging Workflow
When your custom function has bugs, follow this workflow:
def debug_custom_function():
"""
Step-by-step debugging of custom autograd functions.
"""
print("=== DEBUGGING CUSTOM AUTOGRAD FUNCTION ===\n")
# Step 1: Verify forward pass works
print("Step 1: Testing forward pass...")
try:
input = torch.randn(5, 3, dtype=torch.double)
weight = torch.randn(4, 3, dtype=torch.double)
output = MyCustomFunction.apply(input, weight)
print(f"✅ Forward pass works")
print(f" Input shape: {input.shape}")
print(f" Weight shape: {weight.shape}")
print(f" Output shape: {output.shape}")
print(f" Output contains NaN: {torch.isnan(output).any()}")
print(f" Output contains Inf: {torch.isinf(output).any()}")
# Check expected shape
expected_shape = (5, 4) # Based on your operation
assert output.shape == expected_shape, \
f"Wrong output shape: {output.shape} != {expected_shape}"
print(f" Output shape correct: {expected_shape}")
except Exception as e:
print(f"❌ Forward pass failed: {e}")
return
print()
# Step 2: Verify backward runs without error
print("Step 2: Testing backward pass...")
try:
input = torch.randn(5, 3, dtype=torch.double, requires_grad=True)
weight = torch.randn(4, 3, dtype=torch.double, requires_grad=True)
output = MyCustomFunction.apply(input, weight)
loss = output.sum()
loss.backward()
print(f"✅ Backward pass runs")
print(f" Input grad shape: {input.grad.shape}")
print(f" Weight grad shape: {weight.grad.shape}")
# Check gradient shapes
assert input.grad.shape == input.shape, \
f"Input gradient shape mismatch: {input.grad.shape} != {input.shape}"
assert weight.grad.shape == weight.shape, \
f"Weight gradient shape mismatch: {weight.grad.shape} != {weight.shape}"
print(f" Gradient shapes correct")
# Check for NaN/Inf
assert not torch.isnan(input.grad).any(), "Input gradient contains NaN"
assert not torch.isnan(weight.grad).any(), "Weight gradient contains NaN"
assert not torch.isinf(input.grad).any(), "Input gradient contains Inf"
assert not torch.isinf(weight.grad).any(), "Weight gradient contains Inf"
print(f" Gradients are finite")
except Exception as e:
print(f"❌ Backward pass failed: {e}")
return
print()
# Step 3: Check gradient magnitudes
print("Step 3: Checking gradient magnitudes...")
print(f" Input grad - mean: {input.grad.mean():.6f}, std: {input.grad.std():.6f}, max: {input.grad.abs().max():.6f}")
print(f" Weight grad - mean: {weight.grad.mean():.6f}, std: {weight.grad.std():.6f}, max: {weight.grad.abs().max():.6f}")
# Reasonable gradient magnitudes (problem-dependent)
if input.grad.abs().max() > 1e6:
print(f" ⚠️ Input gradient very large (may indicate bug)")
if input.grad.abs().max() < 1e-6:
print(f" ⚠️ Input gradient very small (may indicate vanishing gradient)")
print()
# Step 4: Manual numerical gradient check for one element
print("Step 4: Manual numerical gradient check...")
input_test = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
weight_test = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
# Analytical gradient
output = MyCustomFunction.apply(input_test, weight_test)
loss = output.sum()
loss.backward()
analytical_grad = weight_test.grad[0, 0].item()
# Numerical gradient
eps = 1e-6
weight_plus = weight_test.clone().detach()
weight_plus[0, 0] += eps
output_plus = MyCustomFunction.apply(input_test, weight_plus)
loss_plus = output_plus.sum()
weight_minus = weight_test.clone().detach()
weight_minus[0, 0] -= eps
output_minus = MyCustomFunction.apply(input_test, weight_minus)
loss_minus = output_minus.sum()
numerical_grad = ((loss_plus - loss_minus) / (2 * eps)).item()
diff = abs(analytical_grad - numerical_grad)
print(f" Analytical gradient: {analytical_grad:.10f}")
print(f" Numerical gradient: {numerical_grad:.10f}")
print(f" Absolute difference: {diff:.10e}")
if diff < 1e-4:
print(f" ✅ Small difference - gradient likely correct")
else:
print(f" ❌ Large difference - gradient likely WRONG")
print()
# Step 5: Full gradcheck
print("Step 5: Running full gradcheck...")
try:
input_check = torch.randn(10, 5, dtype=torch.double, requires_grad=True)
weight_check = torch.randn(8, 5, dtype=torch.double, requires_grad=True)
result = gradcheck(
MyCustomFunction.apply,
(input_check, weight_check),
eps=1e-6,
atol=1e-4,
raise_exception=True
)
print(f" ✅ gradcheck PASSED!")
print(f"\n🎉 All checks passed! Function is correct.")
except RuntimeError as e:
print(f" ❌ gradcheck FAILED")
print(f" Error: {e}")
print(f"\n Debug hints:")
print(f" - Check gradient computation formulas")
print(f" - Verify all transposes are correct")
print(f" - Ensure shapes match everywhere")
print(f" - Check for in-place operations")
print(f" - Verify saved tensors are correct")
# Run debugging
debug_custom_function()
Common Error Messages and Solutions
"""
ERROR 1: "one of the variables needed for gradient computation has been modified by an inplace operation"
CAUSE: In-place operation in forward or backward
SOLUTION: Replace in-place ops with non-in-place versions
- Replace: tensor[mask] = value
- With: tensor = tensor.clone(); tensor[mask] = value
- Or use: tensor * mask instead of masking
ERROR 2: "grad can be implicitly created only for scalar outputs"
CAUSE: Calling .backward() on non-scalar tensor without grad_output
SOLUTION: Either sum to scalar or provide grad_output
- loss = output.sum(); loss.backward()
- Or: output.backward(torch.ones_like(output))
ERROR 3: "Expected to get X gradient(s) for backward, but got Y"
CAUSE: backward() returns wrong number of gradients
SOLUTION: Return one gradient per forward() argument (excluding ctx)
- forward(ctx, a, b, c) → backward must return (grad_a, grad_b, grad_c)
- Use None for arguments that don't need gradients
ERROR 4: "Sizes of tensors must match except in dimension X"
CAUSE: Shape mismatch in gradient computation
SOLUTION: Ensure grad_input.shape == input.shape
- Print shapes to debug: print(f"grad shape: {grad.shape}, input shape: {input.shape}")
- Handle broadcasting by summing over broadcasted dimensions
ERROR 5: "RuntimeError: Function returned an invalid gradient at index X - got ... but expected shape ..."
CAUSE: Gradient shape doesn't match input shape
SOLUTION: Verify gradient shape matches input exactly
- assert grad_input.shape == input.shape
- Use .view(), .reshape(), .expand() to fix shape
ERROR 6: "gradcheck failed"
CAUSE: Analytical gradient ≠ numerical gradient
SOLUTION: Debug gradient computation
- Check math formulas (derivatives)
- Verify transposes in matrix operations
- Test manually with small tensors
- Run debug_custom_function() above
ERROR 7: "AttributeError: 'Context' object has no attribute 'saved_tensors'"
CAUSE: Accessing saved_tensors in forward (only available in backward)
SOLUTION: Only access saved_tensors in backward()
- forward: ctx.save_for_backward(...)
- backward: ctx.saved_tensors
ERROR 8: "TypeError: save_for_backward() takes 1 positional argument but X were given"
CAUSE: Passing non-tensors to save_for_backward()
SOLUTION: Only save tensors with save_for_backward(), use attributes for others
- ctx.save_for_backward(tensor1, tensor2) # Tensors only
- ctx.some_value = non_tensor_data # Non-tensors as attributes
"""
When NOT to Use Custom Functions
Recognize when custom functions are unnecessary:
# DON'T: Wrapping simple PyTorch operations
class UnnecessaryAdd(Function): # ❌ Pointless
@staticmethod
def forward(ctx, a, b):
return a + b
@staticmethod
def backward(ctx, grad):
return grad, grad
# DO: Use PyTorch directly
output = a + b # ✅ Autograd handles this
# DON'T: Reimplementing existing activations
class UnnecessaryReLU(Function): # ❌ Use torch.relu
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * (input > 0).float()
# DO: Use built-in
output = torch.relu(input) # ✅ Optimized C++ implementation
# DON'T: Operations that compose from standard ops
class UnnecessaryGELU(Function): # ❌ Can compose from existing ops
@staticmethod
def forward(ctx, input):
# GELU = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
# This is composed entirely from standard ops
ctx.save_for_backward(input)
return 0.5 * input * (1 + torch.tanh(...))
@staticmethod
def backward(ctx, grad_output):
# Complex gradient computation
...
# DO: Let autograd handle composition
def gelu(input): # ✅ Autograd computes gradients automatically
return 0.5 * input * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (input + 0.044715 * input ** 3)
))
# Or even better: use built-in
output = torch.nn.functional.gelu(input) # ✅ Most efficient
# WHEN TO USE: True custom operations
class ClippedGradientLinear(Function): # ✅ Custom gradient behavior
"""
Linear operation with clipped gradients.
Can't compose from standard ops - requires custom backward.
"""
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return input @ weight.t()
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
# Custom: clip gradients (can't do with composition)
grad_input = torch.clamp(grad_output @ weight, -1, 1)
grad_weight = grad_output.t() @ input
return grad_input, grad_weight
gradcheck Time Investment
Reality check on "no time for gradcheck":
When under deadline pressure, this rationalization is powerful but wrong:
# Time cost analysis
gradcheck_time = "< 1 second" # Typical operation
debugging_time = "hours to days" # Finding gradient bugs without gradcheck
# ROI calculation
time_investment = 1 # second
time_saved = 3600 * 24 # potentially days
roi = time_saved / time_investment # 86,400x return!
Under deadline pressure, correct workflow:
- Implement Function correctly (5-10 minutes)
- Run gradcheck (1 second)
- Deploy with confidence
NOT:
- Skip verification (saves 1 second)
- Deploy immediately
- Spend days debugging gradient bugs in production (costs hours/days)
Time spent on gradcheck: <1 second (negligible) Time saved from catching bugs early: hours to days (enormous)
The math is clear: Always run gradcheck, even under extreme time pressure.
Multiple Custom Functions Workflow
When implementing multiple functions:
✅ CORRECT: One-at-a-time
- Implement Function 1
- Test Function 1 with gradcheck
- Verify Function 1 passes
- Move to Function 2
- Repeat for each function
❌ WRONG: Batch approach
- Implement all functions
- Test all together
- Debug mess of overlapping bugs
- Waste hours figuring out which function is broken
- Fix bugs one at a time anyway (should have started here)
Why one-at-a-time wins:
- Bugs caught immediately (when you wrote the code, easy to debug)
- Know each function works before building on it
- Same total time, but bugs isolated (no "which function broke?" confusion)
- Build confidence incrementally
- Can use earlier functions while implementing later ones
Example:
# Implementing 5 activations
# ✅ CORRECT
SwishBeta() → test → ✅ → Mish() → test → ✅ → GELU() → test → ✅ ...
# Each bug found immediately after writing that function
# ❌ WRONG
SwishBeta() + Mish() + GELU() + ELU() + SELU() → test all
# Bug found in one, but which? Have to debug all to find it
# 5x the debugging complexity
Handling Approximate Gradients
Special case: External libraries with approximate gradients (finite differences, Monte Carlo estimates).
When wrapping external code with approximate gradients:
Workflow for Approximate Gradients
import torch
from torch.autograd import Function
class ApproximateGradientWrapper(Function):
"""
Wrap external library with approximate gradients.
Key insights:
1. Standard gradcheck WILL fail (approximate ≠ analytical)
2. But can still verify wrapper implementation
3. Must quantify gradient quality
4. Must assess if quality is acceptable
"""
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return external_library_forward(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# Library provides approximate gradient
grad_input = external_library_gradient(input, grad_output)
return grad_input
# Can't run standard gradcheck, but CAN verify:
def test_approximate_gradient_wrapper():
"""Verification workflow for approximate gradients."""
# 1. Verify wrapper mechanics (forward/backward run)
input = torch.randn(5, requires_grad=True)
output = ApproximateGradientWrapper.apply(input)
output.sum().backward()
assert input.grad is not None
print("✅ Wrapper mechanics work")
# 2. Quantify gradient quality (compare with numerical)
input_test = torch.randn(3, requires_grad=True)
output = ApproximateGradientWrapper.apply(input_test)
output.sum().backward()
library_gradient = input_test.grad.clone()
# Compute our own numerical gradient
eps = 1e-6
numerical_gradient = compute_numerical_gradient(input_test, eps)
# Measure error
error = (library_gradient - numerical_gradient).abs().max()
print(f"Gradient error: {error:.6e}")
# 3. Assess acceptability
if error < 1e-3:
print("✅ Approximate gradient high quality (good enough)")
elif error < 1e-2:
print("⚠️ Approximate gradient has noticeable error (may affect optimization)")
else:
print("❌ Approximate gradient poor quality (likely to cause issues)")
raise ValueError("Gradient quality unacceptable")
# 4. Document in code
print("\n📝 Documentation:")
print(f" - Gradients are approximate (error: {error:.6e})")
print(" - Standard gradcheck will fail (expected)")
print(" - Wrapper implementation verified correct")
print(" - Gradient quality measured and acceptable")
return error
# Run verification
gradient_error = test_approximate_gradient_wrapper()
Key Points for Approximate Gradients
- Standard gradcheck will fail - This is expected (approximate ≠ analytical)
- Test wrapper implementation - Verify mechanics (forward/backward run)
- Quantify gradient quality - Measure error vs numerical gradient
- Assess acceptability - Is error tolerable for your use case?
- Document limitations - Record gradient error magnitude in code/docs
Don't skip verification - Adapt it to approximate case.
"Good enough" requires evidence - Measure error, don't assume.
Rationalization Resistance
| Rationalization | Reality | Counter-Response |
|---|---|---|
| "Autograd will figure it out" | Only for standard ops; custom ops need Function | Use Function for non-standard operations. Autograd needs explicit backward implementation. |
| "Gradient looks mathematically correct" | Implementation bugs invisible without testing | Always run gradcheck. Math correctness ≠ implementation correctness. |
| "gradcheck is slow, skip for speed" | Catches bugs early; debugging later costs more | gradcheck is fast (<1s). Finding gradient bugs without it takes hours/days. |
| "No time for gradcheck, deadline NOW" | gradcheck takes <1s, debugging takes hours | 1 second now saves hours of production debugging. Always run gradcheck. |
| "Too complex for me" | Pattern is standardized; template works | Follow template. Thousands of successful implementations exist. |
| "In-place is more efficient" | Breaks autograd graph; causes crashes | Never use in-place in custom functions. Memory savings negligible, bugs catastrophic. |
| "Shape will probably work out" | Must match exactly; no flexibility | Gradient shape must equal input shape exactly. Verify with assertions. |
| "ctx details don't matter much" | Incorrect usage breaks everything | ctx.save_for_backward() is mandatory for tensors. Attributes break memory tracking. |
| "My manual test is good enough" | Misses edge cases gradcheck catches | Manual tests catch obvious bugs. gradcheck catches subtle numerical errors. |
| "Batch test all functions together" | Overlapping bugs hard to debug | Test one at a time. Bugs isolated immediately, same total time. |
| "Don't need needs_input_grad check" | Wastes computation; slower training | Always check needs_input_grad. Free optimization, no downside. |
| "Approximate gradients, can't verify" | Can verify wrapper and measure quality | Adapt verification. Test wrapper mechanics, quantify error, assess acceptability. |
| "Second-order derivatives too advanced" | Same pattern as first-order | Not advanced. Same template, test with gradgradcheck. Accessible to all. |
| "Can skip double backward support" | Breaks higher-order derivatives if needed | If you might need Hessian/meta-learning, support double backward from start. |
| "Detach doesn't matter here" | Controls gradient flow; critical | Understand when to detach. Impacts what gets gradients. |
| "I'll verify gradients during training" | Training metrics hide gradient bugs | Verify before training. Gradient bugs cause subtle issues (slow convergence, wrong behavior). |
| "Test in production, faster iteration" | Production debugging catastrophic | Test before deployment. Production gradient bugs cause model failure. |
Red Flags Checklist
Stop and verify if you see:
-
⚠️ Not using torch.autograd.Function for custom operations
- Writing normal functions for truly custom ops (external code, novel math)
-
⚠️ No gradcheck before using the function
- Skipping numerical verification
- "Testing during training" instead of proper gradcheck
-
⚠️ In-place operations in forward or backward
tensor[mask] = value,tensor += other,tensor.mul_(other)- Modifying
grad_outputin backward
-
⚠️ Wrong number of return values from backward
- Not returning gradient for each forward() input
- Missing None for non-tensor arguments
-
⚠️ Not using ctx.save_for_backward() for tensors
- Saving tensors as
ctx.tensor = tensorinstead - Saving non-tensors with save_for_backward()
- Saving tensors as
-
⚠️ Gradient shape doesn't match input shape
- Not verifying
grad_input.shape == input.shape - Forgetting to expand/sum for broadcasting
- Not verifying
-
⚠️ Missing needs_input_grad checks
- Always computing all gradients even when not needed
- Not checking
ctx.needs_input_grad[i]before expensive computation
-
⚠️ Using .data instead of .detach()
- Accessing
.dataattribute - Not understanding difference between .data and .detach()
- Accessing
-
⚠️ Accessing ctx.saved_tensors in forward
- Trying to use saved tensors before backward
- Not understanding ctx lifecycle
-
⚠️ Modifying saved tensors in backward
- In-place operations on tensors from ctx.saved_tensors
- Breaking gradient graph
-
⚠️ Ignoring gradcheck failures
- "Gradient close enough"
- Not investigating why gradcheck failed
-
⚠️ Using custom Function for standard operations
- Reimplementing built-in operations unnecessarily
- Not checking if PyTorch already provides it
-
⚠️ Batching implementation without incremental testing
- Implementing multiple functions before testing any
- "Will test them all together" approach
-
⚠️ Skipping gradcheck under time pressure
- "Deadline is tight, verify later"
- "No time for gradcheck"
-
⚠️ Assuming approximate gradients can't be verified
- "Library provides gradients, can't test"
- Not measuring gradient quality
-
⚠️ Avoiding second-order derivatives due to perceived complexity
- "Too advanced for me"
- Not attempting gradgradcheck
-
⚠️ Deploying to production without verification
- "Test with real data"
- Skipping numerical verification
Summary Checklist
Before deploying a custom autograd function:
- Verified custom Function is actually needed (can't compose from standard ops)
- Implemented forward() correctly (saves needed tensors/data)
- Used ctx.save_for_backward() for ALL tensors
- Saved non-tensor data as ctx attributes
- Implemented backward() with correct gradient formulas
- Verified gradient shape matches input shape exactly
- Returned gradient for EVERY forward() input (None if not needed)
- Added needs_input_grad checks for expensive computations
- Avoided ALL in-place operations
- Ran gradcheck and verified it PASSES
- Tested with different input shapes
- Tested with optional parameters (if any)
- Verified no NaN/Inf in outputs or gradients
- Checked gradient magnitudes are reasonable
- Tested double backward if needed
- Added documentation explaining when to use this function
- Considered memory efficiency (detach, selective saving)
Final Notes
Custom autograd functions are powerful but require precision:
- Use them when truly needed (custom ops, external code, memory optimization, custom gradients)
- Follow the template pattern (don't reinvent)
- Always run gradcheck (non-negotiable)
- Understand ctx rules (save_for_backward for tensors, attributes for non-tensors)
- Verify shapes (gradient must match input)
- Avoid in-place operations (they break autograd)
- Check needs_input_grad (optimization)
- Test thoroughly before using in training
The gradient implementation is as important as the forward computation. Bugs in backward() are silent and catastrophic - they cause models to learn wrong things. gradcheck is your safety net; never skip it.
When in doubt: implement, run gradcheck, debug until it passes, then use with confidence.