#!/usr/bin/env python3 """ OpenAlex API Client with rate limiting and error handling. Provides a robust client for interacting with the OpenAlex API with: - Automatic rate limiting (polite pool: 10 req/sec) - Exponential backoff retry logic - Pagination support - Batch operations support """ import time import requests from typing import Dict, List, Optional, Any from urllib.parse import urljoin class OpenAlexClient: """Client for OpenAlex API with rate limiting and error handling.""" BASE_URL = "https://api.openalex.org" def __init__(self, email: Optional[str] = None, requests_per_second: int = 10): """ Initialize OpenAlex client. Args: email: Email for polite pool (10x rate limit boost) requests_per_second: Max requests per second (default: 10 for polite pool) """ self.email = email self.requests_per_second = requests_per_second self.min_delay = 1.0 / requests_per_second self.last_request_time = 0 def _rate_limit(self): """Ensure requests don't exceed rate limit.""" current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.min_delay: time.sleep(self.min_delay - time_since_last) self.last_request_time = time.time() def _make_request( self, endpoint: str, params: Optional[Dict] = None, max_retries: int = 5 ) -> Dict[str, Any]: """ Make API request with retry logic. Args: endpoint: API endpoint (e.g., '/works', '/authors') params: Query parameters max_retries: Maximum number of retry attempts Returns: JSON response as dictionary """ if params is None: params = {} # Add email to params for polite pool if self.email: params['mailto'] = self.email url = urljoin(self.BASE_URL, endpoint) for attempt in range(max_retries): try: self._rate_limit() response = requests.get(url, params=params, timeout=30) if response.status_code == 200: return response.json() elif response.status_code == 403: # Rate limited wait_time = 2 ** attempt print(f"Rate limited. Waiting {wait_time}s before retry...") time.sleep(wait_time) elif response.status_code >= 500: # Server error wait_time = 2 ** attempt print(f"Server error. Waiting {wait_time}s before retry...") time.sleep(wait_time) else: # Other error - don't retry response.raise_for_status() except requests.exceptions.Timeout: if attempt < max_retries - 1: wait_time = 2 ** attempt print(f"Request timeout. Waiting {wait_time}s before retry...") time.sleep(wait_time) else: raise raise Exception(f"Failed after {max_retries} retries") def search_works( self, search: Optional[str] = None, filter_params: Optional[Dict] = None, per_page: int = 200, page: int = 1, sort: Optional[str] = None, select: Optional[List[str]] = None ) -> Dict[str, Any]: """ Search works with filters. Args: search: Full-text search query filter_params: Dictionary of filter parameters per_page: Results per page (max: 200) page: Page number sort: Sort parameter (e.g., 'cited_by_count:desc') select: List of fields to return Returns: API response with meta and results """ params = { 'per-page': min(per_page, 200), 'page': page } if search: params['search'] = search if filter_params: filter_str = ','.join([f"{k}:{v}" for k, v in filter_params.items()]) params['filter'] = filter_str if sort: params['sort'] = sort if select: params['select'] = ','.join(select) return self._make_request('/works', params) def get_entity(self, entity_type: str, entity_id: str) -> Dict[str, Any]: """ Get single entity by ID. Args: entity_type: Type of entity ('works', 'authors', 'institutions', etc.) entity_id: OpenAlex ID or external ID (DOI, ORCID, etc.) Returns: Entity object """ endpoint = f"/{entity_type}/{entity_id}" return self._make_request(endpoint) def batch_lookup( self, entity_type: str, ids: List[str], id_field: str = 'openalex_id' ) -> List[Dict[str, Any]]: """ Look up multiple entities by ID efficiently. Args: entity_type: Type of entity ('works', 'authors', etc.) ids: List of IDs (up to 50 per batch) id_field: ID field name ('openalex_id', 'doi', 'orcid', etc.) Returns: List of entity objects """ all_results = [] # Process in batches of 50 for i in range(0, len(ids), 50): batch = ids[i:i+50] filter_value = '|'.join(batch) params = { 'filter': f"{id_field}:{filter_value}", 'per-page': 50 } response = self._make_request(f"/{entity_type}", params) all_results.extend(response.get('results', [])) return all_results def paginate_all( self, endpoint: str, params: Optional[Dict] = None, max_results: Optional[int] = None ) -> List[Dict[str, Any]]: """ Paginate through all results. Args: endpoint: API endpoint params: Query parameters max_results: Maximum number of results to retrieve (None for all) Returns: List of all results """ if params is None: params = {} params['per-page'] = 200 # Use maximum page size params['page'] = 1 all_results = [] while True: response = self._make_request(endpoint, params) results = response.get('results', []) all_results.extend(results) # Check if we've hit max_results if max_results and len(all_results) >= max_results: return all_results[:max_results] # Check if there are more pages meta = response.get('meta', {}) total_count = meta.get('count', 0) current_count = len(all_results) if current_count >= total_count: break params['page'] += 1 return all_results def sample_works( self, sample_size: int, seed: Optional[int] = None, filter_params: Optional[Dict] = None ) -> List[Dict[str, Any]]: """ Get random sample of works. Args: sample_size: Number of samples to retrieve seed: Random seed for reproducibility filter_params: Optional filters to apply Returns: List of sampled works """ params = { 'sample': min(sample_size, 10000), # API limit per request 'per-page': 200 } if seed is not None: params['seed'] = seed if filter_params: filter_str = ','.join([f"{k}:{v}" for k, v in filter_params.items()]) params['filter'] = filter_str # For large samples, need multiple requests with different seeds if sample_size > 10000: all_samples = [] seen_ids = set() for i in range((sample_size // 10000) + 1): current_seed = seed + i if seed else i params['seed'] = current_seed params['sample'] = min(10000, sample_size - len(all_samples)) response = self._make_request('/works', params) results = response.get('results', []) # Deduplicate for result in results: work_id = result.get('id') if work_id not in seen_ids: seen_ids.add(work_id) all_samples.append(result) if len(all_samples) >= sample_size: break return all_samples[:sample_size] else: response = self._make_request('/works', params) return response.get('results', []) def group_by( self, entity_type: str, group_field: str, filter_params: Optional[Dict] = None ) -> List[Dict[str, Any]]: """ Aggregate results by field. Args: entity_type: Type of entity ('works', 'authors', etc.) group_field: Field to group by filter_params: Optional filters Returns: List of grouped results with counts """ params = { 'group_by': group_field } if filter_params: filter_str = ','.join([f"{k}:{v}" for k, v in filter_params.items()]) params['filter'] = filter_str response = self._make_request(f"/{entity_type}", params) return response.get('group_by', []) if __name__ == "__main__": # Example usage client = OpenAlexClient(email="your-email@example.com") # Search for works about machine learning results = client.search_works( search="machine learning", filter_params={"publication_year": "2023"}, per_page=10 ) print(f"Found {results['meta']['count']} works") for work in results['results']: print(f"- {work['title']}")