Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 08:30:10 +08:00
commit f0bd18fb4e
824 changed files with 331919 additions and 0 deletions

View File

@@ -0,0 +1,670 @@
---
name: torch-geometric
description: "Graph Neural Networks (PyG). Node/graph classification, link prediction, GCN, GAT, GraphSAGE, heterogeneous graphs, molecular property prediction, for geometric deep learning."
---
# PyTorch Geometric (PyG)
## Overview
PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). Apply this skill for deep learning on graphs and irregular structures, including mini-batch processing, multi-GPU training, and geometric deep learning applications.
## When to Use This Skill
This skill should be used when working with:
- **Graph-based machine learning**: Node classification, graph classification, link prediction
- **Molecular property prediction**: Drug discovery, chemical property prediction
- **Social network analysis**: Community detection, influence prediction
- **Citation networks**: Paper classification, recommendation systems
- **3D geometric data**: Point clouds, meshes, molecular structures
- **Heterogeneous graphs**: Multi-type nodes and edges (e.g., knowledge graphs)
- **Large-scale graph learning**: Neighbor sampling, distributed training
## Quick Start
### Installation
```bash
uv pip install torch_geometric
```
For additional dependencies (sparse operations, clustering):
```bash
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
```
### Basic Graph Creation
```python
import torch
from torch_geometric.data import Data
# Create a simple graph with 3 nodes
edge_index = torch.tensor([[0, 1, 1, 2], # source nodes
[1, 0, 2, 1]], dtype=torch.long) # target nodes
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # node features
data = Data(x=x, edge_index=edge_index)
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
```
### Loading a Benchmark Dataset
```python
from torch_geometric.datasets import Planetoid
# Load Cora citation network
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # Get the first (and only) graph
print(f"Dataset: {dataset}")
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")
```
## Core Concepts
### Data Structure
PyG represents graphs using the `torch_geometric.data.Data` class with these key attributes:
- **`data.x`**: Node feature matrix `[num_nodes, num_node_features]`
- **`data.edge_index`**: Graph connectivity in COO format `[2, num_edges]`
- **`data.edge_attr`**: Edge feature matrix `[num_edges, num_edge_features]` (optional)
- **`data.y`**: Target labels for nodes or graphs
- **`data.pos`**: Node spatial positions `[num_nodes, num_dimensions]` (optional)
- **Custom attributes**: Can add any attribute (e.g., `data.train_mask`, `data.batch`)
**Important**: These attributes are not mandatory—extend Data objects with custom attributes as needed.
### Edge Index Format
Edges are stored in COO (coordinate) format as a `[2, num_edges]` tensor:
- First row: source node indices
- Second row: target node indices
```python
# Edge list: (0→1), (1→0), (1→2), (2→1)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
```
### Mini-Batch Processing
PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph:
- Adjacency matrices are stacked diagonally
- Node features are concatenated along the node dimension
- A `batch` vector maps each node to its source graph
- No padding needed—computationally efficient
```python
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
print(f"Batch size: {batch.num_graphs}")
print(f"Total nodes: {batch.num_nodes}")
# batch.batch maps nodes to graphs
```
## Building Graph Neural Networks
### Message Passing Paradigm
GNNs in PyG follow a neighborhood aggregation scheme:
1. Transform node features
2. Propagate messages along edges
3. Aggregate messages from neighbors
4. Update node representations
### Using Pre-Built Layers
PyG provides 40+ convolutional layers. Common ones include:
**GCNConv** (Graph Convolutional Network):
```python
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
```
**GATConv** (Graph Attention Network):
```python
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
```
**GraphSAGE**:
```python
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = SAGEConv(num_features, 64)
self.conv2 = SAGEConv(64, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
```
### Custom Message Passing Layers
For custom layers, inherit from `MessagePassing`:
```python
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class CustomConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add", "mean", or "max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Add self-loops to adjacency matrix
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Transform node features
x = self.lin(x)
# Compute normalization
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Propagate messages
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j: features of source nodes
return norm.view(-1, 1) * x_j
```
Key methods:
- **`forward()`**: Main entry point
- **`message()`**: Constructs messages from source to target nodes
- **`aggregate()`**: Aggregates messages (usually don't override—set `aggr` parameter)
- **`update()`**: Updates node embeddings after aggregation
**Variable naming convention**: Appending `_i` or `_j` to tensor names automatically maps them to target or source nodes.
## Working with Datasets
### Loading Built-in Datasets
PyG provides extensive benchmark datasets:
```python
# Citation networks (node classification)
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora') # or 'CiteSeer', 'PubMed'
# Graph classification
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
# Molecular datasets
from torch_geometric.datasets import QM9
dataset = QM9(root='/tmp/QM9')
# Large-scale datasets
from torch_geometric.datasets import Reddit
dataset = Reddit(root='/tmp/Reddit')
```
Check `references/datasets_reference.md` for a comprehensive list.
### Creating Custom Datasets
For datasets that fit in memory, inherit from `InMemoryDataset`:
```python
from torch_geometric.data import InMemoryDataset, Data
import torch
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
self.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['my_data.csv'] # Files needed in raw_dir
@property
def processed_file_names(self):
return ['data.pt'] # Files in processed_dir
def download(self):
# Download raw data to self.raw_dir
pass
def process(self):
# Read data, create Data objects
data_list = []
# Example: Create a simple graph
edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
x = torch.randn(2, 16)
y = torch.tensor([0], dtype=torch.long)
data = Data(x=x, edge_index=edge_index, y=y)
data_list.append(data)
# Apply pre_filter and pre_transform
if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]
# Save processed data
self.save(data_list, self.processed_paths[0])
```
For large datasets that don't fit in memory, inherit from `Dataset` and implement `len()` and `get(idx)`.
### Loading Graphs from CSV
```python
import pandas as pd
import torch
from torch_geometric.data import HeteroData
# Load nodes
nodes_df = pd.read_csv('nodes.csv')
x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)
# Load edges
edges_df = pd.read_csv('edges.csv')
edge_index = torch.tensor([edges_df['source'].values,
edges_df['target'].values], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
```
## Training Workflows
### Node Classification (Single Graph)
```python
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Create model
model = GCN(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# Evaluation
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Test Accuracy: {acc:.4f}')
```
### Graph Classification (Multiple Graphs)
```python
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
class GraphClassifier(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.conv2 = GCNConv(64, 64)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
# Global pooling (aggregate node features to graph-level)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# Load dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Training
model.train()
for epoch in range(100):
total_loss = 0
for batch in loader:
optimizer.zero_grad()
out = model(batch)
loss = F.nll_loss(out, batch.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')
```
### Large-Scale Graphs with Neighbor Sampling
For large graphs, use `NeighborLoader` to sample subgraphs:
```python
from torch_geometric.loader import NeighborLoader
# Create a neighbor sampler
train_loader = NeighborLoader(
data,
num_neighbors=[25, 10], # Sample 25 neighbors for 1st hop, 10 for 2nd hop
batch_size=128,
input_nodes=data.train_mask,
)
# Training
model.train()
for batch in train_loader:
optimizer.zero_grad()
out = model(batch)
# Only compute loss on seed nodes (first batch_size nodes)
loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
```
**Important**:
- Output subgraphs are directed
- Node indices are relabeled (0 to batch.num_nodes - 1)
- Only use seed node predictions for loss computation
- Sampling beyond 2-3 hops is generally not feasible
## Advanced Features
### Heterogeneous Graphs
For graphs with multiple node and edge types, use `HeteroData`:
```python
from torch_geometric.data import HeteroData
data = HeteroData()
# Add node features for different types
data['paper'].x = torch.randn(100, 128) # 100 papers with 128 features
data['author'].x = torch.randn(200, 64) # 200 authors with 64 features
# Add edges for different types (source_type, edge_type, target_type)
data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500))
data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300))
print(data)
```
Convert homogeneous models to heterogeneous:
```python
from torch_geometric.nn import to_hetero
# Define homogeneous model
model = GNN(...)
# Convert to heterogeneous
model = to_hetero(model, data.metadata(), aggr='sum')
# Use as normal
out = model(data.x_dict, data.edge_index_dict)
```
Or use `HeteroConv` for custom edge-type-specific operations:
```python
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv
class HeteroGNN(torch.nn.Module):
def __init__(self, metadata):
super().__init__()
self.conv1 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
}, aggr='sum')
self.conv2 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(64, 32),
('author', 'writes', 'paper'): SAGEConv((64, 64), 32),
}, aggr='sum')
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
```
### Transforms
Apply transforms to modify graph structure or features:
```python
from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose
# Single transform
transform = NormalizeFeatures()
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
# Compose multiple transforms
transform = Compose([
AddSelfLoops(),
NormalizeFeatures(),
])
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
```
Common transforms:
- **Structure**: `ToUndirected`, `AddSelfLoops`, `RemoveSelfLoops`, `KNNGraph`, `RadiusGraph`
- **Features**: `NormalizeFeatures`, `NormalizeScale`, `Center`
- **Sampling**: `RandomNodeSplit`, `RandomLinkSplit`
- **Positional Encoding**: `AddLaplacianEigenvectorPE`, `AddRandomWalkPE`
See `references/transforms_reference.md` for the full list.
### Model Explainability
PyG provides explainability tools to understand model predictions:
```python
from torch_geometric.explain import Explainer, GNNExplainer
# Create explainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model', # or 'phenomenon'
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
# Generate explanation for a specific node
node_idx = 10
explanation = explainer(data.x, data.edge_index, index=node_idx)
# Visualize
print(f'Node {node_idx} explanation:')
print(f'Important edges: {explanation.edge_mask.topk(5).indices}')
print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')
```
### Pooling Operations
For hierarchical graph representations:
```python
from torch_geometric.nn import TopKPooling, global_mean_pool
class HierarchicalGNN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.pool1 = TopKPooling(64, ratio=0.8)
self.conv2 = GCNConv(64, 64)
self.pool2 = TopKPooling(64, ratio=0.8)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
```
## Common Patterns and Best Practices
### Check Graph Properties
```python
# Undirected check
from torch_geometric.utils import is_undirected
print(f"Is undirected: {is_undirected(data.edge_index)}")
# Connected components
from torch_geometric.utils import connected_components
print(f"Connected components: {connected_components(data.edge_index)}")
# Contains self-loops
from torch_geometric.utils import contains_self_loops
print(f"Has self-loops: {contains_self_loops(data.edge_index)}")
```
### GPU Training
```python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
# For DataLoader
for batch in loader:
batch = batch.to(device)
# Train...
```
### Save and Load Models
```python
# Save
torch.save(model.state_dict(), 'model.pth')
# Load
model = GCN(num_features, num_classes)
model.load_state_dict(torch.load('model.pth'))
model.eval()
```
### Layer Capabilities
When choosing layers, consider these capabilities:
- **SparseTensor**: Supports efficient sparse matrix operations
- **edge_weight**: Handles one-dimensional edge weights
- **edge_attr**: Processes multi-dimensional edge features
- **Bipartite**: Works with bipartite graphs (different source/target dimensions)
- **Lazy**: Enables initialization without specifying input dimensions
See the GNN cheatsheet at `references/layer_capabilities.md`.
## Resources
### Bundled References
This skill includes detailed reference documentation:
- **`references/layers_reference.md`**: Complete listing of all 40+ GNN layers with descriptions and capabilities
- **`references/datasets_reference.md`**: Comprehensive dataset catalog organized by category
- **`references/transforms_reference.md`**: All available transforms and their use cases
- **`references/api_patterns.md`**: Common API patterns and coding examples
### Scripts
Utility scripts are provided in `scripts/`:
- **`scripts/visualize_graph.py`**: Visualize graph structure using networkx and matplotlib
- **`scripts/create_gnn_template.py`**: Generate boilerplate code for common GNN architectures
- **`scripts/benchmark_model.py`**: Benchmark model performance on standard datasets
Execute scripts directly or read them for implementation patterns.
### Official Resources
- **Documentation**: https://pytorch-geometric.readthedocs.io/
- **GitHub**: https://github.com/pyg-team/pytorch_geometric
- **Tutorials**: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html
- **Examples**: https://github.com/pyg-team/pytorch_geometric/tree/master/examples

View File

@@ -0,0 +1,574 @@
# PyTorch Geometric Datasets Reference
This document provides a comprehensive catalog of all datasets available in `torch_geometric.datasets`.
## Citation Networks
### Planetoid
**Usage**: Node classification, semi-supervised learning
**Networks**: Cora, CiteSeer, PubMed
**Description**: Citation networks where nodes are papers and edges are citations
- **Cora**: 2,708 nodes, 5,429 edges, 7 classes, 1,433 features
- **CiteSeer**: 3,327 nodes, 4,732 edges, 6 classes, 3,703 features
- **PubMed**: 19,717 nodes, 44,338 edges, 3 classes, 500 features
```python
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
```
### Coauthor
**Usage**: Node classification on collaboration networks
**Networks**: CS, Physics
**Description**: Co-authorship networks from Microsoft Academic Graph
- **CS**: 18,333 nodes, 81,894 edges, 15 classes (computer science)
- **Physics**: 34,493 nodes, 247,962 edges, 5 classes (physics)
```python
from torch_geometric.datasets import Coauthor
dataset = Coauthor(root='/tmp/CS', name='CS')
```
### Amazon
**Usage**: Node classification on product networks
**Networks**: Computers, Photo
**Description**: Amazon co-purchase networks where nodes are products
- **Computers**: 13,752 nodes, 245,861 edges, 10 classes
- **Photo**: 7,650 nodes, 119,081 edges, 8 classes
```python
from torch_geometric.datasets import Amazon
dataset = Amazon(root='/tmp/Computers', name='Computers')
```
### CitationFull
**Usage**: Citation network analysis
**Networks**: Cora, Cora_ML, DBLP, PubMed
**Description**: Full citation networks without sampling
```python
from torch_geometric.datasets import CitationFull
dataset = CitationFull(root='/tmp/Cora', name='Cora')
```
## Graph Classification
### TUDataset
**Usage**: Graph classification, graph kernel benchmarks
**Description**: Collection of 120+ graph classification datasets
- **MUTAG**: 188 graphs, 2 classes (molecular compounds)
- **PROTEINS**: 1,113 graphs, 2 classes (protein structures)
- **ENZYMES**: 600 graphs, 6 classes (protein enzymes)
- **IMDB-BINARY**: 1,000 graphs, 2 classes (social networks)
- **REDDIT-BINARY**: 2,000 graphs, 2 classes (discussion threads)
- **COLLAB**: 5,000 graphs, 3 classes (scientific collaborations)
- **NCI1**: 4,110 graphs, 2 classes (chemical compounds)
- **DD**: 1,178 graphs, 2 classes (protein structures)
```python
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
```
### MoleculeNet
**Usage**: Molecular property prediction
**Datasets**: Over 10 molecular benchmark datasets
**Description**: Comprehensive molecular machine learning benchmarks
- **ESOL**: Aqueous solubility (regression)
- **FreeSolv**: Hydration free energy (regression)
- **Lipophilicity**: Octanol/water distribution (regression)
- **BACE**: Binding results (classification)
- **BBBP**: Blood-brain barrier penetration (classification)
- **HIV**: HIV inhibition (classification)
- **Tox21**: Toxicity prediction (multi-task classification)
- **ToxCast**: Toxicology forecasting (multi-task classification)
- **SIDER**: Side effects (multi-task classification)
- **ClinTox**: Clinical trial toxicity (multi-task classification)
```python
from torch_geometric.datasets import MoleculeNet
dataset = MoleculeNet(root='/tmp/ESOL', name='ESOL')
```
## Molecular and Chemical Datasets
### QM7b
**Usage**: Molecular property prediction (quantum mechanics)
**Description**: 7,211 molecules with up to 7 heavy atoms
- Properties: Atomization energies, electronic properties
```python
from torch_geometric.datasets import QM7b
dataset = QM7b(root='/tmp/QM7b')
```
### QM9
**Usage**: Molecular property prediction (quantum mechanics)
**Description**: ~130,000 molecules with up to 9 heavy atoms (C, O, N, F)
- Properties: 19 quantum chemical properties including HOMO, LUMO, gap, energy
```python
from torch_geometric.datasets import QM9
dataset = QM9(root='/tmp/QM9')
```
### ZINC
**Usage**: Molecular generation, property prediction
**Description**: ~250,000 drug-like molecular graphs
- Properties: Constrained solubility, molecular weight
```python
from torch_geometric.datasets import ZINC
dataset = ZINC(root='/tmp/ZINC', subset=True)
```
### AQSOL
**Usage**: Aqueous solubility prediction
**Description**: ~10,000 molecules with solubility measurements
```python
from torch_geometric.datasets import AQSOL
dataset = AQSOL(root='/tmp/AQSOL')
```
### MD17
**Usage**: Molecular dynamics, force field learning
**Description**: Molecular dynamics trajectories for small molecules
- Molecules: Benzene, Uracil, Naphthalene, Aspirin, Salicylic acid, etc.
```python
from torch_geometric.datasets import MD17
dataset = MD17(root='/tmp/MD17', name='benzene')
```
### PCQM4Mv2
**Usage**: Large-scale molecular property prediction
**Description**: 3.8M molecules from PubChem for quantum chemistry
- Part of OGB Large-Scale Challenge
```python
from torch_geometric.datasets import PCQM4Mv2
dataset = PCQM4Mv2(root='/tmp/PCQM4Mv2')
```
## Social Networks
### Reddit
**Usage**: Large-scale node classification
**Description**: Reddit posts from September 2014
- 232,965 nodes, 11,606,919 edges, 41 classes
- Features: TF-IDF of post content
```python
from torch_geometric.datasets import Reddit
dataset = Reddit(root='/tmp/Reddit')
```
### Reddit2
**Usage**: Large-scale node classification
**Description**: Updated Reddit dataset with more posts
```python
from torch_geometric.datasets import Reddit2
dataset = Reddit2(root='/tmp/Reddit2')
```
### Twitch
**Usage**: Node classification, social network analysis
**Networks**: DE, EN, ES, FR, PT, RU
**Description**: Twitch user networks by language
```python
from torch_geometric.datasets import Twitch
dataset = Twitch(root='/tmp/Twitch', name='DE')
```
### Facebook
**Usage**: Social network analysis, node classification
**Description**: Facebook page-page networks
```python
from torch_geometric.datasets import FacebookPagePage
dataset = FacebookPagePage(root='/tmp/Facebook')
```
### GitHub
**Usage**: Social network analysis
**Description**: GitHub developer networks
```python
from torch_geometric.datasets import GitHub
dataset = GitHub(root='/tmp/GitHub')
```
## Knowledge Graphs
### Entities
**Usage**: Link prediction, knowledge graph embeddings
**Datasets**: AIFB, MUTAG, BGS, AM
**Description**: RDF knowledge graphs with typed relations
```python
from torch_geometric.datasets import Entities
dataset = Entities(root='/tmp/AIFB', name='AIFB')
```
### WordNet18
**Usage**: Link prediction on semantic networks
**Description**: Subset of WordNet with 18 relations
- 40,943 entities, 151,442 triplets
```python
from torch_geometric.datasets import WordNet18
dataset = WordNet18(root='/tmp/WordNet18')
```
### WordNet18RR
**Usage**: Link prediction (no inverse relations)
**Description**: Refined version without inverse relations
```python
from torch_geometric.datasets import WordNet18RR
dataset = WordNet18RR(root='/tmp/WordNet18RR')
```
### FB15k-237
**Usage**: Link prediction on Freebase
**Description**: Subset of Freebase with 237 relations
- 14,541 entities, 310,116 triplets
```python
from torch_geometric.datasets import FB15k_237
dataset = FB15k_237(root='/tmp/FB15k')
```
## Heterogeneous Graphs
### OGB_MAG
**Usage**: Heterogeneous graph learning, node classification
**Description**: Microsoft Academic Graph with multiple node/edge types
- Node types: paper, author, institution, field of study
- 1M+ nodes, 21M+ edges
```python
from torch_geometric.datasets import OGB_MAG
dataset = OGB_MAG(root='/tmp/OGB_MAG')
```
### MovieLens
**Usage**: Recommendation systems, link prediction
**Versions**: 100K, 1M, 10M, 20M
**Description**: User-movie rating networks
- Node types: user, movie
- Edge types: rates
```python
from torch_geometric.datasets import MovieLens
dataset = MovieLens(root='/tmp/MovieLens', model_name='100k')
```
### IMDB
**Usage**: Heterogeneous graph learning
**Description**: IMDB movie network
- Node types: movie, actor, director
```python
from torch_geometric.datasets import IMDB
dataset = IMDB(root='/tmp/IMDB')
```
### DBLP
**Usage**: Heterogeneous graph learning, node classification
**Description**: DBLP bibliography network
- Node types: author, paper, term, conference
```python
from torch_geometric.datasets import DBLP
dataset = DBLP(root='/tmp/DBLP')
```
### LastFM
**Usage**: Heterogeneous recommendation
**Description**: LastFM music network
- Node types: user, artist, tag
```python
from torch_geometric.datasets import LastFM
dataset = LastFM(root='/tmp/LastFM')
```
## Temporal Graphs
### BitcoinOTC
**Usage**: Temporal link prediction, trust networks
**Description**: Bitcoin OTC trust network over time
```python
from torch_geometric.datasets import BitcoinOTC
dataset = BitcoinOTC(root='/tmp/BitcoinOTC')
```
### ICEWS18
**Usage**: Temporal knowledge graph completion
**Description**: Integrated Crisis Early Warning System events
```python
from torch_geometric.datasets import ICEWS18
dataset = ICEWS18(root='/tmp/ICEWS18')
```
### GDELT
**Usage**: Temporal event forecasting
**Description**: Global Database of Events, Language, and Tone
```python
from torch_geometric.datasets import GDELT
dataset = GDELT(root='/tmp/GDELT')
```
### JODIEDataset
**Usage**: Dynamic graph learning
**Datasets**: Reddit, Wikipedia, MOOC, LastFM
**Description**: Temporal interaction networks
```python
from torch_geometric.datasets import JODIEDataset
dataset = JODIEDataset(root='/tmp/JODIE', name='Reddit')
```
## 3D Meshes and Point Clouds
### ShapeNet
**Usage**: 3D shape classification and segmentation
**Description**: Large-scale 3D CAD model dataset
- 16,881 models across 16 categories
- Part-level segmentation labels
```python
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])
```
### ModelNet
**Usage**: 3D shape classification
**Versions**: ModelNet10, ModelNet40
**Description**: CAD models for 3D object classification
- ModelNet10: 4,899 models, 10 categories
- ModelNet40: 12,311 models, 40 categories
```python
from torch_geometric.datasets import ModelNet
dataset = ModelNet(root='/tmp/ModelNet', name='10')
```
### FAUST
**Usage**: 3D shape matching, correspondence
**Description**: Human body scans for shape analysis
- 100 meshes of 10 people in 10 poses
```python
from torch_geometric.datasets import FAUST
dataset = FAUST(root='/tmp/FAUST')
```
### CoMA
**Usage**: 3D mesh deformation
**Description**: Facial expression meshes
- 20,466 3D face scans with expressions
```python
from torch_geometric.datasets import CoMA
dataset = CoMA(root='/tmp/CoMA')
```
### S3DIS
**Usage**: 3D semantic segmentation
**Description**: Stanford Large-Scale 3D Indoor Spaces
- 6 areas, 271 rooms, point cloud data
```python
from torch_geometric.datasets import S3DIS
dataset = S3DIS(root='/tmp/S3DIS', test_area=6)
```
## Image and Vision Datasets
### MNISTSuperpixels
**Usage**: Graph-based image classification
**Description**: MNIST images as superpixel graphs
- 70,000 graphs (60k train, 10k test)
```python
from torch_geometric.datasets import MNISTSuperpixels
dataset = MNISTSuperpixels(root='/tmp/MNIST')
```
### Flickr
**Usage**: Image description, node classification
**Description**: Flickr image network
- 89,250 nodes, 899,756 edges
```python
from torch_geometric.datasets import Flickr
dataset = Flickr(root='/tmp/Flickr')
```
### PPI
**Usage**: Protein-protein interaction prediction
**Description**: Multi-graph protein interaction networks
- 24 graphs, 2,373 nodes total
```python
from torch_geometric.datasets import PPI
dataset = PPI(root='/tmp/PPI', split='train')
```
## Small Classic Graphs
### KarateClub
**Usage**: Community detection, visualization
**Description**: Zachary's karate club network
- 34 nodes, 78 edges, 2 communities
```python
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
```
## Open Graph Benchmark (OGB)
PyG integrates seamlessly with OGB datasets:
### Node Property Prediction
- **ogbn-products**: Amazon product network (2.4M nodes)
- **ogbn-proteins**: Protein association network (132K nodes)
- **ogbn-arxiv**: Citation network (169K nodes)
- **ogbn-papers100M**: Large citation network (111M nodes)
- **ogbn-mag**: Heterogeneous academic graph
### Link Property Prediction
- **ogbl-ppa**: Protein association networks
- **ogbl-collab**: Collaboration networks
- **ogbl-ddi**: Drug-drug interaction network
- **ogbl-citation2**: Citation network
- **ogbl-wikikg2**: Wikidata knowledge graph
### Graph Property Prediction
- **ogbg-molhiv**: Molecular HIV activity prediction
- **ogbg-molpcba**: Molecular bioassays (multi-task)
- **ogbg-ppa**: Protein function prediction
- **ogbg-code2**: Code abstract syntax trees
```python
from torch_geometric.datasets import OGB_MAG, OGB_PPA
# or
from ogb.nodeproppred import PygNodePropPredDataset
dataset = PygNodePropPredDataset(name='ogbn-arxiv')
```
## Synthetic Datasets
### FakeDataset
**Usage**: Testing, debugging
**Description**: Generates random graph data
```python
from torch_geometric.datasets import FakeDataset
dataset = FakeDataset(num_graphs=100, avg_num_nodes=50)
```
### StochasticBlockModelDataset
**Usage**: Community detection benchmarks
**Description**: Graphs generated from stochastic block models
```python
from torch_geometric.datasets import StochasticBlockModelDataset
dataset = StochasticBlockModelDataset(root='/tmp/SBM', num_graphs=1000)
```
### ExplainerDataset
**Usage**: Testing explainability methods
**Description**: Synthetic graphs with known explanation ground truth
```python
from torch_geometric.datasets import ExplainerDataset
dataset = ExplainerDataset(num_graphs=1000)
```
## Materials Science
### QM8
**Usage**: Molecular property prediction
**Description**: Electronic properties of small molecules
```python
from torch_geometric.datasets import QM8
dataset = QM8(root='/tmp/QM8')
```
## Biological Networks
### PPI (Protein-Protein Interaction)
Already listed above under Image and Vision Datasets
### STRING
**Usage**: Protein interaction networks
**Description**: Known and predicted protein-protein interactions
```python
# Available through external sources or custom loading
```
## Usage Tips
1. **Start with small datasets**: Use Cora, KarateClub, or ENZYMES for prototyping
2. **Citation networks**: Planetoid datasets are perfect for node classification
3. **Graph classification**: TUDataset provides diverse benchmarks
4. **Molecular**: QM9, ZINC, MoleculeNet for chemistry applications
5. **Large-scale**: Use Reddit, OGB datasets with NeighborLoader
6. **Heterogeneous**: OGB_MAG, MovieLens, IMDB for multi-type graphs
7. **Temporal**: JODIE, ICEWS for dynamic graph learning
8. **3D**: ShapeNet, ModelNet, S3DIS for geometric learning
## Common Patterns
### Loading with Transforms
```python
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root='/tmp/Cora', name='Cora',
transform=NormalizeFeatures())
```
### Train/Val/Test Splits
```python
# For datasets with pre-defined splits
data = dataset[0]
train_data = data[data.train_mask]
val_data = data[data.val_mask]
test_data = data[data.test_mask]
# For graph classification
from torch_geometric.loader import DataLoader
train_dataset = dataset[:int(len(dataset) * 0.8)]
test_dataset = dataset[int(len(dataset) * 0.8):]
train_loader = DataLoader(train_dataset, batch_size=32)
```
### Custom Data Loading
```python
from torch_geometric.data import Data, Dataset
class MyCustomDataset(Dataset):
def __init__(self, root, transform=None):
super().__init__(root, transform)
# Your initialization
def len(self):
return len(self.data_list)
def get(self, idx):
# Load and return data object
return self.data_list[idx]
```

View File

@@ -0,0 +1,485 @@
# PyTorch Geometric Neural Network Layers Reference
This document provides a comprehensive reference of all neural network layers available in `torch_geometric.nn`.
## Layer Capability Flags
When selecting layers, consider these capability flags:
- **SparseTensor**: Supports `torch_sparse.SparseTensor` format for efficient sparse operations
- **edge_weight**: Handles one-dimensional edge weight data
- **edge_attr**: Processes multi-dimensional edge feature information
- **Bipartite**: Works with bipartite graphs (different source/target node dimensions)
- **Static**: Operates on static graphs with batched node features
- **Lazy**: Enables initialization without specifying input channel dimensions
## Convolutional Layers
### Standard Graph Convolutions
**GCNConv** - Graph Convolutional Network layer
- Implements spectral graph convolution with symmetric normalization
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Citation networks, social networks, general graph learning
- Example: `GCNConv(in_channels, out_channels, improved=False, cached=True)`
**SAGEConv** - GraphSAGE layer
- Inductive learning via neighborhood sampling and aggregation
- Supports: SparseTensor, Bipartite, Lazy
- Use for: Large graphs, inductive learning, heterogeneous features
- Example: `SAGEConv(in_channels, out_channels, aggr='mean')`
**GATConv** - Graph Attention Network layer
- Multi-head attention mechanism for adaptive neighbor weighting
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
- Use for: Tasks requiring variable neighbor importance
- Example: `GATConv(in_channels, out_channels, heads=8, dropout=0.6)`
**GraphConv** - Simple graph convolution (Morris et al.)
- Basic message passing with optional edge weights
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Baseline models, simple graph structures
- Example: `GraphConv(in_channels, out_channels, aggr='add')`
**GINConv** - Graph Isomorphism Network layer
- Maximally powerful GNN for graph isomorphism testing
- Supports: Bipartite
- Use for: Graph classification, molecular property prediction
- Example: `GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), nn.ReLU()))`
**TransformerConv** - Graph Transformer layer
- Combines graph structure with transformer attention
- Supports: SparseTensor, Bipartite, Lazy
- Use for: Long-range dependencies, complex graphs
- Example: `TransformerConv(in_channels, out_channels, heads=8, beta=True)`
**ChebConv** - Chebyshev spectral graph convolution
- Uses Chebyshev polynomials for efficient spectral filtering
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Spectral graph learning, efficient convolutions
- Example: `ChebConv(in_channels, out_channels, K=3)`
**SGConv** - Simplified Graph Convolution
- Pre-computes fixed number of propagation steps
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Fast training, shallow models
- Example: `SGConv(in_channels, out_channels, K=2)`
**APPNP** - Approximate Personalized Propagation of Neural Predictions
- Separates feature transformation from propagation
- Supports: SparseTensor, edge_weight, Lazy
- Use for: Deep propagation without oversmoothing
- Example: `APPNP(K=10, alpha=0.1)`
**ARMAConv** - ARMA graph convolution
- Uses ARMA filters for graph filtering
- Supports: SparseTensor, edge_weight, Bipartite, Lazy
- Use for: Advanced spectral methods
- Example: `ARMAConv(in_channels, out_channels, num_stacks=3, num_layers=2)`
**GATv2Conv** - Improved Graph Attention Network
- Fixes static attention computation issue in GAT
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
- Use for: Better attention learning than original GAT
- Example: `GATv2Conv(in_channels, out_channels, heads=8)`
**SuperGATConv** - Self-supervised Graph Attention
- Adds self-supervised attention mechanism
- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy
- Use for: Self-supervised learning, limited labels
- Example: `SuperGATConv(in_channels, out_channels, heads=8)`
**GMMConv** - Gaussian Mixture Model Convolution
- Uses Gaussian kernels in pseudo-coordinate space
- Supports: Bipartite
- Use for: Point clouds, spatial data
- Example: `GMMConv(in_channels, out_channels, dim=3, kernel_size=5)`
**SplineConv** - Spline-based convolution
- B-spline basis functions for spatial filtering
- Supports: Bipartite
- Use for: Irregular grids, continuous spaces
- Example: `SplineConv(in_channels, out_channels, dim=2, kernel_size=5)`
**NNConv** - Neural Network Convolution
- Edge features processed by neural networks
- Supports: edge_attr, Bipartite
- Use for: Rich edge features, molecular graphs
- Example: `NNConv(in_channels, out_channels, nn=edge_nn, aggr='mean')`
**CGConv** - Crystal Graph Convolution
- Designed for crystalline materials
- Supports: Bipartite
- Use for: Materials science, crystal structures
- Example: `CGConv(in_channels, dim=3, batch_norm=True)`
**EdgeConv** - Edge Convolution (Dynamic Graph CNN)
- Dynamically computes edges based on feature space
- Supports: Static
- Use for: Point clouds, dynamic graphs
- Example: `EdgeConv(nn=edge_nn, aggr='max')`
**PointNetConv** - PointNet++ convolution
- Local and global feature learning for point clouds
- Use for: 3D point cloud processing
- Example: `PointNetConv(local_nn, global_nn)`
**ResGatedGraphConv** - Residual Gated Graph Convolution
- Gating mechanism with residual connections
- Supports: edge_attr, Bipartite, Lazy
- Use for: Deep GNNs, complex features
- Example: `ResGatedGraphConv(in_channels, out_channels)`
**GENConv** - Generalized Graph Convolution
- Generalizes multiple GNN variants
- Supports: SparseTensor, edge_weight, edge_attr, Bipartite, Lazy
- Use for: Flexible architecture exploration
- Example: `GENConv(in_channels, out_channels, aggr='softmax', num_layers=2)`
**FiLMConv** - Feature-wise Linear Modulation
- Conditions on global features
- Supports: Bipartite, Lazy
- Use for: Conditional generation, multi-task learning
- Example: `FiLMConv(in_channels, out_channels, num_relations=5)`
**PANConv** - Path Attention Network
- Attention over multi-hop paths
- Supports: SparseTensor, Lazy
- Use for: Complex connectivity patterns
- Example: `PANConv(in_channels, out_channels, filter_size=3)`
**ClusterGCNConv** - Cluster-GCN convolution
- Efficient training via graph clustering
- Supports: edge_attr, Lazy
- Use for: Very large graphs
- Example: `ClusterGCNConv(in_channels, out_channels)`
**MFConv** - Multi-scale Feature Convolution
- Aggregates features at multiple scales
- Supports: SparseTensor, Lazy
- Use for: Multi-scale patterns
- Example: `MFConv(in_channels, out_channels)`
**RGCNConv** - Relational Graph Convolution
- Handles multiple edge types
- Supports: SparseTensor, edge_weight, Lazy
- Use for: Knowledge graphs, heterogeneous graphs
- Example: `RGCNConv(in_channels, out_channels, num_relations=10)`
**FAConv** - Frequency Adaptive Convolution
- Adaptive filtering in spectral domain
- Supports: SparseTensor, Lazy
- Use for: Spectral graph learning
- Example: `FAConv(in_channels, eps=0.1, dropout=0.5)`
### Molecular and 3D Convolutions
**SchNet** - Continuous-filter convolutional layer
- Designed for molecular dynamics
- Use for: Molecular property prediction, 3D molecules
- Example: `SchNet(hidden_channels=128, num_filters=64, num_interactions=6)`
**DimeNet** - Directional Message Passing
- Uses directional information and angles
- Use for: 3D molecular structures, chemical properties
- Example: `DimeNet(hidden_channels=128, out_channels=1, num_blocks=6)`
**PointTransformerConv** - Point cloud transformer
- Transformer for 3D point clouds
- Use for: 3D vision, point cloud segmentation
- Example: `PointTransformerConv(in_channels, out_channels)`
### Hypergraph Convolutions
**HypergraphConv** - Hypergraph convolution
- Operates on hyperedges (edges connecting multiple nodes)
- Supports: Lazy
- Use for: Multi-way relationships, chemical reactions
- Example: `HypergraphConv(in_channels, out_channels)`
**HGTConv** - Heterogeneous Graph Transformer
- Transformer for heterogeneous graphs with multiple types
- Supports: Lazy
- Use for: Heterogeneous networks, knowledge graphs
- Example: `HGTConv(in_channels, out_channels, metadata, heads=8)`
## Aggregation Operators
**Aggr** - Base aggregation class
- Flexible aggregation across nodes
**SumAggregation** - Sum aggregation
- Example: `SumAggregation()`
**MeanAggregation** - Mean aggregation
- Example: `MeanAggregation()`
**MaxAggregation** - Max aggregation
- Example: `MaxAggregation()`
**SoftmaxAggregation** - Softmax-weighted aggregation
- Learnable attention weights
- Example: `SoftmaxAggregation(learn=True)`
**PowerMeanAggregation** - Power mean aggregation
- Learnable power parameter
- Example: `PowerMeanAggregation(learn=True)`
**LSTMAggregation** - LSTM-based aggregation
- Sequential processing of neighbors
- Example: `LSTMAggregation(in_channels, out_channels)`
**SetTransformerAggregation** - Set Transformer aggregation
- Transformer for permutation-invariant aggregation
- Example: `SetTransformerAggregation(in_channels, out_channels)`
**MultiAggregation** - Multiple aggregations
- Combines multiple aggregation methods
- Example: `MultiAggregation(['mean', 'max', 'std'])`
## Pooling Layers
### Global Pooling
**global_mean_pool** - Global mean pooling
- Averages node features per graph
- Example: `global_mean_pool(x, batch)`
**global_max_pool** - Global max pooling
- Max over node features per graph
- Example: `global_max_pool(x, batch)`
**global_add_pool** - Global sum pooling
- Sums node features per graph
- Example: `global_add_pool(x, batch)`
**global_sort_pool** - Global sort pooling
- Sorts and concatenates top-k nodes
- Example: `global_sort_pool(x, batch, k=30)`
**GlobalAttention** - Global attention pooling
- Learnable attention weights for aggregation
- Example: `GlobalAttention(gate_nn)`
**Set2Set** - Set2Set pooling
- LSTM-based attention mechanism
- Example: `Set2Set(in_channels, processing_steps=3)`
### Hierarchical Pooling
**TopKPooling** - Top-k pooling
- Keeps top-k nodes based on projection scores
- Example: `TopKPooling(in_channels, ratio=0.5)`
**SAGPooling** - Self-Attention Graph Pooling
- Uses self-attention for node selection
- Example: `SAGPooling(in_channels, ratio=0.5)`
**ASAPooling** - Adaptive Structure Aware Pooling
- Structure-aware node selection
- Example: `ASAPooling(in_channels, ratio=0.5)`
**PANPooling** - Path Attention Pooling
- Attention over paths for pooling
- Example: `PANPooling(in_channels, ratio=0.5)`
**EdgePooling** - Edge contraction pooling
- Pools by contracting edges
- Example: `EdgePooling(in_channels)`
**MemPooling** - Memory-based pooling
- Learnable cluster assignments
- Example: `MemPooling(in_channels, out_channels, heads=4, num_clusters=10)`
**avg_pool** / **max_pool** - Average/Max pool with clustering
- Pools nodes within clusters
- Example: `avg_pool(cluster, data)`
## Normalization Layers
**BatchNorm** - Batch normalization
- Normalizes features across batch
- Example: `BatchNorm(in_channels)`
**LayerNorm** - Layer normalization
- Normalizes features per sample
- Example: `LayerNorm(in_channels)`
**InstanceNorm** - Instance normalization
- Normalizes per sample and graph
- Example: `InstanceNorm(in_channels)`
**GraphNorm** - Graph normalization
- Graph-specific normalization
- Example: `GraphNorm(in_channels)`
**PairNorm** - Pair normalization
- Prevents oversmoothing in deep GNNs
- Example: `PairNorm(scale_individually=False)`
**MessageNorm** - Message normalization
- Normalizes messages during passing
- Example: `MessageNorm(learn_scale=True)`
**DiffGroupNorm** - Differentiable Group Normalization
- Learnable grouping for normalization
- Example: `DiffGroupNorm(in_channels, groups=10)`
## Model Architectures
### Pre-Built Models
**GCN** - Complete Graph Convolutional Network
- Multi-layer GCN with dropout
- Example: `GCN(in_channels, hidden_channels, num_layers, out_channels)`
**GraphSAGE** - Complete GraphSAGE model
- Multi-layer SAGE with dropout
- Example: `GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)`
**GIN** - Complete Graph Isomorphism Network
- Multi-layer GIN for graph classification
- Example: `GIN(in_channels, hidden_channels, num_layers, out_channels)`
**GAT** - Complete Graph Attention Network
- Multi-layer GAT with attention
- Example: `GAT(in_channels, hidden_channels, num_layers, out_channels, heads=8)`
**PNA** - Principal Neighbourhood Aggregation
- Combines multiple aggregators and scalers
- Example: `PNA(in_channels, hidden_channels, num_layers, out_channels)`
**EdgeCNN** - Edge Convolution CNN
- Dynamic graph CNN for point clouds
- Example: `EdgeCNN(out_channels, num_layers=3, k=20)`
### Auto-Encoders
**GAE** - Graph Auto-Encoder
- Encodes graphs into latent space
- Example: `GAE(encoder)`
**VGAE** - Variational Graph Auto-Encoder
- Probabilistic graph encoding
- Example: `VGAE(encoder)`
**ARGA** - Adversarially Regularized Graph Auto-Encoder
- GAE with adversarial regularization
- Example: `ARGA(encoder, discriminator)`
**ARGVA** - Adversarially Regularized Variational Graph Auto-Encoder
- VGAE with adversarial regularization
- Example: `ARGVA(encoder, discriminator)`
### Knowledge Graph Embeddings
**TransE** - Translating embeddings
- Learns entity and relation embeddings
- Example: `TransE(num_nodes, num_relations, hidden_channels)`
**RotatE** - Rotational embeddings
- Embeddings in complex space
- Example: `RotatE(num_nodes, num_relations, hidden_channels)`
**ComplEx** - Complex embeddings
- Complex-valued embeddings
- Example: `ComplEx(num_nodes, num_relations, hidden_channels)`
**DistMult** - Bilinear diagonal model
- Simplified bilinear model
- Example: `DistMult(num_nodes, num_relations, hidden_channels)`
## Utility Layers
**Sequential** - Sequential container
- Chains multiple layers
- Example: `Sequential('x, edge_index', [(GCNConv(16, 64), 'x, edge_index -> x'), nn.ReLU()])`
**JumpingKnowledge** - Jumping knowledge connections
- Combines representations from all layers
- Modes: 'cat', 'max', 'lstm'
- Example: `JumpingKnowledge(mode='cat')`
**DeepGCNLayer** - Deep GCN layer wrapper
- Enables very deep GNNs with skip connections
- Example: `DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1)`
**MLP** - Multi-layer perceptron
- Standard feedforward network
- Example: `MLP([in_channels, 64, 64, out_channels], dropout=0.5)`
**Linear** - Lazy linear layer
- Linear transformation with lazy initialization
- Example: `Linear(in_channels, out_channels, bias=True)`
## Dense Layers
For dense (non-sparse) graph representations:
**DenseGCNConv** - Dense GCN layer
**DenseSAGEConv** - Dense SAGE layer
**DenseGINConv** - Dense GIN layer
**DenseGraphConv** - Dense graph convolution
These are useful when working with small, fully-connected, or densely represented graphs.
## Usage Tips
1. **Start simple**: Begin with GCNConv or GATConv for most tasks
2. **Consider data type**: Use molecular layers (SchNet, DimeNet) for 3D structures
3. **Check capabilities**: Match layer capabilities to your data (edge features, bipartite, etc.)
4. **Deep networks**: Use normalization (PairNorm, LayerNorm) and JumpingKnowledge for deep GNNs
5. **Large graphs**: Use scalable layers (SAGE, Cluster-GCN) with neighbor sampling
6. **Heterogeneous**: Use RGCNConv, HGTConv, or to_hetero() conversion
7. **Lazy initialization**: Use lazy layers when input dimensions vary or are unknown
## Common Patterns
### Basic GNN
```python
from torch_geometric.nn import GCNConv, global_mean_pool
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return global_mean_pool(x, batch)
```
### Deep GNN with Normalization
```python
class DeepGNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.norms = torch.nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
self.norms.append(LayerNorm(hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.norms.append(LayerNorm(hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels))
self.jk = JumpingKnowledge(mode='cat')
def forward(self, x, edge_index, batch):
xs = []
for conv, norm in zip(self.convs[:-1], self.norms):
x = conv(x, edge_index)
x = norm(x)
x = F.relu(x)
xs.append(x)
x = self.convs[-1](x, edge_index)
xs.append(x)
x = self.jk(xs)
return global_mean_pool(x, batch)
```

View File

@@ -0,0 +1,679 @@
# PyTorch Geometric Transforms Reference
This document provides a comprehensive reference of all transforms available in `torch_geometric.transforms`.
## Overview
Transforms modify `Data` or `HeteroData` objects before or during training. Apply them via:
```python
# During dataset loading
dataset = MyDataset(root='/tmp', transform=MyTransform())
# Apply to individual data
transform = MyTransform()
data = transform(data)
# Compose multiple transforms
from torch_geometric.transforms import Compose
transform = Compose([Transform1(), Transform2(), Transform3()])
```
## General Transforms
### NormalizeFeatures
**Purpose**: Row-normalizes node features to sum to 1
**Use case**: Feature scaling, probability-like features
```python
from torch_geometric.transforms import NormalizeFeatures
transform = NormalizeFeatures()
```
### ToDevice
**Purpose**: Transfers data to specified device (CPU/GPU)
**Use case**: GPU training, device management
```python
from torch_geometric.transforms import ToDevice
transform = ToDevice('cuda')
```
### RandomNodeSplit
**Purpose**: Creates train/val/test node masks
**Use case**: Node classification splits
**Parameters**: `split='train_rest'`, `num_splits`, `num_val`, `num_test`
```python
from torch_geometric.transforms import RandomNodeSplit
transform = RandomNodeSplit(num_val=0.1, num_test=0.2)
```
### RandomLinkSplit
**Purpose**: Creates train/val/test edge splits
**Use case**: Link prediction
**Parameters**: `num_val`, `num_test`, `is_undirected`, `split_labels`
```python
from torch_geometric.transforms import RandomLinkSplit
transform = RandomLinkSplit(num_val=0.1, num_test=0.2)
```
### IndexToMask
**Purpose**: Converts indices to boolean masks
**Use case**: Data preprocessing
```python
from torch_geometric.transforms import IndexToMask
transform = IndexToMask()
```
### MaskToIndex
**Purpose**: Converts boolean masks to indices
**Use case**: Data preprocessing
```python
from torch_geometric.transforms import MaskToIndex
transform = MaskToIndex()
```
### FixedPoints
**Purpose**: Samples a fixed number of points
**Use case**: Point cloud subsampling
**Parameters**: `num`, `replace`, `allow_duplicates`
```python
from torch_geometric.transforms import FixedPoints
transform = FixedPoints(1024)
```
### ToDense
**Purpose**: Converts to dense adjacency matrices
**Use case**: Small graphs, dense operations
```python
from torch_geometric.transforms import ToDense
transform = ToDense(num_nodes=100)
```
### ToSparseTensor
**Purpose**: Converts edge_index to SparseTensor
**Use case**: Efficient sparse operations
**Parameters**: `remove_edge_index`, `fill_cache`
```python
from torch_geometric.transforms import ToSparseTensor
transform = ToSparseTensor()
```
## Graph Structure Transforms
### ToUndirected
**Purpose**: Converts directed graph to undirected
**Use case**: Undirected graph algorithms
**Parameters**: `reduce='add'` (how to handle duplicate edges)
```python
from torch_geometric.transforms import ToUndirected
transform = ToUndirected()
```
### AddSelfLoops
**Purpose**: Adds self-loops to all nodes
**Use case**: GCN-style convolutions
**Parameters**: `fill_value` (edge attribute for self-loops)
```python
from torch_geometric.transforms import AddSelfLoops
transform = AddSelfLoops()
```
### RemoveSelfLoops
**Purpose**: Removes all self-loops
**Use case**: Cleaning graph structure
```python
from torch_geometric.transforms import RemoveSelfLoops
transform = RemoveSelfLoops()
```
### RemoveIsolatedNodes
**Purpose**: Removes nodes without edges
**Use case**: Graph cleaning
```python
from torch_geometric.transforms import RemoveIsolatedNodes
transform = RemoveIsolatedNodes()
```
### RemoveDuplicatedEdges
**Purpose**: Removes duplicate edges
**Use case**: Graph cleaning
```python
from torch_geometric.transforms import RemoveDuplicatedEdges
transform = RemoveDuplicatedEdges()
```
### LargestConnectedComponents
**Purpose**: Keeps only the largest connected component
**Use case**: Focus on main graph structure
**Parameters**: `num_components` (how many components to keep)
```python
from torch_geometric.transforms import LargestConnectedComponents
transform = LargestConnectedComponents(num_components=1)
```
### KNNGraph
**Purpose**: Creates edges based on k-nearest neighbors
**Use case**: Point clouds, spatial data
**Parameters**: `k`, `loop`, `force_undirected`, `flow`
```python
from torch_geometric.transforms import KNNGraph
transform = KNNGraph(k=6)
```
### RadiusGraph
**Purpose**: Creates edges within a radius
**Use case**: Point clouds, spatial data
**Parameters**: `r`, `loop`, `max_num_neighbors`, `flow`
```python
from torch_geometric.transforms import RadiusGraph
transform = RadiusGraph(r=0.1)
```
### Delaunay
**Purpose**: Computes Delaunay triangulation
**Use case**: 2D/3D spatial graphs
```python
from torch_geometric.transforms import Delaunay
transform = Delaunay()
```
### FaceToEdge
**Purpose**: Converts mesh faces to edges
**Use case**: Mesh processing
```python
from torch_geometric.transforms import FaceToEdge
transform = FaceToEdge()
```
### LineGraph
**Purpose**: Converts graph to its line graph
**Use case**: Edge-centric analysis
**Parameters**: `force_directed`
```python
from torch_geometric.transforms import LineGraph
transform = LineGraph()
```
### GDC
**Purpose**: Graph Diffusion Convolution preprocessing
**Use case**: Improved message passing
**Parameters**: `self_loop_weight`, `normalization_in`, `normalization_out`, `diffusion_kwargs`
```python
from torch_geometric.transforms import GDC
transform = GDC(self_loop_weight=1, normalization_in='sym',
diffusion_kwargs=dict(method='ppr', alpha=0.15))
```
### SIGN
**Purpose**: Scalable Inception Graph Neural Networks preprocessing
**Use case**: Efficient multi-scale features
**Parameters**: `K` (number of hops)
```python
from torch_geometric.transforms import SIGN
transform = SIGN(K=3)
```
## Feature Transforms
### OneHotDegree
**Purpose**: One-hot encodes node degree
**Use case**: Degree as feature
**Parameters**: `max_degree`, `cat` (concatenate with existing features)
```python
from torch_geometric.transforms import OneHotDegree
transform = OneHotDegree(max_degree=100)
```
### LocalDegreeProfile
**Purpose**: Appends local degree profile
**Use case**: Structural node features
```python
from torch_geometric.transforms import LocalDegreeProfile
transform = LocalDegreeProfile()
```
### Constant
**Purpose**: Adds constant features to nodes
**Use case**: Featureless graphs
**Parameters**: `value`, `cat`
```python
from torch_geometric.transforms import Constant
transform = Constant(value=1.0)
```
### TargetIndegree
**Purpose**: Saves in-degree as target
**Use case**: Degree prediction
**Parameters**: `norm`, `max_value`
```python
from torch_geometric.transforms import TargetIndegree
transform = TargetIndegree(norm=False)
```
### AddRandomWalkPE
**Purpose**: Adds random walk positional encoding
**Use case**: Positional information
**Parameters**: `walk_length`, `attr_name`
```python
from torch_geometric.transforms import AddRandomWalkPE
transform = AddRandomWalkPE(walk_length=20)
```
### AddLaplacianEigenvectorPE
**Purpose**: Adds Laplacian eigenvector positional encoding
**Use case**: Spectral positional information
**Parameters**: `k` (number of eigenvectors), `attr_name`
```python
from torch_geometric.transforms import AddLaplacianEigenvectorPE
transform = AddLaplacianEigenvectorPE(k=10)
```
### AddMetaPaths
**Purpose**: Adds meta-path induced edges
**Use case**: Heterogeneous graphs
**Parameters**: `metapaths`, `drop_orig_edges`, `drop_unconnected_nodes`
```python
from torch_geometric.transforms import AddMetaPaths
metapaths = [[('author', 'paper'), ('paper', 'author')]] # Co-authorship
transform = AddMetaPaths(metapaths)
```
### SVDFeatureReduction
**Purpose**: Reduces feature dimensionality via SVD
**Use case**: Dimensionality reduction
**Parameters**: `out_channels`
```python
from torch_geometric.transforms import SVDFeatureReduction
transform = SVDFeatureReduction(out_channels=64)
```
## Vision/Spatial Transforms
### Center
**Purpose**: Centers node positions
**Use case**: Point cloud preprocessing
```python
from torch_geometric.transforms import Center
transform = Center()
```
### NormalizeScale
**Purpose**: Normalizes positions to unit sphere
**Use case**: Point cloud normalization
```python
from torch_geometric.transforms import NormalizeScale
transform = NormalizeScale()
```
### NormalizeRotation
**Purpose**: Rotates to principal components
**Use case**: Rotation-invariant learning
**Parameters**: `max_points`
```python
from torch_geometric.transforms import NormalizeRotation
transform = NormalizeRotation()
```
### Distance
**Purpose**: Saves Euclidean distance as edge attribute
**Use case**: Spatial graphs
**Parameters**: `norm`, `max_value`, `cat`
```python
from torch_geometric.transforms import Distance
transform = Distance(norm=False, cat=False)
```
### Cartesian
**Purpose**: Saves relative Cartesian coordinates as edge attributes
**Use case**: Spatial relationships
**Parameters**: `norm`, `max_value`, `cat`
```python
from torch_geometric.transforms import Cartesian
transform = Cartesian(norm=False)
```
### Polar
**Purpose**: Saves polar coordinates as edge attributes
**Use case**: 2D spatial graphs
**Parameters**: `norm`, `max_value`, `cat`
```python
from torch_geometric.transforms import Polar
transform = Polar(norm=False)
```
### Spherical
**Purpose**: Saves spherical coordinates as edge attributes
**Use case**: 3D spatial graphs
**Parameters**: `norm`, `max_value`, `cat`
```python
from torch_geometric.transforms import Spherical
transform = Spherical(norm=False)
```
### LocalCartesian
**Purpose**: Saves coordinates in local coordinate system
**Use case**: Local spatial features
**Parameters**: `norm`, `cat`
```python
from torch_geometric.transforms import LocalCartesian
transform = LocalCartesian()
```
### PointPairFeatures
**Purpose**: Computes point pair features
**Use case**: 3D registration, correspondence
**Parameters**: `cat`
```python
from torch_geometric.transforms import PointPairFeatures
transform = PointPairFeatures()
```
## Data Augmentation
### RandomJitter
**Purpose**: Randomly jitters node positions
**Use case**: Point cloud augmentation
**Parameters**: `translate`, `scale`
```python
from torch_geometric.transforms import RandomJitter
transform = RandomJitter(0.01)
```
### RandomFlip
**Purpose**: Randomly flips positions along axis
**Use case**: Geometric augmentation
**Parameters**: `axis`, `p` (probability)
```python
from torch_geometric.transforms import RandomFlip
transform = RandomFlip(axis=0, p=0.5)
```
### RandomScale
**Purpose**: Randomly scales positions
**Use case**: Scale augmentation
**Parameters**: `scales` (min, max)
```python
from torch_geometric.transforms import RandomScale
transform = RandomScale((0.9, 1.1))
```
### RandomRotate
**Purpose**: Randomly rotates positions
**Use case**: Rotation augmentation
**Parameters**: `degrees` (range), `axis` (rotation axis)
```python
from torch_geometric.transforms import RandomRotate
transform = RandomRotate(degrees=15, axis=2)
```
### RandomShear
**Purpose**: Randomly shears positions
**Use case**: Geometric augmentation
**Parameters**: `shear` (range)
```python
from torch_geometric.transforms import RandomShear
transform = RandomShear(0.1)
```
### RandomTranslate
**Purpose**: Randomly translates positions
**Use case**: Translation augmentation
**Parameters**: `translate` (range)
```python
from torch_geometric.transforms import RandomTranslate
transform = RandomTranslate(0.1)
```
### LinearTransformation
**Purpose**: Applies linear transformation matrix
**Use case**: Custom geometric transforms
**Parameters**: `matrix`
```python
from torch_geometric.transforms import LinearTransformation
import torch
matrix = torch.eye(3)
transform = LinearTransformation(matrix)
```
## Mesh Processing
### SamplePoints
**Purpose**: Samples points uniformly from mesh
**Use case**: Mesh to point cloud conversion
**Parameters**: `num`, `remove_faces`, `include_normals`
```python
from torch_geometric.transforms import SamplePoints
transform = SamplePoints(num=1024)
```
### GenerateMeshNormals
**Purpose**: Generates face/vertex normals
**Use case**: Mesh processing
```python
from torch_geometric.transforms import GenerateMeshNormals
transform = GenerateMeshNormals()
```
### FaceToEdge
**Purpose**: Converts mesh faces to edges
**Use case**: Mesh to graph conversion
**Parameters**: `remove_faces`
```python
from torch_geometric.transforms import FaceToEdge
transform = FaceToEdge()
```
## Sampling and Splitting
### GridSampling
**Purpose**: Clusters points in voxel grid
**Use case**: Point cloud downsampling
**Parameters**: `size` (voxel size), `start`, `end`
```python
from torch_geometric.transforms import GridSampling
transform = GridSampling(size=0.1)
```
### FixedPoints
**Purpose**: Samples fixed number of points
**Use case**: Uniform point cloud size
**Parameters**: `num`, `replace`, `allow_duplicates`
```python
from torch_geometric.transforms import FixedPoints
transform = FixedPoints(num=2048, replace=False)
```
### RandomScale
**Purpose**: Randomly scales by sampling from range
**Use case**: Scale augmentation (already listed above)
### VirtualNode
**Purpose**: Adds a virtual node connected to all nodes
**Use case**: Global information propagation
```python
from torch_geometric.transforms import VirtualNode
transform = VirtualNode()
```
## Specialized Transforms
### ToSLIC
**Purpose**: Converts images to superpixel graphs (SLIC algorithm)
**Use case**: Image as graph
**Parameters**: `num_segments`, `compactness`, `add_seg`, `add_img`
```python
from torch_geometric.transforms import ToSLIC
transform = ToSLIC(num_segments=75)
```
### GCNNorm
**Purpose**: Applies GCN-style normalization to edges
**Use case**: Preprocessing for GCN
**Parameters**: `add_self_loops`
```python
from torch_geometric.transforms import GCNNorm
transform = GCNNorm(add_self_loops=True)
```
### LaplacianLambdaMax
**Purpose**: Computes largest Laplacian eigenvalue
**Use case**: ChebConv preprocessing
**Parameters**: `normalization`, `is_undirected`
```python
from torch_geometric.transforms import LaplacianLambdaMax
transform = LaplacianLambdaMax(normalization='sym')
```
### NormalizeRotation
**Purpose**: Rotates mesh/point cloud to align with principal axes
**Use case**: Canonical orientation
**Parameters**: `max_points`
```python
from torch_geometric.transforms import NormalizeRotation
transform = NormalizeRotation()
```
## Compose and Apply
### Compose
**Purpose**: Chains multiple transforms
**Use case**: Complex preprocessing pipelines
```python
from torch_geometric.transforms import Compose
transform = Compose([
Center(),
NormalizeScale(),
KNNGraph(k=6),
Distance(norm=False),
])
```
### BaseTransform
**Purpose**: Base class for custom transforms
**Use case**: Implementing custom transforms
```python
from torch_geometric.transforms import BaseTransform
class MyTransform(BaseTransform):
def __init__(self, param):
self.param = param
def __call__(self, data):
# Modify data
data.x = data.x * self.param
return data
```
## Common Transform Combinations
### Node Classification Preprocessing
```python
transform = Compose([
NormalizeFeatures(),
RandomNodeSplit(num_val=0.1, num_test=0.2),
])
```
### Point Cloud Processing
```python
transform = Compose([
Center(),
NormalizeScale(),
RandomRotate(degrees=15, axis=2),
RandomJitter(0.01),
KNNGraph(k=6),
Distance(norm=False),
])
```
### Mesh to Graph
```python
transform = Compose([
FaceToEdge(remove_faces=True),
GenerateMeshNormals(),
Distance(norm=True),
])
```
### Graph Structure Enhancement
```python
transform = Compose([
ToUndirected(),
AddSelfLoops(),
RemoveIsolatedNodes(),
GCNNorm(),
])
```
### Heterogeneous Graph Preprocessing
```python
transform = Compose([
AddMetaPaths(metapaths=[
[('author', 'paper'), ('paper', 'author')],
[('author', 'paper'), ('paper', 'conference'), ('conference', 'paper'), ('paper', 'author')]
]),
RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.2),
])
```
### Link Prediction
```python
transform = Compose([
NormalizeFeatures(),
RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True),
])
```
## Usage Tips
1. **Order matters**: Apply structural transforms before feature transforms
2. **Caching**: Some transforms (like GDC) are expensive—apply once
3. **Augmentation**: Use Random* transforms during training only
4. **Compose sparingly**: Too many transforms slow down data loading
5. **Custom transforms**: Inherit from `BaseTransform` for custom logic
6. **Pre-transforms**: Apply expensive transforms once during dataset processing:
```python
dataset = MyDataset(root='/tmp', pre_transform=ExpensiveTransform())
```
7. **Dynamic transforms**: Apply cheap transforms during training:
```python
dataset = MyDataset(root='/tmp', transform=CheapTransform())
```
## Performance Considerations
**Expensive transforms** (apply as pre_transform):
- GDC
- SIGN
- KNNGraph (for large point clouds)
- AddLaplacianEigenvectorPE
- SVDFeatureReduction
**Cheap transforms** (apply as transform):
- NormalizeFeatures
- ToUndirected
- AddSelfLoops
- Random* augmentations
- ToDevice
**Example**:
```python
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import Compose, GDC, NormalizeFeatures
# Expensive preprocessing done once
pre_transform = GDC(
self_loop_weight=1,
normalization_in='sym',
diffusion_kwargs=dict(method='ppr', alpha=0.15)
)
# Cheap transform applied each time
transform = NormalizeFeatures()
dataset = Planetoid(
root='/tmp/Cora',
name='Cora',
pre_transform=pre_transform,
transform=transform
)
```

View File

@@ -0,0 +1,309 @@
#!/usr/bin/env python3
"""
Benchmark GNN models on standard datasets.
This script provides a simple way to benchmark different GNN architectures
on common datasets and compare their performance.
Usage:
python benchmark_model.py --models gcn gat --dataset Cora
python benchmark_model.py --models gcn --dataset Cora --epochs 200 --runs 10
"""
import argparse
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
import time
import numpy as np
class GCN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5):
super().__init__()
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, num_classes)
self.dropout = dropout
def forward(self, x, edge_index, batch=None):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
if batch is not None:
x = global_mean_pool(x, batch)
return F.log_softmax(x, dim=1)
class GAT(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes, heads=8, dropout=0.6):
super().__init__()
self.conv1 = GATConv(num_features, hidden_channels, heads=heads, dropout=dropout)
self.conv2 = GATConv(hidden_channels * heads, num_classes, heads=1,
concat=False, dropout=dropout)
self.dropout = dropout
def forward(self, x, edge_index, batch=None):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
if batch is not None:
x = global_mean_pool(x, batch)
return F.log_softmax(x, dim=1)
class GraphSAGE(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5):
super().__init__()
self.conv1 = SAGEConv(num_features, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, num_classes)
self.dropout = dropout
def forward(self, x, edge_index, batch=None):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
if batch is not None:
x = global_mean_pool(x, batch)
return F.log_softmax(x, dim=1)
MODELS = {
'gcn': GCN,
'gat': GAT,
'graphsage': GraphSAGE,
}
def train_node_classification(model, data, optimizer):
"""Train for node classification."""
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test_node_classification(model, data):
"""Test for node classification."""
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = (pred[mask] == data.y[mask]).sum()
accs.append(float(correct) / int(mask.sum()))
return accs
def train_graph_classification(model, loader, optimizer, device):
"""Train for graph classification."""
model.train()
total_loss = 0
for data in loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
return total_loss / len(loader.dataset)
@torch.no_grad()
def test_graph_classification(model, loader, device):
"""Test for graph classification."""
model.eval()
correct = 0
for data in loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += (pred == data.y).sum().item()
return correct / len(loader.dataset)
def benchmark_node_classification(model_name, dataset_name, epochs, lr, weight_decay, device):
"""Benchmark a model on node classification."""
# Load dataset
dataset = Planetoid(root=f'/tmp/{dataset_name}', name=dataset_name)
data = dataset[0].to(device)
# Create model
model_class = MODELS[model_name]
model = model_class(
num_features=dataset.num_features,
hidden_channels=64,
num_classes=dataset.num_classes
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# Training
start_time = time.time()
best_val_acc = 0
best_test_acc = 0
for epoch in range(1, epochs + 1):
loss = train_node_classification(model, data, optimizer)
train_acc, val_acc, test_acc = test_node_classification(model, data)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
train_time = time.time() - start_time
return {
'train_acc': train_acc,
'val_acc': best_val_acc,
'test_acc': best_test_acc,
'train_time': train_time,
}
def benchmark_graph_classification(model_name, dataset_name, epochs, lr, device):
"""Benchmark a model on graph classification."""
# Load dataset
dataset = TUDataset(root=f'/tmp/{dataset_name}', name=dataset_name)
# Split dataset
dataset = dataset.shuffle()
train_dataset = dataset[:int(len(dataset) * 0.8)]
test_dataset = dataset[int(len(dataset) * 0.8):]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
# Create model
model_class = MODELS[model_name]
model = model_class(
num_features=dataset.num_features,
hidden_channels=64,
num_classes=dataset.num_classes
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Training
start_time = time.time()
for epoch in range(1, epochs + 1):
loss = train_graph_classification(model, train_loader, optimizer, device)
# Final evaluation
train_acc = test_graph_classification(model, train_loader, device)
test_acc = test_graph_classification(model, test_loader, device)
train_time = time.time() - start_time
return {
'train_acc': train_acc,
'test_acc': test_acc,
'train_time': train_time,
}
def run_benchmark(args):
"""Run benchmark experiments."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Determine task type
if args.dataset in ['Cora', 'CiteSeer', 'PubMed']:
task = 'node_classification'
else:
task = 'graph_classification'
print(f"\\nDataset: {args.dataset}")
print(f"Task: {task}")
print(f"Models: {', '.join(args.models)}")
print(f"Epochs: {args.epochs}")
print(f"Runs: {args.runs}")
print("=" * 60)
results = {model: [] for model in args.models}
# Run experiments
for run in range(args.runs):
print(f"\\nRun {run + 1}/{args.runs}")
print("-" * 60)
for model_name in args.models:
if model_name not in MODELS:
print(f"Unknown model: {model_name}")
continue
print(f" Training {model_name.upper()}...", end=" ")
try:
if task == 'node_classification':
result = benchmark_node_classification(
model_name, args.dataset, args.epochs,
args.lr, args.weight_decay, device
)
print(f"Test Acc: {result['test_acc']:.4f}, "
f"Time: {result['train_time']:.2f}s")
else:
result = benchmark_graph_classification(
model_name, args.dataset, args.epochs, args.lr, device
)
print(f"Test Acc: {result['test_acc']:.4f}, "
f"Time: {result['train_time']:.2f}s")
results[model_name].append(result)
except Exception as e:
print(f"Error: {e}")
# Print summary
print("\\n" + "=" * 60)
print("BENCHMARK RESULTS")
print("=" * 60)
for model_name in args.models:
if not results[model_name]:
continue
test_accs = [r['test_acc'] for r in results[model_name]]
times = [r['train_time'] for r in results[model_name]]
print(f"\\n{model_name.upper()}")
print(f" Test Accuracy: {np.mean(test_accs):.4f} ± {np.std(test_accs):.4f}")
print(f" Training Time: {np.mean(times):.2f} ± {np.std(times):.2f}s")
def main():
parser = argparse.ArgumentParser(description="Benchmark GNN models")
parser.add_argument('--models', nargs='+', default=['gcn'],
help='Model types to benchmark (gcn, gat, graphsage)')
parser.add_argument('--dataset', type=str, default='Cora',
help='Dataset name (Cora, CiteSeer, PubMed, ENZYMES, PROTEINS)')
parser.add_argument('--epochs', type=int, default=200,
help='Number of training epochs')
parser.add_argument('--runs', type=int, default=5,
help='Number of runs to average over')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning rate')
parser.add_argument('--weight-decay', type=float, default=5e-4,
help='Weight decay for node classification')
args = parser.parse_args()
run_benchmark(args)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,529 @@
#!/usr/bin/env python3
"""
Generate boilerplate code for common GNN architectures in PyTorch Geometric.
This script creates ready-to-use GNN model templates with training loops,
evaluation metrics, and proper data handling.
Usage:
python create_gnn_template.py --model gcn --task node_classification --output my_model.py
python create_gnn_template.py --model gat --task graph_classification --output graph_classifier.py
"""
import argparse
from pathlib import Path
TEMPLATES = {
'node_classification': {
'gcn': '''import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
class GCN(torch.nn.Module):
"""Graph Convolutional Network for node classification."""
def __init__(self, num_features, hidden_channels, num_classes, num_layers=2, dropout=0.5):
super().__init__()
self.convs = torch.nn.ModuleList()
# First layer
self.convs.append(GCNConv(num_features, hidden_channels))
# Hidden layers
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
# Output layer
self.convs.append(GCNConv(hidden_channels, num_classes))
self.dropout = dropout
def forward(self, data):
x, edge_index = data.x, data.edge_index
# Apply conv layers with ReLU and dropout
for conv in self.convs[:-1]:
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Final layer without activation
x = self.convs[-1](x, edge_index)
return F.log_softmax(x, dim=1)
def train(model, data, optimizer):
"""Train the model for one epoch."""
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test(model, data):
"""Evaluate the model."""
model.eval()
out = model(data)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = (pred[mask] == data.y[mask]).sum()
accs.append(int(correct) / int(mask.sum()))
return accs
def main():
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(
num_features=dataset.num_features,
hidden_channels=64,
num_classes=dataset.num_classes,
num_layers=3,
dropout=0.5
).to(device)
data = data.to(device)
# Setup optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training loop
print("Training GCN model...")
best_val_acc = 0
for epoch in range(1, 201):
loss = train(model, data, optimizer)
train_acc, val_acc, test_acc = test(model, data)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if epoch % 10 == 0:
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
print(f'\\nBest Test Accuracy: {best_test_acc:.4f}')
if __name__ == '__main__':
main()
''',
'gat': '''import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
class GAT(torch.nn.Module):
"""Graph Attention Network for node classification."""
def __init__(self, num_features, hidden_channels, num_classes, heads=8, dropout=0.6):
super().__init__()
self.conv1 = GATConv(num_features, hidden_channels, heads=heads, dropout=dropout)
self.conv2 = GATConv(hidden_channels * heads, num_classes, heads=1,
concat=False, dropout=dropout)
self.dropout = dropout
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
def train(model, data, optimizer):
"""Train the model for one epoch."""
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test(model, data):
"""Evaluate the model."""
model.eval()
out = model(data)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = (pred[mask] == data.y[mask]).sum()
accs.append(int(correct) / int(mask.sum()))
return accs
def main():
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(
num_features=dataset.num_features,
hidden_channels=8,
num_classes=dataset.num_classes,
heads=8,
dropout=0.6
).to(device)
data = data.to(device)
# Setup optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
# Training loop
print("Training GAT model...")
best_val_acc = 0
for epoch in range(1, 201):
loss = train(model, data, optimizer)
train_acc, val_acc, test_acc = test(model, data)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if epoch % 10 == 0:
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
print(f'\\nBest Test Accuracy: {best_test_acc:.4f}')
if __name__ == '__main__':
main()
''',
'graphsage': '''import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
class GraphSAGE(torch.nn.Module):
"""GraphSAGE for node classification."""
def __init__(self, num_features, hidden_channels, num_classes, num_layers=2, dropout=0.5):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(num_features, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, num_classes))
self.dropout = dropout
def forward(self, data):
x, edge_index = data.x, data.edge_index
for conv in self.convs[:-1]:
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return F.log_softmax(x, dim=1)
def train(model, data, optimizer):
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test(model, data):
model.eval()
out = model(data)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = (pred[mask] == data.y[mask]).sum()
accs.append(int(correct) / int(mask.sum()))
return accs
def main():
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(
num_features=dataset.num_features,
hidden_channels=64,
num_classes=dataset.num_classes,
num_layers=2,
dropout=0.5
).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
print("Training GraphSAGE model...")
best_val_acc = 0
for epoch in range(1, 201):
loss = train(model, data, optimizer)
train_acc, val_acc, test_acc = test(model, data)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if epoch % 10 == 0:
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
print(f'\\nBest Test Accuracy: {best_test_acc:.4f}')
if __name__ == '__main__':
main()
''',
},
'graph_classification': {
'gin': '''import torch
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
class GIN(torch.nn.Module):
"""Graph Isomorphism Network for graph classification."""
def __init__(self, num_features, hidden_channels, num_classes, num_layers=3, dropout=0.5):
super().__init__()
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
# Create MLP for first layer
nn = torch.nn.Sequential(
torch.nn.Linear(num_features, hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(hidden_channels, hidden_channels)
)
self.convs.append(GINConv(nn))
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))
# Hidden layers
for _ in range(num_layers - 2):
nn = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(hidden_channels, hidden_channels)
)
self.convs.append(GINConv(nn))
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))
# Output MLP
self.lin = torch.nn.Linear(hidden_channels, num_classes)
self.dropout = dropout
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
for conv, batch_norm in zip(self.convs, self.batch_norms):
x = conv(x, edge_index)
x = batch_norm(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Global pooling
x = global_add_pool(x, batch)
# Output layer
x = self.lin(x)
return F.log_softmax(x, dim=1)
def train(model, loader, optimizer, device):
"""Train the model for one epoch."""
model.train()
total_loss = 0
for data in loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
return total_loss / len(loader.dataset)
@torch.no_grad()
def test(model, loader, device):
"""Evaluate the model."""
model.eval()
correct = 0
for data in loader:
data = data.to(device)
out = model(data)
pred = out.argmax(dim=1)
correct += (pred == data.y).sum().item()
return correct / len(loader.dataset)
def main():
# Load dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")
# Shuffle and split
dataset = dataset.shuffle()
train_dataset = dataset[:int(len(dataset) * 0.8)]
test_dataset = dataset[int(len(dataset) * 0.8):]
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GIN(
num_features=dataset.num_features,
hidden_channels=64,
num_classes=dataset.num_classes,
num_layers=3,
dropout=0.5
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Training loop
print("\\nTraining GIN model...")
for epoch in range(1, 101):
loss = train(model, train_loader, optimizer, device)
train_acc = test(model, train_loader, device)
test_acc = test(model, test_loader, device)
if epoch % 10 == 0:
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
if __name__ == '__main__':
main()
''',
},
}
def generate_template(model_type: str, task: str, output_path: str):
"""Generate a GNN template file."""
if task not in TEMPLATES:
raise ValueError(f"Unknown task: {task}. Available: {list(TEMPLATES.keys())}")
if model_type not in TEMPLATES[task]:
raise ValueError(f"Model {model_type} not available for task {task}. "
f"Available: {list(TEMPLATES[task].keys())}")
template = TEMPLATES[task][model_type]
# Write to file
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w') as f:
f.write(template)
print(f"✓ Generated {model_type.upper()} template for {task}")
print(f" Saved to: {output_path}")
print(f"\\nTo run the template:")
print(f" python {output_path}")
def list_templates():
"""List all available templates."""
print("Available GNN Templates")
print("=" * 50)
for task, models in TEMPLATES.items():
print(f"\\n{task.upper()}")
print("-" * 50)
for model in models.keys():
print(f" - {model}")
print()
def main():
parser = argparse.ArgumentParser(
description="Generate GNN model templates",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python create_gnn_template.py --model gcn --task node_classification --output gcn_model.py
python create_gnn_template.py --model gin --task graph_classification --output gin_model.py
python create_gnn_template.py --list
"""
)
parser.add_argument('--model', type=str,
help='Model type (gcn, gat, graphsage, gin)')
parser.add_argument('--task', type=str,
help='Task type (node_classification, graph_classification)')
parser.add_argument('--output', type=str, default='gnn_model.py',
help='Output file path (default: gnn_model.py)')
parser.add_argument('--list', action='store_true',
help='List all available templates')
args = parser.parse_args()
if args.list:
list_templates()
return
if not args.model or not args.task:
parser.print_help()
print("\\n" + "=" * 50)
list_templates()
return
try:
generate_template(args.model, args.task, args.output)
except ValueError as e:
print(f"Error: {e}")
print("\\nUse --list to see available templates")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,313 @@
#!/usr/bin/env python3
"""
Visualize PyTorch Geometric graph structures using networkx and matplotlib.
This script provides utilities to visualize Data objects, including:
- Graph structure (nodes and edges)
- Node features (as colors)
- Edge attributes (as edge colors/widths)
- Community/cluster assignments
Usage:
python visualize_graph.py --dataset Cora --output graph.png
Or import and use:
from scripts.visualize_graph import visualize_data
visualize_data(data, title="My Graph", show_labels=True)
"""
import argparse
import matplotlib.pyplot as plt
import networkx as nx
import torch
from typing import Optional, Union
import numpy as np
def visualize_data(
data,
title: str = "Graph Visualization",
node_color_attr: Optional[str] = None,
edge_color_attr: Optional[str] = None,
show_labels: bool = False,
node_size: int = 300,
figsize: tuple = (12, 10),
layout: str = "spring",
output_path: Optional[str] = None,
max_nodes: Optional[int] = None,
):
"""
Visualize a PyTorch Geometric Data object.
Args:
data: PyTorch Geometric Data object
title: Plot title
node_color_attr: Data attribute to use for node colors (e.g., 'y', 'train_mask')
edge_color_attr: Data attribute to use for edge colors
show_labels: Whether to show node labels
node_size: Size of nodes in visualization
figsize: Figure size (width, height)
layout: Graph layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral')
output_path: Path to save figure (if None, displays interactively)
max_nodes: Maximum number of nodes to visualize (samples if exceeded)
"""
# Sample nodes if graph is too large
if max_nodes and data.num_nodes > max_nodes:
print(f"Graph has {data.num_nodes} nodes. Sampling {max_nodes} nodes for visualization.")
node_indices = torch.randperm(data.num_nodes)[:max_nodes]
data = data.subgraph(node_indices)
# Convert to networkx graph
G = nx.Graph() if is_undirected(data.edge_index) else nx.DiGraph()
# Add nodes
G.add_nodes_from(range(data.num_nodes))
# Add edges
edge_index = data.edge_index.cpu().numpy()
edges = list(zip(edge_index[0], edge_index[1]))
G.add_edges_from(edges)
# Setup figure
fig, ax = plt.subplots(figsize=figsize)
# Choose layout
if layout == "spring":
pos = nx.spring_layout(G, k=0.5, iterations=50)
elif layout == "circular":
pos = nx.circular_layout(G)
elif layout == "kamada_kawai":
pos = nx.kamada_kawai_layout(G)
elif layout == "spectral":
pos = nx.spectral_layout(G)
else:
raise ValueError(f"Unknown layout: {layout}")
# Determine node colors
if node_color_attr and hasattr(data, node_color_attr):
node_colors = getattr(data, node_color_attr).cpu().numpy()
if node_colors.dtype == bool:
node_colors = node_colors.astype(int)
if len(node_colors.shape) > 1:
# Multi-dimensional features - use first dimension
node_colors = node_colors[:, 0]
else:
node_colors = 'skyblue'
# Determine edge colors
if edge_color_attr and hasattr(data, edge_color_attr):
edge_colors = getattr(data, edge_color_attr).cpu().numpy()
if len(edge_colors.shape) > 1:
edge_colors = edge_colors[:, 0]
else:
edge_colors = 'gray'
# Draw graph
nx.draw_networkx_nodes(
G, pos,
node_color=node_colors,
node_size=node_size,
cmap=plt.cm.viridis,
ax=ax
)
nx.draw_networkx_edges(
G, pos,
edge_color=edge_colors,
alpha=0.3,
arrows=isinstance(G, nx.DiGraph),
arrowsize=10,
ax=ax
)
if show_labels:
nx.draw_networkx_labels(G, pos, font_size=8, ax=ax)
ax.set_title(title, fontsize=16, fontweight='bold')
ax.axis('off')
# Add colorbar if using numeric node colors
if node_color_attr and isinstance(node_colors, np.ndarray):
sm = plt.cm.ScalarMappable(
cmap=plt.cm.viridis,
norm=plt.Normalize(vmin=node_colors.min(), vmax=node_colors.max())
)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(node_color_attr, rotation=270, labelpad=20)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Figure saved to {output_path}")
else:
plt.show()
plt.close()
def is_undirected(edge_index):
"""Check if graph is undirected."""
row, col = edge_index
num_edges = edge_index.size(1)
# Create a set of edges and reverse edges
edges = set(zip(row.tolist(), col.tolist()))
reverse_edges = set(zip(col.tolist(), row.tolist()))
# Check if all edges have their reverse
return edges == reverse_edges
def plot_degree_distribution(data, output_path: Optional[str] = None):
"""Plot the degree distribution of the graph."""
from torch_geometric.utils import degree
row, col = data.edge_index
deg = degree(col, data.num_nodes).cpu().numpy()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Histogram
ax1.hist(deg, bins=50, edgecolor='black', alpha=0.7)
ax1.set_xlabel('Degree', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.set_title('Degree Distribution', fontsize=14, fontweight='bold')
ax1.grid(alpha=0.3)
# Log-log plot
unique_degrees, counts = np.unique(deg, return_counts=True)
ax2.loglog(unique_degrees, counts, 'o-', alpha=0.7)
ax2.set_xlabel('Degree (log scale)', fontsize=12)
ax2.set_ylabel('Frequency (log scale)', fontsize=12)
ax2.set_title('Degree Distribution (Log-Log)', fontsize=14, fontweight='bold')
ax2.grid(alpha=0.3)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Degree distribution saved to {output_path}")
else:
plt.show()
plt.close()
def plot_graph_statistics(data, output_path: Optional[str] = None):
"""Plot various graph statistics."""
from torch_geometric.utils import degree, contains_self_loops, is_undirected as check_undirected
# Compute statistics
row, col = data.edge_index
deg = degree(col, data.num_nodes).cpu().numpy()
stats = {
'Nodes': data.num_nodes,
'Edges': data.num_edges,
'Avg Degree': deg.mean(),
'Max Degree': deg.max(),
'Self-loops': contains_self_loops(data.edge_index),
'Undirected': check_undirected(data.edge_index),
}
if hasattr(data, 'num_node_features'):
stats['Node Features'] = data.num_node_features
if hasattr(data, 'num_edge_features') and data.edge_attr is not None:
stats['Edge Features'] = data.num_edge_features
if hasattr(data, 'y'):
if data.y.dim() == 1:
stats['Classes'] = int(data.y.max().item()) + 1
# Create text plot
fig, ax = plt.subplots(figsize=(8, 6))
ax.axis('off')
text = "Graph Statistics\n" + "=" * 40 + "\n\n"
for key, value in stats.items():
text += f"{key:20s}: {value}\n"
ax.text(0.1, 0.5, text, fontsize=14, family='monospace',
verticalalignment='center', transform=ax.transAxes)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Statistics saved to {output_path}")
else:
plt.show()
plt.close()
# Print to console as well
print("\n" + text)
def main():
parser = argparse.ArgumentParser(description="Visualize PyTorch Geometric graphs")
parser.add_argument('--dataset', type=str, default='Cora',
help='Dataset name (e.g., Cora, CiteSeer, ENZYMES)')
parser.add_argument('--output', type=str, default=None,
help='Output file path for visualization')
parser.add_argument('--node-color', type=str, default='y',
help='Attribute to use for node colors')
parser.add_argument('--layout', type=str, default='spring',
choices=['spring', 'circular', 'kamada_kawai', 'spectral'],
help='Graph layout algorithm')
parser.add_argument('--show-labels', action='store_true',
help='Show node labels')
parser.add_argument('--max-nodes', type=int, default=500,
help='Maximum nodes to visualize')
parser.add_argument('--stats', action='store_true',
help='Show graph statistics')
parser.add_argument('--degree', action='store_true',
help='Show degree distribution')
args = parser.parse_args()
# Load dataset
print(f"Loading dataset: {args.dataset}")
try:
# Try Planetoid datasets
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root=f'/tmp/{args.dataset}', name=args.dataset)
data = dataset[0]
except:
try:
# Try TUDataset
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root=f'/tmp/{args.dataset}', name=args.dataset)
data = dataset[0]
except Exception as e:
print(f"Error loading dataset: {e}")
print("Supported datasets: Cora, CiteSeer, PubMed, ENZYMES, PROTEINS, etc.")
return
print(f"Loaded {args.dataset}: {data.num_nodes} nodes, {data.num_edges} edges")
# Generate visualizations
if args.stats:
stats_output = args.output.replace('.png', '_stats.png') if args.output else None
plot_graph_statistics(data, stats_output)
if args.degree:
degree_output = args.output.replace('.png', '_degree.png') if args.output else None
plot_degree_distribution(data, degree_output)
# Main visualization
visualize_data(
data,
title=f"{args.dataset} Graph",
node_color_attr=args.node_color,
show_labels=args.show_labels,
layout=args.layout,
output_path=args.output,
max_nodes=args.max_nodes
)
if __name__ == '__main__':
main()