Files
gh-k-dense-ai-claude-scient…/skills/torchdrug/references/knowledge_graphs.md
2025-11-30 08:30:10 +08:00

8.3 KiB

Knowledge Graph Reasoning

Overview

Knowledge graphs represent structured information as entities and relations in a graph format. TorchDrug provides comprehensive support for knowledge graph completion (link prediction) using embedding-based models and neural reasoning approaches.

Available Datasets

General Knowledge Graphs

FB15k (Freebase subset):

  • 14,951 entities
  • 1,345 relation types
  • 592,213 triples
  • General world knowledge from Freebase

FB15k-237:

  • 14,541 entities
  • 237 relation types
  • 310,116 triples
  • Filtered version removing inverse relations
  • More challenging benchmark

WN18 (WordNet):

  • 40,943 entities (word senses)
  • 18 relation types (lexical relations)
  • 151,442 triples
  • Linguistic knowledge graph

WN18RR:

  • 40,943 entities
  • 11 relation types
  • 93,003 triples
  • Filtered WordNet removing easy inverse patterns

Biomedical Knowledge Graphs

Hetionet:

  • 45,158 entities (genes, compounds, diseases, pathways, etc.)
  • 24 relation types (treats, causes, binds, etc.)
  • 2,250,197 edges
  • Integrates 29 public biomedical databases
  • Designed for drug repurposing and disease understanding

Task: KnowledgeGraphCompletion

The primary task for knowledge graphs is link prediction - given a head entity and relation, predict the tail entity (or vice versa).

Task Modes

Head Prediction:

  • Given (?, relation, tail), predict head entity
  • "What can cause Disease X?"

Tail Prediction:

  • Given (head, relation, ?), predict tail entity
  • "What diseases does Gene X cause?"

Both:

  • Predict both head and tail
  • Standard evaluation protocol

Evaluation Metrics

Ranking Metrics:

  • Mean Rank (MR): Average rank of correct entity
  • Mean Reciprocal Rank (MRR): Average of 1/rank
  • Hits@K: Percentage of correct entities in top K predictions
    • Typically reported for K=1, 3, 10

Filtered vs Raw:

  • Filtered: Remove other known true triples from ranking
  • Raw: Rank among all possible entities
  • Filtered is standard for evaluation

Embedding Models

Translational Models

TransE (Translation Embedding):

  • Represents relations as translations in embedding space
  • h + r ≈ t (head + relation ≈ tail)
  • Simple and effective baseline
  • Works well for 1-to-1 relations
  • Struggles with N-to-N relations

RotatE (Rotation Embedding):

  • Relations as rotations in complex space
  • Better handles symmetric and inverse relations
  • State-of-the-art on many benchmarks
  • Can model composition patterns

Semantic Matching Models

DistMult:

  • Bilinear scoring function
  • Handles symmetric relations naturally
  • Cannot model asymmetric relations
  • Fast and memory efficient

ComplEx:

  • Complex-valued embeddings
  • Models asymmetric and inverse relations
  • Better than DistMult for most graphs
  • Balances expressiveness and efficiency

SimplE:

  • Extends DistMult with inverse relations
  • Fully expressive (can represent any relation pattern)
  • Two embeddings per entity (canonical and inverse)

Neural Logic Models

NeuralLP (Neural Logic Programming):

  • Learns logical rules through differentiable operations
  • Interprets predictions via learned rules
  • Good for sparse knowledge graphs
  • Computationally more expensive

KBGAT (Knowledge Base Graph Attention):

  • Graph attention networks for KG completion
  • Learns entity representations from neighborhood
  • Handles unseen entities through inductive learning
  • Better for incomplete graphs

Training Workflow

Basic Pipeline

from torchdrug import datasets, models, tasks, core

# Load dataset
dataset = datasets.FB15k237("~/kg-datasets/")

# Define model
model = models.RotatE(
    num_entity=dataset.num_entity,
    num_relation=dataset.num_relation,
    embedding_dim=2000,
    max_score=9
)

# Define task
task = tasks.KnowledgeGraphCompletion(
    model,
    num_negative=128,
    adversarial_temperature=2,
    criterion="bce"
)

# Train with PyTorch Lightning or custom loop

Negative Sampling

Strategies:

  • Uniform: Sample entities uniformly at random
  • Self-Adversarial: Weight samples by current model's scores
  • Type-Constrained: Sample only valid entity types for relation

Parameters:

  • num_negative: Number of negative samples per positive triple
  • adversarial_temperature: Temperature for self-adversarial weighting
  • Higher temperature = more focus on hard negatives

Loss Functions

Binary Cross-Entropy (BCE):

  • Treats each triple independently
  • Balanced classification between positive and negative

Margin Loss:

  • Ensures positive scores higher than negative by margin
  • max(0, margin + score_neg - score_pos)

Logistic Loss:

  • Smooth version of margin loss
  • Better gradient properties

Model Selection Guide

By Relation Patterns

1-to-1 Relations:

  • TransE works well
  • Any model will likely succeed

1-to-N Relations:

  • DistMult, ComplEx, SimplE
  • Avoid TransE

N-to-1 Relations:

  • DistMult, ComplEx, SimplE
  • Avoid TransE

N-to-N Relations:

  • ComplEx, SimplE, RotatE
  • Most challenging pattern

Symmetric Relations:

  • DistMult, ComplEx
  • RotatE with proper initialization

Antisymmetric Relations:

  • ComplEx, SimplE, RotatE
  • Avoid DistMult

Inverse Relations:

  • ComplEx, SimplE, RotatE
  • Important for bidirectional reasoning

Composition:

  • RotatE (best)
  • TransE (reasonable)
  • Captures multi-hop paths

By Dataset Characteristics

Small Graphs (< 50k entities):

  • ComplEx or SimplE
  • Lower embedding dimensions (200-500)

Large Graphs (> 100k entities):

  • DistMult for efficiency
  • RotatE for accuracy
  • Higher dimensions (500-2000)

Sparse Graphs:

  • NeuralLP (learns rules from limited data)
  • Pre-train embeddings on larger graphs

Dense, Complete Graphs:

  • Any embedding model works well
  • Choose based on relation patterns

Biomedical/Domain Graphs:

  • Consider type constraints in sampling
  • Use domain-specific negative sampling
  • Hetionet benefits from relation-specific models

Advanced Techniques

Multi-Hop Reasoning

Chain multiple relations to answer complex queries:

  • "What drugs treat diseases caused by gene X?"
  • Requires path-based or rule-based reasoning
  • NeuralLP naturally supports this

Temporal Knowledge Graphs

Extend to time-varying facts:

  • Add temporal information to triples
  • Predict future facts
  • Requires temporal encoding in models

Few-Shot Learning

Handle relations with few examples:

  • Meta-learning approaches
  • Transfer from related relations
  • Important for emerging knowledge

Inductive Learning

Generalize to unseen entities:

  • KBGAT and other GNN-based methods
  • Use entity features/descriptions
  • Critical for evolving knowledge graphs

Biomedical Applications

Drug Repurposing

Predict "drug treats disease" links in Hetionet:

  1. Train on known drug-disease associations
  2. Predict new treatment candidates
  3. Filter by mechanism (gene, pathway involvement)
  4. Validate predictions experimentally

Disease Gene Discovery

Identify genes associated with diseases:

  1. Model gene-disease-pathway networks
  2. Predict missing gene-disease links
  3. Incorporate protein interactions, expression data
  4. Prioritize candidates for validation

Protein Function Prediction

Link proteins to biological processes:

  1. Integrate protein interactions, GO terms
  2. Predict missing GO annotations
  3. Transfer function from similar proteins

Common Issues and Solutions

Issue: Poor performance on specific relation types

  • Solution: Analyze relation patterns, choose appropriate model, or use relation-specific models

Issue: Overfitting on small graphs

  • Solution: Reduce embedding dimension, increase regularization, or use simpler models

Issue: Slow training on large graphs

  • Solution: Reduce negative samples, use DistMult for efficiency, or implement mini-batch training

Issue: Cannot handle new entities

  • Solution: Use inductive models (KBGAT), incorporate entity features, or pre-compute embeddings for new entities based on their neighbors

Best Practices

  1. Start with ComplEx or RotatE for most tasks
  2. Use self-adversarial negative sampling
  3. Tune embedding dimension (typically 500-2000)
  4. Apply regularization to prevent overfitting
  5. Use filtered evaluation metrics
  6. Analyze performance per relation type
  7. Consider relation-specific models for heterogeneous graphs
  8. Validate predictions with domain experts