310 lines
9.8 KiB
Python
310 lines
9.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Benchmark GNN models on standard datasets.
|
|
|
|
This script provides a simple way to benchmark different GNN architectures
|
|
on common datasets and compare their performance.
|
|
|
|
Usage:
|
|
python benchmark_model.py --models gcn gat --dataset Cora
|
|
python benchmark_model.py --models gcn --dataset Cora --epochs 200 --runs 10
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv
|
|
from torch_geometric.datasets import Planetoid, TUDataset
|
|
from torch_geometric.loader import DataLoader
|
|
from torch_geometric.nn import global_mean_pool
|
|
import time
|
|
import numpy as np
|
|
|
|
|
|
class GCN(torch.nn.Module):
|
|
def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5):
|
|
super().__init__()
|
|
self.conv1 = GCNConv(num_features, hidden_channels)
|
|
self.conv2 = GCNConv(hidden_channels, num_classes)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, edge_index, batch=None):
|
|
x = self.conv1(x, edge_index)
|
|
x = F.relu(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = self.conv2(x, edge_index)
|
|
if batch is not None:
|
|
x = global_mean_pool(x, batch)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
class GAT(torch.nn.Module):
|
|
def __init__(self, num_features, hidden_channels, num_classes, heads=8, dropout=0.6):
|
|
super().__init__()
|
|
self.conv1 = GATConv(num_features, hidden_channels, heads=heads, dropout=dropout)
|
|
self.conv2 = GATConv(hidden_channels * heads, num_classes, heads=1,
|
|
concat=False, dropout=dropout)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, edge_index, batch=None):
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = F.elu(self.conv1(x, edge_index))
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = self.conv2(x, edge_index)
|
|
if batch is not None:
|
|
x = global_mean_pool(x, batch)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
class GraphSAGE(torch.nn.Module):
|
|
def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5):
|
|
super().__init__()
|
|
self.conv1 = SAGEConv(num_features, hidden_channels)
|
|
self.conv2 = SAGEConv(hidden_channels, num_classes)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, edge_index, batch=None):
|
|
x = self.conv1(x, edge_index)
|
|
x = F.relu(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = self.conv2(x, edge_index)
|
|
if batch is not None:
|
|
x = global_mean_pool(x, batch)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
MODELS = {
|
|
'gcn': GCN,
|
|
'gat': GAT,
|
|
'graphsage': GraphSAGE,
|
|
}
|
|
|
|
|
|
def train_node_classification(model, data, optimizer):
|
|
"""Train for node classification."""
|
|
model.train()
|
|
optimizer.zero_grad()
|
|
out = model(data.x, data.edge_index)
|
|
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
|
|
loss.backward()
|
|
optimizer.step()
|
|
return loss.item()
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_node_classification(model, data):
|
|
"""Test for node classification."""
|
|
model.eval()
|
|
out = model(data.x, data.edge_index)
|
|
pred = out.argmax(dim=1)
|
|
|
|
accs = []
|
|
for mask in [data.train_mask, data.val_mask, data.test_mask]:
|
|
correct = (pred[mask] == data.y[mask]).sum()
|
|
accs.append(float(correct) / int(mask.sum()))
|
|
|
|
return accs
|
|
|
|
|
|
def train_graph_classification(model, loader, optimizer, device):
|
|
"""Train for graph classification."""
|
|
model.train()
|
|
total_loss = 0
|
|
|
|
for data in loader:
|
|
data = data.to(device)
|
|
optimizer.zero_grad()
|
|
out = model(data.x, data.edge_index, data.batch)
|
|
loss = F.nll_loss(out, data.y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item() * data.num_graphs
|
|
|
|
return total_loss / len(loader.dataset)
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_graph_classification(model, loader, device):
|
|
"""Test for graph classification."""
|
|
model.eval()
|
|
correct = 0
|
|
|
|
for data in loader:
|
|
data = data.to(device)
|
|
out = model(data.x, data.edge_index, data.batch)
|
|
pred = out.argmax(dim=1)
|
|
correct += (pred == data.y).sum().item()
|
|
|
|
return correct / len(loader.dataset)
|
|
|
|
|
|
def benchmark_node_classification(model_name, dataset_name, epochs, lr, weight_decay, device):
|
|
"""Benchmark a model on node classification."""
|
|
# Load dataset
|
|
dataset = Planetoid(root=f'/tmp/{dataset_name}', name=dataset_name)
|
|
data = dataset[0].to(device)
|
|
|
|
# Create model
|
|
model_class = MODELS[model_name]
|
|
model = model_class(
|
|
num_features=dataset.num_features,
|
|
hidden_channels=64,
|
|
num_classes=dataset.num_classes
|
|
).to(device)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
|
|
# Training
|
|
start_time = time.time()
|
|
best_val_acc = 0
|
|
best_test_acc = 0
|
|
|
|
for epoch in range(1, epochs + 1):
|
|
loss = train_node_classification(model, data, optimizer)
|
|
train_acc, val_acc, test_acc = test_node_classification(model, data)
|
|
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
best_test_acc = test_acc
|
|
|
|
train_time = time.time() - start_time
|
|
|
|
return {
|
|
'train_acc': train_acc,
|
|
'val_acc': best_val_acc,
|
|
'test_acc': best_test_acc,
|
|
'train_time': train_time,
|
|
}
|
|
|
|
|
|
def benchmark_graph_classification(model_name, dataset_name, epochs, lr, device):
|
|
"""Benchmark a model on graph classification."""
|
|
# Load dataset
|
|
dataset = TUDataset(root=f'/tmp/{dataset_name}', name=dataset_name)
|
|
|
|
# Split dataset
|
|
dataset = dataset.shuffle()
|
|
train_dataset = dataset[:int(len(dataset) * 0.8)]
|
|
test_dataset = dataset[int(len(dataset) * 0.8):]
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
test_loader = DataLoader(test_dataset, batch_size=32)
|
|
|
|
# Create model
|
|
model_class = MODELS[model_name]
|
|
model = model_class(
|
|
num_features=dataset.num_features,
|
|
hidden_channels=64,
|
|
num_classes=dataset.num_classes
|
|
).to(device)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
|
# Training
|
|
start_time = time.time()
|
|
|
|
for epoch in range(1, epochs + 1):
|
|
loss = train_graph_classification(model, train_loader, optimizer, device)
|
|
|
|
# Final evaluation
|
|
train_acc = test_graph_classification(model, train_loader, device)
|
|
test_acc = test_graph_classification(model, test_loader, device)
|
|
train_time = time.time() - start_time
|
|
|
|
return {
|
|
'train_acc': train_acc,
|
|
'test_acc': test_acc,
|
|
'train_time': train_time,
|
|
}
|
|
|
|
|
|
def run_benchmark(args):
|
|
"""Run benchmark experiments."""
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"Using device: {device}")
|
|
|
|
# Determine task type
|
|
if args.dataset in ['Cora', 'CiteSeer', 'PubMed']:
|
|
task = 'node_classification'
|
|
else:
|
|
task = 'graph_classification'
|
|
|
|
print(f"\\nDataset: {args.dataset}")
|
|
print(f"Task: {task}")
|
|
print(f"Models: {', '.join(args.models)}")
|
|
print(f"Epochs: {args.epochs}")
|
|
print(f"Runs: {args.runs}")
|
|
print("=" * 60)
|
|
|
|
results = {model: [] for model in args.models}
|
|
|
|
# Run experiments
|
|
for run in range(args.runs):
|
|
print(f"\\nRun {run + 1}/{args.runs}")
|
|
print("-" * 60)
|
|
|
|
for model_name in args.models:
|
|
if model_name not in MODELS:
|
|
print(f"Unknown model: {model_name}")
|
|
continue
|
|
|
|
print(f" Training {model_name.upper()}...", end=" ")
|
|
|
|
try:
|
|
if task == 'node_classification':
|
|
result = benchmark_node_classification(
|
|
model_name, args.dataset, args.epochs,
|
|
args.lr, args.weight_decay, device
|
|
)
|
|
print(f"Test Acc: {result['test_acc']:.4f}, "
|
|
f"Time: {result['train_time']:.2f}s")
|
|
else:
|
|
result = benchmark_graph_classification(
|
|
model_name, args.dataset, args.epochs, args.lr, device
|
|
)
|
|
print(f"Test Acc: {result['test_acc']:.4f}, "
|
|
f"Time: {result['train_time']:.2f}s")
|
|
|
|
results[model_name].append(result)
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
# Print summary
|
|
print("\\n" + "=" * 60)
|
|
print("BENCHMARK RESULTS")
|
|
print("=" * 60)
|
|
|
|
for model_name in args.models:
|
|
if not results[model_name]:
|
|
continue
|
|
|
|
test_accs = [r['test_acc'] for r in results[model_name]]
|
|
times = [r['train_time'] for r in results[model_name]]
|
|
|
|
print(f"\\n{model_name.upper()}")
|
|
print(f" Test Accuracy: {np.mean(test_accs):.4f} ± {np.std(test_accs):.4f}")
|
|
print(f" Training Time: {np.mean(times):.2f} ± {np.std(times):.2f}s")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Benchmark GNN models")
|
|
parser.add_argument('--models', nargs='+', default=['gcn'],
|
|
help='Model types to benchmark (gcn, gat, graphsage)')
|
|
parser.add_argument('--dataset', type=str, default='Cora',
|
|
help='Dataset name (Cora, CiteSeer, PubMed, ENZYMES, PROTEINS)')
|
|
parser.add_argument('--epochs', type=int, default=200,
|
|
help='Number of training epochs')
|
|
parser.add_argument('--runs', type=int, default=5,
|
|
help='Number of runs to average over')
|
|
parser.add_argument('--lr', type=float, default=0.01,
|
|
help='Learning rate')
|
|
parser.add_argument('--weight-decay', type=float, default=5e-4,
|
|
help='Weight decay for node classification')
|
|
|
|
args = parser.parse_args()
|
|
run_benchmark(args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|