8.3 KiB
Knowledge Graph Reasoning
Overview
Knowledge graphs represent structured information as entities and relations in a graph format. TorchDrug provides comprehensive support for knowledge graph completion (link prediction) using embedding-based models and neural reasoning approaches.
Available Datasets
General Knowledge Graphs
FB15k (Freebase subset):
- 14,951 entities
- 1,345 relation types
- 592,213 triples
- General world knowledge from Freebase
FB15k-237:
- 14,541 entities
- 237 relation types
- 310,116 triples
- Filtered version removing inverse relations
- More challenging benchmark
WN18 (WordNet):
- 40,943 entities (word senses)
- 18 relation types (lexical relations)
- 151,442 triples
- Linguistic knowledge graph
WN18RR:
- 40,943 entities
- 11 relation types
- 93,003 triples
- Filtered WordNet removing easy inverse patterns
Biomedical Knowledge Graphs
Hetionet:
- 45,158 entities (genes, compounds, diseases, pathways, etc.)
- 24 relation types (treats, causes, binds, etc.)
- 2,250,197 edges
- Integrates 29 public biomedical databases
- Designed for drug repurposing and disease understanding
Task: KnowledgeGraphCompletion
The primary task for knowledge graphs is link prediction - given a head entity and relation, predict the tail entity (or vice versa).
Task Modes
Head Prediction:
- Given (?, relation, tail), predict head entity
- "What can cause Disease X?"
Tail Prediction:
- Given (head, relation, ?), predict tail entity
- "What diseases does Gene X cause?"
Both:
- Predict both head and tail
- Standard evaluation protocol
Evaluation Metrics
Ranking Metrics:
- Mean Rank (MR): Average rank of correct entity
- Mean Reciprocal Rank (MRR): Average of 1/rank
- Hits@K: Percentage of correct entities in top K predictions
- Typically reported for K=1, 3, 10
Filtered vs Raw:
- Filtered: Remove other known true triples from ranking
- Raw: Rank among all possible entities
- Filtered is standard for evaluation
Embedding Models
Translational Models
TransE (Translation Embedding):
- Represents relations as translations in embedding space
- h + r ≈ t (head + relation ≈ tail)
- Simple and effective baseline
- Works well for 1-to-1 relations
- Struggles with N-to-N relations
RotatE (Rotation Embedding):
- Relations as rotations in complex space
- Better handles symmetric and inverse relations
- State-of-the-art on many benchmarks
- Can model composition patterns
Semantic Matching Models
DistMult:
- Bilinear scoring function
- Handles symmetric relations naturally
- Cannot model asymmetric relations
- Fast and memory efficient
ComplEx:
- Complex-valued embeddings
- Models asymmetric and inverse relations
- Better than DistMult for most graphs
- Balances expressiveness and efficiency
SimplE:
- Extends DistMult with inverse relations
- Fully expressive (can represent any relation pattern)
- Two embeddings per entity (canonical and inverse)
Neural Logic Models
NeuralLP (Neural Logic Programming):
- Learns logical rules through differentiable operations
- Interprets predictions via learned rules
- Good for sparse knowledge graphs
- Computationally more expensive
KBGAT (Knowledge Base Graph Attention):
- Graph attention networks for KG completion
- Learns entity representations from neighborhood
- Handles unseen entities through inductive learning
- Better for incomplete graphs
Training Workflow
Basic Pipeline
from torchdrug import datasets, models, tasks, core
# Load dataset
dataset = datasets.FB15k237("~/kg-datasets/")
# Define model
model = models.RotatE(
num_entity=dataset.num_entity,
num_relation=dataset.num_relation,
embedding_dim=2000,
max_score=9
)
# Define task
task = tasks.KnowledgeGraphCompletion(
model,
num_negative=128,
adversarial_temperature=2,
criterion="bce"
)
# Train with PyTorch Lightning or custom loop
Negative Sampling
Strategies:
- Uniform: Sample entities uniformly at random
- Self-Adversarial: Weight samples by current model's scores
- Type-Constrained: Sample only valid entity types for relation
Parameters:
num_negative: Number of negative samples per positive tripleadversarial_temperature: Temperature for self-adversarial weighting- Higher temperature = more focus on hard negatives
Loss Functions
Binary Cross-Entropy (BCE):
- Treats each triple independently
- Balanced classification between positive and negative
Margin Loss:
- Ensures positive scores higher than negative by margin
max(0, margin + score_neg - score_pos)
Logistic Loss:
- Smooth version of margin loss
- Better gradient properties
Model Selection Guide
By Relation Patterns
1-to-1 Relations:
- TransE works well
- Any model will likely succeed
1-to-N Relations:
- DistMult, ComplEx, SimplE
- Avoid TransE
N-to-1 Relations:
- DistMult, ComplEx, SimplE
- Avoid TransE
N-to-N Relations:
- ComplEx, SimplE, RotatE
- Most challenging pattern
Symmetric Relations:
- DistMult, ComplEx
- RotatE with proper initialization
Antisymmetric Relations:
- ComplEx, SimplE, RotatE
- Avoid DistMult
Inverse Relations:
- ComplEx, SimplE, RotatE
- Important for bidirectional reasoning
Composition:
- RotatE (best)
- TransE (reasonable)
- Captures multi-hop paths
By Dataset Characteristics
Small Graphs (< 50k entities):
- ComplEx or SimplE
- Lower embedding dimensions (200-500)
Large Graphs (> 100k entities):
- DistMult for efficiency
- RotatE for accuracy
- Higher dimensions (500-2000)
Sparse Graphs:
- NeuralLP (learns rules from limited data)
- Pre-train embeddings on larger graphs
Dense, Complete Graphs:
- Any embedding model works well
- Choose based on relation patterns
Biomedical/Domain Graphs:
- Consider type constraints in sampling
- Use domain-specific negative sampling
- Hetionet benefits from relation-specific models
Advanced Techniques
Multi-Hop Reasoning
Chain multiple relations to answer complex queries:
- "What drugs treat diseases caused by gene X?"
- Requires path-based or rule-based reasoning
- NeuralLP naturally supports this
Temporal Knowledge Graphs
Extend to time-varying facts:
- Add temporal information to triples
- Predict future facts
- Requires temporal encoding in models
Few-Shot Learning
Handle relations with few examples:
- Meta-learning approaches
- Transfer from related relations
- Important for emerging knowledge
Inductive Learning
Generalize to unseen entities:
- KBGAT and other GNN-based methods
- Use entity features/descriptions
- Critical for evolving knowledge graphs
Biomedical Applications
Drug Repurposing
Predict "drug treats disease" links in Hetionet:
- Train on known drug-disease associations
- Predict new treatment candidates
- Filter by mechanism (gene, pathway involvement)
- Validate predictions experimentally
Disease Gene Discovery
Identify genes associated with diseases:
- Model gene-disease-pathway networks
- Predict missing gene-disease links
- Incorporate protein interactions, expression data
- Prioritize candidates for validation
Protein Function Prediction
Link proteins to biological processes:
- Integrate protein interactions, GO terms
- Predict missing GO annotations
- Transfer function from similar proteins
Common Issues and Solutions
Issue: Poor performance on specific relation types
- Solution: Analyze relation patterns, choose appropriate model, or use relation-specific models
Issue: Overfitting on small graphs
- Solution: Reduce embedding dimension, increase regularization, or use simpler models
Issue: Slow training on large graphs
- Solution: Reduce negative samples, use DistMult for efficiency, or implement mini-batch training
Issue: Cannot handle new entities
- Solution: Use inductive models (KBGAT), incorporate entity features, or pre-compute embeddings for new entities based on their neighbors
Best Practices
- Start with ComplEx or RotatE for most tasks
- Use self-adversarial negative sampling
- Tune embedding dimension (typically 500-2000)
- Apply regularization to prevent overfitting
- Use filtered evaluation metrics
- Analyze performance per relation type
- Consider relation-specific models for heterogeneous graphs
- Validate predictions with domain experts