finalyze / evaluate_eval.py
FridayCodehhr's picture
Upload 2 files
4a76722 verified
from __future__ import annotations
import argparse
import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple
TARGETS = ["balance_sheet", "profit_and_loss", "cash_flow"]
SCOPES = ["consolidated", "standalone"]
def load_json(p: Path):
with open(p, "r", encoding="utf-8") as fh:
return json.load(fh)
def to_set_pages(obj) -> Set[int]:
"""Normalize a GT or predicted pages value into a set of ints."""
if obj is None:
return set()
if isinstance(obj, (int, float)):
return {int(obj)}
if isinstance(obj, str):
if obj.isdigit():
return {int(obj)}
return set()
if isinstance(obj, (list, tuple, set)):
return set(int(x) for x in obj if isinstance(x, (int, float)) or (isinstance(x, str) and x.isdigit()))
# fallback: attempt to parse iterable
try:
return set(int(x) for x in obj)
except Exception:
return set()
def jaccard(a: Set[int], b: Set[int]) -> float:
if not a and not b:
return 1.0
if not a and b:
return 0.0
inter = len(a & b)
union = len(a | b)
return inter / union if union > 0 else 0.0
def precision_recall_f1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]:
p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
return p, r, f1
def evaluate_file(gt_path: Path, pred_path: Path) -> Dict:
gt = load_json(gt_path)
pred = load_json(pred_path)
# Map possible GT key synonyms to canonical targets
gt_key_map = {"pnl": "profit_and_loss", "profit_and_loss": "profit_and_loss"}
per_stmt_scores = {}
per_stmt_counts = {}
# For confusion counts aggregated by (stmt, scope)
counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES}
for stmt in TARGETS:
# GT: GT sometimes uses 'pnl' key
raw_gt = None
if stmt in gt:
raw_gt = gt.get(stmt)
elif stmt == "profit_and_loss" and "pnl" in gt:
raw_gt = gt.get("pnl")
# Normalize GT scopes -> sets
gt_scopes: Dict[str, Set[int]] = {}
if isinstance(raw_gt, dict):
for scope in SCOPES:
if scope in raw_gt and raw_gt[scope]:
gt_scopes[scope] = to_set_pages(raw_gt[scope])
else:
# If GT is list (no scope), treat as 'consolidated' single scope
if isinstance(raw_gt, list):
gt_scopes["consolidated"] = to_set_pages(raw_gt)
# Predictions: predicted blocks per stmt
pred_blocks = pred.get(stmt) or []
pred_by_scope: Dict[str, Set[int]] = {"consolidated": set(), "standalone": set(), "unknown": set()}
for b in pred_blocks:
if not isinstance(b, dict):
continue
scope = (b.get("scope") or "unknown").lower()
# Try 'pages' first, then 'start_page' to 'end_page' range
pages = to_set_pages(b.get("pages") or [])
if not pages:
sp = b.get("start_page")
ep = b.get("end_page")
if isinstance(sp, int) and isinstance(ep, int):
pages = set(range(sp, ep + 1))
if scope not in pred_by_scope:
pred_by_scope[scope] = set()
pred_by_scope[scope] |= pages
pred_any_scope = set().union(*pred_by_scope.values())
# Scoring logic per statement
stmt_scores = []
if gt_scopes:
# If GT has both scopes, score each separately and average
if all(s in gt_scopes for s in SCOPES):
for scope in SCOPES:
gt_pages = gt_scopes.get(scope, set())
pred_pages = pred_by_scope.get(scope, set())
# Jaccard
j = jaccard(gt_pages, pred_pages)
stmt_scores.append(j)
# Update TP/FP/FN counts (page-level)
tp = len(gt_pages & pred_pages)
fp = len(pred_pages - gt_pages)
fn = len(gt_pages - pred_pages)
counts[(stmt, scope)]["tp"] += tp
counts[(stmt, scope)]["fp"] += fp
counts[(stmt, scope)]["fn"] += fn
else:
# Single scope in GT: compare GT pages to any predicted pages (scope-agnostic)
# choose the GT scope name
gt_scope = next(iter(gt_scopes.keys()))
gt_pages = gt_scopes[gt_scope]
pred_pages = pred_any_scope
j = jaccard(gt_pages, pred_pages)
stmt_scores.append(j)
# For counting, attribute predicted pages to the GT scope
tp = len(gt_pages & pred_pages)
fp = len(pred_pages - gt_pages)
fn = len(gt_pages - pred_pages)
counts[(stmt, gt_scope)]["tp"] += tp
counts[(stmt, gt_scope)]["fp"] += fp
counts[(stmt, gt_scope)]["fn"] += fn
else:
# No GT for this statement: treat as not-applicable; but penalize false positives
# Any predicted pages here are false positives for both scopes (we count under 'consolidated')
pred_count = len(pred_any_scope)
if pred_count > 0:
counts[(stmt, "consolidated")]["fp"] += pred_count
stmt_scores.append(1.0) # neutral / perfect since nothing to predict
per_stmt_scores[stmt] = sum(stmt_scores) / max(1, len(stmt_scores))
# store a copy of counts per scope for this statement
per_stmt_counts[stmt] = {s: counts[(stmt, s)].copy() for s in SCOPES} if stmt_scores else {}
return {
"gt_path": str(gt_path),
"pred_path": str(pred_path),
"per_stmt_scores": per_stmt_scores,
"counts": counts,
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--split", default="eval", help="Which split folder under dataset/ to use (default: eval)")
args = ap.parse_args()
base = Path("./dataset")
split = base / args.split
gt_dir = split / "GTs"
pred_dir = split / "classifier_output"
if not gt_dir.exists():
raise FileNotFoundError(f"GTs dir not found: {gt_dir}")
if not pred_dir.exists():
raise FileNotFoundError(f"Predictions dir not found: {pred_dir}")
gt_files = sorted([p for p in gt_dir.iterdir() if p.suffix.lower() == ".json"])
if not gt_files:
print("No GT files found.")
return
total_counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES}
per_file_scores = []
for gt_p in gt_files:
stem = gt_p.stem
pred_p = pred_dir / f"{stem}.json"
if not pred_p.exists():
print(f"WARN: prediction missing for {stem}, skipping")
continue
res = evaluate_file(gt_p, pred_p)
per_file_scores.append((stem, res["per_stmt_scores"]))
# accumulate counts
for k, v in res["counts"].items():
total_counts[k]["tp"] += v["tp"]
total_counts[k]["fp"] += v["fp"]
total_counts[k]["fn"] += v["fn"]
# print per-file breakdown
print(f"\nFile: {stem}")
for stmt, score in res["per_stmt_scores"].items():
print(f" {stmt}: Jaccard={score:.3f}")
# Aggregate metrics
print("\n=== Aggregate metrics ===")
stmt_scope_results: Dict[Tuple[str, str], Tuple[float, float, float]] = {}
for stmt in TARGETS:
for scope in SCOPES:
tp = total_counts[(stmt, scope)]["tp"]
fp = total_counts[(stmt, scope)]["fp"]
fn = total_counts[(stmt, scope)]["fn"]
p, r, f1 = precision_recall_f1(tp, fp, fn)
stmt_scope_results[(stmt, scope)] = (p, r, f1)
print(f"{stmt}/{scope}: TP={tp} FP={fp} FN={fn} P={p:.3f} R={r:.3f} F1={f1:.3f}")
# Mean Jaccard across files and statements
all_scores = []
for _, per in per_file_scores:
for stmt in TARGETS:
if stmt in per:
all_scores.append(per[stmt])
mean_jaccard = sum(all_scores) / len(all_scores) if all_scores else 0.0
print(f"\nMean per-statement Jaccard (averaged over files and statements): {mean_jaccard:.3f}")
if __name__ == "__main__":
main()