Initial commit

This commit is contained in:
Zhongwei Li
2025-11-29 18:02:40 +08:00
commit 69617b598e
25 changed files with 5790 additions and 0 deletions

View File

@@ -0,0 +1,310 @@
#!/usr/bin/env python3
"""
Organize PDFs and metadata from various sources (BibTeX, RIS, directory, DOI list).
Standardizes file naming and creates a unified metadata JSON for downstream processing.
"""
import argparse
import json
import shutil
from pathlib import Path
from typing import Dict, List, Optional
import re
try:
from pybtex.database.input import bibtex
BIBTEX_AVAILABLE = True
except ImportError:
BIBTEX_AVAILABLE = False
print("Warning: pybtex not installed. BibTeX support disabled.")
try:
import rispy
RIS_AVAILABLE = True
except ImportError:
RIS_AVAILABLE = False
print("Warning: rispy not installed. RIS support disabled.")
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Organize PDFs and metadata from various sources'
)
parser.add_argument(
'--source-type',
choices=['bibtex', 'ris', 'directory', 'doi_list'],
required=True,
help='Type of source data'
)
parser.add_argument(
'--source',
required=True,
help='Path to source file (BibTeX/RIS file, directory, or DOI list)'
)
parser.add_argument(
'--pdf-dir',
help='Directory containing PDFs (for bibtex/ris with relative paths)'
)
parser.add_argument(
'--output',
default='metadata.json',
help='Output metadata JSON file'
)
parser.add_argument(
'--organize-pdfs',
action='store_true',
help='Copy PDFs to standardized directory structure'
)
parser.add_argument(
'--pdf-output-dir',
default='organized_pdfs',
help='Directory for organized PDFs'
)
return parser.parse_args()
def load_bibtex_metadata(bib_path: Path, pdf_base_dir: Optional[Path] = None) -> List[Dict]:
"""Load metadata from BibTeX file"""
if not BIBTEX_AVAILABLE:
raise ImportError("pybtex is required for BibTeX support. Install with: pip install pybtex")
parser = bibtex.Parser()
bib_data = parser.parse_file(str(bib_path))
metadata = []
for key, entry in bib_data.entries.items():
record = {
'id': key,
'type': entry.type,
'title': entry.fields.get('title', ''),
'year': entry.fields.get('year', ''),
'doi': entry.fields.get('doi', ''),
'abstract': entry.fields.get('abstract', ''),
'journal': entry.fields.get('journal', ''),
'authors': ', '.join(
[' '.join([p for p in person.last_names + person.first_names])
for person in entry.persons.get('author', [])]
),
'keywords': entry.fields.get('keywords', ''),
'pdf_path': None
}
# Extract PDF path from file field
if 'file' in entry.fields:
file_field = entry.fields['file']
if file_field.startswith('{') and file_field.endswith('}'):
file_field = file_field[1:-1]
for file_entry in file_field.split(';'):
parts = file_entry.strip().split(':')
if len(parts) >= 3 and parts[2].lower() == 'application/pdf':
pdf_path = parts[1].strip()
if pdf_base_dir:
pdf_path = str(pdf_base_dir / pdf_path)
record['pdf_path'] = pdf_path
break
metadata.append(record)
print(f"Loaded {len(metadata)} entries from BibTeX file")
return metadata
def load_ris_metadata(ris_path: Path, pdf_base_dir: Optional[Path] = None) -> List[Dict]:
"""Load metadata from RIS file"""
if not RIS_AVAILABLE:
raise ImportError("rispy is required for RIS support. Install with: pip install rispy")
with open(ris_path, 'r', encoding='utf-8') as f:
entries = rispy.load(f)
metadata = []
for i, entry in enumerate(entries):
# Generate ID from first author and year or use index
first_author = entry.get('authors', [None])[0] or 'Unknown'
year = entry.get('year', 'NoYear')
entry_id = f"{first_author.split()[-1]}{year}_{i}"
record = {
'id': entry_id,
'type': entry.get('type_of_reference', 'article'),
'title': entry.get('title', ''),
'year': str(entry.get('year', '')),
'doi': entry.get('doi', ''),
'abstract': entry.get('abstract', ''),
'journal': entry.get('journal_name', ''),
'authors': '; '.join(entry.get('authors', [])),
'keywords': '; '.join(entry.get('keywords', [])),
'pdf_path': None
}
# Try to find PDF in standard locations
if pdf_base_dir:
# Common patterns: FirstAuthorYear.pdf, doi_cleaned.pdf, etc.
pdf_candidates = [
f"{entry_id}.pdf",
f"{first_author.split()[-1]}_{year}.pdf"
]
if record['doi']:
safe_doi = re.sub(r'[^\w\-_]', '_', record['doi'])
pdf_candidates.append(f"{safe_doi}.pdf")
for candidate in pdf_candidates:
pdf_path = pdf_base_dir / candidate
if pdf_path.exists():
record['pdf_path'] = str(pdf_path)
break
metadata.append(record)
print(f"Loaded {len(metadata)} entries from RIS file")
return metadata
def load_directory_metadata(dir_path: Path) -> List[Dict]:
"""Load metadata by scanning directory for PDFs"""
pdf_files = list(dir_path.glob('**/*.pdf'))
metadata = []
for pdf_path in pdf_files:
# Generate ID from filename
entry_id = pdf_path.stem
record = {
'id': entry_id,
'type': 'article',
'title': entry_id.replace('_', ' '),
'year': '',
'doi': '',
'abstract': '',
'journal': '',
'authors': '',
'keywords': '',
'pdf_path': str(pdf_path)
}
# Try to extract DOI from filename if present
doi_match = re.search(r'10\.\d{4,}/[^\s]+', entry_id)
if doi_match:
record['doi'] = doi_match.group(0)
metadata.append(record)
print(f"Found {len(metadata)} PDFs in directory")
return metadata
def load_doi_list_metadata(doi_list_path: Path) -> List[Dict]:
"""Load metadata from a list of DOIs (will need to fetch metadata separately)"""
with open(doi_list_path, 'r') as f:
dois = [line.strip() for line in f if line.strip()]
metadata = []
for doi in dois:
safe_doi = re.sub(r'[^\w\-_]', '_', doi)
record = {
'id': safe_doi,
'type': 'article',
'title': '',
'year': '',
'doi': doi,
'abstract': '',
'journal': '',
'authors': '',
'keywords': '',
'pdf_path': None
}
metadata.append(record)
print(f"Loaded {len(metadata)} DOIs")
print("Note: You'll need to fetch full metadata and PDFs separately")
return metadata
def organize_pdfs(metadata: List[Dict], output_dir: Path) -> List[Dict]:
"""Copy and rename PDFs to standardized directory structure"""
output_dir.mkdir(parents=True, exist_ok=True)
organized_metadata = []
stats = {'copied': 0, 'missing': 0, 'total': len(metadata)}
for record in metadata:
if record['pdf_path'] and Path(record['pdf_path']).exists():
source_path = Path(record['pdf_path'])
dest_path = output_dir / f"{record['id']}.pdf"
try:
shutil.copy2(source_path, dest_path)
record['pdf_path'] = str(dest_path)
stats['copied'] += 1
except Exception as e:
print(f"Error copying {source_path}: {e}")
stats['missing'] += 1
else:
if record['pdf_path']:
print(f"PDF not found: {record['pdf_path']}")
stats['missing'] += 1
organized_metadata.append(record)
print(f"\nPDF Organization Summary:")
print(f" Total entries: {stats['total']}")
print(f" PDFs copied: {stats['copied']}")
print(f" PDFs missing: {stats['missing']}")
return organized_metadata
def save_metadata(metadata: List[Dict], output_path: Path):
"""Save metadata to JSON file"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
print(f"\nMetadata saved to: {output_path}")
def main():
args = parse_args()
source_path = Path(args.source)
pdf_base_dir = Path(args.pdf_dir) if args.pdf_dir else None
output_path = Path(args.output)
# Load metadata based on source type
if args.source_type == 'bibtex':
metadata = load_bibtex_metadata(source_path, pdf_base_dir)
elif args.source_type == 'ris':
metadata = load_ris_metadata(source_path, pdf_base_dir)
elif args.source_type == 'directory':
metadata = load_directory_metadata(source_path)
elif args.source_type == 'doi_list':
metadata = load_doi_list_metadata(source_path)
else:
raise ValueError(f"Unknown source type: {args.source_type}")
# Organize PDFs if requested
if args.organize_pdfs:
pdf_output_dir = Path(args.pdf_output_dir)
metadata = organize_pdfs(metadata, pdf_output_dir)
# Save metadata
save_metadata(metadata, output_path)
# Print summary statistics
total = len(metadata)
with_pdfs = sum(1 for r in metadata if r['pdf_path'])
with_abstracts = sum(1 for r in metadata if r['abstract'])
with_dois = sum(1 for r in metadata if r['doi'])
print(f"\nMetadata Summary:")
print(f" Total entries: {total}")
print(f" With PDFs: {with_pdfs}")
print(f" With abstracts: {with_abstracts}")
print(f" With DOIs: {with_dois}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,468 @@
#!/usr/bin/env python3
"""
Filter papers based on abstract content using Claude API or local models.
Reduces processing costs by identifying relevant papers before full PDF extraction.
This script template needs to be customized with your specific filtering criteria.
Supports:
- Claude Haiku (cheap, fast API option)
- Claude Sonnet (more accurate API option)
- Local models via Ollama (free, private, requires local setup)
"""
import argparse
import json
import os
import time
from pathlib import Path
from typing import Dict, List, Optional
from anthropic import Anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
try:
import requests
REQUESTS_AVAILABLE = True
except ImportError:
REQUESTS_AVAILABLE = False
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Filter papers by analyzing abstracts with Claude or local models',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Backend options:
anthropic-haiku : Claude 3 Haiku (cheap, fast, ~$0.25/million input tokens)
anthropic-sonnet : Claude 3.5 Sonnet (more accurate, ~$3/million input tokens)
ollama : Local model via Ollama (free, requires local setup)
Local model setup (Ollama):
1. Install Ollama: https://ollama.com
2. Pull a model: ollama pull llama3.1:8b
3. Run server: ollama serve (usually starts automatically)
4. Use --backend ollama --ollama-model llama3.1:8b
Recommended models for Ollama:
- llama3.1:8b (good balance)
- llama3.1:70b (better accuracy, needs more RAM)
- mistral:7b (fast, good for simple filtering)
- qwen2.5:7b (good multilingual support)
"""
)
parser.add_argument(
'--metadata',
required=True,
help='Input metadata JSON file from step 01'
)
parser.add_argument(
'--output',
default='filtered_papers.json',
help='Output JSON file with filter results'
)
parser.add_argument(
'--backend',
choices=['anthropic-haiku', 'anthropic-sonnet', 'ollama'],
default='anthropic-haiku',
help='Model backend to use (default: anthropic-haiku for cost efficiency)'
)
parser.add_argument(
'--ollama-model',
default='llama3.1:8b',
help='Ollama model name (default: llama3.1:8b)'
)
parser.add_argument(
'--ollama-url',
default='http://localhost:11434',
help='Ollama server URL (default: http://localhost:11434)'
)
parser.add_argument(
'--use-batches',
action='store_true',
help='Use Anthropic Batches API (only for anthropic backends)'
)
parser.add_argument(
'--test',
action='store_true',
help='Run in test mode (process only 10 records)'
)
return parser.parse_args()
def load_metadata(metadata_path: Path) -> List[Dict]:
"""Load metadata from JSON file"""
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_existing_results(output_path: Path) -> Dict:
"""Load existing filter results if available"""
if output_path.exists():
with open(output_path, 'r', encoding='utf-8') as f:
return json.load(f)
return {}
def save_results(results: Dict, output_path: Path):
"""Save filter results to JSON file"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
def create_filter_prompt(title: str, abstract: str) -> str:
"""
Create the filtering prompt.
TODO: CUSTOMIZE THIS PROMPT FOR YOUR SPECIFIC USE CASE
This is a template. Replace the example criteria with your own.
"""
return f"""You are analyzing scientific literature to identify relevant papers for a research project.
<title>
{title}
</title>
<abstract>
{abstract}
</abstract>
Your task is to determine if this paper meets the following criteria:
TODO: Replace these example criteria with your own:
1. Does the paper contain PRIMARY empirical data (not review/meta-analysis)?
2. Does the paper report [YOUR SPECIFIC DATA TYPE, e.g., "field observations", "experimental measurements", "clinical outcomes"]?
3. Is the geographic/temporal/taxonomic scope relevant to [YOUR STUDY SYSTEM]?
Important considerations:
- Be conservative: when in doubt, include the paper (false positives are better than false negatives)
- Distinguish between primary data and citations of others' work
- Consider whether the abstract suggests the full paper likely contains the data of interest
Provide your determination as a JSON object with these boolean fields:
1. "has_relevant_data": true if the paper likely contains the data type of interest
2. "is_primary_research": true if the paper reports original empirical data
3. "meets_scope": true if the study system/scope is relevant
Also provide:
4. "confidence": your confidence level (high/medium/low)
5. "reasoning": brief explanation (1-2 sentences)
Wrap your response in <output> tags. Example:
<output>
{{
"has_relevant_data": true,
"is_primary_research": true,
"meets_scope": true,
"confidence": "high",
"reasoning": "Abstract explicitly mentions field observations of the target phenomenon in the relevant geographic region."
}}
</output>
Base your determination solely on the title and abstract provided."""
def extract_json_from_xml(text: str) -> Dict:
"""Extract JSON from XML output tags in Claude's response"""
import re
match = re.search(r'<output>\s*(\{.*?\})\s*</output>', text, re.DOTALL)
if match:
json_str = match.group(1)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
print(f"JSON string: {json_str}")
return None
return None
def filter_paper_ollama(record: Dict, ollama_url: str, ollama_model: str) -> Dict:
"""Use local Ollama model to filter a single paper"""
if not REQUESTS_AVAILABLE:
return {
'status': 'error',
'reason': 'requests library not available. Install with: pip install requests'
}
if not record.get('title') or not record.get('abstract'):
return {
'status': 'skipped',
'reason': 'missing_title_or_abstract'
}
max_retries = 3
for attempt in range(max_retries):
try:
# Ollama uses OpenAI-compatible chat API
response = requests.post(
f"{ollama_url}/api/chat",
json={
"model": ollama_model,
"messages": [
{
"role": "system",
"content": "You are a scientific literature analyst specializing in identifying relevant papers for systematic reviews and meta-analyses."
},
{
"role": "user",
"content": create_filter_prompt(record['title'], record['abstract'])
}
],
"stream": False,
"options": {
"temperature": 0,
"num_predict": 2048
}
},
timeout=60
)
if response.status_code == 200:
data = response.json()
content = data.get('message', {}).get('content', '')
result = extract_json_from_xml(content)
if result:
return {
'status': 'success',
'filter_result': result,
'model_used': ollama_model
}
else:
return {
'status': 'error',
'reason': 'failed_to_parse_json',
'raw_response': content[:500]
}
else:
return {
'status': 'error',
'reason': f'Ollama API error: {response.status_code} {response.text[:200]}'
}
except requests.exceptions.ConnectionError:
return {
'status': 'error',
'reason': f'Cannot connect to Ollama at {ollama_url}. Make sure Ollama is running: ollama serve'
}
except Exception as e:
if attempt == max_retries - 1:
return {
'status': 'error',
'reason': str(e)
}
time.sleep(2 ** attempt)
def filter_paper_direct(client: Anthropic, record: Dict, model: str) -> Dict:
"""Use Claude API directly to filter a single paper"""
if not record.get('title') or not record.get('abstract'):
return {
'status': 'skipped',
'reason': 'missing_title_or_abstract'
}
max_retries = 3
for attempt in range(max_retries):
try:
response = client.messages.create(
model=model,
max_tokens=2048,
temperature=0,
system="You are a scientific literature analyst specializing in identifying relevant papers for systematic reviews and meta-analyses.",
messages=[{
"role": "user",
"content": create_filter_prompt(record['title'], record['abstract'])
}]
)
result = extract_json_from_xml(response.content[0].text)
if result:
return {
'status': 'success',
'filter_result': result,
'model_used': model
}
else:
return {
'status': 'error',
'reason': 'failed_to_parse_json'
}
except Exception as e:
if attempt == max_retries - 1:
return {
'status': 'error',
'reason': str(e)
}
time.sleep(2 ** attempt)
def filter_papers_batch(client: Anthropic, records: List[Dict], model: str) -> Dict[str, Dict]:
"""Use Claude Batches API to filter multiple papers efficiently"""
requests = []
for record in records:
if not record.get('title') or not record.get('abstract'):
continue
requests.append(Request(
custom_id=record['id'],
params=MessageCreateParamsNonStreaming(
model=model,
max_tokens=2048,
temperature=0,
system="You are a scientific literature analyst specializing in identifying relevant papers for systematic reviews and meta-analyses.",
messages=[{
"role": "user",
"content": create_filter_prompt(record['title'], record['abstract'])
}]
)
))
if not requests:
print("No papers to process (missing titles or abstracts)")
return {}
# Create batch
print(f"Creating batch with {len(requests)} requests...")
message_batch = client.messages.batches.create(requests=requests)
print(f"Batch created: {message_batch.id}")
# Poll for completion
while message_batch.processing_status == "in_progress":
print("Waiting for batch processing...")
time.sleep(30)
message_batch = client.messages.batches.retrieve(message_batch.id)
# Process results
results = {}
if message_batch.processing_status == "ended":
print("Batch completed. Processing results...")
for result in client.messages.batches.results(message_batch.id):
if result.result.type == "succeeded":
filter_result = extract_json_from_xml(
result.result.message.content[0].text
)
if filter_result:
results[result.custom_id] = {
'status': 'success',
'filter_result': filter_result
}
else:
results[result.custom_id] = {
'status': 'error',
'reason': 'failed_to_parse_json'
}
else:
results[result.custom_id] = {
'status': 'error',
'reason': f"{result.result.type}: {getattr(result.result, 'error', 'unknown error')}"
}
else:
print(f"Batch failed with status: {message_batch.processing_status}")
return results
def get_model_name(backend: str) -> str:
"""Get the appropriate model name for the backend"""
if backend == 'anthropic-haiku':
return 'claude-3-5-haiku-20241022'
elif backend == 'anthropic-sonnet':
return 'claude-3-5-sonnet-20241022'
return backend
def main():
args = parse_args()
# Backend-specific setup
client = None
if args.backend.startswith('anthropic'):
if not os.getenv('ANTHROPIC_API_KEY'):
raise ValueError("Please set ANTHROPIC_API_KEY environment variable for Anthropic backends")
client = Anthropic()
model = get_model_name(args.backend)
print(f"Using Anthropic backend: {model}")
elif args.backend == 'ollama':
if args.use_batches:
print("Warning: Batches API not available for Ollama. Processing sequentially.")
args.use_batches = False
print(f"Using Ollama backend: {args.ollama_model} at {args.ollama_url}")
print("Make sure Ollama is running: ollama serve")
# Load metadata
metadata = load_metadata(Path(args.metadata))
print(f"Loaded {len(metadata)} metadata records")
# Apply test mode if specified
if args.test:
metadata = metadata[:10]
print(f"Test mode: processing {len(metadata)} records")
# Load existing results
output_path = Path(args.output)
results = load_existing_results(output_path)
print(f"Loaded {len(results)} existing results")
# Identify papers to process
to_process = [r for r in metadata if r['id'] not in results]
print(f"Papers to process: {len(to_process)}")
if not to_process:
print("All papers already processed!")
return
# Process papers based on backend
if args.backend == 'ollama':
print("Processing papers with Ollama...")
for record in to_process:
print(f"Processing: {record['id']}")
result = filter_paper_ollama(record, args.ollama_url, args.ollama_model)
results[record['id']] = result
save_results(results, output_path)
# No sleep needed for local models
elif args.use_batches:
print("Using Batches API...")
batch_results = filter_papers_batch(client, to_process, model)
results.update(batch_results)
else:
print("Processing papers sequentially with Anthropic API...")
for record in to_process:
print(f"Processing: {record['id']}")
result = filter_paper_direct(client, record, model)
results[record['id']] = result
save_results(results, output_path)
time.sleep(1) # Rate limiting
# Save final results
save_results(results, output_path)
# Print summary statistics
total = len(results)
successful = sum(1 for r in results.values() if r.get('status') == 'success')
relevant = sum(
1 for r in results.values()
if r.get('status') == 'success' and r.get('filter_result', {}).get('has_relevant_data')
)
print(f"\n{'='*60}")
print("Filtering Summary")
print(f"{'='*60}")
print(f"Total papers processed: {total}")
print(f"Successfully analyzed: {successful}")
print(f"Papers with relevant data: {relevant}")
print(f"Relevance rate: {relevant/successful*100:.1f}%" if successful > 0 else "N/A")
print(f"\nResults saved to: {output_path}")
print(f"\nNext step: Review results and proceed to PDF extraction for relevant papers")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,478 @@
#!/usr/bin/env python3
"""
Extract structured data from PDFs using Claude API.
Supports multiple PDF processing methods and prompt caching for efficiency.
This script template needs to be customized with your specific extraction schema.
"""
import argparse
import base64
import json
import os
import time
from pathlib import Path
from typing import Dict, List, Optional
import re
from anthropic import Anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
# Configuration
BATCH_SIZE = 5
SIMULTANEOUS_BATCHES = 4
BATCH_CHECK_INTERVAL = 30
BATCH_SUBMISSION_INTERVAL = 20
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Extract structured data from PDFs using Claude'
)
parser.add_argument(
'--metadata',
required=True,
help='Input metadata JSON file (from step 01 or 02)'
)
parser.add_argument(
'--schema',
required=True,
help='JSON file defining extraction schema and prompts'
)
parser.add_argument(
'--output',
default='extracted_data.json',
help='Output JSON file with extraction results'
)
parser.add_argument(
'--method',
choices=['base64', 'files_api', 'batches'],
default='batches',
help='PDF processing method (default: batches)'
)
parser.add_argument(
'--use-caching',
action='store_true',
help='Enable prompt caching (reduces costs by ~90%% for repeated queries)'
)
parser.add_argument(
'--test',
action='store_true',
help='Run in test mode (process only 3 PDFs)'
)
parser.add_argument(
'--model',
default='claude-3-5-sonnet-20241022',
help='Claude model to use'
)
parser.add_argument(
'--filter-results',
help='Optional: JSON file with filter results from step 02 (only process relevant papers)'
)
return parser.parse_args()
def load_metadata(metadata_path: Path) -> List[Dict]:
"""Load metadata from JSON file"""
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_schema(schema_path: Path) -> Dict:
"""Load extraction schema definition"""
with open(schema_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_filter_results(filter_path: Path) -> Dict:
"""Load filter results from step 02"""
with open(filter_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_existing_results(output_path: Path) -> Dict:
"""Load existing extraction results if available"""
if output_path.exists():
with open(output_path, 'r', encoding='utf-8') as f:
return json.load(f)
return {}
def save_results(results: Dict, output_path: Path):
"""Save extraction results to JSON file"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
def create_extraction_prompt(schema: Dict) -> str:
"""
Create extraction prompt from schema definition.
The schema JSON should contain:
- system_context: Description of the analysis task
- instructions: Step-by-step analysis instructions
- output_schema: JSON schema for the output
- output_example: Example of desired output
TODO: Customize schema.json for your specific use case
"""
prompt_parts = []
# Add objective
if 'objective' in schema:
prompt_parts.append(f"Your objective is to {schema['objective']}\n")
# Add instructions
if 'instructions' in schema:
prompt_parts.append("Please follow these steps:\n")
for i, instruction in enumerate(schema['instructions'], 1):
prompt_parts.append(f"{i}. {instruction}")
prompt_parts.append("")
# Add analysis framework
if 'analysis_steps' in schema:
prompt_parts.append("<analysis_framework>")
for step in schema['analysis_steps']:
prompt_parts.append(f"- {step}")
prompt_parts.append("</analysis_framework>\n")
prompt_parts.append(
"Your analysis must be wrapped within <analysis> tags. "
"Be thorough and explicit in your reasoning.\n"
)
# Add output schema explanation
if 'output_schema' in schema:
prompt_parts.append("<output_schema>")
prompt_parts.append(json.dumps(schema['output_schema'], indent=2))
prompt_parts.append("</output_schema>\n")
# Add output example
if 'output_example' in schema:
prompt_parts.append("<output_example>")
prompt_parts.append(json.dumps(schema['output_example'], indent=2))
prompt_parts.append("</output_example>\n")
# Add important notes
if 'important_notes' in schema:
prompt_parts.append("Important considerations:")
for note in schema['important_notes']:
prompt_parts.append(f"- {note}")
prompt_parts.append("")
# Add final instruction
prompt_parts.append(
"After your analysis, provide the final output in the following JSON format, "
"wrapped in <output> tags. The output must be valid, parseable JSON.\n"
)
return "\n".join(prompt_parts)
def extract_json_from_response(text: str) -> Optional[Dict]:
"""Extract JSON from XML output tags in Claude's response"""
match = re.search(r'<output>\s*(\{.*?\})\s*</output>', text, re.DOTALL)
if match:
json_str = match.group(1)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return None
return None
def extract_analysis_from_response(text: str) -> Optional[str]:
"""Extract analysis from XML tags in Claude's response"""
match = re.search(r'<analysis>(.*?)</analysis>', text, re.DOTALL)
if match:
return match.group(1).strip()
return None
def process_pdf_base64(
client: Anthropic,
pdf_path: Path,
schema: Dict,
model: str
) -> Dict:
"""Process a single PDF using base64 encoding (direct upload)"""
if not pdf_path.exists():
return {
'status': 'error',
'error': f'PDF not found: {pdf_path}'
}
# Check file size (32MB limit)
file_size = pdf_path.stat().st_size
if file_size > 32 * 1024 * 1024:
return {
'status': 'error',
'error': f'PDF exceeds 32MB limit: {file_size / 1024 / 1024:.1f}MB'
}
try:
# Read and encode PDF
with open(pdf_path, 'rb') as f:
pdf_data = base64.b64encode(f.read()).decode('utf-8')
# Create message
response = client.messages.create(
model=model,
max_tokens=16384,
temperature=0,
system=schema.get('system_context', 'You are a scientific research assistant.'),
messages=[{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "base64",
"media_type": "application/pdf",
"data": pdf_data
}
},
{
"type": "text",
"text": create_extraction_prompt(schema)
}
]
}]
)
response_text = response.content[0].text
return {
'status': 'success',
'extracted_data': extract_json_from_response(response_text),
'analysis': extract_analysis_from_response(response_text),
'input_tokens': response.usage.input_tokens,
'output_tokens': response.usage.output_tokens
}
except Exception as e:
return {
'status': 'error',
'error': str(e)
}
def process_pdfs_batch(
client: Anthropic,
records: List[tuple],
schema: Dict,
model: str
) -> Dict[str, Dict]:
"""Process multiple PDFs using Batches API for efficiency"""
all_results = {}
for window_start in range(0, len(records), SIMULTANEOUS_BATCHES * BATCH_SIZE):
window_records = records[window_start:window_start + (SIMULTANEOUS_BATCHES * BATCH_SIZE)]
print(f"\nProcessing window starting at index {window_start} ({len(window_records)} PDFs)")
active_batches = {}
for batch_start in range(0, len(window_records), BATCH_SIZE):
batch_records = window_records[batch_start:batch_start + BATCH_SIZE]
requests = []
for record_id, pdf_data in batch_records:
requests.append(Request(
custom_id=record_id,
params=MessageCreateParamsNonStreaming(
model=model,
max_tokens=16384,
temperature=0,
system=schema.get('system_context', 'You are a scientific research assistant.'),
messages=[{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "base64",
"media_type": "application/pdf",
"data": pdf_data
}
},
{
"type": "text",
"text": create_extraction_prompt(schema)
}
]
}]
)
))
try:
message_batch = client.messages.batches.create(requests=requests)
print(f"Created batch {message_batch.id} with {len(requests)} requests")
active_batches[message_batch.id] = {r.custom_id for r in requests}
time.sleep(BATCH_SUBMISSION_INTERVAL)
except Exception as e:
print(f"Error creating batch: {e}")
# Wait for batches
window_results = wait_for_batches(client, list(active_batches.keys()), schema)
all_results.update(window_results)
return all_results
def wait_for_batches(
client: Anthropic,
batch_ids: List[str],
schema: Dict
) -> Dict[str, Dict]:
"""Wait for batches to complete and return results"""
print(f"\nWaiting for {len(batch_ids)} batches to complete...")
incomplete = set(batch_ids)
while incomplete:
time.sleep(BATCH_CHECK_INTERVAL)
for batch_id in list(incomplete):
batch = client.messages.batches.retrieve(batch_id)
if batch.processing_status != "in_progress":
incomplete.remove(batch_id)
print(f"Batch {batch_id} completed: {batch.processing_status}")
# Collect results
results = {}
for batch_id in batch_ids:
batch = client.messages.batches.retrieve(batch_id)
if batch.processing_status == "ended":
for result in client.messages.batches.results(batch_id):
if result.result.type == "succeeded":
response_text = result.result.message.content[0].text
results[result.custom_id] = {
'status': 'success',
'extracted_data': extract_json_from_response(response_text),
'analysis': extract_analysis_from_response(response_text),
'input_tokens': result.result.message.usage.input_tokens,
'output_tokens': result.result.message.usage.output_tokens
}
else:
results[result.custom_id] = {
'status': 'error',
'error': str(getattr(result.result, 'error', 'Unknown error'))
}
return results
def main():
args = parse_args()
# Check for API key
if not os.getenv('ANTHROPIC_API_KEY'):
raise ValueError("Please set ANTHROPIC_API_KEY environment variable")
client = Anthropic()
# Load inputs
metadata = load_metadata(Path(args.metadata))
schema = load_schema(Path(args.schema))
print(f"Loaded {len(metadata)} metadata records")
# Filter by relevance if filter results provided
if args.filter_results:
filter_results = load_filter_results(Path(args.filter_results))
relevant_ids = {
id for id, result in filter_results.items()
if result.get('status') == 'success'
and result.get('filter_result', {}).get('has_relevant_data')
}
metadata = [r for r in metadata if r['id'] in relevant_ids]
print(f"Filtered to {len(metadata)} relevant papers")
# Apply test mode
if args.test:
metadata = metadata[:3]
print(f"Test mode: processing {len(metadata)} PDFs")
# Load existing results
output_path = Path(args.output)
results = load_existing_results(output_path)
print(f"Loaded {len(results)} existing results")
# Prepare PDFs to process
to_process = []
for record in metadata:
if record['id'] in results:
continue
if not record.get('pdf_path'):
print(f"Skipping {record['id']}: no PDF path")
continue
pdf_path = Path(record['pdf_path'])
if not pdf_path.exists():
print(f"Skipping {record['id']}: PDF not found")
continue
# Read and encode PDF
try:
with open(pdf_path, 'rb') as f:
pdf_data = base64.b64encode(f.read()).decode('utf-8')
to_process.append((record['id'], pdf_data))
except Exception as e:
print(f"Error reading {pdf_path}: {e}")
print(f"PDFs to process: {len(to_process)}")
if not to_process:
print("All PDFs already processed!")
return
# Process PDFs
if args.method == 'batches':
print("Using Batches API...")
batch_results = process_pdfs_batch(client, to_process, schema, args.model)
results.update(batch_results)
else:
print("Processing PDFs sequentially...")
for record_id, pdf_data in to_process:
print(f"Processing: {record_id}")
# For sequential processing, reconstruct Path
record = next(r for r in metadata if r['id'] == record_id)
result = process_pdf_base64(
client, Path(record['pdf_path']), schema, args.model
)
results[record_id] = result
save_results(results, output_path)
time.sleep(2)
# Save final results
save_results(results, output_path)
# Print summary
total = len(results)
successful = sum(1 for r in results.values() if r.get('status') == 'success')
total_input_tokens = sum(
r.get('input_tokens', 0) for r in results.values()
if r.get('status') == 'success'
)
total_output_tokens = sum(
r.get('output_tokens', 0) for r in results.values()
if r.get('status') == 'success'
)
print(f"\n{'='*60}")
print("Extraction Summary")
print(f"{'='*60}")
print(f"Total PDFs processed: {total}")
print(f"Successful extractions: {successful}")
print(f"Failed extractions: {total - successful}")
print(f"\nToken usage:")
print(f" Input tokens: {total_input_tokens:,}")
print(f" Output tokens: {total_output_tokens:,}")
print(f" Total tokens: {total_input_tokens + total_output_tokens:,}")
print(f"\nResults saved to: {output_path}")
print(f"\nNext step: Repair and validate JSON outputs")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,227 @@
#!/usr/bin/env python3
"""
Repair and validate JSON extractions using json_repair library.
Handles common JSON parsing issues and validates against schema.
"""
import argparse
import json
from pathlib import Path
from typing import Dict, Any, Optional
import jsonschema
try:
from json_repair import repair_json
JSON_REPAIR_AVAILABLE = True
except ImportError:
JSON_REPAIR_AVAILABLE = False
print("Warning: json_repair not installed. Install with: pip install json-repair")
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Repair and validate JSON extractions'
)
parser.add_argument(
'--input',
required=True,
help='Input JSON file with extraction results from step 03'
)
parser.add_argument(
'--output',
default='cleaned_extractions.json',
help='Output JSON file with cleaned results'
)
parser.add_argument(
'--schema',
help='Optional: JSON schema file for validation'
)
parser.add_argument(
'--strict',
action='store_true',
help='Strict mode: reject records that fail validation'
)
return parser.parse_args()
def load_results(input_path: Path) -> Dict:
"""Load extraction results from JSON file"""
with open(input_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_schema(schema_path: Path) -> Dict:
"""Load JSON schema for validation"""
with open(schema_path, 'r', encoding='utf-8') as f:
schema_data = json.load(f)
return schema_data.get('output_schema', schema_data)
def save_results(results: Dict, output_path: Path):
"""Save cleaned results to JSON file"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
def repair_json_data(data: Any) -> tuple[Any, bool]:
"""
Attempt to repair JSON data using json_repair library.
Returns (repaired_data, success)
"""
if not JSON_REPAIR_AVAILABLE:
return data, True # Skip repair if library not available
try:
# Convert to JSON string and back to repair
json_str = json.dumps(data)
repaired_str = repair_json(json_str, return_objects=False)
repaired_data = json.loads(repaired_str)
return repaired_data, True
except Exception as e:
print(f"Failed to repair JSON: {e}")
return data, False
def validate_against_schema(data: Any, schema: Dict) -> tuple[bool, Optional[str]]:
"""
Validate data against JSON schema.
Returns (is_valid, error_message)
"""
try:
jsonschema.validate(instance=data, schema=schema)
return True, None
except jsonschema.exceptions.ValidationError as e:
return False, str(e)
except Exception as e:
return False, f"Validation error: {str(e)}"
def clean_extraction_result(
result: Dict,
schema: Optional[Dict] = None,
strict: bool = False
) -> Dict:
"""
Clean and validate a single extraction result.
Returns updated result with:
- repaired_data: Repaired JSON if repair was needed
- validation_status: 'valid', 'invalid', or 'repaired'
- validation_errors: List of validation errors if any
"""
if result.get('status') != 'success':
return result # Skip non-successful results
extracted_data = result.get('extracted_data')
if not extracted_data:
result['validation_status'] = 'invalid'
result['validation_errors'] = ['No extracted data found']
if strict:
result['status'] = 'failed_validation'
return result
# Try to repair JSON
repaired_data, repair_success = repair_json_data(extracted_data)
# Validate against schema if provided
validation_errors = []
if schema:
is_valid, error_msg = validate_against_schema(repaired_data, schema)
if not is_valid:
validation_errors.append(error_msg)
if strict:
result['status'] = 'failed_validation'
# Update result
if repaired_data != extracted_data and repair_success:
result['extracted_data'] = repaired_data
result['validation_status'] = 'repaired'
elif validation_errors:
result['validation_status'] = 'invalid'
else:
result['validation_status'] = 'valid'
if validation_errors:
result['validation_errors'] = validation_errors
return result
def main():
args = parse_args()
# Load inputs
results = load_results(Path(args.input))
print(f"Loaded {len(results)} extraction results")
schema = None
if args.schema:
schema = load_schema(Path(args.schema))
print(f"Loaded validation schema from {args.schema}")
# Clean each result
cleaned_results = {}
stats = {
'total': len(results),
'valid': 0,
'repaired': 0,
'invalid': 0,
'failed': 0
}
for record_id, result in results.items():
cleaned_result = clean_extraction_result(result, schema, args.strict)
cleaned_results[record_id] = cleaned_result
# Update statistics
if cleaned_result.get('status') == 'success':
status = cleaned_result.get('validation_status', 'unknown')
if status == 'valid':
stats['valid'] += 1
elif status == 'repaired':
stats['repaired'] += 1
elif status == 'invalid':
stats['invalid'] += 1
else:
stats['failed'] += 1
# Save cleaned results
output_path = Path(args.output)
save_results(cleaned_results, output_path)
# Print summary
print(f"\n{'='*60}")
print("JSON Repair and Validation Summary")
print(f"{'='*60}")
print(f"Total records: {stats['total']}")
print(f"Valid JSON: {stats['valid']}")
print(f"Repaired JSON: {stats['repaired']}")
print(f"Invalid JSON: {stats['invalid']}")
print(f"Failed extractions: {stats['failed']}")
if schema:
validation_rate = (stats['valid'] + stats['repaired']) / stats['total'] * 100
print(f"\nValidation rate: {validation_rate:.1f}%")
print(f"\nCleaned results saved to: {output_path}")
# Print examples of validation errors
if stats['invalid'] > 0:
print(f"\nShowing first 3 validation errors:")
error_count = 0
for record_id, result in cleaned_results.items():
if result.get('validation_errors'):
print(f"\n{record_id}:")
for error in result['validation_errors'][:2]:
print(f" - {error[:200]}")
error_count += 1
if error_count >= 3:
break
print(f"\nNext step: Validate and enrich data with external APIs")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,390 @@
#!/usr/bin/env python3
"""
Validate and enrich extracted data using external API databases.
Supports common scientific databases for taxonomy, geography, chemistry, etc.
This script template includes examples for common databases. Customize for your needs.
"""
import argparse
import json
import time
from pathlib import Path
from typing import Dict, List, Optional, Any
import requests
from urllib.parse import quote
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Validate and enrich data with external APIs'
)
parser.add_argument(
'--input',
required=True,
help='Input JSON file with cleaned extraction results from step 04'
)
parser.add_argument(
'--output',
default='validated_data.json',
help='Output JSON file with validated and enriched data'
)
parser.add_argument(
'--apis',
required=True,
help='JSON configuration file specifying which APIs to use and for which fields'
)
parser.add_argument(
'--skip-validation',
action='store_true',
help='Skip API calls, only load and structure data'
)
return parser.parse_args()
def load_results(input_path: Path) -> Dict:
"""Load extraction results from JSON file"""
with open(input_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_api_config(config_path: Path) -> Dict:
"""Load API configuration"""
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def save_results(results: Dict, output_path: Path):
"""Save validated results to JSON file"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
# ==============================================================================
# Taxonomy validation functions
# ==============================================================================
def validate_gbif_taxonomy(scientific_name: str) -> Optional[Dict]:
"""
Validate taxonomic name using GBIF (Global Biodiversity Information Facility).
Returns standardized taxonomy if found.
"""
url = f"https://api.gbif.org/v1/species/match?name={quote(scientific_name)}"
try:
response = requests.get(url, timeout=10)
if response.status_code == 200:
data = response.json()
if data.get('matchType') != 'NONE':
return {
'matched_name': data.get('canonicalName', scientific_name),
'scientific_name': data.get('scientificName'),
'rank': data.get('rank'),
'kingdom': data.get('kingdom'),
'phylum': data.get('phylum'),
'class': data.get('class'),
'order': data.get('order'),
'family': data.get('family'),
'genus': data.get('genus'),
'gbif_id': data.get('usageKey'),
'confidence': data.get('confidence'),
'match_type': data.get('matchType'),
'status': data.get('status')
}
except Exception as e:
print(f"GBIF API error for '{scientific_name}': {e}")
return None
def validate_wfo_plant(scientific_name: str) -> Optional[Dict]:
"""
Validate plant name using World Flora Online.
Returns standardized plant taxonomy if found.
"""
# WFO requires name parsing - this is a simplified example
url = f"http://www.worldfloraonline.org/api/1.0/search?query={quote(scientific_name)}"
try:
response = requests.get(url, timeout=10)
if response.status_code == 200:
data = response.json()
if data.get('results'):
first_result = data['results'][0]
return {
'matched_name': first_result.get('name'),
'scientific_name': first_result.get('scientificName'),
'authors': first_result.get('authors'),
'family': first_result.get('family'),
'wfo_id': first_result.get('wfoId'),
'status': first_result.get('status')
}
except Exception as e:
print(f"WFO API error for '{scientific_name}': {e}")
return None
# ==============================================================================
# Geography validation functions
# ==============================================================================
def validate_geonames(location: str, country: Optional[str] = None) -> Optional[Dict]:
"""
Validate location using GeoNames.
Note: Requires free GeoNames account and username.
Set GEONAMES_USERNAME environment variable.
"""
import os
username = os.getenv('GEONAMES_USERNAME')
if not username:
print("Warning: GEONAMES_USERNAME not set. Skipping GeoNames validation.")
return None
url = f"http://api.geonames.org/searchJSON?q={quote(location)}&maxRows=1&username={username}"
if country:
url += f"&country={country[:2]}" # Country code
try:
response = requests.get(url, timeout=10)
if response.status_code == 200:
data = response.json()
if data.get('geonames'):
place = data['geonames'][0]
return {
'matched_name': place.get('name'),
'country': place.get('countryName'),
'country_code': place.get('countryCode'),
'admin1': place.get('adminName1'),
'admin2': place.get('adminName2'),
'latitude': place.get('lat'),
'longitude': place.get('lng'),
'geonames_id': place.get('geonameId')
}
except Exception as e:
print(f"GeoNames API error for '{location}': {e}")
return None
def geocode_location(address: str) -> Optional[Dict]:
"""
Geocode an address using OpenStreetMap Nominatim (free, no API key needed).
Please use responsibly - add delays between calls.
"""
url = f"https://nominatim.openstreetmap.org/search?q={quote(address)}&format=json&limit=1"
headers = {'User-Agent': 'Scientific-PDF-Extraction/1.0'}
try:
time.sleep(1) # Be nice to OSM
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
if data:
place = data[0]
return {
'display_name': place.get('display_name'),
'latitude': place.get('lat'),
'longitude': place.get('lon'),
'osm_type': place.get('osm_type'),
'osm_id': place.get('osm_id'),
'place_rank': place.get('place_rank')
}
except Exception as e:
print(f"Nominatim error for '{address}': {e}")
return None
# ==============================================================================
# Chemistry validation functions
# ==============================================================================
def validate_pubchem_compound(compound_name: str) -> Optional[Dict]:
"""
Validate chemical compound using PubChem.
Returns standardized compound information.
"""
url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{quote(compound_name)}/JSON"
try:
response = requests.get(url, timeout=10)
if response.status_code == 200:
data = response.json()
if 'PC_Compounds' in data and data['PC_Compounds']:
compound = data['PC_Compounds'][0]
return {
'cid': compound['id']['id']['cid'],
'molecular_formula': compound.get('props', [{}])[0].get('value', {}).get('sval'),
'pubchem_url': f"https://pubchem.ncbi.nlm.nih.gov/compound/{compound['id']['id']['cid']}"
}
except Exception as e:
print(f"PubChem API error for '{compound_name}': {e}")
return None
# ==============================================================================
# Gene/Protein validation functions
# ==============================================================================
def validate_ncbi_gene(gene_symbol: str, organism: Optional[str] = None) -> Optional[Dict]:
"""
Validate gene using NCBI Gene database.
"""
query = gene_symbol
if organism:
query += f"[Gene Name] AND {organism}[Organism]"
search_url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=gene&term={quote(query)}&retmode=json"
try:
response = requests.get(search_url, timeout=10)
if response.status_code == 200:
data = response.json()
if data.get('esearchresult', {}).get('idlist'):
gene_id = data['esearchresult']['idlist'][0]
return {
'gene_id': gene_id,
'ncbi_url': f"https://www.ncbi.nlm.nih.gov/gene/{gene_id}"
}
except Exception as e:
print(f"NCBI Gene API error for '{gene_symbol}': {e}")
return None
# ==============================================================================
# Main validation orchestration
# ==============================================================================
API_VALIDATORS = {
'gbif_taxonomy': validate_gbif_taxonomy,
'wfo_plants': validate_wfo_plant,
'geonames': validate_geonames,
'geocode': geocode_location,
'pubchem': validate_pubchem_compound,
'ncbi_gene': validate_ncbi_gene
}
def validate_field(value: Any, api_name: str, extra_params: Dict = None) -> Optional[Dict]:
"""
Validate a single field value using the specified API.
"""
if not value or value == 'none' or value == '':
return None
validator = API_VALIDATORS.get(api_name)
if not validator:
print(f"Unknown API: {api_name}")
return None
try:
if extra_params:
return validator(value, **extra_params)
else:
return validator(value)
except Exception as e:
print(f"Validation error for {api_name} with value '{value}': {e}")
return None
def process_record(
record_data: Dict,
api_config: Dict,
skip_validation: bool = False
) -> Dict:
"""
Process a single record, validating specified fields.
api_config should map field names to API names:
{
"field_mappings": {
"species": {"api": "gbif_taxonomy", "output_field": "validated_species"},
"location": {"api": "geocode", "output_field": "geocoded_location"}
}
}
"""
if skip_validation:
return record_data
field_mappings = api_config.get('field_mappings', {})
for field_name, field_config in field_mappings.items():
api_name = field_config.get('api')
output_field = field_config.get('output_field', f'validated_{field_name}')
extra_params = field_config.get('extra_params', {})
# Handle nested fields (e.g., 'records.species')
if '.' in field_name:
# This is a simplified example - you'd need to implement proper nested access
continue
value = record_data.get(field_name)
if value:
validated = validate_field(value, api_name, extra_params)
if validated:
record_data[output_field] = validated
return record_data
def main():
args = parse_args()
# Load inputs
results = load_results(Path(args.input))
api_config = load_api_config(Path(args.apis))
print(f"Loaded {len(results)} extraction results")
# Process each result
validated_results = {}
stats = {'total': 0, 'validated': 0, 'failed': 0}
for record_id, result in results.items():
if result.get('status') != 'success':
validated_results[record_id] = result
stats['failed'] += 1
continue
stats['total'] += 1
# Get extracted data
extracted_data = result.get('extracted_data', {})
# Process/validate the data
validated_data = process_record(
extracted_data.copy(),
api_config,
args.skip_validation
)
# Update result
result['validated_data'] = validated_data
validated_results[record_id] = result
stats['validated'] += 1
# Rate limiting
if not args.skip_validation:
time.sleep(0.5)
# Save results
output_path = Path(args.output)
save_results(validated_results, output_path)
# Print summary
print(f"\n{'='*60}")
print("Validation and Enrichment Summary")
print(f"{'='*60}")
print(f"Total records: {len(results)}")
print(f"Successfully validated: {stats['validated']}")
print(f"Failed extractions: {stats['failed']}")
print(f"\nResults saved to: {output_path}")
print(f"\nNext step: Export to analysis format")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,345 @@
#!/usr/bin/env python3
"""
Export validated data to various analysis formats.
Supports Python (pandas/SQLite), R (RDS/CSV), Excel, and more.
"""
import argparse
import json
import csv
from pathlib import Path
from typing import Dict, List, Any
import sys
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Export validated data to analysis format'
)
parser.add_argument(
'--input',
required=True,
help='Input JSON file with validated data from step 05'
)
parser.add_argument(
'--format',
choices=['python', 'r', 'csv', 'json', 'excel', 'sqlite'],
required=True,
help='Output format'
)
parser.add_argument(
'--output',
required=True,
help='Output file path (without extension for some formats)'
)
parser.add_argument(
'--flatten',
action='store_true',
help='Flatten nested JSON structures for tabular formats'
)
parser.add_argument(
'--include-metadata',
action='store_true',
help='Include original paper metadata in output'
)
return parser.parse_args()
def load_results(input_path: Path) -> Dict:
"""Load validated results from JSON file"""
with open(input_path, 'r', encoding='utf-8') as f:
return json.load(f)
def flatten_dict(d: Dict, parent_key: str = '', sep: str = '_') -> Dict:
"""
Flatten nested dictionary structure.
Useful for converting JSON to tabular format.
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
elif isinstance(v, list):
# Convert lists to comma-separated strings
if v and isinstance(v[0], dict):
# List of dicts - create numbered columns
for i, item in enumerate(v):
items.extend(flatten_dict(item, f"{new_key}_{i}", sep=sep).items())
else:
# Simple list
items.append((new_key, ', '.join(str(x) for x in v)))
else:
items.append((new_key, v))
return dict(items)
def extract_records(results: Dict, flatten: bool = False, include_metadata: bool = False) -> List[Dict]:
"""
Extract records from results structure.
Returns a list of dictionaries suitable for tabular export.
"""
records = []
for paper_id, result in results.items():
if result.get('status') != 'success':
continue
# Get the validated data (or fall back to extracted data)
data = result.get('validated_data', result.get('extracted_data', {}))
if not data:
continue
# Check if data contains nested records or is a single record
if 'records' in data and isinstance(data['records'], list):
# Multiple records per paper
for record in data['records']:
record_dict = record.copy() if isinstance(record, dict) else {'value': record}
# Add paper-level fields
if include_metadata:
record_dict['paper_id'] = paper_id
for key in data:
if key != 'records':
record_dict[f'paper_{key}'] = data[key]
if flatten:
record_dict = flatten_dict(record_dict)
records.append(record_dict)
else:
# Single record per paper
record_dict = data.copy()
if include_metadata:
record_dict['paper_id'] = paper_id
if flatten:
record_dict = flatten_dict(record_dict)
records.append(record_dict)
return records
def export_to_csv(records: List[Dict], output_path: Path):
"""Export to CSV format"""
if not records:
print("No records to export")
return
# Get all possible field names
fieldnames = set()
for record in records:
fieldnames.update(record.keys())
fieldnames = sorted(fieldnames)
with open(output_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(records)
print(f"Exported {len(records)} records to CSV: {output_path}")
def export_to_json(records: List[Dict], output_path: Path):
"""Export to JSON format"""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(records, f, indent=2, ensure_ascii=False)
print(f"Exported {len(records)} records to JSON: {output_path}")
def export_to_python(records: List[Dict], output_path: Path):
"""Export to Python format (pandas DataFrame pickle)"""
try:
import pandas as pd
except ImportError:
print("Error: pandas is required for Python export. Install with: pip install pandas")
sys.exit(1)
df = pd.DataFrame(records)
# Save as pickle
pickle_path = output_path.with_suffix('.pkl')
df.to_pickle(pickle_path)
print(f"Exported {len(records)} records to pandas pickle: {pickle_path}")
# Also create a Python script to load it
script_path = output_path.with_suffix('.py')
script_content = f'''#!/usr/bin/env python3
"""
Data loading script
Generated by extract_from_pdfs skill
"""
import pandas as pd
# Load the data
df = pd.read_pickle('{pickle_path.name}')
print(f"Loaded {{len(df)}} records")
print(f"Columns: {{list(df.columns)}}")
print("\\nFirst few rows:")
print(df.head())
# Example analyses:
# df.describe()
# df.groupby('some_column').size()
# df.to_csv('output.csv', index=False)
'''
with open(script_path, 'w') as f:
f.write(script_content)
print(f"Created loading script: {script_path}")
def export_to_r(records: List[Dict], output_path: Path):
"""Export to R format (RDS file)"""
try:
import pandas as pd
import pyreadr
except ImportError:
print("Error: pandas and pyreadr are required for R export.")
print("Install with: pip install pandas pyreadr")
sys.exit(1)
df = pd.DataFrame(records)
# Save as RDS
rds_path = output_path.with_suffix('.rds')
pyreadr.write_rds(rds_path, df)
print(f"Exported {len(records)} records to RDS: {rds_path}")
# Also create an R script to load it
script_path = output_path.with_suffix('.R')
script_content = f'''# Data loading script
# Generated by extract_from_pdfs skill
# Load the data
data <- readRDS('{rds_path.name}')
cat(sprintf("Loaded %d records\\n", nrow(data)))
cat(sprintf("Columns: %s\\n", paste(colnames(data), collapse=", ")))
cat("\\nFirst few rows:\\n")
print(head(data))
# Example analyses:
# summary(data)
# table(data$some_column)
# write.csv(data, 'output.csv', row.names=FALSE)
'''
with open(script_path, 'w') as f:
f.write(script_content)
print(f"Created loading script: {script_path}")
def export_to_excel(records: List[Dict], output_path: Path):
"""Export to Excel format"""
try:
import pandas as pd
except ImportError:
print("Error: pandas is required for Excel export. Install with: pip install pandas openpyxl")
sys.exit(1)
df = pd.DataFrame(records)
# Save as Excel
excel_path = output_path.with_suffix('.xlsx')
df.to_excel(excel_path, index=False, engine='openpyxl')
print(f"Exported {len(records)} records to Excel: {excel_path}")
def export_to_sqlite(records: List[Dict], output_path: Path):
"""Export to SQLite database"""
try:
import pandas as pd
import sqlite3
except ImportError:
print("Error: pandas is required for SQLite export. Install with: pip install pandas")
sys.exit(1)
df = pd.DataFrame(records)
# Create database
db_path = output_path.with_suffix('.db')
conn = sqlite3.connect(db_path)
# Write to database
table_name = 'extracted_data'
df.to_sql(table_name, conn, if_exists='replace', index=False)
conn.close()
print(f"Exported {len(records)} records to SQLite database: {db_path}")
print(f"Table name: {table_name}")
# Create SQL script with example queries
sql_script_path = output_path.with_suffix('.sql')
sql_content = f'''-- Example SQL queries for {db_path.name}
-- Generated by extract_from_pdfs skill
-- View all records
SELECT * FROM {table_name} LIMIT 10;
-- Count total records
SELECT COUNT(*) as total_records FROM {table_name};
-- Example: Group by a column (adjust column name as needed)
-- SELECT column_name, COUNT(*) as count
-- FROM {table_name}
-- GROUP BY column_name
-- ORDER BY count DESC;
'''
with open(sql_script_path, 'w') as f:
f.write(sql_content)
print(f"Created SQL example script: {sql_script_path}")
def main():
args = parse_args()
# Load validated results
results = load_results(Path(args.input))
print(f"Loaded {len(results)} results")
# Extract records
records = extract_records(
results,
flatten=args.flatten,
include_metadata=args.include_metadata
)
print(f"Extracted {len(records)} records")
if not records:
print("No records to export. Check your data.")
return
# Export based on format
output_path = Path(args.output)
if args.format == 'csv':
export_to_csv(records, output_path)
elif args.format == 'json':
export_to_json(records, output_path)
elif args.format == 'python':
export_to_python(records, output_path)
elif args.format == 'r':
export_to_r(records, output_path)
elif args.format == 'excel':
export_to_excel(records, output_path)
elif args.format == 'sqlite':
export_to_sqlite(records, output_path)
print(f"\nExport complete!")
print(f"Your data is ready for analysis in {args.format.upper()} format.")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,280 @@
#!/usr/bin/env python3
"""
Prepare a validation set for evaluating extraction quality.
This script helps you:
1. Sample a subset of papers for manual annotation
2. Set up a structured annotation file
3. Guide the annotation process
The validation set is used to calculate precision and recall metrics.
"""
import argparse
import json
import random
from pathlib import Path
from typing import Dict, List, Any
import sys
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Prepare validation set for extraction quality evaluation',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Workflow:
1. This script samples papers from your extraction results
2. It creates an annotation template based on your schema
3. You manually annotate the sampled papers with ground truth
4. Use 08_calculate_validation_metrics.py to compare automated vs. manual extraction
Sampling strategies:
random : Random sample (good for overall quality)
stratified: Sample by extraction characteristics (good for identifying weaknesses)
diverse : Sample to maximize diversity (good for comprehensive evaluation)
"""
)
parser.add_argument(
'--extraction-results',
required=True,
help='JSON file with extraction results from step 03 or 04'
)
parser.add_argument(
'--schema',
required=True,
help='Extraction schema JSON file used in step 03'
)
parser.add_argument(
'--output',
default='validation_set.json',
help='Output file for validation annotations'
)
parser.add_argument(
'--sample-size',
type=int,
default=20,
help='Number of papers to sample (default: 20, recommended: 20-50)'
)
parser.add_argument(
'--strategy',
choices=['random', 'stratified', 'diverse'],
default='random',
help='Sampling strategy (default: random)'
)
parser.add_argument(
'--seed',
type=int,
default=42,
help='Random seed for reproducibility'
)
return parser.parse_args()
def load_results(results_path: Path) -> Dict:
"""Load extraction results"""
with open(results_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_schema(schema_path: Path) -> Dict:
"""Load extraction schema"""
with open(schema_path, 'r', encoding='utf-8') as f:
return json.load(f)
def sample_random(results: Dict, sample_size: int, seed: int) -> List[str]:
"""Random sampling strategy"""
# Only sample from successful extractions
successful = [
paper_id for paper_id, result in results.items()
if result.get('status') == 'success' and result.get('extracted_data')
]
if len(successful) < sample_size:
print(f"Warning: Only {len(successful)} successful extractions available")
sample_size = len(successful)
random.seed(seed)
return random.sample(successful, sample_size)
def sample_stratified(results: Dict, sample_size: int, seed: int) -> List[str]:
"""
Stratified sampling: sample papers with different characteristics
E.g., papers with many records vs. few records, different data completeness
"""
successful = {}
for paper_id, result in results.items():
if result.get('status') == 'success' and result.get('extracted_data'):
data = result['extracted_data']
# Count records if present
num_records = len(data.get('records', [])) if 'records' in data else 0
successful[paper_id] = num_records
if not successful:
print("No successful extractions found")
return []
# Create strata based on number of records
strata = {
'zero': [],
'few': [], # 1-2 records
'medium': [], # 3-5 records
'many': [] # 6+ records
}
for paper_id, count in successful.items():
if count == 0:
strata['zero'].append(paper_id)
elif count <= 2:
strata['few'].append(paper_id)
elif count <= 5:
strata['medium'].append(paper_id)
else:
strata['many'].append(paper_id)
# Sample proportionally from each stratum
random.seed(seed)
sampled = []
total_papers = len(successful)
for stratum_name, papers in strata.items():
if not papers:
continue
# Sample proportionally, at least 1 from each non-empty stratum
stratum_sample_size = max(1, int(len(papers) / total_papers * sample_size))
stratum_sample_size = min(stratum_sample_size, len(papers))
sampled.extend(random.sample(papers, stratum_sample_size))
# If we haven't reached sample_size, add more randomly
if len(sampled) < sample_size:
remaining = [p for p in successful.keys() if p not in sampled]
additional = min(sample_size - len(sampled), len(remaining))
sampled.extend(random.sample(remaining, additional))
return sampled[:sample_size]
def sample_diverse(results: Dict, sample_size: int, seed: int) -> List[str]:
"""
Diverse sampling: maximize diversity in sampled papers
This is a simplified version - could be enhanced with actual diversity metrics
"""
# For now, use stratified sampling as a proxy for diversity
return sample_stratified(results, sample_size, seed)
def create_annotation_template(
sampled_ids: List[str],
results: Dict,
schema: Dict
) -> Dict:
"""
Create annotation template for manual validation.
Structure:
{
"paper_id": {
"automated_extraction": {...},
"ground_truth": null, # To be filled manually
"notes": "",
"annotator": "",
"annotation_date": ""
}
}
"""
template = {
"_instructions": {
"overview": "This is a validation annotation file. For each paper, review the PDF and fill in the ground_truth field with the correct extraction.",
"steps": [
"1. Read the PDF for each paper_id",
"2. Extract data according to the schema, filling the 'ground_truth' field",
"3. The 'ground_truth' should have the same structure as 'automated_extraction'",
"4. Add your name in 'annotator' and date in 'annotation_date'",
"5. Use 'notes' field for any comments or ambiguities",
"6. Once complete, use 08_calculate_validation_metrics.py to compare"
],
"schema_reference": schema.get('output_schema', {}),
"tips": [
"Be thorough: extract ALL relevant information, even if automated extraction missed it",
"Be precise: use exact values as they appear in the paper",
"Be consistent: follow the same schema structure",
"Mark ambiguous cases in notes field"
]
},
"validation_papers": {}
}
for paper_id in sampled_ids:
result = results[paper_id]
template["validation_papers"][paper_id] = {
"automated_extraction": result.get('extracted_data', {}),
"ground_truth": None, # To be filled by annotator
"notes": "",
"annotator": "",
"annotation_date": "",
"_pdf_path": None, # Will try to find from metadata
"_extraction_metadata": {
"extraction_status": result.get('status'),
"validation_status": result.get('validation_status'),
"has_analysis": bool(result.get('analysis'))
}
}
return template
def save_template(template: Dict, output_path: Path):
"""Save annotation template to JSON file"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(template, f, indent=2, ensure_ascii=False)
def main():
args = parse_args()
# Load inputs
results = load_results(Path(args.extraction_results))
schema = load_schema(Path(args.schema))
print(f"Loaded {len(results)} extraction results")
# Sample papers
if args.strategy == 'random':
sampled = sample_random(results, args.sample_size, args.seed)
elif args.strategy == 'stratified':
sampled = sample_stratified(results, args.sample_size, args.seed)
elif args.strategy == 'diverse':
sampled = sample_diverse(results, args.sample_size, args.seed)
print(f"Sampled {len(sampled)} papers using '{args.strategy}' strategy")
# Create annotation template
template = create_annotation_template(sampled, results, schema)
# Save template
output_path = Path(args.output)
save_template(template, output_path)
print(f"\n{'='*60}")
print("Validation Set Preparation Complete")
print(f"{'='*60}")
print(f"Annotation file created: {output_path}")
print(f"Papers to annotate: {len(sampled)}")
print(f"\nNext steps:")
print(f"1. Open {output_path} in a text editor")
print(f"2. For each paper, read the PDF and fill in the 'ground_truth' field")
print(f"3. Follow the schema structure shown in '_instructions'")
print(f"4. Save your annotations")
print(f"5. Run: python 08_calculate_validation_metrics.py --annotations {output_path}")
print(f"\nTips for efficient annotation:")
print(f"- Work in batches of 5-10 papers")
print(f"- Use the automated extraction as a starting point to check")
print(f"- Document any ambiguous cases in the notes field")
print(f"- Consider having 2+ annotators for inter-rater reliability")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,513 @@
#!/usr/bin/env python3
"""
Calculate validation metrics (precision, recall, F1) for extraction quality.
Compares automated extraction against ground truth annotations to evaluate:
- Field-level precision and recall
- Record-level accuracy
- Overall extraction quality
Handles different data types appropriately:
- Boolean: exact match
- Numeric: exact match or tolerance
- String: exact match or fuzzy matching
- Lists: set-based precision/recall
- Nested objects: recursive comparison
"""
import argparse
import json
from pathlib import Path
from typing import Dict, List, Any, Tuple, Optional
from collections import defaultdict
import sys
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Calculate validation metrics for extraction quality',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Metrics calculated:
Precision : Of extracted items, how many are correct?
Recall : Of true items, how many were extracted?
F1 Score : Harmonic mean of precision and recall
Accuracy : Overall correctness (for boolean/categorical fields)
Field type handling:
Boolean/Categorical : Exact match
Numeric : Exact match or within tolerance
String : Exact match or fuzzy (normalized)
Lists : Set-based precision/recall
Nested objects : Recursive field-by-field comparison
Output:
- Overall metrics
- Per-field metrics
- Per-paper detailed comparison
- Common error patterns
"""
)
parser.add_argument(
'--annotations',
required=True,
help='Annotation file from 07_prepare_validation_set.py (with ground truth filled in)'
)
parser.add_argument(
'--output',
default='validation_metrics.json',
help='Output file for detailed metrics'
)
parser.add_argument(
'--report',
default='validation_report.txt',
help='Human-readable validation report'
)
parser.add_argument(
'--numeric-tolerance',
type=float,
default=0.0,
help='Tolerance for numeric comparisons (default: 0.0 for exact match)'
)
parser.add_argument(
'--fuzzy-strings',
action='store_true',
help='Use fuzzy string matching (normalize whitespace, case)'
)
parser.add_argument(
'--list-order-matters',
action='store_true',
help='Consider order in list comparisons (default: treat as sets)'
)
return parser.parse_args()
def load_annotations(annotations_path: Path) -> Dict:
"""Load annotations file"""
with open(annotations_path, 'r', encoding='utf-8') as f:
return json.load(f)
def normalize_string(s: str, fuzzy: bool = False) -> str:
"""Normalize string for comparison"""
if not isinstance(s, str):
return str(s)
if fuzzy:
return ' '.join(s.lower().split())
return s
def compare_boolean(automated: Any, truth: Any) -> Dict[str, int]:
"""Compare boolean values"""
if automated == truth:
return {'tp': 1, 'fp': 0, 'fn': 0, 'tn': 0}
elif automated and not truth:
return {'tp': 0, 'fp': 1, 'fn': 0, 'tn': 0}
elif not automated and truth:
return {'tp': 0, 'fp': 0, 'fn': 1, 'tn': 0}
else:
return {'tp': 0, 'fp': 0, 'fn': 0, 'tn': 1}
def compare_numeric(automated: Any, truth: Any, tolerance: float = 0.0) -> bool:
"""Compare numeric values with optional tolerance"""
try:
a = float(automated) if automated is not None else None
t = float(truth) if truth is not None else None
if a is None and t is None:
return True
if a is None or t is None:
return False
if tolerance > 0:
return abs(a - t) <= tolerance
else:
return a == t
except (ValueError, TypeError):
return automated == truth
def compare_string(automated: Any, truth: Any, fuzzy: bool = False) -> bool:
"""Compare string values"""
if automated is None and truth is None:
return True
if automated is None or truth is None:
return False
a = normalize_string(automated, fuzzy)
t = normalize_string(truth, fuzzy)
return a == t
def compare_list(
automated: List,
truth: List,
order_matters: bool = False,
fuzzy: bool = False
) -> Dict[str, int]:
"""
Compare lists and calculate precision/recall.
Returns counts of true positives, false positives, and false negatives.
"""
if automated is None:
automated = []
if truth is None:
truth = []
if not isinstance(automated, list):
automated = [automated]
if not isinstance(truth, list):
truth = [truth]
if order_matters:
# Ordered comparison
tp = sum(1 for a, t in zip(automated, truth) if compare_string(a, t, fuzzy))
fp = max(0, len(automated) - len(truth))
fn = max(0, len(truth) - len(automated))
else:
# Set-based comparison
if fuzzy:
auto_set = {normalize_string(x, fuzzy) for x in automated}
truth_set = {normalize_string(x, fuzzy) for x in truth}
else:
auto_set = set(automated)
truth_set = set(truth)
tp = len(auto_set & truth_set) # Intersection
fp = len(auto_set - truth_set) # In automated but not in truth
fn = len(truth_set - auto_set) # In truth but not in automated
return {'tp': tp, 'fp': fp, 'fn': fn}
def calculate_metrics(tp: int, fp: int, fn: int) -> Dict[str, float]:
"""Calculate precision, recall, and F1 from counts"""
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return {
'precision': precision,
'recall': recall,
'f1': f1,
'tp': tp,
'fp': fp,
'fn': fn
}
def compare_field(
automated: Any,
truth: Any,
field_name: str,
config: Dict
) -> Dict[str, Any]:
"""
Compare a single field between automated and ground truth.
Returns metrics appropriate for the field type.
"""
# Determine field type
if isinstance(truth, bool):
return compare_boolean(automated, truth)
elif isinstance(truth, (int, float)):
match = compare_numeric(automated, truth, config['numeric_tolerance'])
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
elif isinstance(truth, str):
match = compare_string(automated, truth, config['fuzzy_strings'])
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
elif isinstance(truth, list):
return compare_list(automated, truth, config['list_order_matters'], config['fuzzy_strings'])
elif isinstance(truth, dict):
# Recursive comparison for nested objects
return compare_nested(automated or {}, truth, config)
elif truth is None:
# Field should be empty/null
if automated is None or automated == "" or automated == []:
return {'tp': 1, 'fp': 0, 'fn': 0}
else:
return {'tp': 0, 'fp': 1, 'fn': 0}
else:
# Fallback to exact match
match = automated == truth
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
def compare_nested(automated: Dict, truth: Dict, config: Dict) -> Dict[str, int]:
"""Recursively compare nested objects"""
total_counts = {'tp': 0, 'fp': 0, 'fn': 0}
all_fields = set(automated.keys()) | set(truth.keys())
for field in all_fields:
auto_val = automated.get(field)
truth_val = truth.get(field)
field_counts = compare_field(auto_val, truth_val, field, config)
for key in ['tp', 'fp', 'fn']:
total_counts[key] += field_counts.get(key, 0)
return total_counts
def evaluate_paper(
paper_id: str,
automated: Dict,
truth: Dict,
config: Dict
) -> Dict[str, Any]:
"""
Evaluate extraction for a single paper.
Returns field-level and overall metrics.
"""
if truth is None:
return {
'status': 'not_annotated',
'message': 'Ground truth not provided'
}
field_metrics = {}
all_fields = set(automated.keys()) | set(truth.keys())
for field in all_fields:
if field == 'records':
# Special handling for records arrays
auto_records = automated.get('records', [])
truth_records = truth.get('records', [])
# Overall record count comparison
record_counts = compare_list(auto_records, truth_records, order_matters=False)
# Detailed record-level comparison
record_details = []
for i, (auto_rec, truth_rec) in enumerate(zip(auto_records, truth_records)):
rec_comparison = compare_nested(auto_rec, truth_rec, config)
record_details.append({
'record_index': i,
'metrics': calculate_metrics(**rec_comparison)
})
field_metrics['records'] = {
'count_metrics': calculate_metrics(**record_counts),
'record_details': record_details
}
else:
auto_val = automated.get(field)
truth_val = truth.get(field)
counts = compare_field(auto_val, truth_val, field, config)
field_metrics[field] = calculate_metrics(**counts)
# Calculate overall metrics
total_tp = sum(
m.get('tp', 0) if isinstance(m, dict) and 'tp' in m
else m.get('count_metrics', {}).get('tp', 0)
for m in field_metrics.values()
)
total_fp = sum(
m.get('fp', 0) if isinstance(m, dict) and 'fp' in m
else m.get('count_metrics', {}).get('fp', 0)
for m in field_metrics.values()
)
total_fn = sum(
m.get('fn', 0) if isinstance(m, dict) and 'fn' in m
else m.get('count_metrics', {}).get('fn', 0)
for m in field_metrics.values()
)
overall = calculate_metrics(total_tp, total_fp, total_fn)
return {
'status': 'evaluated',
'field_metrics': field_metrics,
'overall': overall
}
def aggregate_metrics(paper_evaluations: Dict[str, Dict]) -> Dict[str, Any]:
"""Aggregate metrics across all papers"""
# Collect field-level metrics
field_aggregates = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
evaluated_papers = [
p for p in paper_evaluations.values()
if p.get('status') == 'evaluated'
]
for paper_eval in evaluated_papers:
for field, metrics in paper_eval.get('field_metrics', {}).items():
if isinstance(metrics, dict):
if 'tp' in metrics:
# Simple field
field_aggregates[field]['tp'] += metrics['tp']
field_aggregates[field]['fp'] += metrics['fp']
field_aggregates[field]['fn'] += metrics['fn']
elif 'count_metrics' in metrics:
# Records field
field_aggregates[field]['tp'] += metrics['count_metrics']['tp']
field_aggregates[field]['fp'] += metrics['count_metrics']['fp']
field_aggregates[field]['fn'] += metrics['count_metrics']['fn']
# Calculate metrics for each field
field_metrics = {}
for field, counts in field_aggregates.items():
field_metrics[field] = calculate_metrics(**counts)
# Overall aggregated metrics
total_tp = sum(counts['tp'] for counts in field_aggregates.values())
total_fp = sum(counts['fp'] for counts in field_aggregates.values())
total_fn = sum(counts['fn'] for counts in field_aggregates.values())
overall = calculate_metrics(total_tp, total_fp, total_fn)
return {
'overall': overall,
'by_field': field_metrics,
'num_papers_evaluated': len(evaluated_papers)
}
def generate_report(
paper_evaluations: Dict[str, Dict],
aggregated: Dict,
output_path: Path
):
"""Generate human-readable validation report"""
lines = []
lines.append("="*80)
lines.append("EXTRACTION VALIDATION REPORT")
lines.append("="*80)
lines.append("")
# Overall summary
lines.append("OVERALL METRICS")
lines.append("-"*80)
overall = aggregated['overall']
lines.append(f"Papers evaluated: {aggregated['num_papers_evaluated']}")
lines.append(f"Precision: {overall['precision']:.2%}")
lines.append(f"Recall: {overall['recall']:.2%}")
lines.append(f"F1 Score: {overall['f1']:.2%}")
lines.append(f"True Positives: {overall['tp']}")
lines.append(f"False Positives: {overall['fp']}")
lines.append(f"False Negatives: {overall['fn']}")
lines.append("")
# Per-field metrics
lines.append("METRICS BY FIELD")
lines.append("-"*80)
lines.append(f"{'Field':<30} {'Precision':>10} {'Recall':>10} {'F1':>10}")
lines.append("-"*80)
for field, metrics in sorted(aggregated['by_field'].items()):
lines.append(
f"{field:<30} "
f"{metrics['precision']:>9.1%} "
f"{metrics['recall']:>9.1%} "
f"{metrics['f1']:>9.1%}"
)
lines.append("")
# Top errors
lines.append("COMMON ISSUES")
lines.append("-"*80)
# Fields with low recall (missed information)
low_recall = [
(field, metrics) for field, metrics in aggregated['by_field'].items()
if metrics['recall'] < 0.7 and metrics['fn'] > 0
]
if low_recall:
lines.append("\nFields with low recall (missed information):")
for field, metrics in sorted(low_recall, key=lambda x: x[1]['recall']):
lines.append(f" - {field}: {metrics['recall']:.1%} recall, {metrics['fn']} missed items")
# Fields with low precision (incorrect extractions)
low_precision = [
(field, metrics) for field, metrics in aggregated['by_field'].items()
if metrics['precision'] < 0.7 and metrics['fp'] > 0
]
if low_precision:
lines.append("\nFields with low precision (incorrect extractions):")
for field, metrics in sorted(low_precision, key=lambda x: x[1]['precision']):
lines.append(f" - {field}: {metrics['precision']:.1%} precision, {metrics['fp']} incorrect items")
lines.append("")
lines.append("="*80)
# Write report
report_text = "\n".join(lines)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(report_text)
# Also print to console
print(report_text)
def main():
args = parse_args()
# Load annotations
annotations = load_annotations(Path(args.annotations))
validation_papers = annotations.get('validation_papers', {})
print(f"Loaded {len(validation_papers)} validation papers")
# Check how many have ground truth
annotated = sum(1 for p in validation_papers.values() if p.get('ground_truth') is not None)
print(f"Papers with ground truth: {annotated}")
if annotated == 0:
print("\nError: No ground truth annotations found!")
print("Please fill in the 'ground_truth' field for each paper in the annotation file.")
sys.exit(1)
# Configuration for comparisons
config = {
'numeric_tolerance': args.numeric_tolerance,
'fuzzy_strings': args.fuzzy_strings,
'list_order_matters': args.list_order_matters
}
# Evaluate each paper
paper_evaluations = {}
for paper_id, paper_data in validation_papers.items():
automated = paper_data.get('automated_extraction', {})
truth = paper_data.get('ground_truth')
evaluation = evaluate_paper(paper_id, automated, truth, config)
paper_evaluations[paper_id] = evaluation
if evaluation['status'] == 'evaluated':
overall = evaluation['overall']
print(f"{paper_id}: P={overall['precision']:.2%} R={overall['recall']:.2%} F1={overall['f1']:.2%}")
# Aggregate metrics
aggregated = aggregate_metrics(paper_evaluations)
# Save detailed metrics
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
detailed_output = {
'summary': aggregated,
'by_paper': paper_evaluations,
'config': config
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(detailed_output, f, indent=2, ensure_ascii=False)
print(f"\nDetailed metrics saved to: {output_path}")
# Generate report
report_path = Path(args.report)
generate_report(paper_evaluations, aggregated, report_path)
print(f"Validation report saved to: {report_path}")
if __name__ == '__main__':
main()