Files
gh-brunoasm-my-claude-skill…/skills/extract_from_pdfs/scripts/08_calculate_validation_metrics.py
2025-11-29 18:02:40 +08:00

514 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Calculate validation metrics (precision, recall, F1) for extraction quality.
Compares automated extraction against ground truth annotations to evaluate:
- Field-level precision and recall
- Record-level accuracy
- Overall extraction quality
Handles different data types appropriately:
- Boolean: exact match
- Numeric: exact match or tolerance
- String: exact match or fuzzy matching
- Lists: set-based precision/recall
- Nested objects: recursive comparison
"""
import argparse
import json
from pathlib import Path
from typing import Dict, List, Any, Tuple, Optional
from collections import defaultdict
import sys
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Calculate validation metrics for extraction quality',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Metrics calculated:
Precision : Of extracted items, how many are correct?
Recall : Of true items, how many were extracted?
F1 Score : Harmonic mean of precision and recall
Accuracy : Overall correctness (for boolean/categorical fields)
Field type handling:
Boolean/Categorical : Exact match
Numeric : Exact match or within tolerance
String : Exact match or fuzzy (normalized)
Lists : Set-based precision/recall
Nested objects : Recursive field-by-field comparison
Output:
- Overall metrics
- Per-field metrics
- Per-paper detailed comparison
- Common error patterns
"""
)
parser.add_argument(
'--annotations',
required=True,
help='Annotation file from 07_prepare_validation_set.py (with ground truth filled in)'
)
parser.add_argument(
'--output',
default='validation_metrics.json',
help='Output file for detailed metrics'
)
parser.add_argument(
'--report',
default='validation_report.txt',
help='Human-readable validation report'
)
parser.add_argument(
'--numeric-tolerance',
type=float,
default=0.0,
help='Tolerance for numeric comparisons (default: 0.0 for exact match)'
)
parser.add_argument(
'--fuzzy-strings',
action='store_true',
help='Use fuzzy string matching (normalize whitespace, case)'
)
parser.add_argument(
'--list-order-matters',
action='store_true',
help='Consider order in list comparisons (default: treat as sets)'
)
return parser.parse_args()
def load_annotations(annotations_path: Path) -> Dict:
"""Load annotations file"""
with open(annotations_path, 'r', encoding='utf-8') as f:
return json.load(f)
def normalize_string(s: str, fuzzy: bool = False) -> str:
"""Normalize string for comparison"""
if not isinstance(s, str):
return str(s)
if fuzzy:
return ' '.join(s.lower().split())
return s
def compare_boolean(automated: Any, truth: Any) -> Dict[str, int]:
"""Compare boolean values"""
if automated == truth:
return {'tp': 1, 'fp': 0, 'fn': 0, 'tn': 0}
elif automated and not truth:
return {'tp': 0, 'fp': 1, 'fn': 0, 'tn': 0}
elif not automated and truth:
return {'tp': 0, 'fp': 0, 'fn': 1, 'tn': 0}
else:
return {'tp': 0, 'fp': 0, 'fn': 0, 'tn': 1}
def compare_numeric(automated: Any, truth: Any, tolerance: float = 0.0) -> bool:
"""Compare numeric values with optional tolerance"""
try:
a = float(automated) if automated is not None else None
t = float(truth) if truth is not None else None
if a is None and t is None:
return True
if a is None or t is None:
return False
if tolerance > 0:
return abs(a - t) <= tolerance
else:
return a == t
except (ValueError, TypeError):
return automated == truth
def compare_string(automated: Any, truth: Any, fuzzy: bool = False) -> bool:
"""Compare string values"""
if automated is None and truth is None:
return True
if automated is None or truth is None:
return False
a = normalize_string(automated, fuzzy)
t = normalize_string(truth, fuzzy)
return a == t
def compare_list(
automated: List,
truth: List,
order_matters: bool = False,
fuzzy: bool = False
) -> Dict[str, int]:
"""
Compare lists and calculate precision/recall.
Returns counts of true positives, false positives, and false negatives.
"""
if automated is None:
automated = []
if truth is None:
truth = []
if not isinstance(automated, list):
automated = [automated]
if not isinstance(truth, list):
truth = [truth]
if order_matters:
# Ordered comparison
tp = sum(1 for a, t in zip(automated, truth) if compare_string(a, t, fuzzy))
fp = max(0, len(automated) - len(truth))
fn = max(0, len(truth) - len(automated))
else:
# Set-based comparison
if fuzzy:
auto_set = {normalize_string(x, fuzzy) for x in automated}
truth_set = {normalize_string(x, fuzzy) for x in truth}
else:
auto_set = set(automated)
truth_set = set(truth)
tp = len(auto_set & truth_set) # Intersection
fp = len(auto_set - truth_set) # In automated but not in truth
fn = len(truth_set - auto_set) # In truth but not in automated
return {'tp': tp, 'fp': fp, 'fn': fn}
def calculate_metrics(tp: int, fp: int, fn: int) -> Dict[str, float]:
"""Calculate precision, recall, and F1 from counts"""
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return {
'precision': precision,
'recall': recall,
'f1': f1,
'tp': tp,
'fp': fp,
'fn': fn
}
def compare_field(
automated: Any,
truth: Any,
field_name: str,
config: Dict
) -> Dict[str, Any]:
"""
Compare a single field between automated and ground truth.
Returns metrics appropriate for the field type.
"""
# Determine field type
if isinstance(truth, bool):
return compare_boolean(automated, truth)
elif isinstance(truth, (int, float)):
match = compare_numeric(automated, truth, config['numeric_tolerance'])
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
elif isinstance(truth, str):
match = compare_string(automated, truth, config['fuzzy_strings'])
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
elif isinstance(truth, list):
return compare_list(automated, truth, config['list_order_matters'], config['fuzzy_strings'])
elif isinstance(truth, dict):
# Recursive comparison for nested objects
return compare_nested(automated or {}, truth, config)
elif truth is None:
# Field should be empty/null
if automated is None or automated == "" or automated == []:
return {'tp': 1, 'fp': 0, 'fn': 0}
else:
return {'tp': 0, 'fp': 1, 'fn': 0}
else:
# Fallback to exact match
match = automated == truth
return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1}
def compare_nested(automated: Dict, truth: Dict, config: Dict) -> Dict[str, int]:
"""Recursively compare nested objects"""
total_counts = {'tp': 0, 'fp': 0, 'fn': 0}
all_fields = set(automated.keys()) | set(truth.keys())
for field in all_fields:
auto_val = automated.get(field)
truth_val = truth.get(field)
field_counts = compare_field(auto_val, truth_val, field, config)
for key in ['tp', 'fp', 'fn']:
total_counts[key] += field_counts.get(key, 0)
return total_counts
def evaluate_paper(
paper_id: str,
automated: Dict,
truth: Dict,
config: Dict
) -> Dict[str, Any]:
"""
Evaluate extraction for a single paper.
Returns field-level and overall metrics.
"""
if truth is None:
return {
'status': 'not_annotated',
'message': 'Ground truth not provided'
}
field_metrics = {}
all_fields = set(automated.keys()) | set(truth.keys())
for field in all_fields:
if field == 'records':
# Special handling for records arrays
auto_records = automated.get('records', [])
truth_records = truth.get('records', [])
# Overall record count comparison
record_counts = compare_list(auto_records, truth_records, order_matters=False)
# Detailed record-level comparison
record_details = []
for i, (auto_rec, truth_rec) in enumerate(zip(auto_records, truth_records)):
rec_comparison = compare_nested(auto_rec, truth_rec, config)
record_details.append({
'record_index': i,
'metrics': calculate_metrics(**rec_comparison)
})
field_metrics['records'] = {
'count_metrics': calculate_metrics(**record_counts),
'record_details': record_details
}
else:
auto_val = automated.get(field)
truth_val = truth.get(field)
counts = compare_field(auto_val, truth_val, field, config)
field_metrics[field] = calculate_metrics(**counts)
# Calculate overall metrics
total_tp = sum(
m.get('tp', 0) if isinstance(m, dict) and 'tp' in m
else m.get('count_metrics', {}).get('tp', 0)
for m in field_metrics.values()
)
total_fp = sum(
m.get('fp', 0) if isinstance(m, dict) and 'fp' in m
else m.get('count_metrics', {}).get('fp', 0)
for m in field_metrics.values()
)
total_fn = sum(
m.get('fn', 0) if isinstance(m, dict) and 'fn' in m
else m.get('count_metrics', {}).get('fn', 0)
for m in field_metrics.values()
)
overall = calculate_metrics(total_tp, total_fp, total_fn)
return {
'status': 'evaluated',
'field_metrics': field_metrics,
'overall': overall
}
def aggregate_metrics(paper_evaluations: Dict[str, Dict]) -> Dict[str, Any]:
"""Aggregate metrics across all papers"""
# Collect field-level metrics
field_aggregates = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
evaluated_papers = [
p for p in paper_evaluations.values()
if p.get('status') == 'evaluated'
]
for paper_eval in evaluated_papers:
for field, metrics in paper_eval.get('field_metrics', {}).items():
if isinstance(metrics, dict):
if 'tp' in metrics:
# Simple field
field_aggregates[field]['tp'] += metrics['tp']
field_aggregates[field]['fp'] += metrics['fp']
field_aggregates[field]['fn'] += metrics['fn']
elif 'count_metrics' in metrics:
# Records field
field_aggregates[field]['tp'] += metrics['count_metrics']['tp']
field_aggregates[field]['fp'] += metrics['count_metrics']['fp']
field_aggregates[field]['fn'] += metrics['count_metrics']['fn']
# Calculate metrics for each field
field_metrics = {}
for field, counts in field_aggregates.items():
field_metrics[field] = calculate_metrics(**counts)
# Overall aggregated metrics
total_tp = sum(counts['tp'] for counts in field_aggregates.values())
total_fp = sum(counts['fp'] for counts in field_aggregates.values())
total_fn = sum(counts['fn'] for counts in field_aggregates.values())
overall = calculate_metrics(total_tp, total_fp, total_fn)
return {
'overall': overall,
'by_field': field_metrics,
'num_papers_evaluated': len(evaluated_papers)
}
def generate_report(
paper_evaluations: Dict[str, Dict],
aggregated: Dict,
output_path: Path
):
"""Generate human-readable validation report"""
lines = []
lines.append("="*80)
lines.append("EXTRACTION VALIDATION REPORT")
lines.append("="*80)
lines.append("")
# Overall summary
lines.append("OVERALL METRICS")
lines.append("-"*80)
overall = aggregated['overall']
lines.append(f"Papers evaluated: {aggregated['num_papers_evaluated']}")
lines.append(f"Precision: {overall['precision']:.2%}")
lines.append(f"Recall: {overall['recall']:.2%}")
lines.append(f"F1 Score: {overall['f1']:.2%}")
lines.append(f"True Positives: {overall['tp']}")
lines.append(f"False Positives: {overall['fp']}")
lines.append(f"False Negatives: {overall['fn']}")
lines.append("")
# Per-field metrics
lines.append("METRICS BY FIELD")
lines.append("-"*80)
lines.append(f"{'Field':<30} {'Precision':>10} {'Recall':>10} {'F1':>10}")
lines.append("-"*80)
for field, metrics in sorted(aggregated['by_field'].items()):
lines.append(
f"{field:<30} "
f"{metrics['precision']:>9.1%} "
f"{metrics['recall']:>9.1%} "
f"{metrics['f1']:>9.1%}"
)
lines.append("")
# Top errors
lines.append("COMMON ISSUES")
lines.append("-"*80)
# Fields with low recall (missed information)
low_recall = [
(field, metrics) for field, metrics in aggregated['by_field'].items()
if metrics['recall'] < 0.7 and metrics['fn'] > 0
]
if low_recall:
lines.append("\nFields with low recall (missed information):")
for field, metrics in sorted(low_recall, key=lambda x: x[1]['recall']):
lines.append(f" - {field}: {metrics['recall']:.1%} recall, {metrics['fn']} missed items")
# Fields with low precision (incorrect extractions)
low_precision = [
(field, metrics) for field, metrics in aggregated['by_field'].items()
if metrics['precision'] < 0.7 and metrics['fp'] > 0
]
if low_precision:
lines.append("\nFields with low precision (incorrect extractions):")
for field, metrics in sorted(low_precision, key=lambda x: x[1]['precision']):
lines.append(f" - {field}: {metrics['precision']:.1%} precision, {metrics['fp']} incorrect items")
lines.append("")
lines.append("="*80)
# Write report
report_text = "\n".join(lines)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(report_text)
# Also print to console
print(report_text)
def main():
args = parse_args()
# Load annotations
annotations = load_annotations(Path(args.annotations))
validation_papers = annotations.get('validation_papers', {})
print(f"Loaded {len(validation_papers)} validation papers")
# Check how many have ground truth
annotated = sum(1 for p in validation_papers.values() if p.get('ground_truth') is not None)
print(f"Papers with ground truth: {annotated}")
if annotated == 0:
print("\nError: No ground truth annotations found!")
print("Please fill in the 'ground_truth' field for each paper in the annotation file.")
sys.exit(1)
# Configuration for comparisons
config = {
'numeric_tolerance': args.numeric_tolerance,
'fuzzy_strings': args.fuzzy_strings,
'list_order_matters': args.list_order_matters
}
# Evaluate each paper
paper_evaluations = {}
for paper_id, paper_data in validation_papers.items():
automated = paper_data.get('automated_extraction', {})
truth = paper_data.get('ground_truth')
evaluation = evaluate_paper(paper_id, automated, truth, config)
paper_evaluations[paper_id] = evaluation
if evaluation['status'] == 'evaluated':
overall = evaluation['overall']
print(f"{paper_id}: P={overall['precision']:.2%} R={overall['recall']:.2%} F1={overall['f1']:.2%}")
# Aggregate metrics
aggregated = aggregate_metrics(paper_evaluations)
# Save detailed metrics
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
detailed_output = {
'summary': aggregated,
'by_paper': paper_evaluations,
'config': config
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(detailed_output, f, indent=2, ensure_ascii=False)
print(f"\nDetailed metrics saved to: {output_path}")
# Generate report
report_path = Path(args.report)
generate_report(paper_evaluations, aggregated, report_path)
print(f"Validation report saved to: {report_path}")
if __name__ == '__main__':
main()