Files
2025-11-30 08:30:10 +08:00

338 lines
9.9 KiB
Python

#!/usr/bin/env python3
"""
OpenAlex API Client with rate limiting and error handling.
Provides a robust client for interacting with the OpenAlex API with:
- Automatic rate limiting (polite pool: 10 req/sec)
- Exponential backoff retry logic
- Pagination support
- Batch operations support
"""
import time
import requests
from typing import Dict, List, Optional, Any
from urllib.parse import urljoin
class OpenAlexClient:
"""Client for OpenAlex API with rate limiting and error handling."""
BASE_URL = "https://api.openalex.org"
def __init__(self, email: Optional[str] = None, requests_per_second: int = 10):
"""
Initialize OpenAlex client.
Args:
email: Email for polite pool (10x rate limit boost)
requests_per_second: Max requests per second (default: 10 for polite pool)
"""
self.email = email
self.requests_per_second = requests_per_second
self.min_delay = 1.0 / requests_per_second
self.last_request_time = 0
def _rate_limit(self):
"""Ensure requests don't exceed rate limit."""
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_delay:
time.sleep(self.min_delay - time_since_last)
self.last_request_time = time.time()
def _make_request(
self,
endpoint: str,
params: Optional[Dict] = None,
max_retries: int = 5
) -> Dict[str, Any]:
"""
Make API request with retry logic.
Args:
endpoint: API endpoint (e.g., '/works', '/authors')
params: Query parameters
max_retries: Maximum number of retry attempts
Returns:
JSON response as dictionary
"""
if params is None:
params = {}
# Add email to params for polite pool
if self.email:
params['mailto'] = self.email
url = urljoin(self.BASE_URL, endpoint)
for attempt in range(max_retries):
try:
self._rate_limit()
response = requests.get(url, params=params, timeout=30)
if response.status_code == 200:
return response.json()
elif response.status_code == 403:
# Rate limited
wait_time = 2 ** attempt
print(f"Rate limited. Waiting {wait_time}s before retry...")
time.sleep(wait_time)
elif response.status_code >= 500:
# Server error
wait_time = 2 ** attempt
print(f"Server error. Waiting {wait_time}s before retry...")
time.sleep(wait_time)
else:
# Other error - don't retry
response.raise_for_status()
except requests.exceptions.Timeout:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f"Request timeout. Waiting {wait_time}s before retry...")
time.sleep(wait_time)
else:
raise
raise Exception(f"Failed after {max_retries} retries")
def search_works(
self,
search: Optional[str] = None,
filter_params: Optional[Dict] = None,
per_page: int = 200,
page: int = 1,
sort: Optional[str] = None,
select: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Search works with filters.
Args:
search: Full-text search query
filter_params: Dictionary of filter parameters
per_page: Results per page (max: 200)
page: Page number
sort: Sort parameter (e.g., 'cited_by_count:desc')
select: List of fields to return
Returns:
API response with meta and results
"""
params = {
'per-page': min(per_page, 200),
'page': page
}
if search:
params['search'] = search
if filter_params:
filter_str = ','.join([f"{k}:{v}" for k, v in filter_params.items()])
params['filter'] = filter_str
if sort:
params['sort'] = sort
if select:
params['select'] = ','.join(select)
return self._make_request('/works', params)
def get_entity(self, entity_type: str, entity_id: str) -> Dict[str, Any]:
"""
Get single entity by ID.
Args:
entity_type: Type of entity ('works', 'authors', 'institutions', etc.)
entity_id: OpenAlex ID or external ID (DOI, ORCID, etc.)
Returns:
Entity object
"""
endpoint = f"/{entity_type}/{entity_id}"
return self._make_request(endpoint)
def batch_lookup(
self,
entity_type: str,
ids: List[str],
id_field: str = 'openalex_id'
) -> List[Dict[str, Any]]:
"""
Look up multiple entities by ID efficiently.
Args:
entity_type: Type of entity ('works', 'authors', etc.)
ids: List of IDs (up to 50 per batch)
id_field: ID field name ('openalex_id', 'doi', 'orcid', etc.)
Returns:
List of entity objects
"""
all_results = []
# Process in batches of 50
for i in range(0, len(ids), 50):
batch = ids[i:i+50]
filter_value = '|'.join(batch)
params = {
'filter': f"{id_field}:{filter_value}",
'per-page': 50
}
response = self._make_request(f"/{entity_type}", params)
all_results.extend(response.get('results', []))
return all_results
def paginate_all(
self,
endpoint: str,
params: Optional[Dict] = None,
max_results: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Paginate through all results.
Args:
endpoint: API endpoint
params: Query parameters
max_results: Maximum number of results to retrieve (None for all)
Returns:
List of all results
"""
if params is None:
params = {}
params['per-page'] = 200 # Use maximum page size
params['page'] = 1
all_results = []
while True:
response = self._make_request(endpoint, params)
results = response.get('results', [])
all_results.extend(results)
# Check if we've hit max_results
if max_results and len(all_results) >= max_results:
return all_results[:max_results]
# Check if there are more pages
meta = response.get('meta', {})
total_count = meta.get('count', 0)
current_count = len(all_results)
if current_count >= total_count:
break
params['page'] += 1
return all_results
def sample_works(
self,
sample_size: int,
seed: Optional[int] = None,
filter_params: Optional[Dict] = None
) -> List[Dict[str, Any]]:
"""
Get random sample of works.
Args:
sample_size: Number of samples to retrieve
seed: Random seed for reproducibility
filter_params: Optional filters to apply
Returns:
List of sampled works
"""
params = {
'sample': min(sample_size, 10000), # API limit per request
'per-page': 200
}
if seed is not None:
params['seed'] = seed
if filter_params:
filter_str = ','.join([f"{k}:{v}" for k, v in filter_params.items()])
params['filter'] = filter_str
# For large samples, need multiple requests with different seeds
if sample_size > 10000:
all_samples = []
seen_ids = set()
for i in range((sample_size // 10000) + 1):
current_seed = seed + i if seed else i
params['seed'] = current_seed
params['sample'] = min(10000, sample_size - len(all_samples))
response = self._make_request('/works', params)
results = response.get('results', [])
# Deduplicate
for result in results:
work_id = result.get('id')
if work_id not in seen_ids:
seen_ids.add(work_id)
all_samples.append(result)
if len(all_samples) >= sample_size:
break
return all_samples[:sample_size]
else:
response = self._make_request('/works', params)
return response.get('results', [])
def group_by(
self,
entity_type: str,
group_field: str,
filter_params: Optional[Dict] = None
) -> List[Dict[str, Any]]:
"""
Aggregate results by field.
Args:
entity_type: Type of entity ('works', 'authors', etc.)
group_field: Field to group by
filter_params: Optional filters
Returns:
List of grouped results with counts
"""
params = {
'group_by': group_field
}
if filter_params:
filter_str = ','.join([f"{k}:{v}" for k, v in filter_params.items()])
params['filter'] = filter_str
response = self._make_request(f"/{entity_type}", params)
return response.get('group_by', [])
if __name__ == "__main__":
# Example usage
client = OpenAlexClient(email="your-email@example.com")
# Search for works about machine learning
results = client.search_works(
search="machine learning",
filter_params={"publication_year": "2023"},
per_page=10
)
print(f"Found {results['meta']['count']} works")
for work in results['results']:
print(f"- {work['title']}")