#!/usr/bin/env python3 """ Calculate validation metrics (precision, recall, F1) for extraction quality. Compares automated extraction against ground truth annotations to evaluate: - Field-level precision and recall - Record-level accuracy - Overall extraction quality Handles different data types appropriately: - Boolean: exact match - Numeric: exact match or tolerance - String: exact match or fuzzy matching - Lists: set-based precision/recall - Nested objects: recursive comparison """ import argparse import json from pathlib import Path from typing import Dict, List, Any, Tuple, Optional from collections import defaultdict import sys def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser( description='Calculate validation metrics for extraction quality', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Metrics calculated: Precision : Of extracted items, how many are correct? Recall : Of true items, how many were extracted? F1 Score : Harmonic mean of precision and recall Accuracy : Overall correctness (for boolean/categorical fields) Field type handling: Boolean/Categorical : Exact match Numeric : Exact match or within tolerance String : Exact match or fuzzy (normalized) Lists : Set-based precision/recall Nested objects : Recursive field-by-field comparison Output: - Overall metrics - Per-field metrics - Per-paper detailed comparison - Common error patterns """ ) parser.add_argument( '--annotations', required=True, help='Annotation file from 07_prepare_validation_set.py (with ground truth filled in)' ) parser.add_argument( '--output', default='validation_metrics.json', help='Output file for detailed metrics' ) parser.add_argument( '--report', default='validation_report.txt', help='Human-readable validation report' ) parser.add_argument( '--numeric-tolerance', type=float, default=0.0, help='Tolerance for numeric comparisons (default: 0.0 for exact match)' ) parser.add_argument( '--fuzzy-strings', action='store_true', help='Use fuzzy string matching (normalize whitespace, case)' ) parser.add_argument( '--list-order-matters', action='store_true', help='Consider order in list comparisons (default: treat as sets)' ) return parser.parse_args() def load_annotations(annotations_path: Path) -> Dict: """Load annotations file""" with open(annotations_path, 'r', encoding='utf-8') as f: return json.load(f) def normalize_string(s: str, fuzzy: bool = False) -> str: """Normalize string for comparison""" if not isinstance(s, str): return str(s) if fuzzy: return ' '.join(s.lower().split()) return s def compare_boolean(automated: Any, truth: Any) -> Dict[str, int]: """Compare boolean values""" if automated == truth: return {'tp': 1, 'fp': 0, 'fn': 0, 'tn': 0} elif automated and not truth: return {'tp': 0, 'fp': 1, 'fn': 0, 'tn': 0} elif not automated and truth: return {'tp': 0, 'fp': 0, 'fn': 1, 'tn': 0} else: return {'tp': 0, 'fp': 0, 'fn': 0, 'tn': 1} def compare_numeric(automated: Any, truth: Any, tolerance: float = 0.0) -> bool: """Compare numeric values with optional tolerance""" try: a = float(automated) if automated is not None else None t = float(truth) if truth is not None else None if a is None and t is None: return True if a is None or t is None: return False if tolerance > 0: return abs(a - t) <= tolerance else: return a == t except (ValueError, TypeError): return automated == truth def compare_string(automated: Any, truth: Any, fuzzy: bool = False) -> bool: """Compare string values""" if automated is None and truth is None: return True if automated is None or truth is None: return False a = normalize_string(automated, fuzzy) t = normalize_string(truth, fuzzy) return a == t def compare_list( automated: List, truth: List, order_matters: bool = False, fuzzy: bool = False ) -> Dict[str, int]: """ Compare lists and calculate precision/recall. Returns counts of true positives, false positives, and false negatives. """ if automated is None: automated = [] if truth is None: truth = [] if not isinstance(automated, list): automated = [automated] if not isinstance(truth, list): truth = [truth] if order_matters: # Ordered comparison tp = sum(1 for a, t in zip(automated, truth) if compare_string(a, t, fuzzy)) fp = max(0, len(automated) - len(truth)) fn = max(0, len(truth) - len(automated)) else: # Set-based comparison if fuzzy: auto_set = {normalize_string(x, fuzzy) for x in automated} truth_set = {normalize_string(x, fuzzy) for x in truth} else: auto_set = set(automated) truth_set = set(truth) tp = len(auto_set & truth_set) # Intersection fp = len(auto_set - truth_set) # In automated but not in truth fn = len(truth_set - auto_set) # In truth but not in automated return {'tp': tp, 'fp': fp, 'fn': fn} def calculate_metrics(tp: int, fp: int, fn: int) -> Dict[str, float]: """Calculate precision, recall, and F1 from counts""" precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return { 'precision': precision, 'recall': recall, 'f1': f1, 'tp': tp, 'fp': fp, 'fn': fn } def compare_field( automated: Any, truth: Any, field_name: str, config: Dict ) -> Dict[str, Any]: """ Compare a single field between automated and ground truth. Returns metrics appropriate for the field type. """ # Determine field type if isinstance(truth, bool): return compare_boolean(automated, truth) elif isinstance(truth, (int, float)): match = compare_numeric(automated, truth, config['numeric_tolerance']) return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1} elif isinstance(truth, str): match = compare_string(automated, truth, config['fuzzy_strings']) return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1} elif isinstance(truth, list): return compare_list(automated, truth, config['list_order_matters'], config['fuzzy_strings']) elif isinstance(truth, dict): # Recursive comparison for nested objects return compare_nested(automated or {}, truth, config) elif truth is None: # Field should be empty/null if automated is None or automated == "" or automated == []: return {'tp': 1, 'fp': 0, 'fn': 0} else: return {'tp': 0, 'fp': 1, 'fn': 0} else: # Fallback to exact match match = automated == truth return {'tp': 1 if match else 0, 'fp': 0 if match else 1, 'fn': 0 if match else 1} def compare_nested(automated: Dict, truth: Dict, config: Dict) -> Dict[str, int]: """Recursively compare nested objects""" total_counts = {'tp': 0, 'fp': 0, 'fn': 0} all_fields = set(automated.keys()) | set(truth.keys()) for field in all_fields: auto_val = automated.get(field) truth_val = truth.get(field) field_counts = compare_field(auto_val, truth_val, field, config) for key in ['tp', 'fp', 'fn']: total_counts[key] += field_counts.get(key, 0) return total_counts def evaluate_paper( paper_id: str, automated: Dict, truth: Dict, config: Dict ) -> Dict[str, Any]: """ Evaluate extraction for a single paper. Returns field-level and overall metrics. """ if truth is None: return { 'status': 'not_annotated', 'message': 'Ground truth not provided' } field_metrics = {} all_fields = set(automated.keys()) | set(truth.keys()) for field in all_fields: if field == 'records': # Special handling for records arrays auto_records = automated.get('records', []) truth_records = truth.get('records', []) # Overall record count comparison record_counts = compare_list(auto_records, truth_records, order_matters=False) # Detailed record-level comparison record_details = [] for i, (auto_rec, truth_rec) in enumerate(zip(auto_records, truth_records)): rec_comparison = compare_nested(auto_rec, truth_rec, config) record_details.append({ 'record_index': i, 'metrics': calculate_metrics(**rec_comparison) }) field_metrics['records'] = { 'count_metrics': calculate_metrics(**record_counts), 'record_details': record_details } else: auto_val = automated.get(field) truth_val = truth.get(field) counts = compare_field(auto_val, truth_val, field, config) field_metrics[field] = calculate_metrics(**counts) # Calculate overall metrics total_tp = sum( m.get('tp', 0) if isinstance(m, dict) and 'tp' in m else m.get('count_metrics', {}).get('tp', 0) for m in field_metrics.values() ) total_fp = sum( m.get('fp', 0) if isinstance(m, dict) and 'fp' in m else m.get('count_metrics', {}).get('fp', 0) for m in field_metrics.values() ) total_fn = sum( m.get('fn', 0) if isinstance(m, dict) and 'fn' in m else m.get('count_metrics', {}).get('fn', 0) for m in field_metrics.values() ) overall = calculate_metrics(total_tp, total_fp, total_fn) return { 'status': 'evaluated', 'field_metrics': field_metrics, 'overall': overall } def aggregate_metrics(paper_evaluations: Dict[str, Dict]) -> Dict[str, Any]: """Aggregate metrics across all papers""" # Collect field-level metrics field_aggregates = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0}) evaluated_papers = [ p for p in paper_evaluations.values() if p.get('status') == 'evaluated' ] for paper_eval in evaluated_papers: for field, metrics in paper_eval.get('field_metrics', {}).items(): if isinstance(metrics, dict): if 'tp' in metrics: # Simple field field_aggregates[field]['tp'] += metrics['tp'] field_aggregates[field]['fp'] += metrics['fp'] field_aggregates[field]['fn'] += metrics['fn'] elif 'count_metrics' in metrics: # Records field field_aggregates[field]['tp'] += metrics['count_metrics']['tp'] field_aggregates[field]['fp'] += metrics['count_metrics']['fp'] field_aggregates[field]['fn'] += metrics['count_metrics']['fn'] # Calculate metrics for each field field_metrics = {} for field, counts in field_aggregates.items(): field_metrics[field] = calculate_metrics(**counts) # Overall aggregated metrics total_tp = sum(counts['tp'] for counts in field_aggregates.values()) total_fp = sum(counts['fp'] for counts in field_aggregates.values()) total_fn = sum(counts['fn'] for counts in field_aggregates.values()) overall = calculate_metrics(total_tp, total_fp, total_fn) return { 'overall': overall, 'by_field': field_metrics, 'num_papers_evaluated': len(evaluated_papers) } def generate_report( paper_evaluations: Dict[str, Dict], aggregated: Dict, output_path: Path ): """Generate human-readable validation report""" lines = [] lines.append("="*80) lines.append("EXTRACTION VALIDATION REPORT") lines.append("="*80) lines.append("") # Overall summary lines.append("OVERALL METRICS") lines.append("-"*80) overall = aggregated['overall'] lines.append(f"Papers evaluated: {aggregated['num_papers_evaluated']}") lines.append(f"Precision: {overall['precision']:.2%}") lines.append(f"Recall: {overall['recall']:.2%}") lines.append(f"F1 Score: {overall['f1']:.2%}") lines.append(f"True Positives: {overall['tp']}") lines.append(f"False Positives: {overall['fp']}") lines.append(f"False Negatives: {overall['fn']}") lines.append("") # Per-field metrics lines.append("METRICS BY FIELD") lines.append("-"*80) lines.append(f"{'Field':<30} {'Precision':>10} {'Recall':>10} {'F1':>10}") lines.append("-"*80) for field, metrics in sorted(aggregated['by_field'].items()): lines.append( f"{field:<30} " f"{metrics['precision']:>9.1%} " f"{metrics['recall']:>9.1%} " f"{metrics['f1']:>9.1%}" ) lines.append("") # Top errors lines.append("COMMON ISSUES") lines.append("-"*80) # Fields with low recall (missed information) low_recall = [ (field, metrics) for field, metrics in aggregated['by_field'].items() if metrics['recall'] < 0.7 and metrics['fn'] > 0 ] if low_recall: lines.append("\nFields with low recall (missed information):") for field, metrics in sorted(low_recall, key=lambda x: x[1]['recall']): lines.append(f" - {field}: {metrics['recall']:.1%} recall, {metrics['fn']} missed items") # Fields with low precision (incorrect extractions) low_precision = [ (field, metrics) for field, metrics in aggregated['by_field'].items() if metrics['precision'] < 0.7 and metrics['fp'] > 0 ] if low_precision: lines.append("\nFields with low precision (incorrect extractions):") for field, metrics in sorted(low_precision, key=lambda x: x[1]['precision']): lines.append(f" - {field}: {metrics['precision']:.1%} precision, {metrics['fp']} incorrect items") lines.append("") lines.append("="*80) # Write report report_text = "\n".join(lines) with open(output_path, 'w', encoding='utf-8') as f: f.write(report_text) # Also print to console print(report_text) def main(): args = parse_args() # Load annotations annotations = load_annotations(Path(args.annotations)) validation_papers = annotations.get('validation_papers', {}) print(f"Loaded {len(validation_papers)} validation papers") # Check how many have ground truth annotated = sum(1 for p in validation_papers.values() if p.get('ground_truth') is not None) print(f"Papers with ground truth: {annotated}") if annotated == 0: print("\nError: No ground truth annotations found!") print("Please fill in the 'ground_truth' field for each paper in the annotation file.") sys.exit(1) # Configuration for comparisons config = { 'numeric_tolerance': args.numeric_tolerance, 'fuzzy_strings': args.fuzzy_strings, 'list_order_matters': args.list_order_matters } # Evaluate each paper paper_evaluations = {} for paper_id, paper_data in validation_papers.items(): automated = paper_data.get('automated_extraction', {}) truth = paper_data.get('ground_truth') evaluation = evaluate_paper(paper_id, automated, truth, config) paper_evaluations[paper_id] = evaluation if evaluation['status'] == 'evaluated': overall = evaluation['overall'] print(f"{paper_id}: P={overall['precision']:.2%} R={overall['recall']:.2%} F1={overall['f1']:.2%}") # Aggregate metrics aggregated = aggregate_metrics(paper_evaluations) # Save detailed metrics output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) detailed_output = { 'summary': aggregated, 'by_paper': paper_evaluations, 'config': config } with open(output_path, 'w', encoding='utf-8') as f: json.dump(detailed_output, f, indent=2, ensure_ascii=False) print(f"\nDetailed metrics saved to: {output_path}") # Generate report report_path = Path(args.report) generate_report(paper_evaluations, aggregated, report_path) print(f"Validation report saved to: {report_path}") if __name__ == '__main__': main()