#!/usr/bin/env python3 """ Benchmark GNN models on standard datasets. This script provides a simple way to benchmark different GNN architectures on common datasets and compare their performance. Usage: python benchmark_model.py --models gcn gat --dataset Cora python benchmark_model.py --models gcn --dataset Cora --epochs 200 --runs 10 """ import argparse import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv from torch_geometric.datasets import Planetoid, TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import global_mean_pool import time import numpy as np class GCN(torch.nn.Module): def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5): super().__init__() self.conv1 = GCNConv(num_features, hidden_channels) self.conv2 = GCNConv(hidden_channels, num_classes) self.dropout = dropout def forward(self, x, edge_index, batch=None): x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x, edge_index) if batch is not None: x = global_mean_pool(x, batch) return F.log_softmax(x, dim=1) class GAT(torch.nn.Module): def __init__(self, num_features, hidden_channels, num_classes, heads=8, dropout=0.6): super().__init__() self.conv1 = GATConv(num_features, hidden_channels, heads=heads, dropout=dropout) self.conv2 = GATConv(hidden_channels * heads, num_classes, heads=1, concat=False, dropout=dropout) self.dropout = dropout def forward(self, x, edge_index, batch=None): x = F.dropout(x, p=self.dropout, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x, edge_index) if batch is not None: x = global_mean_pool(x, batch) return F.log_softmax(x, dim=1) class GraphSAGE(torch.nn.Module): def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5): super().__init__() self.conv1 = SAGEConv(num_features, hidden_channels) self.conv2 = SAGEConv(hidden_channels, num_classes) self.dropout = dropout def forward(self, x, edge_index, batch=None): x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x, edge_index) if batch is not None: x = global_mean_pool(x, batch) return F.log_softmax(x, dim=1) MODELS = { 'gcn': GCN, 'gat': GAT, 'graphsage': GraphSAGE, } def train_node_classification(model, data, optimizer): """Train for node classification.""" model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return loss.item() @torch.no_grad() def test_node_classification(model, data): """Test for node classification.""" model.eval() out = model(data.x, data.edge_index) pred = out.argmax(dim=1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: correct = (pred[mask] == data.y[mask]).sum() accs.append(float(correct) / int(mask.sum())) return accs def train_graph_classification(model, loader, optimizer, device): """Train for graph classification.""" model.train() total_loss = 0 for data in loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() total_loss += loss.item() * data.num_graphs return total_loss / len(loader.dataset) @torch.no_grad() def test_graph_classification(model, loader, device): """Test for graph classification.""" model.eval() correct = 0 for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch) pred = out.argmax(dim=1) correct += (pred == data.y).sum().item() return correct / len(loader.dataset) def benchmark_node_classification(model_name, dataset_name, epochs, lr, weight_decay, device): """Benchmark a model on node classification.""" # Load dataset dataset = Planetoid(root=f'/tmp/{dataset_name}', name=dataset_name) data = dataset[0].to(device) # Create model model_class = MODELS[model_name] model = model_class( num_features=dataset.num_features, hidden_channels=64, num_classes=dataset.num_classes ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) # Training start_time = time.time() best_val_acc = 0 best_test_acc = 0 for epoch in range(1, epochs + 1): loss = train_node_classification(model, data, optimizer) train_acc, val_acc, test_acc = test_node_classification(model, data) if val_acc > best_val_acc: best_val_acc = val_acc best_test_acc = test_acc train_time = time.time() - start_time return { 'train_acc': train_acc, 'val_acc': best_val_acc, 'test_acc': best_test_acc, 'train_time': train_time, } def benchmark_graph_classification(model_name, dataset_name, epochs, lr, device): """Benchmark a model on graph classification.""" # Load dataset dataset = TUDataset(root=f'/tmp/{dataset_name}', name=dataset_name) # Split dataset dataset = dataset.shuffle() train_dataset = dataset[:int(len(dataset) * 0.8)] test_dataset = dataset[int(len(dataset) * 0.8):] train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32) # Create model model_class = MODELS[model_name] model = model_class( num_features=dataset.num_features, hidden_channels=64, num_classes=dataset.num_classes ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Training start_time = time.time() for epoch in range(1, epochs + 1): loss = train_graph_classification(model, train_loader, optimizer, device) # Final evaluation train_acc = test_graph_classification(model, train_loader, device) test_acc = test_graph_classification(model, test_loader, device) train_time = time.time() - start_time return { 'train_acc': train_acc, 'test_acc': test_acc, 'train_time': train_time, } def run_benchmark(args): """Run benchmark experiments.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Determine task type if args.dataset in ['Cora', 'CiteSeer', 'PubMed']: task = 'node_classification' else: task = 'graph_classification' print(f"\\nDataset: {args.dataset}") print(f"Task: {task}") print(f"Models: {', '.join(args.models)}") print(f"Epochs: {args.epochs}") print(f"Runs: {args.runs}") print("=" * 60) results = {model: [] for model in args.models} # Run experiments for run in range(args.runs): print(f"\\nRun {run + 1}/{args.runs}") print("-" * 60) for model_name in args.models: if model_name not in MODELS: print(f"Unknown model: {model_name}") continue print(f" Training {model_name.upper()}...", end=" ") try: if task == 'node_classification': result = benchmark_node_classification( model_name, args.dataset, args.epochs, args.lr, args.weight_decay, device ) print(f"Test Acc: {result['test_acc']:.4f}, " f"Time: {result['train_time']:.2f}s") else: result = benchmark_graph_classification( model_name, args.dataset, args.epochs, args.lr, device ) print(f"Test Acc: {result['test_acc']:.4f}, " f"Time: {result['train_time']:.2f}s") results[model_name].append(result) except Exception as e: print(f"Error: {e}") # Print summary print("\\n" + "=" * 60) print("BENCHMARK RESULTS") print("=" * 60) for model_name in args.models: if not results[model_name]: continue test_accs = [r['test_acc'] for r in results[model_name]] times = [r['train_time'] for r in results[model_name]] print(f"\\n{model_name.upper()}") print(f" Test Accuracy: {np.mean(test_accs):.4f} ± {np.std(test_accs):.4f}") print(f" Training Time: {np.mean(times):.2f} ± {np.std(times):.2f}s") def main(): parser = argparse.ArgumentParser(description="Benchmark GNN models") parser.add_argument('--models', nargs='+', default=['gcn'], help='Model types to benchmark (gcn, gat, graphsage)') parser.add_argument('--dataset', type=str, default='Cora', help='Dataset name (Cora, CiteSeer, PubMed, ENZYMES, PROTEINS)') parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs') parser.add_argument('--runs', type=int, default=5, help='Number of runs to average over') parser.add_argument('--lr', type=float, default=0.01, help='Learning rate') parser.add_argument('--weight-decay', type=float, default=5e-4, help='Weight decay for node classification') args = parser.parse_args() run_benchmark(args) if __name__ == '__main__': main()