Initial commit
This commit is contained in:
310
skills/extract_from_pdfs/scripts/01_organize_metadata.py
Normal file
310
skills/extract_from_pdfs/scripts/01_organize_metadata.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Organize PDFs and metadata from various sources (BibTeX, RIS, directory, DOI list).
|
||||
Standardizes file naming and creates a unified metadata JSON for downstream processing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import re
|
||||
|
||||
try:
|
||||
from pybtex.database.input import bibtex
|
||||
BIBTEX_AVAILABLE = True
|
||||
except ImportError:
|
||||
BIBTEX_AVAILABLE = False
|
||||
print("Warning: pybtex not installed. BibTeX support disabled.")
|
||||
|
||||
try:
|
||||
import rispy
|
||||
RIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
RIS_AVAILABLE = False
|
||||
print("Warning: rispy not installed. RIS support disabled.")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Organize PDFs and metadata from various sources'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--source-type',
|
||||
choices=['bibtex', 'ris', 'directory', 'doi_list'],
|
||||
required=True,
|
||||
help='Type of source data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--source',
|
||||
required=True,
|
||||
help='Path to source file (BibTeX/RIS file, directory, or DOI list)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-dir',
|
||||
help='Directory containing PDFs (for bibtex/ris with relative paths)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='metadata.json',
|
||||
help='Output metadata JSON file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--organize-pdfs',
|
||||
action='store_true',
|
||||
help='Copy PDFs to standardized directory structure'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-output-dir',
|
||||
default='organized_pdfs',
|
||||
help='Directory for organized PDFs'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_bibtex_metadata(bib_path: Path, pdf_base_dir: Optional[Path] = None) -> List[Dict]:
|
||||
"""Load metadata from BibTeX file"""
|
||||
if not BIBTEX_AVAILABLE:
|
||||
raise ImportError("pybtex is required for BibTeX support. Install with: pip install pybtex")
|
||||
|
||||
parser = bibtex.Parser()
|
||||
bib_data = parser.parse_file(str(bib_path))
|
||||
|
||||
metadata = []
|
||||
for key, entry in bib_data.entries.items():
|
||||
record = {
|
||||
'id': key,
|
||||
'type': entry.type,
|
||||
'title': entry.fields.get('title', ''),
|
||||
'year': entry.fields.get('year', ''),
|
||||
'doi': entry.fields.get('doi', ''),
|
||||
'abstract': entry.fields.get('abstract', ''),
|
||||
'journal': entry.fields.get('journal', ''),
|
||||
'authors': ', '.join(
|
||||
[' '.join([p for p in person.last_names + person.first_names])
|
||||
for person in entry.persons.get('author', [])]
|
||||
),
|
||||
'keywords': entry.fields.get('keywords', ''),
|
||||
'pdf_path': None
|
||||
}
|
||||
|
||||
# Extract PDF path from file field
|
||||
if 'file' in entry.fields:
|
||||
file_field = entry.fields['file']
|
||||
if file_field.startswith('{') and file_field.endswith('}'):
|
||||
file_field = file_field[1:-1]
|
||||
|
||||
for file_entry in file_field.split(';'):
|
||||
parts = file_entry.strip().split(':')
|
||||
if len(parts) >= 3 and parts[2].lower() == 'application/pdf':
|
||||
pdf_path = parts[1].strip()
|
||||
if pdf_base_dir:
|
||||
pdf_path = str(pdf_base_dir / pdf_path)
|
||||
record['pdf_path'] = pdf_path
|
||||
break
|
||||
|
||||
metadata.append(record)
|
||||
|
||||
print(f"Loaded {len(metadata)} entries from BibTeX file")
|
||||
return metadata
|
||||
|
||||
|
||||
def load_ris_metadata(ris_path: Path, pdf_base_dir: Optional[Path] = None) -> List[Dict]:
|
||||
"""Load metadata from RIS file"""
|
||||
if not RIS_AVAILABLE:
|
||||
raise ImportError("rispy is required for RIS support. Install with: pip install rispy")
|
||||
|
||||
with open(ris_path, 'r', encoding='utf-8') as f:
|
||||
entries = rispy.load(f)
|
||||
|
||||
metadata = []
|
||||
for i, entry in enumerate(entries):
|
||||
# Generate ID from first author and year or use index
|
||||
first_author = entry.get('authors', [None])[0] or 'Unknown'
|
||||
year = entry.get('year', 'NoYear')
|
||||
entry_id = f"{first_author.split()[-1]}{year}_{i}"
|
||||
|
||||
record = {
|
||||
'id': entry_id,
|
||||
'type': entry.get('type_of_reference', 'article'),
|
||||
'title': entry.get('title', ''),
|
||||
'year': str(entry.get('year', '')),
|
||||
'doi': entry.get('doi', ''),
|
||||
'abstract': entry.get('abstract', ''),
|
||||
'journal': entry.get('journal_name', ''),
|
||||
'authors': '; '.join(entry.get('authors', [])),
|
||||
'keywords': '; '.join(entry.get('keywords', [])),
|
||||
'pdf_path': None
|
||||
}
|
||||
|
||||
# Try to find PDF in standard locations
|
||||
if pdf_base_dir:
|
||||
# Common patterns: FirstAuthorYear.pdf, doi_cleaned.pdf, etc.
|
||||
pdf_candidates = [
|
||||
f"{entry_id}.pdf",
|
||||
f"{first_author.split()[-1]}_{year}.pdf"
|
||||
]
|
||||
if record['doi']:
|
||||
safe_doi = re.sub(r'[^\w\-_]', '_', record['doi'])
|
||||
pdf_candidates.append(f"{safe_doi}.pdf")
|
||||
|
||||
for candidate in pdf_candidates:
|
||||
pdf_path = pdf_base_dir / candidate
|
||||
if pdf_path.exists():
|
||||
record['pdf_path'] = str(pdf_path)
|
||||
break
|
||||
|
||||
metadata.append(record)
|
||||
|
||||
print(f"Loaded {len(metadata)} entries from RIS file")
|
||||
return metadata
|
||||
|
||||
|
||||
def load_directory_metadata(dir_path: Path) -> List[Dict]:
|
||||
"""Load metadata by scanning directory for PDFs"""
|
||||
pdf_files = list(dir_path.glob('**/*.pdf'))
|
||||
|
||||
metadata = []
|
||||
for pdf_path in pdf_files:
|
||||
# Generate ID from filename
|
||||
entry_id = pdf_path.stem
|
||||
|
||||
record = {
|
||||
'id': entry_id,
|
||||
'type': 'article',
|
||||
'title': entry_id.replace('_', ' '),
|
||||
'year': '',
|
||||
'doi': '',
|
||||
'abstract': '',
|
||||
'journal': '',
|
||||
'authors': '',
|
||||
'keywords': '',
|
||||
'pdf_path': str(pdf_path)
|
||||
}
|
||||
|
||||
# Try to extract DOI from filename if present
|
||||
doi_match = re.search(r'10\.\d{4,}/[^\s]+', entry_id)
|
||||
if doi_match:
|
||||
record['doi'] = doi_match.group(0)
|
||||
|
||||
metadata.append(record)
|
||||
|
||||
print(f"Found {len(metadata)} PDFs in directory")
|
||||
return metadata
|
||||
|
||||
|
||||
def load_doi_list_metadata(doi_list_path: Path) -> List[Dict]:
|
||||
"""Load metadata from a list of DOIs (will need to fetch metadata separately)"""
|
||||
with open(doi_list_path, 'r') as f:
|
||||
dois = [line.strip() for line in f if line.strip()]
|
||||
|
||||
metadata = []
|
||||
for doi in dois:
|
||||
safe_doi = re.sub(r'[^\w\-_]', '_', doi)
|
||||
record = {
|
||||
'id': safe_doi,
|
||||
'type': 'article',
|
||||
'title': '',
|
||||
'year': '',
|
||||
'doi': doi,
|
||||
'abstract': '',
|
||||
'journal': '',
|
||||
'authors': '',
|
||||
'keywords': '',
|
||||
'pdf_path': None
|
||||
}
|
||||
metadata.append(record)
|
||||
|
||||
print(f"Loaded {len(metadata)} DOIs")
|
||||
print("Note: You'll need to fetch full metadata and PDFs separately")
|
||||
return metadata
|
||||
|
||||
|
||||
def organize_pdfs(metadata: List[Dict], output_dir: Path) -> List[Dict]:
|
||||
"""Copy and rename PDFs to standardized directory structure"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
organized_metadata = []
|
||||
stats = {'copied': 0, 'missing': 0, 'total': len(metadata)}
|
||||
|
||||
for record in metadata:
|
||||
if record['pdf_path'] and Path(record['pdf_path']).exists():
|
||||
source_path = Path(record['pdf_path'])
|
||||
dest_path = output_dir / f"{record['id']}.pdf"
|
||||
|
||||
try:
|
||||
shutil.copy2(source_path, dest_path)
|
||||
record['pdf_path'] = str(dest_path)
|
||||
stats['copied'] += 1
|
||||
except Exception as e:
|
||||
print(f"Error copying {source_path}: {e}")
|
||||
stats['missing'] += 1
|
||||
else:
|
||||
if record['pdf_path']:
|
||||
print(f"PDF not found: {record['pdf_path']}")
|
||||
stats['missing'] += 1
|
||||
|
||||
organized_metadata.append(record)
|
||||
|
||||
print(f"\nPDF Organization Summary:")
|
||||
print(f" Total entries: {stats['total']}")
|
||||
print(f" PDFs copied: {stats['copied']}")
|
||||
print(f" PDFs missing: {stats['missing']}")
|
||||
|
||||
return organized_metadata
|
||||
|
||||
|
||||
def save_metadata(metadata: List[Dict], output_path: Path):
|
||||
"""Save metadata to JSON file"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nMetadata saved to: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
source_path = Path(args.source)
|
||||
pdf_base_dir = Path(args.pdf_dir) if args.pdf_dir else None
|
||||
output_path = Path(args.output)
|
||||
|
||||
# Load metadata based on source type
|
||||
if args.source_type == 'bibtex':
|
||||
metadata = load_bibtex_metadata(source_path, pdf_base_dir)
|
||||
elif args.source_type == 'ris':
|
||||
metadata = load_ris_metadata(source_path, pdf_base_dir)
|
||||
elif args.source_type == 'directory':
|
||||
metadata = load_directory_metadata(source_path)
|
||||
elif args.source_type == 'doi_list':
|
||||
metadata = load_doi_list_metadata(source_path)
|
||||
else:
|
||||
raise ValueError(f"Unknown source type: {args.source_type}")
|
||||
|
||||
# Organize PDFs if requested
|
||||
if args.organize_pdfs:
|
||||
pdf_output_dir = Path(args.pdf_output_dir)
|
||||
metadata = organize_pdfs(metadata, pdf_output_dir)
|
||||
|
||||
# Save metadata
|
||||
save_metadata(metadata, output_path)
|
||||
|
||||
# Print summary statistics
|
||||
total = len(metadata)
|
||||
with_pdfs = sum(1 for r in metadata if r['pdf_path'])
|
||||
with_abstracts = sum(1 for r in metadata if r['abstract'])
|
||||
with_dois = sum(1 for r in metadata if r['doi'])
|
||||
|
||||
print(f"\nMetadata Summary:")
|
||||
print(f" Total entries: {total}")
|
||||
print(f" With PDFs: {with_pdfs}")
|
||||
print(f" With abstracts: {with_abstracts}")
|
||||
print(f" With DOIs: {with_dois}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
468
skills/extract_from_pdfs/scripts/02_filter_abstracts.py
Normal file
468
skills/extract_from_pdfs/scripts/02_filter_abstracts.py
Normal file
@@ -0,0 +1,468 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Filter papers based on abstract content using Claude API or local models.
|
||||
Reduces processing costs by identifying relevant papers before full PDF extraction.
|
||||
This script template needs to be customized with your specific filtering criteria.
|
||||
|
||||
Supports:
|
||||
- Claude Haiku (cheap, fast API option)
|
||||
- Claude Sonnet (more accurate API option)
|
||||
- Local models via Ollama (free, private, requires local setup)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
|
||||
from anthropic.types.messages.batch_create_params import Request
|
||||
|
||||
try:
|
||||
import requests
|
||||
REQUESTS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REQUESTS_AVAILABLE = False
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Filter papers by analyzing abstracts with Claude or local models',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Backend options:
|
||||
anthropic-haiku : Claude 3 Haiku (cheap, fast, ~$0.25/million input tokens)
|
||||
anthropic-sonnet : Claude 3.5 Sonnet (more accurate, ~$3/million input tokens)
|
||||
ollama : Local model via Ollama (free, requires local setup)
|
||||
|
||||
Local model setup (Ollama):
|
||||
1. Install Ollama: https://ollama.com
|
||||
2. Pull a model: ollama pull llama3.1:8b
|
||||
3. Run server: ollama serve (usually starts automatically)
|
||||
4. Use --backend ollama --ollama-model llama3.1:8b
|
||||
|
||||
Recommended models for Ollama:
|
||||
- llama3.1:8b (good balance)
|
||||
- llama3.1:70b (better accuracy, needs more RAM)
|
||||
- mistral:7b (fast, good for simple filtering)
|
||||
- qwen2.5:7b (good multilingual support)
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--metadata',
|
||||
required=True,
|
||||
help='Input metadata JSON file from step 01'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='filtered_papers.json',
|
||||
help='Output JSON file with filter results'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--backend',
|
||||
choices=['anthropic-haiku', 'anthropic-sonnet', 'ollama'],
|
||||
default='anthropic-haiku',
|
||||
help='Model backend to use (default: anthropic-haiku for cost efficiency)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--ollama-model',
|
||||
default='llama3.1:8b',
|
||||
help='Ollama model name (default: llama3.1:8b)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--ollama-url',
|
||||
default='http://localhost:11434',
|
||||
help='Ollama server URL (default: http://localhost:11434)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use-batches',
|
||||
action='store_true',
|
||||
help='Use Anthropic Batches API (only for anthropic backends)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--test',
|
||||
action='store_true',
|
||||
help='Run in test mode (process only 10 records)'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_metadata(metadata_path: Path) -> List[Dict]:
|
||||
"""Load metadata from JSON file"""
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_existing_results(output_path: Path) -> Dict:
|
||||
"""Load existing filter results if available"""
|
||||
if output_path.exists():
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def save_results(results: Dict, output_path: Path):
|
||||
"""Save filter results to JSON file"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_filter_prompt(title: str, abstract: str) -> str:
|
||||
"""
|
||||
Create the filtering prompt.
|
||||
|
||||
TODO: CUSTOMIZE THIS PROMPT FOR YOUR SPECIFIC USE CASE
|
||||
|
||||
This is a template. Replace the example criteria with your own.
|
||||
"""
|
||||
return f"""You are analyzing scientific literature to identify relevant papers for a research project.
|
||||
|
||||
<title>
|
||||
{title}
|
||||
</title>
|
||||
|
||||
<abstract>
|
||||
{abstract}
|
||||
</abstract>
|
||||
|
||||
Your task is to determine if this paper meets the following criteria:
|
||||
|
||||
TODO: Replace these example criteria with your own:
|
||||
|
||||
1. Does the paper contain PRIMARY empirical data (not review/meta-analysis)?
|
||||
2. Does the paper report [YOUR SPECIFIC DATA TYPE, e.g., "field observations", "experimental measurements", "clinical outcomes"]?
|
||||
3. Is the geographic/temporal/taxonomic scope relevant to [YOUR STUDY SYSTEM]?
|
||||
|
||||
Important considerations:
|
||||
- Be conservative: when in doubt, include the paper (false positives are better than false negatives)
|
||||
- Distinguish between primary data and citations of others' work
|
||||
- Consider whether the abstract suggests the full paper likely contains the data of interest
|
||||
|
||||
Provide your determination as a JSON object with these boolean fields:
|
||||
1. "has_relevant_data": true if the paper likely contains the data type of interest
|
||||
2. "is_primary_research": true if the paper reports original empirical data
|
||||
3. "meets_scope": true if the study system/scope is relevant
|
||||
|
||||
Also provide:
|
||||
4. "confidence": your confidence level (high/medium/low)
|
||||
5. "reasoning": brief explanation (1-2 sentences)
|
||||
|
||||
Wrap your response in <output> tags. Example:
|
||||
<output>
|
||||
{{
|
||||
"has_relevant_data": true,
|
||||
"is_primary_research": true,
|
||||
"meets_scope": true,
|
||||
"confidence": "high",
|
||||
"reasoning": "Abstract explicitly mentions field observations of the target phenomenon in the relevant geographic region."
|
||||
}}
|
||||
</output>
|
||||
|
||||
Base your determination solely on the title and abstract provided."""
|
||||
|
||||
|
||||
def extract_json_from_xml(text: str) -> Dict:
|
||||
"""Extract JSON from XML output tags in Claude's response"""
|
||||
import re
|
||||
|
||||
match = re.search(r'<output>\s*(\{.*?\})\s*</output>', text, re.DOTALL)
|
||||
if match:
|
||||
json_str = match.group(1)
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse JSON: {e}")
|
||||
print(f"JSON string: {json_str}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def filter_paper_ollama(record: Dict, ollama_url: str, ollama_model: str) -> Dict:
|
||||
"""Use local Ollama model to filter a single paper"""
|
||||
if not REQUESTS_AVAILABLE:
|
||||
return {
|
||||
'status': 'error',
|
||||
'reason': 'requests library not available. Install with: pip install requests'
|
||||
}
|
||||
|
||||
if not record.get('title') or not record.get('abstract'):
|
||||
return {
|
||||
'status': 'skipped',
|
||||
'reason': 'missing_title_or_abstract'
|
||||
}
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Ollama uses OpenAI-compatible chat API
|
||||
response = requests.post(
|
||||
f"{ollama_url}/api/chat",
|
||||
json={
|
||||
"model": ollama_model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a scientific literature analyst specializing in identifying relevant papers for systematic reviews and meta-analyses."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": create_filter_prompt(record['title'], record['abstract'])
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0,
|
||||
"num_predict": 2048
|
||||
}
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
content = data.get('message', {}).get('content', '')
|
||||
result = extract_json_from_xml(content)
|
||||
|
||||
if result:
|
||||
return {
|
||||
'status': 'success',
|
||||
'filter_result': result,
|
||||
'model_used': ollama_model
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'error',
|
||||
'reason': 'failed_to_parse_json',
|
||||
'raw_response': content[:500]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'error',
|
||||
'reason': f'Ollama API error: {response.status_code} {response.text[:200]}'
|
||||
}
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
return {
|
||||
'status': 'error',
|
||||
'reason': f'Cannot connect to Ollama at {ollama_url}. Make sure Ollama is running: ollama serve'
|
||||
}
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
return {
|
||||
'status': 'error',
|
||||
'reason': str(e)
|
||||
}
|
||||
time.sleep(2 ** attempt)
|
||||
|
||||
|
||||
def filter_paper_direct(client: Anthropic, record: Dict, model: str) -> Dict:
|
||||
"""Use Claude API directly to filter a single paper"""
|
||||
if not record.get('title') or not record.get('abstract'):
|
||||
return {
|
||||
'status': 'skipped',
|
||||
'reason': 'missing_title_or_abstract'
|
||||
}
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=2048,
|
||||
temperature=0,
|
||||
system="You are a scientific literature analyst specializing in identifying relevant papers for systematic reviews and meta-analyses.",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": create_filter_prompt(record['title'], record['abstract'])
|
||||
}]
|
||||
)
|
||||
|
||||
result = extract_json_from_xml(response.content[0].text)
|
||||
if result:
|
||||
return {
|
||||
'status': 'success',
|
||||
'filter_result': result,
|
||||
'model_used': model
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'error',
|
||||
'reason': 'failed_to_parse_json'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
return {
|
||||
'status': 'error',
|
||||
'reason': str(e)
|
||||
}
|
||||
time.sleep(2 ** attempt)
|
||||
|
||||
|
||||
def filter_papers_batch(client: Anthropic, records: List[Dict], model: str) -> Dict[str, Dict]:
|
||||
"""Use Claude Batches API to filter multiple papers efficiently"""
|
||||
requests = []
|
||||
|
||||
for record in records:
|
||||
if not record.get('title') or not record.get('abstract'):
|
||||
continue
|
||||
|
||||
requests.append(Request(
|
||||
custom_id=record['id'],
|
||||
params=MessageCreateParamsNonStreaming(
|
||||
model=model,
|
||||
max_tokens=2048,
|
||||
temperature=0,
|
||||
system="You are a scientific literature analyst specializing in identifying relevant papers for systematic reviews and meta-analyses.",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": create_filter_prompt(record['title'], record['abstract'])
|
||||
}]
|
||||
)
|
||||
))
|
||||
|
||||
if not requests:
|
||||
print("No papers to process (missing titles or abstracts)")
|
||||
return {}
|
||||
|
||||
# Create batch
|
||||
print(f"Creating batch with {len(requests)} requests...")
|
||||
message_batch = client.messages.batches.create(requests=requests)
|
||||
print(f"Batch created: {message_batch.id}")
|
||||
|
||||
# Poll for completion
|
||||
while message_batch.processing_status == "in_progress":
|
||||
print("Waiting for batch processing...")
|
||||
time.sleep(30)
|
||||
message_batch = client.messages.batches.retrieve(message_batch.id)
|
||||
|
||||
# Process results
|
||||
results = {}
|
||||
if message_batch.processing_status == "ended":
|
||||
print("Batch completed. Processing results...")
|
||||
for result in client.messages.batches.results(message_batch.id):
|
||||
if result.result.type == "succeeded":
|
||||
filter_result = extract_json_from_xml(
|
||||
result.result.message.content[0].text
|
||||
)
|
||||
if filter_result:
|
||||
results[result.custom_id] = {
|
||||
'status': 'success',
|
||||
'filter_result': filter_result
|
||||
}
|
||||
else:
|
||||
results[result.custom_id] = {
|
||||
'status': 'error',
|
||||
'reason': 'failed_to_parse_json'
|
||||
}
|
||||
else:
|
||||
results[result.custom_id] = {
|
||||
'status': 'error',
|
||||
'reason': f"{result.result.type}: {getattr(result.result, 'error', 'unknown error')}"
|
||||
}
|
||||
else:
|
||||
print(f"Batch failed with status: {message_batch.processing_status}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_model_name(backend: str) -> str:
|
||||
"""Get the appropriate model name for the backend"""
|
||||
if backend == 'anthropic-haiku':
|
||||
return 'claude-3-5-haiku-20241022'
|
||||
elif backend == 'anthropic-sonnet':
|
||||
return 'claude-3-5-sonnet-20241022'
|
||||
return backend
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Backend-specific setup
|
||||
client = None
|
||||
if args.backend.startswith('anthropic'):
|
||||
if not os.getenv('ANTHROPIC_API_KEY'):
|
||||
raise ValueError("Please set ANTHROPIC_API_KEY environment variable for Anthropic backends")
|
||||
client = Anthropic()
|
||||
model = get_model_name(args.backend)
|
||||
print(f"Using Anthropic backend: {model}")
|
||||
elif args.backend == 'ollama':
|
||||
if args.use_batches:
|
||||
print("Warning: Batches API not available for Ollama. Processing sequentially.")
|
||||
args.use_batches = False
|
||||
print(f"Using Ollama backend: {args.ollama_model} at {args.ollama_url}")
|
||||
print("Make sure Ollama is running: ollama serve")
|
||||
|
||||
# Load metadata
|
||||
metadata = load_metadata(Path(args.metadata))
|
||||
print(f"Loaded {len(metadata)} metadata records")
|
||||
|
||||
# Apply test mode if specified
|
||||
if args.test:
|
||||
metadata = metadata[:10]
|
||||
print(f"Test mode: processing {len(metadata)} records")
|
||||
|
||||
# Load existing results
|
||||
output_path = Path(args.output)
|
||||
results = load_existing_results(output_path)
|
||||
print(f"Loaded {len(results)} existing results")
|
||||
|
||||
# Identify papers to process
|
||||
to_process = [r for r in metadata if r['id'] not in results]
|
||||
print(f"Papers to process: {len(to_process)}")
|
||||
|
||||
if not to_process:
|
||||
print("All papers already processed!")
|
||||
return
|
||||
|
||||
# Process papers based on backend
|
||||
if args.backend == 'ollama':
|
||||
print("Processing papers with Ollama...")
|
||||
for record in to_process:
|
||||
print(f"Processing: {record['id']}")
|
||||
result = filter_paper_ollama(record, args.ollama_url, args.ollama_model)
|
||||
results[record['id']] = result
|
||||
save_results(results, output_path)
|
||||
# No sleep needed for local models
|
||||
elif args.use_batches:
|
||||
print("Using Batches API...")
|
||||
batch_results = filter_papers_batch(client, to_process, model)
|
||||
results.update(batch_results)
|
||||
else:
|
||||
print("Processing papers sequentially with Anthropic API...")
|
||||
for record in to_process:
|
||||
print(f"Processing: {record['id']}")
|
||||
result = filter_paper_direct(client, record, model)
|
||||
results[record['id']] = result
|
||||
save_results(results, output_path)
|
||||
time.sleep(1) # Rate limiting
|
||||
|
||||
# Save final results
|
||||
save_results(results, output_path)
|
||||
|
||||
# Print summary statistics
|
||||
total = len(results)
|
||||
successful = sum(1 for r in results.values() if r.get('status') == 'success')
|
||||
relevant = sum(
|
||||
1 for r in results.values()
|
||||
if r.get('status') == 'success' and r.get('filter_result', {}).get('has_relevant_data')
|
||||
)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Filtering Summary")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total papers processed: {total}")
|
||||
print(f"Successfully analyzed: {successful}")
|
||||
print(f"Papers with relevant data: {relevant}")
|
||||
print(f"Relevance rate: {relevant/successful*100:.1f}%" if successful > 0 else "N/A")
|
||||
print(f"\nResults saved to: {output_path}")
|
||||
print(f"\nNext step: Review results and proceed to PDF extraction for relevant papers")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
478
skills/extract_from_pdfs/scripts/03_extract_from_pdfs.py
Normal file
478
skills/extract_from_pdfs/scripts/03_extract_from_pdfs.py
Normal file
@@ -0,0 +1,478 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Extract structured data from PDFs using Claude API.
|
||||
Supports multiple PDF processing methods and prompt caching for efficiency.
|
||||
This script template needs to be customized with your specific extraction schema.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import re
|
||||
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
|
||||
from anthropic.types.messages.batch_create_params import Request
|
||||
|
||||
|
||||
# Configuration
|
||||
BATCH_SIZE = 5
|
||||
SIMULTANEOUS_BATCHES = 4
|
||||
BATCH_CHECK_INTERVAL = 30
|
||||
BATCH_SUBMISSION_INTERVAL = 20
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract structured data from PDFs using Claude'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--metadata',
|
||||
required=True,
|
||||
help='Input metadata JSON file (from step 01 or 02)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--schema',
|
||||
required=True,
|
||||
help='JSON file defining extraction schema and prompts'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='extracted_data.json',
|
||||
help='Output JSON file with extraction results'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--method',
|
||||
choices=['base64', 'files_api', 'batches'],
|
||||
default='batches',
|
||||
help='PDF processing method (default: batches)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use-caching',
|
||||
action='store_true',
|
||||
help='Enable prompt caching (reduces costs by ~90%% for repeated queries)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--test',
|
||||
action='store_true',
|
||||
help='Run in test mode (process only 3 PDFs)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
default='claude-3-5-sonnet-20241022',
|
||||
help='Claude model to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--filter-results',
|
||||
help='Optional: JSON file with filter results from step 02 (only process relevant papers)'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_metadata(metadata_path: Path) -> List[Dict]:
|
||||
"""Load metadata from JSON file"""
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_schema(schema_path: Path) -> Dict:
|
||||
"""Load extraction schema definition"""
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_filter_results(filter_path: Path) -> Dict:
|
||||
"""Load filter results from step 02"""
|
||||
with open(filter_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_existing_results(output_path: Path) -> Dict:
|
||||
"""Load existing extraction results if available"""
|
||||
if output_path.exists():
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def save_results(results: Dict, output_path: Path):
|
||||
"""Save extraction results to JSON file"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_extraction_prompt(schema: Dict) -> str:
|
||||
"""
|
||||
Create extraction prompt from schema definition.
|
||||
|
||||
The schema JSON should contain:
|
||||
- system_context: Description of the analysis task
|
||||
- instructions: Step-by-step analysis instructions
|
||||
- output_schema: JSON schema for the output
|
||||
- output_example: Example of desired output
|
||||
|
||||
TODO: Customize schema.json for your specific use case
|
||||
"""
|
||||
prompt_parts = []
|
||||
|
||||
# Add objective
|
||||
if 'objective' in schema:
|
||||
prompt_parts.append(f"Your objective is to {schema['objective']}\n")
|
||||
|
||||
# Add instructions
|
||||
if 'instructions' in schema:
|
||||
prompt_parts.append("Please follow these steps:\n")
|
||||
for i, instruction in enumerate(schema['instructions'], 1):
|
||||
prompt_parts.append(f"{i}. {instruction}")
|
||||
prompt_parts.append("")
|
||||
|
||||
# Add analysis framework
|
||||
if 'analysis_steps' in schema:
|
||||
prompt_parts.append("<analysis_framework>")
|
||||
for step in schema['analysis_steps']:
|
||||
prompt_parts.append(f"- {step}")
|
||||
prompt_parts.append("</analysis_framework>\n")
|
||||
prompt_parts.append(
|
||||
"Your analysis must be wrapped within <analysis> tags. "
|
||||
"Be thorough and explicit in your reasoning.\n"
|
||||
)
|
||||
|
||||
# Add output schema explanation
|
||||
if 'output_schema' in schema:
|
||||
prompt_parts.append("<output_schema>")
|
||||
prompt_parts.append(json.dumps(schema['output_schema'], indent=2))
|
||||
prompt_parts.append("</output_schema>\n")
|
||||
|
||||
# Add output example
|
||||
if 'output_example' in schema:
|
||||
prompt_parts.append("<output_example>")
|
||||
prompt_parts.append(json.dumps(schema['output_example'], indent=2))
|
||||
prompt_parts.append("</output_example>\n")
|
||||
|
||||
# Add important notes
|
||||
if 'important_notes' in schema:
|
||||
prompt_parts.append("Important considerations:")
|
||||
for note in schema['important_notes']:
|
||||
prompt_parts.append(f"- {note}")
|
||||
prompt_parts.append("")
|
||||
|
||||
# Add final instruction
|
||||
prompt_parts.append(
|
||||
"After your analysis, provide the final output in the following JSON format, "
|
||||
"wrapped in <output> tags. The output must be valid, parseable JSON.\n"
|
||||
)
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
|
||||
def extract_json_from_response(text: str) -> Optional[Dict]:
|
||||
"""Extract JSON from XML output tags in Claude's response"""
|
||||
match = re.search(r'<output>\s*(\{.*?\})\s*</output>', text, re.DOTALL)
|
||||
if match:
|
||||
json_str = match.group(1)
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse JSON: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def extract_analysis_from_response(text: str) -> Optional[str]:
|
||||
"""Extract analysis from XML tags in Claude's response"""
|
||||
match = re.search(r'<analysis>(.*?)</analysis>', text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
|
||||
def process_pdf_base64(
|
||||
client: Anthropic,
|
||||
pdf_path: Path,
|
||||
schema: Dict,
|
||||
model: str
|
||||
) -> Dict:
|
||||
"""Process a single PDF using base64 encoding (direct upload)"""
|
||||
if not pdf_path.exists():
|
||||
return {
|
||||
'status': 'error',
|
||||
'error': f'PDF not found: {pdf_path}'
|
||||
}
|
||||
|
||||
# Check file size (32MB limit)
|
||||
file_size = pdf_path.stat().st_size
|
||||
if file_size > 32 * 1024 * 1024:
|
||||
return {
|
||||
'status': 'error',
|
||||
'error': f'PDF exceeds 32MB limit: {file_size / 1024 / 1024:.1f}MB'
|
||||
}
|
||||
|
||||
try:
|
||||
# Read and encode PDF
|
||||
with open(pdf_path, 'rb') as f:
|
||||
pdf_data = base64.b64encode(f.read()).decode('utf-8')
|
||||
|
||||
# Create message
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=16384,
|
||||
temperature=0,
|
||||
system=schema.get('system_context', 'You are a scientific research assistant.'),
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "application/pdf",
|
||||
"data": pdf_data
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": create_extraction_prompt(schema)
|
||||
}
|
||||
]
|
||||
}]
|
||||
)
|
||||
|
||||
response_text = response.content[0].text
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'extracted_data': extract_json_from_response(response_text),
|
||||
'analysis': extract_analysis_from_response(response_text),
|
||||
'input_tokens': response.usage.input_tokens,
|
||||
'output_tokens': response.usage.output_tokens
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
|
||||
def process_pdfs_batch(
|
||||
client: Anthropic,
|
||||
records: List[tuple],
|
||||
schema: Dict,
|
||||
model: str
|
||||
) -> Dict[str, Dict]:
|
||||
"""Process multiple PDFs using Batches API for efficiency"""
|
||||
all_results = {}
|
||||
|
||||
for window_start in range(0, len(records), SIMULTANEOUS_BATCHES * BATCH_SIZE):
|
||||
window_records = records[window_start:window_start + (SIMULTANEOUS_BATCHES * BATCH_SIZE)]
|
||||
print(f"\nProcessing window starting at index {window_start} ({len(window_records)} PDFs)")
|
||||
|
||||
active_batches = {}
|
||||
|
||||
for batch_start in range(0, len(window_records), BATCH_SIZE):
|
||||
batch_records = window_records[batch_start:batch_start + BATCH_SIZE]
|
||||
requests = []
|
||||
|
||||
for record_id, pdf_data in batch_records:
|
||||
requests.append(Request(
|
||||
custom_id=record_id,
|
||||
params=MessageCreateParamsNonStreaming(
|
||||
model=model,
|
||||
max_tokens=16384,
|
||||
temperature=0,
|
||||
system=schema.get('system_context', 'You are a scientific research assistant.'),
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "application/pdf",
|
||||
"data": pdf_data
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": create_extraction_prompt(schema)
|
||||
}
|
||||
]
|
||||
}]
|
||||
)
|
||||
))
|
||||
|
||||
try:
|
||||
message_batch = client.messages.batches.create(requests=requests)
|
||||
print(f"Created batch {message_batch.id} with {len(requests)} requests")
|
||||
active_batches[message_batch.id] = {r.custom_id for r in requests}
|
||||
time.sleep(BATCH_SUBMISSION_INTERVAL)
|
||||
except Exception as e:
|
||||
print(f"Error creating batch: {e}")
|
||||
|
||||
# Wait for batches
|
||||
window_results = wait_for_batches(client, list(active_batches.keys()), schema)
|
||||
all_results.update(window_results)
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def wait_for_batches(
|
||||
client: Anthropic,
|
||||
batch_ids: List[str],
|
||||
schema: Dict
|
||||
) -> Dict[str, Dict]:
|
||||
"""Wait for batches to complete and return results"""
|
||||
print(f"\nWaiting for {len(batch_ids)} batches to complete...")
|
||||
|
||||
incomplete = set(batch_ids)
|
||||
|
||||
while incomplete:
|
||||
time.sleep(BATCH_CHECK_INTERVAL)
|
||||
|
||||
for batch_id in list(incomplete):
|
||||
batch = client.messages.batches.retrieve(batch_id)
|
||||
if batch.processing_status != "in_progress":
|
||||
incomplete.remove(batch_id)
|
||||
print(f"Batch {batch_id} completed: {batch.processing_status}")
|
||||
|
||||
# Collect results
|
||||
results = {}
|
||||
for batch_id in batch_ids:
|
||||
batch = client.messages.batches.retrieve(batch_id)
|
||||
if batch.processing_status == "ended":
|
||||
for result in client.messages.batches.results(batch_id):
|
||||
if result.result.type == "succeeded":
|
||||
response_text = result.result.message.content[0].text
|
||||
results[result.custom_id] = {
|
||||
'status': 'success',
|
||||
'extracted_data': extract_json_from_response(response_text),
|
||||
'analysis': extract_analysis_from_response(response_text),
|
||||
'input_tokens': result.result.message.usage.input_tokens,
|
||||
'output_tokens': result.result.message.usage.output_tokens
|
||||
}
|
||||
else:
|
||||
results[result.custom_id] = {
|
||||
'status': 'error',
|
||||
'error': str(getattr(result.result, 'error', 'Unknown error'))
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Check for API key
|
||||
if not os.getenv('ANTHROPIC_API_KEY'):
|
||||
raise ValueError("Please set ANTHROPIC_API_KEY environment variable")
|
||||
|
||||
client = Anthropic()
|
||||
|
||||
# Load inputs
|
||||
metadata = load_metadata(Path(args.metadata))
|
||||
schema = load_schema(Path(args.schema))
|
||||
print(f"Loaded {len(metadata)} metadata records")
|
||||
|
||||
# Filter by relevance if filter results provided
|
||||
if args.filter_results:
|
||||
filter_results = load_filter_results(Path(args.filter_results))
|
||||
relevant_ids = {
|
||||
id for id, result in filter_results.items()
|
||||
if result.get('status') == 'success'
|
||||
and result.get('filter_result', {}).get('has_relevant_data')
|
||||
}
|
||||
metadata = [r for r in metadata if r['id'] in relevant_ids]
|
||||
print(f"Filtered to {len(metadata)} relevant papers")
|
||||
|
||||
# Apply test mode
|
||||
if args.test:
|
||||
metadata = metadata[:3]
|
||||
print(f"Test mode: processing {len(metadata)} PDFs")
|
||||
|
||||
# Load existing results
|
||||
output_path = Path(args.output)
|
||||
results = load_existing_results(output_path)
|
||||
print(f"Loaded {len(results)} existing results")
|
||||
|
||||
# Prepare PDFs to process
|
||||
to_process = []
|
||||
for record in metadata:
|
||||
if record['id'] in results:
|
||||
continue
|
||||
if not record.get('pdf_path'):
|
||||
print(f"Skipping {record['id']}: no PDF path")
|
||||
continue
|
||||
pdf_path = Path(record['pdf_path'])
|
||||
if not pdf_path.exists():
|
||||
print(f"Skipping {record['id']}: PDF not found")
|
||||
continue
|
||||
|
||||
# Read and encode PDF
|
||||
try:
|
||||
with open(pdf_path, 'rb') as f:
|
||||
pdf_data = base64.b64encode(f.read()).decode('utf-8')
|
||||
to_process.append((record['id'], pdf_data))
|
||||
except Exception as e:
|
||||
print(f"Error reading {pdf_path}: {e}")
|
||||
|
||||
print(f"PDFs to process: {len(to_process)}")
|
||||
|
||||
if not to_process:
|
||||
print("All PDFs already processed!")
|
||||
return
|
||||
|
||||
# Process PDFs
|
||||
if args.method == 'batches':
|
||||
print("Using Batches API...")
|
||||
batch_results = process_pdfs_batch(client, to_process, schema, args.model)
|
||||
results.update(batch_results)
|
||||
else:
|
||||
print("Processing PDFs sequentially...")
|
||||
for record_id, pdf_data in to_process:
|
||||
print(f"Processing: {record_id}")
|
||||
# For sequential processing, reconstruct Path
|
||||
record = next(r for r in metadata if r['id'] == record_id)
|
||||
result = process_pdf_base64(
|
||||
client, Path(record['pdf_path']), schema, args.model
|
||||
)
|
||||
results[record_id] = result
|
||||
save_results(results, output_path)
|
||||
time.sleep(2)
|
||||
|
||||
# Save final results
|
||||
save_results(results, output_path)
|
||||
|
||||
# Print summary
|
||||
total = len(results)
|
||||
successful = sum(1 for r in results.values() if r.get('status') == 'success')
|
||||
total_input_tokens = sum(
|
||||
r.get('input_tokens', 0) for r in results.values()
|
||||
if r.get('status') == 'success'
|
||||
)
|
||||
total_output_tokens = sum(
|
||||
r.get('output_tokens', 0) for r in results.values()
|
||||
if r.get('status') == 'success'
|
||||
)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Extraction Summary")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total PDFs processed: {total}")
|
||||
print(f"Successful extractions: {successful}")
|
||||
print(f"Failed extractions: {total - successful}")
|
||||
print(f"\nToken usage:")
|
||||
print(f" Input tokens: {total_input_tokens:,}")
|
||||
print(f" Output tokens: {total_output_tokens:,}")
|
||||
print(f" Total tokens: {total_input_tokens + total_output_tokens:,}")
|
||||
print(f"\nResults saved to: {output_path}")
|
||||
print(f"\nNext step: Repair and validate JSON outputs")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
227
skills/extract_from_pdfs/scripts/04_repair_json.py
Normal file
227
skills/extract_from_pdfs/scripts/04_repair_json.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Repair and validate JSON extractions using json_repair library.
|
||||
Handles common JSON parsing issues and validates against schema.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import jsonschema
|
||||
|
||||
try:
|
||||
from json_repair import repair_json
|
||||
JSON_REPAIR_AVAILABLE = True
|
||||
except ImportError:
|
||||
JSON_REPAIR_AVAILABLE = False
|
||||
print("Warning: json_repair not installed. Install with: pip install json-repair")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Repair and validate JSON extractions'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input',
|
||||
required=True,
|
||||
help='Input JSON file with extraction results from step 03'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='cleaned_extractions.json',
|
||||
help='Output JSON file with cleaned results'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--schema',
|
||||
help='Optional: JSON schema file for validation'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strict',
|
||||
action='store_true',
|
||||
help='Strict mode: reject records that fail validation'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_results(input_path: Path) -> Dict:
|
||||
"""Load extraction results from JSON file"""
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_schema(schema_path: Path) -> Dict:
|
||||
"""Load JSON schema for validation"""
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
schema_data = json.load(f)
|
||||
return schema_data.get('output_schema', schema_data)
|
||||
|
||||
|
||||
def save_results(results: Dict, output_path: Path):
|
||||
"""Save cleaned results to JSON file"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def repair_json_data(data: Any) -> tuple[Any, bool]:
|
||||
"""
|
||||
Attempt to repair JSON data using json_repair library.
|
||||
Returns (repaired_data, success)
|
||||
"""
|
||||
if not JSON_REPAIR_AVAILABLE:
|
||||
return data, True # Skip repair if library not available
|
||||
|
||||
try:
|
||||
# Convert to JSON string and back to repair
|
||||
json_str = json.dumps(data)
|
||||
repaired_str = repair_json(json_str, return_objects=False)
|
||||
repaired_data = json.loads(repaired_str)
|
||||
return repaired_data, True
|
||||
except Exception as e:
|
||||
print(f"Failed to repair JSON: {e}")
|
||||
return data, False
|
||||
|
||||
|
||||
def validate_against_schema(data: Any, schema: Dict) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate data against JSON schema.
|
||||
Returns (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(instance=data, schema=schema)
|
||||
return True, None
|
||||
except jsonschema.exceptions.ValidationError as e:
|
||||
return False, str(e)
|
||||
except Exception as e:
|
||||
return False, f"Validation error: {str(e)}"
|
||||
|
||||
|
||||
def clean_extraction_result(
|
||||
result: Dict,
|
||||
schema: Optional[Dict] = None,
|
||||
strict: bool = False
|
||||
) -> Dict:
|
||||
"""
|
||||
Clean and validate a single extraction result.
|
||||
|
||||
Returns updated result with:
|
||||
- repaired_data: Repaired JSON if repair was needed
|
||||
- validation_status: 'valid', 'invalid', or 'repaired'
|
||||
- validation_errors: List of validation errors if any
|
||||
"""
|
||||
if result.get('status') != 'success':
|
||||
return result # Skip non-successful results
|
||||
|
||||
extracted_data = result.get('extracted_data')
|
||||
if not extracted_data:
|
||||
result['validation_status'] = 'invalid'
|
||||
result['validation_errors'] = ['No extracted data found']
|
||||
if strict:
|
||||
result['status'] = 'failed_validation'
|
||||
return result
|
||||
|
||||
# Try to repair JSON
|
||||
repaired_data, repair_success = repair_json_data(extracted_data)
|
||||
|
||||
# Validate against schema if provided
|
||||
validation_errors = []
|
||||
if schema:
|
||||
is_valid, error_msg = validate_against_schema(repaired_data, schema)
|
||||
if not is_valid:
|
||||
validation_errors.append(error_msg)
|
||||
if strict:
|
||||
result['status'] = 'failed_validation'
|
||||
|
||||
# Update result
|
||||
if repaired_data != extracted_data and repair_success:
|
||||
result['extracted_data'] = repaired_data
|
||||
result['validation_status'] = 'repaired'
|
||||
elif validation_errors:
|
||||
result['validation_status'] = 'invalid'
|
||||
else:
|
||||
result['validation_status'] = 'valid'
|
||||
|
||||
if validation_errors:
|
||||
result['validation_errors'] = validation_errors
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load inputs
|
||||
results = load_results(Path(args.input))
|
||||
print(f"Loaded {len(results)} extraction results")
|
||||
|
||||
schema = None
|
||||
if args.schema:
|
||||
schema = load_schema(Path(args.schema))
|
||||
print(f"Loaded validation schema from {args.schema}")
|
||||
|
||||
# Clean each result
|
||||
cleaned_results = {}
|
||||
stats = {
|
||||
'total': len(results),
|
||||
'valid': 0,
|
||||
'repaired': 0,
|
||||
'invalid': 0,
|
||||
'failed': 0
|
||||
}
|
||||
|
||||
for record_id, result in results.items():
|
||||
cleaned_result = clean_extraction_result(result, schema, args.strict)
|
||||
cleaned_results[record_id] = cleaned_result
|
||||
|
||||
# Update statistics
|
||||
if cleaned_result.get('status') == 'success':
|
||||
status = cleaned_result.get('validation_status', 'unknown')
|
||||
if status == 'valid':
|
||||
stats['valid'] += 1
|
||||
elif status == 'repaired':
|
||||
stats['repaired'] += 1
|
||||
elif status == 'invalid':
|
||||
stats['invalid'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
|
||||
# Save cleaned results
|
||||
output_path = Path(args.output)
|
||||
save_results(cleaned_results, output_path)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print("JSON Repair and Validation Summary")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total records: {stats['total']}")
|
||||
print(f"Valid JSON: {stats['valid']}")
|
||||
print(f"Repaired JSON: {stats['repaired']}")
|
||||
print(f"Invalid JSON: {stats['invalid']}")
|
||||
print(f"Failed extractions: {stats['failed']}")
|
||||
|
||||
if schema:
|
||||
validation_rate = (stats['valid'] + stats['repaired']) / stats['total'] * 100
|
||||
print(f"\nValidation rate: {validation_rate:.1f}%")
|
||||
|
||||
print(f"\nCleaned results saved to: {output_path}")
|
||||
|
||||
# Print examples of validation errors
|
||||
if stats['invalid'] > 0:
|
||||
print(f"\nShowing first 3 validation errors:")
|
||||
error_count = 0
|
||||
for record_id, result in cleaned_results.items():
|
||||
if result.get('validation_errors'):
|
||||
print(f"\n{record_id}:")
|
||||
for error in result['validation_errors'][:2]:
|
||||
print(f" - {error[:200]}")
|
||||
error_count += 1
|
||||
if error_count >= 3:
|
||||
break
|
||||
|
||||
print(f"\nNext step: Validate and enrich data with external APIs")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
390
skills/extract_from_pdfs/scripts/05_validate_with_apis.py
Normal file
390
skills/extract_from_pdfs/scripts/05_validate_with_apis.py
Normal file
@@ -0,0 +1,390 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Validate and enrich extracted data using external API databases.
|
||||
Supports common scientific databases for taxonomy, geography, chemistry, etc.
|
||||
|
||||
This script template includes examples for common databases. Customize for your needs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
import requests
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Validate and enrich data with external APIs'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input',
|
||||
required=True,
|
||||
help='Input JSON file with cleaned extraction results from step 04'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='validated_data.json',
|
||||
help='Output JSON file with validated and enriched data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--apis',
|
||||
required=True,
|
||||
help='JSON configuration file specifying which APIs to use and for which fields'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip-validation',
|
||||
action='store_true',
|
||||
help='Skip API calls, only load and structure data'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_results(input_path: Path) -> Dict:
|
||||
"""Load extraction results from JSON file"""
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_api_config(config_path: Path) -> Dict:
|
||||
"""Load API configuration"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_results(results: Dict, output_path: Path):
|
||||
"""Save validated results to JSON file"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Taxonomy validation functions
|
||||
# ==============================================================================
|
||||
|
||||
def validate_gbif_taxonomy(scientific_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Validate taxonomic name using GBIF (Global Biodiversity Information Facility).
|
||||
Returns standardized taxonomy if found.
|
||||
"""
|
||||
url = f"https://api.gbif.org/v1/species/match?name={quote(scientific_name)}"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('matchType') != 'NONE':
|
||||
return {
|
||||
'matched_name': data.get('canonicalName', scientific_name),
|
||||
'scientific_name': data.get('scientificName'),
|
||||
'rank': data.get('rank'),
|
||||
'kingdom': data.get('kingdom'),
|
||||
'phylum': data.get('phylum'),
|
||||
'class': data.get('class'),
|
||||
'order': data.get('order'),
|
||||
'family': data.get('family'),
|
||||
'genus': data.get('genus'),
|
||||
'gbif_id': data.get('usageKey'),
|
||||
'confidence': data.get('confidence'),
|
||||
'match_type': data.get('matchType'),
|
||||
'status': data.get('status')
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"GBIF API error for '{scientific_name}': {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_wfo_plant(scientific_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Validate plant name using World Flora Online.
|
||||
Returns standardized plant taxonomy if found.
|
||||
"""
|
||||
# WFO requires name parsing - this is a simplified example
|
||||
url = f"http://www.worldfloraonline.org/api/1.0/search?query={quote(scientific_name)}"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('results'):
|
||||
first_result = data['results'][0]
|
||||
return {
|
||||
'matched_name': first_result.get('name'),
|
||||
'scientific_name': first_result.get('scientificName'),
|
||||
'authors': first_result.get('authors'),
|
||||
'family': first_result.get('family'),
|
||||
'wfo_id': first_result.get('wfoId'),
|
||||
'status': first_result.get('status')
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"WFO API error for '{scientific_name}': {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Geography validation functions
|
||||
# ==============================================================================
|
||||
|
||||
def validate_geonames(location: str, country: Optional[str] = None) -> Optional[Dict]:
|
||||
"""
|
||||
Validate location using GeoNames.
|
||||
Note: Requires free GeoNames account and username.
|
||||
Set GEONAMES_USERNAME environment variable.
|
||||
"""
|
||||
import os
|
||||
username = os.getenv('GEONAMES_USERNAME')
|
||||
if not username:
|
||||
print("Warning: GEONAMES_USERNAME not set. Skipping GeoNames validation.")
|
||||
return None
|
||||
|
||||
url = f"http://api.geonames.org/searchJSON?q={quote(location)}&maxRows=1&username={username}"
|
||||
if country:
|
||||
url += f"&country={country[:2]}" # Country code
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('geonames'):
|
||||
place = data['geonames'][0]
|
||||
return {
|
||||
'matched_name': place.get('name'),
|
||||
'country': place.get('countryName'),
|
||||
'country_code': place.get('countryCode'),
|
||||
'admin1': place.get('adminName1'),
|
||||
'admin2': place.get('adminName2'),
|
||||
'latitude': place.get('lat'),
|
||||
'longitude': place.get('lng'),
|
||||
'geonames_id': place.get('geonameId')
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"GeoNames API error for '{location}': {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def geocode_location(address: str) -> Optional[Dict]:
|
||||
"""
|
||||
Geocode an address using OpenStreetMap Nominatim (free, no API key needed).
|
||||
Please use responsibly - add delays between calls.
|
||||
"""
|
||||
url = f"https://nominatim.openstreetmap.org/search?q={quote(address)}&format=json&limit=1"
|
||||
headers = {'User-Agent': 'Scientific-PDF-Extraction/1.0'}
|
||||
|
||||
try:
|
||||
time.sleep(1) # Be nice to OSM
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data:
|
||||
place = data[0]
|
||||
return {
|
||||
'display_name': place.get('display_name'),
|
||||
'latitude': place.get('lat'),
|
||||
'longitude': place.get('lon'),
|
||||
'osm_type': place.get('osm_type'),
|
||||
'osm_id': place.get('osm_id'),
|
||||
'place_rank': place.get('place_rank')
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Nominatim error for '{address}': {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Chemistry validation functions
|
||||
# ==============================================================================
|
||||
|
||||
def validate_pubchem_compound(compound_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Validate chemical compound using PubChem.
|
||||
Returns standardized compound information.
|
||||
"""
|
||||
url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{quote(compound_name)}/JSON"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if 'PC_Compounds' in data and data['PC_Compounds']:
|
||||
compound = data['PC_Compounds'][0]
|
||||
return {
|
||||
'cid': compound['id']['id']['cid'],
|
||||
'molecular_formula': compound.get('props', [{}])[0].get('value', {}).get('sval'),
|
||||
'pubchem_url': f"https://pubchem.ncbi.nlm.nih.gov/compound/{compound['id']['id']['cid']}"
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"PubChem API error for '{compound_name}': {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Gene/Protein validation functions
|
||||
# ==============================================================================
|
||||
|
||||
def validate_ncbi_gene(gene_symbol: str, organism: Optional[str] = None) -> Optional[Dict]:
|
||||
"""
|
||||
Validate gene using NCBI Gene database.
|
||||
"""
|
||||
query = gene_symbol
|
||||
if organism:
|
||||
query += f"[Gene Name] AND {organism}[Organism]"
|
||||
|
||||
search_url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=gene&term={quote(query)}&retmode=json"
|
||||
|
||||
try:
|
||||
response = requests.get(search_url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('esearchresult', {}).get('idlist'):
|
||||
gene_id = data['esearchresult']['idlist'][0]
|
||||
return {
|
||||
'gene_id': gene_id,
|
||||
'ncbi_url': f"https://www.ncbi.nlm.nih.gov/gene/{gene_id}"
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"NCBI Gene API error for '{gene_symbol}': {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main validation orchestration
|
||||
# ==============================================================================
|
||||
|
||||
API_VALIDATORS = {
|
||||
'gbif_taxonomy': validate_gbif_taxonomy,
|
||||
'wfo_plants': validate_wfo_plant,
|
||||
'geonames': validate_geonames,
|
||||
'geocode': geocode_location,
|
||||
'pubchem': validate_pubchem_compound,
|
||||
'ncbi_gene': validate_ncbi_gene
|
||||
}
|
||||
|
||||
|
||||
def validate_field(value: Any, api_name: str, extra_params: Dict = None) -> Optional[Dict]:
|
||||
"""
|
||||
Validate a single field value using the specified API.
|
||||
"""
|
||||
if not value or value == 'none' or value == '':
|
||||
return None
|
||||
|
||||
validator = API_VALIDATORS.get(api_name)
|
||||
if not validator:
|
||||
print(f"Unknown API: {api_name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
if extra_params:
|
||||
return validator(value, **extra_params)
|
||||
else:
|
||||
return validator(value)
|
||||
except Exception as e:
|
||||
print(f"Validation error for {api_name} with value '{value}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_record(
|
||||
record_data: Dict,
|
||||
api_config: Dict,
|
||||
skip_validation: bool = False
|
||||
) -> Dict:
|
||||
"""
|
||||
Process a single record, validating specified fields.
|
||||
|
||||
api_config should map field names to API names:
|
||||
{
|
||||
"field_mappings": {
|
||||
"species": {"api": "gbif_taxonomy", "output_field": "validated_species"},
|
||||
"location": {"api": "geocode", "output_field": "geocoded_location"}
|
||||
}
|
||||
}
|
||||
"""
|
||||
if skip_validation:
|
||||
return record_data
|
||||
|
||||
field_mappings = api_config.get('field_mappings', {})
|
||||
|
||||
for field_name, field_config in field_mappings.items():
|
||||
api_name = field_config.get('api')
|
||||
output_field = field_config.get('output_field', f'validated_{field_name}')
|
||||
extra_params = field_config.get('extra_params', {})
|
||||
|
||||
# Handle nested fields (e.g., 'records.species')
|
||||
if '.' in field_name:
|
||||
# This is a simplified example - you'd need to implement proper nested access
|
||||
continue
|
||||
|
||||
value = record_data.get(field_name)
|
||||
if value:
|
||||
validated = validate_field(value, api_name, extra_params)
|
||||
if validated:
|
||||
record_data[output_field] = validated
|
||||
|
||||
return record_data
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load inputs
|
||||
results = load_results(Path(args.input))
|
||||
api_config = load_api_config(Path(args.apis))
|
||||
print(f"Loaded {len(results)} extraction results")
|
||||
|
||||
# Process each result
|
||||
validated_results = {}
|
||||
stats = {'total': 0, 'validated': 0, 'failed': 0}
|
||||
|
||||
for record_id, result in results.items():
|
||||
if result.get('status') != 'success':
|
||||
validated_results[record_id] = result
|
||||
stats['failed'] += 1
|
||||
continue
|
||||
|
||||
stats['total'] += 1
|
||||
|
||||
# Get extracted data
|
||||
extracted_data = result.get('extracted_data', {})
|
||||
|
||||
# Process/validate the data
|
||||
validated_data = process_record(
|
||||
extracted_data.copy(),
|
||||
api_config,
|
||||
args.skip_validation
|
||||
)
|
||||
|
||||
# Update result
|
||||
result['validated_data'] = validated_data
|
||||
validated_results[record_id] = result
|
||||
stats['validated'] += 1
|
||||
|
||||
# Rate limiting
|
||||
if not args.skip_validation:
|
||||
time.sleep(0.5)
|
||||
|
||||
# Save results
|
||||
output_path = Path(args.output)
|
||||
save_results(validated_results, output_path)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print("Validation and Enrichment Summary")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total records: {len(results)}")
|
||||
print(f"Successfully validated: {stats['validated']}")
|
||||
print(f"Failed extractions: {stats['failed']}")
|
||||
print(f"\nResults saved to: {output_path}")
|
||||
print(f"\nNext step: Export to analysis format")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
345
skills/extract_from_pdfs/scripts/06_export_database.py
Normal file
345
skills/extract_from_pdfs/scripts/06_export_database.py
Normal file
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Export validated data to various analysis formats.
|
||||
Supports Python (pandas/SQLite), R (RDS/CSV), Excel, and more.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
import sys
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Export validated data to analysis format'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input',
|
||||
required=True,
|
||||
help='Input JSON file with validated data from step 05'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--format',
|
||||
choices=['python', 'r', 'csv', 'json', 'excel', 'sqlite'],
|
||||
required=True,
|
||||
help='Output format'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
required=True,
|
||||
help='Output file path (without extension for some formats)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--flatten',
|
||||
action='store_true',
|
||||
help='Flatten nested JSON structures for tabular formats'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--include-metadata',
|
||||
action='store_true',
|
||||
help='Include original paper metadata in output'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_results(input_path: Path) -> Dict:
|
||||
"""Load validated results from JSON file"""
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def flatten_dict(d: Dict, parent_key: str = '', sep: str = '_') -> Dict:
|
||||
"""
|
||||
Flatten nested dictionary structure.
|
||||
Useful for converting JSON to tabular format.
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
elif isinstance(v, list):
|
||||
# Convert lists to comma-separated strings
|
||||
if v and isinstance(v[0], dict):
|
||||
# List of dicts - create numbered columns
|
||||
for i, item in enumerate(v):
|
||||
items.extend(flatten_dict(item, f"{new_key}_{i}", sep=sep).items())
|
||||
else:
|
||||
# Simple list
|
||||
items.append((new_key, ', '.join(str(x) for x in v)))
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def extract_records(results: Dict, flatten: bool = False, include_metadata: bool = False) -> List[Dict]:
|
||||
"""
|
||||
Extract records from results structure.
|
||||
Returns a list of dictionaries suitable for tabular export.
|
||||
"""
|
||||
records = []
|
||||
|
||||
for paper_id, result in results.items():
|
||||
if result.get('status') != 'success':
|
||||
continue
|
||||
|
||||
# Get the validated data (or fall back to extracted data)
|
||||
data = result.get('validated_data', result.get('extracted_data', {}))
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# Check if data contains nested records or is a single record
|
||||
if 'records' in data and isinstance(data['records'], list):
|
||||
# Multiple records per paper
|
||||
for record in data['records']:
|
||||
record_dict = record.copy() if isinstance(record, dict) else {'value': record}
|
||||
|
||||
# Add paper-level fields
|
||||
if include_metadata:
|
||||
record_dict['paper_id'] = paper_id
|
||||
for key in data:
|
||||
if key != 'records':
|
||||
record_dict[f'paper_{key}'] = data[key]
|
||||
|
||||
if flatten:
|
||||
record_dict = flatten_dict(record_dict)
|
||||
|
||||
records.append(record_dict)
|
||||
else:
|
||||
# Single record per paper
|
||||
record_dict = data.copy()
|
||||
if include_metadata:
|
||||
record_dict['paper_id'] = paper_id
|
||||
|
||||
if flatten:
|
||||
record_dict = flatten_dict(record_dict)
|
||||
|
||||
records.append(record_dict)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def export_to_csv(records: List[Dict], output_path: Path):
|
||||
"""Export to CSV format"""
|
||||
if not records:
|
||||
print("No records to export")
|
||||
return
|
||||
|
||||
# Get all possible field names
|
||||
fieldnames = set()
|
||||
for record in records:
|
||||
fieldnames.update(record.keys())
|
||||
fieldnames = sorted(fieldnames)
|
||||
|
||||
with open(output_path, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(records)
|
||||
|
||||
print(f"Exported {len(records)} records to CSV: {output_path}")
|
||||
|
||||
|
||||
def export_to_json(records: List[Dict], output_path: Path):
|
||||
"""Export to JSON format"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(records, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Exported {len(records)} records to JSON: {output_path}")
|
||||
|
||||
|
||||
def export_to_python(records: List[Dict], output_path: Path):
|
||||
"""Export to Python format (pandas DataFrame pickle)"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
print("Error: pandas is required for Python export. Install with: pip install pandas")
|
||||
sys.exit(1)
|
||||
|
||||
df = pd.DataFrame(records)
|
||||
|
||||
# Save as pickle
|
||||
pickle_path = output_path.with_suffix('.pkl')
|
||||
df.to_pickle(pickle_path)
|
||||
print(f"Exported {len(records)} records to pandas pickle: {pickle_path}")
|
||||
|
||||
# Also create a Python script to load it
|
||||
script_path = output_path.with_suffix('.py')
|
||||
script_content = f'''#!/usr/bin/env python3
|
||||
"""
|
||||
Data loading script
|
||||
Generated by extract_from_pdfs skill
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# Load the data
|
||||
df = pd.read_pickle('{pickle_path.name}')
|
||||
|
||||
print(f"Loaded {{len(df)}} records")
|
||||
print(f"Columns: {{list(df.columns)}}")
|
||||
print("\\nFirst few rows:")
|
||||
print(df.head())
|
||||
|
||||
# Example analyses:
|
||||
# df.describe()
|
||||
# df.groupby('some_column').size()
|
||||
# df.to_csv('output.csv', index=False)
|
||||
'''
|
||||
|
||||
with open(script_path, 'w') as f:
|
||||
f.write(script_content)
|
||||
|
||||
print(f"Created loading script: {script_path}")
|
||||
|
||||
|
||||
def export_to_r(records: List[Dict], output_path: Path):
|
||||
"""Export to R format (RDS file)"""
|
||||
try:
|
||||
import pandas as pd
|
||||
import pyreadr
|
||||
except ImportError:
|
||||
print("Error: pandas and pyreadr are required for R export.")
|
||||
print("Install with: pip install pandas pyreadr")
|
||||
sys.exit(1)
|
||||
|
||||
df = pd.DataFrame(records)
|
||||
|
||||
# Save as RDS
|
||||
rds_path = output_path.with_suffix('.rds')
|
||||
pyreadr.write_rds(rds_path, df)
|
||||
print(f"Exported {len(records)} records to RDS: {rds_path}")
|
||||
|
||||
# Also create an R script to load it
|
||||
script_path = output_path.with_suffix('.R')
|
||||
script_content = f'''# Data loading script
|
||||
# Generated by extract_from_pdfs skill
|
||||
|
||||
# Load the data
|
||||
data <- readRDS('{rds_path.name}')
|
||||
|
||||
cat(sprintf("Loaded %d records\\n", nrow(data)))
|
||||
cat(sprintf("Columns: %s\\n", paste(colnames(data), collapse=", ")))
|
||||
cat("\\nFirst few rows:\\n")
|
||||
print(head(data))
|
||||
|
||||
# Example analyses:
|
||||
# summary(data)
|
||||
# table(data$some_column)
|
||||
# write.csv(data, 'output.csv', row.names=FALSE)
|
||||
'''
|
||||
|
||||
with open(script_path, 'w') as f:
|
||||
f.write(script_content)
|
||||
|
||||
print(f"Created loading script: {script_path}")
|
||||
|
||||
|
||||
def export_to_excel(records: List[Dict], output_path: Path):
|
||||
"""Export to Excel format"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
print("Error: pandas is required for Excel export. Install with: pip install pandas openpyxl")
|
||||
sys.exit(1)
|
||||
|
||||
df = pd.DataFrame(records)
|
||||
|
||||
# Save as Excel
|
||||
excel_path = output_path.with_suffix('.xlsx')
|
||||
df.to_excel(excel_path, index=False, engine='openpyxl')
|
||||
print(f"Exported {len(records)} records to Excel: {excel_path}")
|
||||
|
||||
|
||||
def export_to_sqlite(records: List[Dict], output_path: Path):
|
||||
"""Export to SQLite database"""
|
||||
try:
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
except ImportError:
|
||||
print("Error: pandas is required for SQLite export. Install with: pip install pandas")
|
||||
sys.exit(1)
|
||||
|
||||
df = pd.DataFrame(records)
|
||||
|
||||
# Create database
|
||||
db_path = output_path.with_suffix('.db')
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
# Write to database
|
||||
table_name = 'extracted_data'
|
||||
df.to_sql(table_name, conn, if_exists='replace', index=False)
|
||||
|
||||
conn.close()
|
||||
print(f"Exported {len(records)} records to SQLite database: {db_path}")
|
||||
print(f"Table name: {table_name}")
|
||||
|
||||
# Create SQL script with example queries
|
||||
sql_script_path = output_path.with_suffix('.sql')
|
||||
sql_content = f'''-- Example SQL queries for {db_path.name}
|
||||
-- Generated by extract_from_pdfs skill
|
||||
|
||||
-- View all records
|
||||
SELECT * FROM {table_name} LIMIT 10;
|
||||
|
||||
-- Count total records
|
||||
SELECT COUNT(*) as total_records FROM {table_name};
|
||||
|
||||
-- Example: Group by a column (adjust column name as needed)
|
||||
-- SELECT column_name, COUNT(*) as count
|
||||
-- FROM {table_name}
|
||||
-- GROUP BY column_name
|
||||
-- ORDER BY count DESC;
|
||||
'''
|
||||
|
||||
with open(sql_script_path, 'w') as f:
|
||||
f.write(sql_content)
|
||||
|
||||
print(f"Created SQL example script: {sql_script_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load validated results
|
||||
results = load_results(Path(args.input))
|
||||
print(f"Loaded {len(results)} results")
|
||||
|
||||
# Extract records
|
||||
records = extract_records(
|
||||
results,
|
||||
flatten=args.flatten,
|
||||
include_metadata=args.include_metadata
|
||||
)
|
||||
print(f"Extracted {len(records)} records")
|
||||
|
||||
if not records:
|
||||
print("No records to export. Check your data.")
|
||||
return
|
||||
|
||||
# Export based on format
|
||||
output_path = Path(args.output)
|
||||
|
||||
if args.format == 'csv':
|
||||
export_to_csv(records, output_path)
|
||||
elif args.format == 'json':
|
||||
export_to_json(records, output_path)
|
||||
elif args.format == 'python':
|
||||
export_to_python(records, output_path)
|
||||
elif args.format == 'r':
|
||||
export_to_r(records, output_path)
|
||||
elif args.format == 'excel':
|
||||
export_to_excel(records, output_path)
|
||||
elif args.format == 'sqlite':
|
||||
export_to_sqlite(records, output_path)
|
||||
|
||||
print(f"\nExport complete!")
|
||||
print(f"Your data is ready for analysis in {args.format.upper()} format.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
280
skills/extract_from_pdfs/scripts/07_prepare_validation_set.py
Normal file
280
skills/extract_from_pdfs/scripts/07_prepare_validation_set.py
Normal file
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prepare a validation set for evaluating extraction quality.
|
||||
|
||||
This script helps you:
|
||||
1. Sample a subset of papers for manual annotation
|
||||
2. Set up a structured annotation file
|
||||
3. Guide the annotation process
|
||||
|
||||
The validation set is used to calculate precision and recall metrics.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
import sys
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Prepare validation set for extraction quality evaluation',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Workflow:
|
||||
1. This script samples papers from your extraction results
|
||||
2. It creates an annotation template based on your schema
|
||||
3. You manually annotate the sampled papers with ground truth
|
||||
4. Use 08_calculate_validation_metrics.py to compare automated vs. manual extraction
|
||||
|
||||
Sampling strategies:
|
||||
random : Random sample (good for overall quality)
|
||||
stratified: Sample by extraction characteristics (good for identifying weaknesses)
|
||||
diverse : Sample to maximize diversity (good for comprehensive evaluation)
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--extraction-results',
|
||||
required=True,
|
||||
help='JSON file with extraction results from step 03 or 04'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--schema',
|
||||
required=True,
|
||||
help='Extraction schema JSON file used in step 03'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='validation_set.json',
|
||||
help='Output file for validation annotations'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--sample-size',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Number of papers to sample (default: 20, recommended: 20-50)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strategy',
|
||||
choices=['random', 'stratified', 'diverse'],
|
||||
default='random',
|
||||
help='Sampling strategy (default: random)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
help='Random seed for reproducibility'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_results(results_path: Path) -> Dict:
|
||||
"""Load extraction results"""
|
||||
with open(results_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_schema(schema_path: Path) -> Dict:
|
||||
"""Load extraction schema"""
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def sample_random(results: Dict, sample_size: int, seed: int) -> List[str]:
|
||||
"""Random sampling strategy"""
|
||||
# Only sample from successful extractions
|
||||
successful = [
|
||||
paper_id for paper_id, result in results.items()
|
||||
if result.get('status') == 'success' and result.get('extracted_data')
|
||||
]
|
||||
|
||||
if len(successful) < sample_size:
|
||||
print(f"Warning: Only {len(successful)} successful extractions available")
|
||||
sample_size = len(successful)
|
||||
|
||||
random.seed(seed)
|
||||
return random.sample(successful, sample_size)
|
||||
|
||||
|
||||
def sample_stratified(results: Dict, sample_size: int, seed: int) -> List[str]:
|
||||
"""
|
||||
Stratified sampling: sample papers with different characteristics
|
||||
E.g., papers with many records vs. few records, different data completeness
|
||||
"""
|
||||
successful = {}
|
||||
for paper_id, result in results.items():
|
||||
if result.get('status') == 'success' and result.get('extracted_data'):
|
||||
data = result['extracted_data']
|
||||
# Count records if present
|
||||
num_records = len(data.get('records', [])) if 'records' in data else 0
|
||||
successful[paper_id] = num_records
|
||||
|
||||
if not successful:
|
||||
print("No successful extractions found")
|
||||
return []
|
||||
|
||||
# Create strata based on number of records
|
||||
strata = {
|
||||
'zero': [],
|
||||
'few': [], # 1-2 records
|
||||
'medium': [], # 3-5 records
|
||||
'many': [] # 6+ records
|
||||
}
|
||||
|
||||
for paper_id, count in successful.items():
|
||||
if count == 0:
|
||||
strata['zero'].append(paper_id)
|
||||
elif count <= 2:
|
||||
strata['few'].append(paper_id)
|
||||
elif count <= 5:
|
||||
strata['medium'].append(paper_id)
|
||||
else:
|
||||
strata['many'].append(paper_id)
|
||||
|
||||
# Sample proportionally from each stratum
|
||||
random.seed(seed)
|
||||
sampled = []
|
||||
total_papers = len(successful)
|
||||
|
||||
for stratum_name, papers in strata.items():
|
||||
if not papers:
|
||||
continue
|
||||
# Sample proportionally, at least 1 from each non-empty stratum
|
||||
stratum_sample_size = max(1, int(len(papers) / total_papers * sample_size))
|
||||
stratum_sample_size = min(stratum_sample_size, len(papers))
|
||||
sampled.extend(random.sample(papers, stratum_sample_size))
|
||||
|
||||
# If we haven't reached sample_size, add more randomly
|
||||
if len(sampled) < sample_size:
|
||||
remaining = [p for p in successful.keys() if p not in sampled]
|
||||
additional = min(sample_size - len(sampled), len(remaining))
|
||||
sampled.extend(random.sample(remaining, additional))
|
||||
|
||||
return sampled[:sample_size]
|
||||
|
||||
|
||||
def sample_diverse(results: Dict, sample_size: int, seed: int) -> List[str]:
|
||||
"""
|
||||
Diverse sampling: maximize diversity in sampled papers
|
||||
This is a simplified version - could be enhanced with actual diversity metrics
|
||||
"""
|
||||
# For now, use stratified sampling as a proxy for diversity
|
||||
return sample_stratified(results, sample_size, seed)
|
||||
|
||||
|
||||
def create_annotation_template(
|
||||
sampled_ids: List[str],
|
||||
results: Dict,
|
||||
schema: Dict
|
||||
) -> Dict:
|
||||
"""
|
||||
Create annotation template for manual validation.
|
||||
|
||||
Structure:
|
||||
{
|
||||
"paper_id": {
|
||||
"automated_extraction": {...},
|
||||
"ground_truth": null, # To be filled manually
|
||||
"notes": "",
|
||||
"annotator": "",
|
||||
"annotation_date": ""
|
||||
}
|
||||
}
|
||||
"""
|
||||
template = {
|
||||
"_instructions": {
|
||||
"overview": "This is a validation annotation file. For each paper, review the PDF and fill in the ground_truth field with the correct extraction.",
|
||||
"steps": [
|
||||
"1. Read the PDF for each paper_id",
|
||||
"2. Extract data according to the schema, filling the 'ground_truth' field",
|
||||
"3. The 'ground_truth' should have the same structure as 'automated_extraction'",
|
||||
"4. Add your name in 'annotator' and date in 'annotation_date'",
|
||||
"5. Use 'notes' field for any comments or ambiguities",
|
||||
"6. Once complete, use 08_calculate_validation_metrics.py to compare"
|
||||
],
|
||||
"schema_reference": schema.get('output_schema', {}),
|
||||
"tips": [
|
||||
"Be thorough: extract ALL relevant information, even if automated extraction missed it",
|
||||
"Be precise: use exact values as they appear in the paper",
|
||||
"Be consistent: follow the same schema structure",
|
||||
"Mark ambiguous cases in notes field"
|
||||
]
|
||||
},
|
||||
"validation_papers": {}
|
||||
}
|
||||
|
||||
for paper_id in sampled_ids:
|
||||
result = results[paper_id]
|
||||
template["validation_papers"][paper_id] = {
|
||||
"automated_extraction": result.get('extracted_data', {}),
|
||||
"ground_truth": None, # To be filled by annotator
|
||||
"notes": "",
|
||||
"annotator": "",
|
||||
"annotation_date": "",
|
||||
"_pdf_path": None, # Will try to find from metadata
|
||||
"_extraction_metadata": {
|
||||
"extraction_status": result.get('status'),
|
||||
"validation_status": result.get('validation_status'),
|
||||
"has_analysis": bool(result.get('analysis'))
|
||||
}
|
||||
}
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def save_template(template: Dict, output_path: Path):
|
||||
"""Save annotation template to JSON file"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(template, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load inputs
|
||||
results = load_results(Path(args.extraction_results))
|
||||
schema = load_schema(Path(args.schema))
|
||||
print(f"Loaded {len(results)} extraction results")
|
||||
|
||||
# Sample papers
|
||||
if args.strategy == 'random':
|
||||
sampled = sample_random(results, args.sample_size, args.seed)
|
||||
elif args.strategy == 'stratified':
|
||||
sampled = sample_stratified(results, args.sample_size, args.seed)
|
||||
elif args.strategy == 'diverse':
|
||||
sampled = sample_diverse(results, args.sample_size, args.seed)
|
||||
|
||||
print(f"Sampled {len(sampled)} papers using '{args.strategy}' strategy")
|
||||
|
||||
# Create annotation template
|
||||
template = create_annotation_template(sampled, results, schema)
|
||||
|
||||
# Save template
|
||||
output_path = Path(args.output)
|
||||
save_template(template, output_path)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Validation Set Preparation Complete")
|
||||
print(f"{'='*60}")
|
||||
print(f"Annotation file created: {output_path}")
|
||||
print(f"Papers to annotate: {len(sampled)}")
|
||||
print(f"\nNext steps:")
|
||||
print(f"1. Open {output_path} in a text editor")
|
||||
print(f"2. For each paper, read the PDF and fill in the 'ground_truth' field")
|
||||
print(f"3. Follow the schema structure shown in '_instructions'")
|
||||
print(f"4. Save your annotations")
|
||||
print(f"5. Run: python 08_calculate_validation_metrics.py --annotations {output_path}")
|
||||
print(f"\nTips for efficient annotation:")
|
||||
print(f"- Work in batches of 5-10 papers")
|
||||
print(f"- Use the automated extraction as a starting point to check")
|
||||
print(f"- Document any ambiguous cases in the notes field")
|
||||
print(f"- Consider having 2+ annotators for inter-rater reliability")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,513 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Calculate validation metrics (precision, recall, F1) for extraction quality.
|
||||
|
||||
Compares automated extraction against ground truth annotations to evaluate:
|
||||
- Field-level precision and recall
|
||||
- Record-level accuracy
|
||||
- Overall extraction quality
|
||||
|
||||
Handles different data types appropriately:
|
||||
- Boolean: exact match
|
||||
- Numeric: exact match or tolerance
|
||||
- String: exact match or fuzzy matching
|
||||
- Lists: set-based precision/recall
|
||||
- Nested objects: recursive comparison
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Calculate validation metrics for extraction quality',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Metrics calculated:
|
||||
Precision : Of extracted items, how many are correct?
|
||||
Recall : Of true items, how many were extracted?
|
||||
F1 Score : Harmonic mean of precision and recall
|
||||
Accuracy : Overall correctness (for boolean/categorical fields)
|
||||
|
||||
Field type handling:
|
||||
Boolean/Categorical : Exact match
|
||||
Numeric : Exact match or within tolerance
|
||||
String : Exact match or fuzzy (normalized)
|
||||
Lists : Set-based precision/recall
|
||||
Nested objects : Recursive field-by-field comparison
|
||||
|
||||
Output:
|
||||
- Overall metrics
|
||||
- Per-field metrics
|
||||
- Per-paper detailed comparison
|
||||
- Common error patterns
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--annotations',
|
||||
required=True,
|
||||
help='Annotation file from 07_prepare_validation_set.py (with ground truth filled in)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='validation_metrics.json',
|
||||
help='Output file for detailed metrics'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--report',
|
||||
default='validation_report.txt',
|
||||
help='Human-readable validation report'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--numeric-tolerance',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='Tolerance for numeric comparisons (default: 0.0 for exact match)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--fuzzy-strings',
|
||||
action='store_true',
|
||||
help='Use fuzzy string matching (normalize whitespace, case)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--list-order-matters',
|
||||
action='store_true',
|
||||
help='Consider order in list comparisons (default: treat as sets)'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_annotations(annotations_path: Path) -> Dict:
|
||||
"""Load annotations file"""
|
||||
with open(annotations_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def normalize_string(s: str, fuzzy: bool = False) -> str:
|
||||
"""Normalize string for comparison"""
|
||||
if not isinstance(s, str):
|
||||
return str(s)
|
||||
if fuzzy:
|
||||
return ' '.join(s.lower().split())
|
||||
return s
|
||||
|
||||
|
||||
def compare_boolean(automated: Any, truth: Any) -> Dict[str, int]:
|
||||
"""Compare boolean values"""
|
||||
if automated == truth:
|
||||
return {'tp': 1, 'fp': 0, 'fn': 0, 'tn': 0}
|
||||
elif automated and not truth:
|
||||
return {'tp': 0, 'fp': 1, 'fn': 0, 'tn': 0}
|
||||
elif not automated and truth:
|
||||
return {'tp': 0, 'fp': 0, 'fn': 1, 'tn': 0}
|
||||
else:
|
||||
return {'tp': 0, 'fp': 0, 'fn': 0, 'tn': 1}
|
||||
|
||||
|
||||
def compare_numeric(automated: Any, truth: Any, tolerance: float = 0.0) -> bool:
|
||||
"""Compare numeric values with optional tolerance"""
|
||||
try:
|
||||
a = float(automated) if automated is not None else None
|
||||
t = float(truth) if truth is not None else None
|
||||
|
||||
if a is None and t is None:
|
||||
return True
|
||||
if a is None or t is None:
|
||||
return False
|
||||
|
||||
if tolerance > 0:
|
||||
return abs(a - t) <= tolerance
|
||||
else:
|
||||
return a == t
|
||||
except (ValueError, TypeError):
|
||||
return automated == truth
|
||||
|
||||
|
||||
def compare_string(automated: Any, truth: Any, fuzzy: bool = False) -> bool:
|
||||
"""Compare string values"""
|
||||
if automated is None and truth is None:
|
||||
return True
|
||||
if automated is None or truth is None:
|
||||
return False
|
||||
|
||||
a = normalize_string(automated, fuzzy)
|
||||
t = normalize_string(truth, fuzzy)
|
||||
return a == t
|
||||
|
||||
|
||||
def compare_list(
|
||||
automated: List,
|
||||
truth: List,
|
||||
order_matters: bool = False,
|
||||
fuzzy: bool = False
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Compare lists and calculate precision/recall.
|
||||
|
||||
Returns counts of true positives, false positives, and false negatives.
|
||||
"""
|
||||
if automated is None:
|
||||
automated = []
|
||||
if truth is None:
|
||||
truth = []
|
||||
|
||||
if not isinstance(automated, list):
|
||||
automated = [automated]
|
||||
if not isinstance(truth, list):
|
||||
truth = [truth]
|
||||
|
||||
if order_matters:
|
||||
# Ordered comparison
|
||||
tp = sum(1 for a, t in zip(automated, truth) if compare_string(a, t, fuzzy))
|
||||
fp = max(0, len(automated) - len(truth))
|
||||
fn = max(0, len(truth) - len(automated))
|
||||
else:
|
||||
# Set-based comparison
|
||||
if fuzzy:
|
||||
auto_set = {normalize_string(x, fuzzy) for x in automated}
|
||||
truth_set = {normalize_string(x, fuzzy) for x in truth}
|
||||
else:
|
||||
auto_set = set(automated)
|
||||
truth_set = set(truth)
|
||||
|
||||
tp = len(auto_set & truth_set) # Intersection
|
||||
fp = len(auto_set - truth_set) # In automated but not in truth
|
||||
fn = len(truth_set - auto_set) # In truth but not in automated
|
||||
|
||||
return {'tp': tp, 'fp': fp, 'fn': fn}
|
||||
|
||||
|
||||
def calculate_metrics(tp: int, fp: int, fn: int) -> Dict[str, float]:
|
||||
"""Calculate precision, recall, and F1 from counts"""
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
return {
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'tp': tp,
|
||||
'fp': fp,
|
||||
'fn': fn
|
||||
}
|
||||
|
||||
|
||||
def compare_field(
|
||||
automated: Any,
|
||||
truth: Any,
|
||||
field_name: str,
|
||||
config: Dict
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Compare a single field between automated and ground truth.
|
||||
|
||||
Returns metrics appropriate for the field type.
|
||||
"""
|
||||
# Determine field type
|
||||
if isinstance(truth, bool):
|
||||
return compare_boolean(automated, truth)
|
||||
elif isinstance(truth, (int, float)):
|
||||
match = compare_numeric(automated, truth, config['numeric_tolerance'])
|
||||
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
|
||||
elif isinstance(truth, str):
|
||||
match = compare_string(automated, truth, config['fuzzy_strings'])
|
||||
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
|
||||
elif isinstance(truth, list):
|
||||
return compare_list(automated, truth, config['list_order_matters'], config['fuzzy_strings'])
|
||||
elif isinstance(truth, dict):
|
||||
# Recursive comparison for nested objects
|
||||
return compare_nested(automated or {}, truth, config)
|
||||
elif truth is None:
|
||||
# Field should be empty/null
|
||||
if automated is None or automated == "" or automated == []:
|
||||
return {'tp': 1, 'fp': 0, 'fn': 0}
|
||||
else:
|
||||
return {'tp': 0, 'fp': 1, 'fn': 0}
|
||||
else:
|
||||
# Fallback to exact match
|
||||
match = automated == truth
|
||||
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
|
||||
|
||||
|
||||
def compare_nested(automated: Dict, truth: Dict, config: Dict) -> Dict[str, int]:
|
||||
"""Recursively compare nested objects"""
|
||||
total_counts = {'tp': 0, 'fp': 0, 'fn': 0}
|
||||
|
||||
all_fields = set(automated.keys()) | set(truth.keys())
|
||||
|
||||
for field in all_fields:
|
||||
auto_val = automated.get(field)
|
||||
truth_val = truth.get(field)
|
||||
|
||||
field_counts = compare_field(auto_val, truth_val, field, config)
|
||||
|
||||
for key in ['tp', 'fp', 'fn']:
|
||||
total_counts[key] += field_counts.get(key, 0)
|
||||
|
||||
return total_counts
|
||||
|
||||
|
||||
def evaluate_paper(
|
||||
paper_id: str,
|
||||
automated: Dict,
|
||||
truth: Dict,
|
||||
config: Dict
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate extraction for a single paper.
|
||||
|
||||
Returns field-level and overall metrics.
|
||||
"""
|
||||
if truth is None:
|
||||
return {
|
||||
'status': 'not_annotated',
|
||||
'message': 'Ground truth not provided'
|
||||
}
|
||||
|
||||
field_metrics = {}
|
||||
all_fields = set(automated.keys()) | set(truth.keys())
|
||||
|
||||
for field in all_fields:
|
||||
if field == 'records':
|
||||
# Special handling for records arrays
|
||||
auto_records = automated.get('records', [])
|
||||
truth_records = truth.get('records', [])
|
||||
|
||||
# Overall record count comparison
|
||||
record_counts = compare_list(auto_records, truth_records, order_matters=False)
|
||||
|
||||
# Detailed record-level comparison
|
||||
record_details = []
|
||||
for i, (auto_rec, truth_rec) in enumerate(zip(auto_records, truth_records)):
|
||||
rec_comparison = compare_nested(auto_rec, truth_rec, config)
|
||||
record_details.append({
|
||||
'record_index': i,
|
||||
'metrics': calculate_metrics(**rec_comparison)
|
||||
})
|
||||
|
||||
field_metrics['records'] = {
|
||||
'count_metrics': calculate_metrics(**record_counts),
|
||||
'record_details': record_details
|
||||
}
|
||||
else:
|
||||
auto_val = automated.get(field)
|
||||
truth_val = truth.get(field)
|
||||
counts = compare_field(auto_val, truth_val, field, config)
|
||||
field_metrics[field] = calculate_metrics(**counts)
|
||||
|
||||
# Calculate overall metrics
|
||||
total_tp = sum(
|
||||
m.get('tp', 0) if isinstance(m, dict) and 'tp' in m
|
||||
else m.get('count_metrics', {}).get('tp', 0)
|
||||
for m in field_metrics.values()
|
||||
)
|
||||
total_fp = sum(
|
||||
m.get('fp', 0) if isinstance(m, dict) and 'fp' in m
|
||||
else m.get('count_metrics', {}).get('fp', 0)
|
||||
for m in field_metrics.values()
|
||||
)
|
||||
total_fn = sum(
|
||||
m.get('fn', 0) if isinstance(m, dict) and 'fn' in m
|
||||
else m.get('count_metrics', {}).get('fn', 0)
|
||||
for m in field_metrics.values()
|
||||
)
|
||||
|
||||
overall = calculate_metrics(total_tp, total_fp, total_fn)
|
||||
|
||||
return {
|
||||
'status': 'evaluated',
|
||||
'field_metrics': field_metrics,
|
||||
'overall': overall
|
||||
}
|
||||
|
||||
|
||||
def aggregate_metrics(paper_evaluations: Dict[str, Dict]) -> Dict[str, Any]:
|
||||
"""Aggregate metrics across all papers"""
|
||||
# Collect field-level metrics
|
||||
field_aggregates = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
|
||||
|
||||
evaluated_papers = [
|
||||
p for p in paper_evaluations.values()
|
||||
if p.get('status') == 'evaluated'
|
||||
]
|
||||
|
||||
for paper_eval in evaluated_papers:
|
||||
for field, metrics in paper_eval.get('field_metrics', {}).items():
|
||||
if isinstance(metrics, dict):
|
||||
if 'tp' in metrics:
|
||||
# Simple field
|
||||
field_aggregates[field]['tp'] += metrics['tp']
|
||||
field_aggregates[field]['fp'] += metrics['fp']
|
||||
field_aggregates[field]['fn'] += metrics['fn']
|
||||
elif 'count_metrics' in metrics:
|
||||
# Records field
|
||||
field_aggregates[field]['tp'] += metrics['count_metrics']['tp']
|
||||
field_aggregates[field]['fp'] += metrics['count_metrics']['fp']
|
||||
field_aggregates[field]['fn'] += metrics['count_metrics']['fn']
|
||||
|
||||
# Calculate metrics for each field
|
||||
field_metrics = {}
|
||||
for field, counts in field_aggregates.items():
|
||||
field_metrics[field] = calculate_metrics(**counts)
|
||||
|
||||
# Overall aggregated metrics
|
||||
total_tp = sum(counts['tp'] for counts in field_aggregates.values())
|
||||
total_fp = sum(counts['fp'] for counts in field_aggregates.values())
|
||||
total_fn = sum(counts['fn'] for counts in field_aggregates.values())
|
||||
|
||||
overall = calculate_metrics(total_tp, total_fp, total_fn)
|
||||
|
||||
return {
|
||||
'overall': overall,
|
||||
'by_field': field_metrics,
|
||||
'num_papers_evaluated': len(evaluated_papers)
|
||||
}
|
||||
|
||||
|
||||
def generate_report(
|
||||
paper_evaluations: Dict[str, Dict],
|
||||
aggregated: Dict,
|
||||
output_path: Path
|
||||
):
|
||||
"""Generate human-readable validation report"""
|
||||
lines = []
|
||||
lines.append("="*80)
|
||||
lines.append("EXTRACTION VALIDATION REPORT")
|
||||
lines.append("="*80)
|
||||
lines.append("")
|
||||
|
||||
# Overall summary
|
||||
lines.append("OVERALL METRICS")
|
||||
lines.append("-"*80)
|
||||
overall = aggregated['overall']
|
||||
lines.append(f"Papers evaluated: {aggregated['num_papers_evaluated']}")
|
||||
lines.append(f"Precision: {overall['precision']:.2%}")
|
||||
lines.append(f"Recall: {overall['recall']:.2%}")
|
||||
lines.append(f"F1 Score: {overall['f1']:.2%}")
|
||||
lines.append(f"True Positives: {overall['tp']}")
|
||||
lines.append(f"False Positives: {overall['fp']}")
|
||||
lines.append(f"False Negatives: {overall['fn']}")
|
||||
lines.append("")
|
||||
|
||||
# Per-field metrics
|
||||
lines.append("METRICS BY FIELD")
|
||||
lines.append("-"*80)
|
||||
lines.append(f"{'Field':<30} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
||||
lines.append("-"*80)
|
||||
|
||||
for field, metrics in sorted(aggregated['by_field'].items()):
|
||||
lines.append(
|
||||
f"{field:<30} "
|
||||
f"{metrics['precision']:>9.1%} "
|
||||
f"{metrics['recall']:>9.1%} "
|
||||
f"{metrics['f1']:>9.1%}"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Top errors
|
||||
lines.append("COMMON ISSUES")
|
||||
lines.append("-"*80)
|
||||
|
||||
# Fields with low recall (missed information)
|
||||
low_recall = [
|
||||
(field, metrics) for field, metrics in aggregated['by_field'].items()
|
||||
if metrics['recall'] < 0.7 and metrics['fn'] > 0
|
||||
]
|
||||
if low_recall:
|
||||
lines.append("\nFields with low recall (missed information):")
|
||||
for field, metrics in sorted(low_recall, key=lambda x: x[1]['recall']):
|
||||
lines.append(f" - {field}: {metrics['recall']:.1%} recall, {metrics['fn']} missed items")
|
||||
|
||||
# Fields with low precision (incorrect extractions)
|
||||
low_precision = [
|
||||
(field, metrics) for field, metrics in aggregated['by_field'].items()
|
||||
if metrics['precision'] < 0.7 and metrics['fp'] > 0
|
||||
]
|
||||
if low_precision:
|
||||
lines.append("\nFields with low precision (incorrect extractions):")
|
||||
for field, metrics in sorted(low_precision, key=lambda x: x[1]['precision']):
|
||||
lines.append(f" - {field}: {metrics['precision']:.1%} precision, {metrics['fp']} incorrect items")
|
||||
|
||||
lines.append("")
|
||||
lines.append("="*80)
|
||||
|
||||
# Write report
|
||||
report_text = "\n".join(lines)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report_text)
|
||||
|
||||
# Also print to console
|
||||
print(report_text)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load annotations
|
||||
annotations = load_annotations(Path(args.annotations))
|
||||
validation_papers = annotations.get('validation_papers', {})
|
||||
|
||||
print(f"Loaded {len(validation_papers)} validation papers")
|
||||
|
||||
# Check how many have ground truth
|
||||
annotated = sum(1 for p in validation_papers.values() if p.get('ground_truth') is not None)
|
||||
print(f"Papers with ground truth: {annotated}")
|
||||
|
||||
if annotated == 0:
|
||||
print("\nError: No ground truth annotations found!")
|
||||
print("Please fill in the 'ground_truth' field for each paper in the annotation file.")
|
||||
sys.exit(1)
|
||||
|
||||
# Configuration for comparisons
|
||||
config = {
|
||||
'numeric_tolerance': args.numeric_tolerance,
|
||||
'fuzzy_strings': args.fuzzy_strings,
|
||||
'list_order_matters': args.list_order_matters
|
||||
}
|
||||
|
||||
# Evaluate each paper
|
||||
paper_evaluations = {}
|
||||
for paper_id, paper_data in validation_papers.items():
|
||||
automated = paper_data.get('automated_extraction', {})
|
||||
truth = paper_data.get('ground_truth')
|
||||
|
||||
evaluation = evaluate_paper(paper_id, automated, truth, config)
|
||||
paper_evaluations[paper_id] = evaluation
|
||||
|
||||
if evaluation['status'] == 'evaluated':
|
||||
overall = evaluation['overall']
|
||||
print(f"{paper_id}: P={overall['precision']:.2%} R={overall['recall']:.2%} F1={overall['f1']:.2%}")
|
||||
|
||||
# Aggregate metrics
|
||||
aggregated = aggregate_metrics(paper_evaluations)
|
||||
|
||||
# Save detailed metrics
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
detailed_output = {
|
||||
'summary': aggregated,
|
||||
'by_paper': paper_evaluations,
|
||||
'config': config
|
||||
}
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(detailed_output, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nDetailed metrics saved to: {output_path}")
|
||||
|
||||
# Generate report
|
||||
report_path = Path(args.report)
|
||||
generate_report(paper_evaluations, aggregated, report_path)
|
||||
print(f"Validation report saved to: {report_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user