18 KiB
PyTorch Geometric Transforms Reference
This document provides a comprehensive reference of all transforms available in torch_geometric.transforms.
Overview
Transforms modify Data or HeteroData objects before or during training. Apply them via:
# During dataset loading
dataset = MyDataset(root='/tmp', transform=MyTransform())
# Apply to individual data
transform = MyTransform()
data = transform(data)
# Compose multiple transforms
from torch_geometric.transforms import Compose
transform = Compose([Transform1(), Transform2(), Transform3()])
General Transforms
NormalizeFeatures
Purpose: Row-normalizes node features to sum to 1 Use case: Feature scaling, probability-like features
from torch_geometric.transforms import NormalizeFeatures
transform = NormalizeFeatures()
ToDevice
Purpose: Transfers data to specified device (CPU/GPU) Use case: GPU training, device management
from torch_geometric.transforms import ToDevice
transform = ToDevice('cuda')
RandomNodeSplit
Purpose: Creates train/val/test node masks
Use case: Node classification splits
Parameters: split='train_rest', num_splits, num_val, num_test
from torch_geometric.transforms import RandomNodeSplit
transform = RandomNodeSplit(num_val=0.1, num_test=0.2)
RandomLinkSplit
Purpose: Creates train/val/test edge splits
Use case: Link prediction
Parameters: num_val, num_test, is_undirected, split_labels
from torch_geometric.transforms import RandomLinkSplit
transform = RandomLinkSplit(num_val=0.1, num_test=0.2)
IndexToMask
Purpose: Converts indices to boolean masks Use case: Data preprocessing
from torch_geometric.transforms import IndexToMask
transform = IndexToMask()
MaskToIndex
Purpose: Converts boolean masks to indices Use case: Data preprocessing
from torch_geometric.transforms import MaskToIndex
transform = MaskToIndex()
FixedPoints
Purpose: Samples a fixed number of points
Use case: Point cloud subsampling
Parameters: num, replace, allow_duplicates
from torch_geometric.transforms import FixedPoints
transform = FixedPoints(1024)
ToDense
Purpose: Converts to dense adjacency matrices Use case: Small graphs, dense operations
from torch_geometric.transforms import ToDense
transform = ToDense(num_nodes=100)
ToSparseTensor
Purpose: Converts edge_index to SparseTensor
Use case: Efficient sparse operations
Parameters: remove_edge_index, fill_cache
from torch_geometric.transforms import ToSparseTensor
transform = ToSparseTensor()
Graph Structure Transforms
ToUndirected
Purpose: Converts directed graph to undirected
Use case: Undirected graph algorithms
Parameters: reduce='add' (how to handle duplicate edges)
from torch_geometric.transforms import ToUndirected
transform = ToUndirected()
AddSelfLoops
Purpose: Adds self-loops to all nodes
Use case: GCN-style convolutions
Parameters: fill_value (edge attribute for self-loops)
from torch_geometric.transforms import AddSelfLoops
transform = AddSelfLoops()
RemoveSelfLoops
Purpose: Removes all self-loops Use case: Cleaning graph structure
from torch_geometric.transforms import RemoveSelfLoops
transform = RemoveSelfLoops()
RemoveIsolatedNodes
Purpose: Removes nodes without edges Use case: Graph cleaning
from torch_geometric.transforms import RemoveIsolatedNodes
transform = RemoveIsolatedNodes()
RemoveDuplicatedEdges
Purpose: Removes duplicate edges Use case: Graph cleaning
from torch_geometric.transforms import RemoveDuplicatedEdges
transform = RemoveDuplicatedEdges()
LargestConnectedComponents
Purpose: Keeps only the largest connected component
Use case: Focus on main graph structure
Parameters: num_components (how many components to keep)
from torch_geometric.transforms import LargestConnectedComponents
transform = LargestConnectedComponents(num_components=1)
KNNGraph
Purpose: Creates edges based on k-nearest neighbors
Use case: Point clouds, spatial data
Parameters: k, loop, force_undirected, flow
from torch_geometric.transforms import KNNGraph
transform = KNNGraph(k=6)
RadiusGraph
Purpose: Creates edges within a radius
Use case: Point clouds, spatial data
Parameters: r, loop, max_num_neighbors, flow
from torch_geometric.transforms import RadiusGraph
transform = RadiusGraph(r=0.1)
Delaunay
Purpose: Computes Delaunay triangulation Use case: 2D/3D spatial graphs
from torch_geometric.transforms import Delaunay
transform = Delaunay()
FaceToEdge
Purpose: Converts mesh faces to edges Use case: Mesh processing
from torch_geometric.transforms import FaceToEdge
transform = FaceToEdge()
LineGraph
Purpose: Converts graph to its line graph
Use case: Edge-centric analysis
Parameters: force_directed
from torch_geometric.transforms import LineGraph
transform = LineGraph()
GDC
Purpose: Graph Diffusion Convolution preprocessing
Use case: Improved message passing
Parameters: self_loop_weight, normalization_in, normalization_out, diffusion_kwargs
from torch_geometric.transforms import GDC
transform = GDC(self_loop_weight=1, normalization_in='sym',
diffusion_kwargs=dict(method='ppr', alpha=0.15))
SIGN
Purpose: Scalable Inception Graph Neural Networks preprocessing
Use case: Efficient multi-scale features
Parameters: K (number of hops)
from torch_geometric.transforms import SIGN
transform = SIGN(K=3)
Feature Transforms
OneHotDegree
Purpose: One-hot encodes node degree
Use case: Degree as feature
Parameters: max_degree, cat (concatenate with existing features)
from torch_geometric.transforms import OneHotDegree
transform = OneHotDegree(max_degree=100)
LocalDegreeProfile
Purpose: Appends local degree profile Use case: Structural node features
from torch_geometric.transforms import LocalDegreeProfile
transform = LocalDegreeProfile()
Constant
Purpose: Adds constant features to nodes
Use case: Featureless graphs
Parameters: value, cat
from torch_geometric.transforms import Constant
transform = Constant(value=1.0)
TargetIndegree
Purpose: Saves in-degree as target
Use case: Degree prediction
Parameters: norm, max_value
from torch_geometric.transforms import TargetIndegree
transform = TargetIndegree(norm=False)
AddRandomWalkPE
Purpose: Adds random walk positional encoding
Use case: Positional information
Parameters: walk_length, attr_name
from torch_geometric.transforms import AddRandomWalkPE
transform = AddRandomWalkPE(walk_length=20)
AddLaplacianEigenvectorPE
Purpose: Adds Laplacian eigenvector positional encoding
Use case: Spectral positional information
Parameters: k (number of eigenvectors), attr_name
from torch_geometric.transforms import AddLaplacianEigenvectorPE
transform = AddLaplacianEigenvectorPE(k=10)
AddMetaPaths
Purpose: Adds meta-path induced edges
Use case: Heterogeneous graphs
Parameters: metapaths, drop_orig_edges, drop_unconnected_nodes
from torch_geometric.transforms import AddMetaPaths
metapaths = [[('author', 'paper'), ('paper', 'author')]] # Co-authorship
transform = AddMetaPaths(metapaths)
SVDFeatureReduction
Purpose: Reduces feature dimensionality via SVD
Use case: Dimensionality reduction
Parameters: out_channels
from torch_geometric.transforms import SVDFeatureReduction
transform = SVDFeatureReduction(out_channels=64)
Vision/Spatial Transforms
Center
Purpose: Centers node positions Use case: Point cloud preprocessing
from torch_geometric.transforms import Center
transform = Center()
NormalizeScale
Purpose: Normalizes positions to unit sphere Use case: Point cloud normalization
from torch_geometric.transforms import NormalizeScale
transform = NormalizeScale()
NormalizeRotation
Purpose: Rotates to principal components
Use case: Rotation-invariant learning
Parameters: max_points
from torch_geometric.transforms import NormalizeRotation
transform = NormalizeRotation()
Distance
Purpose: Saves Euclidean distance as edge attribute
Use case: Spatial graphs
Parameters: norm, max_value, cat
from torch_geometric.transforms import Distance
transform = Distance(norm=False, cat=False)
Cartesian
Purpose: Saves relative Cartesian coordinates as edge attributes
Use case: Spatial relationships
Parameters: norm, max_value, cat
from torch_geometric.transforms import Cartesian
transform = Cartesian(norm=False)
Polar
Purpose: Saves polar coordinates as edge attributes
Use case: 2D spatial graphs
Parameters: norm, max_value, cat
from torch_geometric.transforms import Polar
transform = Polar(norm=False)
Spherical
Purpose: Saves spherical coordinates as edge attributes
Use case: 3D spatial graphs
Parameters: norm, max_value, cat
from torch_geometric.transforms import Spherical
transform = Spherical(norm=False)
LocalCartesian
Purpose: Saves coordinates in local coordinate system
Use case: Local spatial features
Parameters: norm, cat
from torch_geometric.transforms import LocalCartesian
transform = LocalCartesian()
PointPairFeatures
Purpose: Computes point pair features
Use case: 3D registration, correspondence
Parameters: cat
from torch_geometric.transforms import PointPairFeatures
transform = PointPairFeatures()
Data Augmentation
RandomJitter
Purpose: Randomly jitters node positions
Use case: Point cloud augmentation
Parameters: translate, scale
from torch_geometric.transforms import RandomJitter
transform = RandomJitter(0.01)
RandomFlip
Purpose: Randomly flips positions along axis
Use case: Geometric augmentation
Parameters: axis, p (probability)
from torch_geometric.transforms import RandomFlip
transform = RandomFlip(axis=0, p=0.5)
RandomScale
Purpose: Randomly scales positions
Use case: Scale augmentation
Parameters: scales (min, max)
from torch_geometric.transforms import RandomScale
transform = RandomScale((0.9, 1.1))
RandomRotate
Purpose: Randomly rotates positions
Use case: Rotation augmentation
Parameters: degrees (range), axis (rotation axis)
from torch_geometric.transforms import RandomRotate
transform = RandomRotate(degrees=15, axis=2)
RandomShear
Purpose: Randomly shears positions
Use case: Geometric augmentation
Parameters: shear (range)
from torch_geometric.transforms import RandomShear
transform = RandomShear(0.1)
RandomTranslate
Purpose: Randomly translates positions
Use case: Translation augmentation
Parameters: translate (range)
from torch_geometric.transforms import RandomTranslate
transform = RandomTranslate(0.1)
LinearTransformation
Purpose: Applies linear transformation matrix
Use case: Custom geometric transforms
Parameters: matrix
from torch_geometric.transforms import LinearTransformation
import torch
matrix = torch.eye(3)
transform = LinearTransformation(matrix)
Mesh Processing
SamplePoints
Purpose: Samples points uniformly from mesh
Use case: Mesh to point cloud conversion
Parameters: num, remove_faces, include_normals
from torch_geometric.transforms import SamplePoints
transform = SamplePoints(num=1024)
GenerateMeshNormals
Purpose: Generates face/vertex normals Use case: Mesh processing
from torch_geometric.transforms import GenerateMeshNormals
transform = GenerateMeshNormals()
FaceToEdge
Purpose: Converts mesh faces to edges
Use case: Mesh to graph conversion
Parameters: remove_faces
from torch_geometric.transforms import FaceToEdge
transform = FaceToEdge()
Sampling and Splitting
GridSampling
Purpose: Clusters points in voxel grid
Use case: Point cloud downsampling
Parameters: size (voxel size), start, end
from torch_geometric.transforms import GridSampling
transform = GridSampling(size=0.1)
FixedPoints
Purpose: Samples fixed number of points
Use case: Uniform point cloud size
Parameters: num, replace, allow_duplicates
from torch_geometric.transforms import FixedPoints
transform = FixedPoints(num=2048, replace=False)
RandomScale
Purpose: Randomly scales by sampling from range Use case: Scale augmentation (already listed above)
VirtualNode
Purpose: Adds a virtual node connected to all nodes Use case: Global information propagation
from torch_geometric.transforms import VirtualNode
transform = VirtualNode()
Specialized Transforms
ToSLIC
Purpose: Converts images to superpixel graphs (SLIC algorithm)
Use case: Image as graph
Parameters: num_segments, compactness, add_seg, add_img
from torch_geometric.transforms import ToSLIC
transform = ToSLIC(num_segments=75)
GCNNorm
Purpose: Applies GCN-style normalization to edges
Use case: Preprocessing for GCN
Parameters: add_self_loops
from torch_geometric.transforms import GCNNorm
transform = GCNNorm(add_self_loops=True)
LaplacianLambdaMax
Purpose: Computes largest Laplacian eigenvalue
Use case: ChebConv preprocessing
Parameters: normalization, is_undirected
from torch_geometric.transforms import LaplacianLambdaMax
transform = LaplacianLambdaMax(normalization='sym')
NormalizeRotation
Purpose: Rotates mesh/point cloud to align with principal axes
Use case: Canonical orientation
Parameters: max_points
from torch_geometric.transforms import NormalizeRotation
transform = NormalizeRotation()
Compose and Apply
Compose
Purpose: Chains multiple transforms Use case: Complex preprocessing pipelines
from torch_geometric.transforms import Compose
transform = Compose([
Center(),
NormalizeScale(),
KNNGraph(k=6),
Distance(norm=False),
])
BaseTransform
Purpose: Base class for custom transforms Use case: Implementing custom transforms
from torch_geometric.transforms import BaseTransform
class MyTransform(BaseTransform):
def __init__(self, param):
self.param = param
def __call__(self, data):
# Modify data
data.x = data.x * self.param
return data
Common Transform Combinations
Node Classification Preprocessing
transform = Compose([
NormalizeFeatures(),
RandomNodeSplit(num_val=0.1, num_test=0.2),
])
Point Cloud Processing
transform = Compose([
Center(),
NormalizeScale(),
RandomRotate(degrees=15, axis=2),
RandomJitter(0.01),
KNNGraph(k=6),
Distance(norm=False),
])
Mesh to Graph
transform = Compose([
FaceToEdge(remove_faces=True),
GenerateMeshNormals(),
Distance(norm=True),
])
Graph Structure Enhancement
transform = Compose([
ToUndirected(),
AddSelfLoops(),
RemoveIsolatedNodes(),
GCNNorm(),
])
Heterogeneous Graph Preprocessing
transform = Compose([
AddMetaPaths(metapaths=[
[('author', 'paper'), ('paper', 'author')],
[('author', 'paper'), ('paper', 'conference'), ('conference', 'paper'), ('paper', 'author')]
]),
RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.2),
])
Link Prediction
transform = Compose([
NormalizeFeatures(),
RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True),
])
Usage Tips
- Order matters: Apply structural transforms before feature transforms
- Caching: Some transforms (like GDC) are expensive—apply once
- Augmentation: Use Random* transforms during training only
- Compose sparingly: Too many transforms slow down data loading
- Custom transforms: Inherit from
BaseTransformfor custom logic - Pre-transforms: Apply expensive transforms once during dataset processing:
dataset = MyDataset(root='/tmp', pre_transform=ExpensiveTransform()) - Dynamic transforms: Apply cheap transforms during training:
dataset = MyDataset(root='/tmp', transform=CheapTransform())
Performance Considerations
Expensive transforms (apply as pre_transform):
- GDC
- SIGN
- KNNGraph (for large point clouds)
- AddLaplacianEigenvectorPE
- SVDFeatureReduction
Cheap transforms (apply as transform):
- NormalizeFeatures
- ToUndirected
- AddSelfLoops
- Random* augmentations
- ToDevice
Example:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import Compose, GDC, NormalizeFeatures
# Expensive preprocessing done once
pre_transform = GDC(
self_loop_weight=1,
normalization_in='sym',
diffusion_kwargs=dict(method='ppr', alpha=0.15)
)
# Cheap transform applied each time
transform = NormalizeFeatures()
dataset = Planetoid(
root='/tmp/Cora',
name='Cora',
pre_transform=pre_transform,
transform=transform
)